fix: route local OpenAI-compatible models

This commit is contained in:
bellman
2026-06-03 23:16:46 +09:00
parent 9522674c87
commit bcc5bfde9c
7 changed files with 264 additions and 40 deletions

View File

@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::collections::{BTreeMap, VecDeque};
use std::net::Ipv4Addr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -131,13 +132,22 @@ impl OpenAiCompatClient {
}
pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
return Err(ApiError::missing_credentials(
config.provider_name,
config.credential_env_vars(),
));
let base_url = read_base_url(config);
let api_key = match read_env_non_empty(config.api_key_env)? {
Some(api_key) => api_key,
None if config.provider_name == "OpenAI"
&& is_local_openai_compatible_base_url(&base_url) =>
{
"local-dev-token".to_string()
}
None => {
return Err(ApiError::missing_credentials(
config.provider_name,
config.credential_env_vars(),
));
}
};
Ok(Self::new(api_key, config))
Ok(Self::new(api_key, config).with_base_url(base_url))
}
#[must_use]
@@ -915,14 +925,18 @@ pub fn model_requires_reasoning_content_in_history(model: &str) -> bool {
/// Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire.
/// The prefix is used only to select transport; the backend expects the
/// bare model id.
/// bare model id. Use `local/` to force OpenAI-compatible routing while
/// preserving any slashes that follow the prefix.
#[allow(dead_code)]
fn strip_routing_prefix(model: &str) -> &str {
if let Some(pos) = model.find('/') {
let prefix = &model[..pos];
// Only strip if the prefix before "/" is a known routing prefix,
// not if "/" appears in the middle of the model name for other reasons.
if matches!(prefix, "openai" | "xai" | "grok" | "qwen" | "kimi") {
if matches!(
prefix,
"openai" | "xai" | "grok" | "qwen" | "kimi" | "local"
) {
&model[pos + 1..]
} else {
model
@@ -932,6 +946,44 @@ fn strip_routing_prefix(model: &str) -> &str {
}
}
fn normalize_base_url_for_model_routing(url: &str) -> &str {
let trimmed = url.trim_end_matches('/');
trimmed
.strip_suffix("/chat/completions")
.map(|value| value.trim_end_matches('/'))
.unwrap_or(trimmed)
}
fn url_host(url: &str) -> &str {
let after_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
let authority = after_scheme.split(['/', '?', '#']).next().unwrap_or("");
let host_port = authority
.rsplit_once('@')
.map_or(authority, |(_, host_port)| host_port);
if host_port.starts_with('[') {
return host_port
.split(']')
.next()
.unwrap_or("")
.trim_start_matches('[');
}
host_port.split(':').next().unwrap_or("")
}
fn is_local_openai_compatible_base_url(url: &str) -> bool {
let host = url_host(url.trim());
if host.eq_ignore_ascii_case("localhost") || host == "::1" {
return true;
}
let Ok(address) = host.parse::<Ipv4Addr>() else {
return false;
};
let [first, second, ..] = address.octets();
matches!(first, 10 | 127)
|| first == 192 && second == 168
|| first == 172 && (16..=31).contains(&second)
}
fn wire_model_for_base_url<'a>(
model: &'a str,
config: OpenAiCompatConfig,
@@ -944,26 +996,22 @@ fn wire_model_for_base_url<'a>(
let lowered_prefix = prefix.to_ascii_lowercase();
if lowered_prefix == "openai" {
let trimmed_base_url = base_url.trim_end_matches('/');
let default_openai = DEFAULT_OPENAI_BASE_URL.trim_end_matches('/');
if matches!(
lowered_prefix.as_str(),
"xai" | "grok" | "kimi" | "gemini" | "gemma"
) {
let normalized_base_url = normalize_base_url_for_model_routing(base_url);
let default_base_url = normalize_base_url_for_model_routing(config.default_base_url);
if normalized_base_url.eq_ignore_ascii_case(default_base_url)
|| is_local_openai_compatible_base_url(base_url)
{
return Cow::Borrowed(&model[pos + 1..]);
}
if config.provider_name == "OpenAI" && trimmed_base_url != default_openai {
// Only preserve the full slug if it's NOT a model we want to strip
if !model.contains("gemini") && !model.contains("gemma") {
return Cow::Borrowed(model);
}
}
return Cow::Borrowed(&model[pos + 1..]);
return Cow::Borrowed(model);
}
if matches!(lowered_prefix.as_str(), "xai" | "grok" | "qwen" | "kimi") {
return Cow::Borrowed(&model[pos + 1..]);
}
if lowered_prefix == "local" {
return Cow::Borrowed(&model[pos + 1..]);
}
Cow::Borrowed(model)
}
@@ -1708,6 +1756,7 @@ mod tests {
ToolChoice, ToolDefinition, ToolResultContentBlock,
};
use serde_json::json;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::sync::{Mutex, OnceLock};
@@ -2147,6 +2196,28 @@ mod tests {
));
}
#[test]
fn local_openai_base_url_does_not_require_api_key() {
let _lock = env_lock();
let original_base_url = std::env::var_os("OPENAI_BASE_URL");
let original_api_key = std::env::var_os("OPENAI_API_KEY");
std::env::set_var("OPENAI_BASE_URL", "http://127.0.0.1:11434/v1");
std::env::remove_var("OPENAI_API_KEY");
let client = OpenAiCompatClient::from_env(OpenAiCompatConfig::openai())
.expect("local OpenAI-compatible endpoint should not require an API key");
assert_eq!(client.base_url(), "http://127.0.0.1:11434/v1");
match original_base_url {
Some(value) => std::env::set_var("OPENAI_BASE_URL", value),
None => std::env::remove_var("OPENAI_BASE_URL"),
}
match original_api_key {
Some(value) => std::env::set_var("OPENAI_API_KEY", value),
None => std::env::remove_var("OPENAI_API_KEY"),
}
}
#[test]
fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
assert_eq!(
@@ -2762,6 +2833,66 @@ mod tests {
}
}
#[test]
fn wire_model_strips_openai_prefix_for_default_and_local_preserves_custom_gateways() {
assert_eq!(
super::wire_model_for_base_url(
"openai/gpt-4o",
OpenAiCompatConfig::openai(),
super::DEFAULT_OPENAI_BASE_URL,
),
Cow::Borrowed("gpt-4o")
);
assert_eq!(
super::wire_model_for_base_url(
"openai/qwen2.5-coder:7b",
OpenAiCompatConfig::openai(),
"http://127.0.0.1:11434/v1",
),
Cow::Borrowed("qwen2.5-coder:7b")
);
assert_eq!(
super::wire_model_for_base_url(
"openai/llama3.2",
OpenAiCompatConfig::openai(),
"http://localhost:11434/v1/chat/completions",
),
Cow::Borrowed("llama3.2")
);
assert_eq!(
super::wire_model_for_base_url(
"openai/gpt-4.1-mini",
OpenAiCompatConfig::openai(),
"https://openrouter.ai/api/v1",
),
Cow::Borrowed("openai/gpt-4.1-mini")
);
assert_eq!(
super::wire_model_for_base_url(
"openai/gpt-4.1-mini",
OpenAiCompatConfig::openai(),
"https://not-localhost.example.com/v1",
),
Cow::Borrowed("openai/gpt-4.1-mini")
);
}
#[test]
fn local_routing_prefix_strips_only_escape_hatch() {
assert_eq!(
super::strip_routing_prefix("local/Qwen/Qwen3.6-27B-FP8"),
"Qwen/Qwen3.6-27B-FP8"
);
assert_eq!(
super::wire_model_for_base_url(
"local/Qwen/Qwen3.6-27B-FP8",
OpenAiCompatConfig::openai(),
"http://127.0.0.1:8000/v1",
),
Cow::Borrowed("Qwen/Qwen3.6-27B-FP8")
);
}
#[test]
fn check_request_body_size_allows_large_requests_for_openai() {
// Create a request that exceeds DashScope's limit but is under OpenAI's 100MB limit