diff --git a/timely_beliefs/beliefs/classes.py b/timely_beliefs/beliefs/classes.py index 9715d1be..dfee9393 100644 --- a/timely_beliefs/beliefs/classes.py +++ b/timely_beliefs/beliefs/classes.py @@ -28,6 +28,7 @@ func, select, ) +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.hybrid import hybrid_method, hybrid_property from sqlalchemy.orm import Session, backref, declarative_mixin, relationship @@ -251,7 +252,7 @@ def add_to_session( beliefs_data_frame: "BeliefsDataFrame", expunge_session: bool = False, allow_overwrite: bool = False, - bulk_save_objects: bool = False, + bulk_save_objects: bool = True, commit_transaction: bool = False, ): """Add a BeliefsDataFrame as timed beliefs to a database session. @@ -274,23 +275,53 @@ def add_to_session( if False, you can still add other data to the session and commit it all within an atomic transaction """ + if beliefs_data_frame.empty: + return # Belief timing is stored as the belief horizon rather than as the belief time - belief_records = ( - beliefs_data_frame.convert_index_from_belief_time_to_horizon() - .reset_index() - .to_dict("records") + beliefs_data_frame = ( + beliefs_data_frame.convert_index_from_belief_time_to_horizon().reset_index() ) - beliefs = [cls(sensor=beliefs_data_frame.sensor, **d) for d in belief_records] + beliefs = [ + cls(sensor=beliefs_data_frame.sensor, **d) + for d in beliefs_data_frame.to_dict("records") + ] + if expunge_session: session.expunge_all() - if not allow_overwrite: - if bulk_save_objects: - session.bulk_save_objects(beliefs) + + if bulk_save_objects: + # serialize source and sensor + beliefs_data_frame["source_id"] = beliefs_data_frame["source"].apply( + lambda x: x.id + ) + beliefs_data_frame["sensor_id"] = beliefs_data_frame.sensor.id + beliefs_data_frame = beliefs_data_frame.drop(columns=["source"]) + + smt = insert(cls).values(beliefs_data_frame.to_dict("records")) + + if allow_overwrite: + smt = smt.on_conflict_do_update( + index_elements=[ + "event_start", + "belief_horizon", + "source_id", + "sensor_id", + "cumulative_probability", + ], + set_=dict(event_value=smt.excluded.event_value), + ) else: - session.add_all(beliefs) + smt = smt.on_conflict_do_nothing() + + session.execute(smt) + else: - for belief in beliefs: - session.merge(belief) + if allow_overwrite: + for belief in beliefs: + session.merge(belief) + else: + session.add_all(beliefs) + if commit_transaction: session.commit() diff --git a/timely_beliefs/tests/test_belief_persistence.py b/timely_beliefs/tests/test_belief_persistence.py index da4eb8d2..5bdce7de 100644 --- a/timely_beliefs/tests/test_belief_persistence.py +++ b/timely_beliefs/tests/test_belief_persistence.py @@ -49,9 +49,31 @@ def test_adding_to_session( assert len(bdf) == len(new_bdf) -@pytest.mark.parametrize("bulk_save_objects", [False, True]) -def test_fail_adding_to_session( - bulk_save_objects: bool, +def test_adding_to_session_succeeds( + time_slot_sensor: DBSensor, + rolling_day_ahead_beliefs_about_time_slot_events, +): + + # Retrieve some data from the database + bdf = DBTimedBelief.search_session( + session=session, + sensor=time_slot_sensor, + ) + + # Attempting to save the same data should not fail, even if we expunge everything from the session + try: + DBTimedBelief.add_to_session( + session, + bdf, + expunge_session=True, + bulk_save_objects=True, + commit_transaction=True, + ) + except IntegrityError as exception: + raise pytest.fail("DID RAISE {0}".format(exception)) + + +def test_adding_to_session_fails( time_slot_sensor: DBSensor, rolling_day_ahead_beliefs_about_time_slot_events, ): @@ -68,6 +90,6 @@ def test_fail_adding_to_session( session, bdf, expunge_session=True, - bulk_save_objects=bulk_save_objects, + bulk_save_objects=False, commit_transaction=True, )