cellstate_core/
drift.rs

1//! Composite divergence metric for multi-agent drift detection.
2//!
3//! The `DriftMeter` computes a weighted divergence score between two agents
4//! by composing three orthogonal signals:
5//!
6//! 1. **State divergence**: How different are the agents' derived states?
7//! 2. **Causal divergence**: How much have their Event DAG lanes diverged?
8//! 3. **Context divergence**: How much overlap exists in their scored context?
9//!
10//! Each signal produces a value in [0, 1] (0 = identical, 1 = fully diverged).
11//! A Weibull temporal decay weights recent divergence higher than stale divergence.
12//!
13//! The composite score is a weighted sum using the same `ScoringWeights::validate()`
14//! and `normalize()` pattern, ensuring mathematical consistency with the rest of the
15//! scoring system.
16//!
17//! Re-export path: `cellstate_core::drift::*`
18
19use crate::{AgentId, AgentState, DecayParams, Event, Timestamp};
20use serde::{Deserialize, Serialize};
21
22/// Errors from validating [`DriftWeights`].
23#[derive(Debug, Clone, PartialEq)]
24pub enum DriftWeightsError {
25    /// One or more weights are NaN or Infinity.
26    NonFinite,
27    /// One or more weights are negative.
28    Negative,
29    /// Weights do not sum to 1.0 (±epsilon). Contains the actual sum.
30    SumNotOne(f32),
31}
32
33impl std::fmt::Display for DriftWeightsError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::NonFinite => write!(f, "one or more drift weights are NaN or Infinity"),
37            Self::Negative => write!(f, "one or more drift weights are negative"),
38            Self::SumNotOne(sum) => {
39                write!(f, "drift weights sum to {} instead of 1.0", sum)
40            }
41        }
42    }
43}
44
45impl std::error::Error for DriftWeightsError {}
46
47/// Divergence weights for the three DriftMeter signals.
48///
49/// Must sum to 1.0 (same constraint as `ScoringWeights`).
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
51pub struct DriftWeights {
52    /// Weight for agent state divergence signal.
53    pub state: f32,
54    /// Weight for causal (DAG lane) divergence signal.
55    pub causal: f32,
56    /// Weight for context overlap divergence signal.
57    pub context: f32,
58}
59
60impl Default for DriftWeights {
61    fn default() -> Self {
62        Self {
63            state: 0.40,
64            causal: 0.35,
65            context: 0.25,
66        }
67    }
68}
69
70impl DriftWeights {
71    /// Validate weights are finite, non-negative, and sum to 1.0 (±epsilon).
72    pub fn validate(&self) -> Result<(), DriftWeightsError> {
73        const EPSILON: f32 = 0.01;
74        if !self.state.is_finite() || !self.causal.is_finite() || !self.context.is_finite() {
75            return Err(DriftWeightsError::NonFinite);
76        }
77        if self.state < 0.0 || self.causal < 0.0 || self.context < 0.0 {
78            return Err(DriftWeightsError::Negative);
79        }
80        let sum = self.state + self.causal + self.context;
81        if (sum - 1.0).abs() > EPSILON {
82            return Err(DriftWeightsError::SumNotOne(sum));
83        }
84        Ok(())
85    }
86
87    /// Normalize weights to sum to exactly 1.0.
88    ///
89    /// If the sum is NaN, Infinity, or zero the weights are left unchanged
90    /// to avoid silently producing all-NaN values.
91    pub fn normalize(&mut self) {
92        let sum = self.state + self.causal + self.context;
93        if sum.is_finite() && sum > 0.0 {
94            self.state /= sum;
95            self.causal /= sum;
96            self.context /= sum;
97        }
98        // If sum is NaN/Inf/zero, leave weights unchanged
99    }
100}
101
102/// Input parameters for computing drift between two agents.
103#[derive(Debug, Clone)]
104pub struct DriftInput<'a> {
105    /// First agent ID.
106    pub agent_a: AgentId,
107    /// Second agent ID.
108    pub agent_b: AgentId,
109    /// Event stream for agent A.
110    pub events_a: &'a [Event<serde_json::Value>],
111    /// Event stream for agent B.
112    pub events_b: &'a [Event<serde_json::Value>],
113    /// Entity IDs scored into agent A's L1 context.
114    ///
115    /// Uses raw `uuid::Uuid` because context items are heterogeneous (notes,
116    /// artifacts, trajectories, etc.). Type checking happens at insertion time
117    /// via the `EntityIdType` trait on each concrete ID type. When Phase 2
118    /// populates these from the database, consider migrating to a typed
119    /// `ContextEntity` enum if cross-type confusion becomes a risk.
120    pub context_items_a: &'a [uuid::Uuid],
121    /// Entity IDs scored into agent B's L1 context. See `context_items_a` docs.
122    pub context_items_b: &'a [uuid::Uuid],
123    /// Temporal decay parameters (Weibull).
124    pub decay: &'a DecayParams,
125    /// Divergence signal weights.
126    pub weights: &'a DriftWeights,
127    /// Alignment threshold (from IntentDef).
128    pub threshold: f64,
129}
130
131/// Composite divergence metric between two agents.
132///
133/// All scores are in [0, 1]: 0.0 = identical, 1.0 = fully diverged.
134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
135#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
136pub struct DriftMeter {
137    /// First agent in the comparison.
138    pub agent_a: AgentId,
139    /// Second agent in the comparison.
140    pub agent_b: AgentId,
141    /// State-level divergence (0 = same state, 1 = fully different).
142    pub state_divergence: f64,
143    /// Causal divergence from DAG lane comparison.
144    pub causal_divergence: f64,
145    /// Context overlap divergence (1 - Jaccard similarity of scored items).
146    pub context_divergence: f64,
147    /// Weighted composite score.
148    pub composite_score: f64,
149    /// Threshold from IntentDef.drift_threshold (default 0.85).
150    /// Drift is detected when `1.0 - composite_score < threshold`,
151    /// i.e. when composite divergence is high enough.
152    pub threshold: f64,
153    /// Whether the agents are considered to be drifting apart.
154    pub is_drifting: bool,
155    /// When this measurement was taken.
156    #[cfg_attr(feature = "openapi", schema(value_type = String, format = "date-time"))]
157    pub computed_at: Timestamp,
158}
159
160impl DriftMeter {
161    /// Compute the drift between two agents from their event streams.
162    ///
163    /// This is a pure function — no I/O.
164    pub fn compute(input: DriftInput<'_>) -> Self {
165        let now = chrono::Utc::now();
166
167        // Signal 1: State divergence
168        let state_a = AgentState::from_events(input.events_a);
169        let state_b = AgentState::from_events(input.events_b);
170        let state_divergence = Self::compute_state_divergence(&state_a, &state_b);
171
172        // Signal 2: Causal divergence (DAG lane comparison)
173        let causal_divergence =
174            Self::compute_causal_divergence(input.events_a, input.events_b, input.decay);
175
176        // Signal 3: Context divergence (Jaccard distance of scored items)
177        let context_divergence =
178            Self::compute_context_divergence(input.context_items_a, input.context_items_b);
179
180        // Composite weighted score
181        let composite_score = input.weights.state as f64 * state_divergence
182            + input.weights.causal as f64 * causal_divergence
183            + input.weights.context as f64 * context_divergence;
184
185        // Guard NaN/Infinity: treat as no divergence rather than silently
186        // poisoning downstream decisions.
187        let composite_score = if composite_score.is_finite() {
188            composite_score.clamp(0.0, 1.0)
189        } else {
190            0.0 // fail-safe: treat as no divergence
191        };
192
193        // Drift detected when alignment (1 - divergence) drops below threshold
194        let alignment = 1.0 - composite_score;
195        let is_drifting = alignment < input.threshold;
196
197        Self {
198            agent_a: input.agent_a,
199            agent_b: input.agent_b,
200            state_divergence,
201            causal_divergence,
202            context_divergence,
203            composite_score,
204            threshold: input.threshold,
205            is_drifting,
206            computed_at: now,
207        }
208    }
209
210    /// Compare two AgentState values for divergence.
211    ///
212    /// Returns 0.0 if identical variant, 0.5 if same category but different
213    /// details, 1.0 if completely different states.
214    fn compute_state_divergence(a: &AgentState, b: &AgentState) -> f64 {
215        use std::mem::discriminant;
216
217        if a == b {
218            return 0.0;
219        }
220        if discriminant(a) == discriminant(b) {
221            // Same variant but different inner data (e.g. different scope_ids)
222            return 0.5;
223        }
224        1.0
225    }
226
227    /// Compare DAG lane sequences between two agents.
228    ///
229    /// Uses the ratio of shared event kinds in recent history as a proxy
230    /// for causal alignment. Apply Weibull decay to weight recent events more.
231    fn compute_causal_divergence(
232        events_a: &[Event<serde_json::Value>],
233        events_b: &[Event<serde_json::Value>],
234        decay: &DecayParams,
235    ) -> f64 {
236        if events_a.is_empty() && events_b.is_empty() {
237            return 0.0;
238        }
239        if events_a.is_empty() || events_b.is_empty() {
240            return 1.0;
241        }
242
243        // Extract recent event kinds with decay-weighted counts
244        let weighted_kinds_a = Self::decay_weighted_kind_histogram(events_a, decay);
245        let weighted_kinds_b = Self::decay_weighted_kind_histogram(events_b, decay);
246
247        // Compute cosine-like similarity between the two histograms
248        let all_kinds: std::collections::HashSet<u16> = weighted_kinds_a
249            .keys()
250            .chain(weighted_kinds_b.keys())
251            .copied()
252            .collect();
253
254        if all_kinds.is_empty() {
255            return 0.0;
256        }
257
258        let mut dot = 0.0f64;
259        let mut norm_a = 0.0f64;
260        let mut norm_b = 0.0f64;
261
262        for kind in &all_kinds {
263            let wa = *weighted_kinds_a.get(kind).unwrap_or(&0.0) as f64;
264            let wb = *weighted_kinds_b.get(kind).unwrap_or(&0.0) as f64;
265            dot += wa * wb;
266            norm_a += wa * wa;
267            norm_b += wb * wb;
268        }
269
270        let denominator = norm_a.sqrt() * norm_b.sqrt();
271        if denominator == 0.0 {
272            return 1.0;
273        }
274
275        let cosine_similarity = dot / denominator;
276        (1.0 - cosine_similarity).clamp(0.0, 1.0)
277    }
278
279    /// Build a decay-weighted histogram of event kinds.
280    fn decay_weighted_kind_histogram(
281        events: &[Event<serde_json::Value>],
282        decay: &DecayParams,
283    ) -> std::collections::HashMap<u16, f32> {
284        let mut histogram = std::collections::HashMap::new();
285        let len = events.len();
286
287        for (i, event) in events.iter().enumerate() {
288            // Use position from end as the "age" for decay
289            let age = (len - 1 - i) as f32;
290            let weight = crate::parametric_decay(age, decay);
291            *histogram.entry(event.header.event_kind.0).or_insert(0.0f32) += weight;
292        }
293
294        histogram
295    }
296
297    /// Compute context divergence as 1 - Jaccard similarity.
298    fn compute_context_divergence(items_a: &[uuid::Uuid], items_b: &[uuid::Uuid]) -> f64 {
299        if items_a.is_empty() && items_b.is_empty() {
300            return 0.0;
301        }
302
303        let set_a: std::collections::HashSet<uuid::Uuid> = items_a.iter().copied().collect();
304        let set_b: std::collections::HashSet<uuid::Uuid> = items_b.iter().copied().collect();
305
306        let intersection = set_a.intersection(&set_b).count();
307        let union = set_a.union(&set_b).count();
308
309        if union == 0 {
310            return 0.0;
311        }
312
313        let jaccard = intersection as f64 / union as f64;
314        1.0 - jaccard
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::identity::EntityIdType;
322    use crate::{DagPosition, Event, EventFlags, EventHeader, EventId, EventKind};
323    use serde_json::json;
324
325    fn make_scope_event(kind: EventKind, scope_id: crate::ScopeId) -> Event<serde_json::Value> {
326        Event::new(
327            EventHeader::new(
328                EventId::now_v7(),
329                EventId::now_v7(),
330                chrono::Utc::now().timestamp_micros(),
331                DagPosition::root(),
332                0,
333                kind,
334                EventFlags::empty(),
335                None,
336            ),
337            json!({ "scope_id": scope_id }),
338        )
339    }
340
341    #[test]
342    fn test_identical_agents_zero_divergence() {
343        let agent_a = AgentId::now_v7();
344        let agent_b = AgentId::now_v7();
345        let scope_id = crate::ScopeId::now_v7();
346
347        let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
348        let context: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
349        let decay = DecayParams::exponential(10.0);
350        let weights = DriftWeights::default();
351
352        let meter = DriftMeter::compute(DriftInput {
353            agent_a,
354            agent_b,
355            events_a: &events,
356            events_b: &events,
357            context_items_a: &context,
358            context_items_b: &context,
359            decay: &decay,
360            weights: &weights,
361            threshold: 0.85,
362        });
363
364        assert_eq!(meter.state_divergence, 0.0);
365        assert_eq!(meter.causal_divergence, 0.0);
366        assert_eq!(meter.context_divergence, 0.0);
367        assert_eq!(meter.composite_score, 0.0);
368        assert!(!meter.is_drifting);
369    }
370
371    #[test]
372    fn test_fully_diverged_agents() {
373        let agent_a = AgentId::now_v7();
374        let agent_b = AgentId::now_v7();
375        let scope_a = crate::ScopeId::now_v7();
376        let scope_b = crate::ScopeId::now_v7();
377
378        let events_a = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_a)];
379        let events_b = vec![
380            make_scope_event(EventKind::SCOPE_CREATED, scope_b),
381            make_scope_event(EventKind::SCOPE_CLOSED, scope_b),
382        ];
383        let context_a: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
384        let context_b: Vec<uuid::Uuid> = vec![uuid::Uuid::now_v7()];
385        let decay = DecayParams::exponential(10.0);
386        let weights = DriftWeights::default();
387
388        let meter = DriftMeter::compute(DriftInput {
389            agent_a,
390            agent_b,
391            events_a: &events_a,
392            events_b: &events_b,
393            context_items_a: &context_a,
394            context_items_b: &context_b,
395            decay: &decay,
396            weights: &weights,
397            threshold: 0.85,
398        });
399
400        // State: Gathering vs Complete = 1.0 divergence
401        assert_eq!(meter.state_divergence, 1.0);
402        // Context: disjoint sets = 1.0 divergence
403        assert_eq!(meter.context_divergence, 1.0);
404        // Composite should be high
405        assert!(meter.composite_score > 0.5);
406        assert!(meter.is_drifting);
407    }
408
409    #[test]
410    fn test_empty_events_no_divergence() {
411        let meter = DriftMeter::compute(DriftInput {
412            agent_a: AgentId::now_v7(),
413            agent_b: AgentId::now_v7(),
414            events_a: &[],
415            events_b: &[],
416            context_items_a: &[],
417            context_items_b: &[],
418            decay: &DecayParams::exponential(10.0),
419            weights: &DriftWeights::default(),
420            threshold: 0.85,
421        });
422        assert_eq!(meter.composite_score, 0.0);
423        assert!(!meter.is_drifting);
424    }
425
426    #[test]
427    fn test_one_empty_stream_fully_diverged() {
428        let scope_id = crate::ScopeId::now_v7();
429        let events = vec![make_scope_event(EventKind::SCOPE_CREATED, scope_id)];
430        let meter = DriftMeter::compute(DriftInput {
431            agent_a: AgentId::now_v7(),
432            agent_b: AgentId::now_v7(),
433            events_a: &events,
434            events_b: &[],
435            context_items_a: &[],
436            context_items_b: &[],
437            decay: &DecayParams::exponential(10.0),
438            weights: &DriftWeights::default(),
439            threshold: 0.85,
440        });
441        // State: Gathering (has scope_id) vs Idle = 1.0
442        assert_eq!(meter.state_divergence, 1.0);
443        // Causal: one empty = 1.0
444        assert_eq!(meter.causal_divergence, 1.0);
445        assert!(meter.is_drifting);
446    }
447
448    #[test]
449    fn test_drift_weights_validate() {
450        assert!(DriftWeights::default().validate().is_ok());
451        assert!(DriftWeights {
452            state: 0.5,
453            causal: 0.5,
454            context: 0.5,
455        }
456        .validate()
457        .is_err());
458    }
459
460    #[test]
461    fn test_drift_weights_normalize() {
462        let mut weights = DriftWeights {
463            state: 2.0,
464            causal: 2.0,
465            context: 1.0,
466        };
467        weights.normalize();
468        assert!(weights.validate().is_ok());
469        assert!((weights.state - 0.4).abs() < 0.01);
470    }
471
472    #[test]
473    fn test_context_divergence_partial_overlap() {
474        let shared = uuid::Uuid::now_v7();
475        let only_a = uuid::Uuid::now_v7();
476        let only_b = uuid::Uuid::now_v7();
477
478        let items_a = vec![shared, only_a];
479        let items_b = vec![shared, only_b];
480
481        let divergence = DriftMeter::compute_context_divergence(&items_a, &items_b);
482        // Jaccard: 1 intersection / 3 union = 0.333, divergence = 0.667
483        assert!((divergence - 0.667).abs() < 0.01);
484    }
485
486    // ── NaN / Infinity production guard tests ────────────────────────
487
488    #[test]
489    fn test_nan_weight_detected() {
490        let w = DriftWeights {
491            state: f32::NAN,
492            causal: 0.5,
493            context: 0.5,
494        };
495        assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
496    }
497
498    #[test]
499    fn test_infinity_weight_detected() {
500        let w = DriftWeights {
501            state: f32::INFINITY,
502            causal: 0.0,
503            context: 0.0,
504        };
505        assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
506    }
507
508    #[test]
509    fn test_neg_infinity_weight_detected() {
510        let w = DriftWeights {
511            state: f32::NEG_INFINITY,
512            causal: 0.5,
513            context: 0.5,
514        };
515        assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
516    }
517
518    #[test]
519    fn test_negative_weight_detected() {
520        let w = DriftWeights {
521            state: -0.1,
522            causal: 0.6,
523            context: 0.5,
524        };
525        assert_eq!(w.validate(), Err(DriftWeightsError::Negative));
526    }
527
528    #[test]
529    fn test_sum_not_one_detected() {
530        let w = DriftWeights {
531            state: 0.5,
532            causal: 0.5,
533            context: 0.5,
534        };
535        match w.validate() {
536            Err(DriftWeightsError::SumNotOne(sum)) => {
537                assert!((sum - 1.5).abs() < 0.001);
538            }
539            other => panic!("Expected SumNotOne, got {:?}", other),
540        }
541    }
542
543    #[test]
544    fn test_all_zero_weights_normalize_unchanged() {
545        let mut w = DriftWeights {
546            state: 0.0,
547            causal: 0.0,
548            context: 0.0,
549        };
550        let before = w.clone();
551        w.normalize();
552        assert_eq!(w, before, "all-zero weights should be left unchanged");
553    }
554
555    #[test]
556    fn test_nan_weights_normalize_unchanged() {
557        let mut w = DriftWeights {
558            state: f32::NAN,
559            causal: 0.5,
560            context: 0.5,
561        };
562        // NaN sum means normalize should leave weights unchanged
563        w.normalize();
564        // state stays NaN (NaN != NaN, so check with is_nan)
565        assert!(w.state.is_nan());
566        assert_eq!(w.causal, 0.5);
567        assert_eq!(w.context, 0.5);
568    }
569
570    #[test]
571    fn test_infinity_weights_normalize_unchanged() {
572        let mut w = DriftWeights {
573            state: f32::INFINITY,
574            causal: 0.5,
575            context: 0.5,
576        };
577        let causal_before = w.causal;
578        let context_before = w.context;
579        w.normalize();
580        // Infinity sum is not finite, so weights should be unchanged
581        assert_eq!(w.state, f32::INFINITY);
582        assert_eq!(w.causal, causal_before);
583        assert_eq!(w.context, context_before);
584    }
585
586    #[test]
587    fn test_nan_composite_score_guarded() {
588        // NaN weights should produce composite_score = 0.0, not NaN
589        let nan_weights = DriftWeights {
590            state: f32::NAN,
591            causal: 0.5,
592            context: 0.5,
593        };
594        let meter = DriftMeter::compute(DriftInput {
595            agent_a: AgentId::now_v7(),
596            agent_b: AgentId::now_v7(),
597            events_a: &[],
598            events_b: &[],
599            context_items_a: &[],
600            context_items_b: &[],
601            decay: &DecayParams::exponential(10.0),
602            weights: &nan_weights,
603            threshold: 0.85,
604        });
605        assert!(
606            meter.composite_score.is_finite(),
607            "composite_score must be finite, got {}",
608            meter.composite_score
609        );
610        assert_eq!(meter.composite_score, 0.0);
611        assert!(!meter.is_drifting);
612    }
613
614    #[test]
615    fn test_infinity_in_composite_score_guarded() {
616        // Infinity weights should produce composite_score = 0.0, not Infinity
617        let inf_weights = DriftWeights {
618            state: f32::INFINITY,
619            causal: 0.0,
620            context: 0.0,
621        };
622        let meter = DriftMeter::compute(DriftInput {
623            agent_a: AgentId::now_v7(),
624            agent_b: AgentId::now_v7(),
625            events_a: &[],
626            events_b: &[],
627            context_items_a: &[],
628            context_items_b: &[],
629            decay: &DecayParams::exponential(10.0),
630            weights: &inf_weights,
631            threshold: 0.85,
632        });
633        assert!(
634            meter.composite_score.is_finite(),
635            "composite_score must be finite, got {}",
636            meter.composite_score
637        );
638        assert_eq!(meter.composite_score, 0.0);
639    }
640
641    #[test]
642    fn test_validate_returns_specific_error_variants() {
643        // NonFinite takes priority over other checks
644        let w = DriftWeights {
645            state: f32::NAN,
646            causal: -1.0,
647            context: 0.5,
648        };
649        assert!(matches!(w.validate(), Err(DriftWeightsError::NonFinite)));
650
651        // Negative takes priority over SumNotOne
652        let w = DriftWeights {
653            state: -0.5,
654            causal: 1.0,
655            context: 1.0,
656        };
657        assert!(matches!(w.validate(), Err(DriftWeightsError::Negative)));
658
659        // SumNotOne when all are valid but don't sum to 1
660        let w = DriftWeights {
661            state: 0.1,
662            causal: 0.1,
663            context: 0.1,
664        };
665        assert!(matches!(w.validate(), Err(DriftWeightsError::SumNotOne(_))));
666    }
667
668    #[test]
669    fn test_mixed_nan_valid_weights() {
670        // Only one NaN among otherwise valid values
671        let w = DriftWeights {
672            state: 0.4,
673            causal: f32::NAN,
674            context: 0.25,
675        };
676        assert_eq!(w.validate(), Err(DriftWeightsError::NonFinite));
677    }
678
679    #[test]
680    fn test_drift_weights_error_display() {
681        let e = DriftWeightsError::NonFinite;
682        assert!(e.to_string().contains("NaN or Infinity"));
683        let e = DriftWeightsError::Negative;
684        assert!(e.to_string().contains("negative"));
685        let e = DriftWeightsError::SumNotOne(1.5);
686        assert!(e.to_string().contains("1.5"));
687    }
688}