Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Restore send as external message when payload is big #6310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions packages/core/src/conversation/message/MessageService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ function generateRecipients(users: TestUser[]): QualifiedUserClients {
}, {});
}

function fakeEncrypt(_: unknown, recipients: QualifiedUserClients): Promise<{payloads: QualifiedOTRRecipients}> {
function fakeEncrypt(
plainText: Uint8Array,
recipients: QualifiedUserClients,
): Promise<{payloads: QualifiedOTRRecipients}> {
const encryptedPayload = Object.entries(recipients).reduce((acc, [domain, users]) => {
acc[domain] = Object.entries(users).reduce((userClients, [userId, clients]) => {
userClients[userId] = clients.reduce((payloads, client) => {
payloads[client] = new Uint8Array();
payloads[client] = plainText;
return payloads;
}, {} as any);
return userClients;
Expand Down Expand Up @@ -129,6 +132,66 @@ describe('MessageService', () => {
expect(result).toEqual({...baseMessageSendingStatus, failed: undefined});
});

it('encrypts a message individually for each device if message content is small', async () => {
const [messageService, {apiClient}] = await buildMessageService();

const postOTRMessageSpy = jest
.spyOn(apiClient.api.conversation, 'postOTRMessage')
.mockResolvedValue(baseMessageSendingStatus);

const shortText = 'a'.repeat(100);
const result = await messageService.sendMessage(
clientId,
generateRecipients(generateUsers(30, 10)),
createMessage(shortText),
{conversationId},
);
expect(postOTRMessageSpy).toHaveBeenCalledWith(conversationId.id, conversationId.domain, expect.any(Object));
const payload = postOTRMessageSpy.mock.calls[0][2];
// We do not expect any assetData
expect(payload.blob.length).toBe(0);

payload.recipients.forEach(recipient => {
recipient.entries?.forEach(entry => {
entry.clients?.forEach(client => {
// Every client payload should have the text encrypted for them
expect(client.text.length).toBeGreaterThan(shortText.length);
});
});
});
expect(result).toEqual({...baseMessageSendingStatus, failed: undefined});
});

it('sends a message as external if the content is too big', async () => {
const [messageService, {apiClient}] = await buildMessageService();

const postOTRMessageSpy = jest
.spyOn(apiClient.api.conversation, 'postOTRMessage')
.mockResolvedValue(baseMessageSendingStatus);

const longText = 'a'.repeat(10000);
const result = await messageService.sendMessage(
clientId,
generateRecipients(generateUsers(30, 10)),
createMessage(longText),
{conversationId},
);
expect(postOTRMessageSpy).toHaveBeenCalledWith(conversationId.id, conversationId.domain, expect.any(Object));
const payload = postOTRMessageSpy.mock.calls[0][2];
// We expect the actual encrypted payload to be in the blob field of the payload
expect(payload.blob.length).toBeGreaterThanOrEqual(longText.length);

payload.recipients.forEach(recipient => {
recipient.entries?.forEach(entry => {
entry.clients?.forEach(client => {
// Every client payload should have a very short text
expect(client.text.length).toBeLessThanOrEqual(200);
});
});
});
expect(result).toEqual({...baseMessageSendingStatus, failed: undefined});
});

it('should send regular to conversation', async () => {
const [messageService, {apiClient}] = await buildMessageService();

Expand Down
46 changes: 44 additions & 2 deletions packages/core/src/conversation/message/MessageService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ import {StatusCodes as HTTP_STATUS} from 'http-status-codes';
import Long from 'long';

import {APIClient} from '@wireapp/api-client';
import {GenericMessage} from '@wireapp/protocol-messaging';

import {createId} from './MessageBuilder';
import {flattenUserMap} from './UserClientsUtil';

import {encryptAsset} from '../../cryptography/AssetCryptography';
import type {EncryptionResult, ProteusService} from '../../messagingProtocols/proteus';
import {isQualifiedIdArray} from '../../util';
import {GenericMessageType} from '../GenericMessageType';

type ClientMismatchError = AxiosError<MessageSendingStatus>;

Expand Down Expand Up @@ -65,10 +69,14 @@ export class MessageService {
onClientMismatch?: (mismatch: MessageSendingStatus) => void | boolean | Promise<boolean>;
} = {},
): Promise<MessageSendingStatus & {canceled?: boolean; failed?: QualifiedId[]}> {
const encryptionResults = await this.proteusService.encrypt(plainText, recipients);
const {text, assetData} = this.shouldSendAsExternal(plainText, recipients)
? await this.generateExternalPayload(plainText)
: {text: plainText, assetData: options.assetData};

const encryptionResults = await this.proteusService.encrypt(text, recipients);

const send = async ({payloads, unknowns, failed}: EncryptionResult): Promise<MessageSendingStatus> => {
const result = await this.sendOtrMessage(sendingClientId, payloads, options);
const result = await this.sendOtrMessage(sendingClientId, payloads, {...options, assetData});
const extras = {failed, deleted: unknowns ?? {}};
return deepmerge(result, extras) as MessageSendingStatus & {failed?: QualifiedId[]};
};
Expand All @@ -89,6 +97,40 @@ export class MessageService {
}
}

private async generateExternalPayload(plainText: Uint8Array): Promise<{text: Uint8Array; assetData: Uint8Array}> {
const asset = await encryptAsset({plainText});
const {cipherText, keyBytes, sha256} = asset;

const externalMessage = {
otrKey: new Uint8Array(keyBytes),
sha256: new Uint8Array(sha256),
};

const genericMessage = GenericMessage.create({
[GenericMessageType.EXTERNAL]: externalMessage,
messageId: createId(),
});

return {text: GenericMessage.encode(genericMessage).finish(), assetData: cipherText};
}

private shouldSendAsExternal(
plainText: Uint8Array,
preKeyBundles: QualifiedUserClients | QualifiedUserPreKeyBundleMap,
): boolean {
const EXTERNAL_MESSAGE_THRESHOLD_BYTES = 200 * 1024;

let clientCount = 0;
for (const user in preKeyBundles) {
clientCount += Object.keys(preKeyBundles[user]).length;
}

const messageInBytes = new Uint8Array(plainText).length;
const estimatedPayloadInBytes = clientCount * messageInBytes;

return estimatedPayloadInBytes > EXTERNAL_MESSAGE_THRESHOLD_BYTES;
}

private async sendOtrMessage(
sendingClientId: string,
recipients: QualifiedOTRRecipients,
Expand Down
Loading