diff --git a/benches/request_processing.rs b/benches/request_processing.rs index 9c1ed893..af8c7463 100644 --- a/benches/request_processing.rs +++ b/benches/request_processing.rs @@ -210,6 +210,10 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { tool_calls: None, function_call: None, reasoning: None, + reasoning_content: None, + think: None, + think_fast: None, + think_faster: None, }); } diff --git a/src/protocols/spec.rs b/src/protocols/spec.rs index 8b20bf9e..37d77761 100644 --- a/src/protocols/spec.rs +++ b/src/protocols/spec.rs @@ -87,6 +87,18 @@ pub enum ChatMessage { /// Reasoning content for reasoning models #[serde(skip_serializing_if = "Option::is_none")] reasoning: Option, + /// Reasoning content for reasoning models using vLLM-compatible field names + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, + /// Alternative reasoning field accepted by some chat templates + #[serde(skip_serializing_if = "Option::is_none")] + think: Option, + /// Alternative reasoning field accepted by some chat templates + #[serde(skip_serializing_if = "Option::is_none")] + think_fast: Option, + /// Alternative reasoning field accepted by some chat templates + #[serde(skip_serializing_if = "Option::is_none")] + think_faster: Option, }, Tool { role: String, // "tool" @@ -151,6 +163,34 @@ impl<'de> Deserialize<'de> for ChatMessage { r.as_str().map(String::from) } }), + reasoning_content: value.get("reasoning_content").and_then(|r| { + if r.is_null() { + None + } else { + r.as_str().map(String::from) + } + }), + think: value.get("think").and_then(|r| { + if r.is_null() { + None + } else { + r.as_str().map(String::from) + } + }), + think_fast: value.get("think_fast").and_then(|r| { + if r.is_null() { + None + } else { + r.as_str().map(String::from) + } + }), + think_faster: value.get("think_faster").and_then(|r| { + if r.is_null() { + None + } else { + r.as_str().map(String::from) + } + }), }), "system" => Ok(ChatMessage::System { role: role.to_string(), @@ -3672,6 +3712,10 @@ mod tests { tool_calls: None, function_call: None, reasoning: Some("Thinking...".to_string()), + reasoning_content: None, + think: None, + think_fast: None, + think_faster: None, }; let serialized = serde_json::to_string(&original).unwrap(); diff --git a/tests/test_extra_args_chat.rs b/tests/test_extra_args_chat.rs index 5599df2e..7a6b54d4 100644 --- a/tests/test_extra_args_chat.rs +++ b/tests/test_extra_args_chat.rs @@ -1,5 +1,20 @@ //! Tests that unknown/extra fields in ChatCompletionRequest are preserved through serde roundtrip. -use vllm_router_rs::protocols::spec::ChatCompletionRequest; +use vllm_router_rs::protocols::spec::{ChatCompletionRequest, ChatMessage}; + +fn assert_assistant_field_roundtrip(field: &str, expected: &str) { + let request_json = serde_json::json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi", field: expected}, + ], + }); + + let req: ChatCompletionRequest = serde_json::from_value(request_json).unwrap(); + let value = serde_json::to_value(&req).unwrap(); + + assert_eq!(value["messages"][1][field], serde_json::json!(expected)); +} #[test] fn test_extra_fields_preserved_on_deserialize() { @@ -71,3 +86,107 @@ fn test_no_other_fields_gives_empty_map() { let req: ChatCompletionRequest = serde_json::from_str(json).unwrap(); assert!(req.other.is_empty()); } + +#[test] +fn test_assistant_reasoning_content_empty_survives_roundtrip() { + let json = r#"{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "reasoning_content": "", "content": "Hi"} + ] + }"#; + + let req: ChatCompletionRequest = serde_json::from_str(json).unwrap(); + match &req.messages[1] { + ChatMessage::Assistant { + reasoning_content, .. + } => { + assert_eq!(reasoning_content.as_deref(), Some("")); + } + other => panic!("expected assistant message, got {other:?}"), + } + + let value = serde_json::to_value(&req).unwrap(); + assert_eq!( + value["messages"][1]["reasoning_content"], + serde_json::json!("") + ); +} + +#[test] +fn test_assistant_reasoning_content_nonempty_survives_roundtrip() { + let json = r#"{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "reasoning_content": "scratchpad", "content": "Hi"} + ] + }"#; + + let req: ChatCompletionRequest = serde_json::from_str(json).unwrap(); + match &req.messages[1] { + ChatMessage::Assistant { + reasoning_content, .. + } => { + assert_eq!(reasoning_content.as_deref(), Some("scratchpad")); + } + other => panic!("expected assistant message, got {other:?}"), + } + + let value = serde_json::to_value(&req).unwrap(); + assert_eq!( + value["messages"][1]["reasoning_content"], + serde_json::json!("scratchpad") + ); +} + +#[test] +fn test_assistant_reasoning_and_reasoning_content_survive_roundtrip() { + let json = r#"{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "reasoning": "", + "reasoning_content": "", + "content": "Hi" + } + ] + }"#; + + let req: ChatCompletionRequest = serde_json::from_str(json).unwrap(); + match &req.messages[1] { + ChatMessage::Assistant { + reasoning, + reasoning_content, + .. + } => { + assert_eq!(reasoning.as_deref(), Some("")); + assert_eq!(reasoning_content.as_deref(), Some("")); + } + other => panic!("expected assistant message, got {other:?}"), + } + + let value = serde_json::to_value(&req).unwrap(); + assert_eq!(value["messages"][1]["reasoning"], serde_json::json!("")); + assert_eq!( + value["messages"][1]["reasoning_content"], + serde_json::json!("") + ); +} + +#[test] +fn test_assistant_thinking_aliases_empty_survive_roundtrip() { + for field in ["think", "think_fast", "think_faster"] { + assert_assistant_field_roundtrip(field, ""); + } +} + +#[test] +fn test_assistant_thinking_aliases_nonempty_survive_roundtrip() { + for field in ["think", "think_fast", "think_faster"] { + assert_assistant_field_roundtrip(field, "scratchpad"); + } +}