diff --git a/matrix_content_scanner/crypto.py b/matrix_content_scanner/crypto.py new file mode 100644 index 0000000..58f9495 --- /dev/null +++ b/matrix_content_scanner/crypto.py @@ -0,0 +1,91 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from typing import TYPE_CHECKING + +from olm.pk import PkDecryption, PkDecryptionError, PkMessage + +from matrix_content_scanner.utils.constants import ErrCode +from matrix_content_scanner.utils.errors import ContentScannerRestError +from matrix_content_scanner.utils.types import JsonDict + +if TYPE_CHECKING: + from matrix_content_scanner.mcs import MatrixContentScanner + + +logger = logging.getLogger(__name__) + + +class CryptoHandler: + """Handler for handling Olm-encrypted request bodies.""" + + def __init__(self, mcs: "MatrixContentScanner") -> None: + key = mcs.config.crypto.pickle_key + path = mcs.config.crypto.pickle_path + try: + # Try reading the pickle from disk. + with open(path, "r") as fp: + pickle = fp.read() + + # Create a PkDecryption object with the content and key. + self._decryptor: PkDecryption = PkDecryption.from_pickle( + pickle=pickle.encode("ascii"), + passphrase=key, + ) + + logger.info("Loaded Olm pickle from %s", path) + except FileNotFoundError: + # If the pickle file doesn't exist, try creating it. + self._decryptor = PkDecryption() + pickle_bytes = self._decryptor.pickle(passphrase=key) + + logger.info( + "Olm pickle not found, generating one and saving it at %s", path + ) + + with open(path, "w+") as fp: + fp.write(pickle_bytes.decode("ascii")) + + self.public_key = self._decryptor.public_key + + def decrypt_body(self, ciphertext: str, mac: str, ephemeral: str) -> JsonDict: + """Decrypts an Olm-encrypted body. + + Args: + ciphertext: The encrypted body's ciphertext. + mac: The encrypted body's MAC. + ephemeral: The encrypted body's ephemeral key. + + Returns: + The decrypted body, parsed as JSON. + """ + try: + decrypted = self._decryptor.decrypt( + message=PkMessage( + ephemeral_key=ephemeral, + mac=mac, + ciphertext=ciphertext, + ) + ) + except PkDecryptionError as e: + logger.error("Failed to decrypt encrypted body: %s", e) + raise ContentScannerRestError( + http_status=400, + reason=ErrCode.FAILED_TO_DECRYPT, + info=str(e), + ) + + # We know that `decrypted` parses as a JsonDict. + return json.loads(decrypted) # type: ignore[no-any-return] diff --git a/matrix_content_scanner/mcs.py b/matrix_content_scanner/mcs.py index 57d8333..6934698 100644 --- a/matrix_content_scanner/mcs.py +++ b/matrix_content_scanner/mcs.py @@ -24,6 +24,7 @@ from matrix_content_scanner import logutils from matrix_content_scanner.config import MatrixContentScannerConfig +from matrix_content_scanner.crypto import CryptoHandler from matrix_content_scanner.httpserver import HTTPServer from matrix_content_scanner.scanner.file_downloader import FileDownloader from matrix_content_scanner.scanner.scanner import Scanner @@ -58,6 +59,10 @@ def file_downloader(self) -> FileDownloader: def scanner(self) -> Scanner: return Scanner(self) + @cached_property + def crypto_handler(self) -> CryptoHandler: + return CryptoHandler(self) + def start(self) -> None: """Start the HTTP server and start the reactor.""" setup_logging() diff --git a/tests/test_crypto.py b/tests/test_crypto.py new file mode 100644 index 0000000..e5d224f --- /dev/null +++ b/tests/test_crypto.py @@ -0,0 +1,43 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import unittest + +from olm.pk import PkEncryption + +from tests.testutils import get_content_scanner + + +class CryptoHandlerTestCase(unittest.TestCase): + def setUp(self) -> None: + self.crypto_handler = get_content_scanner().crypto_handler + + def test_decrypt(self) -> None: + """Tests that an Olm-encrypted payload is successfully decrypted.""" + payload = {"foo": "bar"} + + # Encrypt the payload with PkEncryption. + pke = PkEncryption(self.crypto_handler.public_key) + encrypted = pke.encrypt(json.dumps(payload)) + + # Decrypt the payload with the crypto handler. + decrypted = self.crypto_handler.decrypt_body( + encrypted.ciphertext, + encrypted.mac, + encrypted.ephemeral_key, + ) + + # Check that the decrypted payload is the same as the original one before + # encryption. + self.assertEqual(decrypted, payload)