Skip to content

Commit

Permalink
temporary fix for path problem
Browse files Browse the repository at this point in the history
  • Loading branch information
adamspd committed Sep 1, 2024
1 parent b4465e8 commit b6e1dce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
9 changes: 8 additions & 1 deletion spam_detector_ai/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# spam_detector_ai/trainer.py
import os
import sys
from pathlib import Path

from sklearn.model_selection import train_test_split

project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from classifiers.classifier_types import ClassifierType
from logger_config import init_logging
from training.train_models import ModelTrainer
Expand All @@ -18,7 +24,8 @@ def train_model(classifier_type, model_filename, vectoriser_filename, X_train, y

if __name__ == '__main__':
# Load and preprocess data once
initial_trainer = ModelTrainer(data_path='data/spam.csv', logger=logger)
data_path = os.path.join(project_root, 'spam_detector_ai', 'data', 'spam.csv')
initial_trainer = ModelTrainer(data_path=data_path, logger=logger)
processed_data = initial_trainer.preprocess_data_()

# Split the data once
Expand Down
19 changes: 9 additions & 10 deletions spam_detector_ai/training/train_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# spam_detector_ai/training/train_models.py

import os
from pathlib import Path

from sklearn.model_selection import train_test_split

Expand Down Expand Up @@ -77,19 +78,17 @@ def get_directory_path(self):
raise ValueError(f"Invalid classifier type: {self.classifier_type}")

def save_model(self, model_filename, vectoriser_filename):
# Determine the directory of this file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Assuming the spam_detector_ai directory is one level up from the current directory
base_dir = os.path.dirname(current_dir)
# Use the project root to construct the paths
project_root = Path(__file__).parent.parent
models_dir = project_root
directory_path = self.get_directory_path()

# Ensure the directory exists
if not os.path.exists(directory_path):
os.makedirs(directory_path)
model_filepath = models_dir / directory_path / model_filename
vectoriser_filepath = models_dir / directory_path / vectoriser_filename

model_filepath = os.path.join(base_dir, directory_path, model_filename)
vectoriser_filepath = os.path.join(base_dir, directory_path, vectoriser_filename)
# Ensure the directory exists
model_filepath.parent.mkdir(parents=True, exist_ok=True)

self.logger.info(f'Saving model to {model_filepath}')
self.classifier.save_model(model_filepath, vectoriser_filepath)
self.classifier.save_model(str(model_filepath), str(vectoriser_filepath))
self.logger.info('Model saved.\n')

0 comments on commit b6e1dce

Please sign in to comment.