diff --git a/common/database_helpers.py b/common/database_helpers.py index 5d6d0bf5..fa8a98a8 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -533,28 +533,33 @@ def get_facility_cycles_for_instrument_count(instrument_id): return len(get_facility_cycles_for_instrument(instrument_id)) -def get_investigations_for_instrument_in_facility_cycle(instrument_id, facility_cycle_id): +class InstrumentInCycleInvestigationsQuery(ReadQuery): + def __init__(self, instrument_id, facility_cycle_id): + super().__init__(INVESTIGATION) + self.instrument_id = instrument_id + self.facility_cycle_id = facility_cycle_id + self._get_date_filters() + + def _get_date_filters(self): """ - Given an instrument id and facility cycle id, get investigations that use the given instrument in the given cycle - :param instrument_id: The id of the instrument - :param facility_cycle_id: the ID of the facility cycle - :return: The investigations + Sets the date filters to be applied to the query """ - session = session_manager.get_icat_db_session() + self.start_date_filter = WhereFilter("STARTDATE", self._get_facility_cycle()["STARTDATE"], "gte") + self.end_date_filter = WhereFilter("ENDDATE", self._get_facility_cycle()["ENDDATE"], "lte") + def _get_facility_cycle(self): + """ + Given a facility cycle_id and instrument_id, get the facility cycle that has an investigation using the given + instrument in the cycle + :return: The facility cycle + """ + facility_cycles = get_facility_cycles_for_instrument(self.instrument_id, filters=[]) try: - facility_cycles = get_facility_cycles_for_instrument(instrument_id) - for i in facility_cycles: - session.add(i) - try: - facility_cycle = [x for x in facility_cycles if x.ID == facility_cycle_id][0] + # The ID value gets previously converted to str + return [cycle for cycle in facility_cycles if cycle["ID"] == str(self.facility_cycle_id)][0] except IndexError: raise MissingRecordError() - investigations = session.query(INVESTIGATION).filter(INVESTIGATION.STARTDATE >= facility_cycle.STARTDATE, - INVESTIGATION.ENDDATE <= facility_cycle.ENDDATE).all() - if len(investigations) == 0: - raise MissingRecordError() - return investigations + finally: session.close()