diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index 0cbe8732..fa3789fd 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::ffi::OsString; use std::sync::Arc; use std::sync::{Mutex as StdMutex, OnceLock}; +use std::time::Duration; use api::{ ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -66,6 +67,56 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() { assert_eq!(body["tools"][0]["type"], json!("function")); } +#[tokio::test] +async fn send_message_passes_optional_openai_compatible_parameters_on_wire() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_params\",", + "\"model\":\"gpt-4o\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Parameters preserved\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":2}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&MessageRequest { + model: "gpt-4o".to_string(), + temperature: Some(0.2), + top_p: Some(0.8), + frequency_penalty: Some(0.15), + presence_penalty: Some(0.25), + stop: Some(vec!["END".to_string()]), + reasoning_effort: Some("low".to_string()), + ..sample_request(false) + }) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("gpt-4o")); + assert_eq!(body["temperature"], json!(0.2)); + assert_eq!(body["top_p"], json!(0.8)); + assert_eq!(body["frequency_penalty"], json!(0.15)); + assert_eq!(body["presence_penalty"], json!(0.25)); + assert_eq!(body["stop"], json!(["END"])); + assert_eq!(body["reasoning_effort"], json!("low")); +} + #[tokio::test] async fn send_message_preserves_deepseek_reasoning_content_before_text() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -280,6 +331,65 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() { assert!(request.body.contains("\"stream\":true")); } +#[allow(clippy::await_holding_lock)] +#[tokio::test] +async fn stream_message_retries_retryable_sse_handshake_failures() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream_retry\",\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"Recovered\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream_retry\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![ + http_response( + "500 Internal Server Error", + "application/json", + "{\"error\":{\"message\":\"try again\",\"type\":\"server_error\",\"code\":500}}", + ), + http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_stream_retry")], + ), + ], + ) + .await; + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url(server.base_url()) + .with_retry_policy(1, Duration::ZERO, Duration::ZERO); + let mut stream = client + .stream_message(&MessageRequest { + model: "gpt-4o".to_string(), + ..sample_request(false) + }) + .await + .expect("stream should retry once then start"); + + assert_eq!(stream.request_id(), Some("req_stream_retry")); + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + assert!(events.iter().any(|event| matches!( + event, + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { text }, + .. + }) if text == "Recovered" + ))); + + let captured = state.lock().await; + assert_eq!(captured.len(), 2, "one original request plus one retry"); + for request in captured.iter() { + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["stream"], json!(true)); + } +} + #[allow(clippy::await_holding_lock)] #[tokio::test] async fn openai_streaming_requests_opt_into_usage_chunks() { @@ -358,6 +468,47 @@ async fn openai_streaming_requests_opt_into_usage_chunks() { assert_eq!(body["stream_options"], json!({"include_usage": true})); } +#[allow(clippy::await_holding_lock)] +#[tokio::test] +async fn openai_compatible_client_honors_http_proxy_for_requests() { + let _lock = env_lock(); + let state = Arc::new(Mutex::new(Vec::::new())); + let proxy = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_proxy\",\"model\":\"gpt-4o\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Via proxy\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":4,\"completion_tokens\":3}}", + )], + ) + .await; + let _http_proxy = ScopedEnvVar::set("HTTP_PROXY", proxy.base_url()); + let _https_proxy = ScopedEnvVar::unset("HTTPS_PROXY"); + let _no_proxy = ScopedEnvVar::unset("NO_PROXY"); + let _http_proxy_lower = ScopedEnvVar::unset("http_proxy"); + let _https_proxy_lower = ScopedEnvVar::unset("https_proxy"); + let _no_proxy_lower = ScopedEnvVar::unset("no_proxy"); + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url("http://origin.invalid/v1"); + let response = client + .send_message(&MessageRequest { + model: "gpt-4o".to_string(), + ..sample_request(false) + }) + .await + .expect("proxy should return the OpenAI-compatible response"); + + assert_eq!(response.total_tokens(), 7); + let captured = state.lock().await; + let request = captured.first().expect("proxy should capture request"); + assert_eq!(request.path, "http://origin.invalid/v1/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer openai-test-key") + ); +} + #[allow(clippy::await_holding_lock)] #[tokio::test] async fn provider_client_dispatches_xai_requests_from_env() { @@ -568,6 +719,12 @@ impl ScopedEnvVar { std::env::set_var(key, value); Self { key, previous } } + + fn unset(key: &'static str) -> Self { + let previous = std::env::var_os(key); + std::env::remove_var(key); + Self { key, previous } + } } impl Drop for ScopedEnvVar {