Skip to content

Commit

Permalink
feat: endpoint to register function calling on shortcuts
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 23, 2025
1 parent ff5764b commit 99bcf70
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 39 deletions.
Binary file modified pipes/search/bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion pipes/search/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"@radix-ui/react-switch": "^1.1.2",
"@radix-ui/react-toast": "^1.2.3",
"@radix-ui/react-tooltip": "^1.1.5",
"@screenpipe/browser": "^0.1.28",
"@screenpipe/browser": "^0.1.33",
"@screenpipe/js": "^1.0.12",
"@shadcn/ui": "^0.0.4",
"@tanstack/react-query": "^5.62.7",
Expand Down
83 changes: 82 additions & 1 deletion screenpipe-app-tauri/src-tauri/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::window_api::{close_window, show_specific_window};
use crate::{get_base_dir, get_store};
use crate::{get_base_dir, get_store, register_shortcut};
use axum::body::Bytes;
use axum::response::sse::{Event, Sse};
use axum::response::IntoResponse;
Expand All @@ -12,6 +12,7 @@ use futures::stream::Stream;
use http::header::{HeaderValue, CONTENT_TYPE};
use notify::RecursiveMode;
use notify::Watcher;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::net::SocketAddr;
Expand Down Expand Up @@ -95,6 +96,15 @@ struct WindowSizePayload {
height: f64,
}

#[derive(Deserialize, Debug)]
struct ShortcutRegistrationPayload {
shortcut: String,
endpoint: String,
method: String,
#[serde(default)]
body: Option<serde_json::Value>,
}

async fn settings_stream(
State(state): State<ServerState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
Expand Down Expand Up @@ -166,6 +176,10 @@ pub async fn run_server(app_handle: tauri::AppHandle, port: u16) {
.route("/sidecar/stop", axum::routing::post(stop_sidecar))
.route("/window", axum::routing::post(show_specific_window))
.route("/window/close", axum::routing::post(close_window))
.route(
"/shortcuts/register",
axum::routing::post(register_http_shortcut),
)
.layer(cors)
.layer(
TraceLayer::new_for_http()
Expand Down Expand Up @@ -412,6 +426,73 @@ async fn stop_sidecar(
}
}

async fn register_http_shortcut(
State(state): State<ServerState>,
Json(payload): Json<ShortcutRegistrationPayload>,
) -> Result<Json<ApiResponse>, (StatusCode, String)> {
info!("registering http shortcut: {:?}", payload);

let client = Client::new();
let endpoint = payload.endpoint.clone();
let method = payload.method.clone();
let body = payload.body.clone();

let handler = move |_app: &tauri::AppHandle| {
info!("executing http shortcut");
let client = client.clone();
let endpoint = endpoint.clone();
let method = method.clone();
let body = body.clone();

tokio::spawn(async move {
let request = match method.to_uppercase().as_str() {
"GET" => client.get(&endpoint),
"POST" => client.post(&endpoint),
"PUT" => client.put(&endpoint),
"DELETE" => client.delete(&endpoint),
_ => {
error!("unsupported http method: {}", method);
return;
}
};

let request = if let Some(body) = body {
request.json(&body)
} else {
request
};

match request.send().await {
Ok(response) => {
info!(
"http shortcut request completed with status: {}",
response.status()
);
}
Err(e) => {
error!("http shortcut request failed: {}", e);
}
}
});
};

// TODO persist in settings?

match register_shortcut(&state.app_handle, &payload.shortcut, false, handler).await {
Ok(_) => Ok(Json(ApiResponse {
success: true,
message: format!("shortcut {} registered successfully", payload.shortcut),
})),
Err(e) => {
error!("failed to register shortcut: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to register shortcut: {}", e),
))
}
}
}

pub fn spawn_server(app_handle: tauri::AppHandle, port: u16) -> mpsc::Sender<()> {
let (tx, mut rx) = mpsc::channel(1);

Expand Down
2 changes: 1 addition & 1 deletion screenpipe-core/src/embedding/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl EmbeddingModel {
let (model_path, tokenizer_path) = if model_path.is_none() || tokenizer_path.is_none() {
let api = Api::new()?;
let repo = api.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
"jinaai/jina-embeddings-v3".to_string(),
RepoType::Model,
));
(repo.get("model.safetensors")?, repo.get("tokenizer.json")?)
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-js/browser-sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@screenpipe/browser",
"version": "0.1.32",
"version": "0.1.33",
"type": "module",
"main": "./dist/index.cjs",
"module": "./dist/index.js",
Expand Down
65 changes: 42 additions & 23 deletions screenpipe-js/browser-sdk/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -604,44 +604,63 @@ class BrowserPipeImpl implements BrowserPipe {
}

try {
const response = await fetch("http://localhost:3030/v1/embeddings", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: "all-MiniLM-L6-v2",
// Send as Multiple variant even for single strings to be consistent
input: texts.length === 1 ? texts[0] : texts,
encoding_format: "float",
}),
});
const BATCH_SIZE = 50;
const batches = [];

if (!response.ok) {
const errorText = await response.text();
throw new Error(
`http error! status: ${response.status}, error: ${errorText}`
);
// Split into batches of 50
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
batches.push(texts.slice(i, i + BATCH_SIZE));
}

const data = await response.json();
const embeddings = data.data.map((d: any) => d.embedding);
// Process batches in parallel
const batchResults = await Promise.all(
batches.map(async (batch) => {
const response = await fetch("http://localhost:3030/v1/embeddings", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: "all-MiniLM-L6-v2",
input: batch.length === 1 ? batch[0] : batch,
encoding_format: "float",
}),
});

if (!response.ok) {
throw new Error(`http error! status: ${response.status}`);
}

// group similar texts using cosine similarity
const data = await response.json();
return data.data.map((d: any) => ({
text: batch[data.data.indexOf(d)],
embedding: d.embedding,
}));
})
);

// Flatten batch results
const allEmbeddings = batchResults.flat();

// Group similar texts using cosine similarity
const similarityThreshold = 0.9;
const groups: { text: string; similar: string[] }[] = [];
const used = new Set<number>();

for (let i = 0; i < texts.length; i++) {
for (let i = 0; i < allEmbeddings.length; i++) {
if (used.has(i)) continue;

const group = { text: texts[i], similar: [] as string[] };
const group = { text: allEmbeddings[i].text, similar: [] as string[] };
used.add(i);

for (let j = i + 1; j < texts.length; j++) {
for (let j = i + 1; j < allEmbeddings.length; j++) {
if (used.has(j)) continue;

const similarity = cosineSimilarity(embeddings[i], embeddings[j]);
const similarity = cosineSimilarity(
allEmbeddings[i].embedding,
allEmbeddings[j].embedding
);

if (similarity > similarityThreshold) {
group.similar.push(texts[j]);
group.similar.push(allEmbeddings[j].text);
used.add(j);
}
}
Expand Down
48 changes: 36 additions & 12 deletions screenpipe-server/src/embedding/embedding_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,52 @@ pub async fn get_or_initialize_model() -> anyhow::Result<Arc<Mutex<EmbeddingMode
pub async fn create_embeddings(
Json(request): Json<EmbeddingRequest>,
) -> Result<Json<EmbeddingResponse>, (axum::http::StatusCode, String)> {
let model = get_or_initialize_model()
.await
.map_err(|e| (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
tracing::debug!("processing embedding request for model: {}", request.model);

let model = get_or_initialize_model().await.map_err(|e| {
tracing::error!("failed to initialize embedding model: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"failed to initialize embedding model".to_string(),
)
})?;

let model = model.lock().await;

let (texts, _is_single) = match request.input {
EmbeddingInput::Single(text) => (vec![text], true),
EmbeddingInput::Multiple(texts) => (texts, false),
let (texts, _) = match request.input {
EmbeddingInput::Single(text) => {
tracing::debug!("processing single text embedding, length: {}", text.len());
(vec![text], true)
}
EmbeddingInput::Multiple(texts) => {
tracing::debug!("processing batch embedding, count: {}", texts.len());
(texts, false)
}
};

// Generate embeddings
let embeddings = if texts.len() == 1 {
vec![model
.generate_embedding(&texts[0])
.map_err(|e| (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?]
tracing::debug!("generating single embedding");
vec![model.generate_embedding(&texts[0]).map_err(|e| {
tracing::error!("failed to generate single embedding: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"embedding generation failed".to_string(),
)
})?]
} else {
model
.generate_batch_embeddings(&texts)
.map_err(|e| (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
tracing::debug!("generating batch embeddings");
model.generate_batch_embeddings(&texts).map_err(|e| {
tracing::error!("failed to generate batch embeddings: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"batch embedding generation failed".to_string(),
)
})?
};

tracing::debug!("successfully generated {} embeddings", embeddings.len());

// Create response
let data = embeddings
.into_iter()
Expand Down

0 comments on commit 99bcf70

Please sign in to comment.