cellstate_core/
prompt_template.rs

1//! Prompt Template — type-safe system prompt construction.
2//!
3//! `PromptTemplate` replaces `System(String)` in `MessageContent` to prevent
4//! user data from being smuggled through the system prompt escape hatch.
5//!
6//! A `PromptTemplate` can only be built from:
7//! - `&'static str` literals (known at compile time)
8//! - `ScrubbedText` (already through PII redaction)
9//!
10//! There is **no** path from an arbitrary `String` into a `PromptTemplate`.
11
12use crate::redaction::ScrubbedText;
13use serde::{Deserialize, Serialize};
14use std::fmt;
15
16// ============================================================================
17// PROMPT SEGMENT
18// ============================================================================
19
20/// A single segment of a prompt template.
21#[derive(Clone, Serialize, Deserialize)]
22#[serde(tag = "type", content = "value")]
23pub enum PromptSegment {
24    /// Fixed text known at compile time. Serialized as the string content.
25    #[serde(rename = "static")]
26    Static(String),
27    /// Pre-scrubbed user/stored content.
28    #[serde(rename = "scrubbed")]
29    Scrubbed(ScrubbedText),
30}
31
32impl fmt::Debug for PromptSegment {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::Static(s) => write!(f, "Static({:?})", s),
36            Self::Scrubbed(t) => write!(f, "Scrubbed({:?})", t),
37        }
38    }
39}
40
41// ============================================================================
42// PROMPT TEMPLATE
43// ============================================================================
44
45/// A system prompt built from static templates + pre-scrubbed inserts.
46///
47/// Cannot be constructed from arbitrary strings. Use [`PromptTemplateBuilder`]
48/// to construct via `PromptTemplate::builder()`.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PromptTemplate {
51    segments: Vec<PromptSegment>,
52}
53
54impl PromptTemplate {
55    /// Start building a new prompt template.
56    pub fn builder() -> PromptTemplateBuilder {
57        PromptTemplateBuilder {
58            segments: Vec::new(),
59        }
60    }
61
62    /// Concatenate all segments into a single string.
63    ///
64    /// Static segments are included verbatim. Scrubbed segments use
65    /// the redacted text (PII replaced with placeholders).
66    pub fn as_str_lossy(&self) -> String {
67        let mut result = String::new();
68        for segment in &self.segments {
69            match segment {
70                PromptSegment::Static(s) => result.push_str(s),
71                PromptSegment::Scrubbed(t) => result.push_str(t.as_redacted_str()),
72            }
73        }
74        result
75    }
76
77    /// Number of segments in this template.
78    pub fn segment_count(&self) -> usize {
79        self.segments.len()
80    }
81
82    /// Whether this template is empty (no segments).
83    pub fn is_empty(&self) -> bool {
84        self.segments.is_empty()
85    }
86
87    /// Access the segments for inspection.
88    pub fn segments(&self) -> &[PromptSegment] {
89        &self.segments
90    }
91}
92
93impl fmt::Display for PromptTemplate {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.write_str(&self.as_str_lossy())
96    }
97}
98
99// ============================================================================
100// BUILDER
101// ============================================================================
102
103/// Builder for [`PromptTemplate`].
104///
105/// Only accepts `&'static str` and [`ScrubbedText`] — **no arbitrary `String`**.
106pub struct PromptTemplateBuilder {
107    segments: Vec<PromptSegment>,
108}
109
110impl PromptTemplateBuilder {
111    /// Append a static text segment (must be a `&'static str` literal).
112    pub fn static_text(mut self, s: &'static str) -> Self {
113        if !s.is_empty() {
114            self.segments.push(PromptSegment::Static(s.to_owned()));
115        }
116        self
117    }
118
119    /// Append pre-scrubbed content.
120    pub fn scrubbed(mut self, t: ScrubbedText) -> Self {
121        self.segments.push(PromptSegment::Scrubbed(t));
122        self
123    }
124
125    /// Build the prompt template.
126    pub fn build(self) -> PromptTemplate {
127        PromptTemplate {
128            segments: self.segments,
129        }
130    }
131}
132
133// ============================================================================
134// TESTS
135// ============================================================================
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::redaction::RedactionManifest;
141
142    fn make_scrubbed(text: &str) -> ScrubbedText {
143        ScrubbedText::new_verified(text.to_string(), RedactionManifest::empty())
144    }
145
146    #[test]
147    fn empty_template() {
148        let t = PromptTemplate::builder().build();
149        assert!(t.is_empty());
150        assert_eq!(t.segment_count(), 0);
151        assert_eq!(t.as_str_lossy(), "");
152    }
153
154    #[test]
155    fn static_only_template() {
156        let t = PromptTemplate::builder()
157            .static_text("<cellstate-memory>\n")
158            .static_text("</cellstate-memory>\n")
159            .build();
160        assert_eq!(t.segment_count(), 2);
161        assert_eq!(
162            t.as_str_lossy(),
163            "<cellstate-memory>\n</cellstate-memory>\n"
164        );
165    }
166
167    #[test]
168    fn static_plus_scrubbed() {
169        let scrubbed = make_scrubbed("artifact content with [REDACTED:ssn:abc12345]");
170        let t = PromptTemplate::builder()
171            .static_text("<cellstate-memory>\n")
172            .scrubbed(scrubbed)
173            .static_text("\n</cellstate-memory>\n")
174            .build();
175        assert_eq!(t.segment_count(), 3);
176        let text = t.as_str_lossy();
177        assert!(text.starts_with("<cellstate-memory>\n"));
178        assert!(text.contains("[REDACTED:ssn:abc12345]"));
179        assert!(text.ends_with("\n</cellstate-memory>\n"));
180    }
181
182    #[test]
183    fn display_matches_as_str_lossy() {
184        let t = PromptTemplate::builder()
185            .static_text("Hello ")
186            .scrubbed(make_scrubbed("world"))
187            .build();
188        assert_eq!(t.to_string(), t.as_str_lossy());
189    }
190
191    #[test]
192    fn empty_static_text_skipped() {
193        let t = PromptTemplate::builder()
194            .static_text("")
195            .static_text("hello")
196            .static_text("")
197            .build();
198        assert_eq!(t.segment_count(), 1);
199        assert_eq!(t.as_str_lossy(), "hello");
200    }
201
202    #[test]
203    fn clone_works() {
204        let t = PromptTemplate::builder()
205            .static_text("test")
206            .scrubbed(make_scrubbed("data"))
207            .build();
208        let cloned = t.clone();
209        assert_eq!(t.as_str_lossy(), cloned.as_str_lossy());
210        assert_eq!(t.segment_count(), cloned.segment_count());
211    }
212
213    #[test]
214    fn serde_roundtrip() {
215        let t = PromptTemplate::builder()
216            .static_text("<system>\n")
217            .scrubbed(make_scrubbed("scrubbed content"))
218            .static_text("\n</system>")
219            .build();
220        let json = serde_json::to_string(&t).unwrap();
221        let deserialized: PromptTemplate = serde_json::from_str(&json).unwrap();
222        assert_eq!(t.as_str_lossy(), deserialized.as_str_lossy());
223        assert_eq!(t.segment_count(), deserialized.segment_count());
224    }
225}