Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/XXXX.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Used an atomic group in ``_WS_EXT_RE`` on Python 3.11+ to prevent
unnecessary backtracking when parsing ``Sec-WebSocket-Extensions`` headers
-- by :user:`HarshithReddy01`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Günther Jena
Hans Adema
Harmon Y.
Harry Liu
Harshith Reddy
Hiroshi Ogawa
Hrishikesh Paranjape
Hu Bo
Expand Down
9 changes: 9 additions & 0 deletions aiohttp/_websocket/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import re
import sys
from re import Pattern
from struct import Struct
from typing import TYPE_CHECKING, Final
Expand Down Expand Up @@ -64,12 +65,20 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
websocket_mask = _websocket_mask_python


# On 3.11+ use an atomic outer group to avoid backtracking over already-matched
# iterations when the tail of the string doesn't match.
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
r"^(?:;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
if sys.version_info < (3, 11)
else r"^(?>;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
)

_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
Expand Down
80 changes: 80 additions & 0 deletions tests/test_ws_ext_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import time

import pytest

from aiohttp._websocket.helpers import ws_ext_parse
from aiohttp.http_websocket import WSHandshakeError


class TestWsExtParse:
def test_empty(self) -> None:
assert ws_ext_parse(None) == (0, False)
assert ws_ext_parse("") == (0, False)

def test_permessage_deflate_only(self) -> None:
compress, notakeover = ws_ext_parse("permessage-deflate")
assert compress == 15
assert notakeover is False

def test_server_no_context_takeover(self) -> None:
compress, notakeover = ws_ext_parse(
"permessage-deflate; server_no_context_takeover", isserver=True
)
assert compress == 15
assert notakeover is True

def test_client_no_context_takeover(self) -> None:
compress, notakeover = ws_ext_parse(
"permessage-deflate; client_no_context_takeover", isserver=False
)
assert compress == 15
assert notakeover is True

def test_server_max_window_bits(self) -> None:
compress, notakeover = ws_ext_parse(
"permessage-deflate; server_max_window_bits=12", isserver=True
)
assert compress == 12
assert notakeover is False
Comment thread
Dreamsorcerer marked this conversation as resolved.
Outdated

def test_client_max_window_bits(self) -> None:
compress, notakeover = ws_ext_parse(
"permessage-deflate; client_max_window_bits=10", isserver=False
)
assert compress == 10
assert notakeover is False

def test_window_bits_out_of_range_server(self) -> None:
# out-of-range wbits on server side → skip, return 0
compress, _ = ws_ext_parse(
"permessage-deflate; server_max_window_bits=8", isserver=True
)
assert compress == 0

def test_window_bits_out_of_range_client(self) -> None:
with pytest.raises(WSHandshakeError):
ws_ext_parse("permessage-deflate; client_max_window_bits=8", isserver=False)

def test_invalid_extension_client_raises(self) -> None:
with pytest.raises(WSHandshakeError):
ws_ext_parse("permessage-deflate; unknown_param", isserver=False)

def test_no_match_server_returns_zero(self) -> None:
compress, notakeover = ws_ext_parse(
"permessage-deflate; unknown_param", isserver=True
)
assert compress == 0
assert notakeover is False

def test_backtracking_performance(self) -> None:
# Crafted input: many valid tokens followed by an invalid suffix.
# Without the atomic group fix this causes exponential backtracking.
evil = "permessage-deflate" + ("; server_no_context_takeover" * 30) + ";INVALID"
start = time.perf_counter()
try:
Comment thread
Dreamsorcerer marked this conversation as resolved.
Outdated
ws_ext_parse(evil, isserver=True)
except WSHandshakeError:
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
pass
elapsed = time.perf_counter() - start

Check notice

Code scanning / CodeQL

Empty except Note test

'except' clause does nothing but pass and there is no explanatory comment.
# Should complete in well under a second on any reasonable hardware.
assert elapsed < 1.0, f"possible backtracking regression: took {elapsed:.3f}s"
Loading