Skip to content

Commit

Permalink
feat(query): Generic templates with document rendering (#520)
Browse files Browse the repository at this point in the history
Reworks `PromptTemplate` to a more generic `Template`, such that they
can also be used elsewhere. This deprecates `PromptTemplate`.

As an example, an optional `Template` in the `Simple` answer
transformer, which can be used to customize the output of retrieved
documents. This has excellent synergy with the metadata changes in #504.
  • Loading branch information
timonv authored Jan 2, 2025
1 parent c35df55 commit 3254bd3
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 212 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/hybrid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
//
// ### Usage of Prompts in Transformers
//
// Swiftide utilizes the [`PromptTemplate`] for templating prompts, making it easy to define and manage prompts within transformers.
// Swiftide utilizes the [`Template`] for templating prompts, making it easy to define and manage prompts within transformers.
//
// ```rust
// let template = PromptTemplate::try_compiled_from_str("hello {{world}}").await.unwrap();
Expand Down
8 changes: 4 additions & 4 deletions swiftide-agents/src/system_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! be provided on the agent level.
use derive_builder::Builder;
use swiftide_core::prompt::{Prompt, PromptTemplate};
use swiftide_core::{prompt::Prompt, template::Template};

#[derive(Clone, Debug, Builder)]
#[builder(setter(into, strip_option))]
Expand All @@ -26,9 +26,9 @@ pub struct SystemPrompt {
#[builder(default, setter(custom))]
constraints: Vec<String>,

/// The template to use
/// The template to use for the system prompt
#[builder(default = default_prompt_template())]
template: PromptTemplate,
template: Template,
}

impl SystemPrompt {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl SystemPromptBuilder {
}
}

fn default_prompt_template() -> PromptTemplate {
fn default_prompt_template() -> Template {
include_str!("system_prompt_template.md").into()
}

Expand Down
2 changes: 2 additions & 0 deletions swiftide-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod type_aliases;

pub mod document;
pub mod prompt;
pub mod template;
pub use type_aliases::*;

mod metadata;
Expand All @@ -34,6 +35,7 @@ pub mod indexing {
}

pub mod querying {
pub use crate::document::*;
pub use crate::query::*;
pub use crate::query_evaluation::*;
pub use crate::query_stream::*;
Expand Down
201 changes: 36 additions & 165 deletions swiftide-core/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
//! Transformers in Swiftide come with default prompts, and they can be customized or replaced as
//! needed.
//!
//! [`PromptTemplate`] can be added with [`PromptTemplate::try_compiled_from_str`]. Prompts can also be
//! [`Template`] can be added with [`Template::try_compiled_from_str`]. Prompts can also be
//! created on the fly from anything that implements [`Into<String>`]. Compiled prompts are stored in
//! an internal repository.
//!
//! Additionally, `PromptTemplate::String` and `PromptTemplate::Static` can be used to create
//! Additionally, `Template::String` and `Template::Static` can be used to create
//! templates on the fly as well.
//!
//! It's recommended to precompile your templates.
Expand All @@ -24,171 +24,29 @@
//! ```
//! #[tokio::main]
//! # async fn main() {
//! # use swiftide_core::prompt::PromptTemplate;
//! let template = PromptTemplate::try_compiled_from_str("hello {{world}}").await.unwrap();
//! # use swiftide_core::template::Template;
//! let template = Template::try_compiled_from_str("hello {{world}}").await.unwrap();
//! let prompt = template.to_prompt().with_context_value("world", "swiftide");
//!
//! assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
//! # }
//! ```
use anyhow::{Context as _, Result};
use lazy_static::lazy_static;
use tera::Tera;
use tokio::sync::RwLock;
use uuid::Uuid;
use anyhow::Result;

use crate::node::Node;

lazy_static! {
/// Tera repository for templates
static ref TEMPLATE_REPOSITORY: RwLock<Tera> = {
let prefix = env!("CARGO_MANIFEST_DIR");
let path = format!("{prefix}/src/transformers/prompts/**/*.prompt.md");

match Tera::new(&path)
{
Ok(t) => RwLock::new(t),
Err(e) => {
tracing::error!("Parsing error(s): {e}");
::std::process::exit(1);
}
}
};
}
use crate::{node::Node, template::Template};

/// A Prompt can be used with large language models to prompt.
#[derive(Clone, Debug)]
pub struct Prompt {
template: PromptTemplate,
template: Template,
context: Option<tera::Context>,
}

/// A `PromptTemplate` defines a template for a prompt
#[derive(Clone, Debug)]
pub enum PromptTemplate {
CompiledTemplate(String),
String(String),
Static(&'static str),
}

impl PromptTemplate {
/// Creates a reference to a template already stored in the repository
pub fn from_compiled_template_name(name: impl Into<String>) -> PromptTemplate {
PromptTemplate::CompiledTemplate(name.into())
}

pub fn from_string(template: impl Into<String>) -> PromptTemplate {
PromptTemplate::String(template.into())
}

/// Extends the prompt repository with a custom [`tera::Tera`] instance.
///
/// If you have your own prompt templates or want to add other functionality, you can extend
/// the repository with your own [`tera::Tera`] instance.
///
/// WARN: Do not use this inside a pipeline or any form of load, as it will lock the repository
///
/// # Errors
///
/// Errors if the repository could not be extended
pub async fn extend(tera: &Tera) -> Result<()> {
TEMPLATE_REPOSITORY
.write()
.await
.extend(tera)
.context("Could not extend prompt repository with custom Tera instance")
}

/// Compiles a template from a string and returns a `PromptTemplate` with a reference to the
/// string.
///
/// WARN: Do not use this inside a pipeline or any form of load, as it will lock the repository
///
/// # Errors
///
/// Errors if the template fails to compile
pub async fn try_compiled_from_str(
template: impl AsRef<str> + Send + 'static,
) -> Result<PromptTemplate> {
let id = Uuid::new_v4().to_string();
let mut lock = TEMPLATE_REPOSITORY.write().await;
lock.add_raw_template(&id, template.as_ref())
.context("Failed to add raw template")?;

Ok(PromptTemplate::CompiledTemplate(id))
}

/// Renders a template with an optional `tera::Context`
///
/// # Errors
///
/// - Template cannot be found
/// - One-off template has errors
/// - Context is missing that is required by the template
pub async fn render(&self, context: &Option<tera::Context>) -> Result<String> {
use PromptTemplate::{CompiledTemplate, Static, String};

let template = match self {
CompiledTemplate(id) => {
let context = match &context {
Some(context) => context,
None => &tera::Context::default(),
};

let lock = TEMPLATE_REPOSITORY.read().await;
let available = lock.get_template_names().collect::<Vec<_>>().join(", ");
tracing::debug!(id, available, "Rendering template ...");
let result = lock.render(id, context);

if result.is_err() {
tracing::error!(
error = result.as_ref().unwrap_err().to_string(),
available,
"Error rendering template {id}"
);
}
result.with_context(|| format!("Failed to render template '{id}'"))?
}
String(template) => {
if let Some(context) = context {
Tera::one_off(template, context, false)
.context("Failed to render one-off template")?
} else {
template.to_string()
}
}
Static(template) => {
if let Some(context) = context {
Tera::one_off(template, context, false)
.context("Failed to render one-off template")?
} else {
(*template).to_string()
}
}
};
Ok(template)
}

/// Builds a Prompt from a template with an empty context
pub fn to_prompt(&self) -> Prompt {
Prompt {
template: self.clone(),
context: Some(tera::Context::default()),
}
}
}

impl From<&'static str> for PromptTemplate {
fn from(template: &'static str) -> Self {
PromptTemplate::Static(template)
}
}

impl From<String> for PromptTemplate {
fn from(template: String) -> Self {
PromptTemplate::String(template)
}
}
#[deprecated(
since = "0.16.0",
note = "Use `Template` instead; they serve a more general purpose"
)]
pub type PromptTemplate = Template;

impl Prompt {
/// Adds an `ingestion::Node` to the context of the Prompt
Expand Down Expand Up @@ -220,16 +78,20 @@ impl Prompt {
///
/// # Errors
///
/// See `PromptTemplate::render`
/// See `Template::render`
pub async fn render(&self) -> Result<String> {
self.template.render(&self.context).await
if let Some(context) = &self.context {
self.template.render(context).await
} else {
self.template.render(&tera::Context::default()).await
}
}
}

impl From<&'static str> for Prompt {
fn from(prompt: &'static str) -> Self {
Prompt {
template: PromptTemplate::Static(prompt),
template: Template::Static(prompt),
context: None,
}
}
Expand All @@ -238,7 +100,16 @@ impl From<&'static str> for Prompt {
impl From<String> for Prompt {
fn from(prompt: String) -> Self {
Prompt {
template: PromptTemplate::String(prompt),
template: Template::String(prompt),
context: None,
}
}
}

impl From<&Template> for Prompt {
fn from(template: &Template) -> Self {
Prompt {
template: template.clone(),
context: None,
}
}
Expand All @@ -250,7 +121,7 @@ mod test {

#[tokio::test]
async fn test_prompt() {
let template = PromptTemplate::try_compiled_from_str("hello {{world}}")
let template = Template::try_compiled_from_str("hello {{world}}")
.await
.unwrap();
let prompt = template.to_prompt().with_context_value("world", "swiftide");
Expand All @@ -259,7 +130,7 @@ mod test {

#[tokio::test]
async fn test_prompt_with_node() {
let template = PromptTemplate::try_compiled_from_str("hello {{node.chunk}}")
let template = Template::try_compiled_from_str("hello {{node.chunk}}")
.await
.unwrap();
let node = Node::new("test");
Expand All @@ -277,15 +148,15 @@ mod test {

#[tokio::test]
async fn test_extending_with_custom_repository() {
let mut custom_tera = Tera::new("**/some/prompts.md").unwrap();
let mut custom_tera = tera::Tera::new("**/some/prompts.md").unwrap();

custom_tera
.add_raw_template("hello", "hello {{world}}")
.unwrap();

PromptTemplate::extend(&custom_tera).await.unwrap();
Template::extend(&custom_tera).await.unwrap();

let prompt = PromptTemplate::from_compiled_template_name("hello")
let prompt = Template::from_compiled_template_name("hello")
.to_prompt()
.with_context_value("world", "swiftide");

Expand Down Expand Up @@ -322,7 +193,7 @@ mod test {
async fn test_coercion_to_template() {
let raw: &str = "hello {{world}}";

let prompt: PromptTemplate = raw.into();
let prompt: Template = raw.into();
assert_eq!(
prompt
.to_prompt()
Expand All @@ -333,7 +204,7 @@ mod test {
"hello swiftide"
);

let prompt: PromptTemplate = raw.to_string().into();
let prompt: Template = raw.to_string().into();
assert_eq!(
prompt
.to_prompt()
Expand Down
Loading

0 comments on commit 3254bd3

Please sign in to comment.