cellstate/
client.rs

1//! HTTP client for the CELLSTATE REST API.
2//!
3//! `CellstateClient` provides typed methods for all entity CRUD operations,
4//! context assembly, memory commits, and recall queries. All responses
5//! are deserialized into the corresponding `types::*` structs.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use cellstate::CellstateClient;
11//! use cellstate::types::CreateTrajectoryRequest;
12//!
13//! # async fn example() -> cellstate::error::Result<()> {
14//! let client = CellstateClient::new("https://cst.batterypack.dev", "cst_your_api_key")?;
15//!
16//! let trajectory = client.create_trajectory(&CreateTrajectoryRequest {
17//!     name: "Research task".to_string(),
18//!     description: Some("Investigating memory models".to_string()),
19//!     parent_trajectory_id: None,
20//!     agent_id: None,
21//!     metadata: None,
22//! }).await?;
23//!
24//! println!("Created trajectory: {}", trajectory.trajectory_id);
25//! # Ok(())
26//! # }
27//! ```
28
29use reqwest::{header::CONTENT_TYPE, Client, RequestBuilder, Response};
30use serde::de::DeserializeOwned;
31use url::Url;
32
33use crate::error::{ApiErrorBody, Error, Result};
34use crate::types::*;
35use cellstate_core::{AgentId, ArtifactId, NoteId, ScopeId, SecretString, TrajectoryId, TurnId};
36
37/// CELLSTATE API client.
38///
39/// Wraps `reqwest::Client` with base URL management, API key authentication,
40/// and typed request/response handling.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ResponseFormat {
43    /// Request standard JSON responses.
44    Json,
45    /// Request MessagePack responses when the API supports it.
46    MsgPack,
47}
48
49impl ResponseFormat {
50    fn accept_header_value(self) -> &'static str {
51        match self {
52            Self::Json => "application/json",
53            Self::MsgPack => "application/msgpack, application/json;q=0.9",
54        }
55    }
56}
57
58#[derive(Clone)]
59pub struct CellstateClient {
60    http: Client,
61    base_url: Url,
62    api_key: SecretString,
63    response_format: ResponseFormat,
64}
65
66impl std::fmt::Debug for CellstateClient {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("CellstateClient")
69            .field("http", &"<reqwest::Client>")
70            .field("base_url", &self.base_url)
71            .field(
72                "api_key",
73                &format!("[REDACTED, {} chars]", self.api_key.len()),
74            )
75            .field("response_format", &self.response_format)
76            .finish()
77    }
78}
79
80impl CellstateClient {
81    /// Create a new client pointing at the given API base URL.
82    ///
83    /// # Arguments
84    /// - `base_url`: Full URL to the CELLSTATE API (e.g., `"https://cst.batterypack.dev"`)
85    /// - `api_key`: API key for authentication (e.g., `"cst_..."`)
86    ///
87    /// # Errors
88    /// Returns `Error::InvalidUrl` if `base_url` cannot be parsed.
89    pub fn new(base_url: &str, api_key: &str) -> Result<Self> {
90        let base_url = Url::parse(base_url)?;
91        let http = Client::builder()
92            .user_agent("cellstate-rs/0.1.0")
93            .build()
94            .map_err(Error::Http)?;
95
96        Ok(Self {
97            http,
98            base_url,
99            api_key: SecretString::new(api_key),
100            response_format: ResponseFormat::Json,
101        })
102    }
103
104    /// Create a client from environment variables.
105    ///
106    /// Reads `CELLSTATE_API_URL` (or `CELLSTATE_BASE_URL`) and `CELLSTATE_API_KEY`.
107    ///
108    /// # Errors
109    /// Returns `Error::InvalidUrl` if the URL env var is missing or invalid.
110    pub fn from_env() -> Result<Self> {
111        let base_url = std::env::var("CELLSTATE_API_URL")
112            .or_else(|_| std::env::var("CELLSTATE_BASE_URL"))
113            .unwrap_or_else(|_| "https://cst.batterypack.dev".to_string());
114        let api_key = std::env::var("CELLSTATE_API_KEY").unwrap_or_default();
115
116        Self::new(&base_url, &api_key)
117    }
118
119    /// Build the full URL for an API path.
120    fn url(&self, path: &str) -> Result<Url> {
121        let mut url = self.base_url.clone();
122        let base_path = url.path().trim_end_matches('/');
123        let suffix = path.trim_start_matches('/');
124        url.set_path(&format!("{base_path}/{suffix}"));
125        Ok(url)
126    }
127
128    /// Returns a cloned client configured to request the given response format.
129    pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
130        self.response_format = response_format;
131        self
132    }
133
134    /// Set the preferred response format for this client.
135    pub fn set_response_format(&mut self, response_format: ResponseFormat) {
136        self.response_format = response_format;
137    }
138
139    /// Attach authentication and common headers to a request.
140    fn auth(&self, builder: RequestBuilder) -> RequestBuilder {
141        builder
142            .header("x-api-key", self.api_key.expose_secret())
143            .header("accept", self.response_format.accept_header_value())
144    }
145
146    fn is_msgpack_content_type(content_type: &str) -> bool {
147        let lower = content_type.to_ascii_lowercase();
148        lower.contains("application/msgpack") || lower.contains("application/x-msgpack")
149    }
150
151    async fn parse_response<T>(&self, response: Response) -> Result<T>
152    where
153        T: DeserializeOwned,
154    {
155        let content_type = response
156            .headers()
157            .get(CONTENT_TYPE)
158            .and_then(|v| v.to_str().ok())
159            .unwrap_or("")
160            .to_string();
161        let bytes = response.bytes().await?;
162
163        if Self::is_msgpack_content_type(&content_type) {
164            rmp_serde::from_slice::<T>(&bytes).map_err(|e| Error::Decode(e.to_string()))
165        } else {
166            Ok(serde_json::from_slice::<T>(&bytes)?)
167        }
168    }
169
170    /// Execute a request and handle error responses.
171    async fn send(&self, builder: RequestBuilder) -> Result<Response> {
172        let response = self.auth(builder).send().await?;
173        let status = response.status();
174
175        if status.is_success() {
176            Ok(response)
177        } else {
178            let status_code = status.as_u16();
179            let content_type = response
180                .headers()
181                .get(CONTENT_TYPE)
182                .and_then(|v| v.to_str().ok())
183                .unwrap_or("")
184                .to_string();
185            let bytes = response.bytes().await.unwrap_or_default();
186            let body = if Self::is_msgpack_content_type(&content_type) {
187                rmp_serde::from_slice::<ApiErrorBody>(&bytes).ok()
188            } else {
189                serde_json::from_slice::<ApiErrorBody>(&bytes).ok()
190            }
191            .unwrap_or_else(|| ApiErrorBody {
192                code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
193                message: format!("HTTP {}", status_code),
194                details: None,
195            });
196            Err(Error::Api {
197                status: status_code,
198                body,
199            })
200        }
201    }
202
203    // ========================================================================
204    // TRAJECTORY CRUD
205    // ========================================================================
206
207    /// Create a new trajectory.
208    pub async fn create_trajectory(
209        &self,
210        request: &CreateTrajectoryRequest,
211    ) -> Result<TrajectoryResponse> {
212        let url = self.url("/api/v1/trajectories")?;
213        let resp = self.send(self.http.post(url).json(request)).await?;
214        self.parse_response(resp).await
215    }
216
217    /// Get a trajectory by ID.
218    pub async fn get_trajectory(&self, id: TrajectoryId) -> Result<TrajectoryResponse> {
219        let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
220        let resp = self.send(self.http.get(url)).await?;
221        self.parse_response(resp).await
222    }
223
224    /// List trajectories with optional pagination.
225    pub async fn list_trajectories(
226        &self,
227        params: &ListParams,
228    ) -> Result<ListResponse<TrajectoryResponse>> {
229        let mut url = self.url("/api/v1/trajectories")?;
230        for (key, value) in params.to_query_pairs() {
231            url.query_pairs_mut().append_pair(key, &value);
232        }
233        let resp = self.send(self.http.get(url)).await?;
234        self.parse_response(resp).await
235    }
236
237    /// Update a trajectory.
238    pub async fn update_trajectory(
239        &self,
240        id: TrajectoryId,
241        request: &UpdateTrajectoryRequest,
242    ) -> Result<TrajectoryResponse> {
243        let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
244        let resp = self.send(self.http.patch(url).json(request)).await?;
245        self.parse_response(resp).await
246    }
247
248    /// Delete a trajectory.
249    pub async fn delete_trajectory(&self, id: TrajectoryId) -> Result<()> {
250        let url = self.url(&format!("/api/v1/trajectories/{}", id))?;
251        self.send(self.http.delete(url)).await?;
252        Ok(())
253    }
254
255    // ========================================================================
256    // SCOPE CRUD
257    // ========================================================================
258
259    /// Create a new scope.
260    pub async fn create_scope(&self, request: &CreateScopeRequest) -> Result<ScopeResponse> {
261        let url = self.url("/api/v1/scopes")?;
262        let resp = self.send(self.http.post(url).json(request)).await?;
263        self.parse_response(resp).await
264    }
265
266    /// Get a scope by ID.
267    pub async fn get_scope(&self, id: ScopeId) -> Result<ScopeResponse> {
268        let url = self.url(&format!("/api/v1/scopes/{}", id))?;
269        let resp = self.send(self.http.get(url)).await?;
270        self.parse_response(resp).await
271    }
272
273    /// List scopes with optional pagination.
274    pub async fn list_scopes(&self, params: &ListParams) -> Result<ListResponse<ScopeResponse>> {
275        let mut url = self.url("/api/v1/scopes")?;
276        for (key, value) in params.to_query_pairs() {
277            url.query_pairs_mut().append_pair(key, &value);
278        }
279        let resp = self.send(self.http.get(url)).await?;
280        self.parse_response(resp).await
281    }
282
283    /// Update a scope.
284    pub async fn update_scope(
285        &self,
286        id: ScopeId,
287        request: &UpdateScopeRequest,
288    ) -> Result<ScopeResponse> {
289        let url = self.url(&format!("/api/v1/scopes/{}", id))?;
290        let resp = self.send(self.http.patch(url).json(request)).await?;
291        self.parse_response(resp).await
292    }
293
294    /// Delete a scope.
295    pub async fn delete_scope(&self, id: ScopeId) -> Result<()> {
296        let url = self.url(&format!("/api/v1/scopes/{}", id))?;
297        self.send(self.http.delete(url)).await?;
298        Ok(())
299    }
300
301    // ========================================================================
302    // ARTIFACT CRUD
303    // ========================================================================
304
305    /// Create a new artifact.
306    pub async fn create_artifact(
307        &self,
308        request: &CreateArtifactRequest,
309    ) -> Result<ArtifactResponse> {
310        let url = self.url("/api/v1/artifacts")?;
311        let resp = self.send(self.http.post(url).json(request)).await?;
312        self.parse_response(resp).await
313    }
314
315    /// Get an artifact by ID.
316    pub async fn get_artifact(&self, id: ArtifactId) -> Result<ArtifactResponse> {
317        let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
318        let resp = self.send(self.http.get(url)).await?;
319        self.parse_response(resp).await
320    }
321
322    /// List artifacts with optional pagination.
323    pub async fn list_artifacts(
324        &self,
325        params: &ListParams,
326    ) -> Result<ListResponse<ArtifactResponse>> {
327        let mut url = self.url("/api/v1/artifacts")?;
328        for (key, value) in params.to_query_pairs() {
329            url.query_pairs_mut().append_pair(key, &value);
330        }
331        let resp = self.send(self.http.get(url)).await?;
332        self.parse_response(resp).await
333    }
334
335    /// Update an artifact.
336    pub async fn update_artifact(
337        &self,
338        id: ArtifactId,
339        request: &UpdateArtifactRequest,
340    ) -> Result<ArtifactResponse> {
341        let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
342        let resp = self.send(self.http.patch(url).json(request)).await?;
343        self.parse_response(resp).await
344    }
345
346    /// Delete an artifact.
347    pub async fn delete_artifact(&self, id: ArtifactId) -> Result<()> {
348        let url = self.url(&format!("/api/v1/artifacts/{}", id))?;
349        self.send(self.http.delete(url)).await?;
350        Ok(())
351    }
352
353    // ========================================================================
354    // NOTE CRUD
355    // ========================================================================
356
357    /// Create a new note.
358    pub async fn create_note(&self, request: &CreateNoteRequest) -> Result<NoteResponse> {
359        let url = self.url("/api/v1/notes")?;
360        let resp = self.send(self.http.post(url).json(request)).await?;
361        self.parse_response(resp).await
362    }
363
364    /// Get a note by ID.
365    pub async fn get_note(&self, id: NoteId) -> Result<NoteResponse> {
366        let url = self.url(&format!("/api/v1/notes/{}", id))?;
367        let resp = self.send(self.http.get(url)).await?;
368        self.parse_response(resp).await
369    }
370
371    /// List notes with optional pagination.
372    pub async fn list_notes(&self, params: &ListParams) -> Result<ListResponse<NoteResponse>> {
373        let mut url = self.url("/api/v1/notes")?;
374        for (key, value) in params.to_query_pairs() {
375            url.query_pairs_mut().append_pair(key, &value);
376        }
377        let resp = self.send(self.http.get(url)).await?;
378        self.parse_response(resp).await
379    }
380
381    /// Update a note.
382    pub async fn update_note(
383        &self,
384        id: NoteId,
385        request: &UpdateNoteRequest,
386    ) -> Result<NoteResponse> {
387        let url = self.url(&format!("/api/v1/notes/{}", id))?;
388        let resp = self.send(self.http.patch(url).json(request)).await?;
389        self.parse_response(resp).await
390    }
391
392    /// Delete a note.
393    pub async fn delete_note(&self, id: NoteId) -> Result<()> {
394        let url = self.url(&format!("/api/v1/notes/{}", id))?;
395        self.send(self.http.delete(url)).await?;
396        Ok(())
397    }
398
399    // ========================================================================
400    // TURN CRUD
401    // ========================================================================
402
403    /// Create a new turn.
404    pub async fn create_turn(&self, request: &CreateTurnRequest) -> Result<TurnResponse> {
405        let url = self.url("/api/v1/turns")?;
406        let resp = self.send(self.http.post(url).json(request)).await?;
407        self.parse_response(resp).await
408    }
409
410    /// Get a turn by ID.
411    pub async fn get_turn(&self, id: TurnId) -> Result<TurnResponse> {
412        let url = self.url(&format!("/api/v1/turns/{}", id))?;
413        let resp = self.send(self.http.get(url)).await?;
414        self.parse_response(resp).await
415    }
416
417    /// List turns with optional pagination.
418    pub async fn list_turns(&self, params: &ListParams) -> Result<ListResponse<TurnResponse>> {
419        let mut url = self.url("/api/v1/turns")?;
420        for (key, value) in params.to_query_pairs() {
421            url.query_pairs_mut().append_pair(key, &value);
422        }
423        let resp = self.send(self.http.get(url)).await?;
424        self.parse_response(resp).await
425    }
426
427    // ========================================================================
428    // AGENT
429    // ========================================================================
430
431    /// Create a new agent.
432    pub async fn create_agent(&self, request: &CreateAgentRequest) -> Result<AgentResponse> {
433        let url = self.url("/api/v1/agents")?;
434        let resp = self.send(self.http.post(url).json(request)).await?;
435        self.parse_response(resp).await
436    }
437
438    /// Get an agent by ID.
439    pub async fn get_agent(&self, id: AgentId) -> Result<AgentResponse> {
440        let url = self.url(&format!("/api/v1/agents/{}", id))?;
441        let resp = self.send(self.http.get(url)).await?;
442        self.parse_response(resp).await
443    }
444
445    /// List agents with optional pagination.
446    pub async fn list_agents(&self, params: &ListParams) -> Result<ListResponse<AgentResponse>> {
447        let mut url = self.url("/api/v1/agents")?;
448        for (key, value) in params.to_query_pairs() {
449            url.query_pairs_mut().append_pair(key, &value);
450        }
451        let resp = self.send(self.http.get(url)).await?;
452        self.parse_response(resp).await
453    }
454
455    // ========================================================================
456    // MEMORY OPERATIONS
457    // ========================================================================
458
459    /// Commit a memory interaction (PCP memory commit).
460    pub async fn commit_memory(
461        &self,
462        request: &CommitMemoryRequest,
463    ) -> Result<CommitMemoryResponse> {
464        let url = self.url("/api/v1/context/commit")?;
465        let resp = self.send(self.http.post(url).json(request)).await?;
466        self.parse_response(resp).await
467    }
468
469    /// Recall previous interactions for a trajectory.
470    pub async fn recall(&self, request: &RecallRequest) -> Result<RecallResponse> {
471        let url = self.url("/api/v1/context/recall")?;
472        let resp = self.send(self.http.post(url).json(request)).await?;
473        self.parse_response(resp).await
474    }
475
476    /// Assemble a context window for a scope.
477    pub async fn assemble_context(
478        &self,
479        request: &AssembleContextRequest,
480    ) -> Result<AssembleContextResponse> {
481        let url = self.url("/api/v1/context/assemble")?;
482        let resp = self.send(self.http.post(url).json(request)).await?;
483        self.parse_response(resp).await
484    }
485
486    // ========================================================================
487    // HEALTH
488    // ========================================================================
489
490    /// Check API health.
491    ///
492    /// Returns `Ok(true)` if the API is healthy, `Ok(false)` if it returns
493    /// a non-success status, or `Err` on connection failure.
494    pub async fn health(&self) -> Result<bool> {
495        let url = self.url("/health/live")?;
496        let resp = self.auth(self.http.get(url)).send().await?;
497        Ok(resp.status().is_success())
498    }
499}
500
501// ============================================================================
502// TESTS
503// ============================================================================
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_client_creation() {
511        let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key");
512        assert!(client.is_ok());
513    }
514
515    #[test]
516    fn test_client_default_response_format_is_json() {
517        let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key").unwrap();
518        assert_eq!(client.response_format, ResponseFormat::Json);
519    }
520
521    #[test]
522    fn test_client_with_response_format_overrides_default() {
523        let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test_key")
524            .unwrap()
525            .with_response_format(ResponseFormat::MsgPack);
526        assert_eq!(client.response_format, ResponseFormat::MsgPack);
527    }
528
529    #[test]
530    fn test_client_invalid_url() {
531        let client = CellstateClient::new("not a url", "cst_test_key");
532        assert!(matches!(client, Err(Error::InvalidUrl(_))));
533    }
534
535    #[test]
536    fn test_url_construction() {
537        let client = CellstateClient::new("https://cst.batterypack.dev", "cst_test").unwrap();
538        let url = client.url("/api/v1/trajectories").unwrap();
539        assert_eq!(
540            url.as_str(),
541            "https://cst.batterypack.dev/api/v1/trajectories"
542        );
543    }
544
545    #[test]
546    fn test_url_construction_trailing_slash() {
547        let client = CellstateClient::new("https://cst.batterypack.dev/", "cst_test").unwrap();
548        let url = client.url("/api/v1/scopes").unwrap();
549        assert_eq!(url.as_str(), "https://cst.batterypack.dev/api/v1/scopes");
550    }
551
552    #[test]
553    fn test_list_params_empty() {
554        let params = ListParams::default();
555        assert!(params.to_query_pairs().is_empty());
556    }
557
558    #[test]
559    fn test_list_params_with_values() {
560        let params = ListParams {
561            limit: Some(10),
562            offset: Some(20),
563            cursor: Some("abc".to_string()),
564        };
565        let pairs = params.to_query_pairs();
566        assert_eq!(pairs.len(), 3);
567        assert_eq!(pairs[0], ("limit", "10".to_string()));
568        assert_eq!(pairs[1], ("offset", "20".to_string()));
569        assert_eq!(pairs[2], ("cursor", "abc".to_string()));
570    }
571
572    #[test]
573    fn test_api_error_body_display() {
574        let body = ApiErrorBody {
575            code: "ENTITY_NOT_FOUND".to_string(),
576            message: "Trajectory 123 not found".to_string(),
577            details: None,
578        };
579        assert_eq!(
580            format!("{}", body),
581            "[ENTITY_NOT_FOUND] Trajectory 123 not found"
582        );
583    }
584
585    #[test]
586    fn test_api_error_body_serialization() {
587        let body = ApiErrorBody {
588            code: "UNAUTHORIZED".to_string(),
589            message: "Invalid API key".to_string(),
590            details: Some(serde_json::json!({"hint": "Check x-api-key header"})),
591        };
592        let json = serde_json::to_value(&body).unwrap();
593        assert_eq!(json["code"], "UNAUTHORIZED");
594        assert_eq!(json["message"], "Invalid API key");
595        assert!(json["details"]["hint"].is_string());
596    }
597
598    #[test]
599    fn test_client_debug_redacts_api_key() {
600        let client =
601            CellstateClient::new("https://cst.batterypack.dev", "cst_real_secret").unwrap();
602        let dbg = format!("{client:?}");
603        assert!(dbg.contains("[REDACTED"));
604        assert!(!dbg.contains("cst_real_secret"));
605    }
606}