cellstate_core/
session.rs

1//! Session typestate for compile-time safety of stateful LLM session lifecycle.
2//!
3//! Uses the typestate pattern to make invalid state transitions uncompilable.
4//! Sessions track stateful LLM connections where only deltas (new messages)
5//! are sent after the initial context, avoiding the "Groundhog Day" problem
6//! of re-sending full context every tool call.
7//!
8//! # State Transition Diagram
9//!
10//! ```text
11//! Session::new() → Created ── activate() ──→ Active ── close() ──→ Closed (terminal)
12//!                                              │
13//!                                         record_delta() ↺
14//!                                              │
15//!                                         expire() ──→ Expired (terminal)
16//! ```
17
18use crate::{EnumParseError, SessionId, TenantId, Timestamp};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use std::marker::PhantomData;
22use std::str::FromStr;
23
24// ============================================================================
25// SESSION STATUS ENUM
26// ============================================================================
27
28/// Status of a stateful LLM session.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
32pub enum SessionStatus {
33    /// Session has been created but not yet activated
34    Created,
35    /// Session is active and accepting deltas
36    Active,
37    /// Session was closed normally
38    Closed,
39    /// Session expired due to TTL
40    Expired,
41}
42
43impl SessionStatus {
44    /// Convert to database string representation.
45    pub fn as_db_str(&self) -> &'static str {
46        match self {
47            SessionStatus::Created => "created",
48            SessionStatus::Active => "active",
49            SessionStatus::Closed => "closed",
50            SessionStatus::Expired => "expired",
51        }
52    }
53
54    /// Parse from database string representation.
55    pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
56        match s.to_lowercase().as_str() {
57            "created" => Ok(SessionStatus::Created),
58            "active" => Ok(SessionStatus::Active),
59            "closed" => Ok(SessionStatus::Closed),
60            "expired" => Ok(SessionStatus::Expired),
61            _ => Err(EnumParseError {
62                enum_name: "SessionStatus",
63                input: s.to_string(),
64            }),
65        }
66    }
67
68    /// Check if this is a terminal state (no further transitions possible).
69    pub fn is_terminal(&self) -> bool {
70        matches!(self, SessionStatus::Closed | SessionStatus::Expired)
71    }
72}
73
74impl fmt::Display for SessionStatus {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        write!(f, "{}", self.as_db_str())
77    }
78}
79
80impl FromStr for SessionStatus {
81    type Err = EnumParseError;
82
83    fn from_str(s: &str) -> Result<Self, Self::Err> {
84        Self::from_db_str(s)
85    }
86}
87
88// ============================================================================
89// SESSION DATA (internal storage, state-independent)
90// ============================================================================
91
92/// Internal data storage for a session, independent of typestate.
93/// This is what gets persisted to the database.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
96pub struct SessionRecord {
97    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
98    pub session_id: SessionId,
99    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "uuid"))]
100    pub tenant_id: TenantId,
101    /// Provider that owns this session
102    pub provider_id: String,
103    /// Model locked to this session
104    pub model: String,
105    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "date-time"))]
106    pub created_at: Timestamp,
107    #[cfg_attr(feature = "openapi", schema(value_type = Option<String>, format = "date-time"))]
108    pub activated_at: Option<Timestamp>,
109    #[cfg_attr(feature = "openapi", schema(value_type = Option<String>, format = "date-time"))]
110    pub closed_at: Option<Timestamp>,
111    /// Time-to-live in seconds for this session
112    pub ttl_secs: u64,
113    /// Number of tool loop rounds completed
114    pub round_count: u64,
115    /// Token count of the initial context (first request)
116    pub initial_token_count: u64,
117    /// Cumulative token count of all deltas sent
118    pub delta_token_count: u64,
119    /// Arbitrary metadata
120    #[serde(default)]
121    pub metadata: serde_json::Value,
122}
123
124// ============================================================================
125// TYPESTATE MARKERS
126// ============================================================================
127
128/// Marker trait for session states.
129pub trait SessionState: private::Sealed + Send + Sync {}
130
131/// Session has been created but not yet activated.
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub struct Created;
134impl SessionState for Created {}
135
136/// Session is active and accepting deltas.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub struct Active;
139impl SessionState for Active {}
140
141/// Session was closed normally (terminal).
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub struct Closed;
144impl SessionState for Closed {}
145
146/// Session expired due to TTL (terminal).
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct SessionExpired;
149impl SessionState for SessionExpired {}
150
151mod private {
152    pub trait Sealed {}
153    impl Sealed for super::Created {}
154    impl Sealed for super::Active {}
155    impl Sealed for super::Closed {}
156    impl Sealed for super::SessionExpired {}
157}
158
159// ============================================================================
160// SESSION TYPESTATE WRAPPER
161// ============================================================================
162
163/// A stateful LLM session with compile-time state tracking.
164///
165/// The type parameter `S` indicates the current state of the session.
166/// Methods are only available in appropriate states:
167/// - `Session<Created>`: Can be activated
168/// - `Session<Active>`: Can record deltas, close, or expire
169/// - `Session<Closed>`: Terminal, no further transitions
170/// - `Session<SessionExpired>`: Terminal, no further transitions
171#[derive(Debug, Clone)]
172pub struct Session<S: SessionState> {
173    data: SessionRecord,
174    _state: PhantomData<S>,
175}
176
177impl<S: SessionState> Session<S> {
178    /// Access the underlying session data (read-only).
179    pub fn data(&self) -> &SessionRecord {
180        &self.data
181    }
182
183    /// Get the session ID.
184    pub fn session_id(&self) -> SessionId {
185        self.data.session_id
186    }
187
188    /// Get the tenant ID.
189    pub fn tenant_id(&self) -> TenantId {
190        self.data.tenant_id
191    }
192
193    /// Get the provider ID.
194    pub fn provider_id(&self) -> &str {
195        &self.data.provider_id
196    }
197
198    /// Get the model.
199    pub fn model(&self) -> &str {
200        &self.data.model
201    }
202
203    /// Get when the session was created.
204    pub fn created_at(&self) -> Timestamp {
205        self.data.created_at
206    }
207
208    /// Get the round count.
209    pub fn round_count(&self) -> u64 {
210        self.data.round_count
211    }
212
213    /// Consume and return the underlying data (for serialization).
214    pub fn into_data(self) -> SessionRecord {
215        self.data
216    }
217}
218
219impl Session<Created> {
220    /// Create a new session in the Created state.
221    pub fn new(data: SessionRecord) -> Self {
222        Session {
223            data,
224            _state: PhantomData,
225        }
226    }
227
228    /// Activate the session.
229    ///
230    /// Transitions to `Session<Active>`.
231    /// Consumes the current session.
232    pub fn activate(mut self, activated_at: Timestamp) -> Session<Active> {
233        self.data.activated_at = Some(activated_at);
234        Session {
235            data: self.data,
236            _state: PhantomData,
237        }
238    }
239}
240
241impl Session<Active> {
242    /// Get when the session was activated.
243    pub fn activated_at(&self) -> Timestamp {
244        self.data
245            .activated_at
246            .expect("Active session must have activated_at")
247    }
248
249    /// Record a delta (new messages sent to the provider).
250    ///
251    /// This is a looping state — it does not consume self.
252    /// Increments round_count and accumulates delta_tokens.
253    pub fn record_delta(&mut self, delta_tokens: u64) {
254        self.data.round_count += 1;
255        self.data.delta_token_count += delta_tokens;
256    }
257
258    /// Calculate total tokens saved by using stateful sessions.
259    ///
260    /// Without sessions, each round would re-send the full initial context.
261    /// Savings = (round_count * initial_tokens) - delta_tokens
262    pub fn tokens_saved(&self) -> u64 {
263        let would_have_sent = self.data.round_count * self.data.initial_token_count;
264        would_have_sent.saturating_sub(self.data.delta_token_count)
265    }
266
267    /// Close the session normally.
268    ///
269    /// Transitions to `Session<Closed>` (terminal state).
270    /// Consumes the current session.
271    pub fn close(mut self, closed_at: Timestamp) -> Session<Closed> {
272        self.data.closed_at = Some(closed_at);
273        Session {
274            data: self.data,
275            _state: PhantomData,
276        }
277    }
278
279    /// Expire the session due to TTL.
280    ///
281    /// Transitions to `Session<SessionExpired>` (terminal state).
282    /// Consumes the current session.
283    pub fn expire(mut self, expired_at: Timestamp) -> Session<SessionExpired> {
284        self.data.closed_at = Some(expired_at);
285        Session {
286            data: self.data,
287            _state: PhantomData,
288        }
289    }
290}
291
292impl Session<Closed> {
293    /// Get when the session was closed.
294    pub fn closed_at(&self) -> Timestamp {
295        self.data
296            .closed_at
297            .expect("Closed session must have closed_at")
298    }
299}
300
301impl Session<SessionExpired> {
302    /// Get when the session expired.
303    pub fn expired_at(&self) -> Timestamp {
304        self.data
305            .closed_at
306            .expect("Expired session must have closed_at")
307    }
308}
309
310// ============================================================================
311// DATABASE BOUNDARY: STORED SESSION
312// ============================================================================
313
314/// A session as stored in the database (status-agnostic).
315///
316/// When loading from the database, we don't know the state at compile time.
317/// Use the `into_typed` method to validate and convert to a typed session.
318#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
319pub struct StoredSession {
320    pub data: SessionRecord,
321    pub status: SessionStatus,
322}
323
324/// Enum representing all possible runtime states of a session.
325/// Use this when you need to handle sessions loaded from the database.
326#[derive(Debug, Clone)]
327pub enum LoadedSession {
328    Created(Session<Created>),
329    Active(Session<Active>),
330    Closed(Session<Closed>),
331    Expired(Session<SessionExpired>),
332}
333
334impl StoredSession {
335    /// Convert to a typed session based on the stored status.
336    pub fn into_typed(self) -> LoadedSession {
337        match self.status {
338            SessionStatus::Created => LoadedSession::Created(Session {
339                data: self.data,
340                _state: PhantomData,
341            }),
342            SessionStatus::Active => LoadedSession::Active(Session {
343                data: self.data,
344                _state: PhantomData,
345            }),
346            SessionStatus::Closed => LoadedSession::Closed(Session {
347                data: self.data,
348                _state: PhantomData,
349            }),
350            SessionStatus::Expired => LoadedSession::Expired(Session {
351                data: self.data,
352                _state: PhantomData,
353            }),
354        }
355    }
356
357    /// Try to convert to a created session.
358    pub fn into_created(self) -> Result<Session<Created>, SessionStateError> {
359        if self.status != SessionStatus::Created {
360            return Err(SessionStateError::WrongState {
361                session_id: self.data.session_id,
362                expected: SessionStatus::Created,
363                actual: self.status,
364            });
365        }
366        Ok(Session {
367            data: self.data,
368            _state: PhantomData,
369        })
370    }
371
372    /// Try to convert to an active session.
373    pub fn into_active(self) -> Result<Session<Active>, SessionStateError> {
374        if self.status != SessionStatus::Active {
375            return Err(SessionStateError::WrongState {
376                session_id: self.data.session_id,
377                expected: SessionStatus::Active,
378                actual: self.status,
379            });
380        }
381        Ok(Session {
382            data: self.data,
383            _state: PhantomData,
384        })
385    }
386
387    /// Get the underlying data without state validation.
388    pub fn data(&self) -> &SessionRecord {
389        &self.data
390    }
391
392    /// Get the current status.
393    pub fn status(&self) -> SessionStatus {
394        self.status
395    }
396}
397
398/// Errors when transitioning session states.
399#[derive(Debug, Clone, PartialEq, Eq)]
400pub enum SessionStateError {
401    /// Session is not in the expected state.
402    WrongState {
403        session_id: SessionId,
404        expected: SessionStatus,
405        actual: SessionStatus,
406    },
407}
408
409impl fmt::Display for SessionStateError {
410    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411        match self {
412            SessionStateError::WrongState {
413                session_id,
414                expected,
415                actual,
416            } => {
417                write!(
418                    f,
419                    "Session {} is in state {} but expected {}",
420                    session_id, actual, expected
421                )
422            }
423        }
424    }
425}
426
427impl std::error::Error for SessionStateError {}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::EntityIdType;
433    use chrono::Utc;
434
435    fn make_session_data() -> SessionRecord {
436        let now = Utc::now();
437        SessionRecord {
438            session_id: SessionId::now_v7(),
439            tenant_id: TenantId::now_v7(),
440            provider_id: "openai".to_string(),
441            model: "gpt-4".to_string(),
442            created_at: now,
443            activated_at: None,
444            closed_at: None,
445            ttl_secs: 300,
446            round_count: 0,
447            initial_token_count: 5000,
448            delta_token_count: 0,
449            metadata: serde_json::Value::Object(serde_json::Map::new()),
450        }
451    }
452
453    #[test]
454    fn test_session_status_roundtrip() {
455        for status in [
456            SessionStatus::Created,
457            SessionStatus::Active,
458            SessionStatus::Closed,
459            SessionStatus::Expired,
460        ] {
461            let db_str = status.as_db_str();
462            let parsed =
463                SessionStatus::from_db_str(db_str).expect("SessionStatus roundtrip should succeed");
464            assert_eq!(status, parsed);
465        }
466    }
467
468    #[test]
469    fn test_session_status_terminal() {
470        assert!(!SessionStatus::Created.is_terminal());
471        assert!(!SessionStatus::Active.is_terminal());
472        assert!(SessionStatus::Closed.is_terminal());
473        assert!(SessionStatus::Expired.is_terminal());
474    }
475
476    #[test]
477    fn test_session_create_activate_delta_close() {
478        let now = Utc::now();
479        let data = make_session_data();
480        let session = Session::<Created>::new(data);
481
482        // Activate
483        let mut active = session.activate(now);
484        assert_eq!(active.activated_at(), now);
485        assert_eq!(active.round_count(), 0);
486
487        // Record 3 deltas
488        active.record_delta(100);
489        active.record_delta(120);
490        active.record_delta(150);
491
492        assert_eq!(active.round_count(), 3);
493        assert_eq!(active.data().delta_token_count, 370);
494
495        // Verify tokens saved: 3 rounds * 5000 initial = 15000; deltas = 370; saved = 14630
496        assert_eq!(active.tokens_saved(), 14630);
497
498        // Close
499        let closed = active.close(now);
500        assert_eq!(closed.closed_at(), now);
501    }
502
503    #[test]
504    fn test_session_create_activate_expire() {
505        let now = Utc::now();
506        let data = make_session_data();
507        let session = Session::<Created>::new(data);
508
509        let active = session.activate(now);
510        let expired = active.expire(now);
511
512        // Terminal — verify expired_at is set
513        assert_eq!(expired.expired_at(), now);
514    }
515
516    #[test]
517    fn test_tokens_saved_calculation() {
518        let now = Utc::now();
519        let data = make_session_data();
520        let session = Session::<Created>::new(data);
521        let mut active = session.activate(now);
522
523        // 3 rounds × 5000 initial = 15000 would-have-sent
524        // deltas = 370 → saved = 14630
525        active.record_delta(100);
526        active.record_delta(120);
527        active.record_delta(150);
528
529        assert_eq!(active.tokens_saved(), 14630);
530    }
531
532    #[test]
533    fn test_tokens_saved_zero_rounds() {
534        let now = Utc::now();
535        let data = make_session_data();
536        let session = Session::<Created>::new(data);
537        let active = session.activate(now);
538
539        // 0 rounds → 0 saved
540        assert_eq!(active.tokens_saved(), 0);
541    }
542
543    #[test]
544    fn test_stored_session_into_typed() {
545        let data = make_session_data();
546
547        // Created
548        let stored = StoredSession {
549            data: data.clone(),
550            status: SessionStatus::Created,
551        };
552        assert!(matches!(stored.into_typed(), LoadedSession::Created(_)));
553
554        // Active
555        let stored = StoredSession {
556            data: data.clone(),
557            status: SessionStatus::Active,
558        };
559        assert!(matches!(stored.into_typed(), LoadedSession::Active(_)));
560
561        // Closed
562        let stored = StoredSession {
563            data: data.clone(),
564            status: SessionStatus::Closed,
565        };
566        assert!(matches!(stored.into_typed(), LoadedSession::Closed(_)));
567
568        // Expired
569        let stored = StoredSession {
570            data,
571            status: SessionStatus::Expired,
572        };
573        assert!(matches!(stored.into_typed(), LoadedSession::Expired(_)));
574    }
575
576    #[test]
577    fn test_stored_session_wrong_state() {
578        let data = make_session_data();
579        let stored = StoredSession {
580            data,
581            status: SessionStatus::Active,
582        };
583
584        assert!(matches!(
585            stored.into_created(),
586            Err(SessionStateError::WrongState { .. })
587        ));
588    }
589
590    #[test]
591    fn test_stored_session_into_active() {
592        let data = make_session_data();
593        let stored = StoredSession {
594            data: data.clone(),
595            status: SessionStatus::Active,
596        };
597
598        let active = stored.into_active().expect("should convert to active");
599        assert_eq!(active.session_id(), data.session_id);
600    }
601
602    #[test]
603    fn test_session_status_display_fromstr() {
604        for status in [
605            SessionStatus::Created,
606            SessionStatus::Active,
607            SessionStatus::Closed,
608            SessionStatus::Expired,
609        ] {
610            let s = status.to_string();
611            let parsed: SessionStatus = s.parse().expect("should parse from display string");
612            assert_eq!(status, parsed);
613        }
614    }
615}