1use crate::{CellstateError, CellstateResult, VectorError};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
10pub struct EmbeddingVector {
11 pub data: Vec<f32>,
13 pub model_id: String,
15 pub dimensions: i32,
17}
18
19impl EmbeddingVector {
20 pub fn new(data: Vec<f32>, model_id: String) -> Self {
22 let dimensions = data.len() as i32;
23 Self {
24 data,
25 model_id,
26 dimensions,
27 }
28 }
29
30 pub fn cosine_similarity(&self, other: &EmbeddingVector) -> CellstateResult<f32> {
32 if self.dimensions != other.dimensions {
33 return Err(CellstateError::Vector(VectorError::DimensionMismatch {
34 expected: self.dimensions,
35 got: other.dimensions,
36 }));
37 }
38
39 if self.data.iter().any(|x| !x.is_finite()) || other.data.iter().any(|x| !x.is_finite()) {
40 return Err(CellstateError::Vector(VectorError::NonFiniteValues));
41 }
42
43 let mut dot_product = 0.0f32;
44 let mut norm_a = 0.0f32;
45 let mut norm_b = 0.0f32;
46
47 for (a, b) in self.data.iter().zip(other.data.iter()) {
48 dot_product += a * b;
49 norm_a += a * a;
50 norm_b += b * b;
51 }
52
53 let norm_a = norm_a.sqrt();
54 let norm_b = norm_b.sqrt();
55
56 if norm_a == 0.0 || norm_b == 0.0 {
57 return Ok(0.0);
58 }
59
60 Ok(dot_product / (norm_a * norm_b))
61 }
62
63 pub fn is_valid(&self) -> bool {
65 self.dimensions > 0 && self.data.len() == self.dimensions as usize
66 }
67}
68
69#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::{CellstateError, VectorError};
77
78 #[test]
79 fn test_new_sets_dimensions() {
80 let data = vec![0.0, 1.0, 0.5];
81 let vec = EmbeddingVector::new(data.clone(), "model".to_string());
82 assert_eq!(vec.dimensions, data.len() as i32);
83 assert_eq!(vec.data, data);
84 assert_eq!(vec.model_id, "model");
85 }
86
87 #[test]
88 fn test_is_valid_checks_dimensions_and_length() {
89 let valid = EmbeddingVector {
90 data: vec![0.0, 1.0],
91 model_id: "m".to_string(),
92 dimensions: 2,
93 };
94 assert!(valid.is_valid());
95
96 let invalid_len = EmbeddingVector {
97 data: vec![0.0, 1.0],
98 model_id: "m".to_string(),
99 dimensions: 3,
100 };
101 assert!(!invalid_len.is_valid());
102
103 let invalid_dim = EmbeddingVector {
104 data: vec![0.0, 1.0],
105 model_id: "m".to_string(),
106 dimensions: 0,
107 };
108 assert!(!invalid_dim.is_valid());
109 }
110
111 #[test]
112 fn test_empty_vector_is_invalid() {
113 let vec = EmbeddingVector::new(vec![], "model".to_string());
114 assert_eq!(vec.dimensions, 0);
115 assert!(!vec.is_valid());
116 }
117
118 #[test]
119 fn test_cosine_similarity_identical_vectors() {
120 let a = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
121 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
122 let sim = a
123 .cosine_similarity(&b)
124 .expect("identical vectors should compute similarity");
125 assert!((sim - 1.0).abs() < 1e-6);
126 }
127
128 #[test]
129 fn test_cosine_similarity_orthogonal_vectors() {
130 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
131 let b = EmbeddingVector::new(vec![0.0, 1.0], "model".to_string());
132 let sim = a
133 .cosine_similarity(&b)
134 .expect("orthogonal vectors should compute similarity");
135 assert!(sim.abs() < 1e-6);
136 }
137
138 #[test]
139 fn test_cosine_similarity_zero_vector_returns_zero() {
140 let a = EmbeddingVector::new(vec![0.0, 0.0], "model".to_string());
141 let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
142 let sim = a
143 .cosine_similarity(&b)
144 .expect("zero vector should compute similarity");
145 assert_eq!(sim, 0.0);
146 }
147
148 #[test]
149 fn test_cosine_similarity_dimension_mismatch() {
150 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
151 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
152 let err = a
153 .cosine_similarity(&b)
154 .expect_err("dimension mismatch should return error");
155 assert!(matches!(
156 err,
157 CellstateError::Vector(VectorError::DimensionMismatch {
158 expected: 2,
159 got: 3
160 })
161 ));
162 }
163
164 #[test]
165 fn test_cosine_similarity_opposite_vectors() {
166 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
167 let b = EmbeddingVector::new(vec![-1.0, 0.0], "model".to_string());
168 let sim = a
169 .cosine_similarity(&b)
170 .expect("opposite vectors should compute similarity");
171 assert!((sim + 1.0).abs() < 1e-6);
172 }
173
174 #[test]
175 fn test_cosine_similarity_scaled_vectors() {
176 let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
177 let b = EmbeddingVector::new(vec![2.0, 4.0, 6.0], "model".to_string());
178 let sim = a
179 .cosine_similarity(&b)
180 .expect("scaled vectors should compute similarity");
181 assert!((sim - 1.0).abs() < 1e-6);
182 }
183
184 #[test]
185 fn test_cosine_similarity_is_symmetric() {
186 let a = EmbeddingVector::new(vec![1.0, 2.0], "model".to_string());
187 let b = EmbeddingVector::new(vec![3.0, 4.0], "model".to_string());
188 let ab = a
189 .cosine_similarity(&b)
190 .expect("similarity should compute successfully");
191 let ba = b
192 .cosine_similarity(&a)
193 .expect("similarity should compute successfully");
194 assert!((ab - ba).abs() < 1e-6);
195 }
196
197 #[test]
198 fn test_is_valid_negative_dimensions() {
199 let invalid = EmbeddingVector {
200 data: vec![0.0, 1.0],
201 model_id: "m".to_string(),
202 dimensions: -1,
203 };
204 assert!(!invalid.is_valid());
205 }
206
207 #[test]
210 fn test_nan_vector_cosine_similarity() {
211 let a = EmbeddingVector::new(vec![f32::NAN, 1.0, 0.0], "model".to_string());
212 let b = EmbeddingVector::new(vec![1.0, 0.0, 0.0], "model".to_string());
213 let err = a
214 .cosine_similarity(&b)
215 .expect_err("NaN vector should return error");
216 assert!(matches!(
217 err,
218 CellstateError::Vector(VectorError::NonFiniteValues)
219 ));
220 }
221
222 #[test]
223 fn test_infinity_vector_cosine_similarity() {
224 let a = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
225 let b = EmbeddingVector::new(vec![f32::INFINITY, 0.0], "model".to_string());
226 let err = a
227 .cosine_similarity(&b)
228 .expect_err("Infinity vector should return error");
229 assert!(matches!(
230 err,
231 CellstateError::Vector(VectorError::NonFiniteValues)
232 ));
233 }
234
235 #[test]
236 fn test_neg_infinity_vector_cosine_similarity() {
237 let a = EmbeddingVector::new(vec![f32::NEG_INFINITY, 1.0], "model".to_string());
238 let b = EmbeddingVector::new(vec![1.0, 0.0], "model".to_string());
239 let err = a
240 .cosine_similarity(&b)
241 .expect_err("NEG_INFINITY vector should return error");
242 assert!(matches!(
243 err,
244 CellstateError::Vector(VectorError::NonFiniteValues)
245 ));
246 }
247
248 #[test]
249 fn test_mixed_nan_valid_vector() {
250 let a = EmbeddingVector::new(vec![1.0, 2.0, f32::NAN], "model".to_string());
252 let b = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
253 let err = a
254 .cosine_similarity(&b)
255 .expect_err("mixed NaN vector should return error");
256 assert!(matches!(
257 err,
258 CellstateError::Vector(VectorError::NonFiniteValues)
259 ));
260 }
261
262 #[test]
263 fn test_both_vectors_nan() {
264 let a = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
265 let b = EmbeddingVector::new(vec![f32::NAN, f32::NAN], "model".to_string());
266 let err = a
267 .cosine_similarity(&b)
268 .expect_err("all-NaN vectors should return error");
269 assert!(matches!(
270 err,
271 CellstateError::Vector(VectorError::NonFiniteValues)
272 ));
273 }
274
275 #[test]
276 fn test_valid_vectors_still_work_after_guard() {
277 let a = EmbeddingVector::new(vec![1.0, 2.0, 3.0], "model".to_string());
279 let b = EmbeddingVector::new(vec![4.0, 5.0, 6.0], "model".to_string());
280 let sim = a
281 .cosine_similarity(&b)
282 .expect("valid vectors should compute similarity");
283 assert!(
284 sim > 0.9,
285 "parallel-ish vectors should have high similarity"
286 );
287 }
288}