Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ jobs:
LIVEKIT_API_SECRET: ${{ secrets.LIVEKIT_API_SECRET }}
run: |
source .test-venv/bin/activate
pytest tests/
pytest tests/ livekit-rtc/tests/

- name: Run tests (Windows)
if: runner.os == 'Windows'
env:
LIVEKIT_URL: ${{ secrets.LIVEKIT_URL }}
LIVEKIT_API_KEY: ${{ secrets.LIVEKIT_API_KEY }}
LIVEKIT_API_SECRET: ${{ secrets.LIVEKIT_API_SECRET }}
run: .test-venv\Scripts\python.exe -m pytest tests/
run: .test-venv\Scripts\python.exe -m pytest tests/ livekit-rtc/tests/
shell: pwsh

373 changes: 373 additions & 0 deletions livekit-rtc/tests/test_dc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
"""End-to-end Test for data-channel scenarios.

Covers one-to-one delivery, broadcasting to all, topic filtering, and targeted
delivery via `destination_identities`.

Requires the following environment variables to run:
LIVEKIT_URL
LIVEKIT_API_KEY
LIVEKIT_API_SECRET
"""

from __future__ import annotations

import asyncio
import os
import uuid
from typing import Callable, Optional

import pytest

from livekit import api, rtc
from livekit.rtc.room import EventTypes


WAIT_TIMEOUT = 20.0
WAIT_INTERVAL = 0.1


def skip_if_no_credentials():
required_vars = ["LIVEKIT_URL", "LIVEKIT_API_KEY", "LIVEKIT_API_SECRET"]
missing = [var for var in required_vars if not os.getenv(var)]
return pytest.mark.skipif(
bool(missing), reason=f"Missing environment variables: {', '.join(missing)}"
)


def create_token(identity: str, room_name: str) -> str:
Comment thread
cloudwebrtc marked this conversation as resolved.
Outdated
return (
api.AccessToken()
.with_identity(identity)
.with_name(identity)
.with_grants(
api.VideoGrants(
room_join=True,
room=room_name,
)
)
.to_jwt()
)


def unique_room_name(base: str) -> str:
return f"{base}-{uuid.uuid4().hex[:8]}"


async def _wait_until(
predicate: Callable[[], bool],
*,
timeout: float = WAIT_TIMEOUT,
interval: float = WAIT_INTERVAL,
message: str = "condition not met",
) -> None:
loop = asyncio.get_event_loop()
deadline = loop.time() + timeout
while loop.time() < deadline:
if predicate():
return
await asyncio.sleep(interval)
raise AssertionError(f"timeout waiting: {message}")


async def _connect(room: rtc.Room, identity: str, room_name: str) -> str:
token = create_token(identity, room_name)
url = os.environ["LIVEKIT_URL"]
await room.connect(url, token)
return token


async def _ensure_all_connected(rooms: list[rtc.Room]) -> None:
await _wait_until(
lambda: all(r.connection_state == rtc.ConnectionState.CONN_CONNECTED for r in rooms),
message="not all rooms reached CONN_CONNECTED",
)


async def _ensure_visible(observer: rtc.Room, identities: list[str]) -> None:
"""Wait until `observer` sees every identity in `identities` as a remote participant.

Targeted publishes resolve identities at publish time, so we must let the
sender's room state catch up before sending."""

def _all_visible() -> bool:
seen = {p.identity for p in observer.remote_participants.values()}
return all(ident in seen for ident in identities)

await _wait_until(
_all_visible,
message=f"not all identities visible to {observer.local_participant.identity}: {identities}",
)


def _expect_event(
room: rtc.Room,
event: EventTypes,
predicate: Optional[Callable[..., bool]] = None,
) -> asyncio.Future:
loop = asyncio.get_event_loop()
fut: asyncio.Future = loop.create_future()

def _on_event(*args, **kwargs) -> None:
if fut.done():
return
if predicate is None or predicate(*args, **kwargs):
fut.set_result(args)

room.on(event, _on_event)
return fut


async def _await_event(fut: asyncio.Future, timeout: float = WAIT_TIMEOUT) -> None:
try:
await asyncio.wait_for(fut, timeout=timeout)
except asyncio.TimeoutError as e:
raise AssertionError("timed out waiting for event") from e


class _DataCollector:
"""Collects `data_received` packets matching `sender_identity` (when set)."""

def __init__(self, room: rtc.Room, sender_identity: Optional[str] = None) -> None:
self.packets: list[rtc.DataPacket] = []
self._sender_identity = sender_identity

def _on_data(packet: rtc.DataPacket) -> None:
if self._sender_identity is not None and (
packet.participant is None or packet.participant.identity != self._sender_identity
):
return
self.packets.append(packet)

room.on("data_received", _on_data)

def payloads(self) -> list[bytes]:
return [p.data for p in self.packets]

def topics(self) -> list[str | None]:
return [p.topic for p in self.packets]


async def _assert_no_data(
room: rtc.Room, collector: _DataCollector, *, settle: float = 1.0
) -> None:
"""Give the server time to deliver, then assert nothing arrived."""
await asyncio.sleep(settle)
assert collector.packets == [], (
f"{room.local_participant.identity} unexpectedly received "
f"{len(collector.packets)} packet(s): {collector.payloads()}"
)


@skip_if_no_credentials()
@pytest.mark.asyncio
async def test_data_one_to_one() -> None:
"""sender targets a single identity; only that identity receives."""
room_name = unique_room_name("py-dc-1to1")

sender = rtc.Room()
receiver = rtc.Room()
bystander = rtc.Room()

await _connect(sender, "sender", room_name)
await _connect(receiver, "receiver", room_name)
await _connect(bystander, "bystander", room_name)
await _ensure_all_connected([sender, receiver, bystander])
await _ensure_visible(sender, ["receiver", "bystander"])

receiver_collector = _DataCollector(receiver, sender_identity="sender")
bystander_collector = _DataCollector(bystander, sender_identity="sender")

receiver_got = _expect_event(
receiver,
"data_received",
predicate=lambda packet: (
packet.participant is not None and packet.participant.identity == "sender"
),
)

payload = b"hello receiver"
await sender.local_participant.publish_data(payload, destination_identities=["receiver"])

await _await_event(receiver_got)
assert receiver_collector.payloads() == [payload]
await _assert_no_data(bystander, bystander_collector)

await asyncio.gather(sender.disconnect(), receiver.disconnect(), bystander.disconnect())


@skip_if_no_credentials()
@pytest.mark.asyncio
async def test_data_one_to_many_targeted() -> None:
"""sender targets a subset of identities; only that subset receives."""
room_name = unique_room_name("py-dc-1tomany")

sender = rtc.Room()
r1 = rtc.Room()
r2 = rtc.Room()
excluded = rtc.Room()

await _connect(sender, "sender", room_name)
await _connect(r1, "r1", room_name)
await _connect(r2, "r2", room_name)
await _connect(excluded, "excluded", room_name)
await _ensure_all_connected([sender, r1, r2, excluded])
await _ensure_visible(sender, ["r1", "r2", "excluded"])

r1_collector = _DataCollector(r1, sender_identity="sender")
r2_collector = _DataCollector(r2, sender_identity="sender")
excluded_collector = _DataCollector(excluded, sender_identity="sender")

r1_got = _expect_event(
r1,
"data_received",
predicate=lambda packet: (
packet.participant is not None and packet.participant.identity == "sender"
),
)
r2_got = _expect_event(
r2,
"data_received",
predicate=lambda packet: (
packet.participant is not None and packet.participant.identity == "sender"
),
)

payload = b"hello selected"
await sender.local_participant.publish_data(payload, destination_identities=["r1", "r2"])

await asyncio.gather(_await_event(r1_got), _await_event(r2_got))
assert r1_collector.payloads() == [payload]
assert r2_collector.payloads() == [payload]
await _assert_no_data(excluded, excluded_collector)

await asyncio.gather(
sender.disconnect(), r1.disconnect(), r2.disconnect(), excluded.disconnect()
)


@skip_if_no_credentials()
@pytest.mark.asyncio
async def test_data_broadcast() -> None:
"""Empty `destination_identities` broadcasts to every other participant."""
room_name = unique_room_name("py-dc-broadcast")

sender = rtc.Room()
receivers = [rtc.Room() for _ in range(3)]
receiver_idents = [f"r{i}" for i in range(len(receivers))]

await _connect(sender, "sender", room_name)
for room, ident in zip(receivers, receiver_idents):
await _connect(room, ident, room_name)
await _ensure_all_connected([sender, *receivers])
await _ensure_visible(sender, receiver_idents)

collectors = [_DataCollector(room, sender_identity="sender") for room in receivers]
received_futures = [
_expect_event(
room,
"data_received",
predicate=lambda packet: (
packet.participant is not None and packet.participant.identity == "sender"
),
)
for room in receivers
]

payload = b"hello everyone"
await sender.local_participant.publish_data(payload)

await asyncio.gather(*(_await_event(f) for f in received_futures))
for ident, collector in zip(receiver_idents, collectors):
assert collector.payloads() == [payload], f"{ident} payloads mismatch"

await asyncio.gather(sender.disconnect(), *(r.disconnect() for r in receivers))


@skip_if_no_credentials()
@pytest.mark.asyncio
async def test_data_topic_passthrough() -> None:
"""Topic field is preserved end-to-end and observable by every receiver."""
room_name = unique_room_name("py-dc-topic")

sender = rtc.Room()
r1 = rtc.Room()
r2 = rtc.Room()

await _connect(sender, "sender", room_name)
await _connect(r1, "r1", room_name)
await _connect(r2, "r2", room_name)
await _ensure_all_connected([sender, r1, r2])
await _ensure_visible(sender, ["r1", "r2"])

r1_collector = _DataCollector(r1, sender_identity="sender")
r2_collector = _DataCollector(r2, sender_identity="sender")

# Send three messages: two on "chat", one on "telemetry".
messages = [
(b"chat-1", "chat"),
(b"telemetry-1", "telemetry"),
(b"chat-2", "chat"),
]

def _all_received(collector: _DataCollector) -> bool:
return len(collector.packets) >= len(messages)

for payload, topic in messages:
await sender.local_participant.publish_data(payload, topic=topic)

await _wait_until(
lambda: _all_received(r1_collector) and _all_received(r2_collector),
message="receivers did not get all topic messages",
)

expected_pairs = [(payload, topic) for payload, topic in messages]
for collector, ident in [(r1_collector, "r1"), (r2_collector, "r2")]:
got = list(zip(collector.payloads(), collector.topics()))
assert got == expected_pairs, f"{ident} mismatch: expected {expected_pairs}, got {got}"

# Also verify `chat`-only filtering at the consumer side works as expected.
chat_only_r1 = [p for p in r1_collector.packets if p.topic == "chat"]
assert [p.data for p in chat_only_r1] == [b"chat-1", b"chat-2"]

await asyncio.gather(sender.disconnect(), r1.disconnect(), r2.disconnect())


@skip_if_no_credentials()
@pytest.mark.asyncio
async def test_data_targeted_with_topic() -> None:
"""Targeted send carries the topic; non-targets receive nothing."""
room_name = unique_room_name("py-dc-targeted-topic")

sender = rtc.Room()
target = rtc.Room()
other = rtc.Room()

await _connect(sender, "sender", room_name)
await _connect(target, "target", room_name)
await _connect(other, "other", room_name)
await _ensure_all_connected([sender, target, other])
await _ensure_visible(sender, ["target", "other"])

target_collector = _DataCollector(target, sender_identity="sender")
other_collector = _DataCollector(other, sender_identity="sender")

target_got = _expect_event(
target,
"data_received",
predicate=lambda packet: (
packet.participant is not None and packet.participant.identity == "sender"
),
)

payload = b"private ping"
topic = "private"
await sender.local_participant.publish_data(
payload, destination_identities=["target"], topic=topic
)

await _await_event(target_got)
assert target_collector.payloads() == [payload]
assert target_collector.topics() == [topic]
await _assert_no_data(other, other_collector)

await asyncio.gather(sender.disconnect(), target.disconnect(), other.disconnect())
Loading