Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add support for MSC2732: olm fallback keys (#8312)
Browse files Browse the repository at this point in the history
  • Loading branch information
uhoreg authored Oct 6, 2020
1 parent a024461 commit 3cd78bb
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog.d/8312.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
1 change: 1 addition & 0 deletions scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
}


Expand Down
16 changes: 16 additions & 0 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ async def upload_keys_for_user(self, user_id, device_id, keys):
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
if fallback_keys and isinstance(fallback_keys, dict):
log_kv(
{
"message": "Updating fallback_keys for device.",
"user_id": user_id,
"device_id": device_id,
}
)
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
elif fallback_keys:
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
else:
log_kv(
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)

# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
Expand Down
8 changes: 8 additions & 0 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ class SyncResult:
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
groups: Group updates, if any
"""

Expand All @@ -213,6 +215,7 @@ class SyncResult:
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict)
device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult])

def __bool__(self) -> bool:
Expand Down Expand Up @@ -1014,10 +1017,14 @@ async def generate_sync_result(
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
unused_fallback_key_types = [] # type: List[str]
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
)

logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
Expand All @@ -1041,6 +1048,7 @@ async def generate_sync_result(
device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)

Expand Down
1 change: 1 addition & 0 deletions synapse/rest/client/v2_alpha/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
"next_batch": await sync_result.next_batch.to_string(self.store),
}

Expand Down
100 changes: 99 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,57 @@ def _count_e2e_one_time_keys(txn):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)

async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
"""Set the user's e2e fallback keys.
Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
)

@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> List[str]:
"""Returns the fallback key types that have an unused key.
Args:
user_id: the user whose keys are being queried
device_id: the device whose keys are being queried
Returns:
a list of key types
"""
return await self.db_pool.simple_select_onecol(
"e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
retcol="algorithm",
desc="get_e2e_unused_fallback_key_types",
)

async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
Expand Down Expand Up @@ -701,15 +752,37 @@ def _claim_e2e_one_time_keys(txn):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
fallback_sql = (
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {}
delete = []
used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn:
otk_row = txn.fetchone()
if otk_row is not None:
key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
else:
# no one-time key available, so see if there's a fallback
# key
txn.execute(fallback_sql, (user_id, device_id, algorithm))
fallback_row = txn.fetchone()
if fallback_row is not None:
key_id, key_json, used = fallback_row
device_result[algorithm + ":" + key_id] = key_json
if not used:
used_fallbacks.append(
(user_id, device_id, algorithm, key_id)
)

# drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
Expand All @@ -726,6 +799,23 @@ def _claim_e2e_one_time_keys(txn):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
# mark fallback keys as used
for user_id, device_id, algorithm, key_id in used_fallbacks:
self.db_pool.simple_update_txn(
txn,
"e2e_fallback_keys_json",
{
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
{"used": True},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

return result

return await self.db_pool.runInteraction(
Expand Down Expand Up @@ -754,6 +844,14 @@ def delete_e2e_keys_by_device_txn(txn):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
Expand Down
24 changes: 24 additions & 0 deletions synapse/storage/databases/main/schema/delta/58/11fallback.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
user_id TEXT NOT NULL, -- The user this fallback key is for.
device_id TEXT NOT NULL, -- The device this fallback key is for.
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
key_json TEXT NOT NULL, -- The key as a JSON blob.
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
);
65 changes: 65 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,71 @@ def test_claim_one_time_key(self):
},
)

@defer.inlineCallbacks
def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
otk = {"alg1:k2": "key2"}

yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key},
)
)

# claiming an OTK when no OTKs are available should return the fallback
# key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# claiming an OTK again should return the same fallback key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)

res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)

res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

@defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
Expand Down

0 comments on commit 3cd78bb

Please sign in to comment.