1use crate::*;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
10pub struct SectionPriorities {
11 pub user: i32,
12 pub system: i32,
13 pub persona: i32,
14 pub artifacts: i32,
15 pub notes: i32,
16 pub history: i32,
17 #[cfg_attr(feature = "openapi", schema(value_type = Vec<Vec<Object>>))]
18 pub custom: Vec<(String, i32)>,
19}
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
25pub enum ContextPersistence {
26 Ephemeral,
28 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
30 Ttl(Duration),
31 Permanent,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
39pub enum ValidationMode {
40 OnMutation,
42 Always,
44}
45
46#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
48#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
49pub struct ProviderConfig {
50 pub provider_type: ProviderType,
51 pub endpoint: Option<String>,
52 pub model: String,
53 pub dimensions: Option<i32>,
54}
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
59pub struct RetryConfig {
60 pub max_retries: i32,
61 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
63 pub initial_backoff: Duration,
64 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
66 pub max_backoff: Duration,
67 pub backoff_multiplier: f32,
68}
69
70#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
73#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
74pub struct CellstateConfig {
75 pub token_budget: i32,
77 pub section_priorities: SectionPriorities,
78
79 pub checkpoint_retention: i32,
81 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
83 pub stale_threshold: Duration,
84 pub contradiction_threshold: f32,
85
86 pub context_window_persistence: ContextPersistence,
88 pub validation_mode: ValidationMode,
89
90 pub embedding_provider: Option<ProviderConfig>,
92 pub summarization_provider: Option<ProviderConfig>,
93 pub llm_retry_config: RetryConfig,
94
95 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
98 pub lock_timeout: Duration,
99 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
101 pub message_retention: Duration,
102 #[cfg_attr(feature = "openapi", schema(value_type = u64))]
104 pub delegation_timeout: Duration,
105}
106
107impl CellstateConfig {
108 pub fn default_context(token_budget: i32) -> Self {
113 Self {
114 token_budget,
115 section_priorities: SectionPriorities {
116 user: 100,
117 system: 90,
118 persona: 85,
119 artifacts: 80,
120 notes: 70,
121 history: 60,
122 custom: vec![],
123 },
124 checkpoint_retention: 10,
125 stale_threshold: Duration::from_secs(3600),
126 contradiction_threshold: 0.8,
127 context_window_persistence: ContextPersistence::Ephemeral,
128 validation_mode: ValidationMode::OnMutation,
129 embedding_provider: None,
130 summarization_provider: None,
131 llm_retry_config: RetryConfig {
132 max_retries: 3,
133 initial_backoff: Duration::from_millis(100),
134 max_backoff: Duration::from_secs(10),
135 backoff_multiplier: 2.0,
136 },
137 lock_timeout: Duration::from_secs(30),
138 message_retention: Duration::from_secs(86400),
139 delegation_timeout: Duration::from_secs(300),
140 }
141 }
142
143 pub fn validate(&self) -> CellstateResult<()> {
152 if self.token_budget <= 0 {
154 return Err(CellstateError::Config(ConfigError::InvalidValue {
155 field: "token_budget".to_string(),
156 value: self.token_budget.to_string(),
157 reason: "token_budget must be greater than 0".to_string(),
158 }));
159 }
160
161 if self.contradiction_threshold < 0.0 || self.contradiction_threshold > 1.0 {
163 return Err(CellstateError::Config(ConfigError::InvalidValue {
164 field: "contradiction_threshold".to_string(),
165 value: self.contradiction_threshold.to_string(),
166 reason: "contradiction_threshold must be between 0.0 and 1.0".to_string(),
167 }));
168 }
169
170 if self.checkpoint_retention < 0 {
172 return Err(CellstateError::Config(ConfigError::InvalidValue {
173 field: "checkpoint_retention".to_string(),
174 value: self.checkpoint_retention.to_string(),
175 reason: "checkpoint_retention must be non-negative".to_string(),
176 }));
177 }
178
179 if self.stale_threshold.is_zero() {
181 return Err(CellstateError::Config(ConfigError::InvalidValue {
182 field: "stale_threshold".to_string(),
183 value: format!("{:?}", self.stale_threshold),
184 reason: "stale_threshold must be positive".to_string(),
185 }));
186 }
187
188 if self.lock_timeout.is_zero() {
190 return Err(CellstateError::Config(ConfigError::InvalidValue {
191 field: "lock_timeout".to_string(),
192 value: format!("{:?}", self.lock_timeout),
193 reason: "lock_timeout must be positive".to_string(),
194 }));
195 }
196
197 if self.message_retention.is_zero() {
199 return Err(CellstateError::Config(ConfigError::InvalidValue {
200 field: "message_retention".to_string(),
201 value: format!("{:?}", self.message_retention),
202 reason: "message_retention must be positive".to_string(),
203 }));
204 }
205
206 if self.delegation_timeout.is_zero() {
208 return Err(CellstateError::Config(ConfigError::InvalidValue {
209 field: "delegation_timeout".to_string(),
210 value: format!("{:?}", self.delegation_timeout),
211 reason: "delegation_timeout must be positive".to_string(),
212 }));
213 }
214
215 if self.llm_retry_config.max_retries < 0 {
217 return Err(CellstateError::Config(ConfigError::InvalidValue {
218 field: "llm_retry_config.max_retries".to_string(),
219 value: self.llm_retry_config.max_retries.to_string(),
220 reason: "max_retries must be non-negative".to_string(),
221 }));
222 }
223
224 if self.llm_retry_config.backoff_multiplier <= 0.0 {
225 return Err(CellstateError::Config(ConfigError::InvalidValue {
226 field: "llm_retry_config.backoff_multiplier".to_string(),
227 value: self.llm_retry_config.backoff_multiplier.to_string(),
228 reason: "backoff_multiplier must be positive".to_string(),
229 }));
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
246pub struct ContextAssemblyDefaults {
247 pub rest_token_budget: i32,
249 pub max_notes: usize,
251 pub max_artifacts: usize,
253 pub max_turns: usize,
255 pub max_summaries: usize,
257}
258
259impl Default for ContextAssemblyDefaults {
260 fn default() -> Self {
261 Self {
262 rest_token_budget: 8000,
263 max_notes: 10,
264 max_artifacts: 5,
265 max_turns: 20,
266 max_summaries: 5,
267 }
268 }
269}
270
271impl ContextAssemblyDefaults {
272 pub fn from_env() -> Self {
281 let defaults = Self::default();
282
283 Self {
284 rest_token_budget: std::env::var("CELLSTATE_CONTEXT_REST_TOKEN_BUDGET")
285 .ok()
286 .and_then(|s| s.parse().ok())
287 .unwrap_or(defaults.rest_token_budget),
288 max_notes: std::env::var("CELLSTATE_CONTEXT_MAX_NOTES")
289 .ok()
290 .and_then(|s| s.parse().ok())
291 .unwrap_or(defaults.max_notes),
292 max_artifacts: std::env::var("CELLSTATE_CONTEXT_MAX_ARTIFACTS")
293 .ok()
294 .and_then(|s| s.parse().ok())
295 .unwrap_or(defaults.max_artifacts),
296 max_turns: std::env::var("CELLSTATE_CONTEXT_MAX_TURNS")
297 .ok()
298 .and_then(|s| s.parse().ok())
299 .unwrap_or(defaults.max_turns),
300 max_summaries: std::env::var("CELLSTATE_CONTEXT_MAX_SUMMARIES")
301 .ok()
302 .and_then(|s| s.parse().ok())
303 .unwrap_or(defaults.max_summaries),
304 }
305 }
306}
307
308#[cfg(test)]
311mod tests {
312 use super::*;
313
314 fn valid_config() -> CellstateConfig {
315 CellstateConfig::default_context(8000)
316 }
317
318 #[test]
321 fn default_context_passes_validation() {
322 assert!(valid_config().validate().is_ok());
323 }
324
325 #[test]
326 fn default_context_has_positive_token_budget() {
327 assert!(valid_config().token_budget > 0);
328 }
329
330 #[test]
331 fn default_context_priorities_are_ordered() {
332 let c = valid_config();
333 let p = &c.section_priorities;
334 assert!(p.user > p.system);
335 assert!(p.system > p.persona);
336 assert!(p.persona > p.artifacts);
337 assert!(p.artifacts > p.notes);
338 assert!(p.notes > p.history);
339 }
340
341 #[test]
344 fn validate_rejects_zero_token_budget() {
345 let mut c = valid_config();
346 c.token_budget = 0;
347 assert!(c.validate().is_err());
348 }
349
350 #[test]
351 fn validate_rejects_negative_token_budget() {
352 let mut c = valid_config();
353 c.token_budget = -1;
354 assert!(c.validate().is_err());
355 }
356
357 #[test]
358 fn validate_rejects_contradiction_threshold_above_one() {
359 let mut c = valid_config();
360 c.contradiction_threshold = 1.1;
361 assert!(c.validate().is_err());
362 }
363
364 #[test]
365 fn validate_rejects_negative_contradiction_threshold() {
366 let mut c = valid_config();
367 c.contradiction_threshold = -0.1;
368 assert!(c.validate().is_err());
369 }
370
371 #[test]
372 fn validate_accepts_boundary_contradiction_thresholds() {
373 let mut c = valid_config();
374 c.contradiction_threshold = 0.0;
375 assert!(c.validate().is_ok());
376 c.contradiction_threshold = 1.0;
377 assert!(c.validate().is_ok());
378 }
379
380 #[test]
381 fn validate_rejects_negative_checkpoint_retention() {
382 let mut c = valid_config();
383 c.checkpoint_retention = -1;
384 assert!(c.validate().is_err());
385 }
386
387 #[test]
388 fn validate_accepts_zero_checkpoint_retention() {
389 let mut c = valid_config();
390 c.checkpoint_retention = 0;
391 assert!(c.validate().is_ok());
392 }
393
394 #[test]
395 fn validate_rejects_zero_stale_threshold() {
396 let mut c = valid_config();
397 c.stale_threshold = Duration::ZERO;
398 assert!(c.validate().is_err());
399 }
400
401 #[test]
402 fn validate_rejects_zero_lock_timeout() {
403 let mut c = valid_config();
404 c.lock_timeout = Duration::ZERO;
405 assert!(c.validate().is_err());
406 }
407
408 #[test]
409 fn validate_rejects_zero_message_retention() {
410 let mut c = valid_config();
411 c.message_retention = Duration::ZERO;
412 assert!(c.validate().is_err());
413 }
414
415 #[test]
416 fn validate_rejects_zero_delegation_timeout() {
417 let mut c = valid_config();
418 c.delegation_timeout = Duration::ZERO;
419 assert!(c.validate().is_err());
420 }
421
422 #[test]
423 fn validate_rejects_negative_max_retries() {
424 let mut c = valid_config();
425 c.llm_retry_config.max_retries = -1;
426 assert!(c.validate().is_err());
427 }
428
429 #[test]
430 fn validate_rejects_zero_backoff_multiplier() {
431 let mut c = valid_config();
432 c.llm_retry_config.backoff_multiplier = 0.0;
433 assert!(c.validate().is_err());
434 }
435
436 #[test]
437 fn validate_rejects_negative_backoff_multiplier() {
438 let mut c = valid_config();
439 c.llm_retry_config.backoff_multiplier = -1.0;
440 assert!(c.validate().is_err());
441 }
442
443 #[test]
446 fn config_serde_roundtrip() {
447 let c = valid_config();
448 let json = serde_json::to_string(&c).unwrap();
449 let deserialized: CellstateConfig = serde_json::from_str(&json).unwrap();
450 assert_eq!(c, deserialized);
451 }
452
453 #[test]
454 fn context_persistence_serde_roundtrip() {
455 let variants = vec![
456 ContextPersistence::Ephemeral,
457 ContextPersistence::Ttl(Duration::from_secs(3600)),
458 ContextPersistence::Permanent,
459 ];
460 for v in variants {
461 let json = serde_json::to_string(&v).unwrap();
462 let d: ContextPersistence = serde_json::from_str(&json).unwrap();
463 assert_eq!(v, d);
464 }
465 }
466
467 #[test]
468 fn validation_mode_serde_roundtrip() {
469 for v in [ValidationMode::OnMutation, ValidationMode::Always] {
470 let json = serde_json::to_string(&v).unwrap();
471 let d: ValidationMode = serde_json::from_str(&json).unwrap();
472 assert_eq!(v, d);
473 }
474 }
475
476 #[test]
479 fn context_assembly_defaults_are_sane() {
480 let d = ContextAssemblyDefaults::default();
481 assert!(d.rest_token_budget > 0);
482 assert!(d.max_notes > 0);
483 assert!(d.max_artifacts > 0);
484 assert!(d.max_turns > 0);
485 assert!(d.max_summaries > 0);
486 }
487
488 #[test]
489 fn context_assembly_from_env_returns_defaults_without_env() {
490 let d = ContextAssemblyDefaults::from_env();
491 let expected = ContextAssemblyDefaults::default();
492 assert_eq!(d, expected);
493 }
494
495 #[test]
496 fn context_assembly_defaults_serde_roundtrip() {
497 let d = ContextAssemblyDefaults::default();
498 let json = serde_json::to_string(&d).unwrap();
499 let deserialized: ContextAssemblyDefaults = serde_json::from_str(&json).unwrap();
500 assert_eq!(d, deserialized);
501 }
502}