From 1efdf0de38d3a91e3e62a7f282c08a106d5bcd12 Mon Sep 17 00:00:00 2001 From: Robert Date: Tue, 7 May 2024 21:52:42 -0700 Subject: [PATCH] Groq API now works. --- diarize.py | 51 ++++++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/diarize.py b/diarize.py index 3379fd2d4..da27c73f1 100644 --- a/diarize.py +++ b/diarize.py @@ -817,7 +817,7 @@ def summarize_with_cohere(api_key, file_path, model): segments = json.load(file) logging.debug(f"cohere: Extracting text from segments file") - text = extract_text_from_segments(segments) # Make sure this function is defined + text = extract_text_from_segments(segments) headers = { 'accept': 'application/json', @@ -862,40 +862,41 @@ def summarize_with_cohere(api_key, file_path, model): # https://console.groq.com/docs/quickstart -def summarize_with_groq(api_url, file_path, token): +def summarize_with_groq(api_key, file_path, model): try: logging.debug("groq: Loading JSON data") with open(file_path, 'r') as file: segments = json.load(file) logging.debug(f"groq: Extracting text from segments file") - text = extract_text_from_segments(segments) # Define this function to extract text properly + text = extract_text_from_segments(segments) headers = { - 'accept': 'application/json', - 'content-type': 'application/json', - 'Authorization': f'Bearer {token}' + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json' } prompt_text = f"{text} \n\nAs a professional summarizer, create a concise and comprehensive summary of the provided text." data = { - "model": "groq-gpt-j-6b", - "prompt": prompt_text, - "max_tokens": 1024, - "temperature": 0.7, - "top_p": 1.0, - "stop": None + "messages": [ + { + "role": "user", + "content": prompt_text + } + ], + "model": model } logging.debug("groq: Submitting request to API endpoint") print("groq: Submitting request to API endpoint") - response = requests.post(api_url, headers=headers, json=data) + response = requests.post('https://api.groq.com/openai/v1/chat/completions', headers=headers, json=data) + response_data = response.json() logging.debug("API Response Data: %s", response_data) if response.status_code == 200: if 'choices' in response_data and len(response_data['choices']) > 0: - summary = response_data['choices'][0]['text'].strip() + summary = response_data['choices'][0]['message']['content'].strip() logging.debug("groq: Summarization successful") print("Summarization successful.") return summary @@ -911,6 +912,9 @@ def summarize_with_groq(api_url, file_path, token): return f"groq: Error occurred while processing summary with groq: {str(e)}" +################################# +# +# Local Summarization def summarize_with_llama(api_url, file_path, token): try: @@ -957,14 +961,6 @@ def summarize_with_llama(api_url, file_path, token): -def save_summary_to_file(summary, file_path): - summary_file_path = file_path.replace('.segments.json', '_summary.txt') - logging.debug("Opening summary file for writing, *segments.json with *_summary.txt") - with open(summary_file_path, 'w') as file: - file.write(summary) - logging.info(f"Summary saved to file: {summary_file_path}") - - # https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API def summarize_with_oobabooga(api_url, file_path): try: @@ -1023,6 +1019,15 @@ def summarize_with_oobabooga(api_url, file_path): logging.error("oobabooga: Error in processing: %s", str(e)) return f"oobabooga: Error occurred while processing summary with oobabooga: {str(e)}" + + +def save_summary_to_file(summary, file_path): + summary_file_path = file_path.replace('.segments.json', '_summary.txt') + logging.debug("Opening summary file for writing, *segments.json with *_summary.txt") + with open(summary_file_path, 'w') as file: + file.write(summary) + logging.info(f"Summary saved to file: {summary_file_path}") + # # #################################################################################################################################### @@ -1094,7 +1099,7 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model= summary = summarize_with_cohere(api_key, json_file_path, cohere_model) elif api_name.lower() == 'groq': api_key = groq_api_key - summary = summarize_with_llama(api_key, json_file_path, groq_model) + summary = summarize_with_groq(api_key, json_file_path, groq_model) elif api_name.lower() == 'llama': token = llama_api_key llama_ip = llama_api_IP