From c1457bd9f877ec0618b7769731b95ba97b54b830 Mon Sep 17 00:00:00 2001 From: Robin Holzinger Date: Sun, 22 Sep 2024 04:26:16 +0200 Subject: [PATCH] Update experiments --- .../arxiv/compare_trigger_policies/run.py | 192 +++++++- .../huffpost/compare_trigger_policies/run.py | 190 +++++++- .../yearbook/compare_trigger_policies/run.py | 446 +++++++++--------- .../schema/pipeline/trigger/cost/cost.py | 2 +- .../pipeline/trigger/drift/alibi_detect.py | 51 +- ...avoidablemissclassification_costtrigger.py | 4 + .../internal/triggers/costtrigger.py | 2 +- .../internal/triggers/drift/detector/alibi.py | 58 ++- 8 files changed, 652 insertions(+), 293 deletions(-) diff --git a/experiments/arxiv/compare_trigger_policies/run.py b/experiments/arxiv/compare_trigger_policies/run.py index 843a6b6d8..bbc4f984e 100644 --- a/experiments/arxiv/compare_trigger_policies/run.py +++ b/experiments/arxiv/compare_trigger_policies/run.py @@ -9,7 +9,9 @@ EvalHandlerConfig, ModynPipelineConfig, ) +from modyn.config.schema.pipeline.evaluation.config import EvalDataConfig from modyn.config.schema.pipeline.evaluation.handler import EvalHandlerExecutionTime +from modyn.config.schema.pipeline.evaluation.metrics import AccuracyMetricConfig from modyn.config.schema.pipeline.evaluation.strategy.between_two_triggers import ( BetweenTwoTriggersEvalStrategyConfig, ) @@ -19,11 +21,26 @@ from modyn.config.schema.pipeline.evaluation.strategy.slicing import ( SlicingEvalStrategyConfig, ) +from modyn.config.schema.pipeline.trigger.drift.alibi_detect import AlibiDetectMmdDriftMetric +from modyn.config.schema.pipeline.trigger.drift.config import DataDriftTriggerConfig +from modyn.config.schema.pipeline.trigger.drift.criterion import ( + DynamicQuantileThresholdCriterion, + DynamicRollingAverageThresholdCriterion, +) +from modyn.config.schema.pipeline.trigger.drift.detection_window.time_ import TimeWindowingStrategy +from modyn.config.schema.pipeline.trigger.performance.criterion import StaticNumberAvoidableMisclassificationCriterion +from modyn.config.schema.pipeline.trigger.performance.performance import PerformanceTriggerConfig, PerformanceTriggerEvaluationConfig from modyn.config.schema.pipeline.trigger.simple.data_amount import DataAmountTriggerConfig from modyn.config.schema.pipeline.trigger.simple.time import TimeTriggerConfig from modyn.utils.utils import SECONDS_PER_UNIT from modynclient.config.schema.client_config import ModynClientConfig, Supervisor + +from .pipeline_config import ( + arxiv_bytes_parser_function, + arxiv_evaluation_transformer_function, +) + _FIRST_TIMESTAMP = int(pd.to_datetime("1995-01-01").timestamp()) _LAST_TIMESTAMP = int(pd.to_datetime("2024-07-01").timestamp()) @@ -105,40 +122,171 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: # 1X: Baselines with PERIODIC_EVAL_INTERVAL, executed with cautious # # parallelism and post factum evaluation (bottlenecking) # # -------------------------------------------------------------------------------- # - # time baselines - 10: Experiment( - name="arxiv-baseline-time", + # # time baselines + # 10: Experiment( + # name="arxiv-baseline-time", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # time_triggers={ + # schedule: TimeTriggerConfig(every=schedule, start_timestamp=_FIRST_TIMESTAMP) + # for schedule in reversed(["26w", "10y"]) + # # 0: "1y", "2y", "5y" + # # 1: "26w", "10y" + # }, + # gpu_device="cuda:2", + # ), + # # data amount baselines + # 11: Experiment( + # name="arxiv-baseline-dataamount", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # data_amount_triggers={ + # f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) + # for num_samples in reversed([25_000, 50_000]) + # # 2: 100_000, 500_000, 1_000_000 + # # 3: 25_000, 50_000 + # }, + # gpu_device="cuda:3", + # ), + # -------------------------------------------------------------------------------- # + # 2X: Drift triggers # + # -------------------------------------------------------------------------------- # + # TODO + # Dynamic threshold drift + 21: Experiment( + name="arxiv-datadrift-dynamic", eval_handlers=( construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + construct_between_trigger_eval_handler("manual") ), - time_triggers={ - schedule: TimeTriggerConfig(every=schedule, start_timestamp=_FIRST_TIMESTAMP) - for schedule in reversed(["26w", "10y"]) - # 0: "1y", "2y", "5y" - # 1: "26w", "10y" + drift_detection_triggers={ + f"{criterion_name}_int{detection_interval}_win{window_size}": DataDriftTriggerConfig( + evaluation_interval_data_points=detection_interval, + windowing_strategy=TimeWindowingStrategy( + # overlap has no affect acc. to offline exploration + limit_ref=window_size, + limit_cur=window_size, + allow_overlap=False, + ), + # frist 200k of 2mio samples are warmup + warmup_intervals=200_000 // detection_interval, + # triggering every 3 years during the warmup phase seems reasonable. + warmup_policy=TimeTriggerConfig(every="2y", start_timestamp=_FIRST_TIMESTAMP), + # 5k samples are enough for drift detection + sample_size=5_000, + metrics={"mmd": AlibiDetectMmdDriftMetric(decision_criterion=criterion, device="gpu")}, + ) + # multiprocessing across gpus + for detection_interval in [20_000] + for window_size in ["1y"] # dataset specific + for decision_window_size in [15] # TODO: check + for criterion_name, criterion in ( + { + f"mmd-quant-{quantile}-{decision_window_size}": DynamicQuantileThresholdCriterion( + window_size=decision_window_size, quantile=quantile + ) + for quantile in [0.05, 0.10] # TODO: 0.15, 0.3 + # 0: 0.05 + # 1: 0.1 + # 2: + # 3: + } + | + { + f"mmd-rollavg-{deviation}-{decision_window_size}": DynamicRollingAverageThresholdCriterion( + window_size=decision_window_size, deviation=deviation, absolute=False + ) + for deviation in [0.5, 1.0, 2.0] # TODO: 0.05, 0.2, + # 0: + # 1: + # 2: 0.5 + # 3: 1.0, 2.0 + } + ).items() }, - gpu_device="cuda:2", + gpu_device="cuda:0", ), - # data amount baselines - 11: Experiment( - name="arxiv-baseline-dataamount", + # -------------------------------------------------------------------------------- # + # 3X: Performance triggers # + # -------------------------------------------------------------------------------- # + 30: Experiment( + name="arxiv-performancetrigger", eval_handlers=( construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + construct_between_trigger_eval_handler("manual") ), - data_amount_triggers={ - f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) - for num_samples in reversed([25_000, 50_000]) - # 2: 100_000, 500_000, 1_000_000 - # 3: 25_000, 50_000 + performance_triggers={ + f"{criterion_name}-int{detection_interval}y": PerformanceTriggerConfig( + evaluation_interval_data_points=detection_interval, + data_density_window_size=20, # performed well for drift, only used for #avoidable misclass + performance_triggers_window_size=20, # performed well for drift, only used for #avoidable misclass + warmup_intervals=200_000 // detection_interval, # first 200k of 2mio samples are warmup + # triggering every 3 years during the warmup phase seems reasonable. + warmup_policy=TimeTriggerConfig(every="2y", start_timestamp=_FIRST_TIMESTAMP), + evaluation=PerformanceTriggerEvaluationConfig( + device="cuda:2", + dataset=EvalDataConfig( + dataset_id="arxiv_kaggle_train", # optional: extra holdout split + bytes_parser_function=arxiv_bytes_parser_function, + batch_size=512, + dataloader_workers=1, + metrics=[ + AccuracyMetricConfig(evaluation_transformer_function=arxiv_evaluation_transformer_function), + ], + ), + ), + mode="hindsight", + forecasting_method="ridge_regression", + decision_criteria={criterion_name: criterion}, + ) + for detection_interval in [20_000] + for criterion_name, criterion in ( + # { + # f"static-{perf_threshold}": StaticPerformanceThresholdCriterion( + # metric="Accuracy", metric_threshold=perf_threshold + # ) + # for perf_threshold in [0.45, 0.5, 0.55, 0.6] + # } + # | + # { + # f"dynamic-quant-{quantile}-{decision_window_size}": DynamicQuantilePerformanceThresholdCriterion( + # metric="Accuracy", + # quantile=quantile, + # window_size=decision_window_size, + # ) + # for quantile in [0.05, 0.15, 0.3] + # for decision_window_size in [15, 30] + # } + # | + # { + # f"dynamic-rollavg-{deviation}-{decision_window_size}": DynamicRollingAveragePerformanceThresholdCriterion( + # metric="Accuracy", + # deviation=deviation, + # absolute=False, + # window_size=decision_window_size, + # ) + # for deviation in reversed([0.1, 0.2, 0.3]) + # for decision_window_size in [15, 30] + # } + # | + { + f"num_misclass-{num_misclassifications}-exp-{expected_accuracy}-red-{allow_reduction}-": StaticNumberAvoidableMisclassificationCriterion( + expected_accuracy=expected_accuracy, + allow_reduction=allow_reduction, + avoidable_misclassification_threshold=num_misclassifications, + ) + for num_misclassifications in reversed([10000]) # 1000, 2000, 5000, 7500, 10000 + for expected_accuracy in [0.5, 0.55, 0.6] + for allow_reduction in [False] + } + ).items() }, - gpu_device="cuda:3", + gpu_device="cuda:2", ), - # -------------------------------------------------------------------------------- # - # 2X: Drift triggers # - # -------------------------------------------------------------------------------- # - # TODO } diff --git a/experiments/huffpost/compare_trigger_policies/run.py b/experiments/huffpost/compare_trigger_policies/run.py index e54b8c8b3..11e4ec8f6 100644 --- a/experiments/huffpost/compare_trigger_policies/run.py +++ b/experiments/huffpost/compare_trigger_policies/run.py @@ -8,7 +8,9 @@ EvalHandlerConfig, ModynPipelineConfig, ) +from modyn.config.schema.pipeline.evaluation.config import EvalDataConfig from modyn.config.schema.pipeline.evaluation.handler import EvalHandlerExecutionTime +from modyn.config.schema.pipeline.evaluation.metrics import AccuracyMetricConfig from modyn.config.schema.pipeline.evaluation.strategy.between_two_triggers import ( BetweenTwoTriggersEvalStrategyConfig, ) @@ -18,11 +20,27 @@ from modyn.config.schema.pipeline.evaluation.strategy.slicing import ( SlicingEvalStrategyConfig, ) -from modyn.config.schema.pipeline.trigger.simple.data_amount import DataAmountTriggerConfig +from modyn.config.schema.pipeline.trigger.drift.alibi_detect import AlibiDetectMmdDriftMetric +from modyn.config.schema.pipeline.trigger.drift.config import DataDriftTriggerConfig +from modyn.config.schema.pipeline.trigger.drift.criterion import ( + DynamicQuantileThresholdCriterion, +) +from modyn.config.schema.pipeline.trigger.drift.detection_window.time_ import TimeWindowingStrategy +from modyn.config.schema.pipeline.trigger.performance.criterion import ( + StaticNumberAvoidableMisclassificationCriterion, +) +from modyn.config.schema.pipeline.trigger.performance.performance import ( + PerformanceTriggerConfig, + PerformanceTriggerEvaluationConfig, +) from modyn.config.schema.pipeline.trigger.simple.time import TimeTriggerConfig from modynclient.config.schema.client_config import ModynClientConfig, Supervisor -from .pipeline_config import gen_pipeline_config +from .pipeline_config import ( + gen_pipeline_config, + hp_bytes_parser_function, + hp_evaluation_transformer_function, +) _FIRST_TIMESTAMP = int(pd.to_datetime("2012-01-28").timestamp()) _LAST_TIMESTAMP = int(pd.to_datetime("2022-09-24").timestamp()) # last: dummy @@ -107,40 +125,166 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: # 1X: Baselines with PERIODIC_EVAL_INTERVAL, executed with cautious # # parallelism and post factum evaluation (bottlenecking) # # -------------------------------------------------------------------------------- # - # time baselines - 10: Experiment( - name="hp-baseline-time", + # # time baselines + # 10: Experiment( + # name="hp-baseline-time", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # time_triggers={ + # schedule: TimeTriggerConfig(every=schedule, start_timestamp=_FIRST_TIMESTAMP) + # for schedule in (["13w", "4y"]) # reversed + # # 0: "26w", "1y", "2y" + # # 1: "13w", "4y" + # }, + # gpu_device="cuda:2", + # ), + # # data amount baselines + # 11: Experiment( + # name="hp-baseline-dataamount", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # data_amount_triggers={ + # f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) + # for num_samples in ([5_000, 80_000]) + # # 2: 10_000, 20_000, 40_000 + # # 3: 5_000, 80_000 + # }, + # gpu_device="cuda:3", + # ), + # -------------------------------------------------------------------------------- # + # 2X: Drift triggers # + # -------------------------------------------------------------------------------- # + # TODO: rerun huffpost with different eval set + 21: Experiment( + name="hp-datadrift-dynamic", eval_handlers=( construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + construct_between_trigger_eval_handler("manual") ), - time_triggers={ - schedule: TimeTriggerConfig(every=schedule, start_timestamp=_FIRST_TIMESTAMP) - for schedule in (["26w", "1y", "2y"]) # reversed - # 0: "26w", "1y", "2y" - # 1: "13w", "4y" + drift_detection_triggers={ + f"{criterion_name}_int{detection_interval}_win{window_size}": DataDriftTriggerConfig( + evaluation_interval_data_points=detection_interval, + windowing_strategy=TimeWindowingStrategy( + # overlap has no affect acc. to offline exploration + limit_ref=window_size, + limit_cur=window_size, + allow_overlap=False, + ), + # first 200k of 2mio samples are warmup + warmup_intervals=30_000 // detection_interval, + # triggering every 3 years during the warmup phase seems reasonable. + warmup_policy=TimeTriggerConfig(every="40w", start_timestamp=_FIRST_TIMESTAMP), + # 5k samples are enough for drift detection + sample_size=5_000, + metrics={"mmd": AlibiDetectMmdDriftMetric(decision_criterion=criterion, device="gpu")}, + ) + # multiprocessing across gpus + for detection_interval in [1500] + for window_size in ["1y"] # dataset specific + for decision_window_size in [15, 30] # TODO: check + for criterion_name, criterion in ( + { + f"mmd-quant-{quantile}-{decision_window_size}": DynamicQuantileThresholdCriterion( + window_size=decision_window_size, quantile=quantile + ) + for quantile in [0.02, 0.05, 0.10, 0.15] # TODO: 0.3 + } + # | + # { + # f"mmd-rollavg-{deviation}-{decision_window_size}": DynamicRollingAverageThresholdCriterion( + # window_size=decision_window_size, deviation=deviation, absolute=False + # ) + # for deviation in [0.5, 1.0, 2.0, 5.0] # TODO: 0.05, 0.2, + # # 0: + # # 1: + # # 2: 0.5 + # # 3: 1.0, 2.0 + # } + ).items() }, gpu_device="cuda:0", ), - # data amount baselines - 11: Experiment( - name="hp-baseline-dataamount", + # -------------------------------------------------------------------------------- # + # 3X: Performance triggers # + # -------------------------------------------------------------------------------- # + 30: Experiment( + name="hp-performancetrigger", eval_handlers=( construct_periodic_eval_handlers(intervals=PERIODIC_EVAL_INTERVAL, execution_time="manual") + construct_between_trigger_eval_handler("manual") ), - data_amount_triggers={ - f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) - for num_samples in ([10_000, 20_000, 40_000]) - # 2: 10_000, 20_000, 40_000 - # 3: 5_000, 80_000 + performance_triggers={ + f"{criterion_name}-int{detection_interval}y": PerformanceTriggerConfig( + evaluation_interval_data_points=detection_interval, + data_density_window_size=20, # performed well for drift, only used for #avoidable misclass + performance_triggers_window_size=20, # performed well for drift, only used for #avoidable misclass + warmup_intervals=30_000 // detection_interval, # first 200k of 2mio samples are warmup + # triggering every 3 years during the warmup phase seems reasonable. + warmup_policy=TimeTriggerConfig(every="40w", start_timestamp=_FIRST_TIMESTAMP), + evaluation=PerformanceTriggerEvaluationConfig( + device="cuda:2", + dataset=EvalDataConfig( + dataset_id="huffpost_kaggle_train", # optional: extra holdout split + bytes_parser_function=hp_bytes_parser_function, + batch_size=512, + dataloader_workers=1, + metrics=[ + AccuracyMetricConfig(evaluation_transformer_function=hp_evaluation_transformer_function), + ], + ), + ), + mode="hindsight", + forecasting_method="ridge_regression", + decision_criteria={criterion_name: criterion}, + ) + for detection_interval in [1500] + for criterion_name, criterion in ( + # { + # f"static-{perf_threshold}": StaticPerformanceThresholdCriterion( + # metric="Accuracy", metric_threshold=perf_threshold + # ) + # for perf_threshold in [0.45, 0.5, 0.55, 0.6] + # } + # | + # { + # f"dynamic-quant-{quantile}-{decision_window_size}": DynamicQuantilePerformanceThresholdCriterion( + # metric="Accuracy", + # quantile=quantile, + # window_size=decision_window_size, + # ) + # for quantile in [0.05, 0.15, 0.3] + # for decision_window_size in [15, 30] + # } + # | + # { + # f"dynamic-rollavg-{deviation}-{decision_window_size}": DynamicRollingAveragePerformanceThresholdCriterion( + # metric="Accuracy", + # deviation=deviation, + # absolute=False, + # window_size=decision_window_size, + # ) + # for deviation in reversed([0.1, 0.2, 0.3]) + # for decision_window_size in [15, 30] + # } + # | + { + f"num_misclass-{num_misclassifications}-exp-{expected_accuracy}-red-{allow_reduction}-": StaticNumberAvoidableMisclassificationCriterion( + expected_accuracy=expected_accuracy, + allow_reduction=allow_reduction, + avoidable_misclassification_threshold=num_misclassifications, + ) + for num_misclassifications in reversed([10000]) # 1000, 2000, 5000, 7500, 10000 + for expected_accuracy in [0.5, 0.55, 0.6] + for allow_reduction in [False] # TODO: test with [False] + } + ).items() }, - gpu_device="cuda:1", + gpu_device="cuda:2", ), - # -------------------------------------------------------------------------------- # - # 2X: Drift triggers # - # -------------------------------------------------------------------------------- # - # TODO } diff --git a/experiments/yearbook/compare_trigger_policies/run.py b/experiments/yearbook/compare_trigger_policies/run.py index a64289ec0..1a4d4db80 100644 --- a/experiments/yearbook/compare_trigger_policies/run.py +++ b/experiments/yearbook/compare_trigger_policies/run.py @@ -21,33 +21,14 @@ from modyn.config.schema.pipeline.evaluation.strategy.slicing import ( SlicingEvalStrategyConfig, ) -from modyn.config.schema.pipeline.trigger import DataDriftTriggerConfig from modyn.config.schema.pipeline.trigger.cost.cost import ( AvoidableMisclassificationCostTriggerConfig, DataIncorporationLatencyCostTriggerConfig, ) -from modyn.config.schema.pipeline.trigger.drift.alibi_detect import ( - AlibiDetectMmdDriftMetric, -) -from modyn.config.schema.pipeline.trigger.drift.criterion import ( - DynamicRollingAverageThresholdCriterion, -) -from modyn.config.schema.pipeline.trigger.drift.detection_window.time_ import ( - TimeWindowingStrategy, -) -from modyn.config.schema.pipeline.trigger.ensemble import ( - AtLeastNEnsembleStrategy, - EnsembleTriggerConfig, -) -from modyn.config.schema.pipeline.trigger.performance.criterion import ( - DynamicQuantilePerformanceThresholdCriterion, - StaticNumberAvoidableMisclassificationCriterion, - StaticPerformanceThresholdCriterion, -) from modyn.config.schema.pipeline.trigger.performance.performance import ( - PerformanceTriggerConfig, PerformanceTriggerEvaluationConfig, ) +from modyn.config.schema.pipeline.trigger.simple.data_amount import DataAmountTriggerConfig from modyn.utils.utils import SECONDS_PER_UNIT from modynclient.config.schema.client_config import ModynClientConfig, Supervisor @@ -209,22 +190,23 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: # gpu_device="cuda:1", # ), # # data amount baselines - # 11: Experiment( - # name="yb-baseline-dataamount", - # eval_handlers=( - # construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") + - # construct_between_trigger_eval_handler("manual") - # ), - # data_amount_triggers={ - # f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) - # for num_samples in ([250, 500, 1_000, 2_500, 5_000, 10_000, 15_000, 30_000]) - # }, - # gpu_device="cuda:2", - # ), + 11: Experiment( + name="yb-baseline-dataamount", + eval_handlers=( + construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") + + construct_between_trigger_eval_handler("manual") + ), + data_amount_triggers={ + f"{num_samples}": DataAmountTriggerConfig(num_samples=num_samples) + # for num_samples in ([250, 500, 1_000, 2_500, 5_000, 10_000, 15_000, 30_000]) + # for num_samples in ([1_000, 2_500, 5_000, 10_000]) + for num_samples in ([250, 500, 15_000, 30_000]) + }, + gpu_device="cuda:2", + ), # -------------------------------------------------------------------------------- # # 2X: Drift triggers # # -------------------------------------------------------------------------------- # - # TODO: check if these experiments also have all values for other handlers # Static threshold drift # 20: Experiment( # name="yb-datadrift-static", @@ -269,118 +251,136 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: # }, # gpu_device="cuda:3", # ), - # Dynamic threshold drift - 21: Experiment( - name="yb-datadrift-dynamic", - eval_handlers=( - construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") - + construct_between_trigger_eval_handler("manual") - ), - drift_detection_triggers={ - f"{criterion_name}_int{detection_interval}_win{window_size}": DataDriftTriggerConfig( - evaluation_interval_data_points=detection_interval, - windowing_strategy=TimeWindowingStrategy( - # overlap has no affect acc. to offline exploration - limit_ref=window_size, - limit_cur=window_size, - allow_overlap=False, - ), - # with 30k samples and 84 years, 10y are roughly 30000/84*10=3500 samples - # hence, if we want ~10 years of warmup, to 3500/detection_interval warmup intervals - warmup_intervals=3500 // detection_interval, - # triggering every 3 years during the warmup phase seems reasonable. - warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), - # 5k samples are enough for drift detection, in yearbook we won't accumulate that many anyway - sample_size=5_000, - metrics={"mmd": AlibiDetectMmdDriftMetric(decision_criterion=criterion, device="gpu")}, - ) - # multiprocessing across gpus - # TODO: 0: 100 - # TODO: 1: 250 - # TODO: 2: 500 - for detection_interval in [100] # 100, 250, 500 - for window_size in ["4d"] # dataset specific, best acc. to offline exploraion and static drift experiments - for decision_window_size in [30] # 10, 20, - for criterion_name, criterion in ( - # { - # f"mmd-perc-{percentile}-{decision_window_size}": DynamicPercentileThresholdCriterion( - # window_size=decision_window_size, percentile=percentile - # ) - # for percentile in [0.05, 0.15, 0.3] - # } - # | - { - f"mmd-rollavg-{deviation}-{decision_window_size}": DynamicRollingAverageThresholdCriterion( - window_size=decision_window_size, deviation=deviation, absolute=False - ) - for deviation in [0.05, 0.2, 0.5, 1.0, 2.0] - } - ).items() - }, - gpu_device="cuda:3", - ), - # -------------------------------------------------------------------------------- # - # 3X: Performance triggers # - # -------------------------------------------------------------------------------- # - 30: Experiment( - name="yb-performancetrigger", - eval_handlers=( - construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") - + construct_between_trigger_eval_handler("manual") - ), - performance_triggers={ - f"{criterion_name}-int{detection_interval}y": PerformanceTriggerConfig( - evaluation_interval_data_points=detection_interval, - data_density_window_size=20, - performance_triggers_window_size=20, - warmup_intervals=10, # TODO: link to window_size - warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), - evaluation=PerformanceTriggerEvaluationConfig( - device="cuda:2", - dataset=EvalDataConfig( - dataset_id="yearbook_train", - bytes_parser_function=yb_bytes_parser_function, - batch_size=512, - dataloader_workers=1, - metrics=[ - AccuracyMetricConfig(evaluation_transformer_function=yb_evaluation_transformer_function), - ], - ), - ), - mode="hindsight", - forecasting_method="ridge_regression", - decision_criteria={criterion_name: criterion}, - ) - for detection_interval in [100, 250, 500] - for criterion_name, criterion in ( - { - f"static-{perf_threshold}": StaticPerformanceThresholdCriterion( # TODO: check if bug is fixed - metric="Accuracy", metric_threshold=perf_threshold - ) - for perf_threshold in [0.7, 0.75, 0.8, 0.85, 0.9, 0.95] - } - | { - f"dynamic-{deviation}": DynamicQuantilePerformanceThresholdCriterion( # TODO: check if bug is fixed - metric="Accuracy", - deviation=deviation, - absolute=False, - ) - for deviation in [0.025, 0.05, 0.1, 0.2, 0.3] - } - # TODO: dynamic rolling average - | { - f"num_misclass-{num_misclassifications}-{allow_reduction}-": StaticNumberAvoidableMisclassificationCriterion( - expected_accuracy=0.95, # TODO: variable - allow_reduction=allow_reduction, - avoidable_misclassification_threshold=num_misclassifications, - ) # TODO: avg / quantile - for num_misclassifications in [100, 200, 500, 1000, 2000, 5000] - for allow_reduction in [True, False] - } - ).items() - }, - gpu_device="cuda:2", - ), + # # Dynamic threshold drift + # 21: Experiment( + # name="yb-datadrift-dynamic", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # drift_detection_triggers={ + # f"{criterion_name}_int{detection_interval}_win{window_size}": DataDriftTriggerConfig( + # evaluation_interval_data_points=detection_interval, + # windowing_strategy=TimeWindowingStrategy( + # # overlap has no affect acc. to offline exploration + # limit_ref=window_size, + # limit_cur=window_size, + # allow_overlap=False, + # ), + # # with 30k samples and 84 years, 10y are roughly 30000/84*10=3500 samples + # # hence, if we want ~10 years of warmup, to 3500/detection_interval warmup intervals + # warmup_intervals=3500 // detection_interval, + # # triggering every 3 years during the warmup phase seems reasonable. + # warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), + # # 5k samples are enough for drift detection, in yearbook we won't accumulate that many anyway + # sample_size=5_000, + # metrics={"mmd": AlibiDetectMmdDriftMetric(decision_criterion=criterion, device="gpu")}, + # ) + # # multiprocessing across gpus + # for detection_interval in reversed([100, 250, 500]) + # for window_size in ["4d"] # dataset specific, best acc. to offline exploraion and static drift experiments + # for decision_window_size in [10, 20, 30] + # # cuda:1: 10 + # # cuda:2: 20 + # # cuda:3: 30 + # for criterion_name, criterion in ( + # { + # f"mmd-quant-{quantile}-{decision_window_size}": DynamicQuantileThresholdCriterion( + # window_size=decision_window_size, quantile=quantile + # ) + # for quantile in [0.05, 0.1, 0.15, 0.3] + # } + # | + # { + # f"mmd-rollavg-{deviation}-{decision_window_size}": DynamicRollingAverageThresholdCriterion( + # window_size=decision_window_size, deviation=deviation, absolute=False + # ) + # for deviation in [0.05, 0.2, 0.5, 1.0, 2.0] + # } + # ).items() + # }, + # gpu_device="cuda:0", + # ), + # # -------------------------------------------------------------------------------- # + # # 3X: Performance triggers # + # # -------------------------------------------------------------------------------- # + # 30: Experiment( + # name="yb-performancetrigger", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # performance_triggers={ + # f"{criterion_name}-int{detection_interval}y": PerformanceTriggerConfig( + # evaluation_interval_data_points=detection_interval, + # data_density_window_size=20, # performed well for drift, only used for #avoidable misclass + # performance_triggers_window_size=20, # performed well for drift, only used for #avoidable misclass + # warmup_intervals=3500 // detection_interval, # same as in drift case + # warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), + # evaluation=PerformanceTriggerEvaluationConfig( + # device="cuda:0", + # dataset=EvalDataConfig( + # dataset_id="yearbook_train", # optional: extra holdout split + # bytes_parser_function=yb_bytes_parser_function, + # batch_size=512, + # dataloader_workers=1, + # metrics=[ + # AccuracyMetricConfig(evaluation_transformer_function=yb_evaluation_transformer_function), + # ], + # ), + # ), + # mode="hindsight", + # forecasting_method="ridge_regression", + # decision_criteria={criterion_name: criterion}, + # ) + # # for detection_interval in [100, 250, 500] + # for detection_interval in [100] + # # cuda1: 100 + # # cuda2: 250 + # # cuda3: 500 + # # cuda0: 100, 250, 500 - num_misscl - 100, 200, 500 + # for criterion_name, criterion in ( + # # { + # # f"static-{perf_threshold}": StaticPerformanceThresholdCriterion( + # # metric="Accuracy", metric_threshold=perf_threshold + # # ) + # # for perf_threshold in [0.7, 0.75, 0.8, 0.85, 0.875, 0.9, 0.925, 0.95] + # # } + # # | { + # # f"dynamic-quant-{quantile}-{decision_window_size}": DynamicQuantilePerformanceThresholdCriterion( + # # metric="Accuracy", + # # quantile=quantile, + # # window_size=decision_window_size, + # # ) + # # for quantile in [0.05, 0.15, 0.3] + # # for decision_window_size in [10, 20, 30] + # # } + # # | + # { + # f"dynamic-rollavg-{deviation}-{decision_window_size}": DynamicRollingAveragePerformanceThresholdCriterion( + # metric="Accuracy", + # deviation=deviation, + # absolute=False, + # window_size=decision_window_size, + # ) + # for deviation in reversed([0.05, 0.1, 0.2, 0.3]) # TODO: delete: 0.025 + # for decision_window_size in [10, 20, 30] + # } + # # | + # # { + # # f"num_misclass-{num_misclassifications}-exp-{expected_accuracy}-red-{allow_reduction}-": StaticNumberAvoidableMisclassificationCriterion( + # # expected_accuracy=expected_accuracy, + # # allow_reduction=allow_reduction, + # # avoidable_misclassification_threshold=num_misclassifications, + # # ) + # # for num_misclassifications in reversed([100, 200, 500]) # TODO: 100, 200, 500, 1000, 2000, 5000 + # # for expected_accuracy in [0.85, 0.9, 0.95] # TODO last successful: yearbook_performancetrigger_num_misclass-200-exp-0.85-red-False--int500y --> mind the reversed + # # for allow_reduction in [True, False] + # # } + # ).items() + # }, + # gpu_device="cuda:0", + # ), # -------------------------------------------------------------------------------- # # 4X: Cost aware triggers # # -------------------------------------------------------------------------------- # @@ -397,10 +397,10 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: cost_tracking_window_size=20, incorporation_delay_per_training_second=exchange_rate, ) - for interval in [100, 250, 500, 1_000] + for interval in reversed([100, 250, 500, 1_000]) for exchange_rate in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0] }, - gpu_device="cuda:3", + gpu_device="cuda:0", ), # avoidable misclassfication integration trigger 41: Experiment( @@ -410,34 +410,38 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: + construct_between_trigger_eval_handler("manual") ), cost_triggers={ - f"int{interval}_exch{exchange_rate}_red{allow_reduction}": AvoidableMisclassificationCostTriggerConfig( + f"int{detection_interval}_exch{exchange_rate}_red{allow_reduction}": AvoidableMisclassificationCostTriggerConfig( # cost trigger params - expected_accuracy=0.9, + expected_accuracy=0.9, # assumed to work out ask it worked well for performance triggers cost_tracking_window_size=50, avoidable_misclassification_latency_per_training_second=exchange_rate, # performance trigger params - evaluation_interval_data_points=interval, + evaluation_interval_data_points=detection_interval, data_density_window_size=20, performance_triggers_window_size=20, - warmup_intervals=10, # TODO: link to window_size + warmup_intervals=3500 // detection_interval, # same as in drift case warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), evaluation=PerformanceTriggerEvaluationConfig( - device="cuda:2", + device="cuda:1", dataset=EvalDataConfig( dataset_id="yearbook_train", bytes_parser_function=yb_bytes_parser_function, - batch_size=512, # TODO: lower + batch_size=512, dataloader_workers=1, metrics=[ AccuracyMetricConfig(evaluation_transformer_function=yb_evaluation_transformer_function), ], ), ), - mode="hindsight", # TODO: lookahead + mode="hindsight", forecasting_method="ridge_regression", ) - for interval in [100, 250, 500, 1_000] - for exchange_rate in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0] + # for detection_interval in [100, 250, 500] + for detection_interval in [500] + # cuda:1 - 100 + # cuda:2 - 250 + # cuda:3 - 500 + for exchange_rate in [1_000_000_000] for allow_reduction in [True, False] }, gpu_device="cuda:1", @@ -445,87 +449,57 @@ def construct_pipelines(experiment: Experiment) -> list[ModynPipelineConfig]: # -------------------------------------------------------------------------------- # # 5X: Ensemble triggers # # -------------------------------------------------------------------------------- # - # with best working previous triggers - 51: Experiment( - name="yb-ensemble", - eval_handlers=( - construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") - + construct_between_trigger_eval_handler("manual") - ), - ensemble_triggers={ - "ensemble1": EnsembleTriggerConfig( - subtriggers={ - "drift1": DataDriftTriggerConfig( - evaluation_interval_data_points=500, - windowing_strategy=TimeWindowingStrategy(limit_ref="4d", limit_cur="4d"), - warmup_intervals=10, - warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), - metrics={ - "mmd": AlibiDetectMmdDriftMetric( - device="gpu", - decision_criterion=DynamicRollingAverageThresholdCriterion( - deviation=0.1, absolute=False, window_size=15 - ), - ) - }, - ), - "perf1": PerformanceTriggerConfig( - evaluation_interval_data_points=500, - data_density_window_size=20, - performance_triggers_window_size=20, - evaluation=PerformanceTriggerEvaluationConfig( - device="cuda:0", - dataset=EvalDataConfig( - dataset_id="yearbook_train", - bytes_parser_function=yb_bytes_parser_function, - batch_size=64, - dataloader_workers=1, - metrics=[ - AccuracyMetricConfig( - evaluation_transformer_function=yb_evaluation_transformer_function - ), - ], - ), - ), - mode="hindsight", # TODO: lookahead - forecasting_method="ridge_regression", - decision_criteria={ - "static-0.8": StaticPerformanceThresholdCriterion(metric="Accuracy", metric_threshold=0.8) - }, - ), - }, - ensemble_strategy=AtLeastNEnsembleStrategy(n=1), - ) - }, - gpu_device="cuda:0", - ), - # ----------------------------- Evaluation intervals ----------------------------- # - # 30: Experiment( - # name="yb-drift-interval-cost", - # eval_handlers=[ - # construct_slicing_eval_handler(), - # construct_between_trigger_eval_handler(), - # ], - # time_triggers={}, - # data_amount_triggers={}, - # drift_detection_triggers={ - # f"detection_interval_{detection_interval}": DataDriftTriggerConfig( - # evaluation_interval_data_points=detection_interval, - # windowing_strategy=TimeWindowingStrategy( - # limit_ref="4d", limit_cur="4d", - # ), - # warmup_intervals=10, - # warmup_policy=TimeTriggerConfig( - # every="3d", start_timestamp=_FIRST_TIMESTAMP - # ), - # metrics={ - # "mmd": AlibiDetectMmdDriftMetric( - # decision_criterion=DynamicThresholdCriterion(window_size=10), - # device="gpu", - # ) - # } + # # with best working previous triggers + # 51: Experiment( + # name="yb-ensemble", + # eval_handlers=( + # construct_periodic_eval_handlers(intervals=BEST_PERIODIC_EVAL_INTERVAL, execution_time="manual") + # + construct_between_trigger_eval_handler("manual") + # ), + # ensemble_triggers={ + # "ensemble1": EnsembleTriggerConfig( + # subtriggers={ + # "drift1": DataDriftTriggerConfig( + # evaluation_interval_data_points=500, + # windowing_strategy=TimeWindowingStrategy(limit_ref="4d", limit_cur="4d"), + # warmup_intervals=10, + # warmup_policy=TimeTriggerConfig(every="3d", start_timestamp=_FIRST_TIMESTAMP), + # metrics={ + # "mmd": AlibiDetectMmdDriftMetric( + # device="gpu", + # decision_criterion=DynamicRollingAverageThresholdCriterion( + # deviation=0.1, absolute=False, window_size=15 + # ), + # ) + # }, + # ), + # "perf1": PerformanceTriggerConfig( + # evaluation_interval_data_points=500, + # data_density_window_size=20, + # performance_triggers_window_size=20, + # evaluation=PerformanceTriggerEvaluationConfig( + # device="cuda:0", + # dataset=EvalDataConfig( + # dataset_id="yearbook_train", + # bytes_parser_function=yb_bytes_parser_function, + # batch_size=64, + # dataloader_workers=1, + # metrics=[ + # AccuracyMetricConfig( + # evaluation_transformer_function=yb_evaluation_transformer_function + # ), + # ], + # ), + # ), + # mode="hindsight", # TODO: lookahead + # forecasting_method="ridge_regression", + # decision_criteria={ + # "static-0.8": StaticPerformanceThresholdCriterion(metric="Accuracy", metric_threshold=0.8) + # }, + # ), + # }, + # ensemble_strategy=AtLeastNEnsembleStrategy(n=1), # ) - # for detection_interval in [100, 200, 500, 1_000, 2_500, 5_000, 10_000, 15_000] # }, # gpu_device="cuda:0", # ), diff --git a/modyn/config/schema/pipeline/trigger/cost/cost.py b/modyn/config/schema/pipeline/trigger/cost/cost.py index ea478cf63..53d9f5639 100644 --- a/modyn/config/schema/pipeline/trigger/cost/cost.py +++ b/modyn/config/schema/pipeline/trigger/cost/cost.py @@ -97,7 +97,7 @@ class AvoidableMisclassificationCostTriggerConfig( `PerformanceTriggerConfig`. """ - id: Literal["AvoidableMisclassificationCostTrigger"] = Field("AvoidableMisclassificationCost") + id: Literal["AvoidableMisclassificationCostTrigger"] = Field("AvoidableMisclassificationCostTrigger") # Conversion rate between budget (training time) and regret metric (misclassifications) avoidable_misclassification_latency_per_training_second: float = Field( diff --git a/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py b/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py index b0cecb15f..bd05cf35f 100644 --- a/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py +++ b/modyn/config/schema/pipeline/trigger/drift/alibi_detect.py @@ -14,9 +14,13 @@ class _AlibiDetectBaseDriftMetric(BaseMetric): - p_val: float = Field(0.05, description="The p-value threshold for the drift detection.") + p_val: float = Field( + 0.05, description="The p-value threshold for the drift detection." + ) x_ref_preprocessed: bool = Field(False) - preprocessor: AlibiDetectNLPreprocessor | None = Field(None, description="Preprocessor function.") + preprocessor: AlibiDetectNLPreprocessor | None = Field( + None, description="Preprocessor function." + ) class AlibiDetectDeviceMixin(ModynBaseModel): @@ -69,8 +73,12 @@ def validate_threshold_permutations(self) -> "AlibiDetectMmdDriftMetric": return self -class AlibiDetectClassifierDriftMetric(_AlibiDetectBaseDriftMetric, AlibiDetectDeviceMixin): - id: Literal["AlibiDetectClassifierDriftMetric"] = Field("AlibiDetectClassifierDriftMetric") +class AlibiDetectClassifierDriftMetric( + _AlibiDetectBaseDriftMetric, AlibiDetectDeviceMixin +): + id: Literal["AlibiDetectClassifierDriftMetric"] = Field( + "AlibiDetectClassifierDriftMetric" + ) classifier_id: str = Field( description="The model to use for classifications; has to be registered in alibi_detector.py" ) @@ -84,11 +92,42 @@ class AlibiDetectKSDriftMetric( id: Literal["AlibiDetectKSDriftMetric"] = Field("AlibiDetectKSDriftMetric") -class AlibiDetectCVMDriftMetric(_AlibiDetectBaseDriftMetric, _AlibiDetectCorrectionMixin): +class AlibiDetectCVMDriftMetric( + _AlibiDetectBaseDriftMetric, _AlibiDetectCorrectionMixin +): id: Literal["AlibiDetectCVMDriftMetric"] = Field("AlibiDetectCVMDriftMetric") +class AlibiDetectLSDDDriftMetric( + _AlibiDetectBaseDriftMetric, _AlibiDetectCorrectionMixin, AlibiDetectDeviceMixin +): + id: Literal["AlibiDetectLSDDDriftMetric"] = Field("AlibiDetectLSDDDriftMetric") + + +class AlibiDetectFETDriftMetric( + _AlibiDetectBaseDriftMetric, + _AlibiDetectCorrectionMixin, + _AlibiDetectAlternativeMixin, +): + id: Literal["AlibiDetectFETDriftMetric"] = Field("AlibiDetectFETDriftMetric") + n_features: int | None = Field(None) + + +class AlibiDetectChiSquareDriftMetric( + _AlibiDetectBaseDriftMetric, _AlibiDetectCorrectionMixin +): + id: Literal["AlibiDetectChiSquareDriftMetric"] = Field( + "AlibiDetectChiSquareDriftMetric" + ) + n_features: int | None = Field(None) + + AlibiDetectDriftMetric = Annotated[ - AlibiDetectMmdDriftMetric | AlibiDetectKSDriftMetric | AlibiDetectCVMDriftMetric, + AlibiDetectMmdDriftMetric + | AlibiDetectKSDriftMetric + | AlibiDetectCVMDriftMetric + | AlibiDetectLSDDDriftMetric + | AlibiDetectFETDriftMetric + | AlibiDetectChiSquareDriftMetric, Field(discriminator="id"), ] diff --git a/modyn/supervisor/internal/triggers/avoidablemissclassification_costtrigger.py b/modyn/supervisor/internal/triggers/avoidablemissclassification_costtrigger.py index 51131266a..580982871 100644 --- a/modyn/supervisor/internal/triggers/avoidablemissclassification_costtrigger.py +++ b/modyn/supervisor/internal/triggers/avoidablemissclassification_costtrigger.py @@ -67,6 +67,10 @@ def _compute_regret_metric( """Compute the regret metric for the current state of the trigger.""" self.data_density.inform_data(batch) + + if not self._triggered_once: + return 0.0, {} # we don't have a model to evaluate yet + model_id, num_samples, num_misclassifications, evaluation_scores = self._run_evaluation(interval_data=batch) self.performance_tracker.inform_evaluation( diff --git a/modyn/supervisor/internal/triggers/costtrigger.py b/modyn/supervisor/internal/triggers/costtrigger.py index 697cb4725..1b6c211ae 100644 --- a/modyn/supervisor/internal/triggers/costtrigger.py +++ b/modyn/supervisor/internal/triggers/costtrigger.py @@ -59,7 +59,7 @@ def _evaluate_batch( traintime_estimate = -1.0 regret_metric, regret_log = self._compute_regret_metric(batch, batch_start, batch_duration) - regret_in_traintime_unit = regret_metric * self.config.conversion_factor + regret_in_traintime_unit = regret_metric / self.config.conversion_factor # --------------------------------------------- Trigger Decision --------------------------------------------- # diff --git a/modyn/supervisor/internal/triggers/drift/detector/alibi.py b/modyn/supervisor/internal/triggers/drift/detector/alibi.py index 239b554f0..9468aaa5c 100644 --- a/modyn/supervisor/internal/triggers/drift/detector/alibi.py +++ b/modyn/supervisor/internal/triggers/drift/detector/alibi.py @@ -20,22 +20,36 @@ MetricResult, ) from modyn.config.schema.pipeline.trigger.drift.alibi_detect import ( + AlibiDetectChiSquareDriftMetric, AlibiDetectClassifierDriftMetric, AlibiDetectCVMDriftMetric, + AlibiDetectFETDriftMetric, AlibiDetectKSDriftMetric, + AlibiDetectLSDDDriftMetric, ) from modyn.supervisor.internal.triggers.drift.classifier_models import ( alibi_classifier_models, ) from modyn.supervisor.internal.triggers.drift.detector.drift import DriftDetector -_AlibiMetrics = MMDDrift | ClassifierDrift | ChiSquareDrift | CVMDrift | FETDrift | KSDrift | LSDDDrift | MMDDrift +_AlibiMetrics = ( + MMDDrift + | ClassifierDrift + | ChiSquareDrift + | CVMDrift + | FETDrift + | KSDrift + | LSDDDrift + | MMDDrift +) class AlibiDriftDetector(DriftDetector): def __init__(self, metrics_config: dict[str, AlibiDetectDriftMetric]): alibi_metrics_config = { - metric_ref: config for metric_ref, config in metrics_config.items() if config.id.startswith("AlibiDetect") + metric_ref: config + for metric_ref, config in metrics_config.items() + if config.id.startswith("AlibiDetect") } super().__init__(alibi_metrics_config) @@ -109,7 +123,9 @@ def _alibi_detect_metric_factory(config: AlibiDetectDriftMetric, embeddings_ref: kwargs = {} if config.preprocessor: - kwargs.update({"preprocess_fn": config.preprocessor.gen_preprocess_fn(config.device)}) + kwargs.update( + {"preprocess_fn": config.preprocessor.gen_preprocess_fn(config.device)} + ) if isinstance(config, AlibiDetectMmdDriftMetric): assert kernel is not None @@ -154,4 +170,38 @@ def _alibi_detect_metric_factory(config: AlibiDetectDriftMetric, embeddings_ref: **kwargs, ) - raise NotImplementedError(f"Metric {config.id} is not supported in AlibiDetectDriftMetric.") + if isinstance(config, AlibiDetectLSDDDriftMetric): + return LSDDDrift( + x_ref=embeddings_ref, + backend="pytorch", + n_permutations=config.num_permutations or 1, + p_val=config.p_val, + correction=config.correction, + x_ref_preprocessed=config.x_ref_preprocessed, + device=config.device, + **kwargs, + ) + + if isinstance(config, AlibiDetectFETDriftMetric): + return FETDrift( + x_ref=embeddings_ref, + p_val=config.p_val, + correction=config.correction, + x_ref_preprocessed=config.x_ref_preprocessed, + n_features=config.n_features, + **kwargs, + ) + + if isinstance(config, AlibiDetectChiSquareDriftMetric): + return ChiSquareDrift( + x_ref=embeddings_ref, + p_val=config.p_val, + correction=config.correction, + x_ref_preprocessed=config.x_ref_preprocessed, + n_features=config.n_features, + **kwargs, + ) + + raise NotImplementedError( + f"Metric {config.id} is not supported in AlibiDetectDriftMetric." + )