Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/mcp/server/connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""`Connection` — per-client connection state and the standalone outbound channel.

Always present on `Context` (never ``None``), even in stateless deployments.
Holds peer info populated at ``initialize`` time, the per-connection lifespan
output, and an `Outbound` for the standalone stream (the SSE GET stream in
streamable HTTP, or the single duplex stream in stdio).
Holds peer info populated at ``initialize`` time, per-connection scratch
``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the
standalone stream (the SSE GET stream in streamable HTTP, or the single duplex
stream in stdio).

`notify` is best-effort: it never raises. If there's no standalone channel
(stateless HTTP) or the stream has been dropped, the notification is
Expand All @@ -14,6 +15,7 @@

import logging
from collections.abc import Mapping
from contextlib import AsyncExitStack
from typing import Any

import anyio
Expand Down Expand Up @@ -44,17 +46,27 @@ class Connection(TypedServerRequestMixin):
``None`` until ``initialize`` completes; ``initialized`` is set then.
"""

def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None:
self._outbound = outbound
self.has_standalone_channel = has_standalone_channel
self.session_id: str | None = session_id

self.client_info: Implementation | None = None
self.client_capabilities: ClientCapabilities | None = None
self.protocol_version: str | None = None
self.initialized: anyio.Event = anyio.Event()
# TODO: make this generic (Connection[StateT]) once connection_lifespan
# wiring lands in ServerRunner.
self.state: Any = None

self.state: dict[str, Any] = {}
"""Per-connection scratch state. Handlers and middleware may read and
write freely; persists across requests on this connection."""

self.exit_stack: AsyncExitStack = AsyncExitStack()
"""Cleanup stack unwound by `ServerRunner` when the connection closes.

Push context managers (``await exit_stack.enter_async_context(...)``)
or callbacks (``exit_stack.push_async_callback(...)``) from handlers or
middleware to register per-connection teardown. Unwound LIFO after
`dispatcher.run()` returns, shielded from cancellation."""

async def send_raw_request(
self,
Expand Down
36 changes: 26 additions & 10 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, Generic, Protocol

Expand Down Expand Up @@ -33,10 +33,9 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex


LifespanT = TypeVar("LifespanT", default=Any, covariant=True)
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)


class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]):
"""Server-side per-request context.

Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
Expand All @@ -50,7 +49,7 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener

def __init__(
self,
dctx: DispatchContext[TransportT],
dctx: DispatchContext[TransportContext],
*,
lifespan: LifespanT,
connection: Connection,
Expand All @@ -70,6 +69,23 @@ def connection(self) -> Connection:
"""The per-client `Connection` for this request's connection."""
return self._connection

@property
def session_id(self) -> str | None:
"""The transport's session id for this connection, when one exists.

Convenience for ``ctx.connection.session_id``. ``None`` on stdio and
stateless HTTP.
"""
return self._connection.session_id

@property
def headers(self) -> Mapping[str, str] | None:
"""Request headers carried by this message, when the transport has them.

Convenience for ``ctx.transport.headers``. ``None`` on stdio.
"""
return self.transport.headers

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a request-scoped ``notifications/message`` log entry.

Expand All @@ -94,23 +110,23 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)


class ContextMiddleware(Protocol[_MwLifespanT]):
class ServerMiddleware(Protocol[_MwLifespanT]):
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.

Runs *inside* `ServerRunner._on_request` after params validation and
`Context` construction. Wraps registered handlers (including ``ping``) but
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
outermost-first on `Server.middleware`.

`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
`LifespanT`, so it registers on any `Server[L]`.
`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
middleware sees `ctx.lifespan: L`. A reusable middleware can be typed
`ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it
registers on any `Server[L]`.
"""

async def __call__(
self,
ctx: Context[_MwLifespanT, TransportContext],
ctx: Context[_MwLifespanT],
method: str,
params: BaseModel,
call_next: CallNext,
Expand Down
147 changes: 101 additions & 46 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ async def main():
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from importlib.metadata import version as importlib_version
from typing import Any, Generic, cast

import anyio
from opentelemetry.trace import SpanKind, StatusCode
from pydantic import BaseModel
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
Expand All @@ -58,7 +60,7 @@ async def main():
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
from mcp.server.auth.settings import AuthSettings
from mcp.server.context import ContextMiddleware, ServerRequestContext
from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext
from mcp.server.experimental.request_context import Experimental
from mcp.server.lowlevel.experimental import ExperimentalHandlers
from mcp.server.models import InitializationOptions
Expand All @@ -76,6 +78,30 @@ async def main():

LifespanResultT = TypeVar("LifespanResultT", default=Any)

_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel)

RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]]
"""A registered request handler: ``(ctx, params) -> result``."""

NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]]
"""A registered notification handler: ``(ctx, params) -> None``."""


@dataclass(frozen=True, slots=True)
class HandlerEntry(Generic[LifespanResultT]):
"""A registered handler and the params model to validate incoming params against.

Stored in `Server._request_handlers` / `_notification_handlers` and consumed
by `ServerRunner` to validate, build `Context`, and invoke. The handler's
second-argument type is erased to ``Any`` in storage (each entry has a
different concrete params type and `Callable` parameters are contravariant);
the precise type is recoverable via `params_type`. The correlation is
enforced at registration time by `Server.add_request_handler`.
"""

params_type: type[BaseModel]
handler: RequestHandler[LifespanResultT, Any]


class NotificationOptions:
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
Expand All @@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals


@asynccontextmanager
async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]:
async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]:
"""Default lifespan context manager that does nothing.

Returns:
Expand All @@ -109,6 +135,8 @@ def __init__(
instructions: str | None = None,
website_url: str | None = None,
icons: list[types.Icon] | None = None,
notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
lifespan: Callable[
[Server[LifespanResultT]],
AbstractAsyncContextManager[LifespanResultT],
Expand Down Expand Up @@ -193,72 +221,96 @@ def __init__(
self.website_url = website_url
self.icons = icons
self.lifespan = lifespan
self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {}
self._notification_handlers: dict[
str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]]
] = {}
self._notification_options = notification_options or NotificationOptions()
self._experimental_capabilities = experimental_capabilities or {}
self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
self._session_manager: StreamableHTTPSessionManager | None = None
# Context-tier middleware consumed by `ServerRunner`. Additive; the
# existing `run()` path ignores it.
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
self.middleware: list[ServerMiddleware[LifespanResultT]] = []
logger.debug("Initializing server %r", name)

# Populate internal handler dicts from on_* kwargs
self._request_handlers.update(
{
method: handler
for method, handler in {
"ping": on_ping,
"prompts/list": on_list_prompts,
"prompts/get": on_get_prompt,
"resources/list": on_list_resources,
"resources/templates/list": on_list_resource_templates,
"resources/read": on_read_resource,
"resources/subscribe": on_subscribe_resource,
"resources/unsubscribe": on_unsubscribe_resource,
"tools/list": on_list_tools,
"tools/call": on_call_tool,
"logging/setLevel": on_set_logging_level,
"completion/complete": on_completion,
}.items()
if handler is not None
}
)
_spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [
("ping", types.RequestParams, on_ping),
("prompts/list", types.PaginatedRequestParams, on_list_prompts),
("prompts/get", types.GetPromptRequestParams, on_get_prompt),
("resources/list", types.PaginatedRequestParams, on_list_resources),
("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates),
("resources/read", types.ReadResourceRequestParams, on_read_resource),
("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource),
("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource),
("tools/list", types.PaginatedRequestParams, on_list_tools),
("tools/call", types.CallToolRequestParams, on_call_tool),
("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level),
("completion/complete", types.CompleteRequestParams, on_completion),
]
self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None})

_spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [
("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed),
("notifications/progress", types.ProgressNotificationParams, on_progress),
]
self._notification_handlers.update(
{
method: handler
for method, handler in {
"notifications/roots/list_changed": on_roots_list_changed,
"notifications/progress": on_progress,
}.items()
if handler is not None
}
{m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None}
)

def add_request_handler(
self,
method: str,
params_type: type[_ParamsT],
handler: RequestHandler[LifespanResultT, _ParamsT],
) -> None:
"""Register a request handler for ``method``.

``params_type`` is the model incoming params are validated against
before the handler is invoked. It should subclass `RequestParams` so
``_meta`` parses uniformly. Replaces any existing handler for the same
method (no collision guard against spec methods).
"""
self._request_handlers[method] = HandlerEntry(params_type, handler)

def add_notification_handler(
self,
method: str,
params_type: type[_ParamsT],
handler: NotificationHandler[LifespanResultT, _ParamsT],
) -> None:
"""Register a notification handler for ``method``.

``params_type`` should subclass `NotificationParams` so ``_meta``
parses uniformly. Replaces any existing handler.
"""
self._notification_handlers[method] = HandlerEntry(params_type, handler)

def _add_request_handler(
self,
method: str,
handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]],
handler: RequestHandler[LifespanResultT, Any],
) -> None:
"""Add a request handler, silently replacing any existing handler for the same method."""
self._request_handlers[method] = handler
# TODO: remove once experimental tasks plumbing and remaining callers
# migrate to `add_request_handler` with an explicit params_type.
self.add_request_handler(method, types.RequestParams, handler)

def _has_handler(self, method: str) -> bool:
"""Check if a handler is registered for the given method."""
return method in self._request_handlers or method in self._notification_handlers

# --- ServerRegistry protocol (consumed by ServerRunner) ------------------

def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a request method, or ``None``."""
def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None:
"""Return the registered entry for a request method, or ``None``."""
return self._request_handlers.get(method)

def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a notification method, or ``None``."""
def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None:
"""Return the registered entry for a notification method, or ``None``."""
return self._notification_handlers.get(method)

def capabilities(self) -> types.ServerCapabilities:
"""Derive `ServerCapabilities` from registered handlers and constructor options."""
return self.get_capabilities(self._notification_options, self._experimental_capabilities)

# TODO: Rethink capabilities API. Currently capabilities are derived from registered
# handlers but require NotificationOptions to be passed externally for list_changed
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities
Expand Down Expand Up @@ -474,7 +526,8 @@ async def _handle_request(
attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id},
context=parent_context,
) as span:
if handler := self._request_handlers.get(req.method):
if entry := self._request_handlers.get(req.method):
handler = entry.handler
logger.debug("Dispatching request of type %s", type(req).__name__)

try:
Expand Down Expand Up @@ -533,7 +586,8 @@ async def _handle_request(
span.set_status(StatusCode.ERROR, response.message)

try:
await message.respond(response)
# TODO: cast goes away when `_handle_request` is deleted.
await message.respond(cast(types.ServerResult | types.ErrorData, response))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
# Transport closed between handler unblocking and respond. Happens
# when _receive_loop's finally wakes a handler blocked on
Expand All @@ -552,7 +606,8 @@ async def _handle_notification(
session: ServerSession,
lifespan_context: LifespanResultT,
) -> None:
if handler := self._notification_handlers.get(notify.method):
if entry := self._notification_handlers.get(notify.method):
handler = entry.handler
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
Expand Down
Loading
Loading