1use crate::{AgentId, AgentState, DecayParams, Event, Timestamp};
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum DriftWeightsError {
25 NonFinite,
27 Negative,
29 SumNotOne(f32),
31}
32
33impl std::fmt::Display for DriftWeightsError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NonFinite => write!(f, "one or more drift weights are NaN or Infinity"),
37 Self::Negative => write!(f, "one or more drift weights are negative"),
38 Self::SumNotOne(sum) => {
39 write!(f, "drift weights sum to {} instead of 1.0", sum)
40 }
41 }
42 }
43}
44
45impl std::error::Error for DriftWeightsError {}
46
47#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
51pub struct DriftWeights {
52 pub state: f32,
54 pub causal: f32,
56 pub context: f32,
58}
59
60impl Default for DriftWeights {
61 fn default() -> Self {
62 Self {
63 state: 0.40,
64 causal: 0.35,
65 context: 0.25,
66 }
67 }
68}
69
70impl DriftWeights {
71 pub fn validate(&self) -> Result<(), DriftWeightsError> {
73 const EPSILON: f32 = 0.01;
74 if !self.state.is_finite() || !self.causal.is_finite() || !self.context.is_finite() {
75 return Err(DriftWeightsError::NonFinite);
76 }
77 if self.state < 0.0 || self.causal < 0.0 || self.context < 0.0 {
78 return Err(DriftWeightsError::Negative);
79 }
80 let sum = self.state + self.causal + self.context;
81 if (sum - 1.0).abs() > EPSILON {
82 return Err(DriftWeightsError::SumNotOne(sum));
83 }
84 Ok(())
85 }
86
87 pub fn normalize(&mut self) {
92 let sum = self.state + self.causal + self.context;
93 if sum.is_finite() && sum > 0.0 {
94 self.state /= sum;
95 self.causal /= sum;
96 self.context /= sum;
97 }
98 }
100}
101
102#[derive(Debug, Clone)]
104pub struct DriftInput<'a> {
105 pub agent_a: AgentId,
107 pub agent_b: AgentId,
109 pub events_a: &'a [Event<serde_json::Value>],
111 pub events_b: &'a [Event<serde_json::Value>],
113 pub context_items_a: &'a [uuid::Uuid],
121 pub context_items_b: &'a [uuid::Uuid],
123 pub decay: &'a DecayParams,
125 pub weights: &'a DriftWeights,
127 pub threshold: f64,
129}
130
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
135#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
136pub struct DriftMeter {
137 pub agent_a: AgentId,
139 pub agent_b: AgentId,
141 pub state_divergence: f64,
143 pub causal_divergence: f64,
145 pub context_divergence: f64,
147 pub composite_score: f64,
149 pub threshold: f64,
153 pub is_drifting: bool,
155 #[cfg_attr(feature = "openapi", schema(value_type = String, format = "date-time"))]
157 pub computed_at: Timestamp,
158}
159
160impl DriftMeter {
161 pub fn compute(input: DriftInput<'_>) -> Self {
165 let now = chrono::Utc::now();
166
167 let state_a = AgentState::from_events(input.events_a);
169 let state_b = AgentState::from_events(input.events_b);
170 let state_divergence = Self::compute_state_divergence(&state_a, &state_b);
171
172 let causal_divergence =
174 Self::compute_causal_divergence(input.events_a, input.events_b, input.decay);
175
176 let context_divergence =
178 Self::compute_context_divergence(input.context_items_a, input.context_items_b);
179
180 let composite_score = input.weights.state as f64 * state_divergence
182 + input.weights.causal as f64 * causal_divergence
183 + input.weights.context as f64 * context_divergence;
184
185 let composite_score = if composite_score.is_finite() {
188 composite_score.clamp(0.0, 1.0)
189 } else {
190 0.0 };
192
193 let alignment = 1.0 - composite_score;
195 let is_drifting = alignment < input.threshold;
196
197 Self {
198 agent_a: input.agent_a,
199 agent_b: input.agent_b,
200 state_divergence,
201 causal_divergence,
202 context_divergence,
203 composite_score,
204 threshold: input.threshold,
205 is_drifting,
206 computed_at: now,
207 }
208 }
209
210 fn compute_state_divergence(a: &AgentState, b: &AgentState) -> f64 {
215 use std::mem::discriminant;
216
217 if a == b {
218 return 0.0;
219 }
220 if discriminant(a) == discriminant(b) {
221 return 0.5;
223 }
224 1.0
225 }
226
227 fn compute_causal_divergence(
232 events_a: &[Event<serde_json::Value>],
233 events_b: &[Event<serde_json::Value>],
234 decay: &DecayParams,
235 ) -> f64 {
236 if events_a.is_empty() && events_b.is_empty() {
237 return 0.0;
238 }
239 if events_a.is_empty() || events_b.is_empty() {
240 return 1.0;
241 }
242
243 let weighted_kinds_a = Self::decay_weighted_kind_histogram(events_a, decay);
245 let weighted_kinds_b = Self::decay_weighted_kind_histogram(events_b, decay);
246
247 let all_kinds: std::collections::HashSet<u16> = weighted_kinds_a
249 .keys()
250 .chain(weighted_kinds_b.keys())
251 .copied()
252 .collect();
253
254 if all_kinds.is_empty() {
255 return 0.0;
256 }
257
258 let mut dot = 0.0f64;
259 let mut norm_a = 0.0f64;
260 let mut norm_b = 0.0f64;
261
262 for kind in &all_kinds {
263 let wa = *weighted_kinds_a.get(kind).unwrap_or(&0.0) as f64;
264 let wb = *weighted_kinds_b.get(kind).unwrap_or(&0.0) as f64;
265 dot += wa * wb;
266 norm_a += wa * wa;
267 norm_b += wb * wb;
268 }
269
270 let denominator = norm_a.sqrt() * norm_b.sqrt();
271 if denominator == 0.0 {
272 return 1.0;
273 }
274
275 let cosine_similarity = dot / denominator;
276 (1.0 - cosine_similarity).clamp(0.0, 1.0)
277 }
278
279 fn decay_weighted_kind_histogram(
281 events: &[Event<serde_json::Value>],
282 decay: &DecayParams,
283 ) -> std::collections::HashMap<u16, f32> {
284 let mut histogram = std::collections::HashMap::new();
285 let len = events.len();
286
287 for (i, event) in events.iter().enumerate() {
288 let age = (len - 1 - i) as f32;
290 let weight = crate::parametric_decay(age, decay);
291 *histogram.entry(event.header.event_kind.0).or_insert(0.0f32) += weight;
292 }
293
294 histogram
295 }
296
297 fn compute_context_divergence(items_a: &[uuid::Uuid], items_b: &[uuid::Uuid]) -> f64 {
299 if items_a.is_empty() && items_b.is_empty() {
300 return 0.0;
301 }
302
303 let set_a: std::collections::HashSet<uuid::Uuid> = items_a.iter().copied().collect();
304 let set_b: std::collections::HashSet<uuid::Uuid> = items_b.iter().copied().collect();
305
306 let intersection = set_a.intersection(&set_b).count();
307 let union = set_a.union(&set_b).count();
308
309 if union == 0 {
310 return 0.0;
311 }
312
313 let jaccard = intersection as f64 / union as f64;
314 1.0 - jaccard
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::identity::EntityIdType;
322 use crate::{DagPosition, Event, EventFlags, EventHeader, EventId, EventKind};
323 use serde_json::json;
324
325 fn make_scope_event(kind: EventKind, scope_id: crate::ScopeId) -> Event<serde_json::Value> {
326 Event::new(
327 EventHeader::new(
328 EventId::now_v7(),
329 EventId::now_v7(),
330 chrono::Utc::now().timestamp_micros(),
331 DagPosition::root(),
332 0,
333 kind,
334 EventFlags::empty(),
335 None,
336 ),
337 json!({ "scope_id": scope_id }),
338 )
339 }
340
341 #[test]
342 fn test_identical_agents_zero_divergence() {
343 let agent_a = AgentId::now_v7();
344 let agent_b = AgentId::now_v7();
345 let scope_id = crate::ScopeId::now_v7();
346
347 let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
348 let context: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
349 let decay = DecayParams::exponential(10.0);
350 let weights = DriftWeights::default();
351
352 let meter = DriftMeter::compute(DriftInput {
353 agent_a,
354 agent_b,
355 events_a: &events,
356 events_b: &events,
357 context_items_a: &context,
358 context_items_b: &context,
359 decay: &decay,
360 weights: &weights,
361 threshold: 0.85,
362 });
363
364 assert_eq!(meter.state_divergence, 0.0);
365 assert_eq!(meter.causal_divergence, 0.0);
366 assert_eq!(meter.context_divergence, 0.0);
367 assert_eq!(meter.composite_score, 0.0);
368 assert!(!meter.is_drifting);
369 }
370
371 #[test]
372 fn test_fully_diverged_agents() {
373 let agent_a = AgentId::now_v7();
374 let agent_b = AgentId::now_v7();
375 let scope_a = crate::ScopeId::now_v7();
376 let scope_b = crate::ScopeId::now_v7();
377
378 let events_a = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_a)];
379 let events_b = vec![
380 make_scope_event(EventKind::SCOPE_CREATED, scope_b),
381 make_scope_event(EventKind::SCOPE_CLOSED, scope_b),
382 ];
383 let context_a: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
384 let context_b: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
385 let decay = DecayParams::exponential(10.0);
386 let weights = DriftWeights::default();
387
388 let meter = DriftMeter::compute(DriftInput {
389 agent_a,
390 agent_b,
391 events_a: &events_a,
392 events_b: &events_b,
393 context_items_a: &context_a,
394 context_items_b: &context_b,
395 decay: &decay,
396 weights: &weights,
397 threshold: 0.85,
398 });
399
400 assert_eq!(meter.state_divergence, 1.0);
402 assert_eq!(meter.context_divergence, 1.0);
404 assert!(meter.composite_score > 0.5);
406 assert!(meter.is_drifting);
407 }
408
409 #[test]
410 fn test_empty_events_no_divergence() {
411 let meter = DriftMeter::compute(DriftInput {
412 agent_a: AgentId::now_v7(),
413 agent_b: AgentId::now_v7(),
414 events_a: &[],
415 events_b: &[],
416 context_items_a: &[],
417 context_items_b: &[],
418 decay: &DecayParams::exponential(10.0),
419 weights: &DriftWeights::default(),
420 threshold: 0.85,
421 });
422 assert_eq!(meter.composite_score, 0.0);
423 assert!(!meter.is_drifting);
424 }
425
426 #[test]
427 fn test_one_empty_stream_fully_diverged() {
428 let scope_id = crate::ScopeId::now_v7();
429 let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
430 let meter = DriftMeter::compute(DriftInput {
431 agent_a: AgentId::now_v7(),
432 agent_b: AgentId::now_v7(),
433 events_a: &events,
434 events_b: &[],
435 context_items_a: &[],
436 context_items_b: &[],
437 decay: &DecayParams::exponential(10.0),
438 weights: &DriftWeights::default(),
439 threshold: 0.85,
440 });
441 assert_eq!(meter.state_divergence, 1.0);
443 assert_eq!(meter.causal_divergence, 1.0);
445 assert!(meter.is_drifting);
446 }
447
448 #[test]
449 fn test_drift_weights_validate() {
450 assert!(DriftWeights::default().validate().is_ok());
451 assert!(DriftWeights {
452 state: 0.5,
453 causal: 0.5,
454 context: 0.5,
455 }
456 .validate()
457 .is_err());
458 }
459
460 #[test]
461 fn test_drift_weights_normalize() {
462 let mut weights = DriftWeights {
463 state: 2.0,
464 causal: 2.0,
465 context: 1.0,
466 };
467 weights.normalize();
468 assert!(weights.validate().is_ok());
469 assert!((weights.state - 0.4).abs() < 0.01);
470 }
471
472 #[test]
473 fn test_context_divergence_partial_overlap() {
474 let shared = uuid::Uuid::now_v7();
475 let only_a = uuid::Uuid::now_v7();
476 let only_b = uuid::Uuid::now_v7();
477
478 let items_a = vec![shared, only_a];
479 let items_b = vec![shared, only_b];
480
481 let divergence = DriftMeter::compute_context_divergence(&items_a, &items_b);
482 assert!((divergence - 0.667).abs() < 0.01);
484 }
485
486 #[test]
489 fn test_nan_weight_detected() {
490 let w = DriftWeights {
491 state: f32::NAN,
492 causal: 0.5,
493 context: 0.5,
494 };
495 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
496 }
497
498 #[test]
499 fn test_infinity_weight_detected() {
500 let w = DriftWeights {
501 state: f32::INFINITY,
502 causal: 0.0,
503 context: 0.0,
504 };
505 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
506 }
507
508 #[test]
509 fn test_neg_infinity_weight_detected() {
510 let w = DriftWeights {
511 state: f32::NEG_INFINITY,
512 causal: 0.5,
513 context: 0.5,
514 };
515 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
516 }
517
518 #[test]
519 fn test_negative_weight_detected() {
520 let w = DriftWeights {
521 state: -0.1,
522 causal: 0.6,
523 context: 0.5,
524 };
525 assert_eq!(w.validate(), Err(DriftWeightsError::Negative));
526 }
527
528 #[test]
529 fn test_sum_not_one_detected() {
530 let w = DriftWeights {
531 state: 0.5,
532 causal: 0.5,
533 context: 0.5,
534 };
535 match w.validate() {
536 Err(DriftWeightsError::SumNotOne(sum)) => {
537 assert!((sum - 1.5).abs() < 0.001);
538 }
539 other => panic!("Expected SumNotOne, got {:?}", other),
540 }
541 }
542
543 #[test]
544 fn test_all_zero_weights_normalize_unchanged() {
545 let mut w = DriftWeights {
546 state: 0.0,
547 causal: 0.0,
548 context: 0.0,
549 };
550 let before = w.clone();
551 w.normalize();
552 assert_eq!(w, before, "all-zero weights should be left unchanged");
553 }
554
555 #[test]
556 fn test_nan_weights_normalize_unchanged() {
557 let mut w = DriftWeights {
558 state: f32::NAN,
559 causal: 0.5,
560 context: 0.5,
561 };
562 w.normalize();
564 assert!(w.state.is_nan());
566 assert_eq!(w.causal, 0.5);
567 assert_eq!(w.context, 0.5);
568 }
569
570 #[test]
571 fn test_infinity_weights_normalize_unchanged() {
572 let mut w = DriftWeights {
573 state: f32::INFINITY,
574 causal: 0.5,
575 context: 0.5,
576 };
577 let causal_before = w.causal;
578 let context_before = w.context;
579 w.normalize();
580 assert_eq!(w.state, f32::INFINITY);
582 assert_eq!(w.causal, causal_before);
583 assert_eq!(w.context, context_before);
584 }
585
586 #[test]
587 fn test_nan_composite_score_guarded() {
588 let nan_weights = DriftWeights {
590 state: f32::NAN,
591 causal: 0.5,
592 context: 0.5,
593 };
594 let meter = DriftMeter::compute(DriftInput {
595 agent_a: AgentId::now_v7(),
596 agent_b: AgentId::now_v7(),
597 events_a: &[],
598 events_b: &[],
599 context_items_a: &[],
600 context_items_b: &[],
601 decay: &DecayParams::exponential(10.0),
602 weights: &nan_weights,
603 threshold: 0.85,
604 });
605 assert!(
606 meter.composite_score.is_finite(),
607 "composite_score must be finite, got {}",
608 meter.composite_score
609 );
610 assert_eq!(meter.composite_score, 0.0);
611 assert!(!meter.is_drifting);
612 }
613
614 #[test]
615 fn test_infinity_in_composite_score_guarded() {
616 let inf_weights = DriftWeights {
618 state: f32::INFINITY,
619 causal: 0.0,
620 context: 0.0,
621 };
622 let meter = DriftMeter::compute(DriftInput {
623 agent_a: AgentId::now_v7(),
624 agent_b: AgentId::now_v7(),
625 events_a: &[],
626 events_b: &[],
627 context_items_a: &[],
628 context_items_b: &[],
629 decay: &DecayParams::exponential(10.0),
630 weights: &inf_weights,
631 threshold: 0.85,
632 });
633 assert!(
634 meter.composite_score.is_finite(),
635 "composite_score must be finite, got {}",
636 meter.composite_score
637 );
638 assert_eq!(meter.composite_score, 0.0);
639 }
640
641 #[test]
642 fn test_validate_returns_specific_error_variants() {
643 let w = DriftWeights {
645 state: f32::NAN,
646 causal: -1.0,
647 context: 0.5,
648 };
649 assert!(matches!(w.validate(), Err(DriftWeightsError::NonFinite)));
650
651 let w = DriftWeights {
653 state: -0.5,
654 causal: 1.0,
655 context: 1.0,
656 };
657 assert!(matches!(w.validate(), Err(DriftWeightsError::Negative)));
658
659 let w = DriftWeights {
661 state: 0.1,
662 causal: 0.1,
663 context: 0.1,
664 };
665 assert!(matches!(w.validate(), Err(DriftWeightsError::SumNotOne(_))));
666 }
667
668 #[test]
669 fn test_mixed_nan_valid_weights() {
670 let w = DriftWeights {
672 state: 0.4,
673 causal: f32::NAN,
674 context: 0.25,
675 };
676 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
677 }
678
679 #[test]
680 fn test_drift_weights_error_display() {
681 let e = DriftWeightsError::NonFinite;
682 assert!(e.to_string().contains("NaN or Infinity"));
683 let e = DriftWeightsError::Negative;
684 assert!(e.to_string().contains("negative"));
685 let e = DriftWeightsError::SumNotOne(1.5);
686 assert!(e.to_string().contains("1.5"));
687 }
688}