Skip to content

Commit

Permalink
Prompts with owned templates
Browse files Browse the repository at this point in the history
  • Loading branch information
timonv committed Jan 1, 2025
1 parent 4086843 commit 2092aa1
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 46 deletions.
4 changes: 2 additions & 2 deletions swiftide-agents/src/system_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct SystemPrompt {

/// The template to use for the system prompt
#[builder(default = default_prompt_template())]
template: Template,
template: Template<'static>,
}

impl SystemPrompt {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl SystemPromptBuilder {
}
}

fn default_prompt_template() -> Template {
fn default_prompt_template() -> Template<'static> {
include_str!("system_prompt_template.md").into()
}

Expand Down
36 changes: 22 additions & 14 deletions swiftide-core/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,24 @@
//! assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
//! # }
//! ```
use anyhow::{Context as _, Result};
use lazy_static::lazy_static;
use tera::Tera;
use tokio::sync::RwLock;
use uuid::Uuid;
use anyhow::Result;

use crate::{node::Node, template::Template};

/// A Prompt can be used with large language models to prompt.
#[derive(Clone, Debug)]
pub struct Prompt {
template: Template,
// Should always have an owned template
template: Template<'static>,
context: Option<tera::Context>,
}

#[deprecated(
since = "0.16.0",
note = "Use `Template` instead; they serve a more general purpose"
)]
pub type PromptTemplate = Template;
pub type PromptTemplate<'inner> = Template<'inner>;

impl Prompt {
/// Adds an `ingestion::Node` to the context of the Prompt
Expand Down Expand Up @@ -95,7 +93,7 @@ impl Prompt {
impl From<&'static str> for Prompt {
fn from(prompt: &'static str) -> Self {
Prompt {
template: Template::Static(prompt),
template: Template::OneOff(prompt.into()).to_owned(),
context: None,
}
}
Expand All @@ -104,23 +102,34 @@ impl From<&'static str> for Prompt {
impl From<String> for Prompt {
fn from(prompt: String) -> Self {
Prompt {
template: Template::String(prompt),
template: Template::OneOff(prompt.into()).to_owned(),
context: None,
}
}
}

impl From<&Template> for Prompt {
fn from(template: &Template) -> Self {
impl From<&Template<'static>> for Prompt {
fn from(template: &Template<'static>) -> Self {
Prompt {
template: template.clone(),
context: None,
}
}
}

impl From<Template<'static>> for Prompt {
fn from(template: Template<'static>) -> Self {
Prompt {
template,
context: None,
}
}
}

#[cfg(test)]
mod test {
use tera::Tera;

use super::*;

#[tokio::test]
Expand Down Expand Up @@ -160,9 +169,8 @@ mod test {

Template::extend(&custom_tera).await.unwrap();

let prompt = Template::from_compiled_template_name("hello")
.to_prompt()
.with_context_value("world", "swiftide");
let template = Template::from_compiled_template_name("hello");
let prompt = template.to_prompt().with_context_value("world", "swiftide");

assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
}
Expand Down
60 changes: 39 additions & 21 deletions swiftide-core/src/template.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use anyhow::{Context as _, Result};
use tokio::sync::RwLock;

Expand Down Expand Up @@ -25,20 +27,24 @@ lazy_static! {
}
/// A `Template` defines a template for a prompt
#[derive(Clone, Debug)]
pub enum Template {
CompiledTemplate(String),
String(String),
Static(&'static str),
pub enum Template<'inner> {
/// A reference to a compiled template stored in the template repository
/// These can also be created on the fly with `Template::try_compiled_from_str`,
/// or retrieved at runtime with `Template::from_compiled_template_name`
CompiledTemplate(Cow<'inner, str>),

/// A one-off template that is not stored in the repository
OneOff(Cow<'inner, str>),
}

impl Template {
impl<'inner> Template<'inner> {
/// Creates a reference to a template already stored in the repository
pub fn from_compiled_template_name(name: impl Into<String>) -> Template {
pub fn from_compiled_template_name(name: impl Into<Cow<'inner, str>>) -> Template<'inner> {
Template::CompiledTemplate(name.into())
}

pub fn from_string(template: impl Into<String>) -> Template {
Template::String(template.into())
pub fn from_string(template: impl Into<Cow<'inner, str>>) -> Template<'inner> {
Template::OneOff(template.into())
}

/// Extends the prompt repository with a custom [`tera::Tera`] instance.
Expand Down Expand Up @@ -69,13 +75,13 @@ impl Template {
/// Errors if the template fails to compile
pub async fn try_compiled_from_str(
template: impl AsRef<str> + Send + 'static,
) -> Result<Template> {
) -> Result<Template<'inner>> {
let id = Uuid::new_v4().to_string();
let mut lock = TEMPLATE_REPOSITORY.write().await;
lock.add_raw_template(&id, template.as_ref())
.context("Failed to add raw template")?;

Ok(Template::CompiledTemplate(id))
Ok(Template::CompiledTemplate(id.into()))
}

/// Renders a template with an optional `tera::Context`
Expand All @@ -86,13 +92,13 @@ impl Template {
/// - One-off template has errors
/// - Context is missing that is required by the template
pub async fn render(&self, context: &tera::Context) -> Result<String> {
use Template::{CompiledTemplate, Static, String};
use Template::{CompiledTemplate, OneOff};

let template = match self {
CompiledTemplate(id) => {
let lock = TEMPLATE_REPOSITORY.read().await;
tracing::debug!(
id,
?id,
available = ?lock.get_template_names().collect::<Vec<_>>(),
"Rendering template ..."
);
Expand All @@ -107,28 +113,40 @@ impl Template {
}
result.with_context(|| format!("Failed to render template '{id}'"))?
}
String(template) => Tera::one_off(template, context, false)
.context("Failed to render one-off template")?,
Static(template) => Tera::one_off(template, context, false)
OneOff(template) => Tera::one_off(template, context, false)
.context("Failed to render one-off template")?,
};
Ok(template)
}
}

impl Template<'_> {
/// Creates an owned version of the template
///
// NOTE: std ToOwned and Clone preserve the Cow types, which is not what we want
pub fn to_owned(&self) -> Template<'static> {
match self {
Template::CompiledTemplate(template) => {
Template::CompiledTemplate(template.clone().into_owned().into())
}
Template::OneOff(template) => Template::OneOff(template.clone().into_owned().into()),
}
}

/// Builds a Prompt from a template with an empty context
pub fn to_prompt(&self) -> Prompt {
self.into()
Prompt::from(self.to_owned())
}
}

impl From<&'static str> for Template {
fn from(template: &'static str) -> Self {
Template::Static(template)
impl<'inner> From<&'inner str> for Template<'inner> {
fn from(template: &'inner str) -> Self {
Template::OneOff(template.into())
}
}

impl From<String> for Template {
impl From<String> for Template<'_> {
fn from(template: String) -> Self {
Template::String(template)
Template::OneOff(template.into())
}
}
2 changes: 1 addition & 1 deletion swiftide-macros/src/indexing_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(crate) fn indexing_transformer_impl(args: TokenStream, input: ItemStruct) ->

let default_prompt_fn = match &args.default_prompt_file {
Some(file) => quote! {
fn default_prompt() -> hidden::Template {
fn default_prompt() -> hidden::Template<'static> {
include_str!(#file).into()
}
},
Expand Down
2 changes: 2 additions & 0 deletions swiftide-query/src/query/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, sta

impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
/// Generates an answer based on previous transformations
///
/// For a lot of use cases, `answers::Simple` should be sufficient
#[must_use]
pub fn then_answer<T: Answer + 'stream>(
self,
Expand Down
16 changes: 8 additions & 8 deletions swiftide-query/src/response_transformers/summary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use swiftide_core::{
};

#[derive(Debug, Clone, Builder)]
pub struct Summary {
pub struct Summary<'a> {
#[builder(setter(custom))]
client: Arc<dyn SimplePrompt>,
#[builder(default = "default_prompt()")]
prompt_template: Template,
prompt_template: Template<'a>,
}

impl Summary {
pub fn builder() -> SummaryBuilder {
impl Summary<'_> {
pub fn builder() -> SummaryBuilder<'static> {
SummaryBuilder::default()
}

Expand All @@ -28,7 +28,7 @@ impl Summary {
/// # Panics
///
/// Panics if the build failed
pub fn from_client(client: impl SimplePrompt + 'static) -> Summary {
pub fn from_client(client: impl SimplePrompt + 'static) -> Self {
SummaryBuilder::default()
.client(client)
.to_owned()
Expand All @@ -37,14 +37,14 @@ impl Summary {
}
}

impl SummaryBuilder {
impl SummaryBuilder<'_> {
pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
self
}
}

fn default_prompt() -> Template {
fn default_prompt() -> Template<'static> {
indoc::indoc!(
"
Your job is to help a query tool find the right context.
Expand All @@ -69,7 +69,7 @@ fn default_prompt() -> Template {
}

#[async_trait]
impl TransformResponse for Summary {
impl<'a> TransformResponse for Summary<'a> {
#[tracing::instrument(skip_all)]
async fn transform_response(
&self,
Expand Down

0 comments on commit 2092aa1

Please sign in to comment.