Skip to content

Commit

Permalink
Add backend suggest endpoint (#21)
Browse files Browse the repository at this point in the history
Add suggest endpoints on backend
  • Loading branch information
eyurtsev authored Mar 18, 2024
1 parent 33d08ae commit 8673ffd
Show file tree
Hide file tree
Showing 7 changed files with 632 additions and 506 deletions.
8 changes: 4 additions & 4 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ RUN set -eux && \
apt-get update && \
apt-get install -y \
build-essential \
# to install poetry
curl \
# to install psycopg2
libpq-dev python3-dev
libpq-dev \
python3-dev \
libmagic1

# https://python-poetry.org/docs/master/#installing-with-the-official-installer
RUN curl -sSL https://install.python-poetry.org | python -
Expand All @@ -42,4 +42,4 @@ EXPOSE 8000
###
FROM base as development

ENTRYPOINT ["bash", "./scripts/local_entry_point.sh"]
ENTRYPOINT ["bash", "./scripts/local_entry_point.sh"]
1 change: 0 additions & 1 deletion backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker
from sqlalchemy.sql import func

from server.settings import get_postgres_url

Expand Down
1,010 changes: 513 additions & 497 deletions backend/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fastapi = "^0.109.2"
langserve = "^0.0.45"
uvicorn = "^0.27.1"
pydantic = "^1.10"
langchain-openai = "^0.0.6"
langchain-openai = "^0.0.8"
jsonschema = "^4.21.1"
sse-starlette = "^2.0.0"
alembic = "^1.13.1"
Expand Down
111 changes: 111 additions & 0 deletions backend/server/api/suggest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Module to handle the suggest API endpoint.
This is logic that leverages LLMs to suggest an extractor for a given task.
"""
from typing import Optional

from fastapi import APIRouter
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

from server.settings import get_model

router = APIRouter(
prefix="/suggest",
tags=["Suggest an extractor"],
responses={404: {"description": "Not found"}},
)


model = get_model()


class SuggestExtractor(BaseModel):
"""A request to create an extractor."""

description: str = Field(
default="",
description=(
"Short description of what information the extractor is extracting."
),
)
jsonSchema: Optional[str] = Field(
default=None,
description=(
"Existing JSON Schema that describes the entity / "
"information that should be extracted."
),
)


class ExtractorDefinition(BaseModel):
"""Define an information extractor to be used in an information extraction system.""" # noqa: E501

json_schema: str = Field(
...,
description=(
"JSON Schema that describes the entity / "
"information that should be extracted. "
"This schema is specified in JSON Schema format. "
),
)


SUGGEST_PROMPT = ChatPromptTemplate.from_messages(
[
(
"system",
"You are are an expert ontologist and have been asked to help a user "
"define an information extractor.The user will describe an entity, "
"a topic or a piece of information that they would like to extract from "
"text. Based on the user input, you are to provide a schema and "
"description for the extractor. The schema should be a JSON Schema that "
"describes the entity or information to be extracted. information to be "
"extracted. Make sure to include title and description for all the "
"attributes in the schema.The JSON Schema should describe a top level "
"object. The object MUST have a title and description.Unless otherwise "
"stated all entity properties in the schema should be considered optional.",
),
("human", "{input}"),
]
)

suggestion_chain = SUGGEST_PROMPT | model.with_structured_output(
schema=ExtractorDefinition
)

UPDATE_PROMPT = ChatPromptTemplate.from_messages(
[
(
"system",
"You are are an expert ontologist and have been asked to help a user "
"define an information extractor.gThe existing extractor schema is "
"provided.\ng```\n{json_schema}\n```\nThe user will describe a desired "
"modification to the schema (e.g., adding a new field, changing a field "
"type, etc.).Your goal is to provide a new schema that incorporates the "
"user's desired modification.The user may also request a completely new "
"schema, in which case you should provide a new schema based on the "
"user's input, and ignore the existing schema.The JSON Schema should "
"describe a top level object. The object MUST have a title and "
"description.Unless otherwise stated all entity properties in the schema "
"should be considered optional.",
),
("human", "{input}"),
]
)

UPDATE_CHAIN = UPDATE_PROMPT | model.with_structured_output(schema=ExtractorDefinition)


# PUBLIC API


@router.post("")
async def suggest(request: SuggestExtractor) -> ExtractorDefinition:
"""Endpoint to create an extractor."""
if len(request.jsonSchema) > 10:
print(f"Using update chain with {request.jsonSchema}")
return await UPDATE_CHAIN.ainvoke(
{"input": request.description, "json_schema": request.jsonSchema}
)
return await suggestion_chain.ainvoke({"input": request.description})
4 changes: 2 additions & 2 deletions backend/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes

from server.api import examples, extract, extractors
from server.api import examples, extract, extractors, suggest
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
Expand All @@ -23,7 +23,6 @@
],
)


origins = [
"http://localhost:5173",
]
Expand All @@ -46,6 +45,7 @@ def ready():
app.include_router(extractors.router)
app.include_router(examples.router)
app.include_router(extract.router)
app.include_router(suggest.router)

add_routes(
app,
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/integration_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Person(BaseModel):
json={
"input": {
"text": text,
"schema": Person.schema(),
"schema": Person(),
"instructions": "Redact all names using the characters `######`",
"examples": examples,
}
Expand Down

0 comments on commit 8673ffd

Please sign in to comment.