Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Additional type hints for relations database class. #11205

Merged
merged 2 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11205.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints for the relations datastore.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ files =
synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/relations.py,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this file now also passes disallow-untyped-defs? From a quick glance it looks like it. I don't know if you want to mark that in mypy.ini though: as we've discussed, covering a broader range of files is the priority.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does, I hesitate to put individual files into that config though, it'll explode quickly. 🤷

synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
Expand Down
38 changes: 23 additions & 15 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

import logging
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import attr

from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_relations_for_event(
"""

where_clause = ["relates_to_id = ?"]
where_args = [event_id]
where_args: List[Union[str, int]] = [event_id]

if relation_type is not None:
where_clause.append("relation_type = ?")
Expand All @@ -80,8 +81,8 @@ async def get_relations_for_event(
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a shame. What does mypy think from_token is before this change---Tuple[Any, ...]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a shame. What does mypy think from_token is before this change---Tuple[Any, ...]?

Yes, it thinks it is Tuple[Any, ...]. I think this is python/mypy#5152, but I'm not 100% sure on that.

to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)

Expand All @@ -106,7 +107,9 @@ async def get_relations_for_event(
order,
)

def _get_recent_references_for_event_txn(txn):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])

last_topo_id = None
Expand Down Expand Up @@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event(
"""

where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args = [event_id, RelationTypes.ANNOTATION]
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]

if event_type:
where_clause.append("type = ?")
Expand All @@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event(
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)

Expand Down Expand Up @@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event(
having_clause=having_clause,
)

def _get_aggregation_groups_for_event_txn(txn):
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])

next_batch = None
Expand Down Expand Up @@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
LIMIT 1
"""

def _get_applicable_edit_txn(txn):
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
return None

edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
Expand All @@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn):
if not edit_id:
return None

return await self.get_event(edit_id, allow_none=True)
return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: this is the DB storage class debacle.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry about that!


@cached()
async def get_thread_summary(
Expand All @@ -283,7 +289,9 @@ async def get_thread_summary(
The number of items in the thread and the most recent response, if any.
"""

def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
Expand Down Expand Up @@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
count = txn.fetchone()[0]
count = txn.fetchone()[0] # type: ignore[index]

return count, latest_event_id

Expand All @@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True)
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]

return count, latest_event

Expand Down Expand Up @@ -354,7 +362,7 @@ async def has_user_annotated_event(
LIMIT 1;
"""

def _get_if_user_has_annotated_event(txn):
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
Expand Down