cellstate_core/
llm.rs

1//! LLM-related primitive types and traits.
2//!
3//! Pure data types and interface definitions for LLM operations.
4//! Runtime orchestration (ProviderRegistry, CircuitBreaker) lives in crates/api/src/providers/.
5
6use crate::{ArtifactType, CellstateResult, EmbeddingVector, EnumParseError};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, OnceLock};
10
11// ============================================================================
12// SUMMARIZATION TYPES
13// ============================================================================
14
15/// Style of summarization output.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
19pub enum SummarizeStyle {
20    /// Brief, high-level summary
21    Brief,
22    /// Detailed, comprehensive summary
23    Detailed,
24    /// Structured summary with sections
25    Structured,
26}
27
28impl SummarizeStyle {
29    /// Convert to database string representation.
30    pub fn as_db_str(&self) -> &'static str {
31        match self {
32            Self::Brief => "brief",
33            Self::Detailed => "detailed",
34            Self::Structured => "structured",
35        }
36    }
37
38    /// Parse from database string representation.
39    pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
40        match s.to_lowercase().as_str() {
41            "brief" => Ok(Self::Brief),
42            "detailed" => Ok(Self::Detailed),
43            "structured" => Ok(Self::Structured),
44            _ => Err(EnumParseError {
45                enum_name: "SummarizeStyle",
46                input: s.to_string(),
47            }),
48        }
49    }
50}
51
52/// Configuration for summarization requests.
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
55pub struct SummarizeConfig {
56    /// Maximum tokens in the summary
57    pub max_tokens: i32,
58    /// Style of summary to generate
59    pub style: SummarizeStyle,
60}
61
62impl Default for SummarizeConfig {
63    fn default() -> Self {
64        Self {
65            max_tokens: 256,
66            style: SummarizeStyle::Brief,
67        }
68    }
69}
70
71// ============================================================================
72// PROVIDER CAPABILITY
73// ============================================================================
74
75/// Capabilities a provider can offer.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
79pub enum ProviderCapability {
80    /// Generate embeddings
81    Embedding,
82    /// Generate summaries
83    Summarization,
84    /// Extract artifacts from content
85    ArtifactExtraction,
86    /// Detect contradictions between content
87    ContradictionDetection,
88    /// Chat completion (conversational LLM)
89    ChatCompletion,
90    /// Stateful chat sessions (provider maintains context server-side)
91    ChatSession,
92}
93
94impl ProviderCapability {
95    /// Convert to database string representation.
96    pub fn as_db_str(&self) -> &'static str {
97        match self {
98            Self::Embedding => "embedding",
99            Self::Summarization => "summarization",
100            Self::ArtifactExtraction => "artifact_extraction",
101            Self::ContradictionDetection => "contradiction_detection",
102            Self::ChatCompletion => "chat_completion",
103            Self::ChatSession => "chat_session",
104        }
105    }
106
107    /// Parse from database string representation.
108    pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
109        match s.to_lowercase().replace('_', "").as_str() {
110            "embedding" => Ok(Self::Embedding),
111            "summarization" => Ok(Self::Summarization),
112            "artifactextraction" => Ok(Self::ArtifactExtraction),
113            "contradictiondetection" => Ok(Self::ContradictionDetection),
114            "chatcompletion" => Ok(Self::ChatCompletion),
115            "chatsession" => Ok(Self::ChatSession),
116            _ => Err(EnumParseError {
117                enum_name: "ProviderCapability",
118                input: s.to_string(),
119            }),
120        }
121    }
122}
123
124// ============================================================================
125// CIRCUIT STATE
126// ============================================================================
127
128/// Circuit breaker state.
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
132pub enum CircuitState {
133    /// Circuit is closed, requests flow normally
134    Closed = 0,
135    /// Circuit is open, requests are rejected
136    Open = 1,
137    /// Circuit is half-open, testing if service recovered
138    HalfOpen = 2,
139}
140
141impl From<u8> for CircuitState {
142    fn from(v: u8) -> Self {
143        match v {
144            0 => CircuitState::Closed,
145            1 => CircuitState::Open,
146            _ => CircuitState::HalfOpen,
147        }
148    }
149}
150
151impl CircuitState {
152    /// Convert to database string representation.
153    pub fn as_db_str(&self) -> &'static str {
154        match self {
155            Self::Closed => "closed",
156            Self::Open => "open",
157            Self::HalfOpen => "half_open",
158        }
159    }
160
161    /// Parse from database string representation.
162    pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
163        match s.to_lowercase().replace('_', "").as_str() {
164            "closed" => Ok(Self::Closed),
165            "open" => Ok(Self::Open),
166            "halfopen" => Ok(Self::HalfOpen),
167            _ => Err(EnumParseError {
168                enum_name: "CircuitState",
169                input: s.to_string(),
170            }),
171        }
172    }
173}
174
175// ============================================================================
176// ROUTING STRATEGY
177// ============================================================================
178
179/// Strategy for routing requests to providers.
180#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
181#[serde(rename_all = "snake_case")]
182#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
183pub enum RoutingStrategy {
184    /// Round-robin between providers
185    #[default]
186    RoundRobin,
187    /// Route to provider with lowest latency
188    LeastLatency,
189    /// Random selection
190    Random,
191    /// Route based on capability
192    Capability(ProviderCapability),
193    /// Always use first available provider
194    First,
195}
196
197impl RoutingStrategy {
198    /// Convert to database string representation.
199    pub fn as_db_str(&self) -> &'static str {
200        match self {
201            Self::RoundRobin => "round_robin",
202            Self::LeastLatency => "least_latency",
203            Self::Random => "random",
204            Self::Capability(_) => "capability",
205            Self::First => "first",
206        }
207    }
208}
209
210// ============================================================================
211// EXTRACTED ARTIFACT
212// ============================================================================
213
214/// An artifact extracted from content.
215#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
216#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
217pub struct ExtractedArtifact {
218    /// Type of artifact extracted
219    pub artifact_type: ArtifactType,
220    /// The extracted content
221    pub content: String,
222    /// Confidence score (0.0 - 1.0)
223    pub confidence: f32,
224}
225
226// ============================================================================
227// PROVIDER TRAITS
228// ============================================================================
229
230/// Async trait for embedding providers.
231///
232/// Implementations must be thread-safe (Send + Sync).
233/// This is the interface definition only - implementations live in crates/api/src/providers/.
234#[async_trait]
235pub trait EmbeddingProvider: Send + Sync {
236    /// Generate an embedding for a single text.
237    async fn embed(&self, text: &str) -> CellstateResult<EmbeddingVector>;
238
239    /// Generate embeddings for multiple texts in a batch.
240    async fn embed_batch(&self, texts: &[&str]) -> CellstateResult<Vec<EmbeddingVector>>;
241
242    /// Get the number of dimensions this provider produces.
243    fn dimensions(&self) -> i32;
244
245    /// Get the model identifier for this provider.
246    fn model_id(&self) -> &str;
247}
248
249/// Async trait for summarization providers.
250///
251/// This is the interface definition only - implementations live in crates/api/src/providers/.
252#[async_trait]
253pub trait SummarizationProvider: Send + Sync {
254    /// Summarize content according to the provided configuration.
255    async fn summarize(&self, content: &str, config: &SummarizeConfig) -> CellstateResult<String>;
256
257    /// Extract artifacts of specified types from content.
258    async fn extract_artifacts(
259        &self,
260        content: &str,
261        types: &[ArtifactType],
262    ) -> CellstateResult<Vec<ExtractedArtifact>>;
263
264    /// Detect if two pieces of content contradict each other.
265    async fn detect_contradiction(&self, a: &str, b: &str) -> CellstateResult<bool>;
266}
267
268// ============================================================================
269// TOKENIZER TRAIT
270// ============================================================================
271
272/// Trait for counting tokens in text.
273///
274/// Used for token budget management in context assembly.
275/// Implementations can provide exact counts (using actual tokenizer)
276/// or heuristic estimates based on character ratios.
277pub trait Tokenizer: Send + Sync {
278    /// Count tokens in the given text.
279    fn count(&self, text: &str) -> i32;
280
281    /// Get the model family this tokenizer is for (e.g., "gpt-4", "claude").
282    fn model_family(&self) -> &str;
283
284    /// Encode text to token IDs (for advanced use cases).
285    /// Returns empty vec if not supported.
286    fn encode(&self, text: &str) -> Vec<u32>;
287
288    /// Decode token IDs back to text.
289    /// Returns empty string if not supported.
290    fn decode(&self, tokens: &[u32]) -> String;
291}
292
293/// Estimate token count using the default heuristic tokenizer.
294pub fn estimate_tokens_heuristic(text: &str) -> i32 {
295    HeuristicTokenizer::default().count(text)
296}
297
298static GLOBAL_TOKENIZER: OnceLock<Arc<dyn Tokenizer>> = OnceLock::new();
299
300/// Register a global tokenizer used by [`estimate_tokens`].
301///
302/// Returns `false` when a tokenizer has already been registered.
303pub fn register_global_tokenizer(tokenizer: Arc<dyn Tokenizer>) -> bool {
304    GLOBAL_TOKENIZER.set(tokenizer).is_ok()
305}
306
307/// Estimate token count using the configured tokenizer if present.
308///
309/// Falls back to [`estimate_tokens_heuristic`] when no global tokenizer has
310/// been registered.
311pub fn estimate_tokens(text: &str) -> i32 {
312    match GLOBAL_TOKENIZER.get() {
313        Some(tokenizer) => tokenizer.count(text),
314        None => estimate_tokens_heuristic(text),
315    }
316}
317
318/// Heuristic tokenizer using character-to-token ratios.
319///
320/// This provides fast, approximate token counts without requiring
321/// an actual tokenizer model. Good for quick estimates.
322#[derive(Debug, Clone)]
323pub struct HeuristicTokenizer {
324    /// Tokens per character ratio (model-specific)
325    ratio: f32,
326    /// Model family identifier
327    model_family: String,
328}
329
330impl HeuristicTokenizer {
331    /// Create a new heuristic tokenizer for a specific model.
332    ///
333    /// Uses empirically-derived ratios based on model family.
334    pub fn for_model(model: &str) -> Self {
335        let (ratio, family) = if model.contains("gpt-4") || model.contains("gpt-3.5") {
336            // GPT models: ~4 characters per token on average
337            (0.25, "gpt")
338        } else if model.contains("claude") {
339            // Claude models: slightly higher token density
340            (0.28, "claude")
341        } else if model.contains("text-embedding") {
342            // OpenAI embedding models
343            (0.25, "openai-embedding")
344        } else if model.contains("llama") || model.contains("mistral") {
345            // Open source models vary more
346            (0.27, "open-source")
347        } else {
348            // Conservative default
349            (0.30, "unknown")
350        };
351
352        Self {
353            ratio,
354            model_family: family.to_string(),
355        }
356    }
357
358    /// Create with a custom ratio.
359    pub fn with_ratio(ratio: f32, model_family: impl Into<String>) -> Self {
360        Self {
361            ratio,
362            model_family: model_family.into(),
363        }
364    }
365
366    /// Get the current ratio.
367    pub fn ratio(&self) -> f32 {
368        self.ratio
369    }
370}
371
372impl Default for HeuristicTokenizer {
373    fn default() -> Self {
374        Self::for_model("gpt-4")
375    }
376}
377
378impl Tokenizer for HeuristicTokenizer {
379    fn count(&self, text: &str) -> i32 {
380        // Multiply character count by ratio
381        (text.len() as f32 * self.ratio).ceil() as i32
382    }
383
384    fn model_family(&self) -> &str {
385        &self.model_family
386    }
387
388    fn encode(&self, _text: &str) -> Vec<u32> {
389        // Heuristic tokenizer doesn't support encoding
390        Vec::new()
391    }
392
393    fn decode(&self, _tokens: &[u32]) -> String {
394        // Heuristic tokenizer doesn't support decoding
395        String::new()
396    }
397}
398
399// ============================================================================
400// TESTS
401// ============================================================================
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_summarize_style_roundtrip() {
409        for style in [
410            SummarizeStyle::Brief,
411            SummarizeStyle::Detailed,
412            SummarizeStyle::Structured,
413        ] {
414            let s = style.as_db_str();
415            let parsed =
416                SummarizeStyle::from_db_str(s).expect("SummarizeStyle roundtrip should succeed");
417            assert_eq!(style, parsed);
418        }
419    }
420
421    #[test]
422    fn test_provider_capability_roundtrip() {
423        for cap in [
424            ProviderCapability::Embedding,
425            ProviderCapability::Summarization,
426            ProviderCapability::ArtifactExtraction,
427            ProviderCapability::ContradictionDetection,
428            ProviderCapability::ChatCompletion,
429            ProviderCapability::ChatSession,
430        ] {
431            let s = cap.as_db_str();
432            let parsed = ProviderCapability::from_db_str(s)
433                .expect("ProviderCapability roundtrip should succeed");
434            assert_eq!(cap, parsed);
435        }
436    }
437
438    #[test]
439    fn test_circuit_state_from_u8() {
440        assert_eq!(CircuitState::from(0), CircuitState::Closed);
441        assert_eq!(CircuitState::from(1), CircuitState::Open);
442        assert_eq!(CircuitState::from(2), CircuitState::HalfOpen);
443        assert_eq!(CircuitState::from(255), CircuitState::HalfOpen);
444    }
445
446    #[test]
447    fn test_circuit_state_roundtrip() {
448        for state in [
449            CircuitState::Closed,
450            CircuitState::Open,
451            CircuitState::HalfOpen,
452        ] {
453            let s = state.as_db_str();
454            let parsed =
455                CircuitState::from_db_str(s).expect("CircuitState roundtrip should succeed");
456            assert_eq!(state, parsed);
457        }
458    }
459
460    #[test]
461    fn test_summarize_config_default() {
462        let config = SummarizeConfig::default();
463        assert_eq!(config.max_tokens, 256);
464        assert_eq!(config.style, SummarizeStyle::Brief);
465    }
466
467    #[test]
468    fn test_routing_strategy_default() {
469        assert_eq!(RoutingStrategy::default(), RoutingStrategy::RoundRobin);
470    }
471
472    #[test]
473    fn test_heuristic_tokenizer_gpt4() {
474        let tokenizer = HeuristicTokenizer::for_model("gpt-4");
475        assert_eq!(tokenizer.model_family(), "gpt");
476        assert_eq!(tokenizer.ratio(), 0.25);
477
478        // 100 chars * 0.25 = 25 tokens
479        let text = "a".repeat(100);
480        assert_eq!(tokenizer.count(&text), 25);
481    }
482
483    #[test]
484    fn test_heuristic_tokenizer_claude() {
485        let tokenizer = HeuristicTokenizer::for_model("claude-3-opus");
486        assert_eq!(tokenizer.model_family(), "claude");
487        assert_eq!(tokenizer.ratio(), 0.28);
488
489        // 100 chars * 0.28 = 28 tokens
490        let text = "a".repeat(100);
491        assert_eq!(tokenizer.count(&text), 28);
492    }
493
494    #[test]
495    fn test_heuristic_tokenizer_unknown() {
496        let tokenizer = HeuristicTokenizer::for_model("some-random-model");
497        assert_eq!(tokenizer.model_family(), "unknown");
498        assert_eq!(tokenizer.ratio(), 0.30);
499    }
500
501    #[test]
502    fn test_heuristic_tokenizer_custom() {
503        let tokenizer = HeuristicTokenizer::with_ratio(0.5, "custom");
504        assert_eq!(tokenizer.model_family(), "custom");
505        assert_eq!(tokenizer.ratio(), 0.5);
506
507        // 100 chars * 0.5 = 50 tokens
508        let text = "a".repeat(100);
509        assert_eq!(tokenizer.count(&text), 50);
510    }
511
512    #[test]
513    fn test_tokenizer_trait_object() {
514        // Verify it can be used as a trait object
515        let tokenizer: Box<dyn Tokenizer> = Box::new(HeuristicTokenizer::default());
516        assert!(!tokenizer.model_family().is_empty());
517    }
518}