diff --git a/timely_beliefs/beliefs/classes.py b/timely_beliefs/beliefs/classes.py index 72e6a734..668da7ff 100644 --- a/timely_beliefs/beliefs/classes.py +++ b/timely_beliefs/beliefs/classes.py @@ -377,6 +377,7 @@ def search_session( # noqa: C901 # todo: remove after removing deprecated argu :param custom_join_targets: additional join targets, to accommodate filters that rely on other targets (e.g. subclasses) :returns: a multi-index DataFrame with all relevant beliefs """ + source_class = cls.source.property.mapper.class_ # todo: deprecate the 'event_before' argument in favor of 'event_ends_before' (announced v1.4.1) event_ends_before = tb_utils.replace_deprecated_argument( @@ -515,7 +516,14 @@ def apply_belief_timing_filters(q): return q # Main query - q = select(cls).filter(cls.sensor_id == sensor.id) + q = select( + cls.event_start, + cls.belief_horizon, + cls.source_id, + cls.cumulative_probability, + cls.event_value, + ).filter(cls.sensor_id == sensor.id) + # Apply event time filter if not pd.isnull(event_starts_after): q = q.filter(cls.event_start >= event_starts_after) @@ -543,8 +551,7 @@ def apply_belief_timing_filters(q): # Apply source filter if source is not None: sources: list = [source] if not isinstance(source, list) else source - source_cls = sources[0].__class__ - q = q.join(source_cls).filter(cls.source_id.in_([s.id for s in sources])) + q = q.join(source_class).filter(cls.source_id.in_([s.id for s in sources])) # Apply most recent beliefs filter most_recent_beliefs_only_incompatible_criteria = ( @@ -596,8 +603,30 @@ def apply_belief_timing_filters(q): ) # Build our DataFrame of beliefs - beliefs = session.scalars(q).all() - df = BeliefsDataFrame(sensor=sensor, beliefs=beliefs) + df = pd.DataFrame(session.execute(q)) + if df.empty: + return BeliefsDataFrame(sensor=sensor) + df.columns = [ + "event_start", + "belief_horizon", + "source_id", + "cumulative_probability", + "event_value", + ] + + # Fill in sources + if source is None: + source_ids = df["source_id"].unique().tolist() + sources = session.scalars( + select(source_class).filter(source_class.id.in_(source_ids)) + ).all() + source_map = {source.id: source for source in sources} + df["source_id"] = df["source_id"].map(source_map) + df = df.rename(columns={"source_id": "source"}) + + # Build our BeliefsDataFrame + df = BeliefsDataFrame(df, sensor=sensor) + df = df.convert_index_from_belief_horizon_to_time() # Actually filter by belief time if beliefs_after is not None: