omx(team): auto-checkpoint worker-1 [1]

This commit is contained in:
bellman
2026-05-15 10:27:02 +09:00
parent 2cac66cd38
commit a212c662e5
4 changed files with 161 additions and 47 deletions

View File

@@ -27,10 +27,8 @@ pub use providers::openai_compat::{
};
pub use providers::{
detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override,
model_family_identity_for, model_family_identity_for_kind, provider_capabilities_for_model,
provider_diagnostics_for_request, resolve_model_alias, ProviderCapabilityReport,
ProviderDiagnostic, ProviderDiagnosticSeverity, ProviderFeatureSupport, ProviderKind,
ProviderWireProtocol,
model_family_identity_for, model_family_identity_for_kind, provider_diagnostics_for_model,
resolve_model_alias, ProviderDiagnostics, ProviderKind,
};
pub use sse::{parse_frame, SseParser};
pub use types::{

View File

@@ -29,7 +29,6 @@ pub trait Provider {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderKind {
Anthropic,
Xai,
@@ -50,53 +49,22 @@ pub struct ModelTokenLimit {
pub context_window_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderWireProtocol {
AnthropicMessages,
OpenAiChatCompletions,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderFeatureSupport {
Supported,
Unsupported,
PassthroughAsTool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct ProviderCapabilityReport {
pub struct ProviderDiagnostics {
pub requested_model: String,
pub resolved_model: String,
pub provider: ProviderKind,
pub wire_protocol: ProviderWireProtocol,
pub auth_env: &'static str,
pub base_url_env: &'static str,
pub default_base_url: &'static str,
pub tool_calls: ProviderFeatureSupport,
pub streaming: ProviderFeatureSupport,
pub streaming_usage: ProviderFeatureSupport,
pub prompt_cache: ProviderFeatureSupport,
pub custom_parameters: ProviderFeatureSupport,
pub reasoning_effort: ProviderFeatureSupport,
pub reasoning_content_history: ProviderFeatureSupport,
pub fixed_sampling_reasoning_models: ProviderFeatureSupport,
pub web_search: ProviderFeatureSupport,
pub web_fetch: ProviderFeatureSupport,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderDiagnosticSeverity {
Info,
Warning,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct ProviderDiagnostic {
pub code: &'static str,
pub severity: ProviderDiagnosticSeverity,
pub message: String,
pub action: String,
pub openai_compatible: bool,
pub reasoning_model: bool,
pub preserves_reasoning_content_in_history: bool,
pub strips_tuning_params: bool,
pub supports_stream_usage: bool,
pub honors_proxy_env: bool,
pub supports_extra_body_params: bool,
pub preserves_slash_model_ids_on_custom_base_url: bool,
}
const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
@@ -269,6 +237,55 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
None
}
#[must_use]
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
let resolved_model = resolve_model_alias(model);
let metadata =
metadata_for_model(&resolved_model).unwrap_or_else(|| {
match detect_provider_kind(&resolved_model) {
ProviderKind::Anthropic => ProviderMetadata {
provider: ProviderKind::Anthropic,
auth_env: "ANTHROPIC_API_KEY",
base_url_env: "ANTHROPIC_BASE_URL",
default_base_url: anthropic::DEFAULT_BASE_URL,
},
ProviderKind::Xai => ProviderMetadata {
provider: ProviderKind::Xai,
auth_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
},
ProviderKind::OpenAi => ProviderMetadata {
provider: ProviderKind::OpenAi,
auth_env: "OPENAI_API_KEY",
base_url_env: "OPENAI_BASE_URL",
default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
},
}
});
let openai_compatible = matches!(metadata.provider, ProviderKind::OpenAi | ProviderKind::Xai);
let reasoning_model = openai_compatible && openai_compat::is_reasoning_model(&resolved_model);
ProviderDiagnostics {
requested_model: model.to_string(),
resolved_model: resolved_model.clone(),
provider: metadata.provider,
auth_env: metadata.auth_env,
base_url_env: metadata.base_url_env,
default_base_url: metadata.default_base_url,
openai_compatible,
reasoning_model,
preserves_reasoning_content_in_history: openai_compatible
&& openai_compat::model_requires_reasoning_content_in_history(&resolved_model),
strips_tuning_params: reasoning_model,
supports_stream_usage: metadata.provider == ProviderKind::OpenAi
&& metadata.default_base_url == openai_compat::DEFAULT_OPENAI_BASE_URL,
honors_proxy_env: true,
supports_extra_body_params: openai_compatible,
preserves_slash_model_ids_on_custom_base_url: metadata.provider == ProviderKind::OpenAi,
}
}
#[must_use]
pub fn detect_provider_kind(model: &str) -> ProviderKind {
if let Some(metadata) = metadata_for_model(model) {
@@ -1026,6 +1043,19 @@ mod tests {
assert_eq!(super::resolve_model_alias("KIMI"), "kimi-k2.5"); // case insensitive
}
#[test]
fn provider_diagnostics_explain_openai_compatible_capabilities() {
let diagnostics = super::provider_diagnostics_for_model("openai/deepseek-v4-pro");
assert_eq!(diagnostics.provider, ProviderKind::OpenAi);
assert_eq!(diagnostics.auth_env, "OPENAI_API_KEY");
assert!(diagnostics.openai_compatible);
assert!(diagnostics.preserves_reasoning_content_in_history);
assert!(diagnostics.supports_extra_body_params);
assert!(diagnostics.honors_proxy_env);
assert!(diagnostics.preserves_slash_model_ids_on_custom_base_url);
}
#[test]
fn keeps_existing_max_token_heuristic() {
assert_eq!(max_tokens_for_model("opus"), 32_000);

View File

@@ -2053,6 +2053,39 @@ mod tests {
assert_eq!(payload["stop"], json!(["\n"]));
}
#[test]
fn extra_body_params_are_passed_through_without_overriding_core_fields() {
let mut extra_body = BTreeMap::new();
extra_body.insert(
"web_search_options".to_string(),
json!({"search_context_size": "medium"}),
);
extra_body.insert("parallel_tool_calls".to_string(), json!(false));
extra_body.insert("model".to_string(), json!("bad-override"));
extra_body.insert("messages".to_string(), json!([]));
extra_body.insert("max_tokens".to_string(), json!(1));
let payload = build_chat_completion_request(
&MessageRequest {
model: "gpt-4o".to_string(),
max_tokens: 1024,
messages: vec![InputMessage::user_text("hello")],
extra_body,
..Default::default()
},
OpenAiCompatConfig::openai(),
);
assert_eq!(payload["model"], json!("gpt-4o"));
assert_eq!(payload["max_tokens"], json!(1024));
assert_eq!(payload["messages"].as_array().map(Vec::len), Some(1));
assert_eq!(
payload["web_search_options"],
json!({"search_context_size": "medium"})
);
assert_eq!(payload["parallel_tool_calls"], json!(false));
}
#[test]
fn reasoning_model_strips_tuning_params() {
let request = MessageRequest {

View File

@@ -161,6 +161,59 @@ async fn send_message_preserves_deepseek_reasoning_content_before_text() {
);
}
#[tokio::test]
async fn custom_openai_gateway_preserves_slash_model_ids_and_extra_body_params() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"chatcmpl_slash_model\",",
"\"model\":\"openai/gpt-4.1-mini\",",
"\"choices\":[{",
"\"message\":{\"role\":\"assistant\",\"content\":\"Gateway accepted slug\",\"tool_calls\":[]},",
"\"finish_reason\":\"stop\"",
"}],",
"\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":2}",
"}"
);
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", body)],
)
.await;
let mut extra_body = std::collections::BTreeMap::new();
extra_body.insert(
"web_search_options".to_string(),
json!({"search_context_size": "low"}),
);
extra_body.insert("parallel_tool_calls".to_string(), json!(false));
extra_body.insert("model".to_string(), json!("malicious-override"));
let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
.with_base_url(server.base_url());
let response = client
.send_message(&MessageRequest {
model: "openai/gpt-4.1-mini".to_string(),
extra_body,
..sample_request(false)
})
.await
.expect("gateway request should succeed");
assert_eq!(response.model, "openai/gpt-4.1-mini");
assert_eq!(response.total_tokens(), 5);
let captured = state.lock().await;
let request = captured.first().expect("captured request");
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["model"], json!("openai/gpt-4.1-mini"));
assert_eq!(
body["web_search_options"],
json!({"search_context_size": "low"})
);
assert_eq!(body["parallel_tool_calls"], json!(false));
}
#[tokio::test]
async fn send_message_blocks_oversized_xai_requests_before_the_http_call() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));