cellstate_core/
prompt_template.rs1use crate::redaction::ScrubbedText;
13use serde::{Deserialize, Serialize};
14use std::fmt;
15
16#[derive(Clone, Serialize, Deserialize)]
22#[serde(tag = "type", content = "value")]
23pub enum PromptSegment {
24 #[serde(rename = "static")]
26 Static(String),
27 #[serde(rename = "scrubbed")]
29 Scrubbed(ScrubbedText),
30}
31
32impl fmt::Debug for PromptSegment {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::Static(s) => write!(f, "Static({:?})", s),
36 Self::Scrubbed(t) => write!(f, "Scrubbed({:?})", t),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PromptTemplate {
51 segments: Vec<PromptSegment>,
52}
53
54impl PromptTemplate {
55 pub fn builder() -> PromptTemplateBuilder {
57 PromptTemplateBuilder {
58 segments: Vec::new(),
59 }
60 }
61
62 pub fn as_str_lossy(&self) -> String {
67 let mut result = String::new();
68 for segment in &self.segments {
69 match segment {
70 PromptSegment::Static(s) => result.push_str(s),
71 PromptSegment::Scrubbed(t) => result.push_str(t.as_redacted_str()),
72 }
73 }
74 result
75 }
76
77 pub fn segment_count(&self) -> usize {
79 self.segments.len()
80 }
81
82 pub fn is_empty(&self) -> bool {
84 self.segments.is_empty()
85 }
86
87 pub fn segments(&self) -> &[PromptSegment] {
89 &self.segments
90 }
91}
92
93impl fmt::Display for PromptTemplate {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 f.write_str(&self.as_str_lossy())
96 }
97}
98
99pub struct PromptTemplateBuilder {
107 segments: Vec<PromptSegment>,
108}
109
110impl PromptTemplateBuilder {
111 pub fn static_text(mut self, s: &'static str) -> Self {
113 if !s.is_empty() {
114 self.segments.push(PromptSegment::Static(s.to_owned()));
115 }
116 self
117 }
118
119 pub fn scrubbed(mut self, t: ScrubbedText) -> Self {
121 self.segments.push(PromptSegment::Scrubbed(t));
122 self
123 }
124
125 pub fn build(self) -> PromptTemplate {
127 PromptTemplate {
128 segments: self.segments,
129 }
130 }
131}
132
133#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::redaction::RedactionManifest;
141
142 fn make_scrubbed(text: &str) -> ScrubbedText {
143 ScrubbedText::new_verified(text.to_string(), RedactionManifest::empty())
144 }
145
146 #[test]
147 fn empty_template() {
148 let t = PromptTemplate::builder().build();
149 assert!(t.is_empty());
150 assert_eq!(t.segment_count(), 0);
151 assert_eq!(t.as_str_lossy(), "");
152 }
153
154 #[test]
155 fn static_only_template() {
156 let t = PromptTemplate::builder()
157 .static_text("<cellstate-memory>\n")
158 .static_text("</cellstate-memory>\n")
159 .build();
160 assert_eq!(t.segment_count(), 2);
161 assert_eq!(
162 t.as_str_lossy(),
163 "<cellstate-memory>\n</cellstate-memory>\n"
164 );
165 }
166
167 #[test]
168 fn static_plus_scrubbed() {
169 let scrubbed = make_scrubbed("artifact content with [REDACTED:ssn:abc12345]");
170 let t = PromptTemplate::builder()
171 .static_text("<cellstate-memory>\n")
172 .scrubbed(scrubbed)
173 .static_text("\n</cellstate-memory>\n")
174 .build();
175 assert_eq!(t.segment_count(), 3);
176 let text = t.as_str_lossy();
177 assert!(text.starts_with("<cellstate-memory>\n"));
178 assert!(text.contains("[REDACTED:ssn:abc12345]"));
179 assert!(text.ends_with("\n</cellstate-memory>\n"));
180 }
181
182 #[test]
183 fn display_matches_as_str_lossy() {
184 let t = PromptTemplate::builder()
185 .static_text("Hello ")
186 .scrubbed(make_scrubbed("world"))
187 .build();
188 assert_eq!(t.to_string(), t.as_str_lossy());
189 }
190
191 #[test]
192 fn empty_static_text_skipped() {
193 let t = PromptTemplate::builder()
194 .static_text("")
195 .static_text("hello")
196 .static_text("")
197 .build();
198 assert_eq!(t.segment_count(), 1);
199 assert_eq!(t.as_str_lossy(), "hello");
200 }
201
202 #[test]
203 fn clone_works() {
204 let t = PromptTemplate::builder()
205 .static_text("test")
206 .scrubbed(make_scrubbed("data"))
207 .build();
208 let cloned = t.clone();
209 assert_eq!(t.as_str_lossy(), cloned.as_str_lossy());
210 assert_eq!(t.segment_count(), cloned.segment_count());
211 }
212
213 #[test]
214 fn serde_roundtrip() {
215 let t = PromptTemplate::builder()
216 .static_text("<system>\n")
217 .scrubbed(make_scrubbed("scrubbed content"))
218 .static_text("\n</system>")
219 .build();
220 let json = serde_json::to_string(&t).unwrap();
221 let deserialized: PromptTemplate = serde_json::from_str(&json).unwrap();
222 assert_eq!(t.as_str_lossy(), deserialized.as_str_lossy());
223 assert_eq!(t.segment_count(), deserialized.segment_count());
224 }
225}