From 0755ddff3ca649866bdd954396e43b4e286d2495 Mon Sep 17 00:00:00 2001 From: Ajinkya-Ghuge Date: Sat, 6 Jun 2026 22:29:59 +0530 Subject: [PATCH] fix(providers): strip provider prefix from model names for openai_compat endpoints --- rust/crates/api/src/providers/mod.rs | 14 ++ .../crates/api/src/providers/openai_compat.rs | 127 ++++++++++++++---- 2 files changed, 112 insertions(+), 29 deletions(-) diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 57dce27e..054c335c 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -296,6 +296,20 @@ pub fn metadata_for_model(model: &str) -> Option { None } + + + +#[must_use] +pub fn strip_provider_prefix(canonical_model: &str) -> String { + if let Some(pos) = canonical_model.find('/') { + canonical_model[pos + 1..].to_string() + } else { + canonical_model.to_string() + } +} + + + #[must_use] pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics { let resolved_model = resolve_model_alias(model); diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 7f82d00c..c378b585 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -16,7 +16,8 @@ use crate::types::{ ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; -use super::{preflight_message_request, Provider, ProviderFuture}; +use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix}; + pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; @@ -212,17 +213,76 @@ impl OpenAiCompatClient { } pub async fn send_message( - &self, - request: &MessageRequest, - ) -> Result { - let request = MessageRequest { - stream: false, - ..request.clone() - }; - preflight_message_request(&request)?; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let body = response.text().await.map_err(ApiError::from)?; + &self, + request: &MessageRequest, +) -> Result { + // 1. Keep track of what Claw originally asked for + let original_model = request.model.clone(); + let canonical = resolve_model_alias(&request.model); + + // 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash") + let downstream_model = strip_provider_prefix(&canonical); + + let mut request = MessageRequest { + stream: false, + ..request.clone() + }; + request.model = downstream_model; // Use the clean name for the API payload + + preflight_message_request(&request)?; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let body = response.text().await.map_err(ApiError::from)?; + + // Some backends return {"error":{"message":"...","type":"...","code":...}} + // instead of a valid completion object. Check for this before attempting + // full deserialization so the user sees the actual error, not a cryptic. + if let Ok(raw) = serde_json::from_str::(&body) { + if let Some(err_obj) = raw.get("error") { + let msg = err_obj + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("provider returned an error") + .to_string(); + let code = err_obj + .get("code") + .and_then(serde_json::Value::as_u64) + .map(|c| c as u16); + return Err(ApiError::Api { + status: reqwest::StatusCode::from_u16(code.unwrap_or(400)) + .unwrap_or(reqwest::StatusCode::BAD_REQUEST), + error_type: err_obj + .get("type") + .and_then(|t| t.as_str()) + .map(str::to_owned), + message: Some(msg), + request_id, + body, + retryable: false, + suggested_action: suggested_action_for_status( + reqwest::StatusCode::from_u16(code.unwrap_or(400)) + .unwrap_or(reqwest::StatusCode::BAD_REQUEST), + ), + retry_after: None, + }); + } + } + + // Pass original_model to the deserializer error context so debugging logs are accurate + let payload = serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error) + })?; + + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + + // 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy + normalized.model = original_model; + + Ok(normalized) +} // Some backends return {"error":{"message":"...","type":"...","code":...}} // instead of a valid completion object. Check for this before attempting // full deserialization so the user sees the actual error, not a cryptic @@ -267,23 +327,32 @@ impl OpenAiCompatClient { Ok(normalized) } - pub async fn stream_message( - &self, - request: &MessageRequest, - ) -> Result { - preflight_message_request(request)?; - let response = self - .send_with_retry(&request.clone().with_streaming()) - .await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()), - pending: VecDeque::new(), - done: false, - state: StreamState::new(request.model.clone()), - }) - } +pub async fn stream_message( + &self, + request: &MessageRequest, +) -> Result { + // 1. Keep track of the original model name + let original_model = request.model.clone(); + let canonical = resolve_model_alias(&request.model); + + // 2. Clean it up for DeepSeek + let downstream_model = strip_provider_prefix(&canonical); + + let mut streaming_request = request.clone().with_streaming(); + streaming_request.model = downstream_model; + + preflight_message_request(&streaming_request)?; + let response = self.send_with_retry(&streaming_request).await?; + + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()), + pending: VecDeque::new(), + done: false, + state: StreamState::new(original_model), // 3. Use the original name here + }) +} async fn send_with_retry( &self,