From 65d9de8adfe71bb29b979e78dfde5b57d2a0160e Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 11 Apr 2024 20:42:18 +0900 Subject: [PATCH] test_get_notes_by_topic_ids passed --- birdxplorer/storage.py | 10 +++++++++- tests/test_storage.py | 7 ++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index ef3180b..bf18993 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -91,7 +91,15 @@ def get_notes( if created_at_to is not None: query = query.filter(NoteRecord.created_at <= created_at_to) if topic_ids is not None: - query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id.in_(topic_ids)) + # 同じトピックIDを持つノートを取得するためのサブクエリ + # とりあえずANDを実装 + subq = ( + select(NoteTopicAssociation.note_id) + .group_by(NoteTopicAssociation.note_id) + .having(func.array_agg(NoteTopicAssociation.topic_id) == topic_ids) + .subquery() + ) + query = query.join(subq, NoteRecord.note_id == subq.c.note_id) if post_ids is not None: query = query.filter(NoteRecord.post_id.in_(post_ids)) if language is not None: diff --git a/tests/test_storage.py b/tests/test_storage.py index 998be93..c60d0e6 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -86,9 +86,10 @@ def test_get_notes_by_topic_ids( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - topic_ids = [1, 2] - expected = [note for note in note_samples if note.topics == topic_ids] - actual = list(storage.get_notes(topic_ids=topic_ids)) + topics = note_samples[0].topics + topic_ids = [0] + expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id) + actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id) assert expected == actual