1use reqwest::{header::CONTENT_TYPE, Client, RequestBuilder, Response};
30use serde::de::DeserializeOwned;
31use url::Url;
32
33use crate::error::{ApiErrorBody, Error, Result};
34use crate::types::*;
35use cellstate_core::{AgentId, ArtifactId, NoteId, ScopeId, SecretString, TrajectoryId, TurnId};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ResponseFormat {
43 Json,
45 MsgPack,
47}
48
49impl ResponseFormat {
50 fn accept_header_value(self) -> &'static str {
51 match self {
52 Self::Json => "application/json",
53 Self::MsgPack => "application/msgpack, application/json;q=0.9",
54 }
55 }
56}
57
58#[derive(Clone)]
59pub struct CellstateClient {
60 http: Client,
61 base_url: Url,
62 api_key: SecretString,
63 response_format: ResponseFormat,
64}
65
66impl std::fmt::Debug for CellstateClient {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("CellstateClient")
69 .field("http", &"<reqwest::Client>")
70 .field("base_url", &self.base_url)
71 .field(
72 "api_key",
73 &format!("[REDACTED, {} chars]", self.api_key.len()),
74 )
75 .field("response_format", &self.response_format)
76 .finish()
77 }
78}
79
80impl CellstateClient {
81 pub fn new(base_url: &str, api_key: &str) -> Result<Self> {
90 let base_url = Url::parse(base_url)?;
91 let http = Client::builder()
92 .user_agent("cellstate-rs/0.1.0")
93 .build()
94 .map_err(Error::Http)?;
95
96 Ok(Self {
97 http,
98 base_url,
99 api_key: SecretString::new(api_key),
100 response_format: ResponseFormat::Json,
101 })
102 }
103
104 pub fn from_env() -> Result<Self> {
111 let base_url = std::env::var("CELLSTATE_API_URL")
112 .or_else(|_| std::env::var("CELLSTATE_BASE_URL"))
113 .unwrap_or_else(|_| "https://cst.batterypack.dev".to_string());
114 let api_key = std::env::var("CELLSTATE_API_KEY").unwrap_or_default();
115
116 Self::new(&base_url, &api_key)
117 }
118
119 fn url(&self, path: &str) -> Result<Url> {
121 let mut url = self.base_url.clone();
122 let base_path = url.path().trim_end_matches('/');
123 let suffix = path.trim_start_matches('/');
124 url.set_path(&format!("{base_path}/{suffix}"));
125 Ok(url)
126 }
127
128 pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
130 self.response_format = response_format;
131 self
132 }
133
134 pub fn set_response_format(&mut self, response_format: ResponseFormat) {
136 self.response_format = response_format;
137 }
138
139 fn auth(&self, builder: RequestBuilder) -> RequestBuilder {
141 builder
142 .header("x-api-key", self.api_key.expose_secret())
143 .header("accept", self.response_format.accept_header_value())
144 }
145
146 fn is_msgpack_content_type(content_type: &str) -> bool {
147 let lower = content_type.to_ascii_lowercase();
148 lower.contains("application/msgpack") || lower.contains("application/x-msgpack")
149 }
150
151 async fn parse_response<T>(&self, response: Response) -> Result<T>
152 where
153 T: DeserializeOwned,
154 {
155 let content_type = response
156 .headers()
157 .get(CONTENT_TYPE)
158 .and_then(|v| v.to_str().ok())
159 .unwrap_or("")
160 .to_string();
161 let bytes = response.bytes().await?;
162
163 if Self::is_msgpack_content_type(&content_type) {
164 rmp_serde::from_slice::<T>(&bytes).map_err(|e| Error::Decode(e.to_string()))
165 } else {
166 Ok(serde_json::from_slice::<T>(&bytes)?)
167 }
168 }
169
170 async fn send(&self, builder: RequestBuilder) -> Result<Response> {
172 let response = self.auth(builder).send().await?;
173 let status = response.status();
174
175 if status.is_success() {
176 Ok(response)
177 } else {
178 let status_code = status.as_u16();
179 let content_type = response
180 .headers()
181 .get(CONTENT_TYPE)
182 .and_then(|v| v.to_str().ok())
183 .unwrap_or("")
184 .to_string();
185 let bytes = response.bytes().await.unwrap_or_default();
186 let body = if Self::is_msgpack_content_type(&content_type) {
187 rmp_serde::from_slice::<ApiErrorBody>(&bytes).ok()
188 } else {
189 serde_json::from_slice::<ApiErrorBody>(&bytes).ok()
190 }
191 .unwrap_or_else(|| ApiErrorBody {
192 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
193 message: format!("HTTP {}", status_code),
194 details: None,
195 });
196 Err(Error::Api {
197 status: status_code,
198 body,
199 })
200 }
201 }
202
203 pub async fn create_trajectory(
209 &self,
210 request: &CreateTrajectoryRequest,
211 ) -> Result<TrajectoryResponse> {
212 let url = self.url("/api/v1/trajectories")?;
213 let resp = self.send(self.http.post(url).json(request)).await?;
214 self.parse_response(resp).await
215 }
216
217 pub async fn get_trajectory(&self, id: TrajectoryId) -> Result<TrajectoryResponse> {
219 let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
220 let resp = self.send(self.http.get(url)).await?;
221 self.parse_response(resp).await
222 }
223
224 pub async fn list_trajectories(
226 &self,
227 params: &ListParams,
228 ) -> Result<ListResponse<TrajectoryResponse>> {
229 let mut url = self.url("/api/v1/trajectories")?;
230 for (key, value) in params.to_query_pairs() {
231 url.query_pairs_mut().append_pair(key, &value);
232 }
233 let resp = self.send(self.http.get(url)).await?;
234 self.parse_response(resp).await
235 }
236
237 pub async fn update_trajectory(
239 &self,
240 id: TrajectoryId,
241 request: &UpdateTrajectoryRequest,
242 ) -> Result<TrajectoryResponse> {
243 let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
244 let resp = self.send(self.http.patch(url).json(request)).await?;
245 self.parse_response(resp).await
246 }
247
248 pub async fn delete_trajectory(&self, id: TrajectoryId) -> Result<()> {
250 let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
251 self.send(self.http.delete(url)).await?;
252 Ok(())
253 }
254
255 pub async fn create_scope(&self, request: &CreateScopeRequest) -> Result<ScopeResponse> {
261 let url = self.url("/api/v1/scopes")?;
262 let resp = self.send(self.http.post(url).json(request)).await?;
263 self.parse_response(resp).await
264 }
265
266 pub async fn get_scope(&self, id: ScopeId) -> Result<ScopeResponse> {
268 let url = self.url(&format!("/api/v1/scopes/{}", id))?;
269 let resp = self.send(self.http.get(url)).await?;
270 self.parse_response(resp).await
271 }
272
273 pub async fn list_scopes(&self, params: &ListParams) -> Result<ListResponse<ScopeResponse>> {
275 let mut url = self.url("/api/v1/scopes")?;
276 for (key, value) in params.to_query_pairs() {
277 url.query_pairs_mut().append_pair(key, &value);
278 }
279 let resp = self.send(self.http.get(url)).await?;
280 self.parse_response(resp).await
281 }
282
283 pub async fn update_scope(
285 &self,
286 id: ScopeId,
287 request: &UpdateScopeRequest,
288 ) -> Result<ScopeResponse> {
289 let url = self.url(&format!("/api/v1/scopes/{}", id))?;
290 let resp = self.send(self.http.patch(url).json(request)).await?;
291 self.parse_response(resp).await
292 }
293
294 pub async fn delete_scope(&self, id: ScopeId) -> Result<()> {
296 let url = self.url(&format!("/api/v1/scopes/{}", id))?;
297 self.send(self.http.delete(url)).await?;
298 Ok(())
299 }
300
301 pub async fn create_artifact(
307 &self,
308 request: &CreateArtifactRequest,
309 ) -> Result<ArtifactResponse> {
310 let url = self.url("/api/v1/artifacts")?;
311 let resp = self.send(self.http.post(url).json(request)).await?;
312 self.parse_response(resp).await
313 }
314
315 pub async fn get_artifact(&self, id: ArtifactId) -> Result<ArtifactResponse> {
317 let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
318 let resp = self.send(self.http.get(url)).await?;
319 self.parse_response(resp).await
320 }
321
322 pub async fn list_artifacts(
324 &self,
325 params: &ListParams,
326 ) -> Result<ListResponse<ArtifactResponse>> {
327 let mut url = self.url("/api/v1/artifacts")?;
328 for (key, value) in params.to_query_pairs() {
329 url.query_pairs_mut().append_pair(key, &value);
330 }
331 let resp = self.send(self.http.get(url)).await?;
332 self.parse_response(resp).await
333 }
334
335 pub async fn update_artifact(
337 &self,
338 id: ArtifactId,
339 request: &UpdateArtifactRequest,
340 ) -> Result<ArtifactResponse> {
341 let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
342 let resp = self.send(self.http.patch(url).json(request)).await?;
343 self.parse_response(resp).await
344 }
345
346 pub async fn delete_artifact(&self, id: ArtifactId) -> Result<()> {
348 let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
349 self.send(self.http.delete(url)).await?;
350 Ok(())
351 }
352
353 pub async fn create_note(&self, request: &CreateNoteRequest) -> Result<NoteResponse> {
359 let url = self.url("/api/v1/notes")?;
360 let resp = self.send(self.http.post(url).json(request)).await?;
361 self.parse_response(resp).await
362 }
363
364 pub async fn get_note(&self, id: NoteId) -> Result<NoteResponse> {
366 let url = self.url(&format!("/api/v1/notes/{}", id))?;
367 let resp = self.send(self.http.get(url)).await?;
368 self.parse_response(resp).await
369 }
370
371 pub async fn list_notes(&self, params: &ListParams) -> Result<ListResponse<NoteResponse>> {
373 let mut url = self.url("/api/v1/notes")?;
374 for (key, value) in params.to_query_pairs() {
375 url.query_pairs_mut().append_pair(key, &value);
376 }
377 let resp = self.send(self.http.get(url)).await?;
378 self.parse_response(resp).await
379 }
380
381 pub async fn update_note(
383 &self,
384 id: NoteId,
385 request: &UpdateNoteRequest,
386 ) -> Result<NoteResponse> {
387 let url = self.url(&format!("/api/v1/notes/{}", id))?;
388 let resp = self.send(self.http.patch(url).json(request)).await?;
389 self.parse_response(resp).await
390 }
391
392 pub async fn delete_note(&self, id: NoteId) -> Result<()> {
394 let url = self.url(&format!("/api/v1/notes/{}", id))?;
395 self.send(self.http.delete(url)).await?;
396 Ok(())
397 }
398
399 pub async fn create_turn(&self, request: &CreateTurnRequest) -> Result<TurnResponse> {
405 let url = self.url("/api/v1/turns")?;
406 let resp = self.send(self.http.post(url).json(request)).await?;
407 self.parse_response(resp).await
408 }
409
410 pub async fn get_turn(&self, id: TurnId) -> Result<TurnResponse> {
412 let url = self.url(&format!("/api/v1/turns/{}", id))?;
413 let resp = self.send(self.http.get(url)).await?;
414 self.parse_response(resp).await
415 }
416
417 pub async fn list_turns(&self, params: &ListParams) -> Result<ListResponse<TurnResponse>> {
419 let mut url = self.url("/api/v1/turns")?;
420 for (key, value) in params.to_query_pairs() {
421 url.query_pairs_mut().append_pair(key, &value);
422 }
423 let resp = self.send(self.http.get(url)).await?;
424 self.parse_response(resp).await
425 }
426
427 pub async fn create_agent(&self, request: &CreateAgentRequest) -> Result<AgentResponse> {
433 let url = self.url("/api/v1/agents")?;
434 let resp = self.send(self.http.post(url).json(request)).await?;
435 self.parse_response(resp).await
436 }
437
438 pub async fn get_agent(&self, id: AgentId) -> Result<AgentResponse> {
440 let url = self.url(&format!("/api/v1/agents/{}", id))?;
441 let resp = self.send(self.http.get(url)).await?;
442 self.parse_response(resp).await
443 }
444
445 pub async fn list_agents(&self, params: &ListParams) -> Result<ListResponse<AgentResponse>> {
447 let mut url = self.url("/api/v1/agents")?;
448 for (key, value) in params.to_query_pairs() {
449 url.query_pairs_mut().append_pair(key, &value);
450 }
451 let resp = self.send(self.http.get(url)).await?;
452 self.parse_response(resp).await
453 }
454
455 pub async fn commit_memory(
461 &self,
462 request: &CommitMemoryRequest,
463 ) -> Result<CommitMemoryResponse> {
464 let url = self.url("/api/v1/context/commit")?;
465 let resp = self.send(self.http.post(url).json(request)).await?;
466 self.parse_response(resp).await
467 }
468
469 pub async fn recall(&self, request: &RecallRequest) -> Result<RecallResponse> {
471 let url = self.url("/api/v1/context/recall")?;
472 let resp = self.send(self.http.post(url).json(request)).await?;
473 self.parse_response(resp).await
474 }
475
476 pub async fn assemble_context(
478 &self,
479 request: &AssembleContextRequest,
480 ) -> Result<AssembleContextResponse> {
481 let url = self.url("/api/v1/context/assemble")?;
482 let resp = self.send(self.http.post(url).json(request)).await?;
483 self.parse_response(resp).await
484 }
485
486 pub async fn health(&self) -> Result<bool> {
495 let url = self.url("/health/live")?;
496 let resp = self.auth(self.http.get(url)).send().await?;
497 Ok(resp.status().is_success())
498 }
499}
500
501#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_client_creation() {
511 let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key");
512 assert!(client.is_ok());
513 }
514
515 #[test]
516 fn test_client_default_response_format_is_json() {
517 let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key").unwrap();
518 assert_eq!(client.response_format, ResponseFormat::Json);
519 }
520
521 #[test]
522 fn test_client_with_response_format_overrides_default() {
523 let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key")
524 .unwrap()
525 .with_response_format(ResponseFormat::MsgPack);
526 assert_eq!(client.response_format, ResponseFormat::MsgPack);
527 }
528
529 #[test]
530 fn test_client_invalid_url() {
531 let client = CellstateClient::new("not a url", "cst_test_key");
532 assert!(matches!(client, Err(Error::InvalidUrl(_))));
533 }
534
535 #[test]
536 fn test_url_construction() {
537 let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test").unwrap();
538 let url = client.url("/api/v1/trajectories").unwrap();
539 assert_eq!(
540 url.as_str(),
541 "https://cst.batterypack.dev/api/v1/trajectories"
542 );
543 }
544
545 #[test]
546 fn test_url_construction_trailing_slash() {
547 let client = CellstateClient::new("https://cst.batterypack.dev/", "cst_test").unwrap();
548 let url = client.url("/api/v1/scopes").unwrap();
549 assert_eq!(url.as_str(), "https://cst.batterypack.dev/api/v1/scopes");
550 }
551
552 #[test]
553 fn test_list_params_empty() {
554 let params = ListParams::default();
555 assert!(params.to_query_pairs().is_empty());
556 }
557
558 #[test]
559 fn test_list_params_with_values() {
560 let params = ListParams {
561 limit: Some(10),
562 offset: Some(20),
563 cursor: Some("abc".to_string()),
564 };
565 let pairs = params.to_query_pairs();
566 assert_eq!(pairs.len(), 3);
567 assert_eq!(pairs[0], ("limit", "10".to_string()));
568 assert_eq!(pairs[1], ("offset", "20".to_string()));
569 assert_eq!(pairs[2], ("cursor", "abc".to_string()));
570 }
571
572 #[test]
573 fn test_api_error_body_display() {
574 let body = ApiErrorBody {
575 code: "ENTITY_NOT_FOUND".to_string(),
576 message: "Trajectory 123 not found".to_string(),
577 details: None,
578 };
579 assert_eq!(
580 format!("{}", body),
581 "[ENTITY_NOT_FOUND] Trajectory 123 not found"
582 );
583 }
584
585 #[test]
586 fn test_api_error_body_serialization() {
587 let body = ApiErrorBody {
588 code: "UNAUTHORIZED".to_string(),
589 message: "Invalid API key".to_string(),
590 details: Some(serde_json::json!({"hint": "Check x-api-key header"})),
591 };
592 let json = serde_json::to_value(&body).unwrap();
593 assert_eq!(json["code"], "UNAUTHORIZED");
594 assert_eq!(json["message"], "Invalid API key");
595 assert!(json["details"]["hint"].is_string());
596 }
597
598 #[test]
599 fn test_client_debug_redacts_api_key() {
600 let client =
601 CellstateClient::new("https://cst.batterypack.dev", "cst_real_secret").unwrap();
602 let dbg = format!("{client:?}");
603 assert!(dbg.contains("[REDACTED"));
604 assert!(!dbg.contains("cst_real_secret"));
605 }
606}