Skip to content

Commit fe7ef60

Browse files
committed
Add tests for _highlevel_open_unix_listeners
1 parent c922a52 commit fe7ef60

1 file changed

Lines changed: 128 additions & 0 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import socket as stdlib_socket
4+
import sys
5+
from typing import TYPE_CHECKING, cast
6+
7+
import pytest
8+
9+
import trio
10+
import trio.socket as tsocket
11+
from trio import (
12+
SocketListener,
13+
open_unix_listener,
14+
serve_unix,
15+
)
16+
from trio.testing import open_stream_to_socket_listener
17+
18+
if TYPE_CHECKING:
19+
from pathlib import Path
20+
21+
from trio.abc import SendStream
22+
23+
assert not TYPE_CHECKING or sys.platform != "win32"
24+
25+
26+
skip_if_not_unix = pytest.mark.skipif(
27+
not hasattr(tsocket, "AF_UNIX"),
28+
reason="Needs unix socket support",
29+
)
30+
31+
32+
@pytest.fixture
33+
def temp_unix_socket_path(tmp_path: Path) -> str:
34+
"""Fixture to create a temporary Unix socket path."""
35+
# Create a temporary file in the tmp_path directory
36+
temp_socket_path = tmp_path / "socket.sock"
37+
return str(temp_socket_path)
38+
39+
40+
@skip_if_not_unix
41+
async def test_open_unix_listener_basic(temp_unix_socket_path: str) -> None:
42+
listener = await open_unix_listener(temp_unix_socket_path)
43+
44+
assert isinstance(listener, SocketListener)
45+
# Check that the listener is using the Unix socket family
46+
assert listener.socket.family == tsocket.AF_UNIX
47+
assert listener.socket.getsockname() == temp_unix_socket_path
48+
49+
# Make sure the backlog is at least 2
50+
c1 = await open_stream_to_socket_listener(listener)
51+
c2 = await open_stream_to_socket_listener(listener)
52+
53+
s1 = await listener.accept()
54+
s2 = await listener.accept()
55+
56+
# Note that we don't know which client stream is connected to which server
57+
# stream
58+
await s1.send_all(b"x")
59+
await s2.send_all(b"x")
60+
assert await c1.receive_some(1) == b"x"
61+
assert await c2.receive_some(1) == b"x"
62+
63+
for resource in [c1, c2, s1, s2, listener]:
64+
await resource.aclose()
65+
66+
67+
@skip_if_not_unix
68+
async def test_open_unix_listener_specific_path(temp_unix_socket_path: str) -> None:
69+
listener = await open_unix_listener(temp_unix_socket_path)
70+
async with listener:
71+
assert listener.socket.getsockname() == temp_unix_socket_path
72+
73+
74+
@skip_if_not_unix
75+
async def test_open_unix_listener_rebind(temp_unix_socket_path: str) -> None:
76+
listener = await open_unix_listener(temp_unix_socket_path)
77+
sockaddr1 = listener.socket.getsockname()
78+
79+
# Attempt to bind again to the same socket should fail
80+
with stdlib_socket.socket(tsocket.AF_UNIX) as probe:
81+
with pytest.raises(
82+
OSError,
83+
match=r"(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
84+
):
85+
probe.bind(temp_unix_socket_path)
86+
87+
# Now use the listener to set up some connections
88+
c_established = await open_stream_to_socket_listener(listener)
89+
s_established = await listener.accept()
90+
await listener.aclose()
91+
92+
# Attempt to bind again should succeed after closing the listener
93+
listener2 = await open_unix_listener(temp_unix_socket_path)
94+
sockaddr2 = listener2.socket.getsockname()
95+
96+
assert sockaddr1 == sockaddr2
97+
assert s_established.socket.getsockname() == sockaddr2
98+
99+
for resource in [listener2, c_established, s_established]:
100+
await resource.aclose()
101+
102+
103+
@skip_if_not_unix
104+
async def test_serve_unix(temp_unix_socket_path: str) -> None:
105+
async def handler(stream: SendStream) -> None:
106+
await stream.send_all(b"x")
107+
108+
async with trio.open_nursery() as nursery:
109+
# nursery.start is incorrectly typed, awaiting #2773
110+
value = await nursery.start(serve_unix, handler, temp_unix_socket_path)
111+
assert isinstance(value, list)
112+
listeners = cast("list[SocketListener]", value)
113+
stream = await open_stream_to_socket_listener(listeners[0])
114+
async with stream:
115+
assert await stream.receive_some(1) == b"x"
116+
nursery.cancel_scope.cancel()
117+
for listener in listeners:
118+
await listener.aclose()
119+
120+
121+
@pytest.mark.skipif(hasattr(tsocket, "AF_UNIX"), reason="Test for non-unix platforms")
122+
async def test_error_on_no_unix(temp_unix_socket_path: str) -> None:
123+
with pytest.raises(
124+
RuntimeError,
125+
match=r"^Unix sockets are not supported on this platform$",
126+
):
127+
async with await open_unix_listener(temp_unix_socket_path):
128+
pass

0 commit comments

Comments
 (0)