cellstate_core/
embedding.rs

1//! Embedding vector operations
2
3use crate::{CellstateError, CellstateResult, VectorError};
4use serde::{Deserialize, Serialize};
5
6/// Embedding vector with dynamic dimensions.
7/// Supports any embedding model dimension (e.g., 384, 768, 1536, 3072).
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
10pub struct EmbeddingVector {
11    /// The embedding data as a vector of f32 values.
12    pub data: Vec<f32>,
13    /// Identifier of the model that produced this embedding.
14    pub model_id: String,
15    /// Number of dimensions (must match data.len()).
16    pub dimensions: i32,
17}
18
19impl EmbeddingVector {
20    /// Create a new embedding vector.
21    pub fn new(data: Vec<f32>, model_id: String) -> Self {
22        let dimensions = data.len() as i32;
23        Self {
24            data,
25            model_id,
26            dimensions,
27        }
28    }
29
30    /// Compute cosine similarity between two embedding vectors.
31    pub fn cosine_similarity(&self, other: &EmbeddingVector) -> CellstateResult<f32> {
32        if self.dimensions != other.dimensions {
33            return Err(CellstateError::Vector(VectorError::DimensionMismatch {
34                expected: self.dimensions,
35                got: other.dimensions,
36            }));
37        }
38
39        if self.data.iter().any(|x| !x.is_finite()) || other.data.iter().any(|x| !x.is_finite()) {
40            return Err(CellstateError::Vector(VectorError::NonFiniteValues));
41        }
42
43        let mut dot_product = 0.0f32;
44        let mut norm_a = 0.0f32;
45        let mut norm_b = 0.0f32;
46
47        for (a, b) in self.data.iter().zip(other.data.iter()) {
48            dot_product += a * b;
49            norm_a += a * a;
50            norm_b += b * b;
51        }
52
53        let norm_a = norm_a.sqrt();
54        let norm_b = norm_b.sqrt();
55
56        if norm_a == 0.0 || norm_b == 0.0 {
57            return Ok(0.0);
58        }
59
60        Ok(dot_product / (norm_a * norm_b))
61    }
62
63    /// Check if this vector has valid dimensions.
64    pub fn is_valid(&self) -> bool {
65        self.dimensions > 0 && self.data.len() == self.dimensions as usize
66    }
67}
68
69// =============================================================================
70// FREE FUNCTIONS — raw slice cosine similarity
71// =============================================================================
72
73/// Compute cosine similarity between two f32 slices.
74///
75/// Returns a value in `[0, 1]` (mapped from raw `[-1, 1]` via `(sim + 1) / 2`).
76/// Returns `0.0` for empty, mismatched-length, or zero-norm vectors.
77/// Non-finite values (NaN, Inf) are treated as `0.0` per element.
78pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79    if a.is_empty() || b.is_empty() || a.len() != b.len() {
80        return 0.0;
81    }
82
83    let mut dot = 0.0f32;
84    let mut norm_a = 0.0f32;
85    let mut norm_b = 0.0f32;
86
87    for (x, y) in a.iter().zip(b.iter()) {
88        let x = if x.is_finite() { *x } else { 0.0 };
89        let y = if y.is_finite() { *y } else { 0.0 };
90        dot += x * y;
91        norm_a += x * x;
92        norm_b += y * y;
93    }
94
95    let norm_a = norm_a.sqrt();
96    let norm_b = norm_b.sqrt();
97
98    if norm_a == 0.0 || norm_b == 0.0 {
99        return 0.0;
100    }
101
102    ((dot / (norm_a * norm_b)).clamp(-1.0, 1.0) + 1.0) / 2.0
103}
104
105/// Time-budgeted cosine similarity.
106///
107/// Same as [`cosine_similarity`] but aborts and returns `(0.0, true)` if the
108/// computation exceeds `budget`. Checks every 128 element pairs.
109pub fn cosine_similarity_budgeted(
110    a: &[f32],
111    b: &[f32],
112    budget: std::time::Duration,
113) -> (f32, bool) {
114    if budget.is_zero() {
115        return (cosine_similarity(a, b), false);
116    }
117    if a.is_empty() || b.is_empty() || a.len() != b.len() {
118        return (0.0, false);
119    }
120
121    let started = std::time::Instant::now();
122    let mut dot = 0.0f32;
123    let mut norm_a = 0.0f32;
124    let mut norm_b = 0.0f32;
125
126    for (idx, (x, y)) in a.iter().zip(b.iter()).enumerate() {
127        let x = if x.is_finite() { *x } else { 0.0 };
128        let y = if y.is_finite() { *y } else { 0.0 };
129        dot += x * y;
130        norm_a += x * x;
131        norm_b += y * y;
132
133        if idx % 128 == 127 && started.elapsed() > budget {
134            return (0.0, true);
135        }
136    }
137
138    let norm_a = norm_a.sqrt();
139    let norm_b = norm_b.sqrt();
140    if norm_a == 0.0 || norm_b == 0.0 {
141        return (0.0, false);
142    }
143    (
144        ((dot / (norm_a * norm_b)).clamp(-1.0, 1.0) + 1.0) / 2.0,
145        false,
146    )
147}
148
149// =============================================================================
150// TESTS
151// =============================================================================
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::{CellstateError, VectorError};
157
158    #[test]
159    fn test_new_sets_dimensions() {
160        let data = vec![0.0, 1.0, 0.5];
161        let vec = EmbeddingVector::new(data.clone(), "model".to_string());
162        assert_eq!(vec.dimensions, data.len() as i32);
163        assert_eq!(vec.data, data);
164        assert_eq!(vec.model_id, "model");
165    }
166
167    #[test]
168    fn test_is_valid_checks_dimensions_and_length() {
169        let valid = EmbeddingVector {
170            data: vec![0.0, 1.0],
171            model_id: "m".to_string(),
172            dimensions: 2,
173        };
174        assert!(valid.is_valid());
175
176        let invalid_len = EmbeddingVector {
177            data: vec![0.0, 1.0],
178            model_id: "m".to_string(),
179            dimensions: 3,
180        };
181        assert!(!invalid_len.is_valid());
182
183        let invalid_dim = EmbeddingVector {
184            data: vec![0.0, 1.0],
185            model_id: "m".to_string(),
186            dimensions: 0,
187        };
188        assert!(!invalid_dim.is_valid());
189    }
190
191    #[test]
192    fn test_empty_vector_is_invalid() {
193        let vec = EmbeddingVector::new(vec![], "model".to_string());
194        assert_eq!(vec.dimensions, 0);
195        assert!(!vec.is_valid());
196    }
197
198    #[test]
199    fn test_cosine_similarity_identical_vectors() {
200        let a = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
201        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
202        let sim = a
203            .cosine_similarity(&b)
204            .expect("identical vectors should compute similarity");
205        assert!((sim - 1.0).abs() < 1e-6);
206    }
207
208    #[test]
209    fn test_cosine_similarity_orthogonal_vectors() {
210        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
211        let b = EmbeddingVector::new(vec![0.0, 1.0], "model".to_string());
212        let sim = a
213            .cosine_similarity(&b)
214            .expect("orthogonal vectors should compute similarity");
215        assert!(sim.abs() < 1e-6);
216    }
217
218    #[test]
219    fn test_cosine_similarity_zero_vector_returns_zero() {
220        let a = EmbeddingVector::new(vec![0.0, 0.0], "model".to_string());
221        let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
222        let sim = a
223            .cosine_similarity(&b)
224            .expect("zero vector should compute similarity");
225        assert_eq!(sim, 0.0);
226    }
227
228    #[test]
229    fn test_cosine_similarity_dimension_mismatch() {
230        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
231        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
232        let err = a
233            .cosine_similarity(&b)
234            .expect_err("dimension mismatch should return error");
235        assert!(matches!(
236            err,
237            CellstateError::Vector(VectorError::DimensionMismatch {
238                expected: 2,
239                got: 3
240            })
241        ));
242    }
243
244    #[test]
245    fn test_cosine_similarity_opposite_vectors() {
246        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
247        let b = EmbeddingVector::new(vec![-1.0, 0.0], "model".to_string());
248        let sim = a
249            .cosine_similarity(&b)
250            .expect("opposite vectors should compute similarity");
251        assert!((sim + 1.0).abs() < 1e-6);
252    }
253
254    #[test]
255    fn test_cosine_similarity_scaled_vectors() {
256        let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
257        let b = EmbeddingVector::new(vec![2.0, 4.0, 6.0], "model".to_string());
258        let sim = a
259            .cosine_similarity(&b)
260            .expect("scaled vectors should compute similarity");
261        assert!((sim - 1.0).abs() < 1e-6);
262    }
263
264    #[test]
265    fn test_cosine_similarity_is_symmetric() {
266        let a = EmbeddingVector::new(vec![1.0, 2.0], "model".to_string());
267        let b = EmbeddingVector::new(vec![3.0, 4.0], "model".to_string());
268        let ab = a
269            .cosine_similarity(&b)
270            .expect("similarity should compute successfully");
271        let ba = b
272            .cosine_similarity(&a)
273            .expect("similarity should compute successfully");
274        assert!((ab - ba).abs() < 1e-6);
275    }
276
277    #[test]
278    fn test_is_valid_negative_dimensions() {
279        let invalid = EmbeddingVector {
280            data: vec![0.0, 1.0],
281            model_id: "m".to_string(),
282            dimensions: -1,
283        };
284        assert!(!invalid.is_valid());
285    }
286
287    // ── NaN / Infinity production guard tests ────────────────────────
288
289    #[test]
290    fn test_nan_vector_cosine_similarity() {
291        let a = EmbeddingVector::new(vec![f32::NAN, 1.0, 0.0], "model".to_string());
292        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
293        let err = a
294            .cosine_similarity(&b)
295            .expect_err("NaN vector should return error");
296        assert!(matches!(
297            err,
298            CellstateError::Vector(VectorError::NonFiniteValues)
299        ));
300    }
301
302    #[test]
303    fn test_infinity_vector_cosine_similarity() {
304        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
305        let b = EmbeddingVector::new(vec![f32::INFINITY, 0.0], "model".to_string());
306        let err = a
307            .cosine_similarity(&b)
308            .expect_err("Infinity vector should return error");
309        assert!(matches!(
310            err,
311            CellstateError::Vector(VectorError::NonFiniteValues)
312        ));
313    }
314
315    #[test]
316    fn test_neg_infinity_vector_cosine_similarity() {
317        let a = EmbeddingVector::new(vec![f32::NEG_INFINITY, 1.0], "model".to_string());
318        let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
319        let err = a
320            .cosine_similarity(&b)
321            .expect_err("NEG_INFINITY vector should return error");
322        assert!(matches!(
323            err,
324            CellstateError::Vector(VectorError::NonFiniteValues)
325        ));
326    }
327
328    #[test]
329    fn test_mixed_nan_valid_vector() {
330        // One NaN among valid values still triggers error
331        let a = EmbeddingVector::new(vec![1.0, 2.0, f32::NAN], "model".to_string());
332        let b = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
333        let err = a
334            .cosine_similarity(&b)
335            .expect_err("mixed NaN vector should return error");
336        assert!(matches!(
337            err,
338            CellstateError::Vector(VectorError::NonFiniteValues)
339        ));
340    }
341
342    #[test]
343    fn test_both_vectors_nan() {
344        let a = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
345        let b = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
346        let err = a
347            .cosine_similarity(&b)
348            .expect_err("all-NaN vectors should return error");
349        assert!(matches!(
350            err,
351            CellstateError::Vector(VectorError::NonFiniteValues)
352        ));
353    }
354
355    #[test]
356    fn test_valid_vectors_still_work_after_guard() {
357        // Ensure the guard doesn't break normal operation
358        let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
359        let b = EmbeddingVector::new(vec![4.0, 5.0, 6.0], "model".to_string());
360        let sim = a
361            .cosine_similarity(&b)
362            .expect("valid vectors should compute similarity");
363        assert!(
364            sim > 0.9,
365            "parallel-ish vectors should have high similarity"
366        );
367    }
368}