1use crate::{AgentId, AgentState, DecayParams, Event, Timestamp};
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum DriftWeightsError {
25 NonFinite,
27 Negative,
29 SumNotOne(f32),
31}
32
33impl std::fmt::Display for DriftWeightsError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NonFinite => write!(f, "one or more drift weights are NaN or Infinity"),
37 Self::Negative => write!(f, "one or more drift weights are negative"),
38 Self::SumNotOne(sum) => {
39 write!(f, "drift weights sum to {} instead of 1.0", sum)
40 }
41 }
42 }
43}
44
45impl std::error::Error for DriftWeightsError {}
46
47#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
51pub struct DriftWeights {
52 pub state: f32,
54 pub causal: f32,
56 pub context: f32,
58}
59
60impl Default for DriftWeights {
61 fn default() -> Self {
62 Self {
63 state: 0.40,
64 causal: 0.35,
65 context: 0.25,
66 }
67 }
68}
69
70impl DriftWeights {
71 pub fn validate(&self) -> Result<(), DriftWeightsError> {
73 const EPSILON: f32 = 0.01;
74 if !self.state.is_finite() || !self.causal.is_finite() || !self.context.is_finite() {
75 return Err(DriftWeightsError::NonFinite);
76 }
77 if self.state < 0.0 || self.causal < 0.0 || self.context < 0.0 {
78 return Err(DriftWeightsError::Negative);
79 }
80 let sum = self.state + self.causal + self.context;
81 if (sum - 1.0).abs() > EPSILON {
82 return Err(DriftWeightsError::SumNotOne(sum));
83 }
84 Ok(())
85 }
86
87 pub fn normalize(&mut self) {
92 let sum = self.state + self.causal + self.context;
93 if sum.is_finite() && sum > 0.0 {
94 self.state /= sum;
95 self.causal /= sum;
96 self.context /= sum;
97 }
98 }
100}
101
102#[derive(Debug, Clone)]
104pub struct DriftInput<'a> {
105 pub agent_a: AgentId,
107 pub agent_b: AgentId,
109 pub events_a: &'a [Event<serde_json::Value>],
111 pub events_b: &'a [Event<serde_json::Value>],
113 pub context_items_a: &'a [uuid::Uuid],
121 pub context_items_b: &'a [uuid::Uuid],
123 pub decay: &'a DecayParams,
125 pub weights: &'a DriftWeights,
127 pub threshold: f64,
129}
130
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
135#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
136pub struct DriftMeter {
137 pub agent_a: AgentId,
139 pub agent_b: AgentId,
141 pub state_divergence: f64,
143 pub causal_divergence: f64,
145 pub context_divergence: f64,
147 pub composite_score: f64,
149 pub threshold: f64,
153 pub is_drifting: bool,
155 #[cfg_attr(feature = "openapi", schema(value_type = String, format = "date-time"))]
157 pub computed_at: Timestamp,
158}
159
160impl DriftMeter {
161 pub fn compute(input: DriftInput<'_>) -> Self {
165 let now = chrono::Utc::now();
166
167 let state_a = AgentState::from_events(input.events_a);
169 let state_b = AgentState::from_events(input.events_b);
170 let state_divergence = Self::compute_state_divergence(&state_a, &state_b);
171
172 let causal_divergence =
174 Self::compute_causal_divergence(input.events_a, input.events_b, input.decay);
175
176 let context_divergence =
178 Self::compute_context_divergence(input.context_items_a, input.context_items_b);
179
180 let composite_score = input.weights.state as f64 * state_divergence
182 + input.weights.causal as f64 * causal_divergence
183 + input.weights.context as f64 * context_divergence;
184
185 let composite_score = if composite_score.is_finite() {
188 composite_score.clamp(0.0, 1.0)
189 } else {
190 0.0 };
192
193 let alignment = 1.0 - composite_score;
195 let is_drifting = alignment < input.threshold;
196
197 Self {
198 agent_a: input.agent_a,
199 agent_b: input.agent_b,
200 state_divergence,
201 causal_divergence,
202 context_divergence,
203 composite_score,
204 threshold: input.threshold,
205 is_drifting,
206 computed_at: now,
207 }
208 }
209
210 fn compute_state_divergence(a: &AgentState, b: &AgentState) -> f64 {
215 use std::mem::discriminant;
216
217 if a == b {
218 return 0.0;
219 }
220 if discriminant(a) == discriminant(b) {
221 return 0.5;
223 }
224 1.0
225 }
226
227 fn compute_causal_divergence(
232 events_a: &[Event<serde_json::Value>],
233 events_b: &[Event<serde_json::Value>],
234 decay: &DecayParams,
235 ) -> f64 {
236 if events_a.is_empty() && events_b.is_empty() {
237 return 0.0;
238 }
239 if events_a.is_empty() || events_b.is_empty() {
240 return 1.0;
241 }
242
243 let weighted_kinds_a = Self::decay_weighted_kind_histogram(events_a, decay);
245 let weighted_kinds_b = Self::decay_weighted_kind_histogram(events_b, decay);
246
247 let all_kinds: std::collections::HashSet<u16> = weighted_kinds_a
249 .keys()
250 .chain(weighted_kinds_b.keys())
251 .copied()
252 .collect();
253
254 if all_kinds.is_empty() {
255 return 0.0;
256 }
257
258 let mut dot = 0.0f64;
259 let mut norm_a = 0.0f64;
260 let mut norm_b = 0.0f64;
261
262 for kind in &all_kinds {
263 let wa = *weighted_kinds_a.get(kind).unwrap_or(&0.0) as f64;
264 let wb = *weighted_kinds_b.get(kind).unwrap_or(&0.0) as f64;
265 dot += wa * wb;
266 norm_a += wa * wa;
267 norm_b += wb * wb;
268 }
269
270 let denominator = norm_a.sqrt() * norm_b.sqrt();
271 if denominator == 0.0 {
272 return 1.0;
273 }
274
275 let cosine_similarity = dot / denominator;
276 (1.0 - cosine_similarity).clamp(0.0, 1.0)
277 }
278
279 fn decay_weighted_kind_histogram(
281 events: &[Event<serde_json::Value>],
282 decay: &DecayParams,
283 ) -> std::collections::HashMap<u16, f32> {
284 let mut histogram = std::collections::HashMap::new();
285 let len = events.len();
286
287 for (i, event) in events.iter().enumerate() {
288 let age = (len - 1 - i) as f32;
290 let weight = crate::parametric_decay(age, decay);
291 *histogram.entry(event.header.event_kind.0).or_insert(0.0f32) += weight;
292 }
293
294 histogram
295 }
296
297 fn compute_context_divergence(items_a: &[uuid::Uuid], items_b: &[uuid::Uuid]) -> f64 {
299 if items_a.is_empty() && items_b.is_empty() {
300 return 0.0;
301 }
302
303 let set_a: std::collections::HashSet<uuid::Uuid> = items_a.iter().copied().collect();
304 let set_b: std::collections::HashSet<uuid::Uuid> = items_b.iter().copied().collect();
305
306 let intersection = set_a.intersection(&set_b).count();
307 let union = set_a.union(&set_b).count();
308
309 if union == 0 {
310 return 0.0;
311 }
312
313 let jaccard = intersection as f64 / union as f64;
314 1.0 - jaccard
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::identity::EntityIdType;
322 use crate::{DagPosition, Event, EventHeader, EventId, EventKind};
323 use serde_json::json;
324
325 fn make_scope_event(kind: EventKind, scope_id: crate::ScopeId) -> Event<serde_json::Value> {
326 Event::new(
327 EventHeader::new(
328 EventId::now_v7(),
329 EventId::now_v7(),
330 chrono::Utc::now().timestamp_micros(),
331 DagPosition::root(),
332 0,
333 kind,
334 ),
335 json!({ "scope_id": scope_id }),
336 )
337 }
338
339 #[test]
340 fn test_identical_agents_zero_divergence() {
341 let agent_a = AgentId::now_v7();
342 let agent_b = AgentId::now_v7();
343 let scope_id = crate::ScopeId::now_v7();
344
345 let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
346 let context: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
347 let decay = DecayParams::exponential(10.0);
348 let weights = DriftWeights::default();
349
350 let meter = DriftMeter::compute(DriftInput {
351 agent_a,
352 agent_b,
353 events_a: &events,
354 events_b: &events,
355 context_items_a: &context,
356 context_items_b: &context,
357 decay: &decay,
358 weights: &weights,
359 threshold: 0.85,
360 });
361
362 assert_eq!(meter.state_divergence, 0.0);
363 assert_eq!(meter.causal_divergence, 0.0);
364 assert_eq!(meter.context_divergence, 0.0);
365 assert_eq!(meter.composite_score, 0.0);
366 assert!(!meter.is_drifting);
367 }
368
369 #[test]
370 fn test_fully_diverged_agents() {
371 let agent_a = AgentId::now_v7();
372 let agent_b = AgentId::now_v7();
373 let scope_a = crate::ScopeId::now_v7();
374 let scope_b = crate::ScopeId::now_v7();
375
376 let events_a = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_a)];
377 let events_b = vec![
378 make_scope_event(EventKind::SCOPE_CREATED, scope_b),
379 make_scope_event(EventKind::SCOPE_CLOSED, scope_b),
380 ];
381 let context_a: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
382 let context_b: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
383 let decay = DecayParams::exponential(10.0);
384 let weights = DriftWeights::default();
385
386 let meter = DriftMeter::compute(DriftInput {
387 agent_a,
388 agent_b,
389 events_a: &events_a,
390 events_b: &events_b,
391 context_items_a: &context_a,
392 context_items_b: &context_b,
393 decay: &decay,
394 weights: &weights,
395 threshold: 0.85,
396 });
397
398 assert_eq!(meter.state_divergence, 1.0);
400 assert_eq!(meter.context_divergence, 1.0);
402 assert!(meter.composite_score > 0.5);
404 assert!(meter.is_drifting);
405 }
406
407 #[test]
408 fn test_empty_events_no_divergence() {
409 let meter = DriftMeter::compute(DriftInput {
410 agent_a: AgentId::now_v7(),
411 agent_b: AgentId::now_v7(),
412 events_a: &[],
413 events_b: &[],
414 context_items_a: &[],
415 context_items_b: &[],
416 decay: &DecayParams::exponential(10.0),
417 weights: &DriftWeights::default(),
418 threshold: 0.85,
419 });
420 assert_eq!(meter.composite_score, 0.0);
421 assert!(!meter.is_drifting);
422 }
423
424 #[test]
425 fn test_one_empty_stream_fully_diverged() {
426 let scope_id = crate::ScopeId::now_v7();
427 let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
428 let meter = DriftMeter::compute(DriftInput {
429 agent_a: AgentId::now_v7(),
430 agent_b: AgentId::now_v7(),
431 events_a: &events,
432 events_b: &[],
433 context_items_a: &[],
434 context_items_b: &[],
435 decay: &DecayParams::exponential(10.0),
436 weights: &DriftWeights::default(),
437 threshold: 0.85,
438 });
439 assert_eq!(meter.state_divergence, 1.0);
441 assert_eq!(meter.causal_divergence, 1.0);
443 assert!(meter.is_drifting);
444 }
445
446 #[test]
447 fn test_drift_weights_validate() {
448 assert!(DriftWeights::default().validate().is_ok());
449 assert!(DriftWeights {
450 state: 0.5,
451 causal: 0.5,
452 context: 0.5,
453 }
454 .validate()
455 .is_err());
456 }
457
458 #[test]
459 fn test_drift_weights_normalize() {
460 let mut weights = DriftWeights {
461 state: 2.0,
462 causal: 2.0,
463 context: 1.0,
464 };
465 weights.normalize();
466 assert!(weights.validate().is_ok());
467 assert!((weights.state - 0.4).abs() < 0.01);
468 }
469
470 #[test]
471 fn test_context_divergence_partial_overlap() {
472 let shared = uuid::Uuid::now_v7();
473 let only_a = uuid::Uuid::now_v7();
474 let only_b = uuid::Uuid::now_v7();
475
476 let items_a = vec![shared, only_a];
477 let items_b = vec![shared, only_b];
478
479 let divergence = DriftMeter::compute_context_divergence(&items_a, &items_b);
480 assert!((divergence - 0.667).abs() < 0.01);
482 }
483
484 #[test]
487 fn test_nan_weight_detected() {
488 let w = DriftWeights {
489 state: f32::NAN,
490 causal: 0.5,
491 context: 0.5,
492 };
493 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
494 }
495
496 #[test]
497 fn test_infinity_weight_detected() {
498 let w = DriftWeights {
499 state: f32::INFINITY,
500 causal: 0.0,
501 context: 0.0,
502 };
503 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
504 }
505
506 #[test]
507 fn test_neg_infinity_weight_detected() {
508 let w = DriftWeights {
509 state: f32::NEG_INFINITY,
510 causal: 0.5,
511 context: 0.5,
512 };
513 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
514 }
515
516 #[test]
517 fn test_negative_weight_detected() {
518 let w = DriftWeights {
519 state: -0.1,
520 causal: 0.6,
521 context: 0.5,
522 };
523 assert_eq!(w.validate(), Err(DriftWeightsError::Negative));
524 }
525
526 #[test]
527 fn test_sum_not_one_detected() {
528 let w = DriftWeights {
529 state: 0.5,
530 causal: 0.5,
531 context: 0.5,
532 };
533 match w.validate() {
534 Err(DriftWeightsError::SumNotOne(sum)) => {
535 assert!((sum - 1.5).abs() < 0.001);
536 }
537 other => panic!("Expected SumNotOne, got {:?}", other),
538 }
539 }
540
541 #[test]
542 fn test_all_zero_weights_normalize_unchanged() {
543 let mut w = DriftWeights {
544 state: 0.0,
545 causal: 0.0,
546 context: 0.0,
547 };
548 let before = w.clone();
549 w.normalize();
550 assert_eq!(w, before, "all-zero weights should be left unchanged");
551 }
552
553 #[test]
554 fn test_nan_weights_normalize_unchanged() {
555 let mut w = DriftWeights {
556 state: f32::NAN,
557 causal: 0.5,
558 context: 0.5,
559 };
560 w.normalize();
562 assert!(w.state.is_nan());
564 assert_eq!(w.causal, 0.5);
565 assert_eq!(w.context, 0.5);
566 }
567
568 #[test]
569 fn test_infinity_weights_normalize_unchanged() {
570 let mut w = DriftWeights {
571 state: f32::INFINITY,
572 causal: 0.5,
573 context: 0.5,
574 };
575 let causal_before = w.causal;
576 let context_before = w.context;
577 w.normalize();
578 assert_eq!(w.state, f32::INFINITY);
580 assert_eq!(w.causal, causal_before);
581 assert_eq!(w.context, context_before);
582 }
583
584 #[test]
585 fn test_nan_composite_score_guarded() {
586 let nan_weights = DriftWeights {
588 state: f32::NAN,
589 causal: 0.5,
590 context: 0.5,
591 };
592 let meter = DriftMeter::compute(DriftInput {
593 agent_a: AgentId::now_v7(),
594 agent_b: AgentId::now_v7(),
595 events_a: &[],
596 events_b: &[],
597 context_items_a: &[],
598 context_items_b: &[],
599 decay: &DecayParams::exponential(10.0),
600 weights: &nan_weights,
601 threshold: 0.85,
602 });
603 assert!(
604 meter.composite_score.is_finite(),
605 "composite_score must be finite, got {}",
606 meter.composite_score
607 );
608 assert_eq!(meter.composite_score, 0.0);
609 assert!(!meter.is_drifting);
610 }
611
612 #[test]
613 fn test_infinity_in_composite_score_guarded() {
614 let inf_weights = DriftWeights {
616 state: f32::INFINITY,
617 causal: 0.0,
618 context: 0.0,
619 };
620 let meter = DriftMeter::compute(DriftInput {
621 agent_a: AgentId::now_v7(),
622 agent_b: AgentId::now_v7(),
623 events_a: &[],
624 events_b: &[],
625 context_items_a: &[],
626 context_items_b: &[],
627 decay: &DecayParams::exponential(10.0),
628 weights: &inf_weights,
629 threshold: 0.85,
630 });
631 assert!(
632 meter.composite_score.is_finite(),
633 "composite_score must be finite, got {}",
634 meter.composite_score
635 );
636 assert_eq!(meter.composite_score, 0.0);
637 }
638
639 #[test]
640 fn test_validate_returns_specific_error_variants() {
641 let w = DriftWeights {
643 state: f32::NAN,
644 causal: -1.0,
645 context: 0.5,
646 };
647 assert!(matches!(w.validate(), Err(DriftWeightsError::NonFinite)));
648
649 let w = DriftWeights {
651 state: -0.5,
652 causal: 1.0,
653 context: 1.0,
654 };
655 assert!(matches!(w.validate(), Err(DriftWeightsError::Negative)));
656
657 let w = DriftWeights {
659 state: 0.1,
660 causal: 0.1,
661 context: 0.1,
662 };
663 assert!(matches!(w.validate(), Err(DriftWeightsError::SumNotOne(_))));
664 }
665
666 #[test]
667 fn test_mixed_nan_valid_weights() {
668 let w = DriftWeights {
670 state: 0.4,
671 causal: f32::NAN,
672 context: 0.25,
673 };
674 assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
675 }
676
677 #[test]
678 fn test_drift_weights_error_display() {
679 let e = DriftWeightsError::NonFinite;
680 assert!(e.to_string().contains("NaN or Infinity"));
681 let e = DriftWeightsError::Negative;
682 assert!(e.to_string().contains("negative"));
683 let e = DriftWeightsError::SumNotOne(1.5);
684 assert!(e.to_string().contains("1.5"));
685 }
686}