cellstate_core/
embedding.rs

1//! Embedding vector operations
2
3use crate::{CellstateError, CellstateResult, VectorError};
4use serde::{Deserialize, Serialize};
5
6/// Embedding vector with dynamic dimensions.
7/// Supports any embedding model dimension (e.g., 384, 768, 1536, 3072).
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
10pub struct EmbeddingVector {
11    /// The embedding data as a vector of f32 values.
12    pub data: Vec<f32>,
13    /// Identifier of the model that produced this embedding.
14    pub model_id: String,
15    /// Number of dimensions (must match data.len()).
16    pub dimensions: i32,
17}
18
19impl EmbeddingVector {
20    /// Create a new embedding vector.
21    pub fn new(data: Vec<f32>, model_id: String) -> Self {
22        let dimensions = data.len() as i32;
23        Self {
24            data,
25            model_id,
26            dimensions,
27        }
28    }
29
30    /// Compute cosine similarity between two embedding vectors.
31    pub fn cosine_similarity(&self, other: &EmbeddingVector) -> CellstateResult<f32> {
32        if self.dimensions != other.dimensions {
33            return Err(CellstateError::Vector(VectorError::DimensionMismatch {
34                expected: self.dimensions,
35                got: other.dimensions,
36            }));
37        }
38
39        if self.data.iter().any(|x| !x.is_finite()) || other.data.iter().any(|x| !x.is_finite()) {
40            return Err(CellstateError::Vector(VectorError::NonFiniteValues));
41        }
42
43        let mut dot_product = 0.0f32;
44        let mut norm_a = 0.0f32;
45        let mut norm_b = 0.0f32;
46
47        for (a, b) in self.data.iter().zip(other.data.iter()) {
48            dot_product += a * b;
49            norm_a += a * a;
50            norm_b += b * b;
51        }
52
53        let norm_a = norm_a.sqrt();
54        let norm_b = norm_b.sqrt();
55
56        if norm_a == 0.0 || norm_b == 0.0 {
57            return Ok(0.0);
58        }
59
60        Ok(dot_product / (norm_a * norm_b))
61    }
62
63    /// Check if this vector has valid dimensions.
64    pub fn is_valid(&self) -> bool {
65        self.dimensions > 0 && self.data.len() == self.dimensions as usize
66    }
67}
68
69// =============================================================================
70// TESTS
71// =============================================================================
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::{CellstateError, VectorError};
77
78    #[test]
79    fn test_new_sets_dimensions() {
80        let data = vec![0.0, 1.0, 0.5];
81        let vec = EmbeddingVector::new(data.clone(), "model".to_string());
82        assert_eq!(vec.dimensions, data.len() as i32);
83        assert_eq!(vec.data, data);
84        assert_eq!(vec.model_id, "model");
85    }
86
87    #[test]
88    fn test_is_valid_checks_dimensions_and_length() {
89        let valid = EmbeddingVector {
90            data: vec![0.0, 1.0],
91            model_id: "m".to_string(),
92            dimensions: 2,
93        };
94        assert!(valid.is_valid());
95
96        let invalid_len = EmbeddingVector {
97            data: vec![0.0, 1.0],
98            model_id: "m".to_string(),
99            dimensions: 3,
100        };
101        assert!(!invalid_len.is_valid());
102
103        let invalid_dim = EmbeddingVector {
104            data: vec![0.0, 1.0],
105            model_id: "m".to_string(),
106            dimensions: 0,
107        };
108        assert!(!invalid_dim.is_valid());
109    }
110
111    #[test]
112    fn test_empty_vector_is_invalid() {
113        let vec = EmbeddingVector::new(vec![], "model".to_string());
114        assert_eq!(vec.dimensions, 0);
115        assert!(!vec.is_valid());
116    }
117
118    #[test]
119    fn test_cosine_similarity_identical_vectors() {
120        let a = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
121        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
122        let sim = a
123            .cosine_similarity(&b)
124            .expect("identical vectors should compute similarity");
125        assert!((sim - 1.0).abs() < 1e-6);
126    }
127
128    #[test]
129    fn test_cosine_similarity_orthogonal_vectors() {
130        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
131        let b = EmbeddingVector::new(vec![0.0, 1.0], "model".to_string());
132        let sim = a
133            .cosine_similarity(&b)
134            .expect("orthogonal vectors should compute similarity");
135        assert!(sim.abs() < 1e-6);
136    }
137
138    #[test]
139    fn test_cosine_similarity_zero_vector_returns_zero() {
140        let a = EmbeddingVector::new(vec![0.0, 0.0], "model".to_string());
141        let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
142        let sim = a
143            .cosine_similarity(&b)
144            .expect("zero vector should compute similarity");
145        assert_eq!(sim, 0.0);
146    }
147
148    #[test]
149    fn test_cosine_similarity_dimension_mismatch() {
150        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
151        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
152        let err = a
153            .cosine_similarity(&b)
154            .expect_err("dimension mismatch should return error");
155        assert!(matches!(
156            err,
157            CellstateError::Vector(VectorError::DimensionMismatch {
158                expected: 2,
159                got: 3
160            })
161        ));
162    }
163
164    #[test]
165    fn test_cosine_similarity_opposite_vectors() {
166        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
167        let b = EmbeddingVector::new(vec![-1.0, 0.0], "model".to_string());
168        let sim = a
169            .cosine_similarity(&b)
170            .expect("opposite vectors should compute similarity");
171        assert!((sim + 1.0).abs() < 1e-6);
172    }
173
174    #[test]
175    fn test_cosine_similarity_scaled_vectors() {
176        let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
177        let b = EmbeddingVector::new(vec![2.0, 4.0, 6.0], "model".to_string());
178        let sim = a
179            .cosine_similarity(&b)
180            .expect("scaled vectors should compute similarity");
181        assert!((sim - 1.0).abs() < 1e-6);
182    }
183
184    #[test]
185    fn test_cosine_similarity_is_symmetric() {
186        let a = EmbeddingVector::new(vec![1.0, 2.0], "model".to_string());
187        let b = EmbeddingVector::new(vec![3.0, 4.0], "model".to_string());
188        let ab = a
189            .cosine_similarity(&b)
190            .expect("similarity should compute successfully");
191        let ba = b
192            .cosine_similarity(&a)
193            .expect("similarity should compute successfully");
194        assert!((ab - ba).abs() < 1e-6);
195    }
196
197    #[test]
198    fn test_is_valid_negative_dimensions() {
199        let invalid = EmbeddingVector {
200            data: vec![0.0, 1.0],
201            model_id: "m".to_string(),
202            dimensions: -1,
203        };
204        assert!(!invalid.is_valid());
205    }
206
207    // ── NaN / Infinity production guard tests ────────────────────────
208
209    #[test]
210    fn test_nan_vector_cosine_similarity() {
211        let a = EmbeddingVector::new(vec![f32::NAN, 1.0, 0.0], "model".to_string());
212        let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
213        let err = a
214            .cosine_similarity(&b)
215            .expect_err("NaN vector should return error");
216        assert!(matches!(
217            err,
218            CellstateError::Vector(VectorError::NonFiniteValues)
219        ));
220    }
221
222    #[test]
223    fn test_infinity_vector_cosine_similarity() {
224        let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
225        let b = EmbeddingVector::new(vec![f32::INFINITY, 0.0], "model".to_string());
226        let err = a
227            .cosine_similarity(&b)
228            .expect_err("Infinity vector should return error");
229        assert!(matches!(
230            err,
231            CellstateError::Vector(VectorError::NonFiniteValues)
232        ));
233    }
234
235    #[test]
236    fn test_neg_infinity_vector_cosine_similarity() {
237        let a = EmbeddingVector::new(vec![f32::NEG_INFINITY, 1.0], "model".to_string());
238        let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
239        let err = a
240            .cosine_similarity(&b)
241            .expect_err("NEG_INFINITY vector should return error");
242        assert!(matches!(
243            err,
244            CellstateError::Vector(VectorError::NonFiniteValues)
245        ));
246    }
247
248    #[test]
249    fn test_mixed_nan_valid_vector() {
250        // One NaN among valid values still triggers error
251        let a = EmbeddingVector::new(vec![1.0, 2.0, f32::NAN], "model".to_string());
252        let b = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
253        let err = a
254            .cosine_similarity(&b)
255            .expect_err("mixed NaN vector should return error");
256        assert!(matches!(
257            err,
258            CellstateError::Vector(VectorError::NonFiniteValues)
259        ));
260    }
261
262    #[test]
263    fn test_both_vectors_nan() {
264        let a = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
265        let b = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
266        let err = a
267            .cosine_similarity(&b)
268            .expect_err("all-NaN vectors should return error");
269        assert!(matches!(
270            err,
271            CellstateError::Vector(VectorError::NonFiniteValues)
272        ));
273    }
274
275    #[test]
276    fn test_valid_vectors_still_work_after_guard() {
277        // Ensure the guard doesn't break normal operation
278        let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
279        let b = EmbeddingVector::new(vec![4.0, 5.0, 6.0], "model".to_string());
280        let sim = a
281            .cosine_similarity(&b)
282            .expect("valid vectors should compute similarity");
283        assert!(
284            sim > 0.9,
285            "parallel-ish vectors should have high similarity"
286        );
287    }
288}