-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
94 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,106 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.append(os.path.abspath('../src')) | ||
|
||
import unittest | ||
from data.data_loader import load_data | ||
from data.data_cleaning import handle_missing_values, remove_outliers | ||
from data.data_transformation import scale_data | ||
from unittest.mock import patch, MagicMock | ||
import pandas as pd | ||
import numpy as np | ||
from datetime import datetime | ||
|
||
# Add `src` directory to the Python path | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | ||
|
||
# Import functions from data_loader | ||
from data.data_loader import load_dataset, load_stock_data, retrieve_stock_data, get_close_prices | ||
|
||
|
||
class TestDataLoader(unittest.TestCase): | ||
def setUp(self): | ||
# Sample data | ||
self.data = pd.DataFrame({ | ||
'Close': [100, 105, np.nan, 110, 115, 120], | ||
'Volume': [200, 210, 220, np.nan, 230, 240] | ||
|
||
@patch('data.data_loader.fetch_data') | ||
def test_load_dataset(self, mock_fetch_data): | ||
# Mock the data returned by fetch_data | ||
assets = ['TSLA', 'BND'] | ||
start_date = '2023-01-01' | ||
end_date = '2023-12-31' | ||
|
||
# Ensure mock data has the correct number of columns (Date + len(assets)) | ||
mock_data = pd.DataFrame({ | ||
'Date': pd.date_range(start=start_date, periods=10), | ||
'TSLA': range(10), | ||
'BND': range(10, 20) | ||
}) | ||
mock_data.set_index('Date', inplace=True) # Set 'Date' as the index to match real data | ||
mock_fetch_data.return_value = mock_data | ||
|
||
# Call load_dataset | ||
result = load_dataset(assets, start_date, end_date) | ||
|
||
# Assert the results | ||
self.assertIsInstance(result, pd.DataFrame) | ||
self.assertIn('Date', result.columns) | ||
self.assertEqual(result.columns.tolist(), ['Date', 'TSLA', 'BND']) # Check column names | ||
self.assertEqual(len(result), 10) | ||
|
||
@patch('data.data_loader.yf.download') | ||
def test_load_stock_data(self, mock_yf_download): | ||
# Mock the data returned by yfinance download | ||
ticker = 'TSLA' | ||
start_date = '2023-01-01' | ||
end_date = '2023-12-31' | ||
mock_data = pd.DataFrame({ | ||
'Date': pd.date_range(start=start_date, periods=10), | ||
'Adj Close': range(10) | ||
}) | ||
mock_data.set_index('Date', inplace=True) | ||
mock_yf_download.return_value = mock_data | ||
|
||
# Call load_stock_data | ||
result = load_stock_data(ticker, start_date, end_date) | ||
|
||
# Assert the results | ||
self.assertIsInstance(result, pd.DataFrame) | ||
self.assertEqual(result.shape, (10, 1)) | ||
self.assertIn('Adj Close', result.columns) | ||
|
||
@patch('data.data_loader.yf.download') | ||
def test_retrieve_stock_data(self, mock_yf_download): | ||
# Mock the data returned by yfinance download for multiple tickers | ||
tickers = ['TSLA', 'BND'] | ||
start_date = '2023-01-01' | ||
end_date = '2023-12-31' | ||
mock_data_TSLA = pd.DataFrame({ | ||
'Date': pd.date_range(start=start_date, periods=10), | ||
'Adj Close': range(10) | ||
}).set_index('Date') | ||
mock_data_BND = pd.DataFrame({ | ||
'Date': pd.date_range(start=start_date, periods=10), | ||
'Adj Close': range(10, 20) | ||
}).set_index('Date') | ||
mock_yf_download.side_effect = [mock_data_TSLA, mock_data_BND] | ||
|
||
# Call retrieve_stock_data | ||
result = retrieve_stock_data(tickers, start_date, end_date) | ||
|
||
# Assert the results | ||
self.assertIsInstance(result, dict) | ||
self.assertIn('TSLA', result) | ||
self.assertIn('BND', result) | ||
self.assertIsInstance(result['TSLA'], pd.DataFrame) | ||
self.assertIsInstance(result['BND'], pd.DataFrame) | ||
|
||
def test_load_data(self): | ||
data = load_data("AAPL", "2022-01-01", "2022-02-01") | ||
self.assertFalse(data.empty) | ||
def test_get_close_prices(self): | ||
# Create mock stock_data dictionary | ||
stock_data = { | ||
'TSLA': pd.DataFrame({'Adj Close': range(10)}), | ||
'BND': pd.DataFrame({'Adj Close': range(10, 20)}) | ||
} | ||
|
||
def test_handle_missing_values(self): | ||
cleaned_data = handle_missing_values(self.data) | ||
self.assertFalse(cleaned_data.isnull().values.any()) | ||
# Call get_close_prices | ||
result = get_close_prices(stock_data) | ||
|
||
def test_remove_outliers(self): | ||
data_no_outliers = remove_outliers(self.data, "Close") | ||
self.assertTrue(data_no_outliers.shape[0] <= self.data.shape[0]) | ||
# Assert the results | ||
self.assertIsInstance(result, pd.DataFrame) | ||
self.assertEqual(result.shape, (10, 2)) | ||
self.assertEqual(result.columns.tolist(), ['TSLA', 'BND']) | ||
|
||
def test_scale_data(self): | ||
scaled_data, scaler = scale_data(self.data.copy(), ["Close", "Volume"]) | ||
self.assertEqual(scaled_data.shape, self.data.shape) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.