1use crate::{CellstateError, CellstateResult, VectorError};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
10pub struct EmbeddingVector {
11 pub data: Vec<f32>,
13 pub model_id: String,
15 pub dimensions: i32,
17}
18
19impl EmbeddingVector {
20 pub fn new(data: Vec<f32>, model_id: String) -> Self {
22 let dimensions = data.len() as i32;
23 Self {
24 data,
25 model_id,
26 dimensions,
27 }
28 }
29
30 pub fn cosine_similarity(&self, other: &EmbeddingVector) -> CellstateResult<f32> {
32 if self.dimensions != other.dimensions {
33 return Err(CellstateError::Vector(VectorError::DimensionMismatch {
34 expected: self.dimensions,
35 got: other.dimensions,
36 }));
37 }
38
39 if self.data.iter().any(|x| !x.is_finite()) || other.data.iter().any(|x| !x.is_finite()) {
40 return Err(CellstateError::Vector(VectorError::NonFiniteValues));
41 }
42
43 let mut dot_product = 0.0f32;
44 let mut norm_a = 0.0f32;
45 let mut norm_b = 0.0f32;
46
47 for (a, b) in self.data.iter().zip(other.data.iter()) {
48 dot_product += a * b;
49 norm_a += a * a;
50 norm_b += b * b;
51 }
52
53 let norm_a = norm_a.sqrt();
54 let norm_b = norm_b.sqrt();
55
56 if norm_a == 0.0 || norm_b == 0.0 {
57 return Ok(0.0);
58 }
59
60 Ok(dot_product / (norm_a * norm_b))
61 }
62
63 pub fn is_valid(&self) -> bool {
65 self.dimensions > 0 && self.data.len() == self.dimensions as usize
66 }
67}
68
69pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79 if a.is_empty() || b.is_empty() || a.len() != b.len() {
80 return 0.0;
81 }
82
83 let mut dot = 0.0f32;
84 let mut norm_a = 0.0f32;
85 let mut norm_b = 0.0f32;
86
87 for (x, y) in a.iter().zip(b.iter()) {
88 let x = if x.is_finite() { *x } else { 0.0 };
89 let y = if y.is_finite() { *y } else { 0.0 };
90 dot += x * y;
91 norm_a += x * x;
92 norm_b += y * y;
93 }
94
95 let norm_a = norm_a.sqrt();
96 let norm_b = norm_b.sqrt();
97
98 if norm_a == 0.0 || norm_b == 0.0 {
99 return 0.0;
100 }
101
102 ((dot / (norm_a * norm_b)).clamp(-1.0, 1.0) + 1.0) / 2.0
103}
104
105pub fn cosine_similarity_budgeted(
110 a: &[f32],
111 b: &[f32],
112 budget: std::time::Duration,
113) -> (f32, bool) {
114 if budget.is_zero() {
115 return (cosine_similarity(a, b), false);
116 }
117 if a.is_empty() || b.is_empty() || a.len() != b.len() {
118 return (0.0, false);
119 }
120
121 let started = std::time::Instant::now();
122 let mut dot = 0.0f32;
123 let mut norm_a = 0.0f32;
124 let mut norm_b = 0.0f32;
125
126 for (idx, (x, y)) in a.iter().zip(b.iter()).enumerate() {
127 let x = if x.is_finite() { *x } else { 0.0 };
128 let y = if y.is_finite() { *y } else { 0.0 };
129 dot += x * y;
130 norm_a += x * x;
131 norm_b += y * y;
132
133 if idx % 128 == 127 && started.elapsed() > budget {
134 return (0.0, true);
135 }
136 }
137
138 let norm_a = norm_a.sqrt();
139 let norm_b = norm_b.sqrt();
140 if norm_a == 0.0 || norm_b == 0.0 {
141 return (0.0, false);
142 }
143 (
144 ((dot / (norm_a * norm_b)).clamp(-1.0, 1.0) + 1.0) / 2.0,
145 false,
146 )
147}
148
149#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::{CellstateError, VectorError};
157
158 #[test]
159 fn test_new_sets_dimensions() {
160 let data = vec![0.0, 1.0, 0.5];
161 let vec = EmbeddingVector::new(data.clone(), "model".to_string());
162 assert_eq!(vec.dimensions, data.len() as i32);
163 assert_eq!(vec.data, data);
164 assert_eq!(vec.model_id, "model");
165 }
166
167 #[test]
168 fn test_is_valid_checks_dimensions_and_length() {
169 let valid = EmbeddingVector {
170 data: vec![0.0, 1.0],
171 model_id: "m".to_string(),
172 dimensions: 2,
173 };
174 assert!(valid.is_valid());
175
176 let invalid_len = EmbeddingVector {
177 data: vec![0.0, 1.0],
178 model_id: "m".to_string(),
179 dimensions: 3,
180 };
181 assert!(!invalid_len.is_valid());
182
183 let invalid_dim = EmbeddingVector {
184 data: vec![0.0, 1.0],
185 model_id: "m".to_string(),
186 dimensions: 0,
187 };
188 assert!(!invalid_dim.is_valid());
189 }
190
191 #[test]
192 fn test_empty_vector_is_invalid() {
193 let vec = EmbeddingVector::new(vec![], "model".to_string());
194 assert_eq!(vec.dimensions, 0);
195 assert!(!vec.is_valid());
196 }
197
198 #[test]
199 fn test_cosine_similarity_identical_vectors() {
200 let a = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
201 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
202 let sim = a
203 .cosine_similarity(&b)
204 .expect("identical vectors should compute similarity");
205 assert!((sim - 1.0).abs() < 1e-6);
206 }
207
208 #[test]
209 fn test_cosine_similarity_orthogonal_vectors() {
210 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
211 let b = EmbeddingVector::new(vec![0.0, 1.0], "model".to_string());
212 let sim = a
213 .cosine_similarity(&b)
214 .expect("orthogonal vectors should compute similarity");
215 assert!(sim.abs() < 1e-6);
216 }
217
218 #[test]
219 fn test_cosine_similarity_zero_vector_returns_zero() {
220 let a = EmbeddingVector::new(vec![0.0, 0.0], "model".to_string());
221 let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
222 let sim = a
223 .cosine_similarity(&b)
224 .expect("zero vector should compute similarity");
225 assert_eq!(sim, 0.0);
226 }
227
228 #[test]
229 fn test_cosine_similarity_dimension_mismatch() {
230 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
231 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
232 let err = a
233 .cosine_similarity(&b)
234 .expect_err("dimension mismatch should return error");
235 assert!(matches!(
236 err,
237 CellstateError::Vector(VectorError::DimensionMismatch {
238 expected: 2,
239 got: 3
240 })
241 ));
242 }
243
244 #[test]
245 fn test_cosine_similarity_opposite_vectors() {
246 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
247 let b = EmbeddingVector::new(vec![-1.0, 0.0], "model".to_string());
248 let sim = a
249 .cosine_similarity(&b)
250 .expect("opposite vectors should compute similarity");
251 assert!((sim + 1.0).abs() < 1e-6);
252 }
253
254 #[test]
255 fn test_cosine_similarity_scaled_vectors() {
256 let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
257 let b = EmbeddingVector::new(vec![2.0, 4.0, 6.0], "model".to_string());
258 let sim = a
259 .cosine_similarity(&b)
260 .expect("scaled vectors should compute similarity");
261 assert!((sim - 1.0).abs() < 1e-6);
262 }
263
264 #[test]
265 fn test_cosine_similarity_is_symmetric() {
266 let a = EmbeddingVector::new(vec![1.0, 2.0], "model".to_string());
267 let b = EmbeddingVector::new(vec![3.0, 4.0], "model".to_string());
268 let ab = a
269 .cosine_similarity(&b)
270 .expect("similarity should compute successfully");
271 let ba = b
272 .cosine_similarity(&a)
273 .expect("similarity should compute successfully");
274 assert!((ab - ba).abs() < 1e-6);
275 }
276
277 #[test]
278 fn test_is_valid_negative_dimensions() {
279 let invalid = EmbeddingVector {
280 data: vec![0.0, 1.0],
281 model_id: "m".to_string(),
282 dimensions: -1,
283 };
284 assert!(!invalid.is_valid());
285 }
286
287 #[test]
290 fn test_nan_vector_cosine_similarity() {
291 let a = EmbeddingVector::new(vec![f32::NAN, 1.0, 0.0], "model".to_string());
292 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
293 let err = a
294 .cosine_similarity(&b)
295 .expect_err("NaN vector should return error");
296 assert!(matches!(
297 err,
298 CellstateError::Vector(VectorError::NonFiniteValues)
299 ));
300 }
301
302 #[test]
303 fn test_infinity_vector_cosine_similarity() {
304 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
305 let b = EmbeddingVector::new(vec![f32::INFINITY, 0.0], "model".to_string());
306 let err = a
307 .cosine_similarity(&b)
308 .expect_err("Infinity vector should return error");
309 assert!(matches!(
310 err,
311 CellstateError::Vector(VectorError::NonFiniteValues)
312 ));
313 }
314
315 #[test]
316 fn test_neg_infinity_vector_cosine_similarity() {
317 let a = EmbeddingVector::new(vec![f32::NEG_INFINITY, 1.0], "model".to_string());
318 let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
319 let err = a
320 .cosine_similarity(&b)
321 .expect_err("NEG_INFINITY vector should return error");
322 assert!(matches!(
323 err,
324 CellstateError::Vector(VectorError::NonFiniteValues)
325 ));
326 }
327
328 #[test]
329 fn test_mixed_nan_valid_vector() {
330 let a = EmbeddingVector::new(vec![1.0, 2.0, f32::NAN], "model".to_string());
332 let b = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
333 let err = a
334 .cosine_similarity(&b)
335 .expect_err("mixed NaN vector should return error");
336 assert!(matches!(
337 err,
338 CellstateError::Vector(VectorError::NonFiniteValues)
339 ));
340 }
341
342 #[test]
343 fn test_both_vectors_nan() {
344 let a = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
345 let b = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
346 let err = a
347 .cosine_similarity(&b)
348 .expect_err("all-NaN vectors should return error");
349 assert!(matches!(
350 err,
351 CellstateError::Vector(VectorError::NonFiniteValues)
352 ));
353 }
354
355 #[test]
356 fn test_valid_vectors_still_work_after_guard() {
357 let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
359 let b = EmbeddingVector::new(vec![4.0, 5.0, 6.0], "model".to_string());
360 let sim = a
361 .cosine_similarity(&b)
362 .expect("valid vectors should compute similarity");
363 assert!(
364 sim > 0.9,
365 "parallel-ish vectors should have high similarity"
366 );
367 }
368}