Skip to content

Commit

Permalink
test(learner.extractor): rewrite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
breakthewall committed Dec 2, 2024
1 parent c983045 commit aafc500
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 117 deletions.
103 changes: 40 additions & 63 deletions icfree/learner/extractor.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
import pandas as pd
import argparse

def find_n_m_from_sampling(sampling_file_path):
def find_n_m_from_sampling(df_sampling):
"""
Find the number of unique combinations (n) and determine if there are repetitions in the sampling file.
Parameters:
sampling_file_path (str): Path to the sampling file (Excel, CSV, or TSV).
df_sampling (DataFrame): The sampling DataFrame.
Returns:
tuple: Number of unique combinations (n) and a boolean indicating if repetitions exist.
"""
file_extension = sampling_file_path.split('.')[-1].lower()

# Load the sampling file
if file_extension == 'xlsx':
df_sampling = pd.read_excel(sampling_file_path)
elif file_extension == 'csv':
df_sampling = pd.read_csv(sampling_file_path)
elif file_extension == 'tsv':
df_sampling = pd.read_csv(sampling_file_path, sep='\t')
else:
raise ValueError("Unsupported file type. Please provide an Excel (.xlsx), CSV (.csv), or TSV (.tsv) file.")

# Drop the unnamed column if it exists
if df_sampling.columns[0].startswith('Unnamed'):
df_sampling = df_sampling.drop(columns=df_sampling.columns[0])
Expand All @@ -33,37 +21,26 @@ def find_n_m_from_sampling(sampling_file_path):

return n, has_repetitions

def infer_replicates(initial_data_file, sampling_file_path, num_samples):
def infer_replicates(df_initial, df_sampling, num_samples):
"""
Infer the number of replicates from the initial data file and sampling file.
Parameters:
initial_data_file (str): Path to the initial data file (Excel, CSV, or TSV).
sampling_file_path (str): Path to the sampling file (Excel, CSV, or TSV).
df_initial (DataFrame): The initial data DataFrame.
df_sampling (DataFrame): The sampling DataFrame.
num_samples (int): Number of unique combinations (n).
Returns:
int: Inferred number of replicates.
"""
# Load initial data file
file_extension = initial_data_file.split('.')[-1].lower()
if file_extension == 'xlsx':
df_initial = pd.read_excel(initial_data_file)
elif file_extension == 'csv':
df_initial = pd.read_csv(initial_data_file)
elif file_extension == 'tsv':
df_initial = pd.read_csv(initial_data_file, sep='\t')
else:
raise ValueError("Unsupported file type. Please provide an Excel (.xlsx), CSV (.csv), or TSV (.tsv) file.")

# Remove the first two columns
# Remove the first two columns from the initial data
df_initial = df_initial.iloc[:, 2:]

# Infer replicates based on the number of columns
total_columns = df_initial.shape[1]

# Load sampling file to check for repetitions
_, has_repetitions = find_n_m_from_sampling(sampling_file_path)
# Check for repetitions in the sampling file
_, has_repetitions = find_n_m_from_sampling(df_sampling)

if has_repetitions:
num_replicates = total_columns // num_samples
Expand All @@ -73,46 +50,32 @@ def infer_replicates(initial_data_file, sampling_file_path, num_samples):

return num_replicates

def process_data(file_path, num_samples, num_replicates):
def process_data(df_initial, num_samples, num_replicates):
"""
Process the initial data file to reshape the fluorescence data.
Parameters:
file_path (str): Path to the initial data file (Excel, CSV, or TSV).
df_initial (DataFrame): The initial data DataFrame.
num_samples (int): Number of samples.
num_replicates (int): Number of replicates.
Returns:
tuple: A tuple containing the reshaped DataFrame and the sheet name (if applicable).
DataFrame: The reshaped DataFrame.
"""
file_extension = file_path.split('.')[-1].lower()
if file_extension == 'xlsx':
excel_data = pd.ExcelFile(file_path)
sheet_name = excel_data.sheet_names[0]
df = pd.read_excel(file_path, sheet_name=sheet_name)
elif file_extension == 'csv':
df = pd.read_csv(file_path)
sheet_name = None
elif file_extension == 'tsv':
df = pd.read_csv(file_path, sep='\t')
sheet_name = None
else:
raise ValueError("Unsupported file type. Please provide an Excel (.xlsx), CSV (.csv), or TSV (.tsv) file.")

# Remove the first two columns
df = df.iloc[:, 2:]
df_initial = df_initial.iloc[:, 2:]

# Reshape data based on num_samples and num_replicates
total_values = num_samples * num_replicates
values_to_keep = df.values.flatten()[:total_values]
values_to_keep = df_initial.values.flatten()[:total_values]
reshaped_values = values_to_keep.reshape((num_samples, num_replicates), order='F')

# Create reshaped DataFrame
df_reshaped = pd.DataFrame(reshaped_values)
df_reshaped.columns = [f"Fluorescence Value {i+1}" for i in range(num_replicates)]
df_reshaped["Fluorescence Average"] = df_reshaped.mean(axis=1)

return df_reshaped, sheet_name
return df_reshaped

def process(initial_data_file, output_file_path, sampling_file, num_samples=None, num_replicates=None, display=True):
"""
Expand All @@ -129,32 +92,46 @@ def process(initial_data_file, output_file_path, sampling_file, num_samples=None
Returns:
DataFrame: The combined DataFrame.
"""
if num_samples is None or num_replicates is None:
n, _ = find_n_m_from_sampling(sampling_file)
num_samples = num_samples if num_samples is not None else n
num_replicates = num_replicates if num_replicates is not None else infer_replicates(initial_data_file, sampling_file, num_samples)

df_reshaped, sheet_name = process_data(initial_data_file, num_samples, num_replicates)

# Load sampling file
file_extension = sampling_file.split('.')[-1].lower()
# Load files once
file_extension = initial_data_file.split('.')[-1].lower()
if file_extension == 'xlsx':
df_sampling = pd.read_excel(sampling_file)
df_initial = pd.read_excel(initial_data_file)
elif file_extension == 'csv':
df_sampling = pd.read_csv(sampling_file)
df_initial = pd.read_csv(initial_data_file)
elif file_extension == 'tsv':
df_initial = pd.read_csv(initial_data_file, sep='\t')
else:
raise ValueError("Unsupported file type. Please provide an Excel (.xlsx), CSV (.csv), or TSV (.tsv) file.")

sampling_extension = sampling_file.split('.')[-1].lower()
if sampling_extension == 'xlsx':
df_sampling = pd.read_excel(sampling_file)
elif sampling_extension == 'csv':
df_sampling = pd.read_csv(sampling_file)
elif sampling_extension == 'tsv':
df_sampling = pd.read_csv(sampling_file, sep='\t')
else:
raise ValueError("Unsupported file type. Please provide an Excel (.xlsx), CSV (.csv), or TSV (.tsv) file.")

# Infer num_samples and num_replicates if not provided
if num_samples is None or num_replicates is None:
n, _ = find_n_m_from_sampling(df_sampling)
num_samples = num_samples if num_samples is not None else n
num_replicates = num_replicates if num_replicates is not None else infer_replicates(df_initial, df_sampling, num_samples)

# Process data
df_reshaped = process_data(df_initial, num_samples, num_replicates)

# Combine sampling file and reshaped data
df_sampling = df_sampling.head(num_samples)
df_combined = pd.concat([df_sampling, df_reshaped], axis=1)

if display:
print(df_combined)

# Save output
if output_file_path.endswith('.xlsx'):
df_combined.to_excel(output_file_path, index=False, sheet_name=sheet_name or "Sheet1")
df_combined.to_excel(output_file_path, index=False, sheet_name="Sheet1")
elif output_file_path.endswith('.csv'):
df_combined.to_csv(output_file_path, index=False)
elif output_file_path.endswith('.tsv'):
Expand Down
115 changes: 61 additions & 54 deletions tests/learner/test_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import unittest
import pandas as pd
import sys
from io import StringIO
from os import path as os_path
from icfree.learner.extractor import find_n_m, process_data, load_sampling_file, clean_sampling_file, process
from icfree.learner.extractor import (
find_n_m_from_sampling,
infer_replicates,
process_data,
process,
)

class TestDataExtractor(unittest.TestCase):

Expand All @@ -21,66 +26,68 @@ def setUpClass(cls):
)
cls.sampling_file = os_path.join(
cls.data_path, 'input',
"plate1_sampling.tsv"
"plate1_sampling.csv"
)
cls.reference_output_file = os_path.join(
cls.data_path, 'output',
"plate1.csv"
)
cls.output_file = "output.csv"

def test_find_n_m(self):
n, m = find_n_m(self.sampling_file)
self.assertIsInstance(n, int)
self.assertIsInstance(m, int)
# Update expected values based on the actual data
self.assertEqual(n, 57) # Actual expected value from the data
self.assertEqual(m, 6) # Actual expected value from the data


cls.num_samples = 57
cls.num_replicates = 6

cls.df_sampling = pd.read_csv(cls.sampling_file)
cls.df_initial = pd.read_excel(cls.initial_data_file)
cls.reference_output = pd.read_csv(cls.reference_output_file)

def test_find_n_m_from_sampling(self):
"""Test detection of unique combinations and repetitions in sampling."""
n, has_repetitions = find_n_m_from_sampling(self.df_sampling)
self.assertEqual(n, self.num_samples) # Unique combinations
self.assertTrue(has_repetitions) # Repetitions exist

def test_infer_replicates(self):
"""Test inference of replicates from initial and sampling data."""
num_replicates = infer_replicates(self.df_initial, self.df_sampling, self.num_samples)
self.assertEqual(num_replicates, self.num_replicates) # Expected replicates based on data structure

def test_process_data(self):
df_reshaped, sheet_name = process_data(self.initial_data_file, 57, 6)
self.assertIsInstance(df_reshaped, pd.DataFrame)
self.assertIsNotNone(sheet_name)
# Update expected shape based on the actual reshaped data
self.assertEqual(df_reshaped.shape[1], 7) # Example check, adjust based on actual reshaped data
# Compare df values with expected values
expected_values = df_reshaped.iloc[0, :].values.tolist()
actual_values = df_reshaped.iloc[0, :].values.tolist()
self.assertEqual(expected_values, actual_values)

def test_load_sampling_file(self):
df_sampling = load_sampling_file(self.sampling_file, 57)
self.assertIsInstance(df_sampling, pd.DataFrame)
self.assertEqual(df_sampling.shape[0], 57) # Assuming num_samples provided is 57
# Compare df values with expected values
expected_values = df_sampling.iloc[:, 1:].values.tolist()
actual_values = df_sampling.iloc[:, 1:].values.tolist()
self.assertEqual(expected_values, actual_values)

def test_clean_sampling_file(self):
df_sampling = load_sampling_file(self.sampling_file, 57)
df_cleaned = clean_sampling_file(df_sampling)
self.assertIsInstance(df_cleaned, pd.DataFrame)
self.assertFalse(df_cleaned.isnull().values.any())
self.assertEqual(df_cleaned.shape[0], 57) # Check number of samples after cleaning
# Compare df values with expected values
expected_values = df_cleaned.iloc[:, 1:].values.tolist()
actual_values = df_cleaned.iloc[:, 1:].values.tolist()
self.assertEqual(expected_values, actual_values)

"""Test reshaping of fluorescence data."""
df_reshaped = process_data(self.df_initial, self.num_samples, self.num_replicates)
self.assertEqual(df_reshaped.shape, (self.num_samples, self.num_replicates+1)) # Rows = num_samples, Columns = num_replicates + 1 (average column)
self.assertAlmostEqual(df_reshaped["Fluorescence Average"].iloc[0], 1551) # Verify first row's average

def test_process(self):
combined_df = process(self.initial_data_file, self.output_file, self.sampling_file, 57, 6, False)
"""Test end-to-end processing of data and file saving."""
# Save the processed output to a temporary file
output_file = "/tmp/test_output.csv"
combined_df = process(
initial_data_file=self.initial_data_file,
output_file_path=output_file,
sampling_file=self.sampling_file,
num_samples=self.num_samples,
num_replicates=self.num_replicates,
display=False
)
self.assertIsInstance(combined_df, pd.DataFrame)
self.assertEqual(combined_df.shape[0], self.num_samples) # Number of samples
self.assertTrue(combined_df.columns[-1] == "Fluorescence Average") # Check last column name

def test_process_with_inference(self):
"""Test end-to-end processing with automatic inference."""
output_file = "/tmp/test_output_inferred.csv"
combined_df = process(
initial_data_file=self.initial_data_file,
output_file_path=output_file,
sampling_file=self.sampling_file,
num_samples=None, # Let the function infer num_samples
num_replicates=None, # Let the function infer num_replicates
display=False
)
self.assertIsInstance(combined_df, pd.DataFrame)
self.assertEqual(combined_df.shape[0], 57) # Example check, adjust based on actual combined data
# Load reference output
df_reference_output = pd.read_csv(self.reference_output_file)
# Round all values at 10^-2 precision
combined_df = combined_df.round(2)
df_reference_output = df_reference_output.round(2)
# Compare df values with expected values from reference output
expected_values = df_reference_output.values.tolist()
actual_values = combined_df.values.tolist()
self.assertEqual(expected_values, actual_values)
self.assertEqual(combined_df.shape[0], self.num_samples) # Number of samples inferred
self.assertEqual(combined_df.shape[1], 18) # Sampling columns + reshaped fluorescence columns

if __name__ == "__main__":
unittest.main(argv=[''], verbosity=2, exit=False)
unittest.main(argv=[''], verbosity=2, exit=False)

0 comments on commit aafc500

Please sign in to comment.