From 68fb5cc701b8bc2769da28212daa640256e08645 Mon Sep 17 00:00:00 2001 From: Thomas Belin Date: Thu, 20 Jun 2024 17:52:24 +0200 Subject: [PATCH] fix: Restore send as external message when payload is big --- .../message/MessageService.test.ts | 67 ++++++++++++++++++- .../conversation/message/MessageService.ts | 46 ++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/packages/core/src/conversation/message/MessageService.test.ts b/packages/core/src/conversation/message/MessageService.test.ts index 92bc54ec48..1026ee4d02 100644 --- a/packages/core/src/conversation/message/MessageService.test.ts +++ b/packages/core/src/conversation/message/MessageService.test.ts @@ -72,11 +72,14 @@ function generateRecipients(users: TestUser[]): QualifiedUserClients { }, {}); } -function fakeEncrypt(_: unknown, recipients: QualifiedUserClients): Promise<{payloads: QualifiedOTRRecipients}> { +function fakeEncrypt( + plainText: Uint8Array, + recipients: QualifiedUserClients, +): Promise<{payloads: QualifiedOTRRecipients}> { const encryptedPayload = Object.entries(recipients).reduce((acc, [domain, users]) => { acc[domain] = Object.entries(users).reduce((userClients, [userId, clients]) => { userClients[userId] = clients.reduce((payloads, client) => { - payloads[client] = new Uint8Array(); + payloads[client] = plainText; return payloads; }, {} as any); return userClients; @@ -129,6 +132,66 @@ describe('MessageService', () => { expect(result).toEqual({...baseMessageSendingStatus, failed: undefined}); }); + it('encrypts a message individually for each device if message content is small', async () => { + const [messageService, {apiClient}] = await buildMessageService(); + + const postOTRMessageSpy = jest + .spyOn(apiClient.api.conversation, 'postOTRMessage') + .mockResolvedValue(baseMessageSendingStatus); + + const shortText = 'a'.repeat(100); + const result = await messageService.sendMessage( + clientId, + generateRecipients(generateUsers(30, 10)), + createMessage(shortText), + {conversationId}, + ); + expect(postOTRMessageSpy).toHaveBeenCalledWith(conversationId.id, conversationId.domain, expect.any(Object)); + const payload = postOTRMessageSpy.mock.calls[0][2]; + // We do not expect any assetData + expect(payload.blob.length).toBe(0); + + payload.recipients.forEach(recipient => { + recipient.entries?.forEach(entry => { + entry.clients?.forEach(client => { + // Every client payload should have the text encrypted for them + expect(client.text.length).toBeGreaterThan(shortText.length); + }); + }); + }); + expect(result).toEqual({...baseMessageSendingStatus, failed: undefined}); + }); + + it('sends a message as external if the content is too big', async () => { + const [messageService, {apiClient}] = await buildMessageService(); + + const postOTRMessageSpy = jest + .spyOn(apiClient.api.conversation, 'postOTRMessage') + .mockResolvedValue(baseMessageSendingStatus); + + const longText = 'a'.repeat(10000); + const result = await messageService.sendMessage( + clientId, + generateRecipients(generateUsers(30, 10)), + createMessage(longText), + {conversationId}, + ); + expect(postOTRMessageSpy).toHaveBeenCalledWith(conversationId.id, conversationId.domain, expect.any(Object)); + const payload = postOTRMessageSpy.mock.calls[0][2]; + // We expect the actual encrypted payload to be in the blob field of the payload + expect(payload.blob.length).toBeGreaterThanOrEqual(longText.length); + + payload.recipients.forEach(recipient => { + recipient.entries?.forEach(entry => { + entry.clients?.forEach(client => { + // Every client payload should have a very short text + expect(client.text.length).toBeLessThanOrEqual(200); + }); + }); + }); + expect(result).toEqual({...baseMessageSendingStatus, failed: undefined}); + }); + it('should send regular to conversation', async () => { const [messageService, {apiClient}] = await buildMessageService(); diff --git a/packages/core/src/conversation/message/MessageService.ts b/packages/core/src/conversation/message/MessageService.ts index 703e2e249f..6810f791f2 100644 --- a/packages/core/src/conversation/message/MessageService.ts +++ b/packages/core/src/conversation/message/MessageService.ts @@ -27,11 +27,15 @@ import {StatusCodes as HTTP_STATUS} from 'http-status-codes'; import Long from 'long'; import {APIClient} from '@wireapp/api-client'; +import {GenericMessage} from '@wireapp/protocol-messaging'; +import {createId} from './MessageBuilder'; import {flattenUserMap} from './UserClientsUtil'; +import {encryptAsset} from '../../cryptography/AssetCryptography'; import type {EncryptionResult, ProteusService} from '../../messagingProtocols/proteus'; import {isQualifiedIdArray} from '../../util'; +import {GenericMessageType} from '../GenericMessageType'; type ClientMismatchError = AxiosError; @@ -65,10 +69,14 @@ export class MessageService { onClientMismatch?: (mismatch: MessageSendingStatus) => void | boolean | Promise; } = {}, ): Promise { - const encryptionResults = await this.proteusService.encrypt(plainText, recipients); + const {text, assetData} = this.shouldSendAsExternal(plainText, recipients) + ? await this.generateExternalPayload(plainText) + : {text: plainText, assetData: options.assetData}; + + const encryptionResults = await this.proteusService.encrypt(text, recipients); const send = async ({payloads, unknowns, failed}: EncryptionResult): Promise => { - const result = await this.sendOtrMessage(sendingClientId, payloads, options); + const result = await this.sendOtrMessage(sendingClientId, payloads, {...options, assetData}); const extras = {failed, deleted: unknowns ?? {}}; return deepmerge(result, extras) as MessageSendingStatus & {failed?: QualifiedId[]}; }; @@ -89,6 +97,40 @@ export class MessageService { } } + private async generateExternalPayload(plainText: Uint8Array): Promise<{text: Uint8Array; assetData: Uint8Array}> { + const asset = await encryptAsset({plainText}); + const {cipherText, keyBytes, sha256} = asset; + + const externalMessage = { + otrKey: new Uint8Array(keyBytes), + sha256: new Uint8Array(sha256), + }; + + const genericMessage = GenericMessage.create({ + [GenericMessageType.EXTERNAL]: externalMessage, + messageId: createId(), + }); + + return {text: GenericMessage.encode(genericMessage).finish(), assetData: cipherText}; + } + + private shouldSendAsExternal( + plainText: Uint8Array, + preKeyBundles: QualifiedUserClients | QualifiedUserPreKeyBundleMap, + ): boolean { + const EXTERNAL_MESSAGE_THRESHOLD_BYTES = 200 * 1024; + + let clientCount = 0; + for (const user in preKeyBundles) { + clientCount += Object.keys(preKeyBundles[user]).length; + } + + const messageInBytes = new Uint8Array(plainText).length; + const estimatedPayloadInBytes = clientCount * messageInBytes; + + return estimatedPayloadInBytes > EXTERNAL_MESSAGE_THRESHOLD_BYTES; + } + private async sendOtrMessage( sendingClientId: string, recipients: QualifiedOTRRecipients,