mirror of
https://github.com/instructkr/claude-code.git
synced 2026-05-14 01:46:44 +00:00
wip: cache-tracking progress
This commit is contained in:
@@ -8,8 +8,9 @@ use runtime::{
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::prompt_cache::{PromptCache, PromptCacheStats};
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
@@ -108,6 +109,7 @@ pub struct AnthropicClient {
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
prompt_cache: Option<PromptCache>,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
@@ -120,6 +122,7 @@ impl AnthropicClient {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
prompt_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +135,7 @@ impl AnthropicClient {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
prompt_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +193,22 @@ impl AnthropicClient {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
|
||||
self.prompt_cache = Some(prompt_cache);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn prompt_cache(&self) -> Option<&PromptCache> {
|
||||
self.prompt_cache.as_ref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
|
||||
self.prompt_cache.as_ref().map(PromptCache::stats)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn auth_source(&self) -> &AuthSource {
|
||||
&self.auth
|
||||
@@ -202,6 +222,11 @@ impl AnthropicClient {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
if let Some(prompt_cache) = &self.prompt_cache {
|
||||
if let Some(response) = prompt_cache.lookup_completion(&request) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let mut response = response
|
||||
@@ -211,6 +236,9 @@ impl AnthropicClient {
|
||||
if response.request_id.is_none() {
|
||||
response.request_id = request_id;
|
||||
}
|
||||
if let Some(prompt_cache) = &self.prompt_cache {
|
||||
let _ = prompt_cache.record_response(&request, &response);
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -227,6 +255,15 @@ impl AnthropicClient {
|
||||
parser: SseParser::new(),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
cache_tracking: self
|
||||
.prompt_cache
|
||||
.as_ref()
|
||||
.map(|prompt_cache| StreamCacheTracking {
|
||||
prompt_cache: prompt_cache.clone(),
|
||||
request: request.clone().with_streaming(),
|
||||
last_usage: None,
|
||||
finalized: false,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -527,6 +564,7 @@ pub struct MessageStream {
|
||||
parser: SseParser,
|
||||
pending: VecDeque<StreamEvent>,
|
||||
done: bool,
|
||||
cache_tracking: Option<StreamCacheTracking>,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
@@ -538,6 +576,9 @@ impl MessageStream {
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
loop {
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
if let Some(cache_tracking) = &mut self.cache_tracking {
|
||||
cache_tracking.observe(&event);
|
||||
}
|
||||
return Ok(Some(event));
|
||||
}
|
||||
|
||||
@@ -545,8 +586,14 @@ impl MessageStream {
|
||||
let remaining = self.parser.finish()?;
|
||||
self.pending.extend(remaining);
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
if let Some(cache_tracking) = &mut self.cache_tracking {
|
||||
cache_tracking.observe(&event);
|
||||
}
|
||||
return Ok(Some(event));
|
||||
}
|
||||
if let Some(cache_tracking) = &mut self.cache_tracking {
|
||||
cache_tracking.finalize();
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
@@ -562,6 +609,41 @@ impl MessageStream {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct StreamCacheTracking {
|
||||
prompt_cache: PromptCache,
|
||||
request: MessageRequest,
|
||||
last_usage: Option<Usage>,
|
||||
finalized: bool,
|
||||
}
|
||||
|
||||
impl StreamCacheTracking {
|
||||
fn observe(&mut self, event: &StreamEvent) {
|
||||
match event {
|
||||
StreamEvent::MessageStart(event) => {
|
||||
self.last_usage = Some(event.message.usage.clone());
|
||||
}
|
||||
StreamEvent::MessageDelta(event) => {
|
||||
self.last_usage = Some(event.usage.clone());
|
||||
}
|
||||
StreamEvent::ContentBlockStart(_)
|
||||
| StreamEvent::ContentBlockDelta(_)
|
||||
| StreamEvent::ContentBlockStop(_)
|
||||
| StreamEvent::MessageStop(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&mut self) {
|
||||
if self.finalized {
|
||||
return;
|
||||
}
|
||||
if let Some(usage) = &self.last_usage {
|
||||
let _ = self.prompt_cache.record_usage(&self.request, usage);
|
||||
}
|
||||
self.finalized = true;
|
||||
}
|
||||
}
|
||||
|
||||
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
@@ -606,6 +688,7 @@ mod tests {
|
||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::thread;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
@@ -622,13 +705,15 @@ mod tests {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock")
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
fn temp_config_home() -> std::path::PathBuf {
|
||||
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
|
||||
std::env::temp_dir().join(format!(
|
||||
"api-oauth-test-{}-{}",
|
||||
"api-oauth-test-{}-{}-{}",
|
||||
std::process::id(),
|
||||
NEXT_ID.fetch_add(1, Ordering::Relaxed),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
|
||||
Reference in New Issue
Block a user