Skip to content

Commit 606dc02

Browse files
committed
refactor(connect): use improved message types
1 parent e849826 commit 606dc02

File tree

8 files changed

+38
-49
lines changed

8 files changed

+38
-49
lines changed

packages/connect/src/device/DeviceCommands.ts

+23-31
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,16 @@ import { initLog } from '../utils/debug';
1818
import * as hdnodeUtils from '../utils/hdnodeUtils';
1919
import { getScriptType, getSerializedPath, isTaprootPath, toHardened } from '../utils/pathUtils';
2020

21-
type MessageType = Messages.MessageType;
22-
type MessageKey = keyof MessageType;
23-
type TypedPayload<T extends MessageKey> = {
24-
type: T;
25-
message: MessageType[T];
26-
};
27-
type TypedCallResponseMap = {
28-
[K in keyof MessageType]: TypedPayload<K>;
29-
};
30-
type DefaultPayloadMessage = TypedCallResponseMap[keyof MessageType];
21+
type TypedCall = Messages.TypedCall;
22+
23+
export type { TypedCall };
3124

3225
const logger = initLog('DeviceCommands');
3326

34-
const assertType = (res: DefaultPayloadMessage, resType: MessageKey | MessageKey[]) => {
27+
const assertType = (
28+
res: Messages.MessageResponse,
29+
resType: Messages.MessageKey | Messages.MessageKey[],
30+
) => {
3531
const splitResTypes = Array.isArray(resType) ? resType : resType.split('|');
3632
if (!splitResTypes.includes(res.type)) {
3733
throw ERRORS.TypedError(
@@ -295,10 +291,10 @@ export class DeviceCommands {
295291
}
296292

297293
// Sends an async message to the opened device.
298-
private async call(
299-
type: MessageKey,
300-
msg: DefaultPayloadMessage['message'] = {},
301-
): Promise<DefaultPayloadMessage> {
294+
private async call<T extends Messages.MessageKey>(
295+
type: T,
296+
msg: Messages.MessagePayload<T>,
297+
): Promise<Messages.MessageResponse> {
302298
logger.debug('Sending', type, filterForLog(type, msg));
303299

304300
this.callPromise = this.transport.call({
@@ -327,32 +323,30 @@ export class DeviceCommands {
327323
filterForLog(res.payload.type, res.payload.message),
328324
);
329325

330-
// TODO: https://github.com/trezor/trezor-suite/issues/5301
331-
// @ts-expect-error
332326
return res.payload;
333327
}
334328

335-
typedCall<T extends MessageKey, R extends MessageKey[]>(
329+
typedCall<T extends Messages.MessageKey, R extends Messages.MessageKey[]>(
336330
type: T,
337331
resType: R,
338-
msg?: MessageType[T],
339-
): Promise<TypedCallResponseMap[R[number]]>;
340-
typedCall<T extends MessageKey, R extends MessageKey>(
332+
msg?: Messages.MessagePayload<T>,
333+
): Promise<Messages.MessageResponse<R[number]>>;
334+
typedCall<T extends Messages.MessageKey, R extends Messages.MessageKey>(
341335
type: T,
342336
resType: R,
343-
msg?: MessageType[T],
344-
): Promise<TypedPayload<R>>;
337+
msg?: Messages.MessagePayload<T>,
338+
): Promise<Messages.MessageResponse<R>>;
345339
async typedCall(
346-
type: MessageKey,
347-
resType: MessageKey | MessageKey[],
348-
msg?: DefaultPayloadMessage['message'],
340+
type: Messages.MessageKey,
341+
resType: Messages.MessageKey | Messages.MessageKey[],
342+
msg: Messages.MessagePayload = {},
349343
) {
350344
if (this.disposed) {
351345
throw ERRORS.TypedError('Runtime', 'typedCall: DeviceCommands already disposed');
352346
}
353347
// Assert message type
354348
// msg is allowed to be undefined for some calls, in that case the schema is an empty object
355-
Assert(Messages.MessageType.properties[type], msg ?? {});
349+
Assert(Messages.MessageType.properties[type], msg);
356350
const response = await this._commonCall(type, msg);
357351
try {
358352
assertType(response, resType);
@@ -380,7 +374,7 @@ export class DeviceCommands {
380374
return response;
381375
}
382376

383-
async _commonCall(type: MessageKey, msg?: DefaultPayloadMessage['message']) {
377+
async _commonCall<T extends Messages.MessageKey>(type: T, msg: Messages.MessagePayload<T>) {
384378
if (this.disposed) {
385379
throw ERRORS.TypedError('Runtime', 'typedCall: DeviceCommands already disposed');
386380
}
@@ -389,7 +383,7 @@ export class DeviceCommands {
389383
return this._filterCommonTypes(resp);
390384
}
391385

392-
_filterCommonTypes(res: DefaultPayloadMessage): Promise<DefaultPayloadMessage> {
386+
_filterCommonTypes(res: Messages.MessageResponse): Promise<Messages.MessageResponse> {
393387
this.device.clearCancelableAction();
394388

395389
if (res.type === 'Failure') {
@@ -589,5 +583,3 @@ export class DeviceCommands {
589583
}
590584
}
591585
}
592-
593-
export type TypedCall = DeviceCommands['typedCall'];

packages/connect/src/device/prompts.ts

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ export const cancelPrompt = (device: Device, expectResponse = true) => {
3535
return expectResponse ? device.transport.call(cancelArgs) : device.transport.send(cancelArgs);
3636
};
3737

38+
const extractMessage = (payload?: Messages.MessageResponse) =>
39+
(payload && 'message' in payload.message && payload.message.message) || '';
40+
3841
const prompt = <E extends PromptEvents>(event: E, { device, ...rest }: DeviceEventArgs<E>) =>
3942
// return non nullable first arg of PromptCallback<E>
4043
new Promise<PromptReturnType<E>>(resolve => {
@@ -43,8 +46,8 @@ const prompt = <E extends PromptEvents>(event: E, { device, ...rest }: DeviceEve
4346
response.success
4447
? resolve({
4548
success: false,
46-
error: error || (response.payload?.message.message as string),
47-
message: response.payload?.message.message as string,
49+
error: error || extractMessage(response.payload),
50+
message: extractMessage(response.payload),
4851
isTransportError: !response.success,
4952
})
5053
: resolve({

packages/protobuf/src/decode.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { Field, Message as MessageType, Type } from 'protobufjs/light';
22

3+
import type { MessageResponse } from './messages';
34
import { createMessageFromType, isPrimitiveField } from './utils';
45

56
const transform = (field: Field, value: any) => {
@@ -84,5 +85,5 @@ export const decodeMessage = (
8485
const { Message, messageName } = createMessageFromType(messages, messageType);
8586
const message = decode(Message, data);
8687

87-
return { messageName, message };
88+
return { type: messageName, message } as MessageResponse;
8889
};

packages/protobuf/src/index.ts

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ export const { parseConfigure, decodeMessage, encodeMessage } = (() => {
1717
return { parseConfigure: parse, decodeMessage: decode, encodeMessage: encode };
1818
})();
1919

20-
export * from './types';
2120
export * as Messages from './messages';
2221
export { loadDefinitions } from './load-definitions';
2322
export * as MessagesSchema from './messages-schema';

packages/protobuf/src/types.ts

-6
This file was deleted.

packages/protobuf/src/utils.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import * as protobuf from 'protobufjs/light';
44

5-
import type { MessageFromTrezor } from './types';
5+
import type { MessageKey } from './messages';
66

77
const primitiveTypes = [
88
'bool',
@@ -58,13 +58,13 @@ export const createMessageFromType = (messages: protobuf.Root, messageType: numb
5858

5959
return {
6060
Message,
61-
messageName: messageType as MessageFromTrezor['type'],
61+
messageName: messageType as MessageKey,
6262
};
6363
}
6464

6565
const messageTypes = messages.lookupEnum('MessageType');
6666

67-
const messageName = messageTypes.valuesById[messageType] as MessageFromTrezor['type'];
67+
const messageName = messageTypes.valuesById[messageType] as MessageKey;
6868

6969
const Message = messages.lookupType(messageName);
7070

packages/transport/src/transports/abstract.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { MessageFromTrezor, loadDefinitions, parseConfigure } from '@trezor/protobuf';
1+
import { Messages, loadDefinitions, parseConfigure } from '@trezor/protobuf';
22
import { PROTOCOL_MALFORMED, TransportProtocol } from '@trezor/protocol';
33
import { ScheduleActionParams, ScheduledAction, TypedEmitter, scheduleAction } from '@trezor/utils';
44

@@ -260,7 +260,7 @@ export abstract class AbstractTransport extends TransportEmitter {
260260
session: Session;
261261
protocol?: TransportProtocol;
262262
} & AbortableParam,
263-
): AsyncResultWithTypedError<MessageFromTrezor, ReadWriteError>;
263+
): AsyncResultWithTypedError<Messages.MessageResponse, ReadWriteError>;
264264

265265
/**
266266
* Send and read after that
@@ -272,7 +272,7 @@ export abstract class AbstractTransport extends TransportEmitter {
272272
data: Record<string, unknown>;
273273
protocol?: TransportProtocol;
274274
} & AbortableParam,
275-
): AsyncResultWithTypedError<MessageFromTrezor, ReadWriteError>;
275+
): AsyncResultWithTypedError<Messages.MessageResponse, ReadWriteError>;
276276

277277
/**
278278
* Stop transport = remove all listeners + try to release session + cancel all requests

packages/transport/src/utils/receive.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ export async function receiveAndParse<T extends () => ReturnType<AbstractApi['re
4444
if (!readResult.success) return readResult;
4545

4646
const { messageType, payload } = readResult.payload;
47-
const { messageName, message } = decodeMessage(messages, messageType, payload);
47+
const message = decodeMessage(messages, messageType, payload);
4848

49-
return success({ message, type: messageName });
49+
return success(message);
5050
}

0 commit comments

Comments
 (0)