diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index b3800d6a..f66710ab 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::{BTreeMap, VecDeque}; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -267,14 +268,18 @@ impl OpenAiCompatClient { request: &MessageRequest, ) -> Result { // Pre-flight check: verify request body size against provider limits - check_request_body_size(request, self.config())?; + check_request_body_size_for_base_url(request, self.config(), &self.base_url)?; let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") .bearer_auth(&self.api_key) - .json(&build_chat_completion_request(request, self.config())) + .json(&build_chat_completion_request_for_base_url( + request, + self.config(), + &self.base_url, + )) .send() .await .map_err(ApiError::from) @@ -882,10 +887,50 @@ fn strip_routing_prefix(model: &str) -> &str { } } +fn wire_model_for_base_url<'a>( + model: &'a str, + config: OpenAiCompatConfig, + base_url: &str, +) -> Cow<'a, str> { + let Some(pos) = model.find('/') else { + return Cow::Borrowed(model); + }; + let prefix = &model[..pos]; + let lowered_prefix = prefix.to_ascii_lowercase(); + + if lowered_prefix == "openai" { + let trimmed_base_url = base_url.trim_end_matches('/'); + let default_openai = DEFAULT_OPENAI_BASE_URL.trim_end_matches('/'); + if config.provider_name == "OpenAI" && trimmed_base_url != default_openai { + // OpenAI-compatible gateways such as OpenRouter commonly use + // slash-containing model slugs (for example `openai/gpt-4.1-mini`). + // Preserve the slug when the user configured a non-default OpenAI + // base URL; the prefix still routed to the OpenAI-compatible client, + // but the gateway owns the final model namespace. + return Cow::Borrowed(model); + } + return Cow::Borrowed(&model[pos + 1..]); + } + + if matches!(lowered_prefix.as_str(), "xai" | "grok" | "qwen" | "kimi") { + return Cow::Borrowed(&model[pos + 1..]); + } + + Cow::Borrowed(model) +} + /// Estimate the serialized JSON size of a request payload in bytes. /// This is a pre-flight check to avoid hitting provider-specific size limits. pub fn estimate_request_body_size(request: &MessageRequest, config: OpenAiCompatConfig) -> usize { - let payload = build_chat_completion_request(request, config); + estimate_request_body_size_for_base_url(request, config, &read_base_url(config)) +} + +fn estimate_request_body_size_for_base_url( + request: &MessageRequest, + config: OpenAiCompatConfig, + base_url: &str, +) -> usize { + let payload = build_chat_completion_request_for_base_url(request, config, base_url); // serde_json::to_vec gives us the exact byte size of the serialized JSON serde_json::to_vec(&payload).map_or(0, |v| v.len()) } @@ -897,7 +942,15 @@ pub fn check_request_body_size( request: &MessageRequest, config: OpenAiCompatConfig, ) -> Result<(), ApiError> { - let estimated_bytes = estimate_request_body_size(request, config); + check_request_body_size_for_base_url(request, config, &read_base_url(config)) +} + +fn check_request_body_size_for_base_url( + request: &MessageRequest, + config: OpenAiCompatConfig, + base_url: &str, +) -> Result<(), ApiError> { + let estimated_bytes = estimate_request_body_size_for_base_url(request, config, base_url); let max_bytes = config.max_request_body_bytes; if estimated_bytes > max_bytes { @@ -916,6 +969,14 @@ pub fn check_request_body_size( pub fn build_chat_completion_request( request: &MessageRequest, config: OpenAiCompatConfig, +) -> Value { + build_chat_completion_request_for_base_url(request, config, &read_base_url(config)) +} + +fn build_chat_completion_request_for_base_url( + request: &MessageRequest, + config: OpenAiCompatConfig, + base_url: &str, ) -> Value { let mut messages = Vec::new(); if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { @@ -924,8 +985,10 @@ pub fn build_chat_completion_request( "content": system, })); } - // Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire. - let wire_model = strip_routing_prefix(&request.model); + // Resolve the transport routing prefix into the wire model. Custom + // OpenAI-compatible gateways may require slash-containing slugs intact. + let wire_model = wire_model_for_base_url(&request.model, config, base_url); + let wire_model = wire_model.as_ref(); for message in &request.messages { messages.extend(translate_message(message, wire_model)); } @@ -994,9 +1057,29 @@ pub fn build_chat_completion_request( payload["reasoning_effort"] = json!(effort); } + for (key, value) in &request.extra_body { + if is_protected_extra_body_key(key) { + continue; + } + payload[key] = value.clone(); + } + payload } +fn is_protected_extra_body_key(key: &str) -> bool { + matches!( + key, + "model" + | "messages" + | "stream" + | "tools" + | "tool_choice" + | "max_tokens" + | "max_completion_tokens" + ) +} + /// Returns true for models that do NOT support the `is_error` field in tool results. /// kimi models (via Moonshot AI/Dashscope) reject this field with 400 Bad Request. /// Returns true for models that do NOT support the `is_error` field in tool results. diff --git a/rust/crates/api/src/types.rs b/rust/crates/api/src/types.rs index 0d41db19..3ec7f879 100644 --- a/rust/crates/api/src/types.rs +++ b/rust/crates/api/src/types.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -31,6 +33,14 @@ pub struct MessageRequest { /// Silently ignored by backends that do not support it. #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_effort: Option, + /// Provider-specific OpenAI-compatible request body parameters. These are + /// copied into the final JSON payload after core fields are populated so + /// users can opt into gateway features such as `web_search_options`, + /// `parallel_tool_calls`, or custom local-server switches without waiting + /// for first-class typed fields. Core protocol keys are protected and cannot + /// be overridden through this map. + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub extra_body: BTreeMap, } impl MessageRequest {