Skip to content

Commit

Permalink
Merge branch 'feature/main-physical-formula' of https://github.com/BP…
Browse files Browse the repository at this point in the history
…-TPSE-Projektgruppe-80/NaPyTau into feature/main-physical-formula
  • Loading branch information
Madddiiiin committed Dec 19, 2024
2 parents aab8ec0 + d127839 commit 6434e34
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 1 deletion.
17 changes: 17 additions & 0 deletions napytau/import_export/model/datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Datapoint:
unshifted_intensity: Optional[ValueErrorPair[float]] = None
feeding_shifted_intensity: Optional[ValueErrorPair[float]] = None
feeding_unshifted_intensity: Optional[ValueErrorPair[float]] = None
tau: Optional[ValueErrorPair[float]] = None
active: bool = True

def get_distance(self) -> ValueErrorPair[float]:
return self.distance
Expand Down Expand Up @@ -70,3 +72,18 @@ def set_feeding_intensity(
) -> None:
self.feeding_shifted_intensity = feeding_shifted_intensity
self.feeding_unshifted_intensity = feeding_unshifted_intensity

def get_tau(self) -> ValueErrorPair[float]:
if self.tau is None:
raise ValueError("Tau was accessed before initialization.")

return self.tau

def set_tau(self, tau: ValueErrorPair[float]) -> None:
self.tau = tau

def is_active(self) -> bool:
return self.active

def set_active(self, active: bool) -> None:
self.active = active
13 changes: 13 additions & 0 deletions napytau/import_export/model/datapoint_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,16 @@ def get_feeding_unshifted_intensities(self) -> List[ValueErrorPair[float]]:
).elements.values(),
)
)

def get_taus(self) -> List[ValueErrorPair[float]]:
return list(
map(
lambda datapoint: coalesce(datapoint.tau),
self.filter(
lambda datapoint: datapoint.tau is not None
).elements.values(),
)
)

def get_active_datapoints(self) -> DatapointCollection:
return self.filter(lambda datapoint: datapoint.active)
37 changes: 37 additions & 0 deletions napytau/import_export/model/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Optional, List

from napytau.import_export.model.datapoint_collection import DatapointCollection
from napytau.import_export.model.polynomial import Polynomial
from napytau.import_export.model.relative_velocity import RelativeVelocity
from napytau.util.model.value_error_pair import ValueErrorPair

Expand All @@ -14,9 +16,44 @@ class DataSet:

relative_velocity: ValueErrorPair[RelativeVelocity]
datapoints: DatapointCollection
tau_factor: Optional[float] = None
weighted_mean_tau: Optional[ValueErrorPair[float]] = None
sampling_points: Optional[List[float]] = None
polynomial_count: Optional[int] = None
polynomials: Optional[List[Polynomial]] = None

def get_relative_velocity(self) -> ValueErrorPair[RelativeVelocity]:
return self.relative_velocity

def get_datapoints(self) -> DatapointCollection:
return self.datapoints

def get_tau_factor(self) -> Optional[float]:
return self.tau_factor

def set_tau_factor(self, tau_factor: float) -> None:
self.tau_factor = tau_factor

def get_weighted_mean_tau(self) -> Optional[ValueErrorPair[float]]:
return self.weighted_mean_tau

def set_weighted_mean_tau(self, weighted_mean_tau: ValueErrorPair[float]) -> None:
self.weighted_mean_tau = weighted_mean_tau

def get_sampling_points(self) -> Optional[List[float]]:
return self.sampling_points

def set_sampling_points(self, sampling_points: List[float]) -> None:
self.sampling_points = sampling_points

def get_polynomial_count(self) -> Optional[int]:
return self.polynomial_count

def set_polynomial_count(self, polynomial_count: int) -> None:
self.polynomial_count = polynomial_count

def get_polynomials(self) -> Optional[List[Polynomial]]:
return self.polynomials

def set_polynomials(self, polynomials: List[Polynomial]) -> None:
self.polynomials = polynomials
17 changes: 17 additions & 0 deletions napytau/import_export/model/polynomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass


@dataclass
class Polynomial:
"""
A class to represent a polynomial.
A polynomial is a mathematical expression consisting of variables and coefficients.
"""

coefficients: list[float]

def get_coefficients(self) -> list[float]:
return self.coefficients

def set_coefficients(self, coefficients: list[float]) -> None:
self.coefficients = coefficients
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[lint.per-file-ignores]
# Ignore line length in tests, to incentivice longer more descriptive test names.
# Disable line length check in test files, to encourage longer, more descriptive test names.
"tests/*" = ["E501"]

[format]
Expand Down
47 changes: 47 additions & 0 deletions tests/import_export/model/datapoint_collection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,53 @@ def test_canRetrieveFeedingUnshiftedIntensities(self):
[ValueErrorPair(1.0, 0.1), ValueErrorPair(2.0, 0.1)],
)

def test_canRetrieveTaus(self):
"""Can retrieve taus"""
collection = DatapointCollection(
[
Datapoint(
distance=ValueErrorPair(12.12, 0.1),
tau=ValueErrorPair(1.0, 0.1),
),
Datapoint(
distance=ValueErrorPair(12.13, 0.1),
tau=ValueErrorPair(2.0, 0.1),
),
Datapoint(
distance=ValueErrorPair(12.14, 0.1),
),
]
)

self.assertEqual(
collection.get_taus(),
[ValueErrorPair(1.0, 0.1), ValueErrorPair(2.0, 0.1)],
)

def test_canRetrieveActiveDatapoints(self):
"""Can retrieve active datapoints"""
collection = DatapointCollection(
[
Datapoint(
distance=ValueErrorPair(12.12, 0.1),
active=True,
),
Datapoint(
distance=ValueErrorPair(12.13, 0.1),
active=False,
),
Datapoint(
distance=ValueErrorPair(12.14, 0.1),
active=True,
),
]
)

self.assertEqual(
list(collection.get_active_datapoints().as_dict().values()),
[collection.elements[hash(12.12)], collection.elements[hash(12.14)]],
)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions tests/import_export/model/datapoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def test_raisesAnExceptionIfIntensityIsAccessedBeforeInitialization(self):
with self.assertRaises(Exception):
datapoint.get_intensity()

def test_raisesAnExceptionIfTauIsAccessedBeforeInitialization(self):
"""Raise an exception if tau is accessed before initialization."""
datapoint = Datapoint(ValueErrorPair(1.0, 0.1))
with self.assertRaises(Exception):
datapoint.get_tau()


if __name__ == "__main__":
unittest.main()

0 comments on commit 6434e34

Please sign in to comment.