diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..5d72804
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[report]
+show_missing = True
+omit = src/traffic_weaver/datasets/*
diff --git a/.flake8 b/.flake8
index e4e5eb3..01eaacb 100644
--- a/.flake8
+++ b/.flake8
@@ -1,3 +1,3 @@
[flake8]
-max-line-length = 88
+max-line-length = 120
extend-ignore = E203, W503
diff --git a/.gitignore b/.gitignore
index 12d6d20..199a41f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,7 @@
.husky
__pycache__/
/htmlcov/
+/src/traffic_weaver/test_runners/
+
+.venv/
+venv/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cb98f08..ad7bca6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,6 +17,7 @@ repos:
hooks:
- id: flake8
stages: [pre-commit]
+ args: [--config=.flake8]
- repo: local
hooks:
- id: conventional-commit
diff --git a/Makefile b/Makefile
index 00e5287..89f33f6 100644
--- a/Makefile
+++ b/Makefile
@@ -10,8 +10,8 @@ build-requirements:
pip-compile -o requirements.txt pyproject.toml
pip-compile --extra dev -o dev-requirements.txt pyproject.toml
-install: clean
- pip-sync requirements.txt dev-requirements.txt
+build: clean
+ pip-sync requirements.txt dev-requirements.txt\
tag-release:
commit-and-tag-version
@@ -27,6 +27,6 @@ docs: clean
cd docs && make html
test: clean
- pytest --cov=traffic_weaver --cov-report term-missing --cov-report html
+ pytest --cov=traffic_weaver --cov-report term-missing --cov-report html --cov-config .coveragerc
mkdir -p _images
coverage-badge -f -o badges/coverage.svg
diff --git a/badges/coverage.svg b/badges/coverage.svg
index b6c4e36..318685c 100644
--- a/badges/coverage.svg
+++ b/badges/coverage.svg
@@ -9,13 +9,13 @@
-
+
coverage
coverage
- 95%
- 95%
+ 85%
+ 85%
diff --git a/dev-requirements.in b/dev-requirements.in
index 684ffd9..36244d5 100644
--- a/dev-requirements.in
+++ b/dev-requirements.in
@@ -3,6 +3,7 @@ sphinx
sphinx-rtd-theme
sphinx-mdinclude
sphinxcontrib-bibtex
+pandas
pytest
pytest-cov
pytest-mock
diff --git a/pytest.ini b/pytest.ini
index c021fab..ce9a298 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,3 +1,3 @@
[pytest]
pythonpath = src/
-addopts = --doctest-modules tests src -W ignore::DeprecationWarning
+addopts = --doctest-modules -W ignore::DeprecationWarning
diff --git a/src/traffic_weaver/__init__.py b/src/traffic_weaver/__init__.py
index cb9a993..e8f2dcc 100644
--- a/src/traffic_weaver/__init__.py
+++ b/src/traffic_weaver/__init__.py
@@ -1,33 +1,20 @@
from . import datasets
-from . import oversample
+from . import rfa
from . import match
from . import process
from . import interval
-from . import array_utils
+from . import sorted_array_utils
from ._version import __version__
-from .oversample import (
- LinearFixedOversample,
- LinearAdaptiveOversample,
- ExpFixedOversample,
- ExpAdaptiveOversample,
- CubicSplineOversample,
- PiecewiseConstantOversample,
-)
from .weaver import Weaver
+# @formatter:off
__all__ = [
Weaver,
__version__,
datasets,
- oversample,
+ rfa,
match,
process,
interval,
- array_utils,
- LinearFixedOversample,
- LinearAdaptiveOversample,
- ExpFixedOversample,
- ExpAdaptiveOversample,
- CubicSplineOversample,
- PiecewiseConstantOversample,
+ sorted_array_utils,
]
diff --git a/src/traffic_weaver/array_utils.py b/src/traffic_weaver/array_utils.py
deleted file mode 100644
index b90fc6e..0000000
--- a/src/traffic_weaver/array_utils.py
+++ /dev/null
@@ -1,251 +0,0 @@
-r"""Array utilities.
-"""
-from typing import Tuple, List, Union
-
-import numpy as np
-
-
-def append_one_sample(
- x: Union[np.ndarray, List], y: Union[np.ndarray, List], make_periodic=False
-) -> Tuple[np.ndarray, np.ndarray]:
- r"""Add one sample to the end of time series.
-
- Add one sample to `x` and `y` array. Newly added point `x_i` point is distant from
- the last point of `x` same as the last from the one before last point.
- If `make_periodic` is False, newly added `y_i` point is the same as the last point
- of `y`. If `make_periodic` is True, newly added point is the same as the first point
- of `y`.
-
- Parameters
- ----------
- x: 1-D array-like of size n
- Independent variable in strictly increasing order.
- y: 1-D array-like of size n
- Dependent variable.
- make_periodic: bool, default: False
- If false, append the last `y` point to `y` array.
- If true, append the first `y` point to `y` array.
-
- Returns
- -------
- ndarray
- x, independent variable.
- ndarray
- y, dependent variable.
- """
- x = np.asarray(x, dtype=np.float64)
- y = np.asarray(y, dtype=np.float64)
-
- x = np.append(x, 2 * x[-1] - x[-2])
- if not make_periodic:
- y = np.append(y, y[-1])
- else:
- y = np.append(y, y[0])
- return x, y
-
-
-def oversample_linspace(a: np.ndarray, num: int):
- r"""Oversample array using linspace between each consecutive pair of array elements.
-
- E.g., Array [1, 2, 3] oversampled by 2 becomes [1, 1.5, 2, 2.5, 3].
-
- If input array is of size `n`, then resulting array is of size `(n - 1) * num + 1`.
-
- If `n` is lower than 2, the original array is returned.
-
- Parameters
- ----------
- a: 1-D array
- Input array to oversample.
- num: int
- Number of elements inserted between each pair of array elements. Larger or
- equal to 2.
-
- Returns
- -------
- ndarray
- 1-D array containing `num` linspaced elements between each array elements' pair.
- Its length is equal to `(len(a) - 1) * num + 1`
-
- Examples
- --------
- >>> import numpy as np
- >>> from traffic_weaver.array_utils import oversample_linspace
- >>> oversample_linspace(np.asarray([1, 2, 3]), 4).tolist()
- [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]
-
- """
- if num < 2:
- return a
- a = np.asarray(a, dtype=float)
- return np.append(np.linspace(a[:-1], a[1:], num=num + 1)[:-1].T.flatten(), a[-1])
-
-
-def oversample_piecewise_constant(a: np.ndarray, num: int):
- r"""Oversample array using same left value between each consecutive pair of array
- elements.
-
- E.g., Array [1, 2, 3] oversampled by 2 becomes [1, 1, 2, 2, 3].
-
- If input array is of size `n`, then resulting array is of size `(n - 1) * num + 1`.
-
- If `n` is lower than 2, the original array is returned.
-
- Parameters
- ----------
- a: 1-D array
- Input array to oversample.
- num: int
- Number of elements inserted between each pair of array elements. Larger or
- equal to 2.
-
- Returns
- -------
- ndarray
- 1-D array containing `num` elements between each array elements' pair.
- Its length is equal to `(len(a) - 1) * num + 1`
-
- Examples
- --------
- >>> import numpy as np
- >>> from traffic_weaver.array_utils import oversample_piecewise_constant
- >>> oversample_piecewise_constant(np.asarray([1.0, 2.0, 3.0]), 4).tolist()
- [1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0]
-
- """
- if num < 2:
- return a
- a = np.asarray(a)
- return a.repeat(num)[: -num + 1]
-
-
-def extend_linspace(
- a: np.ndarray, n: int, direction="both", lstart: float = None, rstop: float = None
-):
- """Extends array using linspace with n elements.
-
- Extends array `a` from left and/or right with `n` elements each side.
-
- When extending to the left,
- the starting value is `lstart` (inclusive) and ending value as `a[0]` (exclusive).
- By default, `lstart` is `a[0] - (a[n] - a[0])`.
-
- When extending to the right,
- the starting value `a[-1]` (exclusive) and ending value is `rstop` (inclusive).
- By default, `rstop` is `a[-1] + (a[-1] - a[-1 - n])`
-
- `direction` determines whether to extend to `both`, `left` or `right`.
- By default, it is 'both'.
-
- Parameters
- ----------
- a: 1-D array
- n: int
- Number of elements to extend
- direction: 'both', 'left' or 'right', default: 'both'
- Direction in which array should be extended.
- lstart: float, optional
- Starting value of the left extension.
- By default, it is `a[0] - (a[n] - a[0])`.
- rstop: float, optional
- Ending value of the right extension.
- By default, it is `a[-1] + (a[-1] - a[-1 - n])`.
-
- Returns
- -------
- ndarray
- 1-D extended array.
-
- Examples
- --------
- >>> import numpy as np
- >>> from traffic_weaver.array_utils import extend_linspace
- >>> a = np.array([1, 2, 3])
- >>> extend_linspace(a, 2, direction='both').tolist()
- [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
-
- >>> extend_linspace(a, 4, direction='right', rstop=4).tolist()
- [1.0, 2.0, 3.0, 3.25, 3.5, 3.75, 4.0]
-
- """
- a = np.asarray(a, dtype=float)
- if direction == "both" or direction == "left":
- if lstart is None:
- lstart = 2 * a[0] - a[n]
- ext = np.linspace(lstart, a[0], n + 1)[:-1]
- a = np.insert(a, 0, ext)
-
- if direction == "both" or direction == "right":
- if rstop is None:
- rstop = 2 * a[-1] - a[-n - 1]
- ext = np.linspace(a[-1], rstop, n + 1)[1:]
- a = np.insert(a, len(a), ext)
-
- return a
-
-
-def extend_constant(a: np.ndarray, n: int, direction="both"):
- """Extends array with first/last value with n elements.
-
- Extends array `a` from left and/or right with `n` elements each side.
-
- When extending to the left, value `a[0]` is repeated.
- When extending to the right, value `a[-1]` is repeated.
-
- `direction` determines whether to extend to `both`, `left` or `right`.
- By default, it is 'both'.
-
- Parameters
- ----------
- a: 1-D array
- n: int
- Number of elements to extend
- direction: 'both', 'left' or 'right', optional: 'both'
- Direction in which array should be extended.
-
- Returns
- -------
- ndarray
- 1-D extended array.
-
- Examples
- --------
- >>> import numpy as np
- >>> from traffic_weaver.array_utils import extend_constant
- >>> a = np.array([1, 2, 3])
- >>> extend_constant(a, 2, direction='both').tolist()
- [1, 1, 1, 2, 3, 3, 3]
-
- """
- a = np.asarray(a)
- if direction == "both" or direction == "left":
- a = np.insert(a, 0, [a[0]] * n)
- if direction == "both" or direction == "right":
- a = np.insert(a, len(a), [a[-1]] * n)
- return a
-
-
-def left_piecewise_integral(x, y):
- r"""Integral values between each pair of points using piecewise constant approx.
-
- In particular, if function contains average values, then it corresponds to the
- exact value of the integral.
-
- Parameters
- ----------
- x: 1-D array-like of size n
- Independent variable in strictly increasing order.
- y: 1-D array-like of size n
- Dependent variable.
-
- Returns
- -------
- 1-D array-like of size n-1
- Values of the integral.
- """
- d = np.diff(x)
- return y[:-1] * d
-
-
-def trapezoidal_integral_between_each_pair(x, y):
- np.trapz(y, x)
diff --git a/src/traffic_weaver/datasets.py b/src/traffic_weaver/datasets.py
deleted file mode 100644
index 6a7a262..0000000
--- a/src/traffic_weaver/datasets.py
+++ /dev/null
@@ -1,119 +0,0 @@
-r"""Small datasets used in Weaver."""
-import json
-import os
-
-import numpy as np
-
-
-def __open_datasets_dict():
- filename = os.path.join(
- os.path.dirname(__file__), "./datasets/example_datasets.json"
- )
- with open(filename) as f:
- dataset_dict = json.load(f)
- return dataset_dict
-
-
-def __example_dataset(fun):
- def wrapper():
- dataset_name = fun.__name__[fun.__name__.find("_") + 1 :]
- y = np.array(eval(__open_datasets_dict()[dataset_name]))
- x = np.arange(0, len(y))
- return x, y
-
- return wrapper
-
-
-@__example_dataset
-def load_mobile_video():
- pass
-
-
-@__example_dataset
-def load_mobile_youtube():
- pass
-
-
-@__example_dataset
-def load_mobile_social_media():
- pass
-
-
-@__example_dataset
-def load_fixed_social_media():
- pass
-
-
-@__example_dataset
-def load_tiktok():
- pass
-
-
-@__example_dataset
-def load_snapchat():
- pass
-
-
-@__example_dataset
-def load_mobile_messaging():
- pass
-
-
-@__example_dataset
-def load_mobile_zoom():
- pass
-
-
-@__example_dataset
-def load_measurements():
- pass
-
-
-@__example_dataset
-def load_social_networking():
- pass
-
-
-@__example_dataset
-def load_web():
- pass
-
-
-@__example_dataset
-def load_video_streaming():
- pass
-
-
-@__example_dataset
-def load_cloud():
- pass
-
-
-@__example_dataset
-def load_messaging():
- pass
-
-
-@__example_dataset
-def load_audio():
- pass
-
-
-@__example_dataset
-def load_vpn_and_security():
- pass
-
-
-@__example_dataset
-def load_marketplace():
- pass
-
-
-@__example_dataset
-def load_file_sharing():
- pass
-
-
-@__example_dataset
-def load_gaming():
- pass
diff --git a/src/traffic_weaver/datasets/__init__.py b/src/traffic_weaver/datasets/__init__.py
new file mode 100644
index 0000000..d8b1177
--- /dev/null
+++ b/src/traffic_weaver/datasets/__init__.py
@@ -0,0 +1,149 @@
+from ._base import (get_data_home)
+
+# @formatter:off
+from ._sandvine import (
+ sandvine_dataset_description,
+ load_sandvine_audio,
+ load_sandvine_cloud,
+ load_sandvine_file_sharing,
+ load_sandvine_fixed_social_media,
+ load_sandvine_gaming,
+ load_sandvine_marketplace,
+ load_sandvine_measurements,
+ load_sandvine_messaging,
+ load_sandvine_mobile_messaging,
+ load_sandvine_mobile_social_media,
+ load_sandvine_mobile_video,
+ load_sandvine_mobile_youtube,
+ load_sandvine_mobile_zoom,
+ load_sandvine_snapchat,
+ load_sandvine_social_networking,
+ load_sandvine_tiktok,
+ load_sandvine_video_streaming,
+ load_sandvine_vpn_and_security,
+ load_sandvine_web,
+)
+
+from ._mix_it import (
+ mix_it_dataset_description,
+ fetch_mix_it_bologna_daily,
+ fetch_mix_it_bologna_monthly,
+ fetch_mix_it_bologna_weekly,
+ fetch_mix_it_bologna_yearly,
+ fetch_mix_it_milan_daily,
+ fetch_mix_it_milan_weekly,
+ fetch_mix_it_milan_monthly,
+ fetch_mix_it_milan_yearly,
+ fetch_mix_it_palermo_daily,
+ fetch_mix_it_palermo_weekly,
+ fetch_mix_it_palermo_monthly,
+ fetch_mix_it_palermo_yearly,
+)
+
+from ._ams_ix import (
+ ams_ix_dataset_description,
+ fetch_ams_ix_yearly_by_day,
+ fetch_ams_ix_daily,
+ fetch_ams_ix_monthly,
+ fetch_ams_ix_weekly,
+ fetch_ams_ix_yearly_input,
+ fetch_ams_ix_yearly_output,
+ fetch_ams_ix_grx_yearly_by_day,
+ fetch_ams_ix_grx_daily,
+ fetch_ams_ix_grx_monthly,
+ fetch_ams_ix_grx_yearly_input,
+ fetch_ams_ix_grx_yearly_output,
+ fetch_ams_ix_i_ipx_diameter_daily,
+ fetch_ams_ix_i_ipx_diameter_monthly,
+ fetch_ams_ix_i_ipx_diameter_weekly,
+ fetch_ams_ix_i_ipx_diameter_yearly_input,
+ fetch_ams_ix_i_ipx_diameter_yearly_output,
+ fetch_ams_ix_i_ipx_yearly_by_day,
+ fetch_ams_ix_i_ipx_daily,
+ fetch_ams_ix_i_ipx_monthly,
+ fetch_ams_ix_i_ipx_weekly,
+ fetch_ams_ix_i_ipx_yearly_input,
+ fetch_ams_ix_i_ipx_yearly_output,
+ fetch_ams_ix_isp_yearly_by_day,
+ fetch_ams_ix_isp_daily,
+ fetch_ams_ix_isp_monthly,
+ fetch_ams_ix_isp_weekly,
+ fetch_ams_ix_isp_yearly_input,
+ fetch_ams_ix_isp_yearly_output,
+ fetch_ams_ix_nawas_anti_ddos_daily,
+ fetch_ams_ix_nawas_anti_ddos_monthly,
+ fetch_ams_ix_nawas_anti_ddos_weekly,
+ fetch_ams_ix_nawas_anti_ddos_yearly_input,
+ fetch_ams_ix_nawas_anti_ddos_yearly_output,
+)
+
+__all__ = [
+ get_data_home,
+ sandvine_dataset_description,
+ load_sandvine_audio,
+ load_sandvine_cloud,
+ load_sandvine_file_sharing,
+ load_sandvine_fixed_social_media,
+ load_sandvine_gaming,
+ load_sandvine_marketplace,
+ load_sandvine_measurements,
+ load_sandvine_messaging,
+ load_sandvine_mobile_messaging,
+ load_sandvine_mobile_social_media,
+ load_sandvine_mobile_video,
+ load_sandvine_mobile_youtube,
+ load_sandvine_mobile_zoom,
+ load_sandvine_snapchat,
+ load_sandvine_social_networking,
+ load_sandvine_tiktok,
+ load_sandvine_video_streaming,
+ load_sandvine_vpn_and_security,
+ load_sandvine_web,
+ mix_it_dataset_description,
+ fetch_mix_it_bologna_daily,
+ fetch_mix_it_bologna_monthly,
+ fetch_mix_it_bologna_weekly,
+ fetch_mix_it_bologna_yearly,
+ fetch_mix_it_milan_daily,
+ fetch_mix_it_milan_weekly,
+ fetch_mix_it_milan_monthly,
+ fetch_mix_it_milan_yearly,
+ fetch_mix_it_palermo_daily,
+ fetch_mix_it_palermo_weekly,
+ fetch_mix_it_palermo_monthly,
+ fetch_mix_it_palermo_yearly,
+ ams_ix_dataset_description,
+ fetch_ams_ix_yearly_by_day,
+ fetch_ams_ix_daily,
+ fetch_ams_ix_monthly,
+ fetch_ams_ix_weekly,
+ fetch_ams_ix_yearly_input,
+ fetch_ams_ix_yearly_output,
+ fetch_ams_ix_grx_yearly_by_day,
+ fetch_ams_ix_grx_daily,
+ fetch_ams_ix_grx_monthly,
+ fetch_ams_ix_grx_yearly_input,
+ fetch_ams_ix_grx_yearly_output,
+ fetch_ams_ix_i_ipx_diameter_daily,
+ fetch_ams_ix_i_ipx_diameter_monthly,
+ fetch_ams_ix_i_ipx_diameter_weekly,
+ fetch_ams_ix_i_ipx_diameter_yearly_input,
+ fetch_ams_ix_i_ipx_diameter_yearly_output,
+ fetch_ams_ix_i_ipx_yearly_by_day,
+ fetch_ams_ix_i_ipx_daily,
+ fetch_ams_ix_i_ipx_monthly,
+ fetch_ams_ix_i_ipx_weekly,
+ fetch_ams_ix_i_ipx_yearly_input,
+ fetch_ams_ix_i_ipx_yearly_output,
+ fetch_ams_ix_isp_yearly_by_day,
+ fetch_ams_ix_isp_daily,
+ fetch_ams_ix_isp_monthly,
+ fetch_ams_ix_isp_weekly,
+ fetch_ams_ix_isp_yearly_input,
+ fetch_ams_ix_isp_yearly_output,
+ fetch_ams_ix_nawas_anti_ddos_daily,
+ fetch_ams_ix_nawas_anti_ddos_monthly,
+ fetch_ams_ix_nawas_anti_ddos_weekly,
+ fetch_ams_ix_nawas_anti_ddos_yearly_input,
+ fetch_ams_ix_nawas_anti_ddos_yearly_output,
+]
diff --git a/src/traffic_weaver/datasets/_ams_ix.py b/src/traffic_weaver/datasets/_ams_ix.py
new file mode 100644
index 0000000..08a35bf
--- /dev/null
+++ b/src/traffic_weaver/datasets/_ams_ix.py
@@ -0,0 +1,305 @@
+from traffic_weaver.datasets._base import RemoteFileMetadata, load_csv_dataset_from_remote, load_dataset_description
+
+DATASET_FOLDER = 'ams-ix'
+
+
+def ams_ix_dataset_description():
+ """Get description of this dataset."""
+ return load_dataset_description("ams_ix.md")
+
+
+def fetch_ams_ix_yearly_by_day(**kwargs):
+ """Load and return AMS-IX yearly by day dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-yearly-by-day_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49348042",
+ checksum="56d31d4f0469599a80b5e952d484fe7b6fde8aec0a88ae6fc35e8b450e078447")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-yearly-by-day",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_daily(**kwargs):
+ """Load and return AMS-IX daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix_daily_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49347946",
+ checksum="2f606b0adecbbae50727539cebd2d107c6d5a962298d34cbeb1bf4b7cab0d3a9")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix_daily", dataset_folder=DATASET_FOLDER,
+ validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_monthly(**kwargs):
+ """Load and return AMS-IX monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix_monthly_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49347949",
+ checksum="a8aeaabbd9089bf25455ab8d164f69a868032f7f0ba2c1a771bf5d04a2e16581")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix_monthly", dataset_folder=DATASET_FOLDER,
+ validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_weekly(**kwargs):
+ """Load and return AMS-IX weekly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix_weekly_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49347952",
+ checksum="2273530aeca328721764491d770a8259b255df5c028c899aa5c6c3b2001e33f4")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix_weekly", dataset_folder=DATASET_FOLDER,
+ validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_yearly_input(**kwargs):
+ """Load and return AMS-IX yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix_yearly_2024-09-21-input.csv",
+ url="https://figshare.com/ndownloader/files/49347955",
+ checksum="19fe5560606477ccacead54a993e856be45d59b5beb1dab819b612a437a301d3")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix_yearly_input",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_yearly_output(**kwargs):
+ """Load and return AMS-IX yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix_yearly_2024-09-21-output.csv",
+ url="https://figshare.com/ndownloader/files/49347958",
+ checksum="9f67208e8b6155634bb517d78c796e5344c1400d174e8ab62a65580a24b553f5")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix_yearly_output",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_grx_yearly_by_day(**kwargs):
+ """Load and return AMS-IX GRX yearly by day dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-grx-yearly-by-day_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49351339",
+ checksum="aff17528c4b3855cfb52bc42d1d67c1cb8d24fc44153f6def0febe30ce7c5892")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-grx-yearly-by-day",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_grx_daily(**kwargs):
+ """Load and return AMS-IX GRX daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-grx_daily_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49347961",
+ checksum="cc69b78859fcf8a328bde5cf407decf01493930efa1d31397af56b8060895c15")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-grx_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_grx_monthly(**kwargs):
+ """Load and return AMS-IX GRX monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-grx_monthly_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49347964",
+ checksum="4a9e45d2bf647c6eba6b2550c2d7b06da363e736a8d801763dba5244ed8f491d")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-grx_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_grx_yearly_input(**kwargs):
+ """Load and return AMS-IX GRX yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-grx_yearly_2024-09-21-input.csv",
+ url="https://figshare.com/ndownloader/files/49347967",
+ checksum="e06a0ab9073057e618e7d41cf3cb4171650ee22bd5411a9bc95cd25104c44bc4")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-grx_yearly_input",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_grx_yearly_output(**kwargs):
+ """Load and return AMS-IX GRX yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-grx_yearly_2024-09-21-output.csv",
+ url="https://figshare.com/ndownloader/files/49347970",
+ checksum="e2d2b1dda84328effca9b461f07a99afd598a6a13011a2c41338bf5a347c5d70")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-grx_yearly_output",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_diameter_daily(**kwargs):
+ """Load and return AMS-IX I-IPX diameter daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx-diameter_daily_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347991",
+ checksum="abbe54c558d3cc954f361d7f5eab66c194ec6f0866332410386ab39678ee15c2")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx-diameter_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_diameter_monthly(**kwargs):
+ """Load and return AMS-IX I-IPX diameter monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx-diameter_monthly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347994",
+ checksum="cebf44d0c585a0685e3446af44d001bddf36e975f8963f60fb36d0c0583eb82b")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx-diameter_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_diameter_weekly(**kwargs):
+ """Load and return AMS-IX I-IPX diameter weekly dataset."""
+ remote = RemoteFileMetadata(filename="./ams-ix-i-ipx-diameter_weekly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347997",
+ checksum="2b5b622d041c4ad1f0e282420620b96da9ddee01c14eaf4457319515bbb1d286")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="./ams-ix-i-ipx-diameter_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_diameter_yearly_input(**kwargs):
+ """Load and return AMS-IX I-IPX diameter yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx-diameter_yearly_2024-09-22-input.csv",
+ url="https://figshare.com/ndownloader/files/49348000",
+ checksum="eba37cdf6131d6d9ddd668e919ab5ef5f222171cf3f33170b9aa9af158e9025e")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx-diameter_yearly_input.csv",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_diameter_yearly_output(**kwargs):
+ """Load and return AMS-IX I-IPX diameter yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx-diameter_yearly_2024-09-22-output.csv",
+ url="https://figshare.com/ndownloader/files/49348003",
+ checksum="1a098f35c6b541569f0a5e3cbec5cafc020b98b038e0a3f2276de1b30231aed1")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx-diameter_yearly_output.csv",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_yearly_by_day(**kwargs):
+ """Load and return AMS-IX I-IPX yearly by day dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx-yearly-by-day_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49351342",
+ checksum="eee61d792e8427e5d4ea55b7e881acd646c676a8681270b63485102ca4062ebf")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx-yearly-by-day",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_daily(**kwargs):
+ """Load and return AMS-IX I-IPX daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx_daily_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347976",
+ checksum="d9752817c7b635dab6cddd329e0d4238d7e94199de242d9d3327208c77cd3aa2")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_monthly(**kwargs):
+ """Load and return AMS-IX I-IPX monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx_monthly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347979",
+ checksum="665b8841c0858e86db9aa8144d9404c175754055da5d1d23047f77f850c5a7ff")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_weekly(**kwargs):
+ """Load and return AMS-IX I-IPX weekly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx_weekly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49347982",
+ checksum="13e4cc3bb2124e03c58066e25d4beac8a323c7cfde6ad2ec6219d8798b81f69c")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_yearly_input(**kwargs):
+ """Load and return AMS-IX I-IPX yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx_yearly_2024-09-22-input.csv",
+ url="https://figshare.com/ndownloader/files/49347985",
+ checksum="8549d3cd62a3b8074aac82450676fe359bcc679c897c487a8519322123f4bd93")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx_yearly_input",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_i_ipx_yearly_output(**kwargs):
+ """Load and return AMS-IX I-IPX yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-i-ipx_yearly_2024-09-22-output.csv",
+ url="https://figshare.com/ndownloader/files/49347988",
+ checksum="450f50f262543c266503de7f89c9c5c5b07fdb5e40c0c39e82600e47e5d41ff8")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-i-ipx_yearly_output",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_yearly_by_day(**kwargs):
+ """Load and return AMS-IX ISP yearly by day dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp-yearly-by-day_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49351345",
+ checksum="d9efb4dd7158c223c45ea2c66f2455ed2f15c5a94d8db437ad0cc6e29c8b0e03")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp-yearly-by-day",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_daily(**kwargs):
+ """Load and return AMS-IX ISP daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp_daily_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49348009",
+ checksum="b839bef4522fdfd19eee291f713834caf812a1da20d871adb429b03c10c3b692")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_monthly(**kwargs):
+ """Load and return AMS-IX ISP monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp_monthly_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49348012",
+ checksum="11cb8057c1984072285c841553e478cacb0672e9153e4d72930c5af40c899875")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_weekly(**kwargs):
+ """Load and return AMS-IX ISP weekly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp_weekly_2024-09-21.csv",
+ url="https://figshare.com/ndownloader/files/49348015",
+ checksum="02ec0fbe6fdd0429ef79427a9b3c1210a0e912cb8b2d146305fdab35c9c22928")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_yearly_input(**kwargs):
+ """Load and return AMS-IX ISP yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp_yearly_2024-09-21-input.csv",
+ url="https://figshare.com/ndownloader/files/49348018",
+ checksum="d6cac12520f3ebcb33b04e2a096106b42fa082510187f83d36e53dd4a81d96a0")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp_yearly_input.csv",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_isp_yearly_output(**kwargs):
+ """Load and return AMS-IX ISP yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-isp_yearly_2024-09-21-output.csv",
+ url="https://figshare.com/ndownloader/files/49348021",
+ checksum="b7ec0614c03704388528005be5899948a84a70c418c3fd7de8bddc1d0d4db0c1")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-isp_yearly_output",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_nawas_anti_ddos_daily(**kwargs):
+ """Load and return AMS-IX NAWAS anti-DDoS daily dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-nawas-anti-ddos_daily_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49348027",
+ checksum="89682274e43228392120f1c28aaad1e2daa8c3781d1667944d7156e73c4363e2")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-nawas-anti-ddos_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_nawas_anti_ddos_monthly(**kwargs):
+ """Load and return AMS-IX NAWAS anti-DDoS monthly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-nawas-anti-ddos_monthly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49348030",
+ checksum="3d8332ac9761751604ce9f21ff03152a6051d8e2e7a3de512fb1cb3869746f36")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-nawas-anti-ddos_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_nawas_anti_ddos_weekly(**kwargs):
+ """Load and return AMS-IX NAWAS anti-DDoS weekly dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-nawas-anti-ddos_weekly_2024-09-22.csv",
+ url="https://figshare.com/ndownloader/files/49348033",
+ checksum="1b70c8e7701d2fe6a5d737c3b46cfeeff1db1d577fe06ff67cefba788bdb807b")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-nawas-anti-ddos_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_nawas_anti_ddos_yearly_input(**kwargs):
+ """Load and return AMS-IX NAWAS anti-DDoS yearly input dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-nawas-anti-ddos_yearly_2024-09-22-input.csv",
+ url="https://figshare.com/ndownloader/files/49348036",
+ checksum="3cc39c26e667b09c1eae6e31665867e1aa89dbb9e614660a7705a385222734d1")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-nawas-anti-ddos_yearly_input",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_ams_ix_nawas_anti_ddos_yearly_output(**kwargs):
+ """Load and return AMS-IX NAWAS anti-DDoS yearly output dataset."""
+ remote = RemoteFileMetadata(filename="ams-ix-nawas-anti-ddos_yearly_2024-09-22-output.csv",
+ url="https://figshare.com/ndownloader/files/49348039",
+ checksum="fa4f9c6e887aa9ecf1e98df856a894b125e892e773129a6e632549248f338776")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="ams-ix-nawas-anti-ddos_yearly_output",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
diff --git a/src/traffic_weaver/datasets/_base.py b/src/traffic_weaver/datasets/_base.py
new file mode 100644
index 0000000..3969ff8
--- /dev/null
+++ b/src/traffic_weaver/datasets/_base.py
@@ -0,0 +1,249 @@
+import hashlib
+import logging
+import os
+import pickle
+import shutil
+import time
+import warnings
+from collections import namedtuple
+from gzip import GzipFile
+from importlib import resources
+from os import environ, path, makedirs
+from tempfile import TemporaryDirectory
+from urllib.error import URLError
+from urllib.request import urlretrieve
+
+import numpy as np
+
+RESOURCES_DATASETS = 'traffic_weaver.datasets.data'
+RESOURCES_DATASETS_DESCRIPTION = 'traffic_weaver.datasets.data_description'
+
+RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"])
+
+logger = logging.getLogger(__name__)
+
+
+def get_data_home(data_home: str = None) -> str:
+ """Return the path of the data directory.
+
+ Datasets are stored in '.traffic-weaver-data' directory in the user directory.
+
+ This directory can be changed by setting `TRAFFIC_WEAVER_DATA` environment variable.
+
+ Parameters
+ ----------
+ data_home: str, default=None
+ The path to the data directory. If `None`, the default directory is `.traffic-weaver-data`.
+
+ Returns
+ -------
+ data_home: str
+ The path to the data directory.
+
+ Examples
+ --------
+ >>> import os
+ >>> from traffic_weaver.datasets import get_data_home
+ >>> data_home = get_data_home()
+ >>> os.path.exists(data_home)
+ True
+ """
+ if data_home is None:
+ data_home = environ.get("TRAFFIC_WEAVER_DATA", path.join("~", ".traffic-weaver-data"))
+ data_home = path.expanduser(data_home)
+ makedirs(data_home, exist_ok=True)
+ return data_home
+
+
+def clear_data_home(data_home: str = None):
+ """Remove all files in the data directory.
+
+ Parameters
+ ----------
+ data_home: str, default=None
+ The path to the data directory. If `None`, the default directory is `.traffic-weaver-data`.
+ """
+ data_home = get_data_home(data_home)
+ shutil.rmtree(data_home)
+
+
+def load_csv_dataset_from_resources(file_name, resources_module=RESOURCES_DATASETS, unpack_dataset_columns=False):
+ """Load dataset from resources.
+
+ Parameters
+ ----------
+ file_name: str
+ name of the file to load.
+ resources_module: str, default='traffic_weaver.datasets.data'
+ The package name where the resources are located.
+ unpack_dataset_columns: bool, default=False
+ If True, the dataset is unpacked to two separate arrays x and y.
+
+ Returns
+ -------
+ dataset: np.ndarray of shape (nr_of_samples, 2)
+ 2D array with each row representing one point in time series.
+ The first column is the x-variable and the second column is the y-variable.
+
+ """
+ data_path = resources.files(resources_module) / file_name
+ data_file = np.loadtxt(data_path, delimiter=',', dtype=np.float64)
+ if unpack_dataset_columns:
+ return data_file[:, 0], data_file[:, 1]
+ else:
+ return data_file
+
+
+def _sha256(path):
+ """Calculate the sha256 hash of the file at path."""
+ sha256hash = hashlib.sha256()
+ chunk_size = 8192
+ with open(path, "rb") as f:
+ while True:
+ buffer = f.read(chunk_size)
+ if not buffer:
+ break
+ sha256hash.update(buffer)
+ return sha256hash.hexdigest()
+
+
+def _fetch_remote(remote: RemoteFileMetadata, dirname=None, n_retries=3, delay=1.0, validate_checksum=True):
+ """Download remote dataset into path.
+
+ Fetch a dataset pointed by remote's url, save into path using remote's filename and
+ ensure integrity based on the SHA256 Checksum of the downloaded file.
+
+ Parameters
+ ----------
+ remote: RemoteFileMetadata
+ Named tuple containing remote dataset meta information: url, filename, checksum.
+
+ dirname: str
+ Directory to save the file to.
+
+ n_retries: int, default=3
+ Number of retries when HTTP errors are encountered.
+
+ delay: float, default=1.0
+ Number of seconds between retries.
+
+ validate_checksum: bool, default=True
+ If True, check the SHA256 checksum of the downloaded file.
+
+ Returns
+ -------
+ file_path: str
+ Full path of the created file.
+ """
+ file_path = remote.filename if dirname is None else path.join(dirname, remote.filename)
+
+ while True:
+ try:
+ urlretrieve(remote.url, file_path)
+ break
+ except (URLError, TimeoutError):
+ if n_retries == 0:
+ # If no more retries are left, re-raise the caught exception.
+ raise
+ warnings.warn(f"Retry downloading from url: {remote.url}")
+ n_retries -= 1
+ time.sleep(delay)
+
+ if validate_checksum:
+ checksum = _sha256(file_path)
+ if remote.checksum != checksum:
+ raise OSError("{} has an SHA256 checksum ({}) "
+ "differing from expected ({}), "
+ "file may be corrupted.".format(file_path, checksum, remote.checksum))
+ return file_path
+
+
+def load_csv_dataset_from_remote(remote: RemoteFileMetadata, dataset_filename, dataset_folder, data_home=None,
+ download_if_missing: bool = True, download_even_if_available: bool = False,
+ validate_checksum: bool = True, n_retries=3, delay=1.0, gzip=False,
+ unpack_dataset_columns=False, ):
+ """
+ Load a dataset from a remote location in csv.gz format.
+ After downloading the dataset it is stored in the cache folder for further use in pickle format.
+
+ Parameters
+ ----------
+ remote: RemoteFileMetadata
+ Named tuple containing remote dataset meta information: url, filename, checksum.
+ dataset_filename: str
+ Name for the dataset file.
+ dataset_folder: str
+ Folder in `data_home` where the dataset is stored.
+ data_home: str, default=None
+ Download cache folder fot the dataset. By default data is stored in `~/.traffic-weaver-data`.
+ download_if_missing: bool, default=True
+ If False, raise an OSError if the data is not locally available instead of
+ trying to download the data from the source.
+ download_even_if_available: bool, default=False
+ If True, download the data even if it is already available locally.
+ validate_checksum: bool, default=True
+ If True, check the SHA256 checksum of the downloaded file.
+ n_retries: int, default=3
+ Number of retries in case of HTTPError or URLError when downloading the data.
+ delay: float, default=1.0
+ Number of seconds between retries.
+ gzip: bool, default=False
+ If True, the file is assumed to be compressed in gzip format in the remote.
+ unpack_dataset_columns: bool, default=False
+ If True, the dataset is unpacked to two separate arrays x and y.
+
+ Returns
+ -------
+ dataset: np.ndarray of shape (nr_of_samples, 2)
+ 2D array with each row representing one point in time series.
+ The first column is the x-variable and the second column is the y-variable.
+ """
+ data_home = get_data_home(data_home)
+
+ dataset_dir = path.join(data_home, dataset_folder)
+ dataset_file_path = path.join(dataset_dir, dataset_filename)
+
+ available = path.exists(dataset_file_path)
+
+ dataset = None
+ if (download_if_missing and not available) or (download_if_missing and download_even_if_available and available):
+ os.makedirs(dataset_dir, exist_ok=True)
+ with TemporaryDirectory(dir=dataset_dir) as tmp_dir:
+ logger.info(f"Downloading {remote.url}")
+ archive_path = _fetch_remote(remote, dirname=tmp_dir, n_retries=n_retries, delay=delay,
+ validate_checksum=validate_checksum)
+ if gzip:
+ dataset = np.loadtxt(GzipFile(filename=archive_path), delimiter=',', dtype=np.float64)
+ else:
+ dataset = np.loadtxt(archive_path, delimiter=',', dtype=np.float64)
+ dataset_tmp_file_path = path.join(tmp_dir, dataset_filename)
+ pickle.dump(dataset, open(dataset_tmp_file_path, "wb"))
+ os.rename(dataset_tmp_file_path, dataset_file_path)
+ elif not available and not download_if_missing:
+ raise OSError("Data not found and `download_if_missing` is False")
+ if dataset is None:
+ dataset = pickle.load(open(dataset_file_path, "rb"))
+ if unpack_dataset_columns:
+ return dataset[:, 0], dataset[:, 1]
+ else:
+ return dataset
+
+
+def load_dataset_description(datasetsource_filename, resources_module=RESOURCES_DATASETS_DESCRIPTION):
+ """Load source of the dataset from filename from resources.
+
+ Parameters
+ ----------
+ datasetsource_filename: str
+ name of the file to load.
+ resources_module: str, default='traffic_weaver.datasets.datadescription'
+ The package name where the resources are located.
+
+ Returns
+ -------
+ description: str
+ Source of the dataset.
+
+ """
+ data_path = resources.files(resources_module) / datasetsource_filename
+ return data_path.read_text(encoding='utf-8')
diff --git a/src/traffic_weaver/datasets/_mix_it.py b/src/traffic_weaver/datasets/_mix_it.py
new file mode 100644
index 0000000..692d1a8
--- /dev/null
+++ b/src/traffic_weaver/datasets/_mix_it.py
@@ -0,0 +1,116 @@
+from traffic_weaver.datasets._base import RemoteFileMetadata, load_csv_dataset_from_remote, load_dataset_description
+
+DATASET_FOLDER = "mix-it"
+
+
+def mix_it_dataset_description():
+ """Get description of this dataset."""
+ return load_dataset_description("mix_it.md")
+
+
+def fetch_mix_it_bologna_daily(**kwargs):
+ """Load and return MIX Bologna daily dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-bologna_daily_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347910",
+ checksum="9f0970dfeca937818f40eab2fbc62c72a4270b6d93d4b2b9d91e3db0f6092c2a")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-bologna_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_bologna_weekly(**kwargs):
+ """Load and return MIX Bologna weekly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-bologna_weekly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347916",
+ checksum="b852c310c6f543659e7fa194d19c3a6cadd7de6b47f184843909acfee98cb781")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-bologna_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_bologna_monthly(**kwargs):
+ """Load and return MIX Bologna monthly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-bologna_monthly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347913",
+ checksum="e29881dc7c44782da783f70d9123548c4aeb75bdcd82f31e6d8622d51617db99")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-bologna_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_bologna_yearly(**kwargs):
+ """Load and return MIX Bologna yearly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-bologna_yearly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347919",
+ checksum="0cbd8c03d46f0ae76ab958fca384f3d5692fefcbbb4c99995d17d5a86e5bd401")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-bologna_yearly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_milan_daily(**kwargs):
+ """Load and return MIX Milan daily dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-milan_daily_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347922",
+ checksum="fbd873d3f91896d992508b00f42c98ac44d1a03ad42551fb09903168831e42f1")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-milan_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_milan_weekly(**kwargs):
+ """Load and return MIX Milan weekly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-milan_weekly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347928",
+ checksum="a38147bb0a4d857ac80f6440f64d7c5983faf326bae6433cad7a4b05fa98afab")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-milan_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_milan_monthly(**kwargs):
+ """Load and return MIX Milan monthly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-milan-monthly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347925",
+ checksum="30d6b7c5b8bbfbff92992052cde3ac9ed3b31aa47103fd5fdc6ab34a6ca9ef59")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-milan_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_milan_yearly(**kwargs):
+ """Load and return MIX Milan yearly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-milan_yearly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347931",
+ checksum="d3d925d1ffae871a65a7ef4f158722953352cc5f2e0a4165880c69115c56f17c")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-milan_yearly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_palermo_daily(**kwargs):
+ """Load and return MIX Palermo daily dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-palermo_daily_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347934",
+ checksum="3b1f43504f26c38e5c81247da20ce9194fc138ecb4e549f3c3af35d9bc60fb9e")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-palermo_daily",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_palermo_weekly(**kwargs):
+ """Load and return MIX Palermo weekly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-palermo_weekly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347940",
+ checksum="a239292b440a6f9cf6f4ce1b5b8766164c6aafca9b12f5352fb56247bc9a28ce")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-palermo_weekly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_palermo_monthly(**kwargs):
+ """Load and return MIX Palermo monthly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-palermo-monthly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347937",
+ checksum="8b94d22ef455ba61d16557b2c587db7aee030e052da7c8c3da9507a5f1074e6b")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-palermo_monthly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
+
+
+def fetch_mix_it_palermo_yearly(**kwargs):
+ """Load and return MIX Palermo yearly dataset."""
+ remote = RemoteFileMetadata(filename="mix-it-palermo_yearly_2024-09_04.csv",
+ url="https://figshare.com/ndownloader/files/49347943",
+ checksum="b3f0d9240803edfa6000086df38613b719b85eaa9c39d5f031fdfb3c9bee3e4f")
+ return load_csv_dataset_from_remote(remote=remote, dataset_filename="mix-it-palermo_yearly",
+ dataset_folder=DATASET_FOLDER, validate_checksum=True, **kwargs)
diff --git a/src/traffic_weaver/datasets/_sandvine.py b/src/traffic_weaver/datasets/_sandvine.py
new file mode 100644
index 0000000..172e265
--- /dev/null
+++ b/src/traffic_weaver/datasets/_sandvine.py
@@ -0,0 +1,104 @@
+from os import path
+
+from ._base import load_csv_dataset_from_resources, load_dataset_description
+
+
+def sandvine_dataset_description():
+ """Get description of this dataset.
+ """
+ return load_dataset_description("sandvine.md")
+
+
+def load_sandvine_audio():
+ """Load and return sandvine audio dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "audio.csv"))
+
+
+def load_sandvine_cloud():
+ """Load and return sandvine cloud dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "cloud.csv"))
+
+
+def load_sandvine_file_sharing():
+ """Load and return sandvine file sharing dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "file_sharing.csv"))
+
+
+def load_sandvine_fixed_social_media():
+ """Load and return sandvine fixed social media dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "fixed_social_media.csv"))
+
+
+def load_sandvine_gaming():
+ """Load and return sandvine gaming dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "gaming.csv"))
+
+
+def load_sandvine_marketplace():
+ """Load and return sandvine marketplace dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "marketplace.csv"))
+
+
+def load_sandvine_measurements():
+ """Load and return sandvine measurements dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "measurements.csv"))
+
+
+def load_sandvine_messaging():
+ """Load and return sandvine messaging dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "messaging.csv"))
+
+
+def load_sandvine_mobile_messaging():
+ """Load and return sandvine mobile messaging dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "mobile_messaging.csv"))
+
+
+def load_sandvine_mobile_social_media():
+ """Load and return sandvine mobile social media dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "mobile_social_media.csv"))
+
+
+def load_sandvine_mobile_video():
+ """Load and return sandvine mobile video dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "mobile_video.csv"))
+
+
+def load_sandvine_mobile_youtube():
+ """Load and return sandvine mobile youtube dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "mobile_youtube.csv"))
+
+
+def load_sandvine_mobile_zoom():
+ """Load and return sandvine mobile zoom dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "mobile_zoom.csv"))
+
+
+def load_sandvine_snapchat():
+ """Load and return sandvine snapchat dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "snapchat.csv"))
+
+
+def load_sandvine_social_networking():
+ """Load and return sandvine social networking dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "social_networking.csv"))
+
+
+def load_sandvine_tiktok():
+ """Load and return sandvine tiktok dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "tiktok.csv"))
+
+
+def load_sandvine_video_streaming():
+ """Load and return sandvine video streaming dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "video_streaming.csv"))
+
+
+def load_sandvine_vpn_and_security():
+ """Load and return sandvine vpn and security dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "vpn_and_security.csv"))
+
+
+def load_sandvine_web():
+ """Load and return sandvine web dataset."""
+ return load_csv_dataset_from_resources(path.join("sandvine", "web.csv"))
diff --git a/src/traffic_weaver/datasets/data/__init__.py b/src/traffic_weaver/datasets/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/traffic_weaver/datasets/data/sandvine/audio.csv b/src/traffic_weaver/datasets/data/sandvine/audio.csv
new file mode 100644
index 0000000..d73622d
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/audio.csv
@@ -0,0 +1,24 @@
+0, 0.0555
+1, 0.027
+2, 0.03
+3, 0.033
+4, 0.0375
+5, 0.075
+6, 0.1032
+7, 0.1032
+8, 0.1247
+9, 0.12255
+10, 0.1247
+11, 0.1333
+12, 0.096
+13, 0.096
+14, 0.1185
+15, 0.1215
+16, 0.1185
+17, 0.1245
+18, 0.111
+19, 0.099
+20, 0.0915
+21, 0.081
+22, 0.0825
+23, 0.069
diff --git a/src/traffic_weaver/datasets/data/sandvine/cloud.csv b/src/traffic_weaver/datasets/data/sandvine/cloud.csv
new file mode 100644
index 0000000..5f31f59
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/cloud.csv
@@ -0,0 +1,24 @@
+0, 0.10175
+1, 0.0495
+2, 0.055
+3, 0.0605
+4, 0.06875
+5, 0.1375
+6, 0.2208
+7, 0.2208
+8, 0.2668
+9, 0.2622
+10, 0.2668
+11, 0.2852
+12, 0.1792
+13, 0.1792
+14, 0.2212
+15, 0.2268
+16, 0.2212
+17, 0.2324
+18, 0.2035
+19, 0.1815
+20, 0.16775
+21, 0.1485
+22, 0.15125
+23, 0.1265
diff --git a/src/traffic_weaver/datasets/data/sandvine/file_sharing.csv b/src/traffic_weaver/datasets/data/sandvine/file_sharing.csv
new file mode 100644
index 0000000..948abb1
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/file_sharing.csv
@@ -0,0 +1,24 @@
+0, 0.0037
+1, 0.0018
+2, 0.002
+3, 0.0022
+4, 0.0025
+5, 0.005
+6, 0.0048
+7, 0.0048
+8, 0.0058
+9, 0.0057
+10, 0.0058
+11, 0.0062
+12, 0.0096
+13, 0.0096
+14, 0.01185
+15, 0.01215
+16, 0.01185
+17, 0.01245
+18, 0.0074
+19, 0.0066
+20, 0.0061
+21, 0.0054
+22, 0.0055
+23, 0.0046
diff --git a/src/traffic_weaver/datasets/data/sandvine/fixed_social_media.csv b/src/traffic_weaver/datasets/data/sandvine/fixed_social_media.csv
new file mode 100644
index 0000000..591f993
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/fixed_social_media.csv
@@ -0,0 +1,24 @@
+0, 1.9
+1, 1.45
+2, 1.25
+3, 1.05
+4, 1.1
+5, 1.2
+6, 1.4
+7, 1.75
+8, 2.15
+9, 2.45
+10, 2.55
+11, 2.5
+12, 2.65
+13, 2.65
+14, 2.6
+15, 2.5
+16, 2.45
+17, 2.4
+18, 2.45
+19, 2.45
+20, 2.65
+21, 2.7
+22, 2.6
+23, 2.3
diff --git a/src/traffic_weaver/datasets/data/sandvine/gaming.csv b/src/traffic_weaver/datasets/data/sandvine/gaming.csv
new file mode 100644
index 0000000..12d6b99
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/gaming.csv
@@ -0,0 +1,24 @@
+0, 0.02405
+1, 0.0117
+2, 0.013
+3, 0.0143
+4, 0.01625
+5, 0.0325
+6, 0.012
+7, 0.012
+8, 0.0145
+9, 0.01425
+10, 0.0145
+11, 0.0155
+12, 0.0544
+13, 0.0544
+14, 0.06715
+15, 0.06885
+16, 0.06715
+17, 0.07055
+18, 0.0481
+19, 0.0429
+20, 0.03965
+21, 0.0351
+22, 0.03575
+23, 0.0299
diff --git a/src/traffic_weaver/datasets/data/sandvine/marketplace.csv b/src/traffic_weaver/datasets/data/sandvine/marketplace.csv
new file mode 100644
index 0000000..f122212
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/marketplace.csv
@@ -0,0 +1,24 @@
+0, 0.03145
+1, 0.0153
+2, 0.017
+3, 0.0187
+4, 0.02125
+5, 0.0425
+6, 0.06
+7, 0.06
+8, 0.0725
+9, 0.07125
+10, 0.0725
+11, 0.0775
+12, 0.0576
+13, 0.0576
+14, 0.0711
+15, 0.0729
+16, 0.0711
+17, 0.0747
+18, 0.0629
+19, 0.0561
+20, 0.05185
+21, 0.0459
+22, 0.04675
+23, 0.0391
diff --git a/src/traffic_weaver/datasets/data/sandvine/measurements.csv b/src/traffic_weaver/datasets/data/sandvine/measurements.csv
new file mode 100644
index 0000000..4e223a4
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/measurements.csv
@@ -0,0 +1,24 @@
+0, 1.85
+1, 0.9
+2, 1
+3, 1.1
+4, 1.25
+5, 2.5
+6, 2.4
+7, 2.4
+8, 2.9
+9, 2.85
+10, 2.9
+11, 3.1
+12, 3.2
+13, 3.2
+14, 3.95
+15, 4.05
+16, 3.95
+17, 4.15
+18, 3.7
+19, 3.3
+20, 3.05
+21, 2.7
+22, 2.75
+23, 2.3
diff --git a/src/traffic_weaver/datasets/data/sandvine/messaging.csv b/src/traffic_weaver/datasets/data/sandvine/messaging.csv
new file mode 100644
index 0000000..ff66db2
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/messaging.csv
@@ -0,0 +1,24 @@
+0, 0.10545
+1, 0.0513
+2, 0.057
+3, 0.0627
+4, 0.07125
+5, 0.1425
+6, 0.1512
+7, 0.1512
+8, 0.1827
+9, 0.17955
+10, 0.1827
+11, 0.1953
+12, 0.1856
+13, 0.1856
+14, 0.2291
+15, 0.2349
+16, 0.2291
+17, 0.2407
+18, 0.2109
+19, 0.1881
+20, 0.17385
+21, 0.1539
+22, 0.15675
+23, 0.1311
diff --git a/src/traffic_weaver/datasets/data/sandvine/mobile_messaging.csv b/src/traffic_weaver/datasets/data/sandvine/mobile_messaging.csv
new file mode 100644
index 0000000..b458edc
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/mobile_messaging.csv
@@ -0,0 +1,24 @@
+0, 0.75
+1, 0.3
+2, 0.2
+3, 0.25
+4, 0.4
+5, 0.8
+6, 0.95
+7, 1.35
+8, 1.4
+9, 1.6
+10, 1.55
+11, 1.45
+12, 1.7
+13, 2.15
+14, 2.1
+15, 2.15
+16, 2.1
+17, 2.5
+18, 1.6
+19, 1.55
+20, 1.35
+21, 1.3
+22, 1.1
+23, 0.9
diff --git a/src/traffic_weaver/datasets/data/sandvine/mobile_social_media.csv b/src/traffic_weaver/datasets/data/sandvine/mobile_social_media.csv
new file mode 100644
index 0000000..942270c
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/mobile_social_media.csv
@@ -0,0 +1,24 @@
+0, 0.75
+1, 0.2
+2, 0.2
+3, 0.35
+4, 0.3
+5, 1
+6, 1.15
+7, 1.5
+8, 1.5
+9, 1.7
+10, 1.5
+11, 1.8
+12, 1.85
+13, 2.2
+14, 2.35
+15, 2.65
+16, 2.6
+17, 2.6
+18, 2.4
+19, 1.7
+20, 1.5
+21, 1.6
+22, 1.4
+23, 1.3
diff --git a/src/traffic_weaver/datasets/data/sandvine/mobile_video.csv b/src/traffic_weaver/datasets/data/sandvine/mobile_video.csv
new file mode 100644
index 0000000..8dba370
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/mobile_video.csv
@@ -0,0 +1,24 @@
+0, 1.5
+1, 0.85
+2, 0.95
+3, 1
+4, 1
+5, 1
+6, 1.25
+7, 1.3
+8, 1.9
+9, 1.9
+10, 1.7
+11, 1.7
+12, 2
+13, 1.75
+14, 2.05
+15, 2.4
+16, 2.3
+17, 2.55
+18, 2.5
+19, 2.2
+20, 2.1
+21, 2
+22, 2
+23, 1.7
diff --git a/src/traffic_weaver/datasets/data/sandvine/mobile_youtube.csv b/src/traffic_weaver/datasets/data/sandvine/mobile_youtube.csv
new file mode 100644
index 0000000..1c59eef
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/mobile_youtube.csv
@@ -0,0 +1,24 @@
+0, 1.45
+1, 0.75
+2, 0.8
+3, 0.9
+4, 0.85
+5, 0.8
+6, 1.3
+7, 2
+8, 1.95
+9, 1.4
+10, 1.5
+11, 1.55
+12, 1.85
+13, 1.85
+14, 2.3
+15, 2.25
+16, 2.15
+17, 1.9
+18, 2.35
+19, 2.05
+20, 1.8
+21, 1.65
+22, 1.95
+23, 1.55
diff --git a/src/traffic_weaver/datasets/data/sandvine/mobile_zoom.csv b/src/traffic_weaver/datasets/data/sandvine/mobile_zoom.csv
new file mode 100644
index 0000000..28e6903
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/mobile_zoom.csv
@@ -0,0 +1,24 @@
+0, 0.1
+1, 0.02
+2, 0.3
+3, 0.02
+4, 0.05
+5, 0.35
+6, 1.25
+7, 0.3
+8, 0.55
+9, 0.55
+10, 2.05
+11, 2.6
+12, 1.7
+13, 1.25
+14, 0.95
+15, 1.7
+16, 0.5
+17, 1.05
+18, 0.6
+19, 2.1
+20, 1.6
+21, 0.25
+22, 0.25
+23, 0.3
diff --git a/src/traffic_weaver/datasets/data/sandvine/snapchat.csv b/src/traffic_weaver/datasets/data/sandvine/snapchat.csv
new file mode 100644
index 0000000..89deda7
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/snapchat.csv
@@ -0,0 +1,24 @@
+0, 1.1
+1, 0.5
+2, 0.4
+3, 0.4
+4, 0.15
+5, 0.45
+6, 0.4
+7, 0.5
+8, 0.75
+9, 0.65
+10, 0.7
+11, 1.1
+12, 1.3
+13, 2.25
+14, 1.65
+15, 1.6
+16, 2.1
+17, 2.55
+18, 1.7
+19, 1.15
+20, 1.2
+21, 0.95
+22, 0.8
+23, 1.05
diff --git a/src/traffic_weaver/datasets/data/sandvine/social_networking.csv b/src/traffic_weaver/datasets/data/sandvine/social_networking.csv
new file mode 100644
index 0000000..2f4314e
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/social_networking.csv
@@ -0,0 +1,24 @@
+0, 0.481
+1, 0.234
+2, 0.26
+3, 0.286
+4, 0.325
+5, 0.65
+6, 0.696
+7, 0.696
+8, 0.841
+9, 0.8265
+10, 0.841
+11, 0.899
+12, 0.864
+13, 0.864
+14, 1.0665
+15, 1.0935
+16, 1.0665
+17, 1.1205
+18, 0.962
+19, 0.858
+20, 0.793
+21, 0.702
+22, 0.715
+23, 0.598
diff --git a/src/traffic_weaver/datasets/data/sandvine/tiktok.csv b/src/traffic_weaver/datasets/data/sandvine/tiktok.csv
new file mode 100644
index 0000000..1c33873
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/tiktok.csv
@@ -0,0 +1,24 @@
+0, 0.45
+1, 0.4
+2, 0.45
+3, 0.2
+4, 0.3
+5, 0.9
+6, 0.9
+7, 1.3
+8, 1.65
+9, 1.15
+10, 1.6
+11, 1.6
+12, 1.5
+13, 2.1
+14, 1.45
+15, 1.6
+16, 1.45
+17, 2.55
+18, 1.4
+19, 2.45
+20, 1.6
+21, 1
+22, 1.35
+23, 1.05
diff --git a/src/traffic_weaver/datasets/data/sandvine/video_streaming.csv b/src/traffic_weaver/datasets/data/sandvine/video_streaming.csv
new file mode 100644
index 0000000..848ca72
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/video_streaming.csv
@@ -0,0 +1,24 @@
+0, 0.60125
+1, 0.2925
+2, 0.325
+3, 0.3575
+4, 0.40625
+5, 0.8125
+6, 0.4608
+7, 0.4608
+8, 0.5568
+9, 0.5472
+10, 0.5568
+11, 0.5952
+12, 0.9152
+13, 0.9152
+14, 1.1297
+15, 1.1583
+16, 1.1297
+17, 1.1869
+18, 1.2025
+19, 1.0725
+20, 0.99125
+21, 0.8775
+22, 0.89375
+23, 0.7475
diff --git a/src/traffic_weaver/datasets/data/sandvine/vpn_and_security.csv b/src/traffic_weaver/datasets/data/sandvine/vpn_and_security.csv
new file mode 100644
index 0000000..fab6b1a
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/vpn_and_security.csv
@@ -0,0 +1,24 @@
+0, 0.06105
+1, 0.0297
+2, 0.033
+3, 0.0363
+4, 0.04125
+5, 0.0825
+6, 0.0816
+7, 0.0816
+8, 0.0986
+9, 0.0969
+10, 0.0986
+11, 0.1054
+12, 0.1088
+13, 0.1088
+14, 0.1343
+15, 0.1377
+16, 0.1343
+17, 0.1411
+18, 0.1221
+19, 0.1089
+20, 0.10065
+21, 0.0891
+22, 0.09075
+23, 0.0759
diff --git a/src/traffic_weaver/datasets/data/sandvine/web.csv b/src/traffic_weaver/datasets/data/sandvine/web.csv
new file mode 100644
index 0000000..00ad117
--- /dev/null
+++ b/src/traffic_weaver/datasets/data/sandvine/web.csv
@@ -0,0 +1,24 @@
+0, 0.3848
+1, 0.1872
+2, 0.208
+3, 0.2288
+4, 0.26
+5, 0.52
+6, 0.6096
+7, 0.6096
+8, 0.7366
+9, 0.7239
+10, 0.7366
+11, 0.7874
+12, 0.7296
+13, 0.7296
+14, 0.9006
+15, 0.9234
+16, 0.9006
+17, 0.9462
+18, 0.7696
+19, 0.6864
+20, 0.6344
+21, 0.5616
+22, 0.572
+23, 0.4784
diff --git a/src/traffic_weaver/datasets/data_description/__init__.py b/src/traffic_weaver/datasets/data_description/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/traffic_weaver/datasets/data_description/ams_ix.md b/src/traffic_weaver/datasets/data_description/ams_ix.md
new file mode 100644
index 0000000..e876b17
--- /dev/null
+++ b/src/traffic_weaver/datasets/data_description/ams_ix.md
@@ -0,0 +1,52 @@
+ams-ix.net. https://stats.ams-ix.net/index.html (accessed Sep. 21, 2024)
+
+## Files description
+
+* ams-ix - statistics showing total aggregate volume of all parties connected to AMS-IX.
+
+* ams-ix-isp - statistics showing total aggregate volume of all parties connected to the Public Peering (ISP) LAN at AMS-IX.
+
+* ams-ix-grx - statistics showing total aggregate volume of all parties connected to the GRX LAN at AMS-IX.
+
+* ams-ix-i-ipx, ams-ix-i-ipx-diameter - statistics showing total aggregate volume of all parties connected to the AMS-IX I-IPX and I-IPX Diameter LANs.
+
+* ams-ix-nawas-anti-ddos - statistics showing aggregate volume of clean, scrubbed traffic for all parties connected to the NaWas DDoS Scrubbing service at AMS-IX.
+
+## Units
+
+| name | y-axis |
+|---|---|
+| ams-ix-yearly-by-day_2024-09-21 | PByte per day |
+| ams-ix_daily_2024-09-21 | Tbps |
+| ams-ix_monthly_2024-09-21 | Tbps |
+| ams-ix_weekly_2024-09-21 | Tbps |
+| ams-ix_yearly_2024-09-21-input-proc | Tbps |
+| ams-ix-isp-yearly-by-day_2024-09-21 | PByte per day |
+| ams-ix-isp_daily_2024-09-21 | Tbps |
+| ams-ix-isp_monthly_2024-09-21 | Tbps |
+| ams-ix-isp_weekly_2024-09-21 | Tbps |
+| ams-ix-isp_yearly_2024-09-21-input | Tbps |
+| ams-ix-isp_yearly_2024-09-21-output | Tbps |
+| ams-ix-grx-yearly-by-day_2024-09-21 | PByte per day |
+| ams-ix-grx_yearly_2024-09-21-input | Tbps |
+| ams-ix-grx_yearly_2024-09-21-output | Tbps |
+| ams-ix-grx_daily_2024-09-21 | Gbps |
+| ams-ix-grx_monthly_2024-09-21 | Gbps |
+| ams-ix-grx_yearly_2024-09-21-input | Gbps |
+| ams-ix-grx_yearly_2024-09-21-output | Gbps |
+| ams-ix-i-ipx-yearly-by-day_2024-09-21 | TByte per day |
+| ams-ix-i-ipx_daily_2024-09-22 | Gbps |
+| ams-ix-i-ipx_monthly_2024-09-22 | Gbps |
+| ams-ix-i-ipx_weekly_2024-09-22 | Gbps |
+| ams-ix-i-ipx_yearly_2024-09-22-input | Gbps |
+| ams-ix-i-ipx_yearly_2024-09-22-output | Gbps |
+| ams-ix-i-ipx-diameter_daily_2024-09-22 | Mbps |
+| ams-ix-i-ipx-diameter_monthly_2024-09-22 | Mbps |
+| ams-ix-i-ipx-diameter_weekly_2024-09-22 | Mbps |
+| ams-ix-i-ipx-diameter_yearly_2024-09-22-input | Mbps |
+| ams-ix-i-ipx-diameter_yearly_2024-09-22-output | Mbps |
+| ams-ix-nawas-anti-ddos_daily_2024-09-22 | Gbps |
+| ams-ix-nawas-anti-ddos_monthly_2024-09-22 | Gbps |
+| ams-ix-nawas-anti-ddos_weekly_2024-09-22 | Gbps |
+| ams-ix-nawas-anti-ddos_yearly_2024-09-22-input | Gbps |
+| ams-ix-nawas-anti-ddos_yearly_2024-09-22-output | Gbps |
diff --git a/src/traffic_weaver/datasets/data_description/mix_it.md b/src/traffic_weaver/datasets/data_description/mix_it.md
new file mode 100644
index 0000000..1f5c484
--- /dev/null
+++ b/src/traffic_weaver/datasets/data_description/mix_it.md
@@ -0,0 +1,26 @@
+mix-it.net. https://www.mix-it.net/en/stats/ (accessed Sep. 4, 2024)
+
+## Files description
+
+* mix-it-bologna - statistics showing MIX Bologna traffic.
+
+* mix-it-milan - statistics showing MIX Milan traffic.
+
+* mix-it-palermo - statistics showing MIX Palermo traffic.
+
+## Units
+
+| name | y-axis |
+|---|---|
+| mix-it-bologna_daily_2024-09-04 | Gbps |
+| mix-it-bologna_monthly_2024-09_04 | Gbps |
+| mix-it-bologna_weekly_2024-09_04 | Gbps |
+| mix-it-bologna_yearly_2024-09_04 | Gbps |
+| mix-it-milan_daily_2024-09_04 | Tbps |
+| mix-it-milan_monthly_2024-09_04 | Tbps |
+| mix-it-milan_weekly_2024-09_04 | Tbps |
+| mix-it-milan_yearly_2024-09_04 | Tbps |
+| mix-it-palermo_daily_2024-09-04 | Gbps |
+| mix-it-palermo_monthly_2024-09-04 | Gbps |
+| mix-it-palermo_weekly_2024-09-04 | Gbps |
+| mix-it-palermo_yearly_2024-09-04 | Gbps |
diff --git a/src/traffic_weaver/datasets/data_description/sandvine.md b/src/traffic_weaver/datasets/data_description/sandvine.md
new file mode 100644
index 0000000..7753117
--- /dev/null
+++ b/src/traffic_weaver/datasets/data_description/sandvine.md
@@ -0,0 +1,2 @@
+Sandvine, "The Mobile Internet Phenomena Report", Sandvine, 2021
+https://www.sandvine.com/download-mobile-internet-phenomena-report-2021
diff --git a/src/traffic_weaver/datasets/example_datasets.json b/src/traffic_weaver/datasets/example_datasets.json
deleted file mode 100644
index fc68581..0000000
--- a/src/traffic_weaver/datasets/example_datasets.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "mobile_video": "[1.5,0.85,0.95,1,1,1,1.25,1.3,1.9,1.9,1.7,1.7,2,1.75,2.05,2.4,2.3,2.55,2.5,2.2,2.1,2,2,1.7]",
- "mobile_youtube": "[1.45,0.75,0.8,0.9,0.85,0.8,1.3,2,1.95,1.4,1.5,1.55,1.85,1.85,2.3,2.25,2.15,1.9,2.35,2.05,1.8,1.65,1.95,1.55]",
- "mobile_social_media": "[0.75,0.2,0.2,0.35,0.3,1,1.15,1.5,1.5,1.7,1.5,1.8,1.85,2.2,2.35,2.65,2.6,2.6,2.4,1.7,1.5,1.6,1.4,1.3]",
- "fixed_social_media": "[1.9,1.45,1.25,1.05,1.1,1.2,1.4,1.75,2.15,2.45,2.55,2.5,2.65,2.65,2.6,2.5,2.45,2.4,2.45,2.45,2.65,2.7,2.6,2.3]",
- "tiktok": "[0.45,0.4,0.45,0.2,0.3,0.9,0.9,1.3,1.65,1.15,1.6,1.6,1.5,2.1,1.45,1.6,1.45,2.55,1.4,2.45,1.6,1,1.35,1.05]",
- "snapchat": "[1.1,0.5,0.4,0.4,0.15,0.45,0.4,0.5,0.75,0.65,0.7,1.1,1.3,2.25,1.65,1.6,2.1,2.55,1.7,1.15,1.2,0.95,0.8,1.05]",
- "mobile_messaging": "[0.75,0.3,0.2,0.25,0.4,0.8,0.95,1.35,1.4,1.6,1.55,1.45,1.7,2.15,2.1,2.15,2.1,2.5,1.6,1.55,1.35,1.3,1.1,0.9]",
- "mobile_zoom": "[0.1,0.02,0.3,0.02,0.05,0.35,1.25,0.3,0.55,0.55,2.05,2.6,1.7,1.25,0.95,1.7,0.5,1.05,0.6,2.1,1.6,0.25,0.25,0.3]",
- "measurements": "[1.85,0.9,1,1.1,1.25,2.5,2.4,2.4,2.9,2.85,2.9,3.1,3.2,3.2,3.95,4.05,3.95,4.15,3.7,3.3,3.05,2.7,2.75,2.3]",
- "social_networking": "[0.481,0.234,0.26,0.286,0.325,0.65,0.696,0.696,0.841,0.8265,0.841,0.899,0.864,0.864,1.0665,1.0935,1.0665,1.1205,0.962,0.858,0.793,0.702,0.715,0.598]",
- "web": "[0.3848,0.1872,0.208,0.2288,0.26,0.52,0.6096,0.6096,0.7366,0.7239,0.7366,0.7874,0.7296,0.7296,0.9006,0.9234,0.9006,0.9462,0.7696,0.6864,0.6344,0.5616,0.572,0.4784]",
- "video_streaming": "[0.60125,0.2925,0.325,0.3575,0.40625,0.8125,0.4608,0.4608,0.5568,0.5472,0.5568,0.5952,0.9152,0.9152,1.1297,1.1583,1.1297,1.1869,1.2025,1.0725,0.99125,0.8775,0.89375,0.7475]",
- "cloud": "[0.10175,0.0495,0.055,0.0605,0.06875,0.1375,0.2208,0.2208,0.2668,0.2622,0.2668,0.2852,0.1792,0.1792,0.2212,0.2268,0.2212,0.2324,0.2035,0.1815,0.16775,0.1485,0.15125,0.1265]",
- "messaging": "[0.10545,0.0513,0.057,0.0627,0.07125,0.1425,0.1512,0.1512,0.1827,0.17955,0.1827,0.1953,0.1856,0.1856,0.2291,0.2349,0.2291,0.2407,0.2109,0.1881,0.17385,0.1539,0.15675,0.1311]",
- "audio": "[0.0555,0.027,0.03,0.033,0.0375,0.075,0.1032,0.1032,0.1247,0.12255,0.1247,0.1333,0.096,0.096,0.1185,0.1215,0.1185,0.1245,0.111,0.099,0.0915,0.081,0.0825,0.069]",
- "vpn_and_security": "[0.06105,0.0297,0.033,0.0363,0.04125,0.0825,0.0816,0.0816,0.0986,0.0969,0.0986,0.1054,0.1088,0.1088,0.1343,0.1377,0.1343,0.1411,0.1221,0.1089,0.10065,0.0891,0.09075,0.0759]",
- "marketplace": "[0.03145,0.0153,0.017,0.0187,0.02125,0.0425,0.06,0.06,0.0725,0.07125,0.0725,0.0775,0.0576,0.0576,0.0711,0.0729,0.0711,0.0747,0.0629,0.0561,0.05185,0.0459,0.04675,0.0391]",
- "file_sharing": "[0.0037,0.0018,0.002,0.0022,0.0025,0.005,0.0048,0.0048,0.0058,0.0057,0.0058,0.0062,0.0096,0.0096,0.01185,0.01215,0.01185,0.01245,0.0074,0.0066,0.0061,0.0054,0.0055,0.0046]",
- "gaming": "[0.02405,0.0117,0.013,0.0143,0.01625,0.0325,0.012,0.012,0.0145,0.01425,0.0145,0.0155,0.0544,0.0544,0.06715,0.06885,0.06715,0.07055,0.0481,0.0429,0.03965,0.0351,0.03575,0.0299]"
-}
diff --git a/src/traffic_weaver/interval.py b/src/traffic_weaver/interval.py
index 91b3f23..deecb18 100644
--- a/src/traffic_weaver/interval.py
+++ b/src/traffic_weaver/interval.py
@@ -7,7 +7,7 @@
import numpy as np
-from .array_utils import (
+from .sorted_array_utils import (
oversample_linspace,
oversample_piecewise_constant,
extend_linspace,
@@ -46,7 +46,7 @@ def __init__(self, a: Union[np.ndarray, List], n: int = 1):
>>> print(a[1])
1
>>> a[1, 2] = 15
- >>> a[1, 2]
+ >>> a[1, 2].item()
15
"""
diff --git a/src/traffic_weaver/match.py b/src/traffic_weaver/match.py
index f9b1e61..31638ed 100644
--- a/src/traffic_weaver/match.py
+++ b/src/traffic_weaver/match.py
@@ -1,25 +1,38 @@
+import warnings
+
import numpy as np
-from .array_utils import left_piecewise_integral
from .process import spline_smooth
+from .sorted_array_utils import find_closest_element_indices_to_values, integral, sum_over_indices
-def integral_matching_reference_stretch(x, y, x_ref, y_ref, alpha=1.0, s=None):
+def integral_matching_reference_stretch(x, y, x_ref, y_ref, fixed_points_in_x=None, fixed_points_indices_in_x=None,
+ fixed_points_finding_strategy: str = 'closest',
+ target_function_integral_method: str = 'trapezoid',
+ reference_function_integral_method: str = 'trapezoid', alpha=1.0, s=None):
"""Stretch function to match integrals in reference.
- Stretch function to match integrals piecewise constant function over the
- same domain.
+ Stretch function evaluated in '(x, y)' to match integral
+ of reference function '(x_ref, y_ref)' over the same range.
+ Fixed points can specify points which should not be moved of the target function.
+ By default, they match the points in 'x' that are closest to the points in the 'x_ref'.
+ The method may work not as expected when number of fixed points is higher than
+ the half of the number of points in the target function 'x'. If number of points in 'x_ref'
+ is higher than the number of points in 'x', provide 'fixed_points_in_x' or 'fixed_points_indices_in_x'
+ in order to allow the method to work correctly.
- .. image:: /_static/gfx/integral_matching_reference.pdf
+ 'target_function_integral_method' specifies the method to calculate integral of '(x, y)' function.
+ 'reference_function_integral_method' specifies the method to calculate integral of '(x_ref, y_ref)' function.
- Reference function is piecewise linear function that can contain only a
- subset of points in original function `x`.
- The target function is stretched according to the integral values and
- intervals provided by the reference function.
+ Use 'rectangle' integral calculation if target/reference functions are representing averaged value over the
+ interval.
+ Use 'trapezoidal' integral calculation if target/reference functions are representing sampled values over time.
+
+ .. image:: /_static/gfx/integral_matching_reference.pdf
Parameters
----------
- x: 1-D array-like of size n
+ 1-D array-like of size n
Independent variable in strictly increasing order.
y: 1-D array-like of size n
Dependent variable.
@@ -29,6 +42,24 @@ def integral_matching_reference_stretch(x, y, x_ref, y_ref, alpha=1.0, s=None):
of points of `x`.
y_ref: 1-D array-like of size m
Dependent variable of reference function.
+ fixed_points_in_x: array-like, optional
+ Points that should not be moved.
+ By default, they are the closest points in `x` to the points in `x_ref`.
+ fixed_points_indices_in_x: array-like, optional
+ Indices of points in `x` that should not be moved.
+ If set, fixed_points_in_x is set according to that points.
+ fixed_points_finding_strategy: str, default: 'closest'
+ Strategy to find fixed points, if fixed points are not specified.
+ Available options:
+ 'closest': closest element (lower or higher)
+ 'lower': closest lower or equal element
+ 'higher': closest higher or equal element
+ target_function_integral_method: str, default: 'trapezoid'
+ Method to calculate integral of target function.
+ Available options: 'trapezoid', 'rectangle'
+ reference_function_integral_method: str, default: 'trapezoid'
+ Method to calculate integral of reference function.
+ Available options: 'trapezoid', 'rectangle'
alpha: scalar, default: 1
Stretching exponent factor.
Scales how points are stretched if they are closer to the center point.
@@ -39,11 +70,6 @@ def integral_matching_reference_stretch(x, y, x_ref, y_ref, alpha=1.0, s=None):
A smoothing condition for spine smoothing.
If None, no smoothing is applied.
- Raises
- ------
- ValueError
- if `x_ref` contains some point that are not present in `x`.
-
Returns
-------
ndarray
@@ -59,26 +85,56 @@ def integral_matching_reference_stretch(x, y, x_ref, y_ref, alpha=1.0, s=None):
>>> y_ref = np.array([2.5, 2.5, 4, 3.5])
>>> x_ref = np.array([0, 1, 2, 3])
>>> integral_matching_reference_stretch(x, y, x_ref, y_ref, s=0.0)
- array([1. , 3.5, 2. , 2.5, 3. , 4.5, 4. ])
-
+ array([1. , 3.5, 2. , 4. , 3. , 4. , 4. ])
"""
- interval_point_indices = np.where(np.in1d(x, x_ref))[
- 0
- ] # get indices of elements in x array that are in x_ref array
- if len(interval_point_indices) != len(x_ref):
- raise ValueError("`x_ref` contains some points that are not in the `x`")
- integral_values = left_piecewise_integral(x_ref, y_ref)
- res_y = interval_integral_matching_stretch(
- x,
- y,
- integral_values=integral_values,
- interval_point_indices=interval_point_indices,
- alpha=alpha,
- )
+ x, y, x_ref, y_ref = np.asarray(x), np.asarray(y), np.asarray(x_ref), np.asarray(y_ref)
+
+ if fixed_points_in_x is not None:
+ if len(fixed_points_in_x) > len(x):
+ raise ValueError("Size of 'fixed_points_in_x' cannot be larger than the number of points in 'x'")
+ if fixed_points_indices_in_x is not None:
+ if len(fixed_points_indices_in_x) > len(x):
+ raise ValueError("Size of 'fixed_points_indices_in_x' cannot be larger than the number of points in 'x'")
+
+ if fixed_points_indices_in_x is not None:
+ fixed_points_indices_in_x = np.unique(fixed_points_indices_in_x)
+ fixed_points_in_x = x.take(fixed_points_indices_in_x)
+ fixed_points_in_x_ref = x_ref.take(
+ find_closest_element_indices_to_values(x_ref, fixed_points_in_x, strategy='closest'))
+ fixed_points_in_x_ref_indices = np.where(np.in1d(x_ref, fixed_points_in_x_ref))[0]
+ else:
+ if fixed_points_in_x is None:
+ fixed_points_in_x = x.take(
+ find_closest_element_indices_to_values(x, x_ref, strategy=fixed_points_finding_strategy))
+ fixed_points_in_x = np.unique(fixed_points_in_x)
+ fixed_points_in_x_ref_indices = np.arange(len(x_ref))
+ else:
+ fixed_points_in_x = np.asarray(fixed_points_in_x)
+ fixed_points_in_x = np.unique(fixed_points_in_x)
+ fixed_points_in_x_ref = x_ref.take(
+ find_closest_element_indices_to_values(x_ref, fixed_points_in_x, strategy='closest'))
+ fixed_points_in_x_ref_indices = np.where(np.in1d(x_ref, fixed_points_in_x_ref))[0]
+
+ # get indices of elements in x that should be fixed
+ fixed_points_indices_in_x = np.where(np.in1d(x, fixed_points_in_x))[0]
+
+ if len(fixed_points_in_x) >= len(x) + 1 / 2:
+ warnings.warn("Integral matching may work not as expected when number of fixed points is higher than"
+ " the half of the number of points in the target function 'x'")
+
+ if len(fixed_points_indices_in_x) != len(fixed_points_in_x):
+ raise ValueError("some of the fixed points are not are not in the `x`")
+
+ integral_values = integral(x_ref, y_ref, reference_function_integral_method)
+ integral_values = sum_over_indices(integral_values, fixed_points_in_x_ref_indices)
+
+ res_y = _interval_integral_matching_stretch(x, y, integral_values=integral_values,
+ integral_method=target_function_integral_method,
+ fixed_points_indices_in_x=fixed_points_indices_in_x, alpha=alpha)
return res_y if s is None else spline_smooth(x, res_y, s)(x)
-def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None):
+def _integral_matching_stretch(x, y, integral_value=0, integral_method='trapezoid', dx=1.0, alpha=1.0, s=None):
r"""Stretches function y=f(x) to match integral value.
.. image:: /_static/gfx/integral_matching.pdf
@@ -86,8 +142,8 @@ def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None)
This method creates function :math:`z=g(x)` from :math:`y=f(x)` such that the
integral of :math:`g(x)` is equal to the provided integral value, and points
are transformed inversely proportionally to the distance from the
- function domain center. Function integral is numerically approximated using
- trapezoidal rule on provided points.
+ function domain center. Function integral is numerically approximated on provided points
+ according to the 'integral_method'.
Parameters
----------
@@ -98,6 +154,9 @@ def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None)
Dependent variable.
integral_value: float, default: 0
Target integral value.
+ integral_method: str, default='trapezoid'
+ Method to calculate integral of target function.
+ Available options: 'trapezoid', 'rectangle'
dx : scalar, optional
The spacing between sample points when `x` is None. By default, it is 1.
alpha: scalar, default: 1
@@ -143,7 +202,7 @@ def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None)
Each point is :math:`z_i` is shifted by :math:`\hat{y} \cdot w_i`
where :math:`\hat{y}` is a shift scaling factor to match desired integral value.
- Difference in integral between target function and current function
+ In trapezoid integral, difference in integral between target function and current function
is calculated as:
.. math::
@@ -155,6 +214,17 @@ def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None)
.. math::
\hat{y} = 2 \Delta P / \sum_{i=1}^N \left[(w_{i-1} + w_i) \Delta x_i \right]
+ In rectangle integral, difference in integral between target function and current function
+ is calculated as:
+
+ .. math::
+ \Delta P = \sum_{i=0}^{N-1} \left[w_i \hat{y} \cdot \Delta x_i\right]
+
+ Shift scaling factor :math:`\hat{y}` can be calculated as:
+
+ .. math::
+ \hat{y} = \Delta P / \sum_{i=0}^{N-1} \left[w_i \Delta x_i\right]
+
Next, if :math:`s` is given, created function is estimated with
spline function :math:`h(x)` which satisfies:
@@ -168,31 +238,43 @@ def integral_matching_stretch(x, y, integral_value=0, dx=1.0, alpha=1.0, s=None)
else:
x = np.array(x)
- current_integral = np.trapz(y, x)
+ if integral_method not in ['trapezoid', 'rectangle']:
+ raise ValueError("Unknown integral method")
+
+ current_integral = integral(x, y, method=integral_method).sum()
+
delta_p = integral_value - current_integral
x_n2 = (x[-1] + x[0]) / 2
delta_x = x[-1] - x[0]
delta_xi = np.diff(x)
- w = 1 - (2 * np.abs(x_n2 - x) / delta_x) ** alpha
- y_hat = 2 * delta_p / np.sum((w[1:] + w[:-1]) * delta_xi)
+ # if there are only two points - set weights to 1
+ if len(x) == 2:
+ w = np.array([1., 1.])
+ else:
+ w = 1 - (2 * np.abs(x_n2 - x) / delta_x) ** alpha
+
+ y_hat = 0
+ if integral_method == 'trapezoid':
+ y_hat = 2 * delta_p / np.sum((w[1:] + w[:-1]) * delta_xi)
+ elif integral_method == 'rectangle':
+ y_hat = delta_p / np.sum(w[:-1] * delta_xi)
res_y = y + y_hat * w
return res_y if s is None else spline_smooth(x, res_y, s)(x)
-def interval_integral_matching_stretch(
- x, y, dx=1.0, integral_values=None, interval_point_indices=None, alpha=1.0, s=None
-):
+def _interval_integral_matching_stretch(x, y, dx=1.0, integral_values=None, fixed_points_indices_in_x=None,
+ integral_method='trapezoid', alpha=1.0, s=None):
r"""Stretches function y=f(x) to match integral value in given intervals.
This method creates function :math:`z=g(x)` from :math:`y=f(x)` such that the
integral of :math:`g(x)` is equal to the provided corresponding `integral_values`
- in intervals given by `interval_point_indices`.
+ in intervals given by `fixed_points_indices_in_x`.
In each interval, points are transformed inversely proportionally to the distance
- from the interval center. Function integral is numerically approximated using
- trapezoidal rule on provided points.
+ from the interval center. Function integral is numerically approximated on provided points
+ according to the 'integral_method'.
Each period stretch is delegated to `integral_matching_stretch`.
@@ -209,12 +291,15 @@ def interval_integral_matching_stretch(
Target integral values.
By default, it is `[0] * (len(interval_points_indices) - 1)`.
If `interval_point_indices` are not specified, `ValueError` is raised.
- interval_point_indices: list[int] | ndarray, optional
+ fixed_points_indices_in_x: list[int] | ndarray, optional
Indices in `x` array specifying intervals over which function is
stretched to match corresponding `integral_values`.
By default, it is evenly spaced, i.e.,
`[0, len(y) / n, 2 len(y) / n, ..., len(y)]`.
If `integral_values` are not specified` `ValueError` is raised.
+ integral_method: str, default='trapezoid'
+ Method to calculate integral of target function.
+ Available options: 'trapezoid', 'rectangle'
alpha: scalar, default: 1
Stretching exponent factor.
Scales how points are stretched if they are closer to the center point.
@@ -237,24 +322,18 @@ def interval_integral_matching_stretch(
else:
x = np.asarray(x)
- if integral_values is None and interval_point_indices is None:
- raise ValueError(
- "integral_values and interval_points cannot be None at the same time"
- )
+ if integral_values is None and fixed_points_indices_in_x is None:
+ raise ValueError("integral_values and fixed_point_indices_in_x cannot be None at the same time")
if integral_values is None:
- integral_values = [0] * (len(interval_point_indices) - 1)
- if interval_point_indices is None:
- interval_point_indices = np.arange(
- 0, len(y) + 1, int(len(y) / len(integral_values))
- )
+ integral_values = [0] * (len(fixed_points_indices_in_x) - 1)
+ if fixed_points_indices_in_x is None:
+ fixed_points_indices_in_x = np.arange(0, len(y) + 1, int(len(y) / len(integral_values)))
y = np.array(y, dtype=float)
- for integral_value, start, end in zip(
- integral_values, interval_point_indices[:-1], interval_point_indices[1:]
- ):
+ for integral_value, start, end in zip(integral_values, fixed_points_indices_in_x[:-1],
+ fixed_points_indices_in_x[1:]):
end = end + 1
- y[start:end] = integral_matching_stretch(
- x[start:end], y[start:end], integral_value=integral_value, alpha=alpha
- )
+ y[start:end] = _integral_matching_stretch(x[start:end], y[start:end], integral_value=integral_value,
+ integral_method=integral_method, alpha=alpha)
return y if s is None else spline_smooth(x, y, s)(x)
diff --git a/src/traffic_weaver/process.py b/src/traffic_weaver/process.py
index 0133232..b275104 100644
--- a/src/traffic_weaver/process.py
+++ b/src/traffic_weaver/process.py
@@ -2,9 +2,89 @@
from typing import Callable, Tuple, Union, List
import numpy as np
-from scipy.interpolate import BSpline, splrep
+from scipy.interpolate import BSpline, splrep, CubicSpline
from traffic_weaver.interval import IntervalArray
+from traffic_weaver.sorted_array_utils import find_closest_lower_equal_element_indices_to_values
+
+
+def _piecewise_constant_interpolate(x, y, new_x, left=None):
+ """Piecewise constant filling for monotonically increasing sample points.
+
+ Returns the one-dimensional piecewise constant array with given discrete data points (x, y), evaluated at new_x.
+
+ Parameters
+ ----------
+ x: np.ndarray
+ The x-coordinates of the data points, must be increasing.
+ y: np.ndarray
+ The y-coordinates of the data points, same length as x.
+ new_x
+ The x-coordinates at which to evaluate the interpolated values.
+ left: float, optional
+ Value to return for new_x < x[0], default is y[0].
+
+
+ Returns
+ -------
+ The interpolated values, same shape as new_x.
+ """
+ x = np.asarray(x)
+ y = np.asarray(y)
+ new_x = np.asarray(new_x)
+
+ new_y = np.zeros(len(new_x))
+
+ indices = find_closest_lower_equal_element_indices_to_values(x, new_x)
+
+ greater_equal_than_first_value_mask = new_x >= x[0]
+ lower_than_first_value_mask = new_x < x[0]
+
+ new_y[greater_equal_than_first_value_mask] = y[indices[greater_equal_than_first_value_mask]]
+ new_y[lower_than_first_value_mask] = left if left is not None else y[0]
+ return new_y
+
+
+def interpolate(x, y, new_x, method='linear', **kwargs):
+ """Interpolate function over new set of points.
+
+ Supports linear, cubic and spline interpolation.
+
+ Parameters
+ ----------
+ x: array-like
+ The x-coordinates of the data points, must be increasing.
+ y: array-like
+ The y-coordinates of the data points, same length as x.
+ new_x: array-like
+ New x-coordinates at which to evaluate the interpolated values.
+ method: str, default='linear'
+ Interpolation method. Supported methods are 'linear', 'constant', 'cubic' and
+ 'spline'.
+ kwargs: dict
+ Additional keyword arguments passed to the interpolation function.
+ For more details, see kwargs of numpy and scipy interpolation functions.
+
+ Returns
+ -------
+
+ See Also
+ --------
+ `https://numpy.org/doc/stable/reference/generated/numpy.interp.html
+ `_
+ `https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html
+ `_
+ `https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.splrep.html#scipy.interpolate.splrep
+ `_
+ """
+ if method == 'linear':
+ return np.interp(new_x, x, y, **kwargs)
+ if method == 'constant':
+ return _piecewise_constant_interpolate(x, y, new_x, **kwargs)
+ elif method == 'cubic':
+ return CubicSpline(x, y, **kwargs)(new_x)
+ elif method == 'spline':
+ return BSpline(*splrep(x, y, **kwargs))(new_x)
def repeat(x, y, repeats: int) -> tuple[np.ndarray, np.ndarray]:
diff --git a/src/traffic_weaver/oversample.py b/src/traffic_weaver/rfa.py
similarity index 86%
rename from src/traffic_weaver/oversample.py
rename to src/traffic_weaver/rfa.py
index 1c953fe..03cf607 100644
--- a/src/traffic_weaver/oversample.py
+++ b/src/traffic_weaver/rfa.py
@@ -1,4 +1,4 @@
-r"""Oversample average function with given recreation strategy.
+r"""Recreate from average function with given strategy.
"""
@@ -9,16 +9,13 @@
from .funfit import lin_fit, lin_exp_xy_fit, exp_lin_fit
from .interval import IntervalArray
-from .array_utils import (
- oversample_linspace,
- oversample_piecewise_constant,
-)
+from .sorted_array_utils import (oversample_linspace, oversample_piecewise_constant, )
-class AbstractOversample(ABC):
- r"""Abstract class for oversampling `n` times. a function `y` measured in `x`.
+class AbstractRFA(ABC):
+ r"""Abstract class for recreating from average (RFA) a function `y` measured in `x`.
- To perform oversampling, call `oversample()` method, which returns newly created
+ To perform recreate from average, call `rfa()` method, which returns newly created
points.
By default, `x` axis is oversampled with `n` linearly space values between each
@@ -27,7 +24,7 @@ class AbstractOversample(ABC):
By default, `y` axis is oversampled with `n` piecewise constant values between each
point. To change this behaviour, override `_initial_y_oversample` method.
- The `oversample()` method returns oversampled values for x and y.
+ The `rfa()` method returns recreated function values for x and y.
In the following derived classes, term 'interval' refers to the distance
between two consecutive observations (values in `y` based on `x`).
@@ -58,8 +55,8 @@ def __init__(self, x, y, n, **kwargs):
raise ValueError("n cannot be lower than 2.")
@abstractmethod
- def oversample(self):
- """Perform oversample on `x` and `y` variables.
+ def rfa(self):
+ """Recreate function from average on `x` and `y` variables.
Returns
-------
@@ -72,10 +69,7 @@ def oversample(self):
def _initial_oversample(self):
r"""Returns initially oversampled tuple"""
- return (
- self._initial_x_oversample(),
- self._initial_y_oversample(),
- )
+ return (self._initial_x_oversample(), self._initial_y_oversample(),)
def _initial_x_oversample(self):
r"""Initial oversample of `x` axis with linearly spaced function.
@@ -96,15 +90,15 @@ def _initial_y_oversample(self):
return oversample_piecewise_constant(self.y, num=self.n)
-class PiecewiseConstantOversample(AbstractOversample):
- r"""Oversample function using piecewise constant values."""
+class PiecewiseConstantRFA(AbstractRFA):
+ r"""Recreate function using piecewise constant values."""
- def oversample(self):
+ def rfa(self):
return self._initial_oversample()
-class FunctionOversample(AbstractOversample):
- r"""Oversample using created sampling function.
+class FunctionRFA(AbstractRFA):
+ r"""Recreate function using created sampling function.
Created sampling function should take `x` as an argument and return corresponding
`y` value.
@@ -129,23 +123,13 @@ class FunctionOversample(AbstractOversample):
Kwargs passed to `sampling_function_supplier`.
"""
- def __init__(
- self,
- x,
- y,
- n,
- sampling_function_supplier=None,
- sampling_function_supplier_kwargs=None,
- ):
+ def __init__(self, x, y, n, sampling_function_supplier=None, sampling_function_supplier_kwargs=None, ):
super().__init__(x, y, n)
self.sampling_function_supplier = sampling_function_supplier
self.sampling_function_supplier_kwargs = (
- sampling_function_supplier_kwargs
- if sampling_function_supplier_kwargs is not None
- else {}
- )
+ sampling_function_supplier_kwargs if sampling_function_supplier_kwargs is not None else {})
- def oversample(self):
+ def rfa(self):
function = self._get_sampling_function()
xs, ys = self._initial_oversample()
ys = [function(x) for x in xs]
@@ -158,32 +142,24 @@ def _get_sampling_function(self):
its corresponding dependent value f(x).
"""
if self.sampling_function_supplier:
- return self.sampling_function_supplier(
- self.x, self.y, **self.sampling_function_supplier_kwargs
- )
+ return self.sampling_function_supplier(self.x, self.y, **self.sampling_function_supplier_kwargs)
else:
raise ValueError("Sampling function not specified")
-class CubicSplineOversample(FunctionOversample):
- r"""Oversample function using cubic spline between given points."""
+class CubicSplineRFA(FunctionRFA):
+ r"""Recreate function using cubic spline between given points."""
- def __init__(
- self,
- x,
- y,
- n,
- sampling_function_supplier=lambda x, y: CubicSpline(x, y),
- ):
+ def __init__(self, x, y, n, sampling_function_supplier=lambda x, y: CubicSpline(x, y), ):
super().__init__(x, y, n, sampling_function_supplier)
-class IntervalOversample(AbstractOversample):
- r"""Abstraction for interval based oversampling classes."""
+class IntervalRFA(AbstractRFA):
+ r"""Abstraction for interval based function recreate from average classes."""
pass
-class LinearFixedOversample(AbstractOversample):
+class LinearFixedRFA(AbstractRFA):
r"""Linearly moves between points in fixed transition intervals.
.. image:: /_static/gfx/linear_fixed_oversample.pdf
@@ -271,7 +247,7 @@ def __init__(self, x, y, n, alpha=1.0, a=None):
self.a_l = int(self.a / 2)
self.a_r = self.a_l
- def oversample(self):
+ def rfa(self):
x, y = self._initial_oversample()
n = self.n
a_r = self.a_r
@@ -294,27 +270,17 @@ def oversample(self):
for k in range(1, x.nr_of_full_intervals() - 1):
# calculate transition points
y_0 = y[k, 0]
- z_0 = lin_fit(
- x[k, 0],
- (x[k, -a_r], y[k - 1, 0]),
- (x[k, a_l], y[k, 0]),
- )
- z_1 = lin_fit(
- x[k + 1, 0],
- (x[k, n - a_r], y_0),
- (x[k + 1, a_l], y[k + 1, 0]),
- )
+ z_0 = lin_fit(x[k, 0], (x[k, -a_r], y[k - 1, 0]), (x[k, a_l], y[k, 0]), )
+ z_1 = lin_fit(x[k + 1, 0], (x[k, n - a_r], y_0), (x[k + 1, a_l], y[k + 1, 0]), )
for i in range(0, self.a_l):
z[k, i] = lin_fit(x[k, i], (x[k, 0], z_0), (x[k, self.a_l], y_0))
for i in range(n - self.a_r + 1, n + 1):
- z[k, i] = lin_fit(
- x[k, i], (x[k, self.n - self.a_r], y_0), (x[k, self.n], z_1)
- )
+ z[k, i] = lin_fit(x[k, i], (x[k, self.n - self.a_r], y_0), (x[k, self.n], z_1))
return x.array[n:-n], z.array[n:-n]
-class LinearAdaptiveOversample(AbstractOversample):
+class LinearAdaptiveRFA(AbstractRFA):
r"""Linearly moves between points in adaptive transition intervals.
.. image:: /_static/gfx/linear_adaptive_oversample.pdf
@@ -478,7 +444,7 @@ def get_adaptive_transition_points(x, y, a, adaptive_smooth):
gammas.append(None)
else:
gamma = nom / denom
- gamma = gamma**adaptive_smooth
+ gamma = gamma ** adaptive_smooth
a_l = gamma * a / (1 + gamma)
a_r = a / (1 + gamma)
a_l = int(min(max(a_l, 1), a))
@@ -493,7 +459,7 @@ def get_adaptive_transition_points(x, y, a, adaptive_smooth):
gammas.extend([None])
return a_ls, a_rs, gammas
- def oversample(self):
+ def rfa(self):
x, y = self._initial_oversample()
n = self.n
@@ -510,9 +476,7 @@ def oversample(self):
y.extend_constant(direction='both')
z.extend_constant(direction='both')
- a_ls, a_rs, gammas = self.get_adaptive_transition_points(
- x, y, self.a, self.adaptive_smooth
- )
+ a_ls, a_rs, gammas = self.get_adaptive_transition_points(x, y, self.a, self.adaptive_smooth)
for k in range(1, x.nr_of_full_intervals() - 1):
y_0 = y[k, 0]
@@ -521,17 +485,11 @@ def oversample(self):
if a_rs[k - 1] == 0 and a_ls[k] == 0: # no left transition window
z_0 = y[k - 1, 0]
else:
- z_0 = lin_fit(
- x[k, 0], (x[k, -a_rs[k - 1]], y[k - 1, 0]), (x[k, a_ls[k]], y[k, 0])
- )
+ z_0 = lin_fit(x[k, 0], (x[k, -a_rs[k - 1]], y[k - 1, 0]), (x[k, a_ls[k]], y[k, 0]))
if a_rs[k] == 0 and a_ls[k + 1] == 0: # no right transition window
z_1 = y[k + 1]
else:
- z_1 = lin_fit(
- x[k + 1, 0],
- (x[k, n - a_rs[k]], y_0),
- (x[k + 1, a_ls[k + 1]], y[k + 1, 0]),
- )
+ z_1 = lin_fit(x[k + 1, 0], (x[k, n - a_rs[k]], y_0), (x[k + 1, a_ls[k + 1]], y[k + 1, 0]), )
# fit remaining points
for i in range(0, a_ls[k]):
@@ -542,7 +500,7 @@ def oversample(self):
return x.array[n:-n], z.array[n:-n]
-class ExpFixedOversample(AbstractOversample):
+class ExpFixedRFA(AbstractRFA):
r"""Moves between points in fixed transition intervals by combining linear and
exponential function.
@@ -665,7 +623,7 @@ def __init__(self, x, y, n, alpha=1.0, beta=0.5, a=None, exp=2.0):
self.b = int(beta * self.a_l)
self.exp = exp
- def oversample(self):
+ def rfa(self):
x, y = self._initial_oversample()
n = self.n
a_r = self.a_r
@@ -691,9 +649,7 @@ def oversample(self):
# calculate transition points
y_0 = y[k, 0]
z_0 = lin_fit(x[k, 0], (x[k, -a_r], y[k - 1, 0]), (x[k, a_l], y[k, 0]))
- z_1 = lin_fit(
- x[k + 1, 0], (x[k, n - a_r], y_0), (x[k + 1, a_l], y[k + 1, 0])
- )
+ z_1 = lin_fit(x[k + 1, 0], (x[k, n - a_r], y_0), (x[k + 1, a_l], y[k + 1, 0]))
z_0_lb = lin_fit(x[k, 0 + b], (x[k, 0], z_0), (x[k, a_l], y[k, 0]))
@@ -703,27 +659,17 @@ def oversample(self):
for i in range(0, b):
z[k, i] = lin_fit(x[k, i], (x[k, 0], z_0), (x[k, b], z_0_lb))
for i in range(b, a_l):
- z[k, i] = lin_exp_xy_fit(
- x[k, i],
- (x[k, b], z_0_lb),
- (x[k, a_l], y_0),
- alpha=exp,
- )
+ z[k, i] = lin_exp_xy_fit(x[k, i], (x[k, b], z_0_lb), (x[k, a_l], y_0), alpha=exp, )
for i in range(n - a_r, n - b):
- z[k, i] = exp_lin_fit(
- x[k, i],
- (x[k, n - a_r], y_0),
- (x[k, n - b], z_0_rb),
- alpha=exp,
- )
+ z[k, i] = exp_lin_fit(x[k, i], (x[k, n - a_r], y_0), (x[k, n - b], z_0_rb), alpha=exp, )
for i in range(n - b, n):
z[k, i] = lin_fit(x[k, i], (x[k, n - b], z_0_rb), (x[k, n], z_1))
return x.array[n:-n], z.array[n:-n]
-class ExpAdaptiveOversample(AbstractOversample):
+class ExpAdaptiveRFA(AbstractRFA):
r"""Moves between points in adaptive transition intervals by combining linear and
exponential function.
@@ -832,9 +778,7 @@ class ExpAdaptiveOversample(AbstractOversample):
"""
- def __init__(
- self, x, y, n, alpha=1.0, beta=0.5, a=None, adaptive_smooth=1.0, exp=2.0
- ):
+ def __init__(self, x, y, n, alpha=1.0, beta=0.5, a=None, adaptive_smooth=1.0, exp=2.0):
super().__init__(x, y, n)
if a is None:
a = alpha * self.n
@@ -845,7 +789,7 @@ def __init__(
self.adaptive_smooth = adaptive_smooth
self.exp = exp
- def oversample(self):
+ def rfa(self):
x, y = self._initial_oversample()
n = self.n
beta = self.beta
@@ -865,9 +809,7 @@ def oversample(self):
z.extend_constant(direction='both')
# get adaptive factors
- a_ls, a_rs, gammas = LinearAdaptiveOversample.get_adaptive_transition_points(
- x, y, self.a, self.adaptive_smooth
- )
+ a_ls, a_rs, gammas = LinearAdaptiveRFA.get_adaptive_transition_points(x, y, self.a, self.adaptive_smooth)
b_ls = [int(beta * a_l) for a_l in a_ls]
b_rs = [int(beta * a_r) for a_r in a_rs]
@@ -879,50 +821,30 @@ def oversample(self):
if a_rs[k - 1] == 0 and a_ls[k] == 0: # no left transition window
z_0 = y[k - 1, 0]
else:
- z_0 = lin_fit(
- x[k, 0], (x[k, -a_rs[k - 1]], y[k - 1, 0]), (x[k, a_ls[k]], y[k, 0])
- )
+ z_0 = lin_fit(x[k, 0], (x[k, -a_rs[k - 1]], y[k - 1, 0]), (x[k, a_ls[k]], y[k, 0]))
if a_rs[k] == 0 and a_ls[k + 1] == 0: # no right transition window
z_1 = y[k + 1]
else:
- z_1 = lin_fit(
- x[k + 1, 0],
- (x[k, n - a_rs[k]], y_0),
- (x[k + 1, a_ls[k + 1]], y[k + 1, 0]),
- )
+ z_1 = lin_fit(x[k + 1, 0], (x[k, n - a_rs[k]], y_0), (x[k + 1, a_ls[k + 1]], y[k + 1, 0]), )
if b_ls[k] == 0: # no left linear transition window
z_0_bl = z_0
else:
- z_0_bl = lin_fit(
- x[k, 0 + b_ls[k]], (x[k, 0], z_0), (x[k, a_ls[k]], y[k, 0])
- )
+ z_0_bl = lin_fit(x[k, 0 + b_ls[k]], (x[k, 0], z_0), (x[k, a_ls[k]], y[k, 0]))
if b_rs[k] == 0: # no right linear transition window
z_0_br = z_1
else:
- z_0_br = lin_fit(
- x[k, n - b_rs[k]], (x[k, n - a_rs[k]], y[k, 0]), (x[k + 1, 0], z_1)
- )
+ z_0_br = lin_fit(x[k, n - b_rs[k]], (x[k, n - a_rs[k]], y[k, 0]), (x[k + 1, 0], z_1))
# remaining points
for i in range(0, b_ls[k]):
z[k, i] = lin_fit(x[k, i], (x[k, 0], z_0), (x[k, b_ls[k]], z_0_bl))
for i in range(b_ls[k], a_ls[k]):
- z[k, i] = lin_exp_xy_fit(
- x[k, i],
- (x[k, b_ls[k]], z_0_bl),
- (x[k, a_ls[k]], y_0),
- alpha=exp,
- )
+ z[k, i] = lin_exp_xy_fit(x[k, i], (x[k, b_ls[k]], z_0_bl), (x[k, a_ls[k]], y_0), alpha=exp, )
for i in range(n - a_rs[k], n - b_rs[k]):
- z[k, i] = exp_lin_fit(
- x[k, i],
- (x[k, n - a_rs[k]], y_0),
- (x[k, n - b_rs[k]], z_0_br),
- alpha=exp,
- )
+ z[k, i] = exp_lin_fit(x[k, i], (x[k, n - a_rs[k]], y_0), (x[k, n - b_rs[k]], z_0_br), alpha=exp, )
for i in range(n - b_rs[k], n):
z[k, i] = lin_fit(x[k, i], (x[k, n - b_rs[k]], z_0_br), (x[k, n], z_1))
diff --git a/src/traffic_weaver/sorted_array_utils.py b/src/traffic_weaver/sorted_array_utils.py
new file mode 100644
index 0000000..adeb101
--- /dev/null
+++ b/src/traffic_weaver/sorted_array_utils.py
@@ -0,0 +1,510 @@
+r"""Array utilities.
+"""
+from typing import List, Union
+
+import numpy as np
+
+
+def append_one_sample(x: Union[np.ndarray, List], y: Union[np.ndarray, List], make_periodic=False):
+ r"""Add one sample to the end of time series.
+
+ Add one sample to `x` and `y` array. Newly added point `x_i` point is distant from
+ the last point of `x` same as the last from the one before the last point.
+ If `make_periodic` is False, newly added `y_i` point is the same as the last point
+ of `y`. If `make_periodic` is True, newly added point is the same as the first point
+ of `y`.
+
+ Parameters
+ ----------
+ x: 1-D array-like of size n
+ Independent variable in strictly increasing order.
+ y: 1-D array-like of size n
+ Dependent variable.
+ make_periodic: bool, default: False
+ If false, append the last `y` point to `y` array.
+ If true, append the first `y` point to `y` array.
+
+ Returns
+ -------
+ ndarray
+ x, independent variable.
+ ndarray
+ y, dependent variable.
+ """
+ x = np.asarray(x, dtype=np.float64)
+ y = np.asarray(y, dtype=np.float64)
+
+ x = np.append(x, 2 * x[-1] - x[-2])
+ if not make_periodic:
+ y = np.append(y, y[-1])
+ else:
+ y = np.append(y, y[0])
+ return x, y
+
+
+def oversample_linspace(a: np.ndarray, num: int):
+ r"""Oversample array using linspace between each consecutive pair of array elements.
+
+ E.g., Array [1, 2, 3] oversampled by 2 becomes [1, 1.5, 2, 2.5, 3].
+
+ If input array is of size `n`, then resulting array is of size `(n - 1) * num + 1`.
+
+ If `n` is lower than 2, the original array is returned.
+
+ Parameters
+ ----------
+ a: 1-D array
+ Input array to oversample.
+ num: int
+ Number of elements inserted between each pair of array elements. Larger or
+ equal to 2.
+
+ Returns
+ -------
+ ndarray
+ 1-D array containing `num` linspaced elements between each array elements' pair.
+ Its length is equal to `(len(a) - 1) * num + 1`
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from traffic_weaver.sorted_array_utils import oversample_linspace
+ >>> oversample_linspace(np.asarray([1, 2, 3]), 4).tolist()
+ [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]
+
+ """
+ if num < 2:
+ return a
+ a = np.asarray(a, dtype=float)
+ return np.append(np.linspace(a[:-1], a[1:], num=num + 1)[:-1].T.flatten(), a[-1])
+
+
+def oversample_piecewise_constant(a: np.ndarray, num: int):
+ r"""Oversample array using same left value between each consecutive pair of array
+ elements.
+
+ E.g., Array [1, 2, 3] oversampled by 2 becomes [1, 1, 2, 2, 3].
+
+ If input array is of size `n`, then resulting array is of size `(n - 1) * num + 1`.
+
+ If `n` is lower than 2, the original array is returned.
+
+ Parameters
+ ----------
+ a: 1-D array
+ Input array to oversample.
+ num: int
+ Number of elements inserted between each pair of array elements. Larger or
+ equal to 2.
+
+ Returns
+ -------
+ ndarray
+ 1-D array containing `num` elements between each array elements' pair.
+ Its length is equal to `(len(a) - 1) * num + 1`
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from traffic_weaver.sorted_array_utils import oversample_piecewise_constant
+ >>> oversample_piecewise_constant(np.asarray([1.0, 2.0, 3.0]), 4).tolist()
+ [1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0]
+
+ """
+ if num < 2:
+ return a
+ a = np.asarray(a)
+ return a.repeat(num)[: -num + 1]
+
+
+def extend_linspace(a: np.ndarray, n: int, direction="both", lstart: float = None, rstop: float = None):
+ """Extends array using linspace with n elements.
+
+ Extends array `a` from left and/or right with `n` elements each side.
+
+ When extending to the left,
+ the starting value is `lstart` (inclusive) and ending value as `a[0]` (exclusive).
+ By default, `lstart` is `a[0] - (a[n] - a[0])`.
+
+ When extending to the right,
+ the starting value `a[-1]` (exclusive) and ending value is `rstop` (inclusive).
+ By default, `rstop` is `a[-1] + (a[-1] - a[-1 - n])`
+
+ `direction` determines whether to extend to `both`, `left` or `right`.
+ By default, it is 'both'.
+
+ Parameters
+ ----------
+ a: 1-D array
+ n: int
+ Number of elements to extend
+ direction: 'both', 'left' or 'right', default: 'both'
+ Direction in which array should be extended.
+ lstart: float, optional
+ Starting value of the left extension.
+ By default, it is `a[0] - (a[n] - a[0])`.
+ rstop: float, optional
+ Ending value of the right extension.
+ By default, it is `a[-1] + (a[-1] - a[-1 - n])`.
+
+ Returns
+ -------
+ ndarray
+ 1-D extended array.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from traffic_weaver.sorted_array_utils import extend_linspace
+ >>> a = np.array([1, 2, 3])
+ >>> extend_linspace(a, 2, direction='both').tolist()
+ [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
+
+ >>> extend_linspace(a, 4, direction='right', rstop=4).tolist()
+ [1.0, 2.0, 3.0, 3.25, 3.5, 3.75, 4.0]
+
+ """
+ a = np.asarray(a, dtype=float)
+ if direction == "both" or direction == "left":
+ if lstart is None:
+ lstart = 2 * a[0] - a[n]
+ ext = np.linspace(lstart, a[0], n + 1)[:-1]
+ a = np.insert(a, 0, ext)
+
+ if direction == "both" or direction == "right":
+ if rstop is None:
+ rstop = 2 * a[-1] - a[-n - 1]
+ ext = np.linspace(a[-1], rstop, n + 1)[1:]
+ a = np.insert(a, len(a), ext)
+
+ return a
+
+
+def extend_constant(a: np.ndarray, n: int, direction="both"):
+ """Extends array with first/last value with n elements.
+
+ Extends array `a` from left and/or right with `n` elements each side.
+
+ When extending to the left, value `a[0]` is repeated.
+ When extending to the right, value `a[-1]` is repeated.
+
+ `direction` determines whether to extend to `both`, `left` or `right`.
+ By default, it is 'both'.
+
+ Parameters
+ ----------
+ a: 1-D array
+ n: int
+ Number of elements to extend
+ direction: 'both', 'left' or 'right', optional: 'both'
+ Direction in which array should be extended.
+
+ Returns
+ -------
+ ndarray
+ 1-D extended array.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> from traffic_weaver.sorted_array_utils import extend_constant
+ >>> a = np.array([1, 2, 3])
+ >>> extend_constant(a, 2, direction='both').tolist()
+ [1, 1, 1, 2, 3, 3, 3]
+
+ """
+ a = np.asarray(a)
+ if direction == "both" or direction == "left":
+ a = np.insert(a, 0, [a[0]] * n)
+ if direction == "both" or direction == "right":
+ a = np.insert(a, len(a), [a[-1]] * n)
+ return a
+
+
+def rectangle_integral(x, y):
+ r"""Integral values between each pair of points using rectangle approx.
+
+ In particular, if function contains average values, then it corresponds to the
+ exact value of the integral.
+
+ Parameters
+ ----------
+ x: 1-D array-like of size n
+ Independent variable in strictly increasing order.
+ y: 1-D array-like of size n
+ Dependent variable.
+
+ Returns
+ -------
+ 1-D array-like of size n-1
+ Values of the integral.
+ """
+ d = np.diff(x)
+ return y[:-1] * d
+
+
+def trapezoid_integral(x, y):
+ """Calculates integral between each pair of points using trapezoidal rule.
+
+ Parameters
+ ----------
+ x: 1-D array-like of size n
+ Independent variable in strictly increasing order.
+ y: 1-D array-like of size n
+ Dependent variable.
+
+ Returns
+ -------
+ 1-D array-like of size n-1
+ Values of the integral.
+
+ """
+ return (y[:-1] + y[1:]) / 2 * np.diff(x)
+
+
+def integral(x, y, method: str = 'trapezoid'):
+ """Calculate integral y over range x according to provided method.
+
+ Parameters
+ ----------
+ 1-D array-like of size n
+ Independent variable in strictly increasing order.
+ y: 1-D array-like of size n
+ Dependent variable.
+ method: str, default: 'trapezoid'
+ Method to calculate integral of target function.
+ Available options: 'trapezoid', 'rectangle'
+ Returns
+ -------
+ 1-D array-like of size n-1
+ Values of the integral.
+ """
+ if method == 'trapezoid':
+ return trapezoid_integral(x, y)
+ elif method == 'rectangle':
+ return rectangle_integral(x, y)
+ raise ValueError("Unknown integral method")
+
+
+def find_closest_lower_equal_element_indices_to_values(x: Union[np.ndarray, list], lookup: Union[np.ndarray, list],
+ fill_not_valid: bool = True):
+ """Find indices of closest lower or equal element in x to each element in lookup.
+
+ Parameters
+ ----------
+ x: np.ndarray
+ Array of values to search in.
+ lookup: np.ndarray
+ Values to search for.
+ fill_not_valid: bool, default: True
+ If True, fill indices of lookup values that are lower than the first element
+ in 'x' with 0.
+ Returns
+ -------
+ np.ndarray
+ Array of indices of closest lower or equal element in x to each element in
+ lookup.
+ """
+ indices = np.zeros(len(lookup), dtype=np.int64)
+
+ x_it = iter(x)
+ x_val = next(x_it)
+ x_next_val = next(x_it, None)
+ x_idx = 0
+
+ lookup_it = iter(lookup)
+ lookup_val = next(lookup_it)
+ lookup_idx = 0
+
+ # lookup value lower than x
+ # shift lookup until it is higher equal than the first element in x
+ while lookup_val is not None and lookup_val < x_val:
+ indices[lookup_idx] = x_idx if fill_not_valid else -1
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+
+ # lookup value is higher than the first element in x
+ while lookup_val is not None:
+ # if lookup is higher than the next x
+ # move x to the right
+ while x_next_val is not None and x_next_val <= lookup_val:
+ x_next_val = next(x_it, None)
+ x_idx += 1
+ if x_next_val is None:
+ break
+ # lookup value is higher than the current x and lower than the next x
+ indices[lookup_idx] = x_idx
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+ return indices
+
+
+def find_closest_higher_equal_element_indices_to_values(x: Union[np.ndarray, list], lookup: Union[np.ndarray, list],
+ fill_not_valid: bool = True):
+ """Find indices of closest higher or equal element in x to each element in lookup.
+
+ Parameters
+ ----------
+ x: np.ndarray
+ Array of values to search in.
+ lookup: np.ndarray
+ Values to search for.
+ fill_not_valid: bool, default: True
+ If True, fill indices of lookup values that are higher than the last element
+ in 'x' with 'len(x) - 1'.
+
+ Returns
+ -------
+ np.ndarray
+ Array of indices of closest higher or equal element in x to each element in
+ lookup.
+ """
+ indices = np.zeros(len(lookup), dtype=np.int64)
+
+ x_it = iter(x)
+ x_val = next(x_it)
+ x_next_val = next(x_it, None)
+ x_idx = 0
+
+ lookup_it = iter(lookup)
+ lookup_val = next(lookup_it)
+ lookup_idx = 0
+
+ # lookup value lower than x
+ # shift lookup until it is higher than the first element in x
+ while lookup_val is not None and lookup_val <= x_val:
+ indices[lookup_idx] = x_idx
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+
+ # lookup value is higher than the first element in x
+ while lookup_val is not None:
+ # if lookup is higher than the next x
+ # move x to the right
+ while x_next_val is not None and x_next_val < lookup_val:
+ x_next_val = next(x_it, None)
+ x_idx += 1
+ if x_next_val is None:
+ break
+ # lookup value is higher than the current x and lower than the next x
+ if x_next_val is None:
+ indices[lookup_idx] = x_idx if fill_not_valid else len(x)
+ else:
+ indices[lookup_idx] = x_idx + 1
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+ return indices
+
+
+def find_closest_lower_or_higher_element_indices_to_values(x: Union[np.ndarray, list], lookup: Union[np.ndarray, list]):
+ """Find indices of closest element in x to each element in lookup.
+
+ Parameters
+ ----------
+ x: np.ndarray
+ Array of values to search in.
+ lookup: np.ndarray
+ Values to search for.
+
+ Returns
+ -------
+ np.ndarray
+ Array of indices of closest element in x to each element in lookup.
+ """
+ indices = np.zeros(len(lookup), dtype=np.int64)
+
+ x_it = iter(x)
+ x_val = next(x_it)
+ x_next_val = next(x_it, None)
+ x_idx = 0
+
+ lookup_it = iter(lookup)
+ lookup_val = next(lookup_it)
+ lookup_idx = 0
+
+ # lookup value lower than x
+ # shift lookup until it is higher than the first element in x
+ while lookup_val is not None and lookup_val <= x_val:
+ indices[lookup_idx] = x_idx
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+
+ # lookup value is higher than the first element in x
+ while lookup_val is not None:
+ # if lookup is higher than the next x
+ # move x to the right
+ while x_next_val is not None and x_next_val < lookup_val:
+ x_val = x_next_val
+ x_next_val = next(x_it, None)
+ x_idx += 1
+ if x_next_val is None:
+ break
+ # lookup value is higher than the last element in x
+ if x_next_val is None:
+ indices[lookup_idx] = x_idx
+ else:
+ # lookup value is higher than the current x and lower than the next x
+ # check which one is closer
+ if lookup_val - x_val <= x_next_val - lookup_val:
+ indices[lookup_idx] = x_idx
+ else:
+ indices[lookup_idx] = x_idx + 1
+ lookup_val = next(lookup_it, None)
+ lookup_idx += 1
+ return indices
+
+
+def find_closest_element_indices_to_values(x: Union[np.ndarray, list], lookup: Union[np.ndarray, list],
+ strategy: str = 'closest', fill_not_valid: bool = True):
+ """Find indices of closest element in x to each element in lookup according
+ to the strategy.
+
+ Parameters
+ ----------
+ x: np.ndarray
+ Array of values to search in.
+ lookup: np.ndarray
+ Values to search for.
+ strategy: str, default: 'closest'
+ Strategy to find the closest element.
+ 'closest': closest element (lower or higher)
+ 'lower': closest lower or equal element
+ 'higher': closest higher or equal element
+ fill_not_valid: bool, default: True
+ Used in case of 'lower' and 'higher' strategy.
+ If True, fill indices of lookup valules that are lower than the first element
+ in 'x' with 'x[0]',
+ fill indices of lookup values that are higher than the last element
+ in 'x' with 'len(x) - 1'.
+
+ Returns
+ -------
+ np.ndarray
+ Array of indices of closest element in x to each element inlookup.
+
+ """
+ if strategy == 'closest':
+ return find_closest_lower_or_higher_element_indices_to_values(x, lookup)
+ elif strategy == 'lower':
+ return find_closest_lower_equal_element_indices_to_values(x, lookup, fill_not_valid)
+ elif strategy == 'higher':
+ return find_closest_higher_equal_element_indices_to_values(x, lookup, fill_not_valid)
+ raise ValueError("Unknown strategy")
+
+
+def sum_over_indices(a, indices):
+ """Sum values of array `a` over ranges defined by `indices`.
+
+ Parameters
+ ----------
+ a: array-like
+ Array of values.
+ indices: array-like of int
+ Array of indices defining ranges over which to sum values.
+ Returns
+ -------
+ Array of sums of values over ranges defined by `indices`.
+ """
+ a = np.asarray(a)
+ indices = np.asarray(indices)
+ return np.array([a[start:stop].sum() for start, stop in zip(indices[:-1], indices[1:])])
diff --git a/src/traffic_weaver/weaver.py b/src/traffic_weaver/weaver.py
index 4ad9048..35c4c10 100644
--- a/src/traffic_weaver/weaver.py
+++ b/src/traffic_weaver/weaver.py
@@ -1,8 +1,9 @@
import numpy as np
from .match import integral_matching_reference_stretch
-from .oversample import AbstractOversample, ExpAdaptiveOversample
-from .process import repeat, trend, spline_smooth, noise_gauss
+from .rfa import AbstractRFA, ExpAdaptiveRFA
+from .process import repeat, trend, spline_smooth, noise_gauss, interpolate
+from .sorted_array_utils import append_one_sample
class Weaver:
@@ -12,19 +13,26 @@ class Weaver:
----------
x: 1-D array-like of size n, optional
Independent variable in strictly increasing order.
+ If x is None, then x is a set of integers from 0 to `len(y) - 1`
y: 1-D array-like of size n
Dependent variable.
+ Raises:
+ -------
+ ValueError
+ If `x` and `y` are not of the same length.
+
Examples
--------
>>> from traffic_weaver import Weaver
- >>> from traffic_weaver.array_utils import append_one_sample
- >>> from traffic_weaver.datasets import load_mobile_video
- >>> x, y = load_mobile_video()
- >>> x, y = append_one_sample(x, y, make_periodic=True)
+ >>> from traffic_weaver.sorted_array_utils import append_one_sample
+ >>> from traffic_weaver.datasets import load_sandvine_mobile_video
+ >>> data = load_sandvine_mobile_video()
+ >>> x, y = data.T
>>> wv = Weaver(x, y)
+ >>> _ = wv.append_one_sample(make_periodic=True)
>>> # chain some command
- >>> _ = wv.oversample(10).integral_match().smooth(s=0.2)
+ >>> _ = wv.recreate_from_average(10).integral_match().smooth(s=0.2)
>>> # at any moment get newly created and processed time series' points
>>> res_x, res_y = wv.get()
>>> # chain some other commands
@@ -39,6 +47,8 @@ class Weaver:
"""
def __init__(self, x, y):
+ if x is not None and len(x) != len(y):
+ raise ValueError("x and y should be of the same length")
if x is None:
self.x = np.arange(stop=len(y))
else:
@@ -48,6 +58,89 @@ def __init__(self, x, y):
self.original_x = self.x
self.original_y = self.y
+ self.x_scale = 1
+ self.y_scale = 1
+
+ def copy(self):
+ """Create a copy of the Weaver object.
+
+ Returns
+ -------
+ Weaver
+ """
+ wv = Weaver(self.original_x.copy(), self.original_y.copy())
+ wv.x = self.x.copy()
+ wv.y = self.y.copy()
+ wv.x_scale = self.x_scale
+ wv.y_scale = self.y_scale
+ return wv
+
+ @staticmethod
+ def from_2d_array(xy: np.ndarray):
+ """Create Weaver object from 2D array.
+
+ Parameters
+ ----------
+ xy: np.ndarray of shape (nr_of_samples, 2)
+ 2D array with each row representing one point in time series.
+ The first column is the x-variable and the second column is the y-variable.
+
+ Returns
+ -------
+ Weaver
+ Weaver object with x and y values from 2D array.
+
+ Raises
+ ------
+ ValueError
+ If `xy` is not a 2D array or does not have 2 columns
+
+ """
+ shape = xy.shape
+ if len(shape) != 2 or shape[1] != 2:
+ raise ValueError("xy should be 2D array with 2 columns")
+ return Weaver(xy[:, 0], xy[:, 1])
+
+ @staticmethod
+ def from_dataframe(df, x_col=0, y_col=1):
+ """Create Weaver object from DataFrame.
+
+ Parameters
+ ----------
+ df: pandas DataFrame
+ DataFrame with data.
+ x_col: int or str, default=0
+ Name of column with x values.
+ y_col: int or str, default=1
+ Name of column with y values.
+
+ Returns
+ -------
+ Weaver
+ Weaver object with x and y values from DataFrame.
+
+ """
+ return Weaver(df[x_col].values, df[y_col].values)
+
+ @staticmethod
+ def from_csv(file_name: str):
+ """Create Weaver object from CSV file.
+
+ CSV has to contain two columns without headers.
+ The first column contains 'x' values,
+ the second column contains 'y' values.
+
+ Parameters
+ ----------
+ file_name: str
+ Path to CSV file.
+ Returns
+ -------
+ Weaver
+ Weaver object from CSV file.
+ """
+ return Weaver.from_2d_array(np.loadtxt(file_name, delimiter=',', dtype=np.float64))
+
def get(self):
r"""Return function x,y tuple after performed processing."""
return self.x, self.y
@@ -62,22 +155,115 @@ def restore_original(self):
self.y = self.original_y
return self
- def oversample(
- self,
- n: int,
- oversample_class: type[AbstractOversample] = ExpAdaptiveOversample,
- **kwargs,
- ):
- r"""Oversample function using provided strategy.
+ def append_one_sample(self, make_periodic=False):
+ """Add one sample to the end of time series.
+
+ Add one sample to `x` and `y` array. Newly added point `x_i` point is distant
+ from
+ the last point of `x` same as the last from the one before the last point.
+ If `make_periodic` is False, newly added `y_i` point is the same as the last
+ point
+ of `y`. If `make_periodic` is True, newly added point is the same as the
+ first point
+ of `y`.
+
+ Parameters
+ ----------
+ make_periodic: bool, default: False
+ If false, append the last `y` point to `y` array.
+ If true, append the first `y` point to `y` array.
+
+ Returns
+ -------
+ self
+
+ See Also
+ --------
+ :func:`~traffic_weaver.array_utils.append_one_sample`
+ """
+ self.x, self.y = append_one_sample(self.x, self.y, make_periodic=make_periodic)
+ return self
+
+ def slice_by_index(self, start=0, stop=None, step=1):
+ if stop is None:
+ stop = len(self.x)
+ if start < 0:
+ raise ValueError("Start index should be non-negative")
+ if stop > len(self.x):
+ raise ValueError("Stop index should be less than length of x")
+ self.x = self.x[start:stop:step]
+ self.y = self.y[start:stop:step]
+ return self
+
+ def slice_by_value(self, start=None, stop=None, step=1):
+ if start is None:
+ start_idx = 0
+ else:
+ start_idx = np.where(self.x == start)[0][0]
+ if stop is None:
+ stop_idx = len(self.x)
+ else:
+ stop_idx = np.where(self.x == stop)[0][0]
+ if not start_idx:
+ raise ValueError("Start value not found in x")
+ if not stop_idx:
+ raise ValueError("Stop value not found in x")
+ return self.slice_by_index(start_idx, stop_idx, step)
+
+ def interpolate(self, n: int = None, new_x=None, method='linear', **kwargs):
+ """ Interpolate function.
+
+ For original time varying function sampled at different points use one of the
+ 'linear', 'cubic' or 'spline' interpolation methods.
+
+ For time varying function that is an averaged function over periods of time,
+ use 'constant' interpolation method.
+
+ Parameters
+ ----------
+ n: int
+ Number of fixed space samples in new function.
+ Ignored if `new_x` specified.
+ new_x: array-like
+ Points to where to evaluate interpolated function.
+ It overrides 'n' parameter. Range should be the same as original x.
+ method: str, default='linear'
+ Interpolation strategy. Supported strategies are 'linear',
+ 'constant', 'cubic' and 'spline'.
+ kwargs:
+ Additional parameters passed to interpolation function.
+ For more details see
+
+ Returns
+ -------
+ self
+
+ See Also
+ --------
+ :func:`~traffic_weaver.process.interpolate`
+ """
+ if new_x is None and n is None:
+ raise ValueError("Either n or new_x should be provided")
+ if new_x is None:
+ new_x = np.linspace(self.x[0], self.x[-1], n)
+ else:
+ if new_x[0] != self.x[0] or new_x[-1] != self.x[-1]:
+ raise ValueError("new_x should have the same range as x")
+ self.y = interpolate(self.x, self.y, new_x, method=method, **kwargs)
+ self.x = new_x
+ return self
+
+ def recreate_from_average(self, n: int, rfa_class: type[AbstractRFA] = ExpAdaptiveRFA, **kwargs, ):
+ r"""Recreate function from average function using provided strategy.
Parameters
----------
n: int
Number of samples between each point.
- oversample_class: subclass of AbstractOversample
- Oversample strategy.
+ rfa_class: subclass of AbstractRFA
+ Recreate from average strategy.
**kwargs
- Additional parameters passed to `oversample_class`.
+ Additional parameters passed to `rfa_class`.
Returns
-------
@@ -85,14 +271,13 @@ def oversample(
See Also
--------
- :func:`~traffic_weaver.oversample.AbstractOversample`
+ :func:`~traffic_weaver.oversample.AbstractRFA`
"""
- self.x, self.y = oversample_class(self.x, self.y, n, **kwargs).oversample()
+ self.x, self.y = rfa_class(self.x, self.y, n, **kwargs).rfa()
return self
def integral_match(self, **kwargs):
- r"""Match function integral to piecewise constant approximated integral of the
- original function.
+ r"""Match function integral to approximated integral of the original function.
Parameters
----------
@@ -107,9 +292,8 @@ def integral_match(self, **kwargs):
--------
:func:`~traffic_weaver.match.integral_matching_reference_stretch`
"""
- self.y = integral_matching_reference_stretch(
- self.x, self.y, self.original_x, self.original_y, **kwargs
- )
+ self.y = integral_matching_reference_stretch(self.x, self.y, self.original_x * self.x_scale,
+ self.original_y * self.y_scale, **kwargs)
return self
def noise(self, snr, **kwargs):
@@ -212,3 +396,15 @@ def to_function(self, s=0):
:func:`~traffic_weaver.process.spline_smooth`
"""
return spline_smooth(self.x, self.y, s=s)
+
+ def scale_x(self, scale):
+ r"""Scale x-axis."""
+ self.x_scale = self.x_scale * scale
+ self.x = self.x * scale
+ return self
+
+ def scale_y(self, scale):
+ r"""Scale y-axis."""
+ self.y_scale = self.y_scale * scale
+ self.y = self.y * scale
+ return self
diff --git a/tests/array_utils_test.py b/tests/array_utils_test.py
deleted file mode 100644
index a565bd3..0000000
--- a/tests/array_utils_test.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import pytest
-from numpy.testing import assert_array_equal
-
-from traffic_weaver.array_utils import (
- oversample_linspace,
- oversample_piecewise_constant,
- extend_linspace,
- extend_constant,
- append_one_sample,
-)
-
-
-@pytest.mark.parametrize(
- "x, y, make_periodic, expected_x, expected_y",
- [
- ([1, 2, 4], [1, 2, 3], False, [1, 2, 4, 6], [1, 2, 3, 3]),
- ([1, 2, 4], [1, 2, 3], True, [1, 2, 4, 6], [1, 2, 3, 1]),
- ],
-)
-def test_add_one_sample(x, y, make_periodic, expected_x, expected_y):
- x, y = append_one_sample(x, y, make_periodic)
- assert_array_equal(x, expected_x)
- assert_array_equal(y, expected_y)
-
-
-@pytest.mark.parametrize(
- "x, num, expected",
- [
- ([1], 2, 1),
- ([1, 1], 3, [1, 1, 1, 1]),
- ([1, 2, 3], 4, [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]),
- ([1, 2, 3], 1, [1, 2, 3]), # test when num <= 2
- ],
-)
-def test_oversample_linspace(x, num, expected):
- xs = oversample_linspace(x, num)
- assert_array_equal(xs, expected)
-
-
-@pytest.mark.parametrize(
- "x, num, expected",
- [
- ([1], 2, 1),
- ([1, 1], 3, [1, 1, 1, 1]),
- ([1, 2, 3], 4, [1, 1, 1, 1, 2, 2, 2, 2, 3]),
- ([1, 2, 3], 1, [1, 2, 3]),
- # test when num <= 2
- ],
-)
-def test_oversample_piecewise(x, num, expected):
- xs = oversample_piecewise_constant(x, num)
- assert_array_equal(xs, expected)
-
-
-@pytest.mark.parametrize(
- "x, num, direction, lstart, rstop, expected",
- [
- ([1, 2, 3], 2, "both", None, None, [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]),
- ([1, 2, 3], 4, "right", 0, 4, [1.0, 2.0, 3.0, 3.25, 3.5, 3.75, 4.0]),
- ([1, 2, 3], 4, "left", 0, 4, [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, 3.0]),
- ],
-)
-def test_extend_linspace(x, num, direction, lstart, rstop, expected):
- xs = extend_linspace(x, num, direction, lstart, rstop)
- assert_array_equal(xs, expected)
-
-
-@pytest.mark.parametrize(
- "x, num, direction, expected",
- [
- ([1, 2, 3], 2, "both", [1, 1, 1, 2, 3, 3, 3]),
- ([1, 2, 3], 4, "right", [1, 2, 3, 3, 3, 3, 3]),
- ([1, 2, 3], 4, "left", [1, 1, 1, 1, 1, 2, 3]),
- ],
-)
-def test_extend_constant(x, num, direction, expected):
- xs = extend_constant(x, num, direction)
- assert_array_equal(xs, expected)
diff --git a/tests/datasets_test.py b/tests/datasets_test.py
deleted file mode 100644
index f1941b6..0000000
--- a/tests/datasets_test.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from unittest import mock
-
-import pytest
-from traffic_weaver.datasets import load_mobile_video
-from numpy.testing import assert_array_equal
-
-
-@pytest.fixture
-def mocked_dataset_file():
- data = """
- {
- "mobile_video": "[1, 2, 3, 4]"
- }
- """
- return mock.mock_open(read_data=data)
-
-
-@pytest.fixture
-def expected_content():
- return [1, 2, 3, 4]
-
-
-def test_open_dataset_file_by_method(mocked_dataset_file, expected_content):
- with mock.patch('builtins.open', mocked_dataset_file):
- x, y = load_mobile_video()
- assert_array_equal(y, expected_content)
diff --git a/tests/match_test.py b/tests/match_test.py
index 35e9d05..d2571b8 100644
--- a/tests/match_test.py
+++ b/tests/match_test.py
@@ -3,11 +3,10 @@
from numpy.testing import assert_array_almost_equal
from pytest import approx
-from traffic_weaver.match import (integral_matching_stretch,
- interval_integral_matching_stretch,
+from traffic_weaver.match import (_integral_matching_stretch,
+ _interval_integral_matching_stretch,
integral_matching_reference_stretch, )
-from traffic_weaver.array_utils import \
- left_piecewise_integral
+from traffic_weaver.sorted_array_utils import rectangle_integral
@pytest.fixture
@@ -22,7 +21,7 @@ def x():
@pytest.mark.parametrize("expected_integral", [10, 23.25, 30])
def test_integral_matching_stretch(expected_integral, x, y):
- y2 = integral_matching_stretch(x, y, expected_integral)
+ y2 = _integral_matching_stretch(x, y, expected_integral)
stretched_integral = np.trapz(y2, x)
# import matplotlib.pyplot as plt
@@ -39,7 +38,7 @@ def test_integral_matching_stretch(expected_integral, x, y):
def test_integral_matching_stretch_with_missing_x(y):
expected_integral = 10
- y2 = integral_matching_stretch(None, y, integral_value=expected_integral)
+ y2 = _integral_matching_stretch(None, y, integral_value=expected_integral)
x = np.arange(len(y2))
stretched_integral = np.trapz(y2, x)
@@ -49,7 +48,7 @@ def test_integral_matching_stretch_with_missing_x(y):
def test_integral_matching_stretch_with_missing_expected_integral(x, y):
- y2 = integral_matching_stretch(x, y)
+ y2 = _integral_matching_stretch(x, y)
expected_integral = 0
@@ -81,11 +80,11 @@ def get_integrals_based_on_interval_points(x, y, interval_points):
def test_interval_integral_with_matching_stretch(
expected_integrals, interval_points, y, x
):
- y2 = interval_integral_matching_stretch(
+ y2 = _interval_integral_matching_stretch(
x,
y,
integral_values=expected_integrals,
- interval_point_indices=interval_points,
+ fixed_points_indices_in_x=interval_points,
)
# if no interval points given, they are created evenly
@@ -111,11 +110,11 @@ def test_interval_integral_matching_stretch_with_missing_x(y):
expected_integrals = [49.46, 25]
interval_points = [0, 3, 8]
- y2 = interval_integral_matching_stretch(
+ y2 = _interval_integral_matching_stretch(
None,
y,
integral_values=expected_integrals,
- interval_point_indices=interval_points,
+ fixed_points_indices_in_x=interval_points,
)
x = np.arange(len(y2))
@@ -130,8 +129,8 @@ def test_interval_integral_matching_stretch_with_missing_x(y):
def test_interval_integral_matching_stretch_with_missing_expected_integral(x, y):
interval_points = [0, 3, 10]
- y2 = interval_integral_matching_stretch(
- x, y, interval_point_indices=interval_points
+ y2 = _interval_integral_matching_stretch(
+ x, y, fixed_points_indices_in_x=interval_points
)
stretched_integrals = get_integrals_based_on_interval_points(x, y2, interval_points)
@@ -147,7 +146,7 @@ def test_fail_integral_matching_stretch_with_missing_expected_integral_and_inter
x, y
):
with pytest.raises(ValueError) as exc_info:
- interval_integral_matching_stretch(x, y)
+ _interval_integral_matching_stretch(x, y)
assert exc_info.type is ValueError
@@ -156,11 +155,11 @@ def test_integral_matching_reference_stretch(x, y):
x_ref = x[::2]
expected_y = [1.0, 0.26, 2.1, 0.8, 6.0, 8.0, 2.0, 1.5, 3.0, 2.0, 6.0]
- y2 = integral_matching_reference_stretch(x, y, x_ref, y_ref)
+ y2 = integral_matching_reference_stretch(x, y, x_ref, y_ref, reference_function_integral_method='rectangle')
assert_array_almost_equal(y2, expected_y, decimal=2)
- expected_integrals = left_piecewise_integral(
+ expected_integrals = rectangle_integral(
x_ref, y_ref
)
actual_integrals = get_integrals_based_on_interval_points(
diff --git a/tests/precess_test.py b/tests/precess_test.py
index 93838b7..ab587b0 100644
--- a/tests/precess_test.py
+++ b/tests/precess_test.py
@@ -2,7 +2,10 @@
import pytest
from numpy.ma.testutils import assert_array_equal, assert_array_almost_equal
-from traffic_weaver.process import repeat, trend, linear_trend, noise_gauss, average
+from traffic_weaver.process import (
+ repeat, trend, linear_trend, noise_gauss, average,
+ _piecewise_constant_interpolate,
+)
@pytest.fixture
@@ -18,7 +21,7 @@ def test_repeat(xy):
def test_trend(xy):
shift = [0, 1 / 16, 1 / 4, 9 / 16, 1]
- nx, ny = trend(xy[0], xy[1], lambda x: x**2)
+ nx, ny = trend(xy[0], xy[1], lambda x: x ** 2)
assert_array_equal(nx, xy[0])
assert_array_equal(ny, xy[1] + shift)
@@ -73,3 +76,15 @@ def test_average(xy):
expected_y = [2, 2.5, 2]
assert_array_equal(ax, expected_x)
assert_array_equal(ay, expected_y)
+
+
+@pytest.mark.parametrize(
+ "x, y, new_x, expected",
+ [
+ ([0, 1, 2], [2, 3, 4], [0.5, 1.0, 1.5, 2.0, 2.5], [2, 3, 3, 4, 4]),
+ ([2, 3, 4], [5, 4, 3], [0, 1, 2, 5, 8], [5, 5, 5, 3, 3]),
+ ],
+)
+def test_piecewise_constant_interpolate(x, y, new_x, expected):
+ new_y = _piecewise_constant_interpolate(x, y, new_x)
+ assert_array_equal(new_y, expected)
diff --git a/tests/oversample_test.py b/tests/rfa_test.py
similarity index 56%
rename from tests/oversample_test.py
rename to tests/rfa_test.py
index ebe080d..39933fc 100644
--- a/tests/oversample_test.py
+++ b/tests/rfa_test.py
@@ -2,13 +2,13 @@
import pytest
from numpy.ma.testutils import assert_array_approx_equal
-from traffic_weaver import (
- PiecewiseConstantOversample,
- CubicSplineOversample,
- LinearFixedOversample,
- LinearAdaptiveOversample,
- ExpFixedOversample,
- ExpAdaptiveOversample,
+from traffic_weaver.rfa import (
+ PiecewiseConstantRFA,
+ CubicSplineRFA,
+ LinearFixedRFA,
+ LinearAdaptiveRFA,
+ ExpFixedRFA,
+ ExpAdaptiveRFA,
)
@@ -17,38 +17,38 @@ def xy():
return np.arange(4), np.array([1, 3, 4, 1])
-def test_piecewise_constant_oversample(xy):
- ov_x, ov_y = PiecewiseConstantOversample(xy[0], xy[1], 4).oversample()
+def test_piecewise_constant_rfa(xy):
+ ov_x, ov_y = PiecewiseConstantRFA(xy[0], xy[1], 4).rfa()
expected = np.array([1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 1], dtype=np.float64)
assert_array_approx_equal(ov_y, expected)
-def test_cubic_spline__oversample(xy):
- ov_x, ov_y = CubicSplineOversample(xy[0], xy[1], 4).oversample()
+def test_cubic_spline_rfa(xy):
+ ov_x, ov_y = CubicSplineRFA(xy[0], xy[1], 4).rfa()
# check every forth if in place
ov_y = ov_y[::4]
expected = np.array([1, 3, 4, 1], dtype=np.float64)
assert_array_approx_equal(ov_y, expected)
-def test_linear_fixed_oversample(xy):
- ov_x, ov_y = LinearFixedOversample(xy[0], xy[1], 4).oversample()
+def test_linear_fixed_rfa(xy):
+ ov_x, ov_y = LinearFixedRFA(xy[0], xy[1], 4).rfa()
expected = np.array(
[1, 1, 1, 1.5, 2, 2.5, 3, 3.25, 3.5, 3.75, 4, 3.25, 2.5], dtype=np.float64
)
assert_array_approx_equal(ov_y, expected)
-def test_linear_adaptive_oversample(xy):
- ov_x, ov_y = LinearAdaptiveOversample(xy[0], xy[1], 4).oversample()
+def test_linear_adaptive_rfa(xy):
+ ov_x, ov_y = LinearAdaptiveRFA(xy[0], xy[1], 4).rfa()
expected = np.array(
[1, 1, 1, 1.66, 2.33, 3, 3, 3.2, 3.4, 3.6, 3.8, 4, 2.5], dtype=np.float64
)
assert_array_approx_equal(ov_y, expected, decimal=2)
-def test_exp_fixed_oversample(xy):
- ov_x, ov_y = ExpFixedOversample(xy[0], xy[1], 8).oversample()
+def test_exp_fixed_rfa(xy):
+ ov_x, ov_y = ExpFixedRFA(xy[0], xy[1], 8).rfa()
expected = np.array(
[1, 1, 1, 1, 1, 1.187, 1.5, 1.75]
+ [2.0, 2.25, 2.5, 2.812, 3, 3.093, 3.25, 3.375]
@@ -58,8 +58,8 @@ def test_exp_fixed_oversample(xy):
assert_array_approx_equal(ov_y, expected, decimal=2)
-def test_exp_adaptive_oversample(xy):
- ov_x, ov_y = ExpAdaptiveOversample(xy[0], xy[1], 8).oversample()
+def test_exp_adaptive_rfa(xy):
+ ov_x, ov_y = ExpAdaptiveRFA(xy[0], xy[1], 8).rfa()
print(ov_y)
expected = np.array(
[1, 1, 1, 1, 1, 1.25, 1.666, 2]
@@ -70,25 +70,25 @@ def test_exp_adaptive_oversample(xy):
assert_array_approx_equal(ov_y, expected, decimal=2)
-def test_fail_too_small_oversample(xy):
+def test_fail_too_small_rfa(xy):
with pytest.raises(ValueError):
- LinearFixedOversample(xy[0], xy[1], 1)
+ LinearFixedRFA(xy[0], xy[1], 1)
def test_setting_parameters(xy):
- ov = LinearFixedOversample(xy[0], xy[1], 12, alpha=0.0)
+ ov = LinearFixedRFA(xy[0], xy[1], 12, alpha=0.0)
assert ov.a == 2
- ov = LinearAdaptiveOversample(xy[0], xy[1], 12, alpha=0.0)
+ ov = LinearAdaptiveRFA(xy[0], xy[1], 12, alpha=0.0)
assert ov.a == 2
- ov = ExpFixedOversample(xy[0], xy[1], 12, alpha=0.0)
+ ov = ExpFixedRFA(xy[0], xy[1], 12, alpha=0.0)
assert ov.a == 2
- ov = ExpAdaptiveOversample(xy[0], xy[1], 12, alpha=0.0)
+ ov = ExpAdaptiveRFA(xy[0], xy[1], 12, alpha=0.0)
assert ov.a == 2
def test_special_cases_in_oversample():
# test 0 nominators and denominators
x, y = np.arange(5), np.array([1, 1, 1, 3, 3])
- LinearAdaptiveOversample(x, y, 4).oversample()
- ExpAdaptiveOversample(x, y, 4, beta=0).oversample()
+ LinearAdaptiveRFA(x, y, 4).rfa()
+ ExpAdaptiveRFA(x, y, 4, beta=0).rfa()
assert True
diff --git a/tests/sorted_array_utils_test.py b/tests/sorted_array_utils_test.py
new file mode 100644
index 0000000..b0f8e1a
--- /dev/null
+++ b/tests/sorted_array_utils_test.py
@@ -0,0 +1,88 @@
+import pytest
+from numpy.testing import assert_array_equal
+
+from traffic_weaver.sorted_array_utils import (oversample_linspace, oversample_piecewise_constant, extend_linspace,
+ extend_constant, append_one_sample,
+ find_closest_lower_equal_element_indices_to_values,
+ find_closest_higher_equal_element_indices_to_values,
+ find_closest_lower_or_higher_element_indices_to_values,
+ sum_over_indices, )
+
+
+@pytest.mark.parametrize("x, y, make_periodic, expected_x, expected_y",
+ [([1, 2, 4], [1, 2, 3], False, [1, 2, 4, 6], [1, 2, 3, 3]),
+ ([1, 2, 4], [1, 2, 3], True, [1, 2, 4, 6], [1, 2, 3, 1]), ], )
+def test_add_one_sample(x, y, make_periodic, expected_x, expected_y):
+ x, y = append_one_sample(x, y, make_periodic)
+ assert_array_equal(x, expected_x)
+ assert_array_equal(y, expected_y)
+
+
+@pytest.mark.parametrize("x, num, expected", [([1], 2, 1), ([1, 1], 3, [1, 1, 1, 1]),
+ ([1, 2, 3], 4, [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]),
+ ([1, 2, 3], 1, [1, 2, 3]), # test when num <= 2
+ ], )
+def test_oversample_linspace(x, num, expected):
+ xs = oversample_linspace(x, num)
+ assert_array_equal(xs, expected)
+
+
+@pytest.mark.parametrize("x, num, expected",
+ [([1], 2, 1), ([1, 1], 3, [1, 1, 1, 1]), ([1, 2, 3], 4, [1, 1, 1, 1, 2, 2, 2, 2, 3]),
+ ([1, 2, 3], 1, [1, 2, 3]), # test when num <= 2
+ ], )
+def test_oversample_piecewise(x, num, expected):
+ xs = oversample_piecewise_constant(x, num)
+ assert_array_equal(xs, expected)
+
+
+@pytest.mark.parametrize("x, num, direction, lstart, rstop, expected",
+ [([1, 2, 3], 2, "both", None, None, [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0]),
+ ([1, 2, 3], 4, "right", 0, 4, [1.0, 2.0, 3.0, 3.25, 3.5, 3.75, 4.0]),
+ ([1, 2, 3], 4, "left", 0, 4, [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, 3.0]), ], )
+def test_extend_linspace(x, num, direction, lstart, rstop, expected):
+ xs = extend_linspace(x, num, direction, lstart, rstop)
+ assert_array_equal(xs, expected)
+
+
+@pytest.mark.parametrize("x, num, direction, expected",
+ [([1, 2, 3], 2, "both", [1, 1, 1, 2, 3, 3, 3]), ([1, 2, 3], 4, "right", [1, 2, 3, 3, 3, 3, 3]),
+ ([1, 2, 3], 4, "left", [1, 1, 1, 1, 1, 2, 3]), ], )
+def test_extend_constant(x, num, direction, expected):
+ xs = extend_constant(x, num, direction)
+ assert_array_equal(xs, expected)
+
+
+@pytest.mark.parametrize("a, values, expected", [([1, 2, 3], [1.1], [0]), ([1, 2, 3], [0.2, 0.9, 1.4], [0, 0, 0]),
+ ([1, 2, 3], [1, 2, 3], [0, 1, 2]), ([1, 2, 3], [2.9, 3.0], [1, 2]),
+ ([1, 2, 3], [2.8, 3.5, 4.0], [1, 2, 2]), ], )
+def test_find_closest_lower_equal_element_indices_to_values(a, values, expected):
+ indices = find_closest_lower_equal_element_indices_to_values(a, values)
+ assert_array_equal(indices, expected)
+
+
+@pytest.mark.parametrize("a, values, expected", [([1, 2, 3], [1.1], [1]), ([1, 2, 3], [0.2, 0.9, 1.4], [0, 0, 1]),
+ ([1, 2, 3], [1, 2, 3], [0, 1, 2]), ([1, 2, 3], [1.9, 3.0], [1, 2]),
+ ([1, 2, 3], [2.8, 3.5, 4.0], [2, 2, 2]), ], )
+def test_find_closest_higher_equal_element_indices_to_values(a, values, expected):
+ indices = find_closest_higher_equal_element_indices_to_values(a, values)
+ print(indices)
+ assert_array_equal(indices, expected)
+
+
+@pytest.mark.parametrize("a, values, expected", [([1, 2, 3], [1.1], [0]), ([1, 2, 3], [0.2, 0.9, 1.4], [0, 0, 0]),
+ ([1, 2, 3], [1, 2, 3], [0, 1, 2]),
+ ([1, 2, 3], [1.5, 2.5, 3.5], [0, 1, 2]),
+ ([1, 2, 3], [1.9, 3.0], [1, 2]),
+ ([1, 2, 3], [2.2, 3.5, 4.0], [1, 2, 2]), ], )
+def test_find_closest_element_lower_or_higher_indices_to_values(a, values, expected):
+ indices = find_closest_lower_or_higher_element_indices_to_values(a, values)
+ assert_array_equal(indices, expected)
+
+
+@pytest.mark.parametrize("a, indices, expected",
+ [([1, 2, 3, 4, 5], [0, 5], [15]), ([1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),
+ ([1, 2, 3, 4, 5], [0, 2, 4], [3, 7]), ])
+def test_sum_over_indices(a, indices, expected):
+ res = sum_over_indices(a, indices)
+ assert_array_equal(res, expected)
diff --git a/tests/weaver_test.py b/tests/weaver_test.py
index 6e4b06e..7343247 100644
--- a/tests/weaver_test.py
+++ b/tests/weaver_test.py
@@ -1,8 +1,10 @@
import numpy as np
+import pandas as pd
import pytest
from numpy.testing import assert_array_equal
-from traffic_weaver import Weaver, PiecewiseConstantOversample
+from traffic_weaver import Weaver
+from traffic_weaver.rfa import PiecewiseConstantRFA
@pytest.fixture
@@ -12,7 +14,7 @@ def xy():
@pytest.fixture
def expected_xy():
- return np.arange(0, 4.5, 0.5), np.array([1, 1, 3, 3, 4, 4, 1, 1, 2])
+ return np.arange(0, 4.5, 0.5), np.array([2, 2, 6, 6, 8, 8, 2, 2, 4])
@pytest.fixture
@@ -37,7 +39,9 @@ def mock_weaver_delegates(mocker, xy):
def test_weaver_chain(mock_weaver_delegates, xy, expected_xy):
weaver = Weaver(xy[0], xy[1])
- weaver.oversample(2, oversample_class=PiecewiseConstantOversample)
+ weaver.recreate_from_average(2, rfa_class=PiecewiseConstantRFA)
+ weaver.scale_x(1)
+ weaver.scale_y(2)
weaver.integral_match()
weaver.repeat(2)
weaver.trend(lambda x: 0)
@@ -57,3 +61,64 @@ def test_weaver_chain(mock_weaver_delegates, xy, expected_xy):
weaver.restore_original()
assert_array_equal(weaver.get_original()[0], weaver.get()[0])
assert_array_equal(weaver.get_original()[1], weaver.get()[1])
+
+
+def test_weaver_factories(xy):
+ weaver = Weaver(xy[0], xy[1])
+
+ weaver2 = Weaver.from_dataframe(pd.DataFrame({"x": xy[0], "y": xy[1]}), x_col='x', y_col='y')
+ weaver3 = Weaver.from_2d_array(np.column_stack(xy))
+ weaver4 = Weaver(x=None, y=xy[1])
+
+ assert_array_equal(weaver.get()[0], weaver2.get()[0])
+ assert_array_equal(weaver.get()[1], weaver2.get()[1])
+
+ assert_array_equal(weaver.get()[0], weaver3.get()[0])
+ assert_array_equal(weaver.get()[1], weaver3.get()[1])
+
+ assert_array_equal(weaver.get()[0], weaver4.get()[0])
+ assert_array_equal(weaver.get()[1], weaver4.get()[1])
+
+
+def test_raise_exception_on_wrong_input_dimension():
+ with pytest.raises(ValueError):
+ Weaver([1, 2], [1, 2, 3])
+
+ with pytest.raises(ValueError):
+ Weaver.from_2d_array(np.zeros((3, 2, 2)))
+
+
+def test_slice_by_index(xy):
+ weaver = Weaver(xy[0], xy[1])
+ weaver.slice_by_index(1, 4)
+ assert_array_equal(weaver.get()[0], np.array([1, 2, 3]))
+ assert_array_equal(weaver.get()[1], np.array([3, 4, 1]))
+
+ weaver = Weaver(xy[0], xy[1])
+ weaver.slice_by_index(step=2)
+ assert_array_equal(weaver.get()[0], np.array([0, 2, 4]))
+ assert_array_equal(weaver.get()[1], np.array([1, 4, 2]))
+
+
+def test_slice_by_value(xy):
+ weaver = Weaver(xy[0], xy[1])
+ weaver.slice_by_value(1, 4)
+ assert_array_equal(weaver.get()[0], np.array([1, 2, 3]))
+ assert_array_equal(weaver.get()[1], np.array([3, 4, 1]))
+
+
+def test_slice_by_index_out_of_bounds(xy):
+ weaver = Weaver(xy[0], xy[1])
+ with pytest.raises(ValueError):
+ weaver.slice_by_index(1, 10)
+ with pytest.raises(ValueError):
+ weaver.slice_by_index(-1, 4)
+
+
+def test_copy(xy):
+ weaver = Weaver(xy[0], xy[1])
+ weaver2 = weaver.copy()
+ assert_array_equal(weaver.get()[0], weaver2.get()[0])
+ assert_array_equal(weaver.get()[1], weaver2.get()[1])
+ weaver2.get()[0][0] = 100
+ assert weaver.get()[0][0] != weaver2.get()[0][0]