1
1
import asyncio
2
2
import pickle
3
+ import random
3
4
import struct
4
5
from typing import Optional , Union
5
6
from unittest import mock
6
7
7
8
import pytest
8
9
9
10
from aiohttp ._websocket import helpers as _websocket_helpers
10
- from aiohttp ._websocket .helpers import PACK_CLOSE_CODE , PACK_LEN1 , PACK_LEN2
11
+ from aiohttp ._websocket .helpers import (
12
+ PACK_CLOSE_CODE ,
13
+ PACK_LEN1 ,
14
+ PACK_LEN2 ,
15
+ PACK_LEN3 ,
16
+ PACK_RANDBITS ,
17
+ websocket_mask ,
18
+ )
11
19
from aiohttp ._websocket .models import WS_DEFLATE_TRAILING
12
20
from aiohttp ._websocket .reader import WebSocketDataQueue
13
21
from aiohttp .base_protocol import BaseProtocol
@@ -52,6 +60,7 @@ def build_frame(
52
60
noheader : bool = False ,
53
61
is_fin : bool = True ,
54
62
ZLibBackend : Optional [ZLibBackendWrapper ] = None ,
63
+ mask : bool = False ,
55
64
) -> bytes :
56
65
# Send a frame over the websocket with message as its payload.
57
66
compress = False
@@ -72,11 +81,21 @@ def build_frame(
72
81
if compress :
73
82
header_first_byte |= 0x40
74
83
84
+ mask_bit = 0x80 if mask else 0
85
+
75
86
if msg_length < 126 :
76
- header = PACK_LEN1 (header_first_byte , msg_length )
87
+ header = PACK_LEN1 (header_first_byte , msg_length | mask_bit )
88
+ elif msg_length < 65536 :
89
+ header = PACK_LEN2 (header_first_byte , 126 | mask_bit , msg_length )
77
90
else :
78
- assert msg_length < (1 << 16 )
79
- header = PACK_LEN2 (header_first_byte , 126 , msg_length )
91
+ header = PACK_LEN3 (header_first_byte , 127 | mask_bit , msg_length )
92
+
93
+ if mask :
94
+ assert not noheader
95
+ mask_bytes = PACK_RANDBITS (random .getrandbits (32 ))
96
+ message_arr = bytearray (message )
97
+ websocket_mask (mask_bytes , message_arr )
98
+ return header + mask_bytes + message_arr
80
99
81
100
if noheader :
82
101
return message
@@ -352,6 +371,51 @@ def test_fragmentation_header(
352
371
assert res == WSMessageText (data = "a" , size = 1 , extra = "" )
353
372
354
373
374
+ def test_large_message (
375
+ out : WebSocketDataQueue , parser : PatchableWebSocketReader
376
+ ) -> None :
377
+ large_payload = b"b" * 131072
378
+ data = build_frame (large_payload , WSMsgType .BINARY )
379
+ parser ._feed_data (data )
380
+
381
+ res = out ._buffer [0 ]
382
+ assert res == WSMessageBinary (data = large_payload , size = 131072 , extra = "" )
383
+
384
+
385
+ def test_large_masked_message (
386
+ out : WebSocketDataQueue , parser : PatchableWebSocketReader
387
+ ) -> None :
388
+ large_payload = b"b" * 131072
389
+ data = build_frame (large_payload , WSMsgType .BINARY , mask = True )
390
+ parser ._feed_data (data )
391
+
392
+ res = out ._buffer [0 ]
393
+ assert res == WSMessageBinary (data = large_payload , size = 131072 , extra = "" )
394
+
395
+
396
+ def test_fragmented_masked_message (
397
+ out : WebSocketDataQueue , parser : PatchableWebSocketReader
398
+ ) -> None :
399
+ large_payload = b"b" * 100
400
+ data = build_frame (large_payload , WSMsgType .BINARY , mask = True )
401
+ for i in range (len (data )):
402
+ parser ._feed_data (data [i : i + 1 ])
403
+
404
+ res = out ._buffer [0 ]
405
+ assert res == WSMessageBinary (data = large_payload , size = 100 , extra = "" )
406
+
407
+
408
+ def test_large_fragmented_masked_message (
409
+ out : WebSocketDataQueue , parser : PatchableWebSocketReader
410
+ ) -> None :
411
+ large_payload = b"b" * 131072
412
+ data = build_frame (large_payload , WSMsgType .BINARY , mask = True )
413
+ for i in range (0 , len (data ), 16384 ):
414
+ parser ._feed_data (data [i : i + 16384 ])
415
+ res = out ._buffer [0 ]
416
+ assert res == WSMessageBinary (data = large_payload , size = 131072 , extra = "" )
417
+
418
+
355
419
def test_continuation (
356
420
out : WebSocketDataQueue , parser : PatchableWebSocketReader
357
421
) -> None :
0 commit comments