Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add MP3Reader class for mp3 file reader #194

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3"
),
"private": False,
},
Expand All @@ -336,7 +336,7 @@
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3"
),
"private": False,
},
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/indices/ingests/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HtmlReader,
MathpixPDFReader,
MhtmlReader,
MP3Reader,
OCRReader,
PandasExcelReader,
PDFThumbnailReader,
Expand Down Expand Up @@ -53,6 +54,7 @@
".tiff": unstructured,
".tif": unstructured,
".pdf": PDFThumbnailReader(),
".mp3": MP3Reader(),
".txt": TxtReader(),
".md": TxtReader(),
}
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .excel_loader import ExcelReader, PandasExcelReader
from .html_loader import HtmlReader, MhtmlReader
from .mathpix_loader import MathpixPDFReader
from .mp3_loader import MP3Reader
from .ocr_loader import ImageReader, OCRReader
from .pdf_loader import PDFThumbnailReader
from .txt_loader import TxtReader
Expand All @@ -30,6 +31,7 @@
"AdobeReader",
"TxtReader",
"PDFThumbnailReader",
"MP3Reader",
"WebReader",
"DoclingReader",
]
101 changes: 101 additions & 0 deletions libs/kotaemon/kotaemon/loaders/mp3_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from loguru import logger

from kotaemon.base import Document, Param

from .base import BaseReader

if TYPE_CHECKING:
from transformers import pipeline


class MP3Reader(BaseReader):
model_name_or_path: str = Param(
help="The model name or path to use for speech recognition.",
default="distil-whisper/distil-large-v3",
)
cache_dir: str = Param(
help="The cache directory to use for the model.",
default="models",
)

@Param.auto()
def asr_pipeline(self) -> "pipeline":
"""Setup the ASR pipeline for speech recognition"""
try:
import accelerate # noqa: F401
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
except ImportError:
raise ImportError(
"Please install the required packages to use the MP3Reader: "
"'pip install accelerate torch transformers'"
)

try:
# Device and model configuration
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Model and processor initialization
model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_name_or_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=self.cache_dir,
).to(device)

processor = AutoProcessor.from_pretrained(
self.model_name_or_path,
)

# ASR pipeline setup
asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
return_timestamps=True,
)
logger.info("ASR pipeline setup successful.")
except Exception as e:
logger.error(f"Error occurred during ASR pipeline setup: {e}")
raise

return asr_pipeline

def speech_to_text(self, audio_path: str) -> str:
try:
import librosa

# Performing speech recognition
audio_array, _ = librosa.load(audio_path, sr=16000) # 16kHz sampling rate
result = self.asr_pipeline(audio_array)

text = result.get("text", "").strip()
if text == "":
logger.warning("No text found in the audio file.")
return text
except Exception as e:
logger.error(f"Error occurred during speech recognition: {e}")
return ""

def run(
self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
return self.load_data(str(file_path), extra_info=extra_info, **kwargs)

def load_data(
self, audio_file: str, extra_info: Optional[dict] = None, **kwargs
) -> List[Document]:
# Get text from the audio file
text = self.speech_to_text(audio_file)
metadata = extra_info or {}

return [Document(text=text, metadata=metadata)]
13 changes: 13 additions & 0 deletions libs/kotaemon/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def if_llama_cpp_not_installed():
return False


def if_librosa_not_installed():
try:
import librosa # noqa: F401
except ImportError:
return True
else:
return False


skip_when_haystack_not_installed = pytest.mark.skipif(
if_haystack_not_installed(), reason="Haystack is not installed"
)
Expand Down Expand Up @@ -97,3 +106,7 @@ def if_llama_cpp_not_installed():
skip_llama_cpp_not_installed = pytest.mark.skipif(
if_llama_cpp_not_installed(), reason="llama_cpp is not installed"
)

skip_when_librosa_not_installed = pytest.mark.skipif(
if_librosa_not_installed(), reason="librosa is not installed"
)
Binary file added libs/kotaemon/tests/resources/dummy.mp3
Binary file not shown.
21 changes: 20 additions & 1 deletion libs/kotaemon/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
DocxReader,
HtmlReader,
MhtmlReader,
MP3Reader,
UnstructuredReader,
)

from .conftest import skip_when_unstructured_pdf_not_installed
from .conftest import (
skip_when_librosa_not_installed,
skip_when_unstructured_pdf_not_installed,
)


def test_docx_reader():
Expand Down Expand Up @@ -93,3 +97,18 @@ def test_azureai_document_intelligence_reader(mock_client):

assert len(docs) == 1
mock_client.assert_called_once()


@skip_when_librosa_not_installed
@patch("kotaemon.loaders.MP3Reader.asr_pipeline")
def test_mp3_reader(mock_pipeline):
# Mock the return value
mock_pipeline.return_value = "This is the transcript"

reader = MP3Reader()
docs = reader.load_data(str(Path(__file__).parent / "resources" / "dummy.mp3"))

assert len(docs) == 1

# Assert that the ASR pipeline was called
mock_pipeline.assert_called_once()