diff --git a/CHANGES/12346.misc.rst b/CHANGES/12346.misc.rst new file mode 100644 index 00000000000..eaec57183ce --- /dev/null +++ b/CHANGES/12346.misc.rst @@ -0,0 +1,3 @@ +Improved performance of ``_WS_EXT_RE`` regular expression on Python 3.11+ +by using atomic grouping when parsing ``Sec-WebSocket-Extensions`` headers +-- by :user:`HarshithReddy01`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index e61c5e8e328..4e757380d31 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -160,6 +160,7 @@ Günther Jena Hans Adema Harmon Y. Harry Liu +Harshith Reddy Hiroshi Ogawa Hrishikesh Paranjape Hu Bo diff --git a/aiohttp/_websocket/helpers.py b/aiohttp/_websocket/helpers.py index f9a44cdd39b..5080b8a1712 100644 --- a/aiohttp/_websocket/helpers.py +++ b/aiohttp/_websocket/helpers.py @@ -2,6 +2,7 @@ import functools import re +import sys from re import Pattern from struct import Struct from typing import TYPE_CHECKING, Final @@ -70,6 +71,12 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None: r"(client_no_context_takeover)|" r"(server_max_window_bits(?:=(\d+))?)|" r"(client_max_window_bits(?:=(\d+))?)))*$" + if sys.version_info < (3, 11) + else r"^(?>;\s*(?:" + r"(server_no_context_takeover)|" + r"(client_no_context_takeover)|" + r"(server_max_window_bits(?:=(\d+))?)|" + r"(client_max_window_bits(?:=(\d+))?)))*$" ) _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") diff --git a/tests/test_websocket_helpers.py b/tests/test_websocket_helpers.py new file mode 100644 index 00000000000..8adb7126ec9 --- /dev/null +++ b/tests/test_websocket_helpers.py @@ -0,0 +1,52 @@ +import time + +import pytest + +from aiohttp._websocket.helpers import ws_ext_parse +from aiohttp.http_websocket import WSHandshakeError + + +@pytest.mark.parametrize( + ("msg", "server", "expected"), + ( + ("permessage-deflate", False, (15, False)), + ("permessage-deflate; server_no_context_takeover", True, (15, True)), + ("permessage-deflate; client_no_context_takeover", False, (15, True)), + ("permessage-deflate; server_max_window_bits=12", True, (12, False)), + ("permessage-deflate; client_max_window_bits=10", False, (10, False)), + # out-of-range wbits on server side → skip rather than fail + ("permessage-deflate; server_max_window_bits=8", True, (0, False)), + # unknown param on server side → no match, return zero + ("permessage-deflate; unknown_param", True, (0, False)), + ), +) +def test_ws_ext_parse(msg: str, server: bool, expected: tuple[int, bool]) -> None: + assert ws_ext_parse(msg, isserver=server) == expected + + +@pytest.mark.parametrize( + ("msg", "server"), + ( + ("permessage-deflate; client_max_window_bits=8", False), + ("permessage-deflate; unknown_param", False), + ), +) +def test_ws_ext_parse_raises(msg: str, server: bool) -> None: + with pytest.raises(WSHandshakeError): + ws_ext_parse(msg, isserver=server) + + +def test_ws_ext_parse_empty() -> None: + assert ws_ext_parse(None) == (0, False) + assert ws_ext_parse("") == (0, False) + + +def test_ws_ext_parse_backtracking_performance() -> None: + # Many valid tokens followed by an invalid suffix — the classic input that + # triggers exponential backtracking in the outer repeating group. + evil = "permessage-deflate" + ("; server_no_context_takeover" * 30) + ";INVALID" + start = time.perf_counter() + with pytest.raises(WSHandshakeError): + ws_ext_parse(evil, isserver=False) + elapsed = time.perf_counter() - start + assert elapsed < 1.0, f"backtracking regression: took {elapsed:.3f}s"