diff --git a/crates/ironrdp-rdpsnd/Cargo.toml b/crates/ironrdp-rdpsnd/Cargo.toml index 6aeaac9d7..4ca7cda1d 100644 --- a/crates/ironrdp-rdpsnd/Cargo.toml +++ b/crates/ironrdp-rdpsnd/Cargo.toml @@ -14,7 +14,7 @@ categories.workspace = true [lib] doctest = false -test = false +# test = false [features] default = [] diff --git a/crates/ironrdp-rdpsnd/src/server.rs b/crates/ironrdp-rdpsnd/src/server.rs index dd0026dee..ff5ad173b 100644 --- a/crates/ironrdp-rdpsnd/src/server.rs +++ b/crates/ironrdp-rdpsnd/src/server.rs @@ -28,34 +28,77 @@ pub enum RdpsndServerMessage { Error(Box), } +/// A server-offered audio format that the client also advertised support for, +/// paired with the `wFormatNo` the client expects for it on the wire. +/// +/// The crate computes the set of these — the intersection of the server's +/// [`get_formats`] and the client's accepted formats — and hands it to +/// [`RdpsndServerHandler::choose_format`], which returns the one to stream. +/// +/// `wformat_no` is intentionally private and there is no public constructor: +/// a handler can neither build nor mutate a `NegotiatedFormat`, so the index +/// stamped onto every Wave/Wave2 PDU is always a valid position in the +/// client's own format list. This makes it impossible to emit an out-of-range +/// `wFormatNo` (which a compliant client rejects, silently dropping all audio +/// — the classic footgun of the old index-returning API). +/// +/// [`get_formats`]: RdpsndServerHandler::get_formats +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiatedFormat { + /// The negotiated audio format (common to server and client). + format: pdu::AudioFormat, + /// Position of `format` in the client's Client Audio Formats list — the + /// `wFormatNo` the client resolves each wave against. Crate-owned. + wformat_no: u16, +} + +impl NegotiatedFormat { + /// The negotiated audio format — common to both server and client, and the + /// one the returned wave data should match. + pub fn format(&self) -> &pdu::AudioFormat { + &self.format + } +} + /// Handler for the server side of the Audio Output Virtual Channel (`RDPSND`). /// -/// Implementations supply the list of audio formats the server offers, decide -/// which format to use once the client replies, and produce the audio waves to -/// stream (via [`RdpsndServer::wave`]). +/// Implementations supply the list of audio formats the server offers, choose +/// which negotiated format to use once the client replies, and produce the +/// audio waves to stream (via [`RdpsndServer::wave`]). pub trait RdpsndServerHandler: Send + core::fmt::Debug { /// The audio formats the server advertises in the Server Audio Formats and /// Version PDU (MS-RDPEA 2.2.2.1). fn get_formats(&self) -> &[pdu::AudioFormat]; - /// Called once the client has replied with the formats it accepts - /// (`client_format`, the Client Audio Formats and Version PDU). Returns the - /// `wFormatNo` to stamp on every subsequent Wave/Wave2 PDU, or [`None`] if - /// no offered format is acceptable (no audio is then streamed). + /// Select which format to stream, once the client has replied with the + /// formats it accepts. /// - /// **The returned index addresses `client_format.formats` — the formats the - /// client just echoed back — NOT the server's own [`get_formats`] list.** - /// The client resolves each wave's format as `ClientFormats[wFormatNo]` - /// against the list *it* sent, and a compliant client rejects any - /// `wFormatNo >= client_format.formats.len()`, silently dropping all audio. - /// The client's list is its accepted subset of the server's formats, so the - /// two lists generally differ in both length and ordering; an index into - /// [`get_formats`] only happens to work when the chosen format sits at the - /// same position in both. Pick the format you intend to send, then return - /// its position within `client_format.formats`. + /// `common` is the set of formats from [`get_formats`] that the client also + /// advertised, in the server's preference order; each carries the + /// `wFormatNo` the client expects, so the crate — not the handler — owns + /// the index arithmetic and the MS-RDPEA rule that `wFormatNo` addresses + /// the *client's* list. `common` is never empty: when server and client + /// share no format, this method is not called and no audio is streamed. + /// + /// Return the [`NegotiatedFormat`] to stream (a reference borrowed from + /// `common`), or [`None`] to decline. Returning a borrow from `common` + /// — rather than an index or a constructed value — makes it impossible to + /// pick a format the client did not accept or to produce an invalid + /// `wFormatNo`. This is a pure selection step: any encoder/producer setup + /// belongs in [`start`], which the crate calls next with the chosen format. /// /// [`get_formats`]: RdpsndServerHandler::get_formats - fn start(&mut self, client_format: &ClientAudioFormatPdu) -> Option; + /// [`start`]: RdpsndServerHandler::start + fn choose_format<'a>(&mut self, common: &'a [NegotiatedFormat]) -> Option<&'a NegotiatedFormat>; + + /// Begin streaming with the `format` just selected by [`choose_format`]. + /// + /// Called once per session, immediately after a successful + /// [`choose_format`]. This is the lifecycle hook: initialize encoder state, + /// spawn the producer, etc. Waves are then emitted via [`RdpsndServer::wave`]. + /// + /// [`choose_format`]: RdpsndServerHandler::choose_format + fn start(&mut self, format: &NegotiatedFormat); /// Called when the audio stream is torn down (e.g. the client closed the /// channel or the session ended). @@ -173,6 +216,43 @@ impl RdpsndServer { } } +/// Build the set of formats common to the server (`server_formats`, kept in the +/// server's preference order) and the client (`client_formats`), each tagged +/// with its `wFormatNo` — its index in the *client's* list, which is what the +/// client resolves waves against (MS-RDPEA). The result mirrors the server's +/// ordering so the handler can express preference simply by `get_formats` +/// order, while the `wFormatNo` always points into the client list. +fn negotiate_formats( + server_formats: &[pdu::AudioFormat], + client_formats: &[pdu::AudioFormat], +) -> Vec { + server_formats + .iter() + .filter_map(|server_format| { + client_formats + .iter() + .position(|client_fmt| audio_format_eq(client_fmt, server_format)) + .and_then(|idx| u16::try_from(idx).ok()) + .map(|wformat_no| NegotiatedFormat { + format: server_format.clone(), + wformat_no, + }) + }) + .collect() +} + +/// Compare two audio formats by their WAVEFORMATEX identity — wave format tag, +/// channel count, sample rate, and bit depth. Derived fields +/// (`n_avg_bytes_per_sec`, `n_block_align`) and the codec-specific `data` blob +/// are deliberately ignored: a client echoes back a format it accepts but is +/// not guaranteed to reproduce those byte-for-byte. +fn audio_format_eq(a: &pdu::AudioFormat, b: &pdu::AudioFormat) -> bool { + a.format == b.format + && a.n_channels == b.n_channels + && a.n_samples_per_sec == b.n_samples_per_sec + && a.bits_per_sample == b.bits_per_sample +} + impl_as_any!(RdpsndServer); impl SvcProcessor for RdpsndServer { @@ -220,8 +300,27 @@ impl SvcProcessor for RdpsndServer { return Ok(vec![]); }; let client_format = self.client_format.as_ref().expect("available in this state"); + // Formats common to server and client, in the server's + // preference order, each tagged with its wFormatNo (its + // position in the *client's* list). Keeping this in the crate + // means the handler never does index arithmetic and can't emit + // an out-of-range wFormatNo. + let common = negotiate_formats(self.handler.get_formats(), &client_format.formats); self.state = RdpsndState::Ready; - self.format_no = self.handler.start(client_format); + self.format_no = if common.is_empty() { + debug!("No audio format in common with the client; audio disabled"); + None + } else if let Some(chosen) = self.handler.choose_format(&common) { + // `chosen` borrows `common` (not `self`), so the encoder + // is read off it and the handler is free to borrow again + // for the `start` lifecycle hook. + let wformat_no = chosen.wformat_no; + self.handler.start(chosen); + Some(wformat_no) + } else { + debug!("Handler declined every common audio format; audio disabled"); + None + }; vec![] } RdpsndState::Ready => { @@ -260,3 +359,77 @@ impl Drop for RdpsndServer { } impl SvcServerProcessor for RdpsndServer {} + +#[cfg(test)] +mod tests { + use super::{audio_format_eq, negotiate_formats}; + use crate::pdu::{AudioFormat, WaveFormat}; + + fn fmt(format: WaveFormat, rate: u32) -> AudioFormat { + AudioFormat { + format, + n_channels: 2, + n_samples_per_sec: rate, + n_avg_bytes_per_sec: rate * 4, + n_block_align: 4, + bits_per_sample: 16, + data: None, + } + } + + #[test] + fn wformat_no_addresses_the_client_list_not_the_server_list() { + // Server prefers AAC over PCM; the client lists them in the opposite + // order. wFormatNo must follow the CLIENT's indices. + let server = [fmt(WaveFormat::AAC_MS, 44100), fmt(WaveFormat::PCM, 44100)]; + let client = [fmt(WaveFormat::PCM, 44100), fmt(WaveFormat::AAC_MS, 44100)]; + + let common = negotiate_formats(&server, &client); + + // Ordering follows the server's preference (AAC first)... + assert_eq!(common.len(), 2); + assert_eq!(common[0].format().format, WaveFormat::AAC_MS); + assert_eq!(common[1].format().format, WaveFormat::PCM); + // ...but each wFormatNo is the position in the CLIENT list. + assert_eq!(common[0].wformat_no, 1); // AAC is client index 1 + assert_eq!(common[1].wformat_no, 0); // PCM is client index 0 + } + + #[test] + fn pcm_only_client_gets_a_valid_client_index() { + // Regression for the --enable-aac trap: server advertises [AAC, PCM] + // but a PCM-only client must get wFormatNo 0 (its sole index), not + // PCM's server-list index of 1 (which the client would reject). + let server = [fmt(WaveFormat::AAC_MS, 44100), fmt(WaveFormat::PCM, 44100)]; + let client = [fmt(WaveFormat::PCM, 44100)]; + + let common = negotiate_formats(&server, &client); + + assert_eq!(common.len(), 1); + assert_eq!(common[0].format().format, WaveFormat::PCM); + assert_eq!(common[0].wformat_no, 0); + } + + #[test] + fn no_shared_format_yields_empty() { + let server = [fmt(WaveFormat::OPUS, 48000)]; + let client = [fmt(WaveFormat::PCM, 44100)]; + assert!(negotiate_formats(&server, &client).is_empty()); + } + + #[test] + fn equality_uses_waveformatex_identity_only() { + let mut a = fmt(WaveFormat::PCM, 44100); + let mut b = fmt(WaveFormat::PCM, 44100); + // Differ only in derived/codec fields — still the same format. + b.n_avg_bytes_per_sec = 0; + b.n_block_align = 99; + a.data = Some(vec![1, 2, 3]); + b.data = None; + assert!(audio_format_eq(&a, &b)); + + // A differing identity field (sample rate) is a different format. + let c = fmt(WaveFormat::PCM, 48000); + assert!(!audio_format_eq(&a, &c)); + } +} diff --git a/crates/ironrdp/examples/server.rs b/crates/ironrdp/examples/server.rs index 71db6d021..785f0ebe3 100644 --- a/crates/ironrdp/examples/server.rs +++ b/crates/ironrdp/examples/server.rs @@ -11,8 +11,8 @@ use std::sync::{Arc, Mutex}; use anyhow::Context as _; use ironrdp::cliprdr::backend::{CliprdrBackend, CliprdrBackendFactory}; use ironrdp::connector::DesktopSize; -use ironrdp::rdpsnd::pdu::{AudioFormat, ClientAudioFormatPdu, WaveFormat}; -use ironrdp::rdpsnd::server::{RdpsndServerHandler, RdpsndServerMessage}; +use ironrdp::rdpsnd::pdu::{AudioFormat, WaveFormat}; +use ironrdp::rdpsnd::server::{NegotiatedFormat, RdpsndServerHandler, RdpsndServerMessage}; use ironrdp::server::tokio::sync::mpsc::UnboundedSender; use ironrdp::server::tokio::time::{self, Duration, sleep}; use ironrdp::server::{ @@ -255,17 +255,6 @@ struct SndHandler { task: Option>, } -impl SndHandler { - fn choose_format(&self, client_formats: &[AudioFormat]) -> Option { - for (n, fmt) in client_formats.iter().enumerate() { - if self.get_formats().contains(fmt) { - return u16::try_from(n).ok(); - } - } - None - } -} - impl RdpsndServerHandler for SndHandler { fn get_formats(&self) -> &[AudioFormat] { &[ @@ -290,14 +279,16 @@ impl RdpsndServerHandler for SndHandler { ] } - fn start(&mut self, client_format: &ClientAudioFormatPdu) -> Option { - debug!(?client_format); + fn choose_format<'a>(&mut self, common: &'a [NegotiatedFormat]) -> Option<&'a NegotiatedFormat> { + debug!(?common); - let Some(nfmt) = self.choose_format(&client_format.formats) else { - return Some(0); - }; + // The crate hands us the formats common to both peers in our preference + // order; take the most-preferred one. + common.first() + } - let fmt = client_format.formats[usize::from(nfmt)].clone(); + fn start(&mut self, format: &NegotiatedFormat) { + let fmt = format.format().clone(); let mut opus_enc = if fmt.format == WaveFormat::OPUS { let n_channels: opus2::Channels = match fmt.n_channels { @@ -305,7 +296,7 @@ impl RdpsndServerHandler for SndHandler { 2 => opus2::Channels::Stereo, n => { warn!("Invalid OPUS channels: {}", n); - return Some(0); + return; } }; @@ -313,7 +304,7 @@ impl RdpsndServerHandler for SndHandler { Ok(enc) => Some(enc), Err(err) => { warn!("Failed to create OPUS encoder: {}", err); - return Some(0); + return; } } } else { @@ -348,8 +339,6 @@ impl RdpsndServerHandler for SndHandler { ts = ts.wrapping_add(100); } })); - - Some(nfmt) } fn stop(&mut self) {