Skip to content

Commit

Permalink
Merge pull request #80 from 6Coders/FIX_Backend
Browse files Browse the repository at this point in the history
Fix Generazione Prompt Backend
  • Loading branch information
ylovato01 authored May 12, 2024
2 parents 7adff72 + 9a86c5a commit e97ec70
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 74 deletions.
25 changes: 6 additions & 19 deletions backend/chatsql/adapter/incoming/EmbeddingGeneratorAdapters.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,20 @@

from typing import List
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
from backend.chatsql.domain.Embedding import Embedding
from backend.chatsql.application.port.outcoming.EmbeddingGeneratorPort import EmbeddingGeneratorPort
import numpy as np

class TestEmbeddingAdapter(EmbeddingGeneratorPort):

def generate(self, texts: List[str]) -> List[Embedding]:
return [Embedding(
text='test',
data=np.array([1, 2, 3], dtype=np.float32)
) for _ in texts]


class HuggingfaceEmbeddingAdapter(EmbeddingGeneratorPort):

def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model = SentenceTransformer(model_name)

def generate(self, texts: List[str]) -> List[Embedding]:
def generate(self, texts: List[str], table_names: List[str]) -> List[Embedding]:
embeddings = []
for text in texts:
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden_states = outputs.last_hidden_state
mean_embedding = np.mean(last_hidden_states.numpy(), axis=1)
embeddings.append(Embedding(text=text, data=mean_embedding))
for text, table_name in zip(texts, table_names):
embedding = self.model.encode(text)
embeddings.append(Embedding(text=text, table_name = table_name, data=embedding))
return embeddings
22 changes: 14 additions & 8 deletions backend/chatsql/adapter/incoming/SearchAlgorithmAdapters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List

from sklearn.metrics.pairwise import cosine_similarity
from transformers.tokenization_utils_base import np

from backend.chatsql.application.port.outcoming.SearchAlgorithmPort import SearchAlgorithmPort
import numpy as np
from chatsql.application.port.outcoming.SearchAlgorithmPort import SearchAlgorithmPort


from backend.chatsql.domain.Embedding import Embedding

Expand All @@ -17,12 +18,17 @@ def search(self, query: Embedding, context: List[Embedding]) -> List[Embedding]:
class KNN(SearchAlgorithmPort):

def __init__(self, top_k: int) -> None:

self._top_k = top_k

def search(self, query: Embedding, context: List[Embedding]) -> List[Embedding]:
m = np.array([emb.data.squeeze(axis=0) for emb in context]).reshape(-1, query.data.shape[1])
similarities = cosine_similarity(m, query.data).flatten()
indices = np.argsort(similarities)[:self._top_k]
a= [context[idx] for idx in indices]
return a
#m = np.array([emb.data for emb in context])
#similarities = cosine_similarity([query.data], m).flatten()
#indices = np.argsort(similarities)[-self._top_k:]
#return [context[idx] for idx in indices]
#context_dict = {emb.table_name: emb.data for emb in context}
similarities = {}
for e in context:
similarity = cosine_similarity([query.data], [e.data])[0][0]
similarities[e.table_name] = similarity
top_tables = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:3]
return top_tables
15 changes: 11 additions & 4 deletions backend/chatsql/adapter/incoming/web/ManagerController.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,19 @@ def handle_upload(self):

def handle_list_files(self):

files = self._visualizzaListaDizionariUseCase.list_all()
data = []

for data in files:
data['loaded'] = data['name'] == self._visualizzaDizionarioCorrenteUseCase.selected
for filename in self._visualizzaListaDizionariUseCase.list_all():

return files
data.append({
'name': '.'.join(filename.split('.')[:-1]),
'loaded': filename == self._visualizzaDizionarioCorrenteUseCase.selected,
'extension': filename.split('.')[-1],
'date': datetime.datetime.fromtimestamp(os.stat(os.path.join(Settings.folder, filename)).st_ctime),
'size': f"{os.stat(os.path.join(Settings.folder, filename)).st_size / 1024.0:.2f} Kb",
})

return data

def handle_selection(self):

Expand Down
8 changes: 6 additions & 2 deletions backend/chatsql/adapter/incoming/web/QueryController.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def handle_prompt_generation(self):
self._loadDizionarioUseCase.load(
self._visualizzaDizionarioCorrenteUseCase.selected
)
risposta = self._richiestaPromptUseCase.query(richiesta)
risposta = self._richiestaPromptUseCase.query(richiesta, self._visualizzaDizionarioCorrenteUseCase.selected)

return { 'result': risposta }
data = {
'result': risposta
}

return data
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from backend.chatsql.utils import Exceptions, Common

import pickle

class EmbeddingRepositoryAdapter(BaseEmbeddingRepository):

def __init__(self) -> None:
Expand All @@ -14,7 +16,9 @@ def __init__(self) -> None:
def save(self, filename: str, embeddings: List[Embedding]) -> bool:
try:
filepath = os.path.join(self._folder, '.'.join(filename.split('.')[:-1]))
np.save(filepath, embeddings)
#np.save(filepath, embeddings)
with open(filepath, 'wb') as file:
pickle.dump(embeddings, file)
return True
except Exception as e:
raise Exceptions.FileNotSaved(f"Error while saving embeddings to {filename}: {e}")
Expand All @@ -23,9 +27,11 @@ def load(self, filename: str) -> List[Embedding]:
if filename is None:
raise Exceptions.EmbeddingsNotLoaded(f"Error while loading embeddings from {filename}: {e}")

filepath = os.path.join(self._folder, filename + '.npy')
embeddings = np.load(filepath, allow_pickle=True)
return embeddings.tolist()
filepath = os.path.join(self._folder, '.'.join(filename.split('.')[:-1]))
#embeddings = np.load(filepath, allow_pickle=True)
with open(filepath, 'rb') as file:
embeddings = pickle.load(file)
return embeddings

def remove(self, filename: str) -> bool:
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
from datetime import datetime
from typing import List, IO

from backend.chatsql.application.port.outcoming.persistance.BaseJSONRepository import BaseJsonRepository
Expand Down Expand Up @@ -57,41 +55,30 @@ def save(self, filename: str, stream: IO[bytes]) -> bool:

def remove(self, filename: str) -> bool:

if filename in self.__filenames():
if filename in self.list_all():
remove(join(self._folder, filename))
return True

return False

def list_all(self) -> List[str]:

data = []

filenames = self.__filenames()

for filename in filenames:

data.append({
'name': '.'.join(filename.split('.')[:-1]),
'extension': filename.split('.')[-1],
'date': datetime.fromtimestamp(os.stat(os.path.join(Settings.folder, filename)).st_ctime),
'size': f"{os.stat(os.path.join(Settings.folder, filename)).st_size / 1024.0:.2f} Kb",
})

return data

def __filenames(self):
return [filename for filename in listdir(self._folder)
return [filename for filename in listdir(self._folder)
if isfile(join(self._folder, filename)) and
filename.split('.')[-1] == 'json']

@staticmethod
def open(filename: str):
secured_filename = secure_filename(filename)
with open(join(Settings.folder, secured_filename), "r") as file:
return json.load(file)

def __is_valid(self, content: str) -> bool:
content = json.loads(content)
return JSONValidator.is_valid_structure(content)

def __already_present(self, filename: str) -> bool:
secured_filename = secure_filename(filename)
return secured_filename in self.__filenames()
return secured_filename in self.list_all()

def __create_folder(self) -> None:
if not exists(self._folder):
Expand Down
27 changes: 22 additions & 5 deletions backend/chatsql/application/EmbeddingManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from backend.chatsql.application.port.outcoming.EmbeddingGeneratorPort import EmbeddingGeneratorPort
from backend.chatsql.application.port.outcoming.persistance.BaseEmbeddingRepository import BaseEmbeddingRepository

from backend.chatsql.application.EmbeddingSaver import EmbeddingSaver
from backend.chatsql.domain.Embedding import Embedding

from chatsql.adapter.outcoming.persistance.JSONRepositoryAdapter import JSONRepositoryAdapter

from chatsql.application.EmbeddingSaver import EmbeddingSaver
from chatsql.domain.Embedding import Embedding


from backend.chatsql.utils.Common import Settings
import os
import json

class EmbeddingManager(EmbeddingSaver):

Expand All @@ -20,11 +25,23 @@ def __init__(self, embeddingRepository: BaseEmbeddingRepository,
def save(self, filename: str) -> bool:
try:

with open(os.path.join(Settings.folder, filename), 'r') as file:
content = file.readlines()
#with open(os.path.join(Settings.folder, filename), 'r') as file:
# content = file.readlines()
data = JSONRepositoryAdapter.open(filename)

embeddings = self._embeddingGeneratorPort.generate(content)
tables = data['tables_info']
table_descriptions =[]
table_names = []
for table_name, table_info in tables.items():
table_descriptions.append(table_info['table_description'])
table_names.append(table_name)
#for column in table_info['columns']:
# table_descriptions.append(column['column_description'])
# table_names.append(table_name)
print(table_descriptions , table_names)
embeddings = self._embeddingGeneratorPort.generate(table_descriptions, table_names)
self._embeddingRepository.save(filename, embeddings)
except BaseException as e:
#print(e)
return e

2 changes: 1 addition & 1 deletion backend/chatsql/application/EmbeddingSaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
class EmbeddingSaver(ABC):

@abstractmethod
def save(self, filename: str, embeddings: List[Embedding]) -> bool:
def save(self, filename: str) -> bool:
pass
38 changes: 34 additions & 4 deletions backend/chatsql/application/PromptService.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .port.outcoming.EmbeddingGeneratorPort import EmbeddingGeneratorPort
from .port.outcoming.SearchAlgorithmPort import SearchAlgorithmPort

from chatsql.adapter.outcoming.persistance.JSONRepositoryAdapter import JSONRepositoryAdapter


from backend.chatsql.utils import Exceptions
from backend.chatsql.domain.Embedding import Embedding
Expand All @@ -31,17 +33,45 @@ def __init__(self,
self._context = None


def query(self, question: str) -> str:
def query(self, question: str, filename: str) -> str:
query_embs = self._embeddingGeneratorPort.generate([question],"null")[0]
context = self._searchAlgorithm.search(
self._embeddingGeneratorPort.generate(question)[0],
query_embs,
self._context
)
context = ' '.join([e.text for e in context])
#context = ' '.join([e.text for e in context])

file_content = JSONRepositoryAdapter.open(filename)
tables = file_content['tables_info']
presult = " "
#result = list(set([e.table_name for e in context]))

for table_name, similarity in context:
table_info = tables[table_name]
presult = presult +"Tabella: "+ table_name
presult = presult +"\nDescrizione: "+ table_info['table_description']
presult = presult +"\nColonne:"
for column in table_info['columns']:
presult = presult +"\n- Nome: "+ column['column_name']
presult = presult +"\n Descrizione :"+ column['column_description']
presult = presult +"\n Tipo: "+ column['attribute_type']
presult = presult +"\n Indice: "+ str(column['index'])
primarykey = "\nChiavi primarie: ".join(file_content['primary_key'][table_name])
presult = presult + primarykey
presult = presult +"\nChiavi esterne: "
for foreign_key in file_content['foreign_keys']:
if foreign_key['table'] == table_name:
presult = presult +"\n- Nome: "+ foreign_key['foreign_key']
presult = presult +"\n Attributo: "+ foreign_key['attribute']
presult = presult +"\n Tabella di riferimento: "+ foreign_key['reference_table']
presult = presult +"\n Attributo di riferimento: "+ foreign_key['reference_attribute']
presult = presult +"\n\n"
#print(presult)

return f"""
Act as a SQL engineer. \n
Given the context below, generate a query for MariaDB to answer the following question: {question}. \n
{context}
{presult}
"""

def load(self, filename: str) -> List[Embedding]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import List
from abc import ABC, abstractmethod

from chatsql.domain.Embedding import Embedding

class LoadDizionarioUseCase(ABC):

@abstractmethod
def load(self, filename: str) -> List[Embedding]:
def load(self, filename: str) -> bool:
pass


Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class EmbeddingGeneratorPort(ABC):

@abstractmethod
def generate(self, texts: List[str]) -> List[Embedding]:
def generate(self, texts: List[str], table_names: List[str]) -> List[Embedding]:
pass


1 change: 1 addition & 0 deletions backend/chatsql/domain/Embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@dataclass
class Embedding:
text: str
table_name: str
data: np.ndarray

def __post_init__(self):
Expand Down
4 changes: 3 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def heartbeat():
def handle(e):
return jsonify(400, e.message)

from backend.chatsql.adapter.incoming.EmbeddingGeneratorAdapters import HuggingfaceEmbeddingAdapter, TestEmbeddingAdapter

from chatsql.adapter.incoming.EmbeddingGeneratorAdapters import HuggingfaceEmbeddingAdapter


if __name__ == '__main__':

Expand Down

0 comments on commit e97ec70

Please sign in to comment.