fix(providers): preserve OpenAI-compatible reasoning history

This commit is contained in:
YeonGyu-Kim
2026-06-08 01:23:13 +09:00
parent ae2f203eb5
commit c1646613d1
3 changed files with 58 additions and 104 deletions

View File

@@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
None None
} }
#[must_use] #[must_use]
pub fn strip_provider_prefix(canonical_model: &str) -> String { pub fn strip_provider_prefix(canonical_model: &str) -> String {
if let Some(pos) = canonical_model.find('/') { if let Some(pos) = canonical_model.find('/') {
@@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
} }
} }
#[must_use] #[must_use]
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics { pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
let resolved_model = resolve_model_alias(model); let resolved_model = resolve_model_alias(model);

View File

@@ -16,8 +16,9 @@ use crate::types::{
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
}; };
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix}; use super::{
preflight_message_request, resolve_model_alias, strip_provider_prefix, Provider, ProviderFuture,
};
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
@@ -213,80 +214,23 @@ impl OpenAiCompatClient {
} }
pub async fn send_message( pub async fn send_message(
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageResponse, ApiError> { ) -> Result<MessageResponse, ApiError> {
// 1. Keep track of what Claw originally asked for let original_model = request.model.clone();
let original_model = request.model.clone(); let canonical = resolve_model_alias(&request.model);
let canonical = resolve_model_alias(&request.model); let downstream_model = strip_provider_prefix(&canonical);
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
let downstream_model = strip_provider_prefix(&canonical);
let mut request = MessageRequest { let mut request = MessageRequest {
stream: false, stream: false,
..request.clone() ..request.clone()
}; };
request.model = downstream_model; // Use the clean name for the API payload request.model = downstream_model;
preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;
// Some backends return {"error":{"message":"...","type":"...","code":...}} preflight_message_request(&request)?;
// instead of a valid completion object. Check for this before attempting let response = self.send_with_retry(&request).await?;
// full deserialization so the user sees the actual error, not a cryptic. let request_id = request_id_from_headers(response.headers());
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) { let body = response.text().await.map_err(ApiError::from)?;
if let Some(err_obj) = raw.get("error") {
let msg = err_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("provider returned an error")
.to_string();
let code = err_obj
.get("code")
.and_then(serde_json::Value::as_u64)
.map(|c| c as u16);
return Err(ApiError::Api {
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
error_type: err_obj
.get("type")
.and_then(|t| t.as_str())
.map(str::to_owned),
message: Some(msg),
request_id,
body,
retryable: false,
suggested_action: suggested_action_for_status(
reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
),
retry_after: None,
});
}
}
// Pass original_model to the deserializer error context so debugging logs are accurate
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?;
let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
normalized.model = original_model;
Ok(normalized)
}
// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic
// "missing field 'id'" parse failure.
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) { if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(err_obj) = raw.get("error") { if let Some(err_obj) = raw.get("error") {
let msg = err_obj let msg = err_obj
@@ -318,41 +262,42 @@ impl OpenAiCompatClient {
} }
} }
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| { let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error) ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?; })?;
let mut normalized = normalize_response(&request.model, payload)?; let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() { if normalized.request_id.is_none() {
normalized.request_id = request_id; normalized.request_id = request_id;
} }
normalized.model = original_model;
Ok(normalized) Ok(normalized)
} }
pub async fn stream_message( pub async fn stream_message(
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageStream, ApiError> { ) -> Result<MessageStream, ApiError> {
// 1. Keep track of the original model name let original_model = request.model.clone();
let original_model = request.model.clone(); let canonical = resolve_model_alias(&request.model);
let canonical = resolve_model_alias(&request.model); let downstream_model = strip_provider_prefix(&canonical);
// 2. Clean it up for DeepSeek
let downstream_model = strip_provider_prefix(&canonical);
let mut streaming_request = request.clone().with_streaming(); let mut streaming_request = request.clone().with_streaming();
streaming_request.model = downstream_model; streaming_request.model = downstream_model;
preflight_message_request(&streaming_request)?; preflight_message_request(&streaming_request)?;
let response = self.send_with_retry(&streaming_request).await?; let response = self.send_with_retry(&streaming_request).await?;
Ok(MessageStream { Ok(MessageStream {
request_id: request_id_from_headers(response.headers()), request_id: request_id_from_headers(response.headers()),
response, response,
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()), parser: OpenAiSseParser::with_context(
pending: VecDeque::new(), self.config.provider_name,
done: false, original_model.clone(),
state: StreamState::new(original_model), // 3. Use the original name here ),
}) pending: VecDeque::new(),
} done: false,
state: StreamState::new(original_model),
})
}
async fn send_with_retry( async fn send_with_retry(
&self, &self,

View File

@@ -13737,8 +13737,15 @@ fn push_output_block(
}; };
*pending_tool = Some((id, name, initial_input)); *pending_tool = Some((id, name, initial_input));
} }
OutputContentBlock::Thinking { thinking, .. } => { OutputContentBlock::Thinking {
thinking,
signature,
} => {
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?; render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
events.push(AssistantEvent::Thinking {
thinking,
signature,
});
*block_has_thinking_summary = true; *block_has_thinking_summary = true;
} }
OutputContentBlock::RedactedThinking { .. } => { OutputContentBlock::RedactedThinking { .. } => {
@@ -19073,6 +19080,13 @@ UU conflicted.rs",
assert!(matches!( assert!(matches!(
&events[0], &events[0],
AssistantEvent::Thinking {
thinking,
signature
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
));
assert!(matches!(
&events[1],
AssistantEvent::TextDelta(text) if text == "Final answer" AssistantEvent::TextDelta(text) if text == "Final answer"
)); ));
let rendered = String::from_utf8(out).expect("utf8"); let rendered = String::from_utf8(out).expect("utf8");