This repository has been archived by the owner on Feb 15, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): assistants endpoint (#424)
* Adds assistants endpoints for CRUD operations and DB migration * Adds tests for assistants and files (requires Supabase) * Adds Supabase config to API --------- Co-authored-by: gharvey <gato.harvey@defenseunicorns.com>
- Loading branch information
1 parent
064cb84
commit 0c483a1
Showing
22 changed files
with
778 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""CRUD Operations for Assistant.""" | ||
|
||
from openai.types.beta import Assistant | ||
from supabase_py_async import AsyncClient | ||
from leapfrogai_api.data.crud_base import CRUDBase | ||
|
||
|
||
class CRUDAssistant(CRUDBase[Assistant]): | ||
"""CRUD Operations for Assistant""" | ||
|
||
def __init__(self, model: type[Assistant], table_name: str = "assistant_objects"): | ||
super().__init__(model=model, table_name=table_name) | ||
|
||
async def create(self, db: AsyncClient, object_: Assistant) -> Assistant | None: | ||
"""Create a new assistant.""" | ||
return await super().create(db=db, object_=object_) | ||
|
||
async def get(self, id_: str, db: AsyncClient) -> Assistant | None: | ||
"""Get an assistant by its ID.""" | ||
return await super().get(db=db, id_=id_) | ||
|
||
async def list(self, db: AsyncClient) -> list[Assistant] | None: | ||
"""List all assistants.""" | ||
return await super().list(db=db) | ||
|
||
async def update( | ||
self, id_: str, db: AsyncClient, object_: Assistant | ||
) -> Assistant | None: | ||
"""Update an assistant by its ID.""" | ||
return await super().update(id_=id_, db=db, object_=object_) | ||
|
||
async def delete(self, id_: str, db: AsyncClient) -> bool: | ||
"""Delete an assistant by its ID.""" | ||
return await super().delete(id_=id_, db=db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""CRUD Operations for VectorStore.""" | ||
|
||
from typing import Generic, TypeVar | ||
from supabase_py_async import AsyncClient | ||
from pydantic import BaseModel | ||
|
||
ModelType = TypeVar("ModelType", bound=BaseModel) | ||
|
||
|
||
class CRUDBase(Generic[ModelType]): | ||
"""CRUD Operations""" | ||
|
||
def __init__(self, model: type[ModelType], table_name: str): | ||
self.model = model | ||
self.table_name = table_name | ||
|
||
async def create(self, db: AsyncClient, object_: ModelType) -> ModelType | None: | ||
"""Create new row.""" | ||
dict_ = object_.model_dump() | ||
del dict_["id"] # Ensure this is created by the database | ||
del dict_["created_at"] # Ensure this is created by the database | ||
data, _count = await db.table(self.table_name).insert(dict_).execute() | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
|
||
async def get(self, id_: str, db: AsyncClient) -> ModelType | None: | ||
"""Get row by ID.""" | ||
data, _count = ( | ||
await db.table(self.table_name).select("*").eq("id", id_).execute() | ||
) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
|
||
async def list(self, db: AsyncClient) -> list[ModelType] | None: | ||
"""List all rows.""" | ||
data, _count = await db.table(self.table_name).select("*").execute() | ||
|
||
_, response = data | ||
|
||
if response: | ||
return [self.model(**item) for item in response] | ||
return None | ||
|
||
async def update( | ||
self, id_: str, db: AsyncClient, object_: ModelType | ||
) -> ModelType | None: | ||
"""Update a vector store by its ID.""" | ||
data, _count = ( | ||
await db.table(self.table_name) | ||
.update(object_.model_dump()) | ||
.eq("id", id_) | ||
.execute() | ||
) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
|
||
async def delete(self, id_: str, db: AsyncClient) -> bool: | ||
"""Delete a vector store by its ID.""" | ||
data, _count = await db.table(self.table_name).delete().eq("id", id_).execute() | ||
|
||
_, response = data | ||
|
||
return bool(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,34 @@ | ||
"""CRUD Operations for FileObject""" | ||
|
||
from openai.types import FileObject | ||
from supabase_py_async import AsyncClient | ||
from openai.types import FileObject, FileDeleted | ||
from leapfrogai_api.data.crud_base import CRUDBase | ||
|
||
|
||
class CRUDFileObject: | ||
class CRUDFileObject(CRUDBase[FileObject]): | ||
"""CRUD Operations for FileObject""" | ||
|
||
def __init__(self, model: type[FileObject]): | ||
self.model = model | ||
def __init__(self, model: type[FileObject], table_name: str = "file_objects"): | ||
super().__init__(model=model, table_name=table_name) | ||
|
||
async def create( | ||
self, client: AsyncClient, file_object: FileObject | ||
) -> FileObject | None: | ||
async def create(self, db: AsyncClient, object_: FileObject) -> FileObject | None: | ||
"""Create a new file object.""" | ||
file_object_dict = file_object.model_dump() | ||
if file_object_dict.get("id") == "": | ||
del file_object_dict["id"] | ||
data, _count = ( | ||
await client.table("file_objects").insert(file_object_dict).execute() | ||
) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
return await super().create(db=db, object_=object_) | ||
|
||
async def get(self, client: AsyncClient, file_id: str) -> FileObject | None: | ||
async def get(self, id_: str, db: AsyncClient) -> FileObject | None: | ||
"""Get a file object by its ID.""" | ||
data, _count = ( | ||
await client.table("file_objects").select("*").eq("id", file_id).execute() | ||
) | ||
return await super().get(db=db, id_=id_) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
|
||
async def list(self, client: AsyncClient) -> list[FileObject] | None: | ||
async def list(self, db: AsyncClient) -> list[FileObject] | None: | ||
"""List all file objects.""" | ||
data, _count = await client.table("file_objects").select("*").execute() | ||
|
||
_, response = data | ||
|
||
if response: | ||
return [self.model(**item) for item in response] | ||
return None | ||
return await super().list(db=db) | ||
|
||
async def update( | ||
self, client: AsyncClient, file_id: str, file_object: FileObject | ||
self, id_: str, db: AsyncClient, object_: FileObject | ||
) -> FileObject | None: | ||
"""Update a file object by its ID.""" | ||
data, _count = ( | ||
await client.table("file_objects") | ||
.update(file_object.model_dump()) | ||
.eq("id", file_id) | ||
.execute() | ||
) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
return await super().update(id_=id_, db=db, object_=object_) | ||
|
||
async def delete(self, client: AsyncClient, file_id: str) -> FileDeleted: | ||
async def delete(self, id_: str, db: AsyncClient) -> bool: | ||
"""Delete a file object by its ID.""" | ||
data, _count = ( | ||
await client.table("file_objects").delete().eq("id", file_id).execute() | ||
) | ||
|
||
_, response = data | ||
|
||
return FileDeleted(id=file_id, deleted=bool(response), object="file") | ||
return await super().delete(id_=id_, db=db) |
Oops, something went wrong.