From 59d7f431837b848e89720b23c3c68c0070a00c6c Mon Sep 17 00:00:00 2001 From: Fiber Date: Sat, 30 May 2026 13:13:54 +0800 Subject: [PATCH 1/2] fix: add timeout protection for event hook calls --- astrbot/core/pipeline/context_utils.py | 31 +++- tests/unit/test_call_event_hook.py | 219 +++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_call_event_hook.py diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 9402ce3e62..31c91db701 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -1,3 +1,4 @@ +import asyncio import inspect import traceback import typing as T @@ -76,13 +77,21 @@ async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, *args, + timeout: float = 300.0, **kwargs, ) -> bool: """调用事件钩子函数 + Args: + event: 事件对象 + hook_type: 钩子事件类型 + *args: 传递给钩子处理器的位置参数 + timeout: 单个钩子处理器的超时时间(秒),超时后跳过该处理器继续执行。 + 设为 0 或负数则不启用超时。默认 300 秒。 + **kwargs: 传递给钩子处理器的关键字参数 + Returns: bool: 如果事件被终止,返回 True - # """ handlers = star_handlers_registry.get_handlers_by_event_type( @@ -92,16 +101,30 @@ async def call_event_hook( for handler in handlers: try: assert inspect.iscoroutinefunction(handler.handler) + plugin_name = star_map[handler.handler_module_path].name + handler_name = handler.handler_name logger.debug( - f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", + f"hook({hook_type.name}) -> {plugin_name} - {handler_name}", ) - await handler.handler(event, *args, **kwargs) + if timeout > 0: + try: + await asyncio.wait_for( + handler.handler(event, *args, **kwargs), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.warning( + f"hook({hook_type.name}) -> {plugin_name} - {handler_name} " + f"timed out after {timeout}s, skipping.", + ) + else: + await handler.handler(event, *args, **kwargs) except BaseException: logger.error(traceback.format_exc()) if event.is_stopped(): logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。", + f"{plugin_name} - {handler_name} 终止了事件传播。", ) return True diff --git a/tests/unit/test_call_event_hook.py b/tests/unit/test_call_event_hook.py new file mode 100644 index 0000000000..76573d59d7 --- /dev/null +++ b/tests/unit/test_call_event_hook.py @@ -0,0 +1,219 @@ +"""Tests for call_event_hook timeout protection.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.pipeline.context_utils import call_event_hook +from astrbot.core.star.star_handler import EventType + + +def _make_handler_metadata( + handler_coro, module_path="test_module", handler_name="test_handler" +): + handler = MagicMock() + handler.handler_module_path = module_path + handler.handler_name = handler_name + handler.handler = handler_coro + handler.enabled = True + return handler + + +def _make_event(stopped=False, plugins_name=None): + event = MagicMock() + event.unified_msg_origin = "test_umo" + event.plugins_name = plugins_name or [] + event.is_stopped = MagicMock(return_value=stopped) + return event + + +@pytest.fixture +def mock_star_map(): + with patch("astrbot.core.pipeline.context_utils.star_map") as sm: + sm.__getitem__ = MagicMock(return_value=MagicMock(name="TestPlugin")) + yield sm + + +@pytest.fixture +def mock_handlers_registry(): + with patch( + "astrbot.core.pipeline.context_utils.star_handlers_registry" + ) as registry: + yield registry + + +@pytest.mark.asyncio +async def test_hook_completes_within_timeout(mock_star_map, mock_handlers_registry): + handler_fn = AsyncMock() + handler_md = _make_handler_metadata(handler_fn) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=5.0) + + handler_fn.assert_awaited_once() + assert result is False + + +@pytest.mark.asyncio +async def test_hook_timeout_skips_handler(mock_star_map, mock_handlers_registry): + async def slow_handler(*args, **kwargs): + await asyncio.sleep(10) + + handler_md = _make_handler_metadata(slow_handler) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.5) + + assert result is False + + +@pytest.mark.asyncio +async def test_hook_timeout_does_not_block_subsequent_handlers( + mock_star_map, mock_handlers_registry +): + async def slow_handler(*args, **kwargs): + await asyncio.sleep(10) + + fast_handler_fn = AsyncMock() + slow_md = _make_handler_metadata( + slow_handler, module_path="slow_mod", handler_name="slow_h" + ) + fast_md = _make_handler_metadata( + fast_handler_fn, module_path="fast_mod", handler_name="fast_h" + ) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[slow_md, fast_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.5) + + fast_handler_fn.assert_awaited_once() + assert result is False + + +@pytest.mark.asyncio +async def test_hook_timeout_zero_disables_timeout( + mock_star_map, mock_handlers_registry +): + async def slow_handler(*args, **kwargs): + await asyncio.sleep(0.3) + + handler_md = _make_handler_metadata(slow_handler) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0) + + assert result is False + + +@pytest.mark.asyncio +async def test_hook_timeout_negative_disables_timeout( + mock_star_map, mock_handlers_registry +): + async def slow_handler(*args, **kwargs): + await asyncio.sleep(0.3) + + handler_md = _make_handler_metadata(slow_handler) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=-1) + + assert result is False + + +@pytest.mark.asyncio +async def test_hook_exception_continues(mock_star_map, mock_handlers_registry): + async def failing_handler(*args, **kwargs): + raise RuntimeError("test error") + + handler_md = _make_handler_metadata(failing_handler) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook(event, EventType.OnLLMRequestEvent) + + assert result is False + + +@pytest.mark.asyncio +async def test_hook_stops_event_propagation(mock_star_map, mock_handlers_registry): + handler_fn = AsyncMock() + handler_md = _make_handler_metadata(handler_fn) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event(stopped=True) + + result = await call_event_hook(event, EventType.OnLLMRequestEvent) + + assert result is True + + +@pytest.mark.asyncio +async def test_default_timeout_value(mock_star_map, mock_handlers_registry): + import inspect + + sig = inspect.signature(call_event_hook) + timeout_param = sig.parameters["timeout"] + assert timeout_param.default == 300.0 + + +@pytest.mark.asyncio +async def test_timeout_logs_plugin_name(mock_star_map, mock_handlers_registry): + async def slow_handler(*args, **kwargs): + await asyncio.sleep(10) + + handler_md = _make_handler_metadata( + slow_handler, module_path="my_plugin_module", handler_name="on_llm_req" + ) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + with patch("astrbot.core.pipeline.context_utils.logger") as mock_logger: + await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.2) + + warning_calls = [ + call for call in mock_logger.warning.call_args_list if "timed out" in str(call) + ] + assert len(warning_calls) == 1 + warning_msg = str(warning_calls[0]) + assert "on_llm_req" in warning_msg + + +@pytest.mark.asyncio +async def test_args_kwargs_passed_to_handler(mock_star_map, mock_handlers_registry): + handler_fn = AsyncMock() + handler_md = _make_handler_metadata(handler_fn) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + extra_arg = MagicMock() + await call_event_hook( + event, EventType.OnLLMRequestEvent, extra_arg, timeout=5.0, extra_kwarg="test" + ) + + handler_fn.assert_awaited_once() + call_args = handler_fn.call_args + assert call_args[0][0] is event + assert call_args[0][1] is extra_arg + assert call_args[1].get("extra_kwarg") == "test" From 88e3dd9fdc91a94dfb49301bca35b1009c3e980e Mon Sep 17 00:00:00 2001 From: Fiber Date: Sat, 30 May 2026 15:38:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?refactor(context=5Futils):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E4=BA=8B=E4=BB=B6=E9=92=A9=E5=AD=90=E8=B6=85=E6=97=B6?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 提取默认超时时间为常量_DEFAULT_HOOK_TIMEOUT 2. 将timeout参数重命名为hook_timeout提升可读性 3. 添加参数合法性校验,非法值自动回退到默认超时 4. 更新所有测试用例适配参数变更 5. 新增测试用例验证非法超时参数的处理逻辑 --- astrbot/core/pipeline/context_utils.py | 14 +++-- tests/unit/test_call_event_hook.py | 87 +++++++++++++++++++------- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 31c91db701..33e94f193b 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -9,6 +9,8 @@ from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry +_DEFAULT_HOOK_TIMEOUT: float = 300.0 + async def call_handler( event: AstrMessageEvent, @@ -77,7 +79,7 @@ async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, *args, - timeout: float = 300.0, + hook_timeout: float = _DEFAULT_HOOK_TIMEOUT, **kwargs, ) -> bool: """调用事件钩子函数 @@ -86,7 +88,7 @@ async def call_event_hook( event: 事件对象 hook_type: 钩子事件类型 *args: 传递给钩子处理器的位置参数 - timeout: 单个钩子处理器的超时时间(秒),超时后跳过该处理器继续执行。 + hook_timeout: 单个钩子处理器的超时时间(秒),超时后跳过该处理器继续执行。 设为 0 或负数则不启用超时。默认 300 秒。 **kwargs: 传递给钩子处理器的关键字参数 @@ -98,6 +100,8 @@ async def call_event_hook( hook_type, plugins_name=event.plugins_name, ) + if hook_timeout is None or not isinstance(hook_timeout, int | float): + hook_timeout = _DEFAULT_HOOK_TIMEOUT for handler in handlers: try: assert inspect.iscoroutinefunction(handler.handler) @@ -106,16 +110,16 @@ async def call_event_hook( logger.debug( f"hook({hook_type.name}) -> {plugin_name} - {handler_name}", ) - if timeout > 0: + if hook_timeout > 0: try: await asyncio.wait_for( handler.handler(event, *args, **kwargs), - timeout=timeout, + timeout=hook_timeout, ) except asyncio.TimeoutError: logger.warning( f"hook({hook_type.name}) -> {plugin_name} - {handler_name} " - f"timed out after {timeout}s, skipping.", + f"timed out after {hook_timeout}s, skipping.", ) else: await handler.handler(event, *args, **kwargs) diff --git a/tests/unit/test_call_event_hook.py b/tests/unit/test_call_event_hook.py index 76573d59d7..086f023e14 100644 --- a/tests/unit/test_call_event_hook.py +++ b/tests/unit/test_call_event_hook.py @@ -5,7 +5,7 @@ import pytest -from astrbot.core.pipeline.context_utils import call_event_hook +from astrbot.core.pipeline.context_utils import _DEFAULT_HOOK_TIMEOUT, call_event_hook from astrbot.core.star.star_handler import EventType @@ -52,7 +52,7 @@ async def test_hook_completes_within_timeout(mock_star_map, mock_handlers_regist ) event = _make_event() - result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=5.0) + result = await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=5.0) handler_fn.assert_awaited_once() assert result is False @@ -60,16 +60,17 @@ async def test_hook_completes_within_timeout(mock_star_map, mock_handlers_regist @pytest.mark.asyncio async def test_hook_timeout_skips_handler(mock_star_map, mock_handlers_registry): - async def slow_handler(*args, **kwargs): - await asyncio.sleep(10) + async def stuck_handler(*args, **kwargs): + event = asyncio.Event() + await event.wait() - handler_md = _make_handler_metadata(slow_handler) + handler_md = _make_handler_metadata(stuck_handler) mock_handlers_registry.get_handlers_by_event_type = MagicMock( return_value=[handler_md] ) event = _make_event() - result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.5) + result = await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=0.5) assert result is False @@ -78,12 +79,13 @@ async def slow_handler(*args, **kwargs): async def test_hook_timeout_does_not_block_subsequent_handlers( mock_star_map, mock_handlers_registry ): - async def slow_handler(*args, **kwargs): - await asyncio.sleep(10) + async def stuck_handler(*args, **kwargs): + event = asyncio.Event() + await event.wait() fast_handler_fn = AsyncMock() slow_md = _make_handler_metadata( - slow_handler, module_path="slow_mod", handler_name="slow_h" + stuck_handler, module_path="slow_mod", handler_name="slow_h" ) fast_md = _make_handler_metadata( fast_handler_fn, module_path="fast_mod", handler_name="fast_h" @@ -93,7 +95,7 @@ async def slow_handler(*args, **kwargs): ) event = _make_event() - result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.5) + result = await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=0.5) fast_handler_fn.assert_awaited_once() assert result is False @@ -112,7 +114,7 @@ async def slow_handler(*args, **kwargs): ) event = _make_event() - result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0) + result = await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=0) assert result is False @@ -130,7 +132,7 @@ async def slow_handler(*args, **kwargs): ) event = _make_event() - result = await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=-1) + result = await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=-1) assert result is False @@ -166,21 +168,18 @@ async def test_hook_stops_event_propagation(mock_star_map, mock_handlers_registr @pytest.mark.asyncio -async def test_default_timeout_value(mock_star_map, mock_handlers_registry): - import inspect - - sig = inspect.signature(call_event_hook) - timeout_param = sig.parameters["timeout"] - assert timeout_param.default == 300.0 +async def test_default_timeout_value(): + assert _DEFAULT_HOOK_TIMEOUT == 300.0 @pytest.mark.asyncio async def test_timeout_logs_plugin_name(mock_star_map, mock_handlers_registry): - async def slow_handler(*args, **kwargs): - await asyncio.sleep(10) + async def stuck_handler(*args, **kwargs): + event = asyncio.Event() + await event.wait() handler_md = _make_handler_metadata( - slow_handler, module_path="my_plugin_module", handler_name="on_llm_req" + stuck_handler, module_path="my_plugin_module", handler_name="on_llm_req" ) mock_handlers_registry.get_handlers_by_event_type = MagicMock( return_value=[handler_md] @@ -188,7 +187,7 @@ async def slow_handler(*args, **kwargs): event = _make_event() with patch("astrbot.core.pipeline.context_utils.logger") as mock_logger: - await call_event_hook(event, EventType.OnLLMRequestEvent, timeout=0.2) + await call_event_hook(event, EventType.OnLLMRequestEvent, hook_timeout=0.2) warning_calls = [ call for call in mock_logger.warning.call_args_list if "timed out" in str(call) @@ -209,7 +208,11 @@ async def test_args_kwargs_passed_to_handler(mock_star_map, mock_handlers_regist extra_arg = MagicMock() await call_event_hook( - event, EventType.OnLLMRequestEvent, extra_arg, timeout=5.0, extra_kwarg="test" + event, + EventType.OnLLMRequestEvent, + extra_arg, + hook_timeout=5.0, + extra_kwarg="test", ) handler_fn.assert_awaited_once() @@ -217,3 +220,41 @@ async def test_args_kwargs_passed_to_handler(mock_star_map, mock_handlers_regist assert call_args[0][0] is event assert call_args[0][1] is extra_arg assert call_args[1].get("extra_kwarg") == "test" + + +@pytest.mark.asyncio +async def test_timeout_none_falls_back_to_default( + mock_star_map, mock_handlers_registry +): + handler_fn = AsyncMock() + handler_md = _make_handler_metadata(handler_fn) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook( + event, EventType.OnLLMRequestEvent, hook_timeout=None + ) + + handler_fn.assert_awaited_once() + assert result is False + + +@pytest.mark.asyncio +async def test_timeout_string_falls_back_to_default( + mock_star_map, mock_handlers_registry +): + handler_fn = AsyncMock() + handler_md = _make_handler_metadata(handler_fn) + mock_handlers_registry.get_handlers_by_event_type = MagicMock( + return_value=[handler_md] + ) + event = _make_event() + + result = await call_event_hook( + event, EventType.OnLLMRequestEvent, hook_timeout="30" + ) + + handler_fn.assert_awaited_once() + assert result is False