Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add support for olm fallback keys #8312

Merged
merged 14 commits into from
Oct 6, 2020
1 change: 1 addition & 0 deletions changelog.d/8312.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
1 change: 1 addition & 0 deletions scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
}


Expand Down
16 changes: 16 additions & 0 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ async def upload_keys_for_user(self, user_id, device_id, keys):
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
if fallback_keys and isinstance(fallback_keys, dict):
uhoreg marked this conversation as resolved.
Show resolved Hide resolved
log_kv(
{
"message": "Updating fallback_keys for device.",
"user_id": user_id,
"device_id": device_id,
}
)
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
elif fallback_keys:
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
else:
log_kv(
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)

# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
Expand Down
8 changes: 8 additions & 0 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ class SyncResult:
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
groups: Group updates, if any
"""

Expand All @@ -213,6 +215,7 @@ class SyncResult:
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict)
device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult])

def __bool__(self) -> bool:
Expand Down Expand Up @@ -1014,10 +1017,14 @@ async def generate_sync_result(
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
unused_fallback_key_types = [] # type: List[str]
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
)

logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
Expand All @@ -1041,6 +1048,7 @@ async def generate_sync_result(
device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)

Expand Down
1 change: 1 addition & 0 deletions synapse/rest/client/v2_alpha/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
clokep marked this conversation as resolved.
Show resolved Hide resolved
"next_batch": await sync_result.next_batch.to_string(self.store),
}

Expand Down
100 changes: 99 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,57 @@ def _count_e2e_one_time_keys(txn):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)

async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
"""Set the user's e2e fallback keys.

Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
Comment on lines +381 to +383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to fix this FIXME before merging?

Should this loop be in a transaction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the FIXME is more of a "nice to have" rather than a requirement, so I don't think it needs to be fixed.

Also, I don't think the loop needs to be in a transaction (all the additions are independent), but I can do that if it you think it's a good idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a constraint on the database for user + device + algorithm. Does that handle this or does this really only expect a single algorithm per user?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should handle it. It would just do a series of upserts, so would just end up with the last fallback key and dropping all the others. A client shouldn't expect that the server would store all the keys if it gives multiple keys per algorithm, so this seems like a reasonable result if we don't throw an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this fixme is really an optimization?

for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
)

@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> List[str]:
"""Returns the fallback key types that have an unused key.

Args:
user_id: the user whose keys are being queried
device_id: the device whose keys are being queried

Returns:
a list of key types
"""
return await self.db_pool.simple_select_onecol(
"e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
retcol="algorithm",
desc="get_e2e_unused_fallback_key_types",
)

async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
Expand Down Expand Up @@ -701,15 +752,37 @@ def _claim_e2e_one_time_keys(txn):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
fallback_sql = (
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {}
delete = []
used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn:
otk_row = txn.fetchone()
if otk_row is not None:
key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
else:
# no one-time key available, so see if there's a fallback
# key
txn.execute(fallback_sql, (user_id, device_id, algorithm))
fallback_row = txn.fetchone()
if fallback_row is not None:
key_id, key_json, used = fallback_row
device_result[algorithm + ":" + key_id] = key_json
if not used:
used_fallbacks.append(
(user_id, device_id, algorithm, key_id)
)

# drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
Expand All @@ -726,6 +799,23 @@ def _claim_e2e_one_time_keys(txn):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
# mark fallback keys as used
for user_id, device_id, algorithm, key_id in used_fallbacks:
self.db_pool.simple_update_txn(
uhoreg marked this conversation as resolved.
Show resolved Hide resolved
txn,
"e2e_fallback_keys_json",
{
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
{"used": True},
clokep marked this conversation as resolved.
Show resolved Hide resolved
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

return result

return await self.db_pool.runInteraction(
Expand Down Expand Up @@ -754,6 +844,14 @@ def delete_e2e_keys_by_device_txn(txn):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
Expand Down
24 changes: 24 additions & 0 deletions synapse/storage/databases/main/schema/delta/58/11fallback.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
clokep marked this conversation as resolved.
Show resolved Hide resolved
user_id TEXT NOT NULL, -- The user this fallback key is for.
device_id TEXT NOT NULL, -- The device this fallback key is for.
clokep marked this conversation as resolved.
Show resolved Hide resolved
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
key_json TEXT NOT NULL, -- The key as a JSON blob.
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
clokep marked this conversation as resolved.
Show resolved Hide resolved
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
);
65 changes: 65 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,71 @@ def test_claim_one_time_key(self):
},
)

@defer.inlineCallbacks
def test_fallback_key(self):
clokep marked this conversation as resolved.
Show resolved Hide resolved
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
otk = {"alg1:k2": "key2"}

yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key},
)
)

# claiming an OTK when no OTKs are available should return the fallback
# key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# claiming an OTK again should return the same fallback key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)

res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)

res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

@defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
Expand Down