From 1a12200c12a31d240d9b80256e94a45c7a448709 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 18 Mar 2024 15:27:17 -0400 Subject: [PATCH] Add sharing (#38) Add sharing of extractors --- backend/db/models.py | 110 +++++++++++++----- backend/server/api/extractors.py | 67 ++++++++++- backend/server/api/shared.py | 51 ++++++++ backend/server/main.py | 3 +- .../api/test_api_defining_extractors.py | 44 +++++++ 5 files changed, 239 insertions(+), 36 deletions(-) create mode 100644 backend/server/api/shared.py diff --git a/backend/db/models.py b/backend/db/models.py index 1d75eb2..08bfd14 100644 --- a/backend/db/models.py +++ b/backend/db/models.py @@ -2,7 +2,15 @@ from datetime import datetime from typing import Generator -from sqlalchemy import Column, DateTime, ForeignKey, String, Text, create_engine +from sqlalchemy import ( + Column, + DateTime, + ForeignKey, + String, + Text, + UniqueConstraint, + create_engine, +) from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, relationship, sessionmaker @@ -56,37 +64,6 @@ class TimestampedModel(Base): ) -class Extractor(TimestampedModel): - __tablename__ = "extractors" - - name = Column( - String(100), - nullable=False, - server_default="", - comment="The name of the extractor.", - ) - schema = Column( - JSONB, - nullable=False, - comment="JSON Schema that describes what content will be " - "extracted from the document", - ) - description = Column( - String(100), - nullable=False, - server_default="", - comment="Surfaced via UI to the users.", - ) - instruction = Column( - Text, nullable=False, comment="The prompt to the language model." - ) # TODO: This will need to evolve - - examples = relationship("Example", backref="extractor") - - def __repr__(self) -> str: - return f"" - - class Example(TimestampedModel): """A representation of an example. @@ -122,3 +99,72 @@ class Example(TimestampedModel): def __repr__(self) -> str: return f"" + + +class SharedExtractors(TimestampedModel): + """A table for managing sharing of extractors.""" + + __tablename__ = "shared_extractors" + + extractor_id = Column( + UUID(as_uuid=True), + ForeignKey("extractors.uuid", ondelete="CASCADE"), + index=True, + nullable=False, + comment="The extractor that is being shared.", + ) + + share_token = Column( + UUID(as_uuid=True), + index=True, + nullable=False, + unique=True, + comment="The token that is used to access the shared extractor.", + ) + + # Add unique constraint for (extractor_id, share_token) + __table_args__ = ( + UniqueConstraint("extractor_id", "share_token", name="unique_share_token"), + ) + + def __repr__(self) -> str: + """Return a string representation of the object.""" + return f"" + + +class Extractor(TimestampedModel): + __tablename__ = "extractors" + + name = Column( + String(100), + nullable=False, + server_default="", + comment="The name of the extractor.", + ) + schema = Column( + JSONB, + nullable=False, + comment="JSON Schema that describes what content will be " + "extracted from the document", + ) + description = Column( + String(100), + nullable=False, + server_default="", + comment="Surfaced via UI to the users.", + ) + instruction = Column( + Text, nullable=False, comment="The prompt to the language model." + ) # TODO: This will need to evolve + + examples = relationship("Example", backref="extractor") + + # Used for sharing the extractor with others. + share_uuid = Column( + UUID(as_uuid=True), + nullable=True, + comment="The uuid of the shareable link.", + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/server/api/extractors.py b/backend/server/api/extractors.py index 9fb44fe..859d082 100644 --- a/backend/server/api/extractors.py +++ b/backend/server/api/extractors.py @@ -1,12 +1,13 @@ """Endpoints for managing definition of extractors.""" from typing import Any, Dict, List -from uuid import UUID +from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field, validator +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from db.models import Extractor, get_session +from db.models import Extractor, SharedExtractors, get_session from server.validators import validate_json_schema router = APIRouter( @@ -39,7 +40,66 @@ def validate_schema(cls, v: Any) -> Dict[str, Any]: class CreateExtractorResponse(BaseModel): """Response for creating an extractor.""" - uuid: UUID + uuid: UUID = Field(..., description="The UUID of the created extractor.") + + +class ShareExtractorRequest(BaseModel): + """Response for sharing an extractor.""" + + uuid: UUID = Field(..., description="The UUID of the extractor to share.") + + +class ShareExtractorResponse(BaseModel): + """Response for sharing an extractor.""" + + share_uuid: UUID = Field(..., description="The UUID for the shared extractor.") + + +@router.post("/{uuid}/share", response_model=ShareExtractorResponse) +def share( + uuid: UUID, + *, + session: Session = Depends(get_session), +) -> ShareExtractorResponse: + """Endpoint to share an extractor. + + Look up a shared extractor by UUID and return the share UUID if it exists. + If not shared, create a new shared extractor entry and return the new share UUID. + + Args: + uuid: The UUID of the extractor to share. + session: The database session. + + Returns: + The UUID for the shared extractor. + """ + # Check if the extractor is already shared + shared_extractor = ( + session.query(SharedExtractors) + .filter(SharedExtractors.extractor_id == uuid) + .scalar() + ) + + if shared_extractor: + # The extractor is already shared, return the existing share_uuid + return ShareExtractorResponse(share_uuid=shared_extractor.share_token) + + # If not shared, create a new shared extractor entry + new_shared_extractor = SharedExtractors( + extractor_id=uuid, + # This will automatically generate a new UUID for share_token + share_token=uuid4(), + ) + + session.add(new_shared_extractor) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException(status_code=400, detail="Failed to share the extractor.") + + # Return the new share_uuid + return ShareExtractorResponse(share_uuid=new_shared_extractor.share_token) @router.post("") @@ -47,6 +107,7 @@ def create( create_request: CreateExtractor, *, session: Session = Depends(get_session) ) -> CreateExtractorResponse: """Endpoint to create an extractor.""" + instance = Extractor( name=create_request.name, schema=create_request.json_schema, diff --git a/backend/server/api/shared.py b/backend/server/api/shared.py new file mode 100644 index 0000000..77e4c30 --- /dev/null +++ b/backend/server/api/shared.py @@ -0,0 +1,51 @@ +"""Endpoints for working with shared resources.""" +from typing import Any, Dict +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from db.models import Extractor, SharedExtractors, get_session + +router = APIRouter( + prefix="/s", + tags=["extractor definitions"], + responses={404: {"description": "Not found"}}, +) + + +class SharedExtractorResponse(BaseModel): + """Response for sharing an extractor.""" + + # UUID should not be included in the response since it is not a public identifier! + name: str + description: str + # schema is a reserved keyword by pydantic + schema_: Dict[str, Any] = Field(..., alias="schema") + instruction: str + + +@router.get("/{uuid}") +def get( + uuid: UUID, + *, + session: Session = Depends(get_session), +) -> SharedExtractorResponse: + """Get a shared extractor.""" + extractor = ( + session.query(Extractor) + .join(SharedExtractors, Extractor.uuid == SharedExtractors.extractor_id) + .filter(SharedExtractors.share_token == uuid) + .first() + ) + + if not extractor: + raise HTTPException(status_code=404, detail="Extractor not found.") + + return SharedExtractorResponse( + name=extractor.name, + description=extractor.description, + schema=extractor.schema, + instruction=extractor.instruction, + ) diff --git a/backend/server/main.py b/backend/server/main.py index 2945100..2361884 100644 --- a/backend/server/main.py +++ b/backend/server/main.py @@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware from langserve import add_routes -from server.api import examples, extract, extractors, suggest +from server.api import examples, extract, extractors, shared, suggest from server.extraction_runnable import ( ExtractRequest, ExtractResponse, @@ -46,6 +46,7 @@ def ready() -> str: app.include_router(examples.router) app.include_router(extract.router) app.include_router(suggest.router) +app.include_router(shared.router) add_routes( app, diff --git a/backend/tests/unit_tests/api/test_api_defining_extractors.py b/backend/tests/unit_tests/api/test_api_defining_extractors.py index 28e5ac0..2a29c55 100644 --- a/backend/tests/unit_tests/api/test_api_defining_extractors.py +++ b/backend/tests/unit_tests/api/test_api_defining_extractors.py @@ -70,3 +70,47 @@ async def test_extractors_api() -> None: } response = await client.post("/extractors", json=create_request) assert response.status_code == 200 + + +async def test_sharing_extractor() -> None: + """Test sharing an extractor.""" + async with get_async_client() as client: + response = await client.get("/extractors") + assert response.status_code == 200 + assert response.json() == [] + # Verify that we can create an extractor + create_request = { + "name": "Test Name", + "description": "Test Description", + "schema": {"type": "object"}, + "instruction": "Test Instruction", + } + response = await client.post("/extractors", json=create_request) + assert response.status_code == 200 + + uuid = response.json()["uuid"] + + # Verify that the extractor was created + response = await client.post(f"/extractors/{uuid}/share") + assert response.status_code == 200 + assert "share_uuid" in response.json() + share_uuid = response.json()["share_uuid"] + + # Test idempotency + response = await client.post(f"/extractors/{uuid}/share") + assert response.status_code == 200 + assert "share_uuid" in response.json() + assert response.json()["share_uuid"] == share_uuid + + # Check that we can retrieve the shared extractor + response = await client.get(f"/s/{share_uuid}") + assert response.status_code == 200 + keys = sorted(response.json()) + assert keys == ["description", "instruction", "name", "schema"] + + assert response.json() == { + "description": "Test Description", + "instruction": "Test Instruction", + "name": "Test Name", + "schema": {"type": "object"}, + }