-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Additional type hints for relations database class. #11205
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Improve type hints for the relations datastore. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,13 +13,14 @@ | |
# limitations under the License. | ||
|
||
import logging | ||
from typing import Optional, Tuple | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import attr | ||
|
||
from synapse.api.constants import RelationTypes | ||
from synapse.events import EventBase | ||
from synapse.storage._base import SQLBaseStore | ||
from synapse.storage.database import LoggingTransaction | ||
from synapse.storage.databases.main.stream import generate_pagination_where_clause | ||
from synapse.storage.relations import ( | ||
AggregationPaginationToken, | ||
|
@@ -63,7 +64,7 @@ async def get_relations_for_event( | |
""" | ||
|
||
where_clause = ["relates_to_id = ?"] | ||
where_args = [event_id] | ||
where_args: List[Union[str, int]] = [event_id] | ||
|
||
if relation_type is not None: | ||
where_clause.append("relation_type = ?") | ||
|
@@ -80,8 +81,8 @@ async def get_relations_for_event( | |
pagination_clause = generate_pagination_where_clause( | ||
direction=direction, | ||
column_names=("topological_ordering", "stream_ordering"), | ||
from_token=attr.astuple(from_token) if from_token else None, | ||
to_token=attr.astuple(to_token) if to_token else None, | ||
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a shame. What does mypy think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, it thinks it is |
||
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] | ||
engine=self.database_engine, | ||
) | ||
|
||
|
@@ -106,7 +107,9 @@ async def get_relations_for_event( | |
order, | ||
) | ||
|
||
def _get_recent_references_for_event_txn(txn): | ||
def _get_recent_references_for_event_txn( | ||
txn: LoggingTransaction, | ||
) -> PaginationChunk: | ||
txn.execute(sql, where_args + [limit + 1]) | ||
|
||
last_topo_id = None | ||
|
@@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event( | |
""" | ||
|
||
where_clause = ["relates_to_id = ?", "relation_type = ?"] | ||
where_args = [event_id, RelationTypes.ANNOTATION] | ||
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] | ||
|
||
if event_type: | ||
where_clause.append("type = ?") | ||
|
@@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event( | |
having_clause = generate_pagination_where_clause( | ||
direction=direction, | ||
column_names=("COUNT(*)", "MAX(stream_ordering)"), | ||
from_token=attr.astuple(from_token) if from_token else None, | ||
to_token=attr.astuple(to_token) if to_token else None, | ||
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] | ||
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] | ||
engine=self.database_engine, | ||
) | ||
|
||
|
@@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event( | |
having_clause=having_clause, | ||
) | ||
|
||
def _get_aggregation_groups_for_event_txn(txn): | ||
def _get_aggregation_groups_for_event_txn( | ||
txn: LoggingTransaction, | ||
) -> PaginationChunk: | ||
txn.execute(sql, where_args + [limit + 1]) | ||
|
||
next_batch = None | ||
|
@@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: | |
LIMIT 1 | ||
""" | ||
|
||
def _get_applicable_edit_txn(txn): | ||
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: | ||
txn.execute(sql, (event_id, RelationTypes.REPLACE)) | ||
row = txn.fetchone() | ||
if row: | ||
return row[0] | ||
return None | ||
|
||
edit_id = await self.db_pool.runInteraction( | ||
"get_applicable_edit", _get_applicable_edit_txn | ||
|
@@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn): | |
if not edit_id: | ||
return None | ||
|
||
return await self.get_event(edit_id, allow_none=True) | ||
return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: this is the DB storage class debacle. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, sorry about that! |
||
|
||
@cached() | ||
async def get_thread_summary( | ||
|
@@ -283,7 +289,9 @@ async def get_thread_summary( | |
The number of items in the thread and the most recent response, if any. | ||
""" | ||
|
||
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: | ||
def _get_thread_summary_txn( | ||
txn: LoggingTransaction, | ||
) -> Tuple[int, Optional[str]]: | ||
# Fetch the count of threaded events and the latest event ID. | ||
# TODO Should this only allow m.room.message events. | ||
sql = """ | ||
|
@@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: | |
AND relation_type = ? | ||
""" | ||
txn.execute(sql, (event_id, RelationTypes.THREAD)) | ||
count = txn.fetchone()[0] | ||
count = txn.fetchone()[0] # type: ignore[index] | ||
|
||
return count, latest_event_id | ||
|
||
|
@@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: | |
|
||
latest_event = None | ||
if latest_event_id: | ||
latest_event = await self.get_event(latest_event_id, allow_none=True) | ||
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] | ||
|
||
return count, latest_event | ||
|
||
|
@@ -354,7 +362,7 @@ async def has_user_annotated_event( | |
LIMIT 1; | ||
""" | ||
|
||
def _get_if_user_has_annotated_event(txn): | ||
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: | ||
txn.execute( | ||
sql, | ||
( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this file now also passes
disallow-untyped-defs
? From a quick glance it looks like it. I don't know if you want to mark that in mypy.ini though: as we've discussed, covering a broader range of files is the priority.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it does, I hesitate to put individual files into that config though, it'll explode quickly. 🤷