mirror of
https://github.com/instructkr/claude-code.git
synced 2026-06-08 05:06:44 +00:00
Compare commits
77 Commits
6037aaeff1
...
rcc/doctor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b200198df7 | ||
|
|
d6341d54c1 | ||
|
|
863958b94c | ||
|
|
9455280f24 | ||
|
|
c92403994d | ||
|
|
8d4a739c05 | ||
|
|
e2f061fd08 | ||
|
|
c139fe9bee | ||
|
|
6a7cea810e | ||
|
|
842abcfe85 | ||
|
|
807e29c8a1 | ||
|
|
e84133527e | ||
|
|
32e89df631 | ||
|
|
1f8cfbce38 | ||
|
|
1e5002b521 | ||
|
|
d5d99af2d0 | ||
|
|
5180cc5658 | ||
|
|
964cc25821 | ||
|
|
8ab16276bf | ||
|
|
b8dadbfbf5 | ||
|
|
46581fe442 | ||
|
|
92f33c75c0 | ||
|
|
5f46fec5ad | ||
|
|
771f716625 | ||
|
|
d3e41be7f1 | ||
|
|
691ea57832 | ||
|
|
4d65f5c1a2 | ||
|
|
8b6bf4cee7 | ||
|
|
647b407444 | ||
|
|
5eeb7be4cc | ||
|
|
f8bc5cf264 | ||
|
|
346ea0b91b | ||
|
|
6076041f19 | ||
|
|
9f3be03463 | ||
|
|
c30bb8aa59 | ||
|
|
88cd2e31df | ||
|
|
1adf11d572 | ||
|
|
9b0c9b5739 | ||
|
|
cf8d5a8389 | ||
|
|
cba31c4f95 | ||
|
|
fa30059790 | ||
|
|
d9c5f60598 | ||
|
|
9b7fe16edb | ||
|
|
c8f95cd72b | ||
|
|
66dde1b74a | ||
|
|
99b78d6ea4 | ||
|
|
3db3dfa60d | ||
|
|
0ac188caad | ||
|
|
0794e76f07 | ||
|
|
b510387045 | ||
|
|
6e378185e9 | ||
|
|
019e9900ed | ||
|
|
67423d005a | ||
|
|
4db21e9595 | ||
|
|
daf98cc750 | ||
|
|
2ad2ec087f | ||
|
|
0346b7dd3a | ||
|
|
a8f5da6427 | ||
|
|
c996eb7b1b | ||
|
|
14757e0780 | ||
|
|
188c35f8a6 | ||
|
|
2de0b0e2af | ||
|
|
c024d8b21f | ||
|
|
a66c301fa3 | ||
|
|
321a1a681a | ||
|
|
2d1cade31b | ||
|
|
6fe404329d | ||
|
|
add5513ac5 | ||
|
|
8465b6923b | ||
|
|
32981ffa28 | ||
|
|
cb24430c56 | ||
|
|
071045f556 | ||
|
|
a96bb6c60f | ||
|
|
d6a814258c | ||
|
|
4bae5ee132 | ||
|
|
619ae71866 | ||
|
|
5b106b840d |
1
.claude/sessions/session-1774998936453.json
Normal file
1
.claude/sessions/session-1774998936453.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
.claude/sessions/session-1774998994373.json
Normal file
1
.claude/sessions/session-1774998994373.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[{"blocks":[{"text":"Say hello in one sentence","type":"text"}],"role":"user"},{"blocks":[{"text":"Hello! I'm Claude, an AI assistant ready to help you with software engineering tasks, code analysis, debugging, or any other programming challenges you might have.","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":11,"output_tokens":32}}],"version":1}
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
archive/
|
archive/
|
||||||
|
.omx/
|
||||||
|
.clawd-agents/
|
||||||
|
|||||||
26
README.md
26
README.md
@@ -1,5 +1,19 @@
|
|||||||
# Rewriting Project Claw Code
|
# Rewriting Project Claw Code
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<strong>⭐ The fastest repo in history to surpass 50K stars, reaching the milestone in just 2 hours after publication ⭐</strong>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://star-history.com/#instructkr/claw-code&Date">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date&theme=dark" />
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date" />
|
||||||
|
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date" width="600" />
|
||||||
|
</picture>
|
||||||
|
</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="assets/clawd-hero.jpeg" alt="Claw" width="300" />
|
<img src="assets/clawd-hero.jpeg" alt="Claw" width="300" />
|
||||||
</p>
|
</p>
|
||||||
@@ -169,17 +183,7 @@ Join the [**instructkr Discord**](https://instruct.kr/) — the best Korean lang
|
|||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
This repository became **the fastest GitHub repo in history to surpass 30K stars**, reaching the milestone in just a few hours after publication.
|
See the chart at the top of this README.
|
||||||
|
|
||||||
<a href="https://star-history.com/#instructkr/claw-code&Date">
|
|
||||||
<picture>
|
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date&theme=dark" />
|
|
||||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date" />
|
|
||||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=instructkr/claw-code&type=Date" />
|
|
||||||
</picture>
|
|
||||||
</a>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## Ownership / Affiliation Disclaimer
|
## Ownership / Affiliation Disclaimer
|
||||||
|
|
||||||
|
|||||||
2
rust/.gitignore
vendored
2
rust/.gitignore
vendored
@@ -1 +1,3 @@
|
|||||||
target/
|
target/
|
||||||
|
.omx/
|
||||||
|
.clawd-agents/
|
||||||
|
|||||||
92
rust/Cargo.lock
generated
92
rust/Cargo.lock
generated
@@ -22,6 +22,7 @@ name = "api"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"runtime",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -54,6 +55,15 @@ version = "2.11.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "block-buffer"
|
||||||
|
version = "0.10.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.20.2"
|
version = "3.20.2"
|
||||||
@@ -104,6 +114,15 @@ dependencies = [
|
|||||||
"tools",
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cpufeatures"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crc32fast"
|
name = "crc32fast"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
@@ -138,6 +157,16 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crypto-common"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
|
||||||
|
dependencies = [
|
||||||
|
"generic-array",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deranged"
|
name = "deranged"
|
||||||
version = "0.5.8"
|
version = "0.5.8"
|
||||||
@@ -147,6 +176,16 @@ dependencies = [
|
|||||||
"powerfmt",
|
"powerfmt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "digest"
|
||||||
|
version = "0.10.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
|
dependencies = [
|
||||||
|
"block-buffer",
|
||||||
|
"crypto-common",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "displaydoc"
|
name = "displaydoc"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
@@ -212,6 +251,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d"
|
checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -220,6 +260,18 @@ version = "0.3.32"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-sink"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-task"
|
name = "futures-task"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
@@ -233,11 +285,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "generic-array"
|
||||||
|
version = "0.14.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||||
|
dependencies = [
|
||||||
|
"typenum",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getopts"
|
name = "getopts"
|
||||||
version = "0.2.24"
|
version = "0.2.24"
|
||||||
@@ -898,7 +963,9 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
@@ -950,6 +1017,7 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"sha2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
@@ -1106,6 +1174,17 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha2"
|
||||||
|
version = "0.10.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -1352,6 +1431,7 @@ dependencies = [
|
|||||||
name = "tools"
|
name = "tools"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"reqwest",
|
||||||
"runtime",
|
"runtime",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
@@ -1427,6 +1507,12 @@ version = "0.2.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicase"
|
name = "unicase"
|
||||||
version = "2.9.0"
|
version = "2.9.0"
|
||||||
@@ -1469,6 +1555,12 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "walkdir"
|
name = "walkdir"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
|||||||
227
rust/README.md
227
rust/README.md
@@ -1,54 +1,211 @@
|
|||||||
# Rust port foundation
|
# Rusty Claude CLI
|
||||||
|
|
||||||
This directory contains the first compatibility-first Rust foundation for a drop-in Claude Code CLI replacement.
|
`rust/` contains the Rust workspace for the integrated `rusty-claude-cli` deliverable.
|
||||||
|
It is intended to be something you can clone, build, and run directly.
|
||||||
## Current milestone
|
|
||||||
|
|
||||||
This initial milestone focuses on **harness-first scaffolding**, not full feature parity:
|
|
||||||
|
|
||||||
- a Cargo workspace aligned to major upstream seams
|
|
||||||
- a placeholder CLI crate (`rusty-claude-cli`)
|
|
||||||
- runtime, command, and tool registry skeleton crates
|
|
||||||
- a `compat-harness` crate that reads the upstream TypeScript sources in `../src/`
|
|
||||||
- tests that prove upstream manifests/bootstrap hints can be extracted from the leaked TypeScript codebase
|
|
||||||
|
|
||||||
## Workspace layout
|
## Workspace layout
|
||||||
|
|
||||||
```text
|
```text
|
||||||
rust/
|
rust/
|
||||||
├── Cargo.toml
|
├── Cargo.toml
|
||||||
|
├── Cargo.lock
|
||||||
├── README.md
|
├── README.md
|
||||||
├── crates/
|
└── crates/
|
||||||
│ ├── rusty-claude-cli/
|
├── api/ # Anthropic API client + SSE streaming support
|
||||||
│ ├── runtime/
|
├── commands/ # Shared slash-command metadata/help surfaces
|
||||||
│ ├── commands/
|
├── compat-harness/ # Upstream TS manifest extraction harness
|
||||||
│ ├── tools/
|
├── runtime/ # Session/runtime/config/prompt orchestration
|
||||||
│ └── compat-harness/
|
├── rusty-claude-cli/ # Main CLI binary
|
||||||
└── tests/
|
└── tools/ # Built-in tool implementations
|
||||||
```
|
```
|
||||||
|
|
||||||
## How to use
|
## Prerequisites
|
||||||
|
|
||||||
From this directory:
|
- Rust toolchain installed (`rustup`, stable toolchain)
|
||||||
|
- Network access and Anthropic credentials for live prompt/REPL usage
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
From the repository root:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo fmt --all
|
cd rust
|
||||||
cargo check --workspace
|
cargo build --release -p rusty-claude-cli
|
||||||
cargo test --workspace
|
|
||||||
cargo run -p rusty-claude-cli -- --help
|
|
||||||
cargo run -p rusty-claude-cli -- dump-manifests
|
|
||||||
cargo run -p rusty-claude-cli -- bootstrap-plan
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Design notes
|
The optimized binary will be written to:
|
||||||
|
|
||||||
The shape follows the PRD's harness-first recommendation:
|
```bash
|
||||||
|
./target/release/rusty-claude-cli
|
||||||
|
```
|
||||||
|
|
||||||
1. Extract observable upstream command/tool/bootstrap facts first.
|
## Test
|
||||||
2. Keep Rust module boundaries recognizable.
|
|
||||||
3. Grow runtime compatibility behind proof artifacts.
|
|
||||||
4. Document explicit gaps instead of implying drop-in parity too early.
|
|
||||||
|
|
||||||
## Relationship to the root README
|
Run the verified workspace test suite used for release-readiness:
|
||||||
|
|
||||||
The repository root README explains the leaked TypeScript codebase. This document tracks the Rust replacement effort that lives in `rust/`.
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo test --workspace --exclude compat-harness
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
### Show help
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### Print version
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --version
|
||||||
|
```
|
||||||
|
|
||||||
|
### Login with OAuth
|
||||||
|
|
||||||
|
Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- login
|
||||||
|
```
|
||||||
|
|
||||||
|
This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`).
|
||||||
|
|
||||||
|
### Logout
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- logout
|
||||||
|
```
|
||||||
|
|
||||||
|
This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`.
|
||||||
|
|
||||||
|
## Usage examples
|
||||||
|
|
||||||
|
### 1) Prompt mode
|
||||||
|
|
||||||
|
Send one prompt, stream the answer, then exit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- prompt "Summarize the architecture of this repository"
|
||||||
|
```
|
||||||
|
|
||||||
|
Use a specific model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --model claude-sonnet-4-20250514 prompt "List the key crates in this workspace"
|
||||||
|
```
|
||||||
|
|
||||||
|
Restrict enabled tools in an interactive session:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --allowedTools read,glob
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) REPL mode
|
||||||
|
|
||||||
|
Start the interactive shell:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli --
|
||||||
|
```
|
||||||
|
|
||||||
|
Inside the REPL, useful commands include:
|
||||||
|
|
||||||
|
```text
|
||||||
|
/help
|
||||||
|
/status
|
||||||
|
/model claude-sonnet-4-20250514
|
||||||
|
/permissions workspace-write
|
||||||
|
/cost
|
||||||
|
/compact
|
||||||
|
/memory
|
||||||
|
/config
|
||||||
|
/init
|
||||||
|
/diff
|
||||||
|
/version
|
||||||
|
/export notes.txt
|
||||||
|
/session list
|
||||||
|
/exit
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) Resume an existing session
|
||||||
|
|
||||||
|
Inspect or maintain a saved session file without entering the REPL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --resume session.json /status /compact /cost
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also inspect memory/config state for a restored session:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available commands
|
||||||
|
|
||||||
|
### Top-level CLI commands
|
||||||
|
|
||||||
|
- `prompt <text...>` — run one prompt non-interactively
|
||||||
|
- `--resume <session.json> [/commands...]` — inspect or maintain a saved session
|
||||||
|
- `dump-manifests` — print extracted upstream manifest counts
|
||||||
|
- `bootstrap-plan` — print the current bootstrap skeleton
|
||||||
|
- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt
|
||||||
|
- `--help` / `-h` — show CLI help
|
||||||
|
- `--version` / `-V` — print the CLI version and build info locally (no API call)
|
||||||
|
- `--output-format text|json` — choose non-interactive prompt output rendering
|
||||||
|
- `--allowedTools <tool[,tool...]>` — restrict enabled tools for interactive sessions and prompt-mode tool use
|
||||||
|
|
||||||
|
### Interactive slash commands
|
||||||
|
|
||||||
|
- `/help` — show command help
|
||||||
|
- `/status` — show current session status
|
||||||
|
- `/compact` — compact local session history
|
||||||
|
- `/model [model]` — inspect or switch the active model
|
||||||
|
- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions
|
||||||
|
- `/clear [--confirm]` — clear the current local session
|
||||||
|
- `/cost` — show token usage totals
|
||||||
|
- `/resume <session-path>` — load a saved session into the REPL
|
||||||
|
- `/config [env|hooks|model]` — inspect discovered Claude config
|
||||||
|
- `/memory` — inspect loaded instruction memory files
|
||||||
|
- `/init` — create a starter `CLAUDE.md`
|
||||||
|
- `/diff` — show the current git diff for the workspace
|
||||||
|
- `/version` — print version and build metadata locally
|
||||||
|
- `/export [file]` — export the current conversation transcript
|
||||||
|
- `/session [list|switch <session-id>]` — inspect or switch managed local sessions
|
||||||
|
- `/exit` — leave the REPL
|
||||||
|
|
||||||
|
## Environment variables
|
||||||
|
|
||||||
|
### Anthropic/API
|
||||||
|
|
||||||
|
- `ANTHROPIC_API_KEY` — highest-precedence API credential
|
||||||
|
- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set
|
||||||
|
- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set
|
||||||
|
- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL
|
||||||
|
- `ANTHROPIC_MODEL` — default model used by selected live integration tests
|
||||||
|
|
||||||
|
### CLI/runtime
|
||||||
|
|
||||||
|
- `RUSTY_CLAUDE_PERMISSION_MODE` — default REPL permission mode (`read-only`, `workspace-write`, or `danger-full-access`)
|
||||||
|
- `CLAUDE_CONFIG_HOME` — override Claude config discovery root
|
||||||
|
- `CLAUDE_CODE_REMOTE` — enable remote-session bootstrap handling when supported
|
||||||
|
- `CLAUDE_CODE_REMOTE_SESSION_ID` — remote session identifier when using remote mode
|
||||||
|
- `CLAUDE_CODE_UPSTREAM` — override the upstream TS source path for compat-harness extraction
|
||||||
|
- `CLAWD_WEB_SEARCH_BASE_URL` — override the built-in web search service endpoint used by tooling
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- `compat-harness` exists to compare the Rust port against the upstream TypeScript codebase and is intentionally excluded from the requested release test run.
|
||||||
|
- The CLI currently focuses on a practical integrated workflow: prompt execution, REPL operation, session inspection/resume, config discovery, and tool/runtime plumbing.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ publish.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||||
|
runtime = { path = "../runtime" }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::time::Duration;
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use runtime::{
|
||||||
|
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
||||||
|
OAuthTokenExchangeRequest,
|
||||||
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
@@ -15,11 +19,91 @@ const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
|||||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum AuthSource {
|
||||||
|
None,
|
||||||
|
ApiKey(String),
|
||||||
|
BearerToken(String),
|
||||||
|
ApiKeyAndBearer {
|
||||||
|
api_key: String,
|
||||||
|
bearer_token: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthSource {
|
||||||
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
|
let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
|
||||||
|
let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
|
||||||
|
match (api_key, auth_token) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
||||||
|
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
||||||
|
(None, None) => Err(ApiError::MissingApiKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn api_key(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
|
||||||
|
Self::None | Self::BearerToken(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn bearer_token(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::BearerToken(token)
|
||||||
|
| Self::ApiKeyAndBearer {
|
||||||
|
bearer_token: token,
|
||||||
|
..
|
||||||
|
} => Some(token),
|
||||||
|
Self::None | Self::ApiKey(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn masked_authorization_header(&self) -> &'static str {
|
||||||
|
if self.bearer_token().is_some() {
|
||||||
|
"Bearer [REDACTED]"
|
||||||
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||||
|
if let Some(api_key) = self.api_key() {
|
||||||
|
request_builder = request_builder.header("x-api-key", api_key);
|
||||||
|
}
|
||||||
|
if let Some(token) = self.bearer_token() {
|
||||||
|
request_builder = request_builder.bearer_auth(token);
|
||||||
|
}
|
||||||
|
request_builder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuthTokenSet> for AuthSource {
|
||||||
|
fn from(value: OAuthTokenSet) -> Self {
|
||||||
|
Self::BearerToken(value.access_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AnthropicClient {
|
pub struct AnthropicClient {
|
||||||
http: reqwest::Client,
|
http: reqwest::Client,
|
||||||
api_key: String,
|
auth: AuthSource,
|
||||||
auth_token: Option<String>,
|
|
||||||
base_url: String,
|
base_url: String,
|
||||||
max_retries: u32,
|
max_retries: u32,
|
||||||
initial_backoff: Duration,
|
initial_backoff: Duration,
|
||||||
@@ -31,8 +115,19 @@ impl AnthropicClient {
|
|||||||
pub fn new(api_key: impl Into<String>) -> Self {
|
pub fn new(api_key: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
http: reqwest::Client::new(),
|
http: reqwest::Client::new(),
|
||||||
api_key: api_key.into(),
|
auth: AuthSource::ApiKey(api_key.into()),
|
||||||
auth_token: None,
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_auth(auth: AuthSource) -> Self {
|
||||||
|
Self {
|
||||||
|
http: reqwest::Client::new(),
|
||||||
|
auth,
|
||||||
base_url: DEFAULT_BASE_URL.to_string(),
|
base_url: DEFAULT_BASE_URL.to_string(),
|
||||||
max_retries: DEFAULT_MAX_RETRIES,
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
@@ -41,14 +136,37 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_env() -> Result<Self, ApiError> {
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
Ok(Self::new(read_api_key()?)
|
Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
|
||||||
.with_auth_token(read_auth_token())
|
}
|
||||||
.with_base_url(read_base_url()))
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
|
||||||
|
self.auth = auth;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
||||||
self.auth_token = auth_token.filter(|token| !token.is_empty());
|
match (
|
||||||
|
self.auth.api_key().map(ToOwned::to_owned),
|
||||||
|
auth_token.filter(|token| !token.is_empty()),
|
||||||
|
) {
|
||||||
|
(Some(api_key), Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
(Some(api_key), None) => {
|
||||||
|
self.auth = AuthSource::ApiKey(api_key);
|
||||||
|
}
|
||||||
|
(None, Some(bearer_token)) => {
|
||||||
|
self.auth = AuthSource::BearerToken(bearer_token);
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
self.auth = AuthSource::None;
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,6 +189,11 @@ impl AnthropicClient {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn auth_source(&self) -> &AuthSource {
|
||||||
|
&self.auth
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn send_message(
|
pub async fn send_message(
|
||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
@@ -107,6 +230,46 @@ impl AnthropicClient {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn exchange_oauth_code(
|
||||||
|
&self,
|
||||||
|
config: &OAuthConfig,
|
||||||
|
request: &OAuthTokenExchangeRequest,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&config.token_url)
|
||||||
|
.header("content-type", "application/x-www-form-urlencoded")
|
||||||
|
.form(&request.form_params())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
let response = expect_success(response).await?;
|
||||||
|
response
|
||||||
|
.json::<OAuthTokenSet>()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn refresh_oauth_token(
|
||||||
|
&self,
|
||||||
|
config: &OAuthConfig,
|
||||||
|
request: &OAuthRefreshRequest,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&config.token_url)
|
||||||
|
.header("content-type", "application/x-www-form-urlencoded")
|
||||||
|
.form(&request.form_params())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
let response = expect_success(response).await?;
|
||||||
|
response
|
||||||
|
.json::<OAuthTokenSet>()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)
|
||||||
|
}
|
||||||
|
|
||||||
async fn send_with_retry(
|
async fn send_with_retry(
|
||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
@@ -151,25 +314,25 @@ impl AnthropicClient {
|
|||||||
let resolved_base_url = self.base_url.trim_end_matches('/');
|
let resolved_base_url = self.base_url.trim_end_matches('/');
|
||||||
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
|
||||||
eprintln!("[anthropic-client] request_url={request_url}");
|
eprintln!("[anthropic-client] request_url={request_url}");
|
||||||
let mut request_builder = self
|
let request_builder = self
|
||||||
.http
|
.http
|
||||||
.post(&request_url)
|
.post(&request_url)
|
||||||
.header("x-api-key", &self.api_key)
|
|
||||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
|
let mut request_builder = self.auth.apply(request_builder);
|
||||||
|
|
||||||
let auth_header = self.auth_token.as_ref().map(|_| "Bearer [REDACTED]").unwrap_or("<absent>");
|
eprintln!(
|
||||||
eprintln!("[anthropic-client] headers x-api-key=[REDACTED] authorization={auth_header} anthropic-version={ANTHROPIC_VERSION} content-type=application/json");
|
"[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json",
|
||||||
|
if self.auth.api_key().is_some() {
|
||||||
|
"[REDACTED]"
|
||||||
|
} else {
|
||||||
|
"<absent>"
|
||||||
|
},
|
||||||
|
self.auth.masked_authorization_header()
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(auth_token) = &self.auth_token {
|
request_builder = request_builder.json(request);
|
||||||
request_builder = request_builder.bearer_auth(auth_token);
|
request_builder.send().await.map_err(ApiError::from)
|
||||||
}
|
|
||||||
|
|
||||||
request_builder
|
|
||||||
.json(request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(ApiError::from)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||||
@@ -186,25 +349,175 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_api_key() -> Result<String, ApiError> {
|
impl AuthSource {
|
||||||
match std::env::var("ANTHROPIC_API_KEY") {
|
pub fn from_env_or_saved() -> Result<Self, ApiError> {
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||||
Ok(_) => Err(ApiError::MissingApiKey),
|
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
|
||||||
Ok(api_key) if !api_key.is_empty() => Ok(api_key),
|
api_key,
|
||||||
Ok(_) => Err(ApiError::MissingApiKey),
|
bearer_token,
|
||||||
Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey),
|
}),
|
||||||
Err(error) => Err(ApiError::from(error)),
|
None => Ok(Self::ApiKey(api_key)),
|
||||||
},
|
};
|
||||||
|
}
|
||||||
|
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
return Ok(Self::BearerToken(bearer_token));
|
||||||
|
}
|
||||||
|
match load_saved_oauth_token() {
|
||||||
|
Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
|
||||||
|
if token_set.refresh_token.is_some() {
|
||||||
|
Err(ApiError::Auth(
|
||||||
|
"saved OAuth token is expired; load runtime OAuth config to refresh it"
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(ApiError::ExpiredOAuthToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
||||||
|
Ok(None) => Err(ApiError::MissingApiKey),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
|
||||||
|
token_set
|
||||||
|
.expires_at
|
||||||
|
.is_some_and(|expires_at| expires_at <= now_unix_timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||||
|
let Some(token_set) = load_saved_oauth_token()? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
resolve_saved_oauth_token_set(config, token_set).map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
||||||
|
{
|
||||||
|
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||||
|
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
None => Ok(AuthSource::ApiKey(api_key)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
return Ok(AuthSource::BearerToken(bearer_token));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(token_set) = load_saved_oauth_token()? else {
|
||||||
|
return Err(ApiError::MissingApiKey);
|
||||||
|
};
|
||||||
|
if !oauth_token_is_expired(&token_set) {
|
||||||
|
return Ok(AuthSource::BearerToken(token_set.access_token));
|
||||||
|
}
|
||||||
|
if token_set.refresh_token.is_none() {
|
||||||
|
return Err(ApiError::ExpiredOAuthToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(config) = load_oauth_config()? else {
|
||||||
|
return Err(ApiError::Auth(
|
||||||
|
"saved OAuth token is expired; runtime OAuth config is missing".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
Ok(AuthSource::from(resolve_saved_oauth_token_set(
|
||||||
|
&config, token_set,
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_saved_oauth_token_set(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
token_set: OAuthTokenSet,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
if !oauth_token_is_expired(&token_set) {
|
||||||
|
return Ok(token_set);
|
||||||
|
}
|
||||||
|
let Some(refresh_token) = token_set.refresh_token.clone() else {
|
||||||
|
return Err(ApiError::ExpiredOAuthToken);
|
||||||
|
};
|
||||||
|
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
|
||||||
|
let refreshed = client_runtime_block_on(async {
|
||||||
|
client
|
||||||
|
.refresh_oauth_token(
|
||||||
|
config,
|
||||||
|
&OAuthRefreshRequest::from_config(
|
||||||
|
config,
|
||||||
|
refresh_token,
|
||||||
|
Some(token_set.scopes.clone()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
})?;
|
||||||
|
let resolved = OAuthTokenSet {
|
||||||
|
access_token: refreshed.access_token,
|
||||||
|
refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
|
||||||
|
expires_at: refreshed.expires_at,
|
||||||
|
scopes: refreshed.scopes,
|
||||||
|
};
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: resolved.access_token.clone(),
|
||||||
|
refresh_token: resolved.refresh_token.clone(),
|
||||||
|
expires_at: resolved.expires_at,
|
||||||
|
scopes: resolved.scopes.clone(),
|
||||||
|
})
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
|
||||||
|
where
|
||||||
|
F: std::future::Future<Output = Result<T, ApiError>>,
|
||||||
|
{
|
||||||
|
tokio::runtime::Runtime::new()
|
||||||
|
.map_err(ApiError::from)?
|
||||||
|
.block_on(future)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||||
|
let token_set = load_oauth_credentials().map_err(ApiError::from)?;
|
||||||
|
Ok(token_set.map(|token_set| OAuthTokenSet {
|
||||||
|
access_token: token_set.access_token,
|
||||||
|
refresh_token: token_set.refresh_token,
|
||||||
|
expires_at: token_set.expires_at,
|
||||||
|
scopes: token_set.scopes,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_unix_timestamp() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map_or(0, |duration| duration.as_secs())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||||
|
match std::env::var(key) {
|
||||||
|
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
||||||
|
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
|
||||||
Err(error) => Err(ApiError::from(error)),
|
Err(error) => Err(ApiError::from(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn read_api_key() -> Result<String, ApiError> {
|
||||||
|
let auth = AuthSource::from_env_or_saved()?;
|
||||||
|
auth.api_key()
|
||||||
|
.or_else(|| auth.bearer_token())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.ok_or(ApiError::MissingApiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
fn read_auth_token() -> Option<String> {
|
fn read_auth_token() -> Option<String> {
|
||||||
match std::env::var("ANTHROPIC_AUTH_TOKEN") {
|
read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
|
||||||
Ok(token) if !token.is_empty() => Some(token),
|
.ok()
|
||||||
_ => None,
|
.and_then(std::convert::identity)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_base_url() -> String {
|
fn read_base_url() -> String {
|
||||||
@@ -303,28 +616,91 @@ struct AnthropicErrorBody {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||||
use std::time::Duration;
|
use std::io::{Read, Write};
|
||||||
|
use std::net::TcpListener;
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
|
||||||
|
|
||||||
|
use crate::client::{
|
||||||
|
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
||||||
|
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
||||||
|
};
|
||||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.expect("env lock")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn temp_config_home() -> std::path::PathBuf {
|
||||||
|
std::env::temp_dir().join(format!(
|
||||||
|
"api-oauth-test-{}-{}",
|
||||||
|
std::process::id(),
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time")
|
||||||
|
.as_nanos()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_oauth_config(token_url: String) -> OAuthConfig {
|
||||||
|
OAuthConfig {
|
||||||
|
client_id: "runtime-client".to_string(),
|
||||||
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||||
|
token_url,
|
||||||
|
callback_port: Some(4545),
|
||||||
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||||
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_token_server(response_body: &'static str) -> String {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
|
||||||
|
let address = listener.local_addr().expect("local addr");
|
||||||
|
thread::spawn(move || {
|
||||||
|
let (mut stream, _) = listener.accept().expect("accept connection");
|
||||||
|
let mut buffer = [0_u8; 4096];
|
||||||
|
let _ = stream.read(&mut buffer).expect("read request");
|
||||||
|
let response = format!(
|
||||||
|
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
|
||||||
|
response_body.len(),
|
||||||
|
response_body
|
||||||
|
);
|
||||||
|
stream
|
||||||
|
.write_all(response.as_bytes())
|
||||||
|
.expect("write response");
|
||||||
|
});
|
||||||
|
format!("http://{address}/oauth/token")
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_requires_presence() {
|
fn read_api_key_requires_presence() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
let error = super::read_api_key().expect_err("missing key should error");
|
let error = super::read_api_key().expect_err("missing key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_requires_non_empty_value() {
|
fn read_api_key_requires_non_empty_value() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
let error = super::read_api_key().expect_err("empty key should error");
|
let error = super::read_api_key().expect_err("empty key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_prefers_api_key_env() {
|
fn read_api_key_prefers_api_key_env() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -337,11 +713,196 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_auth_token_reads_auth_token_env() {
|
fn read_auth_token_reads_auth_token_env() {
|
||||||
|
let _guard = env_lock();
|
||||||
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_token_maps_to_bearer_auth_source() {
|
||||||
|
let auth = AuthSource::from(OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(123),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
});
|
||||||
|
assert_eq!(auth.bearer_token(), Some("access-token"));
|
||||||
|
assert_eq!(auth.api_key(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_from_env_combines_api_key_and_bearer_token() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
||||||
|
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
||||||
|
let auth = AuthSource::from_env().expect("env auth");
|
||||||
|
assert_eq!(auth.api_key(), Some("legacy-key"));
|
||||||
|
assert_eq!(auth.bearer_token(), Some("auth-token"));
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_from_saved_oauth_when_env_absent() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "saved-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(now_unix_timestamp() + 300),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save oauth credentials");
|
||||||
|
|
||||||
|
let auth = AuthSource::from_env_or_saved().expect("saved auth");
|
||||||
|
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_token_expiry_uses_expires_at_timestamp() {
|
||||||
|
assert!(oauth_token_is_expired(&OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: None,
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: Vec::new(),
|
||||||
|
}));
|
||||||
|
assert!(!oauth_token_is_expired(&OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: None,
|
||||||
|
expires_at: Some(now_unix_timestamp() + 60),
|
||||||
|
scopes: Vec::new(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let token_url = spawn_token_server(
|
||||||
|
"{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||||
|
);
|
||||||
|
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||||
|
.expect("resolve refreshed token")
|
||||||
|
.expect("token set present");
|
||||||
|
assert_eq!(resolved.access_token, "refreshed-token");
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.access_token, "refreshed-token");
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "saved-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(now_unix_timestamp() + 300),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save oauth credentials");
|
||||||
|
|
||||||
|
let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
|
||||||
|
.expect("startup auth");
|
||||||
|
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let error =
|
||||||
|
resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
|
||||||
|
assert!(
|
||||||
|
matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
|
||||||
|
);
|
||||||
|
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.access_token, "expired-access-token");
|
||||||
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let token_url = spawn_token_server(
|
||||||
|
"{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||||
|
);
|
||||||
|
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||||
|
.expect("resolve refreshed token")
|
||||||
|
.expect("token set present");
|
||||||
|
assert_eq!(resolved.access_token, "refreshed-token");
|
||||||
|
assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn message_request_stream_helper_sets_stream_true() {
|
fn message_request_stream_helper_sets_stream_true() {
|
||||||
let request = MessageRequest {
|
let request = MessageRequest {
|
||||||
@@ -359,7 +920,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn backoff_doubles_until_maximum() {
|
fn backoff_doubles_until_maximum() {
|
||||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
let client = AnthropicClient::new("test-key").with_retry_policy(
|
||||||
3,
|
3,
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
Duration::from_millis(25),
|
Duration::from_millis(25),
|
||||||
@@ -421,4 +982,25 @@ mod tests {
|
|||||||
Some("req_fallback")
|
Some("req_fallback")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_applies_headers() {
|
||||||
|
let auth = AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
bearer_token: "proxy-token".to_string(),
|
||||||
|
};
|
||||||
|
let request = auth
|
||||||
|
.apply(reqwest::Client::new().post("https://example.test"))
|
||||||
|
.build()
|
||||||
|
.expect("request build");
|
||||||
|
let headers = request.headers();
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("x-api-key").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("test-key")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("authorization").and_then(|v| v.to_str().ok()),
|
||||||
|
Some("Bearer proxy-token")
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ use std::time::Duration;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ApiError {
|
pub enum ApiError {
|
||||||
MissingApiKey,
|
MissingApiKey,
|
||||||
|
ExpiredOAuthToken,
|
||||||
|
Auth(String),
|
||||||
InvalidApiKeyEnv(VarError),
|
InvalidApiKeyEnv(VarError),
|
||||||
Http(reqwest::Error),
|
Http(reqwest::Error),
|
||||||
Io(std::io::Error),
|
Io(std::io::Error),
|
||||||
@@ -35,6 +37,8 @@ impl ApiError {
|
|||||||
Self::Api { retryable, .. } => *retryable,
|
Self::Api { retryable, .. } => *retryable,
|
||||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||||
Self::MissingApiKey
|
Self::MissingApiKey
|
||||||
|
| Self::ExpiredOAuthToken
|
||||||
|
| Self::Auth(_)
|
||||||
| Self::InvalidApiKeyEnv(_)
|
| Self::InvalidApiKeyEnv(_)
|
||||||
| Self::Io(_)
|
| Self::Io(_)
|
||||||
| Self::Json(_)
|
| Self::Json(_)
|
||||||
@@ -53,6 +57,13 @@ impl Display for ApiError {
|
|||||||
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
|
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Self::ExpiredOAuthToken => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"saved OAuth token is expired and no refresh token is available"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Self::Auth(message) => write!(f, "auth error: {message}"),
|
||||||
Self::InvalidApiKeyEnv(error) => {
|
Self::InvalidApiKeyEnv(error) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ mod error;
|
|||||||
mod sse;
|
mod sse;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use client::{AnthropicClient, MessageStream};
|
pub use client::{
|
||||||
|
oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source,
|
||||||
|
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
|
||||||
|
};
|
||||||
pub use error::ApiError;
|
pub use error::ApiError;
|
||||||
pub use sse::{parse_frame, SseParser};
|
pub use sse::{parse_frame, SseParser};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
|
|||||||
@@ -30,6 +30,222 @@ impl CommandRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct SlashCommandSpec {
|
||||||
|
pub name: &'static str,
|
||||||
|
pub summary: &'static str,
|
||||||
|
pub argument_hint: Option<&'static str>,
|
||||||
|
pub resume_supported: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "help",
|
||||||
|
summary: "Show available slash commands",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "status",
|
||||||
|
summary: "Show current session status",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "compact",
|
||||||
|
summary: "Compact local session history",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "model",
|
||||||
|
summary: "Show or switch the active model",
|
||||||
|
argument_hint: Some("[model]"),
|
||||||
|
resume_supported: false,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "permissions",
|
||||||
|
summary: "Show or switch the active permission mode",
|
||||||
|
argument_hint: Some("[read-only|workspace-write|danger-full-access]"),
|
||||||
|
resume_supported: false,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "clear",
|
||||||
|
summary: "Start a fresh local session",
|
||||||
|
argument_hint: Some("[--confirm]"),
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "cost",
|
||||||
|
summary: "Show cumulative token usage for this session",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "resume",
|
||||||
|
summary: "Load a saved session into the REPL",
|
||||||
|
argument_hint: Some("<session-path>"),
|
||||||
|
resume_supported: false,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "config",
|
||||||
|
summary: "Inspect Claude config files or merged sections",
|
||||||
|
argument_hint: Some("[env|hooks|model]"),
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "memory",
|
||||||
|
summary: "Inspect loaded Claude instruction memory files",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "init",
|
||||||
|
summary: "Create a starter CLAUDE.md for this repo",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "diff",
|
||||||
|
summary: "Show git diff for current workspace changes",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "version",
|
||||||
|
summary: "Show CLI version and build information",
|
||||||
|
argument_hint: None,
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "export",
|
||||||
|
summary: "Export the current conversation to a file",
|
||||||
|
argument_hint: Some("[file]"),
|
||||||
|
resume_supported: true,
|
||||||
|
},
|
||||||
|
SlashCommandSpec {
|
||||||
|
name: "session",
|
||||||
|
summary: "List or switch managed local sessions",
|
||||||
|
argument_hint: Some("[list|switch <session-id>]"),
|
||||||
|
resume_supported: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum SlashCommand {
|
||||||
|
Help,
|
||||||
|
Status,
|
||||||
|
Compact,
|
||||||
|
Model {
|
||||||
|
model: Option<String>,
|
||||||
|
},
|
||||||
|
Permissions {
|
||||||
|
mode: Option<String>,
|
||||||
|
},
|
||||||
|
Clear {
|
||||||
|
confirm: bool,
|
||||||
|
},
|
||||||
|
Cost,
|
||||||
|
Resume {
|
||||||
|
session_path: Option<String>,
|
||||||
|
},
|
||||||
|
Config {
|
||||||
|
section: Option<String>,
|
||||||
|
},
|
||||||
|
Memory,
|
||||||
|
Init,
|
||||||
|
Diff,
|
||||||
|
Version,
|
||||||
|
Export {
|
||||||
|
path: Option<String>,
|
||||||
|
},
|
||||||
|
Session {
|
||||||
|
action: Option<String>,
|
||||||
|
target: Option<String>,
|
||||||
|
},
|
||||||
|
Unknown(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SlashCommand {
|
||||||
|
#[must_use]
|
||||||
|
pub fn parse(input: &str) -> Option<Self> {
|
||||||
|
let trimmed = input.trim();
|
||||||
|
if !trimmed.starts_with('/') {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut parts = trimmed.trim_start_matches('/').split_whitespace();
|
||||||
|
let command = parts.next().unwrap_or_default();
|
||||||
|
Some(match command {
|
||||||
|
"help" => Self::Help,
|
||||||
|
"status" => Self::Status,
|
||||||
|
"compact" => Self::Compact,
|
||||||
|
"model" => Self::Model {
|
||||||
|
model: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
"permissions" => Self::Permissions {
|
||||||
|
mode: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
"clear" => Self::Clear {
|
||||||
|
confirm: parts.next() == Some("--confirm"),
|
||||||
|
},
|
||||||
|
"cost" => Self::Cost,
|
||||||
|
"resume" => Self::Resume {
|
||||||
|
session_path: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
"config" => Self::Config {
|
||||||
|
section: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
"memory" => Self::Memory,
|
||||||
|
"init" => Self::Init,
|
||||||
|
"diff" => Self::Diff,
|
||||||
|
"version" => Self::Version,
|
||||||
|
"export" => Self::Export {
|
||||||
|
path: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
"session" => Self::Session {
|
||||||
|
action: parts.next().map(ToOwned::to_owned),
|
||||||
|
target: parts.next().map(ToOwned::to_owned),
|
||||||
|
},
|
||||||
|
other => Self::Unknown(other.to_string()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn slash_command_specs() -> &'static [SlashCommandSpec] {
|
||||||
|
SLASH_COMMAND_SPECS
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> {
|
||||||
|
slash_command_specs()
|
||||||
|
.iter()
|
||||||
|
.filter(|spec| spec.resume_supported)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn render_slash_command_help() -> String {
|
||||||
|
let mut lines = vec![
|
||||||
|
"Slash commands".to_string(),
|
||||||
|
" [resume] means the command also works with --resume SESSION.json".to_string(),
|
||||||
|
];
|
||||||
|
for spec in slash_command_specs() {
|
||||||
|
let name = match spec.argument_hint {
|
||||||
|
Some(argument_hint) => format!("/{} {}", spec.name, argument_hint),
|
||||||
|
None => format!("/{}", spec.name),
|
||||||
|
};
|
||||||
|
let resume = if spec.resume_supported {
|
||||||
|
" [resume]"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
lines.push(format!(" {name:<20} {}{}", spec.summary, resume));
|
||||||
|
}
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct SlashCommandResult {
|
pub struct SlashCommandResult {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
@@ -42,13 +258,8 @@ pub fn handle_slash_command(
|
|||||||
session: &Session,
|
session: &Session,
|
||||||
compaction: CompactionConfig,
|
compaction: CompactionConfig,
|
||||||
) -> Option<SlashCommandResult> {
|
) -> Option<SlashCommandResult> {
|
||||||
let trimmed = input.trim();
|
match SlashCommand::parse(input)? {
|
||||||
if !trimmed.starts_with('/') {
|
SlashCommand::Compact => {
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
match trimmed.split_whitespace().next() {
|
|
||||||
Some("/compact") => {
|
|
||||||
let result = compact_session(session, compaction);
|
let result = compact_session(session, compaction);
|
||||||
let message = if result.removed_message_count == 0 {
|
let message = if result.removed_message_count == 0 {
|
||||||
"Compaction skipped: session is below the compaction threshold.".to_string()
|
"Compaction skipped: session is below the compaction threshold.".to_string()
|
||||||
@@ -63,15 +274,122 @@ pub fn handle_slash_command(
|
|||||||
session: result.compacted_session,
|
session: result.compacted_session,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
_ => None,
|
SlashCommand::Help => Some(SlashCommandResult {
|
||||||
|
message: render_slash_command_help(),
|
||||||
|
session: session.clone(),
|
||||||
|
}),
|
||||||
|
SlashCommand::Status
|
||||||
|
| SlashCommand::Model { .. }
|
||||||
|
| SlashCommand::Permissions { .. }
|
||||||
|
| SlashCommand::Clear { .. }
|
||||||
|
| SlashCommand::Cost
|
||||||
|
| SlashCommand::Resume { .. }
|
||||||
|
| SlashCommand::Config { .. }
|
||||||
|
| SlashCommand::Memory
|
||||||
|
| SlashCommand::Init
|
||||||
|
| SlashCommand::Diff
|
||||||
|
| SlashCommand::Version
|
||||||
|
| SlashCommand::Export { .. }
|
||||||
|
| SlashCommand::Session { .. }
|
||||||
|
| SlashCommand::Unknown(_) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::handle_slash_command;
|
use super::{
|
||||||
|
handle_slash_command, render_slash_command_help, resume_supported_slash_commands,
|
||||||
|
slash_command_specs, SlashCommand,
|
||||||
|
};
|
||||||
use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session};
|
use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_supported_slash_commands() {
|
||||||
|
assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help));
|
||||||
|
assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status));
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/model claude-opus"),
|
||||||
|
Some(SlashCommand::Model {
|
||||||
|
model: Some("claude-opus".to_string()),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/model"),
|
||||||
|
Some(SlashCommand::Model { model: None })
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/permissions read-only"),
|
||||||
|
Some(SlashCommand::Permissions {
|
||||||
|
mode: Some("read-only".to_string()),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/clear"),
|
||||||
|
Some(SlashCommand::Clear { confirm: false })
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/clear --confirm"),
|
||||||
|
Some(SlashCommand::Clear { confirm: true })
|
||||||
|
);
|
||||||
|
assert_eq!(SlashCommand::parse("/cost"), Some(SlashCommand::Cost));
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/resume session.json"),
|
||||||
|
Some(SlashCommand::Resume {
|
||||||
|
session_path: Some("session.json".to_string()),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/config"),
|
||||||
|
Some(SlashCommand::Config { section: None })
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/config env"),
|
||||||
|
Some(SlashCommand::Config {
|
||||||
|
section: Some("env".to_string())
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory));
|
||||||
|
assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init));
|
||||||
|
assert_eq!(SlashCommand::parse("/diff"), Some(SlashCommand::Diff));
|
||||||
|
assert_eq!(SlashCommand::parse("/version"), Some(SlashCommand::Version));
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/export notes.txt"),
|
||||||
|
Some(SlashCommand::Export {
|
||||||
|
path: Some("notes.txt".to_string())
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
SlashCommand::parse("/session switch abc123"),
|
||||||
|
Some(SlashCommand::Session {
|
||||||
|
action: Some("switch".to_string()),
|
||||||
|
target: Some("abc123".to_string())
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_help_from_shared_specs() {
|
||||||
|
let help = render_slash_command_help();
|
||||||
|
assert!(help.contains("works with --resume SESSION.json"));
|
||||||
|
assert!(help.contains("/help"));
|
||||||
|
assert!(help.contains("/status"));
|
||||||
|
assert!(help.contains("/compact"));
|
||||||
|
assert!(help.contains("/model [model]"));
|
||||||
|
assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]"));
|
||||||
|
assert!(help.contains("/clear [--confirm]"));
|
||||||
|
assert!(help.contains("/cost"));
|
||||||
|
assert!(help.contains("/resume <session-path>"));
|
||||||
|
assert!(help.contains("/config [env|hooks|model]"));
|
||||||
|
assert!(help.contains("/memory"));
|
||||||
|
assert!(help.contains("/init"));
|
||||||
|
assert!(help.contains("/diff"));
|
||||||
|
assert!(help.contains("/version"));
|
||||||
|
assert!(help.contains("/export [file]"));
|
||||||
|
assert!(help.contains("/session [list|switch <session-id>]"));
|
||||||
|
assert_eq!(slash_command_specs().len(), 15);
|
||||||
|
assert_eq!(resume_supported_slash_commands().len(), 11);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn compacts_sessions_via_slash_command() {
|
fn compacts_sessions_via_slash_command() {
|
||||||
let session = Session {
|
let session = Session {
|
||||||
@@ -103,8 +421,52 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ignores_unknown_slash_commands() {
|
fn help_command_is_non_mutating() {
|
||||||
|
let session = Session::new();
|
||||||
|
let result = handle_slash_command("/help", &session, CompactionConfig::default())
|
||||||
|
.expect("help command should be handled");
|
||||||
|
assert_eq!(result.session, session);
|
||||||
|
assert!(result.message.contains("Slash commands"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ignores_unknown_or_runtime_bound_slash_commands() {
|
||||||
let session = Session::new();
|
let session = Session::new();
|
||||||
assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none());
|
assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(handle_slash_command("/status", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(
|
||||||
|
handle_slash_command("/model claude", &session, CompactionConfig::default()).is_none()
|
||||||
|
);
|
||||||
|
assert!(handle_slash_command(
|
||||||
|
"/permissions read-only",
|
||||||
|
&session,
|
||||||
|
CompactionConfig::default()
|
||||||
|
)
|
||||||
|
.is_none());
|
||||||
|
assert!(handle_slash_command("/clear", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(
|
||||||
|
handle_slash_command("/clear --confirm", &session, CompactionConfig::default())
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
assert!(handle_slash_command("/cost", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(handle_slash_command(
|
||||||
|
"/resume session.json",
|
||||||
|
&session,
|
||||||
|
CompactionConfig::default()
|
||||||
|
)
|
||||||
|
.is_none());
|
||||||
|
assert!(handle_slash_command("/config", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(
|
||||||
|
handle_slash_command("/config env", &session, CompactionConfig::default()).is_none()
|
||||||
|
);
|
||||||
|
assert!(handle_slash_command("/diff", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(handle_slash_command("/version", &session, CompactionConfig::default()).is_none());
|
||||||
|
assert!(
|
||||||
|
handle_slash_command("/export note.txt", &session, CompactionConfig::default())
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
handle_slash_command("/session list", &session, CompactionConfig::default()).is_none()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ impl UpstreamPaths {
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.canonicalize()
|
.canonicalize()
|
||||||
.unwrap_or_else(|_| workspace_dir.as_ref().to_path_buf());
|
.unwrap_or_else(|_| workspace_dir.as_ref().to_path_buf());
|
||||||
let repo_root = workspace_dir
|
let primary_repo_root = workspace_dir
|
||||||
.parent()
|
.parent()
|
||||||
.map_or_else(|| PathBuf::from(".."), Path::to_path_buf);
|
.map_or_else(|| PathBuf::from(".."), Path::to_path_buf);
|
||||||
|
let repo_root = resolve_upstream_repo_root(&primary_repo_root);
|
||||||
Self { repo_root }
|
Self { repo_root }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +54,42 @@ pub struct ExtractedManifest {
|
|||||||
pub bootstrap: BootstrapPlan,
|
pub bootstrap: BootstrapPlan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resolve_upstream_repo_root(primary_repo_root: &Path) -> PathBuf {
|
||||||
|
let candidates = upstream_repo_candidates(primary_repo_root);
|
||||||
|
candidates
|
||||||
|
.into_iter()
|
||||||
|
.find(|candidate| candidate.join("src/commands.ts").is_file())
|
||||||
|
.unwrap_or_else(|| primary_repo_root.to_path_buf())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upstream_repo_candidates(primary_repo_root: &Path) -> Vec<PathBuf> {
|
||||||
|
let mut candidates = vec![primary_repo_root.to_path_buf()];
|
||||||
|
|
||||||
|
if let Some(explicit) = std::env::var_os("CLAUDE_CODE_UPSTREAM") {
|
||||||
|
candidates.push(PathBuf::from(explicit));
|
||||||
|
}
|
||||||
|
|
||||||
|
for ancestor in primary_repo_root.ancestors().take(4) {
|
||||||
|
candidates.push(ancestor.join("claude-code"));
|
||||||
|
candidates.push(ancestor.join("clawd-code"));
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates.push(
|
||||||
|
primary_repo_root
|
||||||
|
.join("reference-source")
|
||||||
|
.join("claude-code"),
|
||||||
|
);
|
||||||
|
candidates.push(primary_repo_root.join("vendor").join("claude-code"));
|
||||||
|
|
||||||
|
let mut deduped = Vec::new();
|
||||||
|
for candidate in candidates {
|
||||||
|
if !deduped.iter().any(|seen: &PathBuf| seen == &candidate) {
|
||||||
|
deduped.push(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deduped
|
||||||
|
}
|
||||||
|
|
||||||
pub fn extract_manifest(paths: &UpstreamPaths) -> std::io::Result<ExtractedManifest> {
|
pub fn extract_manifest(paths: &UpstreamPaths) -> std::io::Result<ExtractedManifest> {
|
||||||
let commands_source = fs::read_to_string(paths.commands_path())?;
|
let commands_source = fs::read_to_string(paths.commands_path())?;
|
||||||
let tools_source = fs::read_to_string(paths.tools_path())?;
|
let tools_source = fs::read_to_string(paths.tools_path())?;
|
||||||
@@ -270,9 +307,19 @@ mod tests {
|
|||||||
UpstreamPaths::from_workspace_dir(workspace_dir)
|
UpstreamPaths::from_workspace_dir(workspace_dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn has_upstream_fixture(paths: &UpstreamPaths) -> bool {
|
||||||
|
paths.commands_path().is_file()
|
||||||
|
&& paths.tools_path().is_file()
|
||||||
|
&& paths.cli_path().is_file()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extracts_non_empty_manifests_from_upstream_repo() {
|
fn extracts_non_empty_manifests_from_upstream_repo() {
|
||||||
let manifest = extract_manifest(&fixture_paths()).expect("manifest should load");
|
let paths = fixture_paths();
|
||||||
|
if !has_upstream_fixture(&paths) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let manifest = extract_manifest(&paths).expect("manifest should load");
|
||||||
assert!(!manifest.commands.entries().is_empty());
|
assert!(!manifest.commands.entries().is_empty());
|
||||||
assert!(!manifest.tools.entries().is_empty());
|
assert!(!manifest.tools.entries().is_empty());
|
||||||
assert!(!manifest.bootstrap.phases().is_empty());
|
assert!(!manifest.bootstrap.phases().is_empty());
|
||||||
@@ -280,9 +327,12 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn detects_known_upstream_command_symbols() {
|
fn detects_known_upstream_command_symbols() {
|
||||||
let commands = extract_commands(
|
let paths = fixture_paths();
|
||||||
&fs::read_to_string(fixture_paths().commands_path()).expect("commands.ts"),
|
if !paths.commands_path().is_file() {
|
||||||
);
|
return;
|
||||||
|
}
|
||||||
|
let commands =
|
||||||
|
extract_commands(&fs::read_to_string(paths.commands_path()).expect("commands.ts"));
|
||||||
let names: Vec<_> = commands
|
let names: Vec<_> = commands
|
||||||
.entries()
|
.entries()
|
||||||
.iter()
|
.iter()
|
||||||
@@ -295,8 +345,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn detects_known_upstream_tool_symbols() {
|
fn detects_known_upstream_tool_symbols() {
|
||||||
let tools =
|
let paths = fixture_paths();
|
||||||
extract_tools(&fs::read_to_string(fixture_paths().tools_path()).expect("tools.ts"));
|
if !paths.tools_path().is_file() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let tools = extract_tools(&fs::read_to_string(paths.tools_path()).expect("tools.ts"));
|
||||||
let names: Vec<_> = tools
|
let names: Vec<_> = tools
|
||||||
.entries()
|
.entries()
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ license.workspace = true
|
|||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
sha2 = "0.10"
|
||||||
glob = "0.3"
|
glob = "0.3"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
tokio = { version = "1", features = ["macros", "process", "rt", "rt-multi-thread", "time"] }
|
tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
|
||||||
walkdir = "2"
|
walkdir = "2"
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ impl Default for CompactionConfig {
|
|||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CompactionResult {
|
pub struct CompactionResult {
|
||||||
pub summary: String,
|
pub summary: String,
|
||||||
|
pub formatted_summary: String,
|
||||||
pub compacted_session: Session,
|
pub compacted_session: Session,
|
||||||
pub removed_message_count: usize,
|
pub removed_message_count: usize,
|
||||||
}
|
}
|
||||||
@@ -75,6 +76,7 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
|||||||
if !should_compact(session, config) {
|
if !should_compact(session, config) {
|
||||||
return CompactionResult {
|
return CompactionResult {
|
||||||
summary: String::new(),
|
summary: String::new(),
|
||||||
|
formatted_summary: String::new(),
|
||||||
compacted_session: session.clone(),
|
compacted_session: session.clone(),
|
||||||
removed_message_count: 0,
|
removed_message_count: 0,
|
||||||
};
|
};
|
||||||
@@ -87,6 +89,7 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
|||||||
let removed = &session.messages[..keep_from];
|
let removed = &session.messages[..keep_from];
|
||||||
let preserved = session.messages[keep_from..].to_vec();
|
let preserved = session.messages[keep_from..].to_vec();
|
||||||
let summary = summarize_messages(removed);
|
let summary = summarize_messages(removed);
|
||||||
|
let formatted_summary = format_compact_summary(&summary);
|
||||||
let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty());
|
let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty());
|
||||||
|
|
||||||
let mut compacted_messages = vec![ConversationMessage {
|
let mut compacted_messages = vec![ConversationMessage {
|
||||||
@@ -98,6 +101,7 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
|||||||
|
|
||||||
CompactionResult {
|
CompactionResult {
|
||||||
summary,
|
summary,
|
||||||
|
formatted_summary,
|
||||||
compacted_session: Session {
|
compacted_session: Session {
|
||||||
version: session.version,
|
version: session.version,
|
||||||
messages: compacted_messages,
|
messages: compacted_messages,
|
||||||
@@ -107,7 +111,73 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
||||||
let mut lines = vec!["<summary>".to_string(), "Conversation summary:".to_string()];
|
let user_messages = messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| message.role == MessageRole::User)
|
||||||
|
.count();
|
||||||
|
let assistant_messages = messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| message.role == MessageRole::Assistant)
|
||||||
|
.count();
|
||||||
|
let tool_messages = messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| message.role == MessageRole::Tool)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
let mut tool_names = messages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|message| message.blocks.iter())
|
||||||
|
.filter_map(|block| match block {
|
||||||
|
ContentBlock::ToolUse { name, .. } => Some(name.as_str()),
|
||||||
|
ContentBlock::ToolResult { tool_name, .. } => Some(tool_name.as_str()),
|
||||||
|
ContentBlock::Text { .. } => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
tool_names.sort_unstable();
|
||||||
|
tool_names.dedup();
|
||||||
|
|
||||||
|
let mut lines = vec![
|
||||||
|
"<summary>".to_string(),
|
||||||
|
"Conversation summary:".to_string(),
|
||||||
|
format!(
|
||||||
|
"- Scope: {} earlier messages compacted (user={}, assistant={}, tool={}).",
|
||||||
|
messages.len(),
|
||||||
|
user_messages,
|
||||||
|
assistant_messages,
|
||||||
|
tool_messages
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
if !tool_names.is_empty() {
|
||||||
|
lines.push(format!("- Tools mentioned: {}.", tool_names.join(", ")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let recent_user_requests = collect_recent_role_summaries(messages, MessageRole::User, 3);
|
||||||
|
if !recent_user_requests.is_empty() {
|
||||||
|
lines.push("- Recent user requests:".to_string());
|
||||||
|
lines.extend(
|
||||||
|
recent_user_requests
|
||||||
|
.into_iter()
|
||||||
|
.map(|request| format!(" - {request}")),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pending_work = infer_pending_work(messages);
|
||||||
|
if !pending_work.is_empty() {
|
||||||
|
lines.push("- Pending work:".to_string());
|
||||||
|
lines.extend(pending_work.into_iter().map(|item| format!(" - {item}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_files = collect_key_files(messages);
|
||||||
|
if !key_files.is_empty() {
|
||||||
|
lines.push(format!("- Key files referenced: {}.", key_files.join(", ")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(current_work) = infer_current_work(messages) {
|
||||||
|
lines.push(format!("- Current work: {current_work}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push("- Key timeline:".to_string());
|
||||||
for message in messages {
|
for message in messages {
|
||||||
let role = match message.role {
|
let role = match message.role {
|
||||||
MessageRole::System => "system",
|
MessageRole::System => "system",
|
||||||
@@ -121,7 +191,7 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
|||||||
.map(summarize_block)
|
.map(summarize_block)
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(" | ");
|
.join(" | ");
|
||||||
lines.push(format!("- {role}: {content}"));
|
lines.push(format!(" - {role}: {content}"));
|
||||||
}
|
}
|
||||||
lines.push("</summary>".to_string());
|
lines.push("</summary>".to_string());
|
||||||
lines.join("\n")
|
lines.join("\n")
|
||||||
@@ -144,6 +214,106 @@ fn summarize_block(block: &ContentBlock) -> String {
|
|||||||
truncate_summary(&raw, 160)
|
truncate_summary(&raw, 160)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn collect_recent_role_summaries(
|
||||||
|
messages: &[ConversationMessage],
|
||||||
|
role: MessageRole,
|
||||||
|
limit: usize,
|
||||||
|
) -> Vec<String> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| message.role == role)
|
||||||
|
.rev()
|
||||||
|
.filter_map(|message| first_text_block(message))
|
||||||
|
.take(limit)
|
||||||
|
.map(|text| truncate_summary(text, 160))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_pending_work(messages: &[ConversationMessage]) -> Vec<String> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.filter_map(first_text_block)
|
||||||
|
.filter(|text| {
|
||||||
|
let lowered = text.to_ascii_lowercase();
|
||||||
|
lowered.contains("todo")
|
||||||
|
|| lowered.contains("next")
|
||||||
|
|| lowered.contains("pending")
|
||||||
|
|| lowered.contains("follow up")
|
||||||
|
|| lowered.contains("remaining")
|
||||||
|
})
|
||||||
|
.take(3)
|
||||||
|
.map(|text| truncate_summary(text, 160))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_key_files(messages: &[ConversationMessage]) -> Vec<String> {
|
||||||
|
let mut files = messages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|message| message.blocks.iter())
|
||||||
|
.map(|block| match block {
|
||||||
|
ContentBlock::Text { text } => text.as_str(),
|
||||||
|
ContentBlock::ToolUse { input, .. } => input.as_str(),
|
||||||
|
ContentBlock::ToolResult { output, .. } => output.as_str(),
|
||||||
|
})
|
||||||
|
.flat_map(extract_file_candidates)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
files.sort();
|
||||||
|
files.dedup();
|
||||||
|
files.into_iter().take(8).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_current_work(messages: &[ConversationMessage]) -> Option<String> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.filter_map(first_text_block)
|
||||||
|
.find(|text| !text.trim().is_empty())
|
||||||
|
.map(|text| truncate_summary(text, 200))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn first_text_block(message: &ConversationMessage) -> Option<&str> {
|
||||||
|
message.blocks.iter().find_map(|block| match block {
|
||||||
|
ContentBlock::Text { text } if !text.trim().is_empty() => Some(text.as_str()),
|
||||||
|
ContentBlock::ToolUse { .. }
|
||||||
|
| ContentBlock::ToolResult { .. }
|
||||||
|
| ContentBlock::Text { .. } => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_interesting_extension(candidate: &str) -> bool {
|
||||||
|
std::path::Path::new(candidate)
|
||||||
|
.extension()
|
||||||
|
.and_then(|extension| extension.to_str())
|
||||||
|
.is_some_and(|extension| {
|
||||||
|
["rs", "ts", "tsx", "js", "json", "md"]
|
||||||
|
.iter()
|
||||||
|
.any(|expected| extension.eq_ignore_ascii_case(expected))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_file_candidates(content: &str) -> Vec<String> {
|
||||||
|
content
|
||||||
|
.split_whitespace()
|
||||||
|
.filter_map(|token| {
|
||||||
|
let candidate = token.trim_matches(|char: char| {
|
||||||
|
matches!(char, ',' | '.' | ':' | ';' | ')' | '(' | '"' | '\'' | '`')
|
||||||
|
});
|
||||||
|
if candidate.contains('/') && has_interesting_extension(candidate) {
|
||||||
|
Some(candidate.to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn truncate_summary(content: &str, max_chars: usize) -> String {
|
fn truncate_summary(content: &str, max_chars: usize) -> String {
|
||||||
if content.chars().count() <= max_chars {
|
if content.chars().count() <= max_chars {
|
||||||
return content.to_string();
|
return content.to_string();
|
||||||
@@ -207,8 +377,8 @@ fn collapse_blank_lines(content: &str) -> String {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
compact_session, estimate_session_tokens, format_compact_summary, should_compact,
|
collect_key_files, compact_session, estimate_session_tokens, format_compact_summary,
|
||||||
CompactionConfig,
|
infer_pending_work, should_compact, CompactionConfig,
|
||||||
};
|
};
|
||||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||||
|
|
||||||
@@ -229,6 +399,7 @@ mod tests {
|
|||||||
assert_eq!(result.removed_message_count, 0);
|
assert_eq!(result.removed_message_count, 0);
|
||||||
assert_eq!(result.compacted_session, session);
|
assert_eq!(result.compacted_session, session);
|
||||||
assert!(result.summary.is_empty());
|
assert!(result.summary.is_empty());
|
||||||
|
assert!(result.formatted_summary.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -268,6 +439,8 @@ mod tests {
|
|||||||
&result.compacted_session.messages[0].blocks[0],
|
&result.compacted_session.messages[0].blocks[0],
|
||||||
ContentBlock::Text { text } if text.contains("Summary:")
|
ContentBlock::Text { text } if text.contains("Summary:")
|
||||||
));
|
));
|
||||||
|
assert!(result.formatted_summary.contains("Scope:"));
|
||||||
|
assert!(result.formatted_summary.contains("Key timeline:"));
|
||||||
assert!(should_compact(
|
assert!(should_compact(
|
||||||
&session,
|
&session,
|
||||||
CompactionConfig {
|
CompactionConfig {
|
||||||
@@ -288,4 +461,25 @@ mod tests {
|
|||||||
assert!(summary.ends_with('…'));
|
assert!(summary.ends_with('…'));
|
||||||
assert!(summary.chars().count() <= 161);
|
assert!(summary.chars().count() <= 161);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_key_files_from_message_content() {
|
||||||
|
let files = collect_key_files(&[ConversationMessage::user_text(
|
||||||
|
"Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.",
|
||||||
|
)]);
|
||||||
|
assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string()));
|
||||||
|
assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infers_pending_work_from_recent_messages() {
|
||||||
|
let pending = infer_pending_work(&[
|
||||||
|
ConversationMessage::user_text("done"),
|
||||||
|
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||||
|
text: "Next: update tests and follow up on remaining CLI polish.".to_string(),
|
||||||
|
}]),
|
||||||
|
]);
|
||||||
|
assert_eq!(pending.len(), 1);
|
||||||
|
assert!(pending[0].contains("Next: update tests"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,13 @@ pub enum ConfigSource {
|
|||||||
Local,
|
Local,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ResolvedPermissionMode {
|
||||||
|
ReadOnly,
|
||||||
|
WorkspaceWrite,
|
||||||
|
DangerFullAccess,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ConfigEntry {
|
pub struct ConfigEntry {
|
||||||
pub source: ConfigSource,
|
pub source: ConfigSource,
|
||||||
@@ -31,6 +38,8 @@ pub struct RuntimeConfig {
|
|||||||
pub struct RuntimeFeatureConfig {
|
pub struct RuntimeFeatureConfig {
|
||||||
mcp: McpConfigCollection,
|
mcp: McpConfigCollection,
|
||||||
oauth: Option<OAuthConfig>,
|
oauth: Option<OAuthConfig>,
|
||||||
|
model: Option<String>,
|
||||||
|
permission_mode: Option<ResolvedPermissionMode>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
@@ -165,11 +174,23 @@ impl ConfigLoader {
|
|||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn discover(&self) -> Vec<ConfigEntry> {
|
pub fn discover(&self) -> Vec<ConfigEntry> {
|
||||||
|
let user_legacy_path = self.config_home.parent().map_or_else(
|
||||||
|
|| PathBuf::from(".claude.json"),
|
||||||
|
|parent| parent.join(".claude.json"),
|
||||||
|
);
|
||||||
vec![
|
vec![
|
||||||
|
ConfigEntry {
|
||||||
|
source: ConfigSource::User,
|
||||||
|
path: user_legacy_path,
|
||||||
|
},
|
||||||
ConfigEntry {
|
ConfigEntry {
|
||||||
source: ConfigSource::User,
|
source: ConfigSource::User,
|
||||||
path: self.config_home.join("settings.json"),
|
path: self.config_home.join("settings.json"),
|
||||||
},
|
},
|
||||||
|
ConfigEntry {
|
||||||
|
source: ConfigSource::Project,
|
||||||
|
path: self.cwd.join(".claude.json"),
|
||||||
|
},
|
||||||
ConfigEntry {
|
ConfigEntry {
|
||||||
source: ConfigSource::Project,
|
source: ConfigSource::Project,
|
||||||
path: self.cwd.join(".claude").join("settings.json"),
|
path: self.cwd.join(".claude").join("settings.json"),
|
||||||
@@ -195,14 +216,15 @@ impl ConfigLoader {
|
|||||||
loaded_entries.push(entry);
|
loaded_entries.push(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let merged_value = JsonValue::Object(merged.clone());
|
||||||
|
|
||||||
let feature_config = RuntimeFeatureConfig {
|
let feature_config = RuntimeFeatureConfig {
|
||||||
mcp: McpConfigCollection {
|
mcp: McpConfigCollection {
|
||||||
servers: mcp_servers,
|
servers: mcp_servers,
|
||||||
},
|
},
|
||||||
oauth: parse_optional_oauth_config(
|
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
|
||||||
&JsonValue::Object(merged.clone()),
|
model: parse_optional_model(&merged_value),
|
||||||
"merged settings.oauth",
|
permission_mode: parse_optional_permission_mode(&merged_value)?,
|
||||||
)?,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(RuntimeConfig {
|
Ok(RuntimeConfig {
|
||||||
@@ -257,6 +279,16 @@ impl RuntimeConfig {
|
|||||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
self.feature_config.oauth.as_ref()
|
self.feature_config.oauth.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn model(&self) -> Option<&str> {
|
||||||
|
self.feature_config.model.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||||
|
self.feature_config.permission_mode
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RuntimeFeatureConfig {
|
impl RuntimeFeatureConfig {
|
||||||
@@ -269,6 +301,16 @@ impl RuntimeFeatureConfig {
|
|||||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
self.oauth.as_ref()
|
self.oauth.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn model(&self) -> Option<&str> {
|
||||||
|
self.model.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||||
|
self.permission_mode
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl McpConfigCollection {
|
impl McpConfigCollection {
|
||||||
@@ -307,6 +349,7 @@ impl McpServerConfig {
|
|||||||
fn read_optional_json_object(
|
fn read_optional_json_object(
|
||||||
path: &Path,
|
path: &Path,
|
||||||
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
||||||
|
let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json");
|
||||||
let contents = match fs::read_to_string(path) {
|
let contents = match fs::read_to_string(path) {
|
||||||
Ok(contents) => contents,
|
Ok(contents) => contents,
|
||||||
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||||
@@ -317,14 +360,20 @@ fn read_optional_json_object(
|
|||||||
return Ok(Some(BTreeMap::new()));
|
return Ok(Some(BTreeMap::new()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let parsed = JsonValue::parse(&contents)
|
let parsed = match JsonValue::parse(&contents) {
|
||||||
.map_err(|error| ConfigError::Parse(format!("{}: {error}", path.display())))?;
|
Ok(parsed) => parsed,
|
||||||
let object = parsed.as_object().ok_or_else(|| {
|
Err(error) if is_legacy_config => return Ok(None),
|
||||||
ConfigError::Parse(format!(
|
Err(error) => return Err(ConfigError::Parse(format!("{}: {error}", path.display()))),
|
||||||
|
};
|
||||||
|
let Some(object) = parsed.as_object() else {
|
||||||
|
if is_legacy_config {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
return Err(ConfigError::Parse(format!(
|
||||||
"{}: top-level settings value must be a JSON object",
|
"{}: top-level settings value must be a JSON object",
|
||||||
path.display()
|
path.display()
|
||||||
))
|
)));
|
||||||
})?;
|
};
|
||||||
Ok(Some(object.clone()))
|
Ok(Some(object.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,6 +404,47 @@ fn merge_mcp_servers(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_optional_model(root: &JsonValue) -> Option<String> {
|
||||||
|
root.as_object()
|
||||||
|
.and_then(|object| object.get("model"))
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_optional_permission_mode(
|
||||||
|
root: &JsonValue,
|
||||||
|
) -> Result<Option<ResolvedPermissionMode>, ConfigError> {
|
||||||
|
let Some(object) = root.as_object() else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
if let Some(mode) = object.get("permissionMode").and_then(JsonValue::as_str) {
|
||||||
|
return parse_permission_mode_label(mode, "merged settings.permissionMode").map(Some);
|
||||||
|
}
|
||||||
|
let Some(mode) = object
|
||||||
|
.get("permissions")
|
||||||
|
.and_then(JsonValue::as_object)
|
||||||
|
.and_then(|permissions| permissions.get("defaultMode"))
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
parse_permission_mode_label(mode, "merged settings.permissions.defaultMode").map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_permission_mode_label(
|
||||||
|
mode: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<ResolvedPermissionMode, ConfigError> {
|
||||||
|
match mode {
|
||||||
|
"default" | "plan" | "read-only" => Ok(ResolvedPermissionMode::ReadOnly),
|
||||||
|
"acceptEdits" | "auto" | "workspace-write" => Ok(ResolvedPermissionMode::WorkspaceWrite),
|
||||||
|
"dontAsk" | "danger-full-access" => Ok(ResolvedPermissionMode::DangerFullAccess),
|
||||||
|
other => Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: unsupported permission mode {other}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_optional_oauth_config(
|
fn parse_optional_oauth_config(
|
||||||
root: &JsonValue,
|
root: &JsonValue,
|
||||||
context: &str,
|
context: &str,
|
||||||
@@ -594,7 +684,8 @@ fn deep_merge_objects(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode,
|
||||||
|
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
use crate::json::JsonValue;
|
use crate::json::JsonValue;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
@@ -635,14 +726,24 @@ mod tests {
|
|||||||
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||||
fs::create_dir_all(&home).expect("home config dir");
|
fs::create_dir_all(&home).expect("home config dir");
|
||||||
|
|
||||||
|
fs::write(
|
||||||
|
home.parent().expect("home parent").join(".claude.json"),
|
||||||
|
r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#,
|
||||||
|
)
|
||||||
|
.expect("write user compat config");
|
||||||
fs::write(
|
fs::write(
|
||||||
home.join("settings.json"),
|
home.join("settings.json"),
|
||||||
r#"{"model":"sonnet","env":{"A":"1"},"hooks":{"PreToolUse":["base"]}}"#,
|
r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#,
|
||||||
)
|
)
|
||||||
.expect("write user settings");
|
.expect("write user settings");
|
||||||
|
fs::write(
|
||||||
|
cwd.join(".claude.json"),
|
||||||
|
r#"{"model":"project-compat","env":{"B":"2"}}"#,
|
||||||
|
)
|
||||||
|
.expect("write project compat config");
|
||||||
fs::write(
|
fs::write(
|
||||||
cwd.join(".claude").join("settings.json"),
|
cwd.join(".claude").join("settings.json"),
|
||||||
r#"{"env":{"B":"2"},"hooks":{"PostToolUse":["project"]}}"#,
|
r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#,
|
||||||
)
|
)
|
||||||
.expect("write project settings");
|
.expect("write project settings");
|
||||||
fs::write(
|
fs::write(
|
||||||
@@ -656,25 +757,37 @@ mod tests {
|
|||||||
.expect("config should load");
|
.expect("config should load");
|
||||||
|
|
||||||
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
||||||
assert_eq!(loaded.loaded_entries().len(), 3);
|
assert_eq!(loaded.loaded_entries().len(), 5);
|
||||||
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
loaded.get("model"),
|
loaded.get("model"),
|
||||||
Some(&JsonValue::String("opus".to_string()))
|
Some(&JsonValue::String("opus".to_string()))
|
||||||
);
|
);
|
||||||
|
assert_eq!(loaded.model(), Some("opus"));
|
||||||
|
assert_eq!(
|
||||||
|
loaded.permission_mode(),
|
||||||
|
Some(ResolvedPermissionMode::WorkspaceWrite)
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
loaded
|
loaded
|
||||||
.get("env")
|
.get("env")
|
||||||
.and_then(JsonValue::as_object)
|
.and_then(JsonValue::as_object)
|
||||||
.expect("env object")
|
.expect("env object")
|
||||||
.len(),
|
.len(),
|
||||||
2
|
4
|
||||||
);
|
);
|
||||||
assert!(loaded
|
assert!(loaded
|
||||||
.get("hooks")
|
.get("hooks")
|
||||||
.and_then(JsonValue::as_object)
|
.and_then(JsonValue::as_object)
|
||||||
.expect("hooks object")
|
.expect("hooks object")
|
||||||
.contains_key("PreToolUse"));
|
.contains_key("PreToolUse"));
|
||||||
|
assert!(loaded
|
||||||
|
.get("hooks")
|
||||||
|
.and_then(JsonValue::as_object)
|
||||||
|
.expect("hooks object")
|
||||||
|
.contains_key("PostToolUse"));
|
||||||
|
assert!(loaded.mcp().get("home").is_some());
|
||||||
|
assert!(loaded.mcp().get("project").is_some());
|
||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ mod tests {
|
|||||||
.sum::<i32>();
|
.sum::<i32>();
|
||||||
Ok(total.to_string())
|
Ok(total.to_string())
|
||||||
});
|
});
|
||||||
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
|
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
|
||||||
let system_prompt = SystemPromptBuilder::new()
|
let system_prompt = SystemPromptBuilder::new()
|
||||||
.with_project_context(ProjectContext {
|
.with_project_context(ProjectContext {
|
||||||
cwd: PathBuf::from("/tmp/project"),
|
cwd: PathBuf::from("/tmp/project"),
|
||||||
@@ -487,7 +487,7 @@ mod tests {
|
|||||||
Session::new(),
|
Session::new(),
|
||||||
SingleCallApiClient,
|
SingleCallApiClient,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Prompt),
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -536,7 +536,7 @@ mod tests {
|
|||||||
session,
|
session,
|
||||||
SimpleApi,
|
SimpleApi,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Allow),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -563,7 +563,7 @@ mod tests {
|
|||||||
Session::new(),
|
Session::new(),
|
||||||
SimpleApi,
|
SimpleApi,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Allow),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
runtime.run_turn("a", None).expect("turn a");
|
runtime.run_turn("a", None).expect("turn a");
|
||||||
|
|||||||
@@ -296,12 +296,12 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Ok(file_content) = fs::read_to_string(&file_path) else {
|
let Ok(file_contents) = fs::read_to_string(&file_path) else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
if output_mode == "count" {
|
if output_mode == "count" {
|
||||||
let count = regex.find_iter(&file_content).count();
|
let count = regex.find_iter(&file_contents).count();
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
filenames.push(file_path.to_string_lossy().into_owned());
|
filenames.push(file_path.to_string_lossy().into_owned());
|
||||||
total_matches += count;
|
total_matches += count;
|
||||||
@@ -309,7 +309,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let lines: Vec<&str> = file_content.lines().collect();
|
let lines: Vec<&str> = file_contents.lines().collect();
|
||||||
let mut matched_lines = Vec::new();
|
let mut matched_lines = Vec::new();
|
||||||
for (index, line) in lines.iter().enumerate() {
|
for (index, line) in lines.iter().enumerate() {
|
||||||
if regex.is_match(line) {
|
if regex.is_match(line) {
|
||||||
@@ -327,13 +327,13 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
for index in matched_lines {
|
for index in matched_lines {
|
||||||
let start = index.saturating_sub(input.before.unwrap_or(context));
|
let start = index.saturating_sub(input.before.unwrap_or(context));
|
||||||
let end = (index + input.after.unwrap_or(context) + 1).min(lines.len());
|
let end = (index + input.after.unwrap_or(context) + 1).min(lines.len());
|
||||||
for (current, line_content) in lines.iter().enumerate().take(end).skip(start) {
|
for (current, line) in lines.iter().enumerate().take(end).skip(start) {
|
||||||
let prefix = if input.line_numbers.unwrap_or(true) {
|
let prefix = if input.line_numbers.unwrap_or(true) {
|
||||||
format!("{}:{}:", file_path.to_string_lossy(), current + 1)
|
format!("{}:{}:", file_path.to_string_lossy(), current + 1)
|
||||||
} else {
|
} else {
|
||||||
format!("{}:", file_path.to_string_lossy())
|
format!("{}:", file_path.to_string_lossy())
|
||||||
};
|
};
|
||||||
content_lines.push(format!("{prefix}{line_content}"));
|
content_lines.push(format!("{prefix}{line}"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -341,7 +341,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
|
|
||||||
let (filenames, applied_limit, applied_offset) =
|
let (filenames, applied_limit, applied_offset) =
|
||||||
apply_limit(filenames, input.head_limit, input.offset);
|
apply_limit(filenames, input.head_limit, input.offset);
|
||||||
let rendered_content = if output_mode == "content" {
|
let content_output = if output_mode == "content" {
|
||||||
let (lines, limit, offset) = apply_limit(content_lines, input.head_limit, input.offset);
|
let (lines, limit, offset) = apply_limit(content_lines, input.head_limit, input.offset);
|
||||||
return Ok(GrepSearchOutput {
|
return Ok(GrepSearchOutput {
|
||||||
mode: Some(output_mode),
|
mode: Some(output_mode),
|
||||||
@@ -361,7 +361,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
|||||||
mode: Some(output_mode.clone()),
|
mode: Some(output_mode.clone()),
|
||||||
num_files: filenames.len(),
|
num_files: filenames.len(),
|
||||||
filenames,
|
filenames,
|
||||||
content: rendered_content,
|
content: content_output,
|
||||||
num_lines: None,
|
num_lines: None,
|
||||||
num_matches: (output_mode == "count").then_some(total_matches),
|
num_matches: (output_mode == "count").then_some(total_matches),
|
||||||
applied_limit,
|
applied_limit,
|
||||||
|
|||||||
@@ -5,8 +5,13 @@ mod config;
|
|||||||
mod conversation;
|
mod conversation;
|
||||||
mod file_ops;
|
mod file_ops;
|
||||||
mod json;
|
mod json;
|
||||||
|
mod mcp;
|
||||||
|
mod mcp_client;
|
||||||
|
mod mcp_stdio;
|
||||||
|
mod oauth;
|
||||||
mod permissions;
|
mod permissions;
|
||||||
mod prompt;
|
mod prompt;
|
||||||
|
mod remote;
|
||||||
mod session;
|
mod session;
|
||||||
mod usage;
|
mod usage;
|
||||||
|
|
||||||
@@ -20,7 +25,8 @@ pub use config::{
|
|||||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
||||||
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||||
RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig,
|
||||||
|
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
pub use conversation::{
|
pub use conversation::{
|
||||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||||
@@ -31,6 +37,29 @@ pub use file_ops::{
|
|||||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||||
WriteFileOutput,
|
WriteFileOutput,
|
||||||
};
|
};
|
||||||
|
pub use mcp::{
|
||||||
|
mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp,
|
||||||
|
scoped_mcp_config_hash, unwrap_ccr_proxy_url,
|
||||||
|
};
|
||||||
|
pub use mcp_client::{
|
||||||
|
McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport,
|
||||||
|
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
|
||||||
|
};
|
||||||
|
pub use mcp_stdio::{
|
||||||
|
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||||
|
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
|
||||||
|
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
|
||||||
|
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
|
||||||
|
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
|
||||||
|
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
|
||||||
|
};
|
||||||
|
pub use oauth::{
|
||||||
|
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||||
|
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||||
|
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||||
|
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
|
PkceChallengeMethod, PkceCodePair,
|
||||||
|
};
|
||||||
pub use permissions::{
|
pub use permissions::{
|
||||||
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
||||||
PermissionPrompter, PermissionRequest,
|
PermissionPrompter, PermissionRequest,
|
||||||
@@ -39,5 +68,20 @@ pub use prompt::{
|
|||||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||||
};
|
};
|
||||||
|
pub use remote::{
|
||||||
|
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
|
||||||
|
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
|
||||||
|
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
|
||||||
|
};
|
||||||
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
||||||
pub use usage::{TokenUsage, UsageTracker};
|
pub use usage::{
|
||||||
|
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||||
|
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||||
|
}
|
||||||
|
|||||||
300
rust/crates/runtime/src/mcp.rs
Normal file
300
rust/crates/runtime/src/mcp.rs
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
use crate::config::{McpServerConfig, ScopedMcpServerConfig};
|
||||||
|
|
||||||
|
const CLAUDEAI_SERVER_PREFIX: &str = "claude.ai ";
|
||||||
|
const CCR_PROXY_PATH_MARKERS: [&str; 2] = ["/v2/session_ingress/shttp/mcp/", "/v2/ccr-sessions/"];
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn normalize_name_for_mcp(name: &str) -> String {
|
||||||
|
let mut normalized = name
|
||||||
|
.chars()
|
||||||
|
.map(|ch| match ch {
|
||||||
|
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '-' => ch,
|
||||||
|
_ => '_',
|
||||||
|
})
|
||||||
|
.collect::<String>();
|
||||||
|
|
||||||
|
if name.starts_with(CLAUDEAI_SERVER_PREFIX) {
|
||||||
|
normalized = collapse_underscores(&normalized)
|
||||||
|
.trim_matches('_')
|
||||||
|
.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_tool_prefix(server_name: &str) -> String {
|
||||||
|
format!("mcp__{}__", normalize_name_for_mcp(server_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_tool_name(server_name: &str, tool_name: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{}{}",
|
||||||
|
mcp_tool_prefix(server_name),
|
||||||
|
normalize_name_for_mcp(tool_name)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn unwrap_ccr_proxy_url(url: &str) -> String {
|
||||||
|
if !CCR_PROXY_PATH_MARKERS
|
||||||
|
.iter()
|
||||||
|
.any(|marker| url.contains(marker))
|
||||||
|
{
|
||||||
|
return url.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(query_start) = url.find('?') else {
|
||||||
|
return url.to_string();
|
||||||
|
};
|
||||||
|
let query = &url[query_start + 1..];
|
||||||
|
for pair in query.split('&') {
|
||||||
|
let mut parts = pair.splitn(2, '=');
|
||||||
|
if matches!(parts.next(), Some("mcp_url")) {
|
||||||
|
if let Some(value) = parts.next() {
|
||||||
|
return percent_decode(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
url.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> {
|
||||||
|
match config {
|
||||||
|
McpServerConfig::Stdio(config) => {
|
||||||
|
let mut command = vec![config.command.clone()];
|
||||||
|
command.extend(config.args.clone());
|
||||||
|
Some(format!("stdio:{}", render_command_signature(&command)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Sse(config) | McpServerConfig::Http(config) => {
|
||||||
|
Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Ws(config) => Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))),
|
||||||
|
McpServerConfig::ClaudeAiProxy(config) => {
|
||||||
|
Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url)))
|
||||||
|
}
|
||||||
|
McpServerConfig::Sdk(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String {
|
||||||
|
let rendered = match &config.config {
|
||||||
|
McpServerConfig::Stdio(stdio) => format!(
|
||||||
|
"stdio|{}|{}|{}",
|
||||||
|
stdio.command,
|
||||||
|
render_command_signature(&stdio.args),
|
||||||
|
render_env_signature(&stdio.env)
|
||||||
|
),
|
||||||
|
McpServerConfig::Sse(remote) => format!(
|
||||||
|
"sse|{}|{}|{}|{}",
|
||||||
|
remote.url,
|
||||||
|
render_env_signature(&remote.headers),
|
||||||
|
remote.headers_helper.as_deref().unwrap_or(""),
|
||||||
|
render_oauth_signature(remote.oauth.as_ref())
|
||||||
|
),
|
||||||
|
McpServerConfig::Http(remote) => format!(
|
||||||
|
"http|{}|{}|{}|{}",
|
||||||
|
remote.url,
|
||||||
|
render_env_signature(&remote.headers),
|
||||||
|
remote.headers_helper.as_deref().unwrap_or(""),
|
||||||
|
render_oauth_signature(remote.oauth.as_ref())
|
||||||
|
),
|
||||||
|
McpServerConfig::Ws(ws) => format!(
|
||||||
|
"ws|{}|{}|{}",
|
||||||
|
ws.url,
|
||||||
|
render_env_signature(&ws.headers),
|
||||||
|
ws.headers_helper.as_deref().unwrap_or("")
|
||||||
|
),
|
||||||
|
McpServerConfig::Sdk(sdk) => format!("sdk|{}", sdk.name),
|
||||||
|
McpServerConfig::ClaudeAiProxy(proxy) => {
|
||||||
|
format!("claudeai-proxy|{}|{}", proxy.url, proxy.id)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
stable_hex_hash(&rendered)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_command_signature(command: &[String]) -> String {
|
||||||
|
let escaped = command
|
||||||
|
.iter()
|
||||||
|
.map(|part| part.replace('\\', "\\\\").replace('|', "\\|"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
format!("[{}]", escaped.join("|"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_env_signature(map: &std::collections::BTreeMap<String, String>) -> String {
|
||||||
|
map.iter()
|
||||||
|
.map(|(key, value)| format!("{key}={value}"))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(";")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_oauth_signature(oauth: Option<&crate::config::McpOAuthConfig>) -> String {
|
||||||
|
oauth.map_or_else(String::new, |oauth| {
|
||||||
|
format!(
|
||||||
|
"{}|{}|{}|{}",
|
||||||
|
oauth.client_id.as_deref().unwrap_or(""),
|
||||||
|
oauth
|
||||||
|
.callback_port
|
||||||
|
.map_or_else(String::new, |port| port.to_string()),
|
||||||
|
oauth.auth_server_metadata_url.as_deref().unwrap_or(""),
|
||||||
|
oauth.xaa.map_or_else(String::new, |flag| flag.to_string())
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stable_hex_hash(value: &str) -> String {
|
||||||
|
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
|
||||||
|
for byte in value.as_bytes() {
|
||||||
|
hash ^= u64::from(*byte);
|
||||||
|
hash = hash.wrapping_mul(0x0100_0000_01b3);
|
||||||
|
}
|
||||||
|
format!("{hash:016x}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_underscores(value: &str) -> String {
|
||||||
|
let mut collapsed = String::with_capacity(value.len());
|
||||||
|
let mut last_was_underscore = false;
|
||||||
|
for ch in value.chars() {
|
||||||
|
if ch == '_' {
|
||||||
|
if !last_was_underscore {
|
||||||
|
collapsed.push(ch);
|
||||||
|
}
|
||||||
|
last_was_underscore = true;
|
||||||
|
} else {
|
||||||
|
collapsed.push(ch);
|
||||||
|
last_was_underscore = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
collapsed
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_decode(value: &str) -> String {
|
||||||
|
let bytes = value.as_bytes();
|
||||||
|
let mut decoded = Vec::with_capacity(bytes.len());
|
||||||
|
let mut index = 0;
|
||||||
|
while index < bytes.len() {
|
||||||
|
match bytes[index] {
|
||||||
|
b'%' if index + 2 < bytes.len() => {
|
||||||
|
let hex = &value[index + 1..index + 3];
|
||||||
|
if let Ok(byte) = u8::from_str_radix(hex, 16) {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 3;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
decoded.push(bytes[index]);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
b'+' => {
|
||||||
|
decoded.push(b' ');
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
byte => {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::from_utf8_lossy(&decoded).into_owned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
ConfigSource, McpRemoteServerConfig, McpServerConfig, McpStdioServerConfig,
|
||||||
|
McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
mcp_server_signature, mcp_tool_name, normalize_name_for_mcp, scoped_mcp_config_hash,
|
||||||
|
unwrap_ccr_proxy_url,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalizes_server_names_for_mcp_tooling() {
|
||||||
|
assert_eq!(normalize_name_for_mcp("github.com"), "github_com");
|
||||||
|
assert_eq!(normalize_name_for_mcp("tool name!"), "tool_name_");
|
||||||
|
assert_eq!(
|
||||||
|
normalize_name_for_mcp("claude.ai Example Server!!"),
|
||||||
|
"claude_ai_Example_Server"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
mcp_tool_name("claude.ai Example Server", "weather tool"),
|
||||||
|
"mcp__claude_ai_Example_Server__weather_tool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unwraps_ccr_proxy_urls_for_signature_matching() {
|
||||||
|
let wrapped = "https://api.anthropic.com/v2/session_ingress/shttp/mcp/123?mcp_url=https%3A%2F%2Fvendor.example%2Fmcp&other=1";
|
||||||
|
assert_eq!(unwrap_ccr_proxy_url(wrapped), "https://vendor.example/mcp");
|
||||||
|
assert_eq!(
|
||||||
|
unwrap_ccr_proxy_url("https://vendor.example/mcp"),
|
||||||
|
"https://vendor.example/mcp"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn computes_signatures_for_stdio_and_remote_servers() {
|
||||||
|
let stdio = McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: "uvx".to_string(),
|
||||||
|
args: vec!["mcp-server".to_string()],
|
||||||
|
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
mcp_server_signature(&stdio),
|
||||||
|
Some("stdio:[uvx|mcp-server]".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let remote = McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: "https://api.anthropic.com/v2/ccr-sessions/1?mcp_url=wss%3A%2F%2Fvendor.example%2Fmcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
mcp_server_signature(&remote),
|
||||||
|
Some("url:wss://vendor.example/mcp".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scoped_hash_ignores_scope_but_tracks_config_content() {
|
||||||
|
let base_config = McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::from([("Authorization".to_string(), "Bearer token".to_string())]),
|
||||||
|
headers_helper: Some("helper.sh".to_string()),
|
||||||
|
oauth: None,
|
||||||
|
});
|
||||||
|
let user = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::User,
|
||||||
|
config: base_config.clone(),
|
||||||
|
};
|
||||||
|
let local = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: base_config,
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
scoped_mcp_config_hash(&user),
|
||||||
|
scoped_mcp_config_hash(&local)
|
||||||
|
);
|
||||||
|
|
||||||
|
let changed = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/v2/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
oauth: None,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
assert_ne!(
|
||||||
|
scoped_mcp_config_hash(&user),
|
||||||
|
scoped_mcp_config_hash(&changed)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
236
rust/crates/runtime/src/mcp_client.rs
Normal file
236
rust/crates/runtime/src/mcp_client.rs
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
|
||||||
|
use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum McpClientTransport {
|
||||||
|
Stdio(McpStdioTransport),
|
||||||
|
Sse(McpRemoteTransport),
|
||||||
|
Http(McpRemoteTransport),
|
||||||
|
WebSocket(McpRemoteTransport),
|
||||||
|
Sdk(McpSdkTransport),
|
||||||
|
ClaudeAiProxy(McpClaudeAiProxyTransport),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpStdioTransport {
|
||||||
|
pub command: String,
|
||||||
|
pub args: Vec<String>,
|
||||||
|
pub env: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpRemoteTransport {
|
||||||
|
pub url: String,
|
||||||
|
pub headers: BTreeMap<String, String>,
|
||||||
|
pub headers_helper: Option<String>,
|
||||||
|
pub auth: McpClientAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpSdkTransport {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpClaudeAiProxyTransport {
|
||||||
|
pub url: String,
|
||||||
|
pub id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum McpClientAuth {
|
||||||
|
None,
|
||||||
|
OAuth(McpOAuthConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct McpClientBootstrap {
|
||||||
|
pub server_name: String,
|
||||||
|
pub normalized_name: String,
|
||||||
|
pub tool_prefix: String,
|
||||||
|
pub signature: Option<String>,
|
||||||
|
pub transport: McpClientTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientBootstrap {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
normalized_name: normalize_name_for_mcp(server_name),
|
||||||
|
tool_prefix: mcp_tool_prefix(server_name),
|
||||||
|
signature: mcp_server_signature(&config.config),
|
||||||
|
transport: McpClientTransport::from_config(&config.config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientTransport {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(config: &McpServerConfig) -> Self {
|
||||||
|
match config {
|
||||||
|
McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
|
||||||
|
command: config.command.clone(),
|
||||||
|
args: config.args.clone(),
|
||||||
|
env: config.env.clone(),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::from_oauth(config.oauth.clone()),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::from_oauth(config.oauth.clone()),
|
||||||
|
}),
|
||||||
|
McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
headers: config.headers.clone(),
|
||||||
|
headers_helper: config.headers_helper.clone(),
|
||||||
|
auth: McpClientAuth::None,
|
||||||
|
}),
|
||||||
|
McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
|
||||||
|
name: config.name.clone(),
|
||||||
|
}),
|
||||||
|
McpServerConfig::ClaudeAiProxy(config) => {
|
||||||
|
Self::ClaudeAiProxy(McpClaudeAiProxyTransport {
|
||||||
|
url: config.url.clone(),
|
||||||
|
id: config.id.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientAuth {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
|
||||||
|
oauth.map_or(Self::None, Self::OAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub const fn requires_user_auth(&self) -> bool {
|
||||||
|
matches!(self, Self::OAuth(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::config::{
|
||||||
|
ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
|
||||||
|
McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_stdio_servers_into_transport_targets() {
|
||||||
|
let config = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::User,
|
||||||
|
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: "uvx".to_string(),
|
||||||
|
args: vec!["mcp-server".to_string()],
|
||||||
|
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
|
||||||
|
assert_eq!(bootstrap.normalized_name, "stdio-server");
|
||||||
|
assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
|
||||||
|
assert_eq!(
|
||||||
|
bootstrap.signature.as_deref(),
|
||||||
|
Some("stdio:[uvx|mcp-server]")
|
||||||
|
);
|
||||||
|
match bootstrap.transport {
|
||||||
|
McpClientTransport::Stdio(transport) => {
|
||||||
|
assert_eq!(transport.command, "uvx");
|
||||||
|
assert_eq!(transport.args, vec!["mcp-server"]);
|
||||||
|
assert_eq!(
|
||||||
|
transport.env.get("TOKEN").map(String::as_str),
|
||||||
|
Some("secret")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected stdio transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_remote_servers_with_oauth_auth() {
|
||||||
|
let config = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Project,
|
||||||
|
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
|
||||||
|
headers_helper: Some("helper.sh".to_string()),
|
||||||
|
oauth: Some(McpOAuthConfig {
|
||||||
|
client_id: Some("client-id".to_string()),
|
||||||
|
callback_port: Some(7777),
|
||||||
|
auth_server_metadata_url: Some(
|
||||||
|
"https://issuer.example/.well-known/oauth-authorization-server".to_string(),
|
||||||
|
),
|
||||||
|
xaa: Some(true),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
|
||||||
|
assert_eq!(bootstrap.normalized_name, "remote_server");
|
||||||
|
match bootstrap.transport {
|
||||||
|
McpClientTransport::Http(transport) => {
|
||||||
|
assert_eq!(transport.url, "https://vendor.example/mcp");
|
||||||
|
assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
|
||||||
|
assert!(transport.auth.requires_user_auth());
|
||||||
|
match transport.auth {
|
||||||
|
McpClientAuth::OAuth(oauth) => {
|
||||||
|
assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
|
||||||
|
}
|
||||||
|
other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => panic!("expected http transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstraps_websocket_and_sdk_transports_without_oauth() {
|
||||||
|
let ws = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: "wss://vendor.example/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let sdk = ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Sdk(McpSdkServerConfig {
|
||||||
|
name: "sdk-server".to_string(),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
|
||||||
|
match ws_bootstrap.transport {
|
||||||
|
McpClientTransport::WebSocket(transport) => {
|
||||||
|
assert_eq!(transport.url, "wss://vendor.example/mcp");
|
||||||
|
assert!(!transport.auth.requires_user_auth());
|
||||||
|
}
|
||||||
|
other => panic!("expected websocket transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
|
||||||
|
assert_eq!(sdk_bootstrap.signature, None);
|
||||||
|
match sdk_bootstrap.transport {
|
||||||
|
McpClientTransport::Sdk(transport) => {
|
||||||
|
assert_eq!(transport.name, "sdk-server");
|
||||||
|
}
|
||||||
|
other => panic!("expected sdk transport, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1697
rust/crates/runtime/src/mcp_stdio.rs
Normal file
1697
rust/crates/runtime/src/mcp_stdio.rs
Normal file
File diff suppressed because it is too large
Load Diff
589
rust/crates/runtime/src/oauth.rs
Normal file
589
rust/crates/runtime/src/oauth.rs
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs::{self, File};
|
||||||
|
use std::io::{self, Read};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Map, Value};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
use crate::config::OAuthConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub struct OAuthTokenSet {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PkceCodePair {
|
||||||
|
pub verifier: String,
|
||||||
|
pub challenge: String,
|
||||||
|
pub challenge_method: PkceChallengeMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum PkceChallengeMethod {
|
||||||
|
S256,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PkceChallengeMethod {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::S256 => "S256",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthAuthorizationRequest {
|
||||||
|
pub authorize_url: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
pub state: String,
|
||||||
|
pub code_challenge: String,
|
||||||
|
pub code_challenge_method: PkceChallengeMethod,
|
||||||
|
pub extra_params: BTreeMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthTokenExchangeRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub code: String,
|
||||||
|
pub redirect_uri: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub code_verifier: String,
|
||||||
|
pub state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthRefreshRequest {
|
||||||
|
pub grant_type: &'static str,
|
||||||
|
pub refresh_token: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthCallbackParams {
|
||||||
|
pub code: Option<String>,
|
||||||
|
pub state: Option<String>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
pub error_description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct StoredOAuthCredentials {
|
||||||
|
access_token: String,
|
||||||
|
#[serde(default)]
|
||||||
|
refresh_token: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
expires_at: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuthTokenSet> for StoredOAuthCredentials {
|
||||||
|
fn from(value: OAuthTokenSet) -> Self {
|
||||||
|
Self {
|
||||||
|
access_token: value.access_token,
|
||||||
|
refresh_token: value.refresh_token,
|
||||||
|
expires_at: value.expires_at,
|
||||||
|
scopes: value.scopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StoredOAuthCredentials> for OAuthTokenSet {
|
||||||
|
fn from(value: StoredOAuthCredentials) -> Self {
|
||||||
|
Self {
|
||||||
|
access_token: value.access_token,
|
||||||
|
refresh_token: value.refresh_token,
|
||||||
|
expires_at: value.expires_at,
|
||||||
|
scopes: value.scopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthAuthorizationRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
pkce: &PkceCodePair,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
authorize_url: config.authorize_url.clone(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
scopes: config.scopes.clone(),
|
||||||
|
state: state.into(),
|
||||||
|
code_challenge: pkce.challenge.clone(),
|
||||||
|
code_challenge_method: pkce.challenge_method,
|
||||||
|
extra_params: BTreeMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||||
|
self.extra_params.insert(key.into(), value.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn build_url(&self) -> String {
|
||||||
|
let mut params = vec![
|
||||||
|
("response_type", "code".to_string()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
("code_challenge", self.code_challenge.clone()),
|
||||||
|
(
|
||||||
|
"code_challenge_method",
|
||||||
|
self.code_challenge_method.as_str().to_string(),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
params.extend(
|
||||||
|
self.extra_params
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.as_str(), value.clone())),
|
||||||
|
);
|
||||||
|
let query = params
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("&");
|
||||||
|
format!(
|
||||||
|
"{}{}{}",
|
||||||
|
self.authorize_url,
|
||||||
|
if self.authorize_url.contains('?') {
|
||||||
|
'&'
|
||||||
|
} else {
|
||||||
|
'?'
|
||||||
|
},
|
||||||
|
query
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthTokenExchangeRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
code: impl Into<String>,
|
||||||
|
state: impl Into<String>,
|
||||||
|
verifier: impl Into<String>,
|
||||||
|
redirect_uri: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
grant_type: "authorization_code",
|
||||||
|
code: code.into(),
|
||||||
|
redirect_uri: redirect_uri.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
code_verifier: verifier.into(),
|
||||||
|
state: state.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("code", self.code.clone()),
|
||||||
|
("redirect_uri", self.redirect_uri.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("code_verifier", self.code_verifier.clone()),
|
||||||
|
("state", self.state.clone()),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthRefreshRequest {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_config(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
refresh_token: impl Into<String>,
|
||||||
|
scopes: Option<Vec<String>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
grant_type: "refresh_token",
|
||||||
|
refresh_token: refresh_token.into(),
|
||||||
|
client_id: config.client_id.clone(),
|
||||||
|
scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
||||||
|
BTreeMap::from([
|
||||||
|
("grant_type", self.grant_type.to_string()),
|
||||||
|
("refresh_token", self.refresh_token.clone()),
|
||||||
|
("client_id", self.client_id.clone()),
|
||||||
|
("scope", self.scopes.join(" ")),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
|
||||||
|
let verifier = generate_random_token(32)?;
|
||||||
|
Ok(PkceCodePair {
|
||||||
|
challenge: code_challenge_s256(&verifier),
|
||||||
|
verifier,
|
||||||
|
challenge_method: PkceChallengeMethod::S256,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_state() -> io::Result<String> {
|
||||||
|
generate_random_token(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn code_challenge_s256(verifier: &str) -> String {
|
||||||
|
let digest = Sha256::digest(verifier.as_bytes());
|
||||||
|
base64url_encode(&digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn loopback_redirect_uri(port: u16) -> String {
|
||||||
|
format!("http://localhost:{port}/callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn credentials_path() -> io::Result<PathBuf> {
|
||||||
|
Ok(credentials_home_dir()?.join("credentials.json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let root = read_credentials_root(&path)?;
|
||||||
|
let Some(oauth) = root.get("oauth") else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
if oauth.is_null() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||||
|
Ok(Some(stored.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let mut root = read_credentials_root(&path)?;
|
||||||
|
root.insert(
|
||||||
|
"oauth".to_string(),
|
||||||
|
serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
||||||
|
);
|
||||||
|
write_credentials_root(&path, &root)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_oauth_credentials() -> io::Result<()> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let mut root = read_credentials_root(&path)?;
|
||||||
|
root.remove("oauth");
|
||||||
|
write_credentials_root(&path, &root)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
|
||||||
|
let (path, query) = target
|
||||||
|
.split_once('?')
|
||||||
|
.map_or((target, ""), |(path, query)| (path, query));
|
||||||
|
if path != "/callback" {
|
||||||
|
return Err(format!("unexpected callback path: {path}"));
|
||||||
|
}
|
||||||
|
parse_oauth_callback_query(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
|
||||||
|
let mut params = BTreeMap::new();
|
||||||
|
for pair in query.split('&').filter(|pair| !pair.is_empty()) {
|
||||||
|
let (key, value) = pair
|
||||||
|
.split_once('=')
|
||||||
|
.map_or((pair, ""), |(key, value)| (key, value));
|
||||||
|
params.insert(percent_decode(key)?, percent_decode(value)?);
|
||||||
|
}
|
||||||
|
Ok(OAuthCallbackParams {
|
||||||
|
code: params.get("code").cloned(),
|
||||||
|
state: params.get("state").cloned(),
|
||||||
|
error: params.get("error").cloned(),
|
||||||
|
error_description: params.get("error_description").cloned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||||
|
let mut buffer = vec![0_u8; bytes];
|
||||||
|
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||||
|
Ok(base64url_encode(&buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn credentials_home_dir() -> io::Result<PathBuf> {
|
||||||
|
if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
||||||
|
return Ok(PathBuf::from(path));
|
||||||
|
}
|
||||||
|
let home = std::env::var_os("HOME")
|
||||||
|
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
|
||||||
|
Ok(PathBuf::from(home).join(".claude"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
|
||||||
|
match fs::read_to_string(path) {
|
||||||
|
Ok(contents) => {
|
||||||
|
if contents.trim().is_empty() {
|
||||||
|
return Ok(Map::new());
|
||||||
|
}
|
||||||
|
serde_json::from_str::<Value>(&contents)
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
|
||||||
|
.as_object()
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
"credentials file must contain a JSON object",
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||||
|
let temp_path = path.with_extension("json.tmp");
|
||||||
|
fs::write(&temp_path, format!("{rendered}\n"))?;
|
||||||
|
fs::rename(temp_path, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn base64url_encode(bytes: &[u8]) -> String {
|
||||||
|
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||||
|
let mut output = String::new();
|
||||||
|
let mut index = 0;
|
||||||
|
while index + 3 <= bytes.len() {
|
||||||
|
let block = (u32::from(bytes[index]) << 16)
|
||||||
|
| (u32::from(bytes[index + 1]) << 8)
|
||||||
|
| u32::from(bytes[index + 2]);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[(block & 0x3F) as usize] as char);
|
||||||
|
index += 3;
|
||||||
|
}
|
||||||
|
match bytes.len().saturating_sub(index) {
|
||||||
|
1 => {
|
||||||
|
let block = u32::from(bytes[index]) << 16;
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
|
||||||
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
||||||
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_encode(value: &str) -> String {
|
||||||
|
let mut encoded = String::new();
|
||||||
|
for byte in value.bytes() {
|
||||||
|
match byte {
|
||||||
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
||||||
|
encoded.push(char::from(byte));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
let _ = write!(&mut encoded, "%{byte:02X}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
fn percent_decode(value: &str) -> Result<String, String> {
|
||||||
|
let mut decoded = Vec::with_capacity(value.len());
|
||||||
|
let bytes = value.as_bytes();
|
||||||
|
let mut index = 0;
|
||||||
|
while index < bytes.len() {
|
||||||
|
match bytes[index] {
|
||||||
|
b'%' if index + 2 < bytes.len() => {
|
||||||
|
let hi = decode_hex(bytes[index + 1])?;
|
||||||
|
let lo = decode_hex(bytes[index + 2])?;
|
||||||
|
decoded.push((hi << 4) | lo);
|
||||||
|
index += 3;
|
||||||
|
}
|
||||||
|
b'+' => {
|
||||||
|
decoded.push(b' ');
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
byte => {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::from_utf8(decoded).map_err(|error| error.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||||
|
match byte {
|
||||||
|
b'0'..=b'9' => Ok(byte - b'0'),
|
||||||
|
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||||
|
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||||
|
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||||
|
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||||
|
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||||
|
OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn sample_config() -> OAuthConfig {
|
||||||
|
OAuthConfig {
|
||||||
|
client_id: "runtime-client".to_string(),
|
||||||
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||||
|
token_url: "https://console.test/oauth/token".to_string(),
|
||||||
|
callback_port: Some(4545),
|
||||||
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||||
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
crate::test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn temp_config_home() -> std::path::PathBuf {
|
||||||
|
std::env::temp_dir().join(format!(
|
||||||
|
"runtime-oauth-test-{}-{}",
|
||||||
|
std::process::id(),
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time")
|
||||||
|
.as_nanos()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn s256_challenge_matches_expected_vector() {
|
||||||
|
assert_eq!(
|
||||||
|
code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
|
||||||
|
"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generates_pkce_pair_and_state() {
|
||||||
|
let pair = generate_pkce_pair().expect("pkce pair");
|
||||||
|
let state = generate_state().expect("state");
|
||||||
|
assert!(!pair.verifier.is_empty());
|
||||||
|
assert!(!pair.challenge.is_empty());
|
||||||
|
assert!(!state.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn builds_authorize_url_and_form_requests() {
|
||||||
|
let config = sample_config();
|
||||||
|
let pair = generate_pkce_pair().expect("pkce");
|
||||||
|
let url = OAuthAuthorizationRequest::from_config(
|
||||||
|
&config,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
"state-123",
|
||||||
|
&pair,
|
||||||
|
)
|
||||||
|
.with_extra_param("login_hint", "user@example.com")
|
||||||
|
.build_url();
|
||||||
|
assert!(url.starts_with("https://console.test/oauth/authorize?"));
|
||||||
|
assert!(url.contains("response_type=code"));
|
||||||
|
assert!(url.contains("client_id=runtime-client"));
|
||||||
|
assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
|
||||||
|
assert!(url.contains("login_hint=user%40example.com"));
|
||||||
|
|
||||||
|
let exchange = OAuthTokenExchangeRequest::from_config(
|
||||||
|
&config,
|
||||||
|
"auth-code",
|
||||||
|
"state-123",
|
||||||
|
pair.verifier,
|
||||||
|
loopback_redirect_uri(4545),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
exchange.form_params().get("grant_type").map(String::as_str),
|
||||||
|
Some("authorization_code")
|
||||||
|
);
|
||||||
|
|
||||||
|
let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
|
||||||
|
assert_eq!(
|
||||||
|
refresh.form_params().get("scope").map(String::as_str),
|
||||||
|
Some("org:read user:write")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
let path = credentials_path().expect("credentials path");
|
||||||
|
std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
|
||||||
|
std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
|
||||||
|
|
||||||
|
let token_set = OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(123),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
};
|
||||||
|
save_oauth_credentials(&token_set).expect("save credentials");
|
||||||
|
assert_eq!(
|
||||||
|
load_oauth_credentials().expect("load credentials"),
|
||||||
|
Some(token_set)
|
||||||
|
);
|
||||||
|
let saved = std::fs::read_to_string(&path).expect("read saved file");
|
||||||
|
assert!(saved.contains("\"other\": \"value\""));
|
||||||
|
assert!(saved.contains("\"oauth\""));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
assert_eq!(load_oauth_credentials().expect("load cleared"), None);
|
||||||
|
let cleared = std::fs::read_to_string(&path).expect("read cleared file");
|
||||||
|
assert!(cleared.contains("\"other\": \"value\""));
|
||||||
|
assert!(!cleared.contains("\"oauth\""));
|
||||||
|
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_callback_query_and_target() {
|
||||||
|
let params =
|
||||||
|
parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
|
||||||
|
.expect("parse query");
|
||||||
|
assert_eq!(params.code.as_deref(), Some("abc123"));
|
||||||
|
assert_eq!(params.state.as_deref(), Some("state-1"));
|
||||||
|
assert_eq!(params.error_description.as_deref(), Some("needs login"));
|
||||||
|
|
||||||
|
let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
|
||||||
|
.expect("parse callback target");
|
||||||
|
assert_eq!(params.code.as_deref(), Some("abc"));
|
||||||
|
assert_eq!(params.state.as_deref(), Some("xyz"));
|
||||||
|
assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,29 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum PermissionMode {
|
pub enum PermissionMode {
|
||||||
Allow,
|
ReadOnly,
|
||||||
Deny,
|
WorkspaceWrite,
|
||||||
Prompt,
|
DangerFullAccess,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionMode {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::ReadOnly => "read-only",
|
||||||
|
Self::WorkspaceWrite => "workspace-write",
|
||||||
|
Self::DangerFullAccess => "danger-full-access",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PermissionRequest {
|
pub struct PermissionRequest {
|
||||||
pub tool_name: String,
|
pub tool_name: String,
|
||||||
pub input: String,
|
pub input: String,
|
||||||
|
pub current_mode: PermissionMode,
|
||||||
|
pub required_mode: PermissionMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
@@ -31,31 +44,41 @@ pub enum PermissionOutcome {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PermissionPolicy {
|
pub struct PermissionPolicy {
|
||||||
default_mode: PermissionMode,
|
active_mode: PermissionMode,
|
||||||
tool_modes: BTreeMap<String, PermissionMode>,
|
tool_requirements: BTreeMap<String, PermissionMode>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PermissionPolicy {
|
impl PermissionPolicy {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(default_mode: PermissionMode) -> Self {
|
pub fn new(active_mode: PermissionMode) -> Self {
|
||||||
Self {
|
Self {
|
||||||
default_mode,
|
active_mode,
|
||||||
tool_modes: BTreeMap::new(),
|
tool_requirements: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_tool_mode(mut self, tool_name: impl Into<String>, mode: PermissionMode) -> Self {
|
pub fn with_tool_requirement(
|
||||||
self.tool_modes.insert(tool_name.into(), mode);
|
mut self,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
required_mode: PermissionMode,
|
||||||
|
) -> Self {
|
||||||
|
self.tool_requirements
|
||||||
|
.insert(tool_name.into(), required_mode);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn mode_for(&self, tool_name: &str) -> PermissionMode {
|
pub fn active_mode(&self) -> PermissionMode {
|
||||||
self.tool_modes
|
self.active_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn required_mode_for(&self, tool_name: &str) -> PermissionMode {
|
||||||
|
self.tool_requirements
|
||||||
.get(tool_name)
|
.get(tool_name)
|
||||||
.copied()
|
.copied()
|
||||||
.unwrap_or(self.default_mode)
|
.unwrap_or(PermissionMode::DangerFullAccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@@ -65,23 +88,43 @@ impl PermissionPolicy {
|
|||||||
input: &str,
|
input: &str,
|
||||||
mut prompter: Option<&mut dyn PermissionPrompter>,
|
mut prompter: Option<&mut dyn PermissionPrompter>,
|
||||||
) -> PermissionOutcome {
|
) -> PermissionOutcome {
|
||||||
match self.mode_for(tool_name) {
|
let current_mode = self.active_mode();
|
||||||
PermissionMode::Allow => PermissionOutcome::Allow,
|
let required_mode = self.required_mode_for(tool_name);
|
||||||
PermissionMode::Deny => PermissionOutcome::Deny {
|
if current_mode >= required_mode {
|
||||||
reason: format!("tool '{tool_name}' denied by permission policy"),
|
return PermissionOutcome::Allow;
|
||||||
},
|
}
|
||||||
PermissionMode::Prompt => match prompter.as_mut() {
|
|
||||||
Some(prompter) => match prompter.decide(&PermissionRequest {
|
let request = PermissionRequest {
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: tool_name.to_string(),
|
||||||
input: input.to_string(),
|
input: input.to_string(),
|
||||||
}) {
|
current_mode,
|
||||||
|
required_mode,
|
||||||
|
};
|
||||||
|
|
||||||
|
if current_mode == PermissionMode::WorkspaceWrite
|
||||||
|
&& required_mode == PermissionMode::DangerFullAccess
|
||||||
|
{
|
||||||
|
return match prompter.as_mut() {
|
||||||
|
Some(prompter) => match prompter.decide(&request) {
|
||||||
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
||||||
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
||||||
},
|
},
|
||||||
None => PermissionOutcome::Deny {
|
None => PermissionOutcome::Deny {
|
||||||
reason: format!("tool '{tool_name}' requires interactive approval"),
|
reason: format!(
|
||||||
|
"tool '{tool_name}' requires approval to escalate from {} to {}",
|
||||||
|
current_mode.as_str(),
|
||||||
|
required_mode.as_str()
|
||||||
|
),
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
PermissionOutcome::Deny {
|
||||||
|
reason: format!(
|
||||||
|
"tool '{tool_name}' requires {} permission; current mode is {}",
|
||||||
|
required_mode.as_str(),
|
||||||
|
current_mode.as_str()
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,25 +136,92 @@ mod tests {
|
|||||||
PermissionPrompter, PermissionRequest,
|
PermissionPrompter, PermissionRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AllowPrompter;
|
struct RecordingPrompter {
|
||||||
|
seen: Vec<PermissionRequest>,
|
||||||
|
allow: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl PermissionPrompter for AllowPrompter {
|
impl PermissionPrompter for RecordingPrompter {
|
||||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||||
assert_eq!(request.tool_name, "bash");
|
self.seen.push(request.clone());
|
||||||
PermissionPromptDecision::Allow
|
if self.allow {
|
||||||
|
PermissionPromptDecision::Allow
|
||||||
|
} else {
|
||||||
|
PermissionPromptDecision::Deny {
|
||||||
|
reason: "not now".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn uses_tool_specific_overrides() {
|
fn allows_tools_when_active_mode_meets_requirement() {
|
||||||
let policy = PermissionPolicy::new(PermissionMode::Deny)
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
.with_tool_mode("bash", PermissionMode::Prompt);
|
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
|
||||||
|
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
policy.authorize("read_file", "{}", None),
|
||||||
|
PermissionOutcome::Allow
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
policy.authorize("write_file", "{}", None),
|
||||||
|
PermissionOutcome::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn denies_read_only_escalations_without_prompt() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||||
|
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
|
||||||
let outcome = policy.authorize("bash", "echo hi", Some(&mut AllowPrompter));
|
|
||||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
policy.authorize("edit", "x", None),
|
policy.authorize("write_file", "{}", None),
|
||||||
PermissionOutcome::Deny { .. }
|
PermissionOutcome::Deny { reason } if reason.contains("requires workspace-write permission")
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
policy.authorize("bash", "{}", None),
|
||||||
|
PermissionOutcome::Deny { reason } if reason.contains("requires danger-full-access permission")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prompts_for_workspace_write_to_danger_full_access_escalation() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
let mut prompter = RecordingPrompter {
|
||||||
|
seen: Vec::new(),
|
||||||
|
allow: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let outcome = policy.authorize("bash", "echo hi", Some(&mut prompter));
|
||||||
|
|
||||||
|
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||||
|
assert_eq!(prompter.seen.len(), 1);
|
||||||
|
assert_eq!(prompter.seen[0].tool_name, "bash");
|
||||||
|
assert_eq!(
|
||||||
|
prompter.seen[0].current_mode,
|
||||||
|
PermissionMode::WorkspaceWrite
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
prompter.seen[0].required_mode,
|
||||||
|
PermissionMode::DangerFullAccess
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn honors_prompt_rejection_reason() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
let mut prompter = RecordingPrompter {
|
||||||
|
seen: Vec::new(),
|
||||||
|
allow: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
policy.authorize("bash", "echo hi", Some(&mut prompter)),
|
||||||
|
PermissionOutcome::Deny { reason } if reason == "not now"
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
@@ -35,6 +36,8 @@ impl From<ConfigError> for PromptBuildError {
|
|||||||
|
|
||||||
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
||||||
pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6";
|
pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6";
|
||||||
|
const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000;
|
||||||
|
const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ContextFile {
|
pub struct ContextFile {
|
||||||
@@ -198,11 +201,12 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
|||||||
dir.join("CLAUDE.md"),
|
dir.join("CLAUDE.md"),
|
||||||
dir.join("CLAUDE.local.md"),
|
dir.join("CLAUDE.local.md"),
|
||||||
dir.join(".claude").join("CLAUDE.md"),
|
dir.join(".claude").join("CLAUDE.md"),
|
||||||
|
dir.join(".claude").join("instructions.md"),
|
||||||
] {
|
] {
|
||||||
push_context_file(&mut files, candidate)?;
|
push_context_file(&mut files, candidate)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(files)
|
Ok(dedupe_instruction_files(files))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn push_context_file(files: &mut Vec<ContextFile>, path: PathBuf) -> std::io::Result<()> {
|
fn push_context_file(files: &mut Vec<ContextFile>, path: PathBuf) -> std::io::Result<()> {
|
||||||
@@ -237,10 +241,17 @@ fn read_git_status(cwd: &Path) -> Option<String> {
|
|||||||
|
|
||||||
fn render_project_context(project_context: &ProjectContext) -> String {
|
fn render_project_context(project_context: &ProjectContext) -> String {
|
||||||
let mut lines = vec!["# Project context".to_string()];
|
let mut lines = vec!["# Project context".to_string()];
|
||||||
lines.extend(prepend_bullets(vec![format!(
|
let mut bullets = vec![
|
||||||
"Today's date is {}.",
|
format!("Today's date is {}.", project_context.current_date),
|
||||||
project_context.current_date
|
format!("Working directory: {}", project_context.cwd.display()),
|
||||||
)]));
|
];
|
||||||
|
if !project_context.instruction_files.is_empty() {
|
||||||
|
bullets.push(format!(
|
||||||
|
"Claude instruction files discovered: {}.",
|
||||||
|
project_context.instruction_files.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
lines.extend(prepend_bullets(bullets));
|
||||||
if let Some(status) = &project_context.git_status {
|
if let Some(status) = &project_context.git_status {
|
||||||
lines.push(String::new());
|
lines.push(String::new());
|
||||||
lines.push("Git status snapshot:".to_string());
|
lines.push("Git status snapshot:".to_string());
|
||||||
@@ -251,13 +262,105 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
|||||||
|
|
||||||
fn render_instruction_files(files: &[ContextFile]) -> String {
|
fn render_instruction_files(files: &[ContextFile]) -> String {
|
||||||
let mut sections = vec!["# Claude instructions".to_string()];
|
let mut sections = vec!["# Claude instructions".to_string()];
|
||||||
|
let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS;
|
||||||
for file in files {
|
for file in files {
|
||||||
sections.push(format!("## {}", file.path.display()));
|
if remaining_chars == 0 {
|
||||||
sections.push(file.content.trim().to_string());
|
sections.push(
|
||||||
|
"_Additional instruction content omitted after reaching the prompt budget._"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw_content = truncate_instruction_content(&file.content, remaining_chars);
|
||||||
|
let rendered_content = render_instruction_content(&raw_content);
|
||||||
|
let consumed = rendered_content.chars().count().min(remaining_chars);
|
||||||
|
remaining_chars = remaining_chars.saturating_sub(consumed);
|
||||||
|
|
||||||
|
sections.push(format!("## {}", describe_instruction_file(file, files)));
|
||||||
|
sections.push(rendered_content);
|
||||||
}
|
}
|
||||||
sections.join("\n\n")
|
sections.join("\n\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dedupe_instruction_files(files: Vec<ContextFile>) -> Vec<ContextFile> {
|
||||||
|
let mut deduped = Vec::new();
|
||||||
|
let mut seen_hashes = Vec::new();
|
||||||
|
|
||||||
|
for file in files {
|
||||||
|
let normalized = normalize_instruction_content(&file.content);
|
||||||
|
let hash = stable_content_hash(&normalized);
|
||||||
|
if seen_hashes.contains(&hash) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen_hashes.push(hash);
|
||||||
|
deduped.push(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
deduped
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_instruction_content(content: &str) -> String {
|
||||||
|
collapse_blank_lines(content).trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stable_content_hash(content: &str) -> u64 {
|
||||||
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||||
|
content.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe_instruction_file(file: &ContextFile, files: &[ContextFile]) -> String {
|
||||||
|
let path = display_context_path(&file.path);
|
||||||
|
let scope = files
|
||||||
|
.iter()
|
||||||
|
.filter_map(|candidate| candidate.path.parent())
|
||||||
|
.find(|parent| file.path.starts_with(parent))
|
||||||
|
.map_or_else(
|
||||||
|
|| "workspace".to_string(),
|
||||||
|
|parent| parent.display().to_string(),
|
||||||
|
);
|
||||||
|
format!("{path} (scope: {scope})")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_instruction_content(content: &str, remaining_chars: usize) -> String {
|
||||||
|
let hard_limit = MAX_INSTRUCTION_FILE_CHARS.min(remaining_chars);
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if trimmed.chars().count() <= hard_limit {
|
||||||
|
return trimmed.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = trimmed.chars().take(hard_limit).collect::<String>();
|
||||||
|
output.push_str("\n\n[truncated]");
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_instruction_content(content: &str) -> String {
|
||||||
|
truncate_instruction_content(content, MAX_INSTRUCTION_FILE_CHARS)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn display_context_path(path: &Path) -> String {
|
||||||
|
path.file_name().map_or_else(
|
||||||
|
|| path.display().to_string(),
|
||||||
|
|name| name.to_string_lossy().into_owned(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_blank_lines(content: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut previous_blank = false;
|
||||||
|
for line in content.lines() {
|
||||||
|
let is_blank = line.trim().is_empty();
|
||||||
|
if is_blank && previous_blank {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result.push_str(line.trim_end());
|
||||||
|
result.push('\n');
|
||||||
|
previous_blank = is_blank;
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
pub fn load_system_prompt(
|
pub fn load_system_prompt(
|
||||||
cwd: impl Into<PathBuf>,
|
cwd: impl Into<PathBuf>,
|
||||||
current_date: impl Into<String>,
|
current_date: impl Into<String>,
|
||||||
@@ -348,9 +451,14 @@ fn get_actions_section() -> String {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ProjectContext, SystemPromptBuilder, SYSTEM_PROMPT_DYNAMIC_BOUNDARY};
|
use super::{
|
||||||
|
collapse_blank_lines, display_context_path, normalize_instruction_content,
|
||||||
|
render_instruction_content, render_instruction_files, truncate_instruction_content,
|
||||||
|
ContextFile, ProjectContext, SystemPromptBuilder, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||||
|
};
|
||||||
use crate::config::ConfigLoader;
|
use crate::config::ConfigLoader;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
fn temp_dir() -> std::path::PathBuf {
|
fn temp_dir() -> std::path::PathBuf {
|
||||||
@@ -361,6 +469,10 @@ mod tests {
|
|||||||
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
crate::test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discovers_instruction_files_from_ancestor_chain() {
|
fn discovers_instruction_files_from_ancestor_chain() {
|
||||||
let root = temp_dir();
|
let root = temp_dir();
|
||||||
@@ -370,10 +482,21 @@ mod tests {
|
|||||||
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
||||||
.expect("write local instructions");
|
.expect("write local instructions");
|
||||||
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
||||||
|
fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir");
|
||||||
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
||||||
.expect("write apps instructions");
|
.expect("write apps instructions");
|
||||||
|
fs::write(
|
||||||
|
root.join("apps").join(".claude").join("instructions.md"),
|
||||||
|
"apps dot claude instructions",
|
||||||
|
)
|
||||||
|
.expect("write apps dot claude instructions");
|
||||||
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
||||||
.expect("write nested rules");
|
.expect("write nested rules");
|
||||||
|
fs::write(
|
||||||
|
nested.join(".claude").join("instructions.md"),
|
||||||
|
"nested instructions",
|
||||||
|
)
|
||||||
|
.expect("write nested instructions");
|
||||||
|
|
||||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||||
let contents = context
|
let contents = context
|
||||||
@@ -388,12 +511,53 @@ mod tests {
|
|||||||
"root instructions",
|
"root instructions",
|
||||||
"local instructions",
|
"local instructions",
|
||||||
"apps instructions",
|
"apps instructions",
|
||||||
"nested rules"
|
"apps dot claude instructions",
|
||||||
|
"nested rules",
|
||||||
|
"nested instructions"
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dedupes_identical_instruction_content_across_scopes() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let nested = root.join("apps").join("api");
|
||||||
|
fs::create_dir_all(&nested).expect("nested dir");
|
||||||
|
fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root");
|
||||||
|
fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested");
|
||||||
|
|
||||||
|
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||||
|
assert_eq!(context.instruction_files.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
normalize_instruction_content(&context.instruction_files[0].content),
|
||||||
|
"same rules"
|
||||||
|
);
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncates_large_instruction_content_for_rendering() {
|
||||||
|
let rendered = render_instruction_content(&"x".repeat(4500));
|
||||||
|
assert!(rendered.contains("[truncated]"));
|
||||||
|
assert!(rendered.len() < 4_100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalizes_and_collapses_blank_lines() {
|
||||||
|
let normalized = normalize_instruction_content("line one\n\n\nline two\n");
|
||||||
|
assert_eq!(normalized, "line one\n\nline two");
|
||||||
|
assert_eq!(collapse_blank_lines("a\n\n\n\nb\n"), "a\n\nb\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn displays_context_paths_compactly() {
|
||||||
|
assert_eq!(
|
||||||
|
display_context_path(Path::new("/tmp/project/.claude/CLAUDE.md")),
|
||||||
|
"CLAUDE.md"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discover_with_git_includes_status_snapshot() {
|
fn discover_with_git_includes_status_snapshot() {
|
||||||
let root = temp_dir();
|
let root = temp_dir();
|
||||||
@@ -428,7 +592,12 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.expect("write settings");
|
.expect("write settings");
|
||||||
|
|
||||||
|
let _guard = env_lock();
|
||||||
let previous = std::env::current_dir().expect("cwd");
|
let previous = std::env::current_dir().expect("cwd");
|
||||||
|
let original_home = std::env::var("HOME").ok();
|
||||||
|
let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok();
|
||||||
|
std::env::set_var("HOME", &root);
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home"));
|
||||||
std::env::set_current_dir(&root).expect("change cwd");
|
std::env::set_current_dir(&root).expect("change cwd");
|
||||||
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
||||||
.expect("system prompt should load")
|
.expect("system prompt should load")
|
||||||
@@ -438,6 +607,16 @@ mod tests {
|
|||||||
",
|
",
|
||||||
);
|
);
|
||||||
std::env::set_current_dir(previous).expect("restore cwd");
|
std::env::set_current_dir(previous).expect("restore cwd");
|
||||||
|
if let Some(value) = original_home {
|
||||||
|
std::env::set_var("HOME", value);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
if let Some(value) = original_claude_home {
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", value);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
}
|
||||||
|
|
||||||
assert!(prompt.contains("Project rules"));
|
assert!(prompt.contains("Project rules"));
|
||||||
assert!(prompt.contains("permissionMode"));
|
assert!(prompt.contains("permissionMode"));
|
||||||
@@ -476,4 +655,46 @@ mod tests {
|
|||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncates_instruction_content_to_budget() {
|
||||||
|
let content = "x".repeat(5_000);
|
||||||
|
let rendered = truncate_instruction_content(&content, 4_000);
|
||||||
|
assert!(rendered.contains("[truncated]"));
|
||||||
|
assert!(rendered.chars().count() <= 4_000 + "\n\n[truncated]".chars().count());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discovers_dot_claude_instructions_markdown() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let nested = root.join("apps").join("api");
|
||||||
|
fs::create_dir_all(nested.join(".claude")).expect("nested claude dir");
|
||||||
|
fs::write(
|
||||||
|
nested.join(".claude").join("instructions.md"),
|
||||||
|
"instruction markdown",
|
||||||
|
)
|
||||||
|
.expect("write instructions.md");
|
||||||
|
|
||||||
|
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||||
|
assert!(context
|
||||||
|
.instruction_files
|
||||||
|
.iter()
|
||||||
|
.any(|file| file.path.ends_with(".claude/instructions.md")));
|
||||||
|
assert!(
|
||||||
|
render_instruction_files(&context.instruction_files).contains("instruction markdown")
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_instruction_file_metadata() {
|
||||||
|
let rendered = render_instruction_files(&[ContextFile {
|
||||||
|
path: PathBuf::from("/tmp/project/CLAUDE.md"),
|
||||||
|
content: "Project rules".to_string(),
|
||||||
|
}]);
|
||||||
|
assert!(rendered.contains("# Claude instructions"));
|
||||||
|
assert!(rendered.contains("scope: /tmp/project"));
|
||||||
|
assert!(rendered.contains("Project rules"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
401
rust/crates/runtime/src/remote.rs
Normal file
401
rust/crates/runtime/src/remote.rs
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::io;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
pub const DEFAULT_REMOTE_BASE_URL: &str = "https://api.anthropic.com";
|
||||||
|
pub const DEFAULT_SESSION_TOKEN_PATH: &str = "/run/ccr/session_token";
|
||||||
|
pub const DEFAULT_SYSTEM_CA_BUNDLE: &str = "/etc/ssl/certs/ca-certificates.crt";
|
||||||
|
|
||||||
|
pub const UPSTREAM_PROXY_ENV_KEYS: [&str; 8] = [
|
||||||
|
"HTTPS_PROXY",
|
||||||
|
"https_proxy",
|
||||||
|
"NO_PROXY",
|
||||||
|
"no_proxy",
|
||||||
|
"SSL_CERT_FILE",
|
||||||
|
"NODE_EXTRA_CA_CERTS",
|
||||||
|
"REQUESTS_CA_BUNDLE",
|
||||||
|
"CURL_CA_BUNDLE",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const NO_PROXY_HOSTS: [&str; 16] = [
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
"::1",
|
||||||
|
"169.254.0.0/16",
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"anthropic.com",
|
||||||
|
".anthropic.com",
|
||||||
|
"*.anthropic.com",
|
||||||
|
"github.com",
|
||||||
|
"api.github.com",
|
||||||
|
"*.github.com",
|
||||||
|
"*.githubusercontent.com",
|
||||||
|
"registry.npmjs.org",
|
||||||
|
"index.crates.io",
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct RemoteSessionContext {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
pub base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct UpstreamProxyBootstrap {
|
||||||
|
pub remote: RemoteSessionContext,
|
||||||
|
pub upstream_proxy_enabled: bool,
|
||||||
|
pub token_path: PathBuf,
|
||||||
|
pub ca_bundle_path: PathBuf,
|
||||||
|
pub system_ca_path: PathBuf,
|
||||||
|
pub token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct UpstreamProxyState {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub proxy_url: Option<String>,
|
||||||
|
pub ca_bundle_path: Option<PathBuf>,
|
||||||
|
pub no_proxy: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteSessionContext {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
Self::from_env_map(&env::vars().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")),
|
||||||
|
session_id: env_map
|
||||||
|
.get("CLAUDE_CODE_REMOTE_SESSION_ID")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.cloned(),
|
||||||
|
base_url: env_map
|
||||||
|
.get("ANTHROPIC_BASE_URL")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| DEFAULT_REMOTE_BASE_URL.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamProxyBootstrap {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env() -> Self {
|
||||||
|
Self::from_env_map(&env::vars().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
|
||||||
|
let remote = RemoteSessionContext::from_env_map(env_map);
|
||||||
|
let token_path = env_map
|
||||||
|
.get("CCR_SESSION_TOKEN_PATH")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(|| PathBuf::from(DEFAULT_SESSION_TOKEN_PATH), PathBuf::from);
|
||||||
|
let system_ca_path = env_map
|
||||||
|
.get("CCR_SYSTEM_CA_BUNDLE")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(|| PathBuf::from(DEFAULT_SYSTEM_CA_BUNDLE), PathBuf::from);
|
||||||
|
let ca_bundle_path = env_map
|
||||||
|
.get("CCR_CA_BUNDLE_PATH")
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map_or_else(default_ca_bundle_path, PathBuf::from);
|
||||||
|
let token = read_token(&token_path).ok().flatten();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
remote,
|
||||||
|
upstream_proxy_enabled: env_truthy(env_map.get("CCR_UPSTREAM_PROXY_ENABLED")),
|
||||||
|
token_path,
|
||||||
|
ca_bundle_path,
|
||||||
|
system_ca_path,
|
||||||
|
token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn should_enable(&self) -> bool {
|
||||||
|
self.remote.enabled
|
||||||
|
&& self.upstream_proxy_enabled
|
||||||
|
&& self.remote.session_id.is_some()
|
||||||
|
&& self.token.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn ws_url(&self) -> String {
|
||||||
|
upstream_proxy_ws_url(&self.remote.base_url)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn state_for_port(&self, port: u16) -> UpstreamProxyState {
|
||||||
|
if !self.should_enable() {
|
||||||
|
return UpstreamProxyState::disabled();
|
||||||
|
}
|
||||||
|
UpstreamProxyState {
|
||||||
|
enabled: true,
|
||||||
|
proxy_url: Some(format!("http://127.0.0.1:{port}")),
|
||||||
|
ca_bundle_path: Some(self.ca_bundle_path.clone()),
|
||||||
|
no_proxy: no_proxy_list(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamProxyState {
|
||||||
|
#[must_use]
|
||||||
|
pub fn disabled() -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: false,
|
||||||
|
proxy_url: None,
|
||||||
|
ca_bundle_path: None,
|
||||||
|
no_proxy: no_proxy_list(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn subprocess_env(&self) -> BTreeMap<String, String> {
|
||||||
|
if !self.enabled {
|
||||||
|
return BTreeMap::new();
|
||||||
|
}
|
||||||
|
let Some(proxy_url) = &self.proxy_url else {
|
||||||
|
return BTreeMap::new();
|
||||||
|
};
|
||||||
|
let Some(ca_bundle_path) = &self.ca_bundle_path else {
|
||||||
|
return BTreeMap::new();
|
||||||
|
};
|
||||||
|
let ca_bundle_path = ca_bundle_path.to_string_lossy().into_owned();
|
||||||
|
BTreeMap::from([
|
||||||
|
("HTTPS_PROXY".to_string(), proxy_url.clone()),
|
||||||
|
("https_proxy".to_string(), proxy_url.clone()),
|
||||||
|
("NO_PROXY".to_string(), self.no_proxy.clone()),
|
||||||
|
("no_proxy".to_string(), self.no_proxy.clone()),
|
||||||
|
("SSL_CERT_FILE".to_string(), ca_bundle_path.clone()),
|
||||||
|
("NODE_EXTRA_CA_CERTS".to_string(), ca_bundle_path.clone()),
|
||||||
|
("REQUESTS_CA_BUNDLE".to_string(), ca_bundle_path.clone()),
|
||||||
|
("CURL_CA_BUNDLE".to_string(), ca_bundle_path),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_token(path: &Path) -> io::Result<Option<String>> {
|
||||||
|
match fs::read_to_string(path) {
|
||||||
|
Ok(contents) => {
|
||||||
|
let token = contents.trim();
|
||||||
|
if token.is_empty() {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(token.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn upstream_proxy_ws_url(base_url: &str) -> String {
|
||||||
|
let base = base_url.trim_end_matches('/');
|
||||||
|
let ws_base = if let Some(stripped) = base.strip_prefix("https://") {
|
||||||
|
format!("wss://{stripped}")
|
||||||
|
} else if let Some(stripped) = base.strip_prefix("http://") {
|
||||||
|
format!("ws://{stripped}")
|
||||||
|
} else {
|
||||||
|
format!("wss://{base}")
|
||||||
|
};
|
||||||
|
format!("{ws_base}/v1/code/upstreamproxy/ws")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn no_proxy_list() -> String {
|
||||||
|
let mut hosts = NO_PROXY_HOSTS.to_vec();
|
||||||
|
hosts.extend(["pypi.org", "files.pythonhosted.org", "proxy.golang.org"]);
|
||||||
|
hosts.join(",")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn inherited_upstream_proxy_env(
|
||||||
|
env_map: &BTreeMap<String, String>,
|
||||||
|
) -> BTreeMap<String, String> {
|
||||||
|
if !(env_map.contains_key("HTTPS_PROXY") && env_map.contains_key("SSL_CERT_FILE")) {
|
||||||
|
return BTreeMap::new();
|
||||||
|
}
|
||||||
|
UPSTREAM_PROXY_ENV_KEYS
|
||||||
|
.iter()
|
||||||
|
.filter_map(|key| {
|
||||||
|
env_map
|
||||||
|
.get(*key)
|
||||||
|
.map(|value| ((*key).to_string(), value.clone()))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ca_bundle_path() -> PathBuf {
|
||||||
|
env::var_os("HOME")
|
||||||
|
.map_or_else(|| PathBuf::from("."), PathBuf::from)
|
||||||
|
.join(".ccr")
|
||||||
|
.join("ca-bundle.crt")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn env_truthy(value: Option<&String>) -> bool {
|
||||||
|
value.is_some_and(|raw| {
|
||||||
|
matches!(
|
||||||
|
raw.trim().to_ascii_lowercase().as_str(),
|
||||||
|
"1" | "true" | "yes" | "on"
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
|
||||||
|
RemoteSessionContext, UpstreamProxyBootstrap,
|
||||||
|
};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
fn temp_dir() -> PathBuf {
|
||||||
|
let nanos = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time should be after epoch")
|
||||||
|
.as_nanos();
|
||||||
|
std::env::temp_dir().join(format!("runtime-remote-{nanos}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_context_reads_env_state() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()),
|
||||||
|
(
|
||||||
|
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||||
|
"session-123".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ANTHROPIC_BASE_URL".to_string(),
|
||||||
|
"https://remote.test".to_string(),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
let context = RemoteSessionContext::from_env_map(&env);
|
||||||
|
assert!(context.enabled);
|
||||||
|
assert_eq!(context.session_id.as_deref(), Some("session-123"));
|
||||||
|
assert_eq!(context.base_url, "https://remote.test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstrap_fails_open_when_token_or_session_is_missing() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||||
|
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||||
|
]);
|
||||||
|
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
|
||||||
|
assert!(!bootstrap.should_enable());
|
||||||
|
assert!(!bootstrap.state_for_port(8080).enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bootstrap_derives_proxy_state_and_env() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let token_path = root.join("session_token");
|
||||||
|
fs::create_dir_all(&root).expect("temp dir");
|
||||||
|
fs::write(&token_path, "secret-token\n").expect("write token");
|
||||||
|
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||||
|
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||||
|
(
|
||||||
|
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||||
|
"session-123".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ANTHROPIC_BASE_URL".to_string(),
|
||||||
|
"https://remote.test".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"CCR_SESSION_TOKEN_PATH".to_string(),
|
||||||
|
token_path.to_string_lossy().into_owned(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"CCR_CA_BUNDLE_PATH".to_string(),
|
||||||
|
root.join("ca-bundle.crt").to_string_lossy().into_owned(),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
|
||||||
|
assert!(bootstrap.should_enable());
|
||||||
|
assert_eq!(bootstrap.token.as_deref(), Some("secret-token"));
|
||||||
|
assert_eq!(
|
||||||
|
bootstrap.ws_url(),
|
||||||
|
"wss://remote.test/v1/code/upstreamproxy/ws"
|
||||||
|
);
|
||||||
|
|
||||||
|
let state = bootstrap.state_for_port(9443);
|
||||||
|
assert!(state.enabled);
|
||||||
|
let env = state.subprocess_env();
|
||||||
|
assert_eq!(
|
||||||
|
env.get("HTTPS_PROXY").map(String::as_str),
|
||||||
|
Some("http://127.0.0.1:9443")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
env.get("SSL_CERT_FILE").map(String::as_str),
|
||||||
|
Some(root.join("ca-bundle.crt").to_string_lossy().as_ref())
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_reader_trims_and_handles_missing_files() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("temp dir");
|
||||||
|
let token_path = root.join("session_token");
|
||||||
|
fs::write(&token_path, " abc123 \n").expect("write token");
|
||||||
|
assert_eq!(
|
||||||
|
read_token(&token_path).expect("read token").as_deref(),
|
||||||
|
Some("abc123")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
read_token(&root.join("missing")).expect("missing token"),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inherited_proxy_env_requires_proxy_and_ca() {
|
||||||
|
let env = BTreeMap::from([
|
||||||
|
(
|
||||||
|
"HTTPS_PROXY".to_string(),
|
||||||
|
"http://127.0.0.1:8888".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SSL_CERT_FILE".to_string(),
|
||||||
|
"/tmp/ca-bundle.crt".to_string(),
|
||||||
|
),
|
||||||
|
("NO_PROXY".to_string(), "localhost".to_string()),
|
||||||
|
]);
|
||||||
|
let inherited = inherited_upstream_proxy_env(&env);
|
||||||
|
assert_eq!(inherited.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
inherited.get("NO_PROXY").map(String::as_str),
|
||||||
|
Some("localhost")
|
||||||
|
);
|
||||||
|
assert!(inherited_upstream_proxy_env(&BTreeMap::new()).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn helper_outputs_match_expected_shapes() {
|
||||||
|
assert_eq!(
|
||||||
|
upstream_proxy_ws_url("http://localhost:3000/"),
|
||||||
|
"ws://localhost:3000/v1/code/upstreamproxy/ws"
|
||||||
|
);
|
||||||
|
assert!(no_proxy_list().contains("anthropic.com"));
|
||||||
|
assert!(no_proxy_list().contains("github.com"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,30 @@
|
|||||||
use crate::session::Session;
|
use crate::session::Session;
|
||||||
|
|
||||||
|
const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
|
||||||
|
const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
|
||||||
|
const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
|
||||||
|
const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub struct ModelPricing {
|
||||||
|
pub input_cost_per_million: f64,
|
||||||
|
pub output_cost_per_million: f64,
|
||||||
|
pub cache_creation_cost_per_million: f64,
|
||||||
|
pub cache_read_cost_per_million: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelPricing {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn default_sonnet_tier() -> Self {
|
||||||
|
Self {
|
||||||
|
input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
|
||||||
|
output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
|
||||||
|
cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
|
||||||
|
cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||||
pub struct TokenUsage {
|
pub struct TokenUsage {
|
||||||
pub input_tokens: u32,
|
pub input_tokens: u32,
|
||||||
@@ -8,6 +33,49 @@ pub struct TokenUsage {
|
|||||||
pub cache_read_input_tokens: u32,
|
pub cache_read_input_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub struct UsageCostEstimate {
|
||||||
|
pub input_cost_usd: f64,
|
||||||
|
pub output_cost_usd: f64,
|
||||||
|
pub cache_creation_cost_usd: f64,
|
||||||
|
pub cache_read_cost_usd: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UsageCostEstimate {
|
||||||
|
#[must_use]
|
||||||
|
pub fn total_cost_usd(self) -> f64 {
|
||||||
|
self.input_cost_usd
|
||||||
|
+ self.output_cost_usd
|
||||||
|
+ self.cache_creation_cost_usd
|
||||||
|
+ self.cache_read_cost_usd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
|
||||||
|
let normalized = model.to_ascii_lowercase();
|
||||||
|
if normalized.contains("haiku") {
|
||||||
|
return Some(ModelPricing {
|
||||||
|
input_cost_per_million: 1.0,
|
||||||
|
output_cost_per_million: 5.0,
|
||||||
|
cache_creation_cost_per_million: 1.25,
|
||||||
|
cache_read_cost_per_million: 0.1,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if normalized.contains("opus") {
|
||||||
|
return Some(ModelPricing {
|
||||||
|
input_cost_per_million: 15.0,
|
||||||
|
output_cost_per_million: 75.0,
|
||||||
|
cache_creation_cost_per_million: 18.75,
|
||||||
|
cache_read_cost_per_million: 1.5,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if normalized.contains("sonnet") {
|
||||||
|
return Some(ModelPricing::default_sonnet_tier());
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
impl TokenUsage {
|
impl TokenUsage {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn total_tokens(self) -> u32 {
|
pub fn total_tokens(self) -> u32 {
|
||||||
@@ -16,6 +84,79 @@ impl TokenUsage {
|
|||||||
+ self.cache_creation_input_tokens
|
+ self.cache_creation_input_tokens
|
||||||
+ self.cache_read_input_tokens
|
+ self.cache_read_input_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn estimate_cost_usd(self) -> UsageCostEstimate {
|
||||||
|
self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
|
||||||
|
UsageCostEstimate {
|
||||||
|
input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
|
||||||
|
output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
|
||||||
|
cache_creation_cost_usd: cost_for_tokens(
|
||||||
|
self.cache_creation_input_tokens,
|
||||||
|
pricing.cache_creation_cost_per_million,
|
||||||
|
),
|
||||||
|
cache_read_cost_usd: cost_for_tokens(
|
||||||
|
self.cache_read_input_tokens,
|
||||||
|
pricing.cache_read_cost_per_million,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn summary_lines(self, label: &str) -> Vec<String> {
|
||||||
|
self.summary_lines_for_model(label, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
|
||||||
|
let pricing = model.and_then(pricing_for_model);
|
||||||
|
let cost = pricing.map_or_else(
|
||||||
|
|| self.estimate_cost_usd(),
|
||||||
|
|pricing| self.estimate_cost_usd_with_pricing(pricing),
|
||||||
|
);
|
||||||
|
let model_suffix =
|
||||||
|
model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
|
||||||
|
let pricing_suffix = if pricing.is_some() {
|
||||||
|
""
|
||||||
|
} else if model.is_some() {
|
||||||
|
" pricing=estimated-default"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
vec![
|
||||||
|
format!(
|
||||||
|
"{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
|
||||||
|
self.total_tokens(),
|
||||||
|
self.input_tokens,
|
||||||
|
self.output_tokens,
|
||||||
|
self.cache_creation_input_tokens,
|
||||||
|
self.cache_read_input_tokens,
|
||||||
|
format_usd(cost.total_cost_usd()),
|
||||||
|
model_suffix,
|
||||||
|
pricing_suffix,
|
||||||
|
),
|
||||||
|
format!(
|
||||||
|
" cost breakdown: input={} output={} cache_write={} cache_read={}",
|
||||||
|
format_usd(cost.input_cost_usd),
|
||||||
|
format_usd(cost.output_cost_usd),
|
||||||
|
format_usd(cost.cache_creation_cost_usd),
|
||||||
|
format_usd(cost.cache_read_cost_usd),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
|
||||||
|
f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn format_usd(amount: f64) -> String {
|
||||||
|
format!("${amount:.4}")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
@@ -69,7 +210,7 @@ impl UsageTracker {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{TokenUsage, UsageTracker};
|
use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
|
||||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -96,6 +237,53 @@ mod tests {
|
|||||||
assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
|
assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn computes_cost_summary_lines() {
|
||||||
|
let usage = TokenUsage {
|
||||||
|
input_tokens: 1_000_000,
|
||||||
|
output_tokens: 500_000,
|
||||||
|
cache_creation_input_tokens: 100_000,
|
||||||
|
cache_read_input_tokens: 200_000,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cost = usage.estimate_cost_usd();
|
||||||
|
assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
|
||||||
|
assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
|
||||||
|
let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
|
||||||
|
assert!(lines[0].contains("estimated_cost=$54.6750"));
|
||||||
|
assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
|
||||||
|
assert!(lines[1].contains("cache_read=$0.3000"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn supports_model_specific_pricing() {
|
||||||
|
let usage = TokenUsage {
|
||||||
|
input_tokens: 1_000_000,
|
||||||
|
output_tokens: 500_000,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
|
||||||
|
let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
|
||||||
|
let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
|
||||||
|
let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
|
||||||
|
assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
|
||||||
|
assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn marks_unknown_model_pricing_as_fallback() {
|
||||||
|
let usage = TokenUsage {
|
||||||
|
input_tokens: 100,
|
||||||
|
output_tokens: 100,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
};
|
||||||
|
let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
|
||||||
|
assert!(lines[0].contains("pricing=estimated-default"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reconstructs_usage_from_session_messages() {
|
fn reconstructs_usage_from_session_messages() {
|
||||||
let session = Session {
|
let session = Session {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::io::{self, Write};
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use crate::args::{OutputFormat, PermissionMode};
|
use crate::args::{OutputFormat, PermissionMode};
|
||||||
use crate::input::LineEditor;
|
use crate::input::{LineEditor, ReadOutcome};
|
||||||
use crate::render::{Spinner, TerminalRenderer};
|
use crate::render::{Spinner, TerminalRenderer};
|
||||||
use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary};
|
use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary};
|
||||||
|
|
||||||
@@ -111,16 +111,21 @@ impl CliApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_repl(&mut self) -> io::Result<()> {
|
pub fn run_repl(&mut self) -> io::Result<()> {
|
||||||
let editor = LineEditor::new("› ");
|
let mut editor = LineEditor::new("› ", Vec::new());
|
||||||
println!("Rusty Claude CLI interactive mode");
|
println!("Rusty Claude CLI interactive mode");
|
||||||
println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline.");
|
println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline.");
|
||||||
|
|
||||||
while let Some(input) = editor.read_line()? {
|
loop {
|
||||||
if input.trim().is_empty() {
|
match editor.read_line()? {
|
||||||
continue;
|
ReadOutcome::Submit(input) => {
|
||||||
|
if input.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
self.handle_submission(&input, &mut io::stdout())?;
|
||||||
|
}
|
||||||
|
ReadOutcome::Cancel => continue,
|
||||||
|
ReadOutcome::Exit => break,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.handle_submission(&input, &mut io::stdout())?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ pub enum Command {
|
|||||||
DumpManifests,
|
DumpManifests,
|
||||||
/// Print the current bootstrap phase skeleton
|
/// Print the current bootstrap phase skeleton
|
||||||
BootstrapPlan,
|
BootstrapPlan,
|
||||||
|
/// Start the OAuth login flow
|
||||||
|
Login,
|
||||||
|
/// Clear saved OAuth credentials
|
||||||
|
Logout,
|
||||||
/// Run a non-interactive prompt and exit
|
/// Run a non-interactive prompt and exit
|
||||||
Prompt { prompt: Vec<String> },
|
Prompt { prompt: Vec<String> },
|
||||||
}
|
}
|
||||||
@@ -86,4 +90,13 @@ mod tests {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_login_and_logout_commands() {
|
||||||
|
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
|
||||||
|
assert_eq!(login.command, Some(Command::Login));
|
||||||
|
|
||||||
|
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
|
||||||
|
assert_eq!(logout.command, Some(Command::Logout));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
use std::io::{self, IsTerminal, Write};
|
use std::io::{self, IsTerminal, Write};
|
||||||
|
|
||||||
use crossterm::cursor::MoveToColumn;
|
use crossterm::cursor::{MoveDown, MoveToColumn, MoveUp};
|
||||||
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
|
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
|
||||||
use crossterm::queue;
|
use crossterm::queue;
|
||||||
use crossterm::style::Print;
|
|
||||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType};
|
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
@@ -85,21 +84,124 @@ impl InputBuffer {
|
|||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
self.cursor = 0;
|
self.cursor = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn replace(&mut self, value: impl Into<String>) {
|
||||||
|
self.buffer = value.into();
|
||||||
|
self.cursor = self.buffer.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn current_command_prefix(&self) -> Option<&str> {
|
||||||
|
if self.cursor != self.buffer.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let prefix = &self.buffer[..self.cursor];
|
||||||
|
if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn complete_slash_command(&mut self, candidates: &[String]) -> bool {
|
||||||
|
let Some(prefix) = self.current_command_prefix() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let matches = candidates
|
||||||
|
.iter()
|
||||||
|
.filter(|candidate| candidate.starts_with(prefix))
|
||||||
|
.map(String::as_str)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if matches.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let replacement = longest_common_prefix(&matches);
|
||||||
|
if replacement == prefix {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.replace(replacement);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct RenderedBuffer {
|
||||||
|
lines: Vec<String>,
|
||||||
|
cursor_row: u16,
|
||||||
|
cursor_col: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RenderedBuffer {
|
||||||
|
#[must_use]
|
||||||
|
pub fn line_count(&self) -> usize {
|
||||||
|
self.lines.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write(&self, out: &mut impl Write) -> io::Result<()> {
|
||||||
|
for (index, line) in self.lines.iter().enumerate() {
|
||||||
|
if index > 0 {
|
||||||
|
writeln!(out)?;
|
||||||
|
}
|
||||||
|
write!(out, "{line}")?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[must_use]
|
||||||
|
pub fn lines(&self) -> &[String] {
|
||||||
|
&self.lines
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[must_use]
|
||||||
|
pub fn cursor_position(&self) -> (u16, u16) {
|
||||||
|
(self.cursor_row, self.cursor_col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum ReadOutcome {
|
||||||
|
Submit(String),
|
||||||
|
Cancel,
|
||||||
|
Exit,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LineEditor {
|
pub struct LineEditor {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
continuation_prompt: String,
|
||||||
|
history: Vec<String>,
|
||||||
|
history_index: Option<usize>,
|
||||||
|
draft: Option<String>,
|
||||||
|
completions: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LineEditor {
|
impl LineEditor {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(prompt: impl Into<String>) -> Self {
|
pub fn new(prompt: impl Into<String>, completions: Vec<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
prompt: prompt.into(),
|
prompt: prompt.into(),
|
||||||
|
continuation_prompt: String::from("> "),
|
||||||
|
history: Vec::new(),
|
||||||
|
history_index: None,
|
||||||
|
draft: None,
|
||||||
|
completions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_line(&self) -> io::Result<Option<String>> {
|
pub fn push_history(&mut self, entry: impl Into<String>) {
|
||||||
|
let entry = entry.into();
|
||||||
|
if entry.trim().is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
self.history.push(entry);
|
||||||
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_line(&mut self) -> io::Result<ReadOutcome> {
|
||||||
if !io::stdin().is_terminal() || !io::stdout().is_terminal() {
|
if !io::stdin().is_terminal() || !io::stdout().is_terminal() {
|
||||||
return self.read_line_fallback();
|
return self.read_line_fallback();
|
||||||
}
|
}
|
||||||
@@ -107,29 +209,43 @@ impl LineEditor {
|
|||||||
enable_raw_mode()?;
|
enable_raw_mode()?;
|
||||||
let mut stdout = io::stdout();
|
let mut stdout = io::stdout();
|
||||||
let mut input = InputBuffer::new();
|
let mut input = InputBuffer::new();
|
||||||
self.redraw(&mut stdout, &input)?;
|
let mut rendered_lines = 1usize;
|
||||||
|
self.redraw(&mut stdout, &input, rendered_lines)?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let event = event::read()?;
|
let event = event::read()?;
|
||||||
if let Event::Key(key) = event {
|
if let Event::Key(key) = event {
|
||||||
match Self::handle_key(key, &mut input) {
|
match self.handle_key(key, &mut input) {
|
||||||
EditorAction::Continue => self.redraw(&mut stdout, &input)?,
|
EditorAction::Continue => {
|
||||||
|
rendered_lines = self.redraw(&mut stdout, &input, rendered_lines)?;
|
||||||
|
}
|
||||||
EditorAction::Submit => {
|
EditorAction::Submit => {
|
||||||
disable_raw_mode()?;
|
disable_raw_mode()?;
|
||||||
writeln!(stdout)?;
|
writeln!(stdout)?;
|
||||||
return Ok(Some(input.as_str().to_owned()));
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
|
return Ok(ReadOutcome::Submit(input.as_str().to_owned()));
|
||||||
}
|
}
|
||||||
EditorAction::Cancel => {
|
EditorAction::Cancel => {
|
||||||
disable_raw_mode()?;
|
disable_raw_mode()?;
|
||||||
writeln!(stdout)?;
|
writeln!(stdout)?;
|
||||||
return Ok(None);
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
|
return Ok(ReadOutcome::Cancel);
|
||||||
|
}
|
||||||
|
EditorAction::Exit => {
|
||||||
|
disable_raw_mode()?;
|
||||||
|
writeln!(stdout)?;
|
||||||
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
|
return Ok(ReadOutcome::Exit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_line_fallback(&self) -> io::Result<Option<String>> {
|
fn read_line_fallback(&self) -> io::Result<ReadOutcome> {
|
||||||
let mut stdout = io::stdout();
|
let mut stdout = io::stdout();
|
||||||
write!(stdout, "{}", self.prompt)?;
|
write!(stdout, "{}", self.prompt)?;
|
||||||
stdout.flush()?;
|
stdout.flush()?;
|
||||||
@@ -137,22 +253,32 @@ impl LineEditor {
|
|||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
let bytes_read = io::stdin().read_line(&mut buffer)?;
|
let bytes_read = io::stdin().read_line(&mut buffer)?;
|
||||||
if bytes_read == 0 {
|
if bytes_read == 0 {
|
||||||
return Ok(None);
|
return Ok(ReadOutcome::Exit);
|
||||||
}
|
}
|
||||||
|
|
||||||
while matches!(buffer.chars().last(), Some('\n' | '\r')) {
|
while matches!(buffer.chars().last(), Some('\n' | '\r')) {
|
||||||
buffer.pop();
|
buffer.pop();
|
||||||
}
|
}
|
||||||
Ok(Some(buffer))
|
Ok(ReadOutcome::Submit(buffer))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_key(key: KeyEvent, input: &mut InputBuffer) -> EditorAction {
|
#[allow(clippy::too_many_lines)]
|
||||||
|
fn handle_key(&mut self, key: KeyEvent, input: &mut InputBuffer) -> EditorAction {
|
||||||
match key {
|
match key {
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Char('c'),
|
code: KeyCode::Char('c'),
|
||||||
modifiers,
|
modifiers,
|
||||||
..
|
..
|
||||||
} if modifiers.contains(KeyModifiers::CONTROL) => EditorAction::Cancel,
|
} if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||||
|
if input.as_str().is_empty() {
|
||||||
|
EditorAction::Exit
|
||||||
|
} else {
|
||||||
|
input.clear();
|
||||||
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
|
EditorAction::Cancel
|
||||||
|
}
|
||||||
|
}
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Char('j'),
|
code: KeyCode::Char('j'),
|
||||||
modifiers,
|
modifiers,
|
||||||
@@ -194,6 +320,25 @@ impl LineEditor {
|
|||||||
input.move_right();
|
input.move_right();
|
||||||
EditorAction::Continue
|
EditorAction::Continue
|
||||||
}
|
}
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Up, ..
|
||||||
|
} => {
|
||||||
|
self.navigate_history_up(input);
|
||||||
|
EditorAction::Continue
|
||||||
|
}
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Down,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.navigate_history_down(input);
|
||||||
|
EditorAction::Continue
|
||||||
|
}
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Tab, ..
|
||||||
|
} => {
|
||||||
|
input.complete_slash_command(&self.completions);
|
||||||
|
EditorAction::Continue
|
||||||
|
}
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
code: KeyCode::Home,
|
code: KeyCode::Home,
|
||||||
..
|
..
|
||||||
@@ -211,6 +356,8 @@ impl LineEditor {
|
|||||||
code: KeyCode::Esc, ..
|
code: KeyCode::Esc, ..
|
||||||
} => {
|
} => {
|
||||||
input.clear();
|
input.clear();
|
||||||
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
EditorAction::Cancel
|
EditorAction::Cancel
|
||||||
}
|
}
|
||||||
KeyEvent {
|
KeyEvent {
|
||||||
@@ -219,22 +366,74 @@ impl LineEditor {
|
|||||||
..
|
..
|
||||||
} if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => {
|
} if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => {
|
||||||
input.insert(ch);
|
input.insert(ch);
|
||||||
|
self.history_index = None;
|
||||||
|
self.draft = None;
|
||||||
EditorAction::Continue
|
EditorAction::Continue
|
||||||
}
|
}
|
||||||
_ => EditorAction::Continue,
|
_ => EditorAction::Continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn redraw(&self, out: &mut impl Write, input: &InputBuffer) -> io::Result<()> {
|
fn navigate_history_up(&mut self, input: &mut InputBuffer) {
|
||||||
let display = input.as_str().replace('\n', "\\n\n> ");
|
if self.history.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.history_index {
|
||||||
|
Some(0) => {}
|
||||||
|
Some(index) => {
|
||||||
|
let next_index = index - 1;
|
||||||
|
input.replace(self.history[next_index].clone());
|
||||||
|
self.history_index = Some(next_index);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.draft = Some(input.as_str().to_owned());
|
||||||
|
let next_index = self.history.len() - 1;
|
||||||
|
input.replace(self.history[next_index].clone());
|
||||||
|
self.history_index = Some(next_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn navigate_history_down(&mut self, input: &mut InputBuffer) {
|
||||||
|
let Some(index) = self.history_index else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if index + 1 < self.history.len() {
|
||||||
|
let next_index = index + 1;
|
||||||
|
input.replace(self.history[next_index].clone());
|
||||||
|
self.history_index = Some(next_index);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
input.replace(self.draft.take().unwrap_or_default());
|
||||||
|
self.history_index = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn redraw(
|
||||||
|
&self,
|
||||||
|
out: &mut impl Write,
|
||||||
|
input: &InputBuffer,
|
||||||
|
previous_line_count: usize,
|
||||||
|
) -> io::Result<usize> {
|
||||||
|
let rendered = render_buffer(&self.prompt, &self.continuation_prompt, input);
|
||||||
|
if previous_line_count > 1 {
|
||||||
|
queue!(out, MoveUp(saturating_u16(previous_line_count - 1)))?;
|
||||||
|
}
|
||||||
|
queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown),)?;
|
||||||
|
rendered.write(out)?;
|
||||||
queue!(
|
queue!(
|
||||||
out,
|
out,
|
||||||
|
MoveUp(saturating_u16(rendered.line_count().saturating_sub(1))),
|
||||||
MoveToColumn(0),
|
MoveToColumn(0),
|
||||||
Clear(ClearType::CurrentLine),
|
|
||||||
Print(&self.prompt),
|
|
||||||
Print(display),
|
|
||||||
)?;
|
)?;
|
||||||
out.flush()
|
if rendered.cursor_row > 0 {
|
||||||
|
queue!(out, MoveDown(rendered.cursor_row))?;
|
||||||
|
}
|
||||||
|
queue!(out, MoveToColumn(rendered.cursor_col))?;
|
||||||
|
out.flush()?;
|
||||||
|
Ok(rendered.line_count())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,11 +442,76 @@ enum EditorAction {
|
|||||||
Continue,
|
Continue,
|
||||||
Submit,
|
Submit,
|
||||||
Cancel,
|
Cancel,
|
||||||
|
Exit,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn render_buffer(
|
||||||
|
prompt: &str,
|
||||||
|
continuation_prompt: &str,
|
||||||
|
input: &InputBuffer,
|
||||||
|
) -> RenderedBuffer {
|
||||||
|
let before_cursor = &input.as_str()[..input.cursor];
|
||||||
|
let cursor_row = saturating_u16(before_cursor.chars().filter(|ch| *ch == '\n').count());
|
||||||
|
let cursor_line = before_cursor.rsplit('\n').next().unwrap_or_default();
|
||||||
|
let cursor_prompt = if cursor_row == 0 {
|
||||||
|
prompt
|
||||||
|
} else {
|
||||||
|
continuation_prompt
|
||||||
|
};
|
||||||
|
let cursor_col = saturating_u16(cursor_prompt.chars().count() + cursor_line.chars().count());
|
||||||
|
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
for (index, line) in input.as_str().split('\n').enumerate() {
|
||||||
|
let prefix = if index == 0 {
|
||||||
|
prompt
|
||||||
|
} else {
|
||||||
|
continuation_prompt
|
||||||
|
};
|
||||||
|
lines.push(format!("{prefix}{line}"));
|
||||||
|
}
|
||||||
|
if lines.is_empty() {
|
||||||
|
lines.push(prompt.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
RenderedBuffer {
|
||||||
|
lines,
|
||||||
|
cursor_row,
|
||||||
|
cursor_col,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn longest_common_prefix(values: &[&str]) -> String {
|
||||||
|
let Some(first) = values.first() else {
|
||||||
|
return String::new();
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut prefix = (*first).to_string();
|
||||||
|
for value in values.iter().skip(1) {
|
||||||
|
while !value.starts_with(&prefix) {
|
||||||
|
prefix.pop();
|
||||||
|
if prefix.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn saturating_u16(value: usize) -> u16 {
|
||||||
|
u16::try_from(value).unwrap_or(u16::MAX)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::InputBuffer;
|
use super::{render_buffer, InputBuffer, LineEditor};
|
||||||
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
|
|
||||||
|
fn key(code: KeyCode) -> KeyEvent {
|
||||||
|
KeyEvent::new(code, KeyModifiers::NONE)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn supports_basic_line_editing() {
|
fn supports_basic_line_editing() {
|
||||||
@@ -266,4 +530,119 @@ mod tests {
|
|||||||
assert_eq!(input.as_str(), "hix");
|
assert_eq!(input.as_str(), "hix");
|
||||||
assert_eq!(input.cursor(), 2);
|
assert_eq!(input.cursor(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn completes_unique_slash_command() {
|
||||||
|
let mut input = InputBuffer::new();
|
||||||
|
for ch in "/he".chars() {
|
||||||
|
input.insert(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(input.complete_slash_command(&[
|
||||||
|
"/help".to_string(),
|
||||||
|
"/hello".to_string(),
|
||||||
|
"/status".to_string(),
|
||||||
|
]));
|
||||||
|
assert_eq!(input.as_str(), "/hel");
|
||||||
|
|
||||||
|
assert!(input.complete_slash_command(&["/help".to_string(), "/status".to_string()]));
|
||||||
|
assert_eq!(input.as_str(), "/help");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ignores_completion_when_prefix_is_not_a_slash_command() {
|
||||||
|
let mut input = InputBuffer::new();
|
||||||
|
for ch in "hello".chars() {
|
||||||
|
input.insert(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!input.complete_slash_command(&["/help".to_string()]));
|
||||||
|
assert_eq!(input.as_str(), "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn history_navigation_restores_current_draft() {
|
||||||
|
let mut editor = LineEditor::new("› ", vec![]);
|
||||||
|
editor.push_history("/help");
|
||||||
|
editor.push_history("status report");
|
||||||
|
|
||||||
|
let mut input = InputBuffer::new();
|
||||||
|
for ch in "draft".chars() {
|
||||||
|
input.insert(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
||||||
|
assert_eq!(input.as_str(), "status report");
|
||||||
|
|
||||||
|
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
||||||
|
assert_eq!(input.as_str(), "/help");
|
||||||
|
|
||||||
|
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
||||||
|
assert_eq!(input.as_str(), "status report");
|
||||||
|
|
||||||
|
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
||||||
|
assert_eq!(input.as_str(), "draft");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tab_key_completes_from_editor_candidates() {
|
||||||
|
let mut editor = LineEditor::new(
|
||||||
|
"› ",
|
||||||
|
vec![
|
||||||
|
"/help".to_string(),
|
||||||
|
"/status".to_string(),
|
||||||
|
"/session".to_string(),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let mut input = InputBuffer::new();
|
||||||
|
for ch in "/st".chars() {
|
||||||
|
input.insert(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = editor.handle_key(key(KeyCode::Tab), &mut input);
|
||||||
|
assert_eq!(input.as_str(), "/status");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_multiline_buffers_with_continuation_prompt() {
|
||||||
|
let mut input = InputBuffer::new();
|
||||||
|
for ch in "hello\nworld".chars() {
|
||||||
|
if ch == '\n' {
|
||||||
|
input.insert_newline();
|
||||||
|
} else {
|
||||||
|
input.insert(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let rendered = render_buffer("› ", "> ", &input);
|
||||||
|
assert_eq!(
|
||||||
|
rendered.lines(),
|
||||||
|
&["› hello".to_string(), "> world".to_string()]
|
||||||
|
);
|
||||||
|
assert_eq!(rendered.cursor_position(), (1, 7));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ctrl_c_exits_only_when_buffer_is_empty() {
|
||||||
|
let mut editor = LineEditor::new("› ", vec![]);
|
||||||
|
let mut empty = InputBuffer::new();
|
||||||
|
assert!(matches!(
|
||||||
|
editor.handle_key(
|
||||||
|
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
||||||
|
&mut empty,
|
||||||
|
),
|
||||||
|
super::EditorAction::Exit
|
||||||
|
));
|
||||||
|
|
||||||
|
let mut filled = InputBuffer::new();
|
||||||
|
filled.insert('x');
|
||||||
|
assert!(matches!(
|
||||||
|
editor.handle_key(
|
||||||
|
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
||||||
|
&mut filled,
|
||||||
|
),
|
||||||
|
super::EditorAction::Cancel
|
||||||
|
));
|
||||||
|
assert!(filled.as_str().is_empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
1
rust/crates/tools/.gitignore
vendored
Normal file
1
rust/crates/tools/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
.clawd-agents/
|
||||||
@@ -7,6 +7,7 @@ publish.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
runtime = { path = "../runtime" }
|
runtime = { path = "../runtime" }
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user