-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from 6Coders/FIX_Backend
Fix Generazione Prompt Backend
- Loading branch information
Showing
13 changed files
with
119 additions
and
74 deletions.
There are no files selected for viewing
25 changes: 6 additions & 19 deletions
25
backend/chatsql/adapter/incoming/EmbeddingGeneratorAdapters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,20 @@ | ||
|
||
from typing import List | ||
from transformers import AutoTokenizer, AutoModel | ||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
from backend.chatsql.domain.Embedding import Embedding | ||
from backend.chatsql.application.port.outcoming.EmbeddingGeneratorPort import EmbeddingGeneratorPort | ||
import numpy as np | ||
|
||
class TestEmbeddingAdapter(EmbeddingGeneratorPort): | ||
|
||
def generate(self, texts: List[str]) -> List[Embedding]: | ||
return [Embedding( | ||
text='test', | ||
data=np.array([1, 2, 3], dtype=np.float32) | ||
) for _ in texts] | ||
|
||
|
||
class HuggingfaceEmbeddingAdapter(EmbeddingGeneratorPort): | ||
|
||
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> None: | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModel.from_pretrained(model_name) | ||
self.model = SentenceTransformer(model_name) | ||
|
||
def generate(self, texts: List[str]) -> List[Embedding]: | ||
def generate(self, texts: List[str], table_names: List[str]) -> List[Embedding]: | ||
embeddings = [] | ||
for text in texts: | ||
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | ||
with torch.no_grad(): | ||
outputs = self.model(**inputs) | ||
last_hidden_states = outputs.last_hidden_state | ||
mean_embedding = np.mean(last_hidden_states.numpy(), axis=1) | ||
embeddings.append(Embedding(text=text, data=mean_embedding)) | ||
for text, table_name in zip(texts, table_names): | ||
embedding = self.model.encode(text) | ||
embeddings.append(Embedding(text=text, table_name = table_name, data=embedding)) | ||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 1 addition & 3 deletions
4
backend/chatsql/application/port/incoming/LoadDizionarioUseCase.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,10 @@ | ||
from typing import List | ||
from abc import ABC, abstractmethod | ||
|
||
from chatsql.domain.Embedding import Embedding | ||
|
||
class LoadDizionarioUseCase(ABC): | ||
|
||
@abstractmethod | ||
def load(self, filename: str) -> List[Embedding]: | ||
def load(self, filename: str) -> bool: | ||
pass | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
@dataclass | ||
class Embedding: | ||
text: str | ||
table_name: str | ||
data: np.ndarray | ||
|
||
def __post_init__(self): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters