wip: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 04:30:24 +00:00
parent ac6c5d00a8
commit 0cf2204d43
4 changed files with 916 additions and 5 deletions

View File

@@ -8,8 +8,9 @@ use runtime::{
use serde::Deserialize;
use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheStats};
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
@@ -108,6 +109,7 @@ pub struct AnthropicClient {
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
prompt_cache: Option<PromptCache>,
}
impl AnthropicClient {
@@ -120,6 +122,7 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
}
}
@@ -132,6 +135,7 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
}
}
@@ -189,6 +193,22 @@ impl AnthropicClient {
self
}
#[must_use]
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
self.prompt_cache = Some(prompt_cache);
self
}
#[must_use]
pub fn prompt_cache(&self) -> Option<&PromptCache> {
self.prompt_cache.as_ref()
}
#[must_use]
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
self.prompt_cache.as_ref().map(PromptCache::stats)
}
#[must_use]
pub fn auth_source(&self) -> &AuthSource {
&self.auth
@@ -202,6 +222,11 @@ impl AnthropicClient {
stream: false,
..request.clone()
};
if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) {
return Ok(response);
}
}
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let mut response = response
@@ -211,6 +236,9 @@ impl AnthropicClient {
if response.request_id.is_none() {
response.request_id = request_id;
}
if let Some(prompt_cache) = &self.prompt_cache {
let _ = prompt_cache.record_response(&request, &response);
}
Ok(response)
}
@@ -227,6 +255,15 @@ impl AnthropicClient {
parser: SseParser::new(),
pending: VecDeque::new(),
done: false,
cache_tracking: self
.prompt_cache
.as_ref()
.map(|prompt_cache| StreamCacheTracking {
prompt_cache: prompt_cache.clone(),
request: request.clone().with_streaming(),
last_usage: None,
finalized: false,
}),
})
}
@@ -527,6 +564,7 @@ pub struct MessageStream {
parser: SseParser,
pending: VecDeque<StreamEvent>,
done: bool,
cache_tracking: Option<StreamCacheTracking>,
}
impl MessageStream {
@@ -538,6 +576,9 @@ impl MessageStream {
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop {
if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event));
}
@@ -545,8 +586,14 @@ impl MessageStream {
let remaining = self.parser.finish()?;
self.pending.extend(remaining);
if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event));
}
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.finalize();
}
return Ok(None);
}
@@ -562,6 +609,41 @@ impl MessageStream {
}
}
#[derive(Debug, Clone)]
struct StreamCacheTracking {
prompt_cache: PromptCache,
request: MessageRequest,
last_usage: Option<Usage>,
finalized: bool,
}
impl StreamCacheTracking {
fn observe(&mut self, event: &StreamEvent) {
match event {
StreamEvent::MessageStart(event) => {
self.last_usage = Some(event.message.usage.clone());
}
StreamEvent::MessageDelta(event) => {
self.last_usage = Some(event.usage.clone());
}
StreamEvent::ContentBlockStart(_)
| StreamEvent::ContentBlockDelta(_)
| StreamEvent::ContentBlockStop(_)
| StreamEvent::MessageStop(_) => {}
}
}
fn finalize(&mut self) {
if self.finalized {
return;
}
if let Some(usage) = &self.last_usage {
let _ = self.prompt_cache.record_usage(&self.request, usage);
}
self.finalized = true;
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status();
if status.is_success() {
@@ -606,6 +688,7 @@ mod tests {
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -622,13 +705,15 @@ mod tests {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock")
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn temp_config_home() -> std::path::PathBuf {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
std::env::temp_dir().join(format!(
"api-oauth-test-{}-{}",
"api-oauth-test-{}-{}-{}",
std::process::id(),
NEXT_ID.fetch_add(1, Ordering::Relaxed),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")