diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 673ec29e..e6624a3e 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -27,10 +27,8 @@ pub use providers::openai_compat::{ }; pub use providers::{ detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, - model_family_identity_for, model_family_identity_for_kind, provider_capabilities_for_model, - provider_diagnostics_for_request, resolve_model_alias, ProviderCapabilityReport, - ProviderDiagnostic, ProviderDiagnosticSeverity, ProviderFeatureSupport, ProviderKind, - ProviderWireProtocol, + model_family_identity_for, model_family_identity_for_kind, provider_diagnostics_for_model, + resolve_model_alias, ProviderDiagnostics, ProviderKind, }; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 37d21c57..2d37bd69 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -29,7 +29,6 @@ pub trait Provider { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] -#[serde(rename_all = "snake_case")] pub enum ProviderKind { Anthropic, Xai, @@ -50,53 +49,22 @@ pub struct ModelTokenLimit { pub context_window_tokens: u32, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ProviderWireProtocol { - AnthropicMessages, - OpenAiChatCompletions, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ProviderFeatureSupport { - Supported, - Unsupported, - PassthroughAsTool, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct ProviderCapabilityReport { +pub struct ProviderDiagnostics { + pub requested_model: String, + pub resolved_model: String, pub provider: ProviderKind, - pub wire_protocol: ProviderWireProtocol, pub auth_env: &'static str, pub base_url_env: &'static str, pub default_base_url: &'static str, - pub tool_calls: ProviderFeatureSupport, - pub streaming: ProviderFeatureSupport, - pub streaming_usage: ProviderFeatureSupport, - pub prompt_cache: ProviderFeatureSupport, - pub custom_parameters: ProviderFeatureSupport, - pub reasoning_effort: ProviderFeatureSupport, - pub reasoning_content_history: ProviderFeatureSupport, - pub fixed_sampling_reasoning_models: ProviderFeatureSupport, - pub web_search: ProviderFeatureSupport, - pub web_fetch: ProviderFeatureSupport, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] -#[serde(rename_all = "snake_case")] -pub enum ProviderDiagnosticSeverity { - Info, - Warning, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct ProviderDiagnostic { - pub code: &'static str, - pub severity: ProviderDiagnosticSeverity, - pub message: String, - pub action: String, + pub openai_compatible: bool, + pub reasoning_model: bool, + pub preserves_reasoning_content_in_history: bool, + pub strips_tuning_params: bool, + pub supports_stream_usage: bool, + pub honors_proxy_env: bool, + pub supports_extra_body_params: bool, + pub preserves_slash_model_ids_on_custom_base_url: bool, } const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ @@ -269,6 +237,55 @@ pub fn metadata_for_model(model: &str) -> Option { None } +#[must_use] +pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics { + let resolved_model = resolve_model_alias(model); + let metadata = + metadata_for_model(&resolved_model).unwrap_or_else(|| { + match detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ProviderKind::Xai => ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ProviderKind::OpenAi => ProviderMetadata { + provider: ProviderKind::OpenAi, + auth_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL, + }, + } + }); + let openai_compatible = matches!(metadata.provider, ProviderKind::OpenAi | ProviderKind::Xai); + let reasoning_model = openai_compatible && openai_compat::is_reasoning_model(&resolved_model); + + ProviderDiagnostics { + requested_model: model.to_string(), + resolved_model: resolved_model.clone(), + provider: metadata.provider, + auth_env: metadata.auth_env, + base_url_env: metadata.base_url_env, + default_base_url: metadata.default_base_url, + openai_compatible, + reasoning_model, + preserves_reasoning_content_in_history: openai_compatible + && openai_compat::model_requires_reasoning_content_in_history(&resolved_model), + strips_tuning_params: reasoning_model, + supports_stream_usage: metadata.provider == ProviderKind::OpenAi + && metadata.default_base_url == openai_compat::DEFAULT_OPENAI_BASE_URL, + honors_proxy_env: true, + supports_extra_body_params: openai_compatible, + preserves_slash_model_ids_on_custom_base_url: metadata.provider == ProviderKind::OpenAi, + } +} + #[must_use] pub fn detect_provider_kind(model: &str) -> ProviderKind { if let Some(metadata) = metadata_for_model(model) { @@ -1026,6 +1043,19 @@ mod tests { assert_eq!(super::resolve_model_alias("KIMI"), "kimi-k2.5"); // case insensitive } + #[test] + fn provider_diagnostics_explain_openai_compatible_capabilities() { + let diagnostics = super::provider_diagnostics_for_model("openai/deepseek-v4-pro"); + + assert_eq!(diagnostics.provider, ProviderKind::OpenAi); + assert_eq!(diagnostics.auth_env, "OPENAI_API_KEY"); + assert!(diagnostics.openai_compatible); + assert!(diagnostics.preserves_reasoning_content_in_history); + assert!(diagnostics.supports_extra_body_params); + assert!(diagnostics.honors_proxy_env); + assert!(diagnostics.preserves_slash_model_ids_on_custom_base_url); + } + #[test] fn keeps_existing_max_token_heuristic() { assert_eq!(max_tokens_for_model("opus"), 32_000); diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 2202dc61..84ad775e 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -2053,6 +2053,39 @@ mod tests { assert_eq!(payload["stop"], json!(["\n"])); } + #[test] + fn extra_body_params_are_passed_through_without_overriding_core_fields() { + let mut extra_body = BTreeMap::new(); + extra_body.insert( + "web_search_options".to_string(), + json!({"search_context_size": "medium"}), + ); + extra_body.insert("parallel_tool_calls".to_string(), json!(false)); + extra_body.insert("model".to_string(), json!("bad-override")); + extra_body.insert("messages".to_string(), json!([])); + extra_body.insert("max_tokens".to_string(), json!(1)); + + let payload = build_chat_completion_request( + &MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 1024, + messages: vec![InputMessage::user_text("hello")], + extra_body, + ..Default::default() + }, + OpenAiCompatConfig::openai(), + ); + + assert_eq!(payload["model"], json!("gpt-4o")); + assert_eq!(payload["max_tokens"], json!(1024)); + assert_eq!(payload["messages"].as_array().map(Vec::len), Some(1)); + assert_eq!( + payload["web_search_options"], + json!({"search_context_size": "medium"}) + ); + assert_eq!(payload["parallel_tool_calls"], json!(false)); + } + #[test] fn reasoning_model_strips_tuning_params() { let request = MessageRequest { diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index d9883a22..f7754551 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -161,6 +161,59 @@ async fn send_message_preserves_deepseek_reasoning_content_before_text() { ); } +#[tokio::test] +async fn custom_openai_gateway_preserves_slash_model_ids_and_extra_body_params() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_slash_model\",", + "\"model\":\"openai/gpt-4.1-mini\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Gateway accepted slug\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":2}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let mut extra_body = std::collections::BTreeMap::new(); + extra_body.insert( + "web_search_options".to_string(), + json!({"search_context_size": "low"}), + ); + extra_body.insert("parallel_tool_calls".to_string(), json!(false)); + extra_body.insert("model".to_string(), json!("malicious-override")); + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&MessageRequest { + model: "openai/gpt-4.1-mini".to_string(), + extra_body, + ..sample_request(false) + }) + .await + .expect("gateway request should succeed"); + + assert_eq!(response.model, "openai/gpt-4.1-mini"); + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("openai/gpt-4.1-mini")); + assert_eq!( + body["web_search_options"], + json!({"search_context_size": "low"}) + ); + assert_eq!(body["parallel_tool_calls"], json!(false)); +} + #[tokio::test] async fn send_message_blocks_oversized_xai_requests_before_the_http_call() { let state = Arc::new(Mutex::new(Vec::::new()));