Skip to content

Commit

Permalink
test_get_notes_by_topic_ids passed
Browse files Browse the repository at this point in the history
  • Loading branch information
kota-yata committed Apr 11, 2024
1 parent e00a175 commit 65d9de8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
10 changes: 9 additions & 1 deletion birdxplorer/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,15 @@ def get_notes(
if created_at_to is not None:
query = query.filter(NoteRecord.created_at <= created_at_to)
if topic_ids is not None:
query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id.in_(topic_ids))
# 同じトピックIDを持つノートを取得するためのサブクエリ
# とりあえずANDを実装
subq = (
select(NoteTopicAssociation.note_id)
.group_by(NoteTopicAssociation.note_id)
.having(func.array_agg(NoteTopicAssociation.topic_id) == topic_ids)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if post_ids is not None:
query = query.filter(NoteRecord.post_id.in_(post_ids))
if language is not None:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def test_get_notes_by_topic_ids(
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
topic_ids = [1, 2]
expected = [note for note in note_samples if note.topics == topic_ids]
actual = list(storage.get_notes(topic_ids=topic_ids))
topics = note_samples[0].topics
topic_ids = [0]
expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id)
actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id)
assert expected == actual


Expand Down

0 comments on commit 65d9de8

Please sign in to comment.