diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0081fb82..6776cb76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,10 +10,8 @@ repos: hooks: - id: flake8 name: flake8 (code linting) - language_version: python3.9 - repo: https://github.com/psf/black rev: 22.12.0 # New version tags can be found here: https://github.com/psf/black/tags hooks: - id: black name: black (code formatting) - language_version: python3.9 diff --git a/timely_beliefs/beliefs/classes.py b/timely_beliefs/beliefs/classes.py index 69484cc9..c21c4268 100644 --- a/timely_beliefs/beliefs/classes.py +++ b/timely_beliefs/beliefs/classes.py @@ -398,6 +398,36 @@ def search_session( # noqa: C901 get_bounds=True, ) + def apply_event_timing_filters(q): + """Apply filters that concern the event time. + + This includes any custom filters + """ + if not pd.isnull(event_starts_after): + q = q.filter(cls.event_start >= event_starts_after) + if not pd.isnull(event_ends_after): + if sensor.event_resolution == timedelta(0): + # inclusive + q = q.filter(cls.event_start >= event_ends_after) + else: + # exclusive + q = q.filter( + cls.event_start > event_ends_after - sensor.event_resolution + ) + if not pd.isnull(event_starts_before): + if sensor.event_resolution == timedelta(0): + # inclusive + q = q.filter(cls.event_start <= event_starts_before) + else: + # exclusive + q = q.filter(cls.event_start < event_starts_before) + if not pd.isnull(event_ends_before): + q = q.filter( + cls.event_start <= event_ends_before - sensor.event_resolution + ) + + return q + def apply_belief_timing_filters(q): """Apply filters that concern the belief timing. @@ -411,8 +441,8 @@ def apply_belief_timing_filters(q): knowledge_horizon_min, timedelta.min ): q = q.filter( - cls.event_start - >= beliefs_after + cls.belief_horizon + knowledge_horizon_min + cls.event_start - cls.belief_horizon + >= beliefs_after + knowledge_horizon_min ) if not pd.isnull( beliefs_before @@ -420,8 +450,8 @@ def apply_belief_timing_filters(q): knowledge_horizon_max, timedelta.max ): q = q.filter( - cls.event_start - <= beliefs_before + cls.belief_horizon + knowledge_horizon_max + cls.event_start - cls.belief_horizon + <= beliefs_before + knowledge_horizon_max ) # Apply belief horizon filter @@ -448,28 +478,7 @@ def apply_belief_timing_filters(q): cls.event_value, ).filter(cls.sensor_id == sensor.id) - # Apply event time filter - if not pd.isnull(event_starts_after): - q = q.filter(cls.event_start >= event_starts_after) - if not pd.isnull(event_ends_after): - if sensor.event_resolution == timedelta(0): - # inclusive - q = q.filter(cls.event_start >= event_ends_after) - else: - # exclusive - q = q.filter( - cls.event_start + sensor.event_resolution > event_ends_after - ) - if not pd.isnull(event_starts_before): - if sensor.event_resolution == timedelta(0): - # inclusive - q = q.filter(cls.event_start <= event_starts_before) - else: - # exclusive - q = q.filter(cls.event_start < event_starts_before) - if not pd.isnull(event_ends_before): - q = q.filter(cls.event_start + sensor.event_resolution <= event_ends_before) - + q = apply_event_timing_filters(q) q = apply_belief_timing_filters(q) # Apply source filter @@ -477,7 +486,7 @@ def apply_belief_timing_filters(q): sources: list = [source] if not isinstance(source, list) else source q = q.join(source_class).filter(cls.source_id.in_([s.id for s in sources])) - # Apply most recent beliefs filter + # Apply most recent beliefs filter as subquery most_recent_beliefs_only_incompatible_criteria = ( beliefs_before is not None or beliefs_after is not None ) and sensor.knowledge_horizon_fnc not in (ex_ante.__name__, ex_post.__name__) @@ -490,10 +499,12 @@ def apply_belief_timing_filters(q): cls.source_id, func.min(cls.belief_horizon).label("most_recent_belief_horizon"), ) - # Apply belief timing filters to the subquery, too, before taking the minimum horizon + # Apply event and belief timing filters to the subquery, too, + # before taking the minimum horizon (the former is crucial for speed) + subq = apply_event_timing_filters(subq) + subq = apply_belief_timing_filters(subq) subq = ( - apply_belief_timing_filters(subq) - .filter(cls.sensor_id == sensor.id) + subq.filter(cls.sensor_id == sensor.id) .group_by(cls.event_start, cls.source_id) .subquery() ) @@ -506,7 +517,7 @@ def apply_belief_timing_filters(q): ), ) - # Apply most recent events filter + # Apply most recent events filter as subquery if most_recent_events_only: subq_most_recent_events = ( select(