1use crate::{ArtifactType, CellstateResult, EmbeddingVector, EnumParseError};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, OnceLock};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
19pub enum SummarizeStyle {
20 Brief,
22 Detailed,
24 Structured,
26}
27
28impl SummarizeStyle {
29 pub fn as_db_str(&self) -> &'static str {
31 match self {
32 Self::Brief => "brief",
33 Self::Detailed => "detailed",
34 Self::Structured => "structured",
35 }
36 }
37
38 pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
40 match s.to_lowercase().as_str() {
41 "brief" => Ok(Self::Brief),
42 "detailed" => Ok(Self::Detailed),
43 "structured" => Ok(Self::Structured),
44 _ => Err(EnumParseError {
45 enum_name: "SummarizeStyle",
46 input: s.to_string(),
47 }),
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
55pub struct SummarizeConfig {
56 pub max_tokens: i32,
58 pub style: SummarizeStyle,
60}
61
62impl Default for SummarizeConfig {
63 fn default() -> Self {
64 Self {
65 max_tokens: 256,
66 style: SummarizeStyle::Brief,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
79pub enum ProviderCapability {
80 Embedding,
82 Summarization,
84 ArtifactExtraction,
86 ContradictionDetection,
88 ChatCompletion,
90 ChatSession,
92}
93
94impl ProviderCapability {
95 pub fn as_db_str(&self) -> &'static str {
97 match self {
98 Self::Embedding => "embedding",
99 Self::Summarization => "summarization",
100 Self::ArtifactExtraction => "artifact_extraction",
101 Self::ContradictionDetection => "contradiction_detection",
102 Self::ChatCompletion => "chat_completion",
103 Self::ChatSession => "chat_session",
104 }
105 }
106
107 pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
109 match s.to_lowercase().replace('_', "").as_str() {
110 "embedding" => Ok(Self::Embedding),
111 "summarization" => Ok(Self::Summarization),
112 "artifactextraction" => Ok(Self::ArtifactExtraction),
113 "contradictiondetection" => Ok(Self::ContradictionDetection),
114 "chatcompletion" => Ok(Self::ChatCompletion),
115 "chatsession" => Ok(Self::ChatSession),
116 _ => Err(EnumParseError {
117 enum_name: "ProviderCapability",
118 input: s.to_string(),
119 }),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
132pub enum CircuitState {
133 Closed = 0,
135 Open = 1,
137 HalfOpen = 2,
139}
140
141impl From<u8> for CircuitState {
142 fn from(v: u8) -> Self {
143 match v {
144 0 => CircuitState::Closed,
145 1 => CircuitState::Open,
146 _ => CircuitState::HalfOpen,
147 }
148 }
149}
150
151impl CircuitState {
152 pub fn as_db_str(&self) -> &'static str {
154 match self {
155 Self::Closed => "closed",
156 Self::Open => "open",
157 Self::HalfOpen => "half_open",
158 }
159 }
160
161 pub fn from_db_str(s: &str) -> Result<Self, EnumParseError> {
163 match s.to_lowercase().replace('_', "").as_str() {
164 "closed" => Ok(Self::Closed),
165 "open" => Ok(Self::Open),
166 "halfopen" => Ok(Self::HalfOpen),
167 _ => Err(EnumParseError {
168 enum_name: "CircuitState",
169 input: s.to_string(),
170 }),
171 }
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
181#[serde(rename_all = "snake_case")]
182#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
183pub enum RoutingStrategy {
184 #[default]
186 RoundRobin,
187 LeastLatency,
189 Random,
191 Capability(ProviderCapability),
193 First,
195}
196
197impl RoutingStrategy {
198 pub fn as_db_str(&self) -> &'static str {
200 match self {
201 Self::RoundRobin => "round_robin",
202 Self::LeastLatency => "least_latency",
203 Self::Random => "random",
204 Self::Capability(_) => "capability",
205 Self::First => "first",
206 }
207 }
208}
209
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
216#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
217pub struct ExtractedArtifact {
218 pub artifact_type: ArtifactType,
220 pub content: String,
222 pub confidence: f32,
224}
225
226#[async_trait]
235pub trait EmbeddingProvider: Send + Sync {
236 async fn embed(&self, text: &str) -> CellstateResult<EmbeddingVector>;
238
239 async fn embed_batch(&self, texts: &[&str]) -> CellstateResult<Vec<EmbeddingVector>>;
241
242 fn dimensions(&self) -> i32;
244
245 fn model_id(&self) -> &str;
247}
248
249#[async_trait]
253pub trait SummarizationProvider: Send + Sync {
254 async fn summarize(&self, content: &str, config: &SummarizeConfig) -> CellstateResult<String>;
256
257 async fn extract_artifacts(
259 &self,
260 content: &str,
261 types: &[ArtifactType],
262 ) -> CellstateResult<Vec<ExtractedArtifact>>;
263
264 async fn detect_contradiction(&self, a: &str, b: &str) -> CellstateResult<bool>;
266}
267
268pub trait Tokenizer: Send + Sync {
278 fn count(&self, text: &str) -> i32;
280
281 fn model_family(&self) -> &str;
283
284 fn encode(&self, text: &str) -> Vec<u32>;
287
288 fn decode(&self, tokens: &[u32]) -> String;
291}
292
293pub fn estimate_tokens_heuristic(text: &str) -> i32 {
295 HeuristicTokenizer::default().count(text)
296}
297
298static GLOBAL_TOKENIZER: OnceLock<Arc<dyn Tokenizer>> = OnceLock::new();
299
300pub fn register_global_tokenizer(tokenizer: Arc<dyn Tokenizer>) -> bool {
304 GLOBAL_TOKENIZER.set(tokenizer).is_ok()
305}
306
307pub fn estimate_tokens(text: &str) -> i32 {
312 match GLOBAL_TOKENIZER.get() {
313 Some(tokenizer) => tokenizer.count(text),
314 None => estimate_tokens_heuristic(text),
315 }
316}
317
318#[derive(Debug, Clone)]
323pub struct HeuristicTokenizer {
324 ratio: f32,
326 model_family: String,
328}
329
330impl HeuristicTokenizer {
331 pub fn for_model(model: &str) -> Self {
335 let (ratio, family) = if model.contains("gpt-4") || model.contains("gpt-3.5") {
336 (0.25, "gpt")
338 } else if model.contains("claude") {
339 (0.28, "claude")
341 } else if model.contains("text-embedding") {
342 (0.25, "openai-embedding")
344 } else if model.contains("llama") || model.contains("mistral") {
345 (0.27, "open-source")
347 } else {
348 (0.30, "unknown")
350 };
351
352 Self {
353 ratio,
354 model_family: family.to_string(),
355 }
356 }
357
358 pub fn with_ratio(ratio: f32, model_family: impl Into<String>) -> Self {
360 Self {
361 ratio,
362 model_family: model_family.into(),
363 }
364 }
365
366 pub fn ratio(&self) -> f32 {
368 self.ratio
369 }
370}
371
372impl Default for HeuristicTokenizer {
373 fn default() -> Self {
374 Self::for_model("gpt-4")
375 }
376}
377
378impl Tokenizer for HeuristicTokenizer {
379 fn count(&self, text: &str) -> i32 {
380 (text.len() as f32 * self.ratio).ceil() as i32
382 }
383
384 fn model_family(&self) -> &str {
385 &self.model_family
386 }
387
388 fn encode(&self, _text: &str) -> Vec<u32> {
389 Vec::new()
391 }
392
393 fn decode(&self, _tokens: &[u32]) -> String {
394 String::new()
396 }
397}
398
399#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_summarize_style_roundtrip() {
409 for style in [
410 SummarizeStyle::Brief,
411 SummarizeStyle::Detailed,
412 SummarizeStyle::Structured,
413 ] {
414 let s = style.as_db_str();
415 let parsed =
416 SummarizeStyle::from_db_str(s).expect("SummarizeStyle roundtrip should succeed");
417 assert_eq!(style, parsed);
418 }
419 }
420
421 #[test]
422 fn test_provider_capability_roundtrip() {
423 for cap in [
424 ProviderCapability::Embedding,
425 ProviderCapability::Summarization,
426 ProviderCapability::ArtifactExtraction,
427 ProviderCapability::ContradictionDetection,
428 ProviderCapability::ChatCompletion,
429 ProviderCapability::ChatSession,
430 ] {
431 let s = cap.as_db_str();
432 let parsed = ProviderCapability::from_db_str(s)
433 .expect("ProviderCapability roundtrip should succeed");
434 assert_eq!(cap, parsed);
435 }
436 }
437
438 #[test]
439 fn test_circuit_state_from_u8() {
440 assert_eq!(CircuitState::from(0), CircuitState::Closed);
441 assert_eq!(CircuitState::from(1), CircuitState::Open);
442 assert_eq!(CircuitState::from(2), CircuitState::HalfOpen);
443 assert_eq!(CircuitState::from(255), CircuitState::HalfOpen);
444 }
445
446 #[test]
447 fn test_circuit_state_roundtrip() {
448 for state in [
449 CircuitState::Closed,
450 CircuitState::Open,
451 CircuitState::HalfOpen,
452 ] {
453 let s = state.as_db_str();
454 let parsed =
455 CircuitState::from_db_str(s).expect("CircuitState roundtrip should succeed");
456 assert_eq!(state, parsed);
457 }
458 }
459
460 #[test]
461 fn test_summarize_config_default() {
462 let config = SummarizeConfig::default();
463 assert_eq!(config.max_tokens, 256);
464 assert_eq!(config.style, SummarizeStyle::Brief);
465 }
466
467 #[test]
468 fn test_routing_strategy_default() {
469 assert_eq!(RoutingStrategy::default(), RoutingStrategy::RoundRobin);
470 }
471
472 #[test]
473 fn test_heuristic_tokenizer_gpt4() {
474 let tokenizer = HeuristicTokenizer::for_model("gpt-4");
475 assert_eq!(tokenizer.model_family(), "gpt");
476 assert_eq!(tokenizer.ratio(), 0.25);
477
478 let text = "a".repeat(100);
480 assert_eq!(tokenizer.count(&text), 25);
481 }
482
483 #[test]
484 fn test_heuristic_tokenizer_claude() {
485 let tokenizer = HeuristicTokenizer::for_model("claude-3-opus");
486 assert_eq!(tokenizer.model_family(), "claude");
487 assert_eq!(tokenizer.ratio(), 0.28);
488
489 let text = "a".repeat(100);
491 assert_eq!(tokenizer.count(&text), 28);
492 }
493
494 #[test]
495 fn test_heuristic_tokenizer_unknown() {
496 let tokenizer = HeuristicTokenizer::for_model("some-random-model");
497 assert_eq!(tokenizer.model_family(), "unknown");
498 assert_eq!(tokenizer.ratio(), 0.30);
499 }
500
501 #[test]
502 fn test_heuristic_tokenizer_custom() {
503 let tokenizer = HeuristicTokenizer::with_ratio(0.5, "custom");
504 assert_eq!(tokenizer.model_family(), "custom");
505 assert_eq!(tokenizer.ratio(), 0.5);
506
507 let text = "a".repeat(100);
509 assert_eq!(tokenizer.count(&text), 50);
510 }
511
512 #[test]
513 fn test_tokenizer_trait_object() {
514 let tokenizer: Box<dyn Tokenizer> = Box::new(HeuristicTokenizer::default());
516 assert!(!tokenizer.model_family().is_empty());
517 }
518}