diff --git a/.changeset/brown-mugs-beam.md b/.changeset/brown-mugs-beam.md new file mode 100644 index 000000000000..3177614af569 --- /dev/null +++ b/.changeset/brown-mugs-beam.md @@ -0,0 +1,9 @@ +--- +"@ledgerhq/live-common": minor +--- + +Synchronized onboarding logic with: + +- Function to extract the device onboarding state from byte flags +- Polling mechanism to retrieve the device onboarding state +- Polling mechanism available as a react hook for LLM and LLD diff --git a/apps/cli/src/commands-index.ts b/apps/cli/src/commands-index.ts index 71e416d21ad0..d4914eee8963 100644 --- a/apps/cli/src/commands-index.ts +++ b/apps/cli/src/commands-index.ts @@ -41,6 +41,7 @@ import signMessage from "./commands/signMessage"; import speculosList from "./commands/speculosList"; import swap from "./commands/swap"; import sync from "./commands/sync"; +import synchronousOnboarding from "./commands/synchronousOnboarding"; import testDetectOpCollision from "./commands/testDetectOpCollision"; import testGetTrustedInputFromTxHash from "./commands/testGetTrustedInputFromTxHash"; import user from "./commands/user"; @@ -91,6 +92,7 @@ export default { speculosList, swap, sync, + synchronousOnboarding, testDetectOpCollision, testGetTrustedInputFromTxHash, user, diff --git a/apps/cli/src/commands/synchronousOnboarding.ts b/apps/cli/src/commands/synchronousOnboarding.ts new file mode 100644 index 000000000000..c4d9cec043de --- /dev/null +++ b/apps/cli/src/commands/synchronousOnboarding.ts @@ -0,0 +1,30 @@ +import { + getOnboardingStatePolling, + OnboardingStatePollingResult, +} from "@ledgerhq/live-common/lib/hw/getOnboardingStatePolling"; +import { Observable } from "rxjs"; +import { deviceOpt } from "../scan"; + +export default { + description: "track the onboarding status of your device", + args: [ + { + name: "pollingPeriodMs", + alias: "p", + desc: "polling period in milliseconds", + type: Number, + }, + deviceOpt, + ], + job: ({ + device, + pollingPeriodMs, + }: Partial<{ + device: string; + pollingPeriodMs: number; + }>): Observable => + getOnboardingStatePolling({ + deviceId: device ?? "", + pollingPeriodMs: pollingPeriodMs ?? 1000, + }), +}; diff --git a/libs/ledger-live-common/src/hw/extractOnboardingState.test.ts b/libs/ledger-live-common/src/hw/extractOnboardingState.test.ts new file mode 100644 index 000000000000..f5b98e02d098 --- /dev/null +++ b/libs/ledger-live-common/src/hw/extractOnboardingState.test.ts @@ -0,0 +1,251 @@ +import { + extractOnboardingState, + OnboardingStep, +} from "./extractOnboardingState"; + +describe("@hw/extractOnboardingState", () => { + describe("extractOnboardingState", () => { + describe("When the flag bytes are incorrect", () => { + it("should throw an error", () => { + const incompleteFlagsBytes = Buffer.from([0, 0]); + // DeviceExtractOnboardingStateError is not of type Error, + // so cannot check in toThrow(DeviceExtractOnboardingStateError) + expect(() => extractOnboardingState(incompleteFlagsBytes)).toThrow(); + }); + }); + + describe("When the device is onboarded", () => { + it("should return a device state that is onboarded", () => { + const flagsBytes = Buffer.from([1 << 2, 0, 0, 0]); + + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.isOnboarded).toBe(true); + }); + }); + + describe("When the device is in recovery mode", () => { + it("should return a device state that is in recovery mode", () => { + const flagsBytes = Buffer.from([1, 0, 0, 0]); + + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.isInRecoveryMode).toBe(true); + }); + }); + + describe("When the device is not onboarded and in normal mode", () => { + let flagsBytes: Buffer; + + beforeEach(() => { + flagsBytes = Buffer.from([0, 0, 0, 0]); + }); + + describe("and the user is on the welcome screen", () => { + beforeEach(() => { + flagsBytes[3] = 0; + }); + + it("should return an onboarding step that is set at the welcome screen", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.WelcomeScreen + ); + }); + }); + + describe("and the user is choosing what kind of setup they want", () => { + beforeEach(() => { + flagsBytes[3] = 1; + }); + + it("should return an onboarding step that is set at the setup choice", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.SetupChoice + ); + }); + }); + + describe("and the user is setting their pin", () => { + beforeEach(() => { + flagsBytes[3] = 2; + }); + + it("should return an onboarding step that is set at setting the pin", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.Pin + ); + }); + }); + + describe("and the user is generating a new seed", () => { + describe("and the seed phrase type is set to 24 words", () => { + beforeEach(() => { + // 24-words seed + flagsBytes[2] |= 0 << 5; + }); + + it("should return a device state with the correct seed phrase type", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.seedPhraseType).toBe("24-words"); + }); + + describe("and the user is writing the seed word i", () => { + beforeEach(() => { + flagsBytes[3] = 3; + }); + + it("should return an onboarding step that is set at writting the seed phrase", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.NewDevice + ); + }); + + it("should return a device state with the index of the current seed word being written", () => { + const byte3 = flagsBytes[2]; + for (let wordIndex = 0; wordIndex < 24; wordIndex++) { + flagsBytes[2] = byte3 | wordIndex; + + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentSeedWordIndex).toBe(wordIndex); + } + }); + }); + + describe("and the user is confirming the seed word i", () => { + beforeEach(() => { + flagsBytes[3] = 4; + }); + + it("should return an onboarding step that is set at confirming the seed phrase", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.NewDeviceConfirming + ); + }); + + it("should return a device state with the index of the current seed word being confirmed", () => { + const byte3 = flagsBytes[2]; + for (let wordIndex = 0; wordIndex < 24; wordIndex++) { + flagsBytes[2] = byte3 | wordIndex; + + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentSeedWordIndex).toBe(wordIndex); + } + }); + }); + }); + }); + + describe("and the user is recovering a seed", () => { + describe("and the seed phrase type is set to X words", () => { + it("should return a device state with the correct seed phrase type", () => { + const byte3 = flagsBytes[2]; + + // 24-words + flagsBytes[2] = byte3 | (0 << 5); + let onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.seedPhraseType).toBe("24-words"); + + // 18-words + flagsBytes[2] = byte3 | (1 << 5); + onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.seedPhraseType).toBe("18-words"); + + // 12-words + flagsBytes[2] = byte3 | (2 << 5); + onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.seedPhraseType).toBe("12-words"); + }); + + describe("and the user is confirming (seed recovery) the seed word i", () => { + beforeEach(() => { + // 24-words seed + flagsBytes[2] |= 0 << 5; + + flagsBytes[3] = 5; + }); + + it("should return an onboarding step that is set at confirming the restored seed phrase", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.RestoreSeed + ); + }); + + it("should return a device state with the index of the current seed word being confirmed", () => { + const byte3 = flagsBytes[2]; + for (let wordIndex = 0; wordIndex < 24; wordIndex++) { + flagsBytes[2] = byte3 | wordIndex; + + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentSeedWordIndex).toBe(wordIndex); + } + }); + }); + }); + }); + + describe("and the user is on the safety warning screen", () => { + beforeEach(() => { + flagsBytes[3] = 6; + }); + + it("should return an onboarding step that is set at the safety warning screen", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.SafetyWarning + ); + }); + }); + + describe("and the user finished the onboarding process", () => { + beforeEach(() => { + flagsBytes[3] = 7; + }); + + it("should return an onboarding step that is set at ready", () => { + const onboardingState = extractOnboardingState(flagsBytes); + + expect(onboardingState).not.toBeNull(); + expect(onboardingState?.currentOnboardingStep).toBe( + OnboardingStep.Ready + ); + }); + }); + }); + }); +}); diff --git a/libs/ledger-live-common/src/hw/extractOnboardingState.ts b/libs/ledger-live-common/src/hw/extractOnboardingState.ts new file mode 100644 index 000000000000..6959b8e22cfe --- /dev/null +++ b/libs/ledger-live-common/src/hw/extractOnboardingState.ts @@ -0,0 +1,97 @@ +import { DeviceExtractOnboardingStateError } from "@ledgerhq/errors"; +import { SeedPhraseType } from "../types/manager"; + +const onboardingFlagsBytesLength = 4; + +const onboardedMask = 0x04; +const inRecoveryModeMask = 0x01; +const seedPhraseTypeMask = 0x60; +const seedPhraseTypeFlagOffset = 5; +const currentSeedWordIndexMask = 0x1f; + +const fromBitsToSeedPhraseType = new Map([ + [0, SeedPhraseType.TwentyFour], + [1, SeedPhraseType.Eighteen], + [2, SeedPhraseType.Twelve], +]); + +export enum OnboardingStep { + WelcomeScreen = "WELCOME_SCREEN", + SetupChoice = "SETUP_CHOICE", + Pin = "PIN", + NewDevice = "NEW_DEVICE", // path "new device" & currentSeedWordIndex available + NewDeviceConfirming = "NEW_DEVICE_CONFIRMING", // path "new device" & currentSeedWordIndex available + RestoreSeed = "RESTORE_SEED", // path "restore seed" & currentSeedWordIndex available + SafetyWarning = "SAFETY WARNING", + Ready = "READY", +} + +const fromBitsToOnboardingStep = new Map([ + [0, OnboardingStep.WelcomeScreen], + [1, OnboardingStep.SetupChoice], + [2, OnboardingStep.Pin], + [3, OnboardingStep.NewDevice], + [4, OnboardingStep.NewDeviceConfirming], + [5, OnboardingStep.RestoreSeed], + [6, OnboardingStep.SafetyWarning], + [7, OnboardingStep.Ready], +]); + +export type OnboardingState = { + // Device not yet onboarded otherwise + isOnboarded: boolean; + // In normal mode otherwise + isInRecoveryMode: boolean; + + seedPhraseType: SeedPhraseType; + + currentOnboardingStep: OnboardingStep; + currentSeedWordIndex: number; +}; + +/** + * Extracts the onboarding state of the device + * @param flagsBytes Buffer of bytes of length onboardingFlagsBytesLength reprensenting the device state flags + * @returns An OnboardingState + */ +export const extractOnboardingState = (flagsBytes: Buffer): OnboardingState => { + if (!flagsBytes || flagsBytes.length < onboardingFlagsBytesLength) { + throw new DeviceExtractOnboardingStateError( + "Incorrect onboarding flags bytes" + ); + } + + const isOnboarded = Boolean(flagsBytes[0] & onboardedMask); + const isInRecoveryMode = Boolean(flagsBytes[0] & inRecoveryModeMask); + + const seedPhraseTypeBits = + (flagsBytes[2] & seedPhraseTypeMask) >> seedPhraseTypeFlagOffset; + const seedPhraseType = fromBitsToSeedPhraseType.get(seedPhraseTypeBits); + + if (!seedPhraseType) { + throw new DeviceExtractOnboardingStateError( + "Incorrect onboarding bits for the seed phrase type" + ); + } + + const currentOnboardingStepBits = flagsBytes[3]; + const currentOnboardingStep = fromBitsToOnboardingStep.get( + currentOnboardingStepBits + ); + + if (!currentOnboardingStep) { + throw new DeviceExtractOnboardingStateError( + "Incorrect onboarding bits for the current onboarding step" + ); + } + + const currentSeedWordIndex = flagsBytes[2] & currentSeedWordIndexMask; + + return { + isOnboarded, + isInRecoveryMode, + seedPhraseType, + currentOnboardingStep, + currentSeedWordIndex, + }; +}; diff --git a/libs/ledger-live-common/src/hw/getOnboardingStatePolling.test.ts b/libs/ledger-live-common/src/hw/getOnboardingStatePolling.test.ts new file mode 100644 index 000000000000..2762f88b88d1 --- /dev/null +++ b/libs/ledger-live-common/src/hw/getOnboardingStatePolling.test.ts @@ -0,0 +1,247 @@ +import { getOnboardingStatePolling } from "./getOnboardingStatePolling"; +import { from, Subscription, TimeoutError } from "rxjs"; +import * as rxjsOperators from "rxjs/operators"; +import { DeviceModelId } from "@ledgerhq/devices"; +import Transport from "@ledgerhq/hw-transport"; +import { + DeviceExtractOnboardingStateError, + DisconnectedDevice, +} from "@ledgerhq/errors"; +import { withDevice } from "./deviceAccess"; +import getVersion from "./getVersion"; +import { + extractOnboardingState, + OnboardingState, + OnboardingStep, +} from "./extractOnboardingState"; +import { SeedPhraseType } from "../types/manager"; + +jest.mock("./deviceAccess"); +jest.mock("./getVersion"); +jest.mock("./extractOnboardingState"); +jest.mock("@ledgerhq/hw-transport"); +jest.useFakeTimers(); + +const aDevice = { + deviceId: "DEVICE_ID_A", + deviceName: "DEVICE_NAME_A", + modelId: DeviceModelId.nanoFTS, + wired: false, +}; + +// As extractOnboardingState is mocked, the firmwareInfo +// returned by getVersion does not matter +const aFirmwareInfo = { + isBootloader: false, + rawVersion: "", + targetId: 0, + mcuVersion: "", + flags: Buffer.from([]), +}; + +const pollingPeriodMs = 1000; + +const mockedGetVersion = jest.mocked(getVersion); + +const mockedWithDevice = jest.mocked(withDevice); +mockedWithDevice.mockReturnValue((job) => from(job(new Transport()))); + +const mockedExtractOnboardingState = jest.mocked(extractOnboardingState); + +describe("getOnboardingStatePolling", () => { + let anOnboardingState: OnboardingState; + let onboardingStatePollingSubscription: Subscription | null; + + beforeEach(() => { + anOnboardingState = { + isOnboarded: false, + isInRecoveryMode: false, + seedPhraseType: SeedPhraseType.TwentyFour, + currentSeedWordIndex: 0, + currentOnboardingStep: OnboardingStep.NewDevice, + }; + }); + + afterEach(() => { + mockedGetVersion.mockClear(); + mockedExtractOnboardingState.mockClear(); + jest.clearAllTimers(); + onboardingStatePollingSubscription?.unsubscribe(); + }); + + describe("When a communication error occurs while fetching the device state", () => { + describe("and when the error is allowed and thrown before the defined timeout", () => { + it("should update the onboarding state to null and keep track of the allowed error", (done) => { + mockedGetVersion.mockRejectedValue( + new DisconnectedDevice("An allowed error") + ); + mockedExtractOnboardingState.mockReturnValue(anOnboardingState); + + const device = aDevice; + + getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (value) => { + expect(value.onboardingState).toBeNull(); + expect(value.allowedError).toBeInstanceOf(DisconnectedDevice); + done(); + }, + }); + + // The timeout is equal to pollingPeriodMs by default + jest.advanceTimersByTime(pollingPeriodMs - 1); + }); + }); + + describe("and when a timeout occurred before the error (or the fetch took too long)", () => { + it("should update the allowed error value to notify the consumer - default value for the timeout", (done) => { + mockedGetVersion.mockResolvedValue(aFirmwareInfo); + mockedExtractOnboardingState.mockReturnValue(anOnboardingState); + + const device = aDevice; + + getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (value) => { + expect(value.onboardingState).toBeNull(); + expect(value.allowedError).toBeInstanceOf(TimeoutError); + done(); + }, + }); + + // Waits more than the timeout + jest.advanceTimersByTime(pollingPeriodMs + 1); + }); + + it("should update the allowed error value to notify the consumer - timeout value set by the consumer", (done) => { + const fetchingTimeoutMs = pollingPeriodMs + 500; + mockedGetVersion.mockResolvedValue(aFirmwareInfo); + mockedExtractOnboardingState.mockReturnValue(anOnboardingState); + + const device = aDevice; + + getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + fetchingTimeoutMs, + }).subscribe({ + next: (value) => { + expect(value.onboardingState).toBeNull(); + expect(value.allowedError).toBeInstanceOf(TimeoutError); + done(); + }, + }); + + // Waits more than the timeout + jest.advanceTimersByTime(fetchingTimeoutMs + 1); + }); + }); + + describe("and when the error is fatal and thrown before the defined timeout", () => { + it("should notify the consumer that a unallowed error occurred", (done) => { + mockedGetVersion.mockRejectedValue(new Error("Unknown error")); + + const device = aDevice; + + getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + error: (error) => { + expect(error).toBeInstanceOf(Error); + expect(error?.message).toBe("Unknown error"); + done(); + }, + }); + + jest.advanceTimersByTime(pollingPeriodMs - 1); + }); + }); + }); + + describe("When the fetched device state is incorrect", () => { + it("should return a null onboarding state, and keep track of the extract error", (done) => { + mockedGetVersion.mockResolvedValue(aFirmwareInfo); + mockedExtractOnboardingState.mockImplementation(() => { + throw new DeviceExtractOnboardingStateError( + "Some incorrect device info" + ); + }); + + const device = aDevice; + + onboardingStatePollingSubscription = getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (value) => { + expect(value.onboardingState).toBeNull(); + expect(value.allowedError).toBeInstanceOf( + DeviceExtractOnboardingStateError + ); + done(); + }, + }); + + jest.advanceTimersByTime(pollingPeriodMs - 1); + }); + }); + + describe("When polling returns a correct device state", () => { + it("should return a correct onboarding state", (done) => { + mockedGetVersion.mockResolvedValue(aFirmwareInfo); + mockedExtractOnboardingState.mockReturnValue(anOnboardingState); + + const device = aDevice; + + onboardingStatePollingSubscription = getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (value) => { + expect(value.allowedError).toBeNull(); + expect(value.onboardingState).toEqual(anOnboardingState); + done(); + }, + error: (error) => { + done(error); + }, + }); + + jest.advanceTimersByTime(pollingPeriodMs - 1); + }); + + it("should poll a new onboarding state after the defined period of time", (done) => { + mockedGetVersion.mockResolvedValue(aFirmwareInfo); + mockedExtractOnboardingState.mockReturnValue(anOnboardingState); + + const device = aDevice; + + // Did not manage to test that the polling is repeated by using jest's fake timer + // and advanceTimersByTime method or equivalent. + // Hacky test: spy on the repeat operator to see if it has been called. + const spiedRepeat = jest.spyOn(rxjsOperators, "repeat"); + + onboardingStatePollingSubscription = getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (value) => { + expect(value.onboardingState).toEqual(anOnboardingState); + expect(value.allowedError).toBeNull(); + expect(spiedRepeat).toHaveBeenCalledTimes(1); + done(); + }, + error: (error) => { + done(error); + }, + }); + + jest.runOnlyPendingTimers(); + }); + }); +}); diff --git a/libs/ledger-live-common/src/hw/getOnboardingStatePolling.ts b/libs/ledger-live-common/src/hw/getOnboardingStatePolling.ts new file mode 100644 index 000000000000..2142a13cc55f --- /dev/null +++ b/libs/ledger-live-common/src/hw/getOnboardingStatePolling.ts @@ -0,0 +1,153 @@ +import { + from, + merge, + partition, + of, + throwError, + Observable, + TimeoutError, +} from "rxjs"; +import { map, catchError, repeat, first, timeout } from "rxjs/operators"; +import getVersion from "./getVersion"; +import { withDevice } from "./deviceAccess"; +import { + TransportStatusError, + DeviceOnboardingStatePollingError, + DeviceExtractOnboardingStateError, + DisconnectedDevice, + CantOpenDevice, +} from "@ledgerhq/errors"; +import { FirmwareInfo } from "../types/manager"; +import { + extractOnboardingState, + OnboardingState, +} from "./extractOnboardingState"; + +export type OnboardingStatePollingResult = { + onboardingState: OnboardingState | null; + allowedError: Error | null; +}; + +/** + * Polls the device onboarding state at a given frequency + * @param deviceId A device id + * @param pollingPeriodMs The period in ms after which the device onboarding state is fetched again + * @param fetchingTimeoutMs The time to wait while fetching for the device onboarding state before throwing an error, in ms + * @returns An Observable that polls the device onboarding state + */ +export const getOnboardingStatePolling = ({ + deviceId, + pollingPeriodMs, + fetchingTimeoutMs = pollingPeriodMs, +}: { + deviceId: string; + pollingPeriodMs: number; + fetchingTimeoutMs?: number; +}): Observable => { + let firstRun = true; + + const delayedOnceOnboardingStateObservable: Observable = + new Observable((subscriber) => { + const delayMs = firstRun ? 0 : pollingPeriodMs; + firstRun = false; + + const getOnboardingStateOnce = () => { + const firmwareInfoOrAllowedErrorObservable = withDevice(deviceId)((t) => + from(getVersion(t)) + ).pipe( + timeout(fetchingTimeoutMs), // Throws a TimeoutError + first(), + catchError((error: any) => { + if (isAllowedOnboardingStatePollingError(error)) { + // Pushes the error to the next step to be processed (no retry from the beginning) + return of(error); + } + + return throwError(error); + }) + ); + + // If an error is catched previously, and this error is "allowed", + // the value from the observable is not a FirmwareInfo but an Error + const [firmwareInfoObservable, allowedErrorObservable] = partition( + firmwareInfoOrAllowedErrorObservable, + // TS cannot infer correctly the value given to RxJS partition + (value: any) => Boolean(value?.flags) + ); + + const onboardingStateFromFirmwareInfoObservable = + firmwareInfoObservable.pipe( + map((firmwareInfo: FirmwareInfo) => { + let onboardingState: OnboardingState | null = null; + + try { + onboardingState = extractOnboardingState(firmwareInfo.flags); + } catch (error: any) { + if (error instanceof DeviceExtractOnboardingStateError) { + return { + onboardingState: null, + allowedError: error, + }; + } else { + return { + onboardingState: null, + allowedError: new DeviceOnboardingStatePollingError( + `SyncOnboarding: Unknown error while extracting the onboarding state ${ + error?.name ?? error + } ${error?.message}` + ), + }; + } + } + return { onboardingState, allowedError: null }; + }) + ); + + // Handles the case of an (allowed) Error value + const onboardingStateFromAllowedErrorObservable = + allowedErrorObservable.pipe( + map((allowedError: Error) => { + return { + onboardingState: null, + allowedError: allowedError, + }; + }) + ); + + return merge( + onboardingStateFromFirmwareInfoObservable, + onboardingStateFromAllowedErrorObservable + ); + }; + + // Delays the fetch of the onboarding state + setTimeout(() => { + getOnboardingStateOnce().subscribe({ + next: (value: OnboardingStatePollingResult) => { + subscriber.next(value); + }, + error: (error: any) => { + subscriber.error(error); + }, + complete: () => subscriber.complete(), + }); + }, delayMs); + }); + + return delayedOnceOnboardingStateObservable.pipe(repeat()); +}; + +export const isAllowedOnboardingStatePollingError = (error: Error): boolean => { + if ( + error && + // Timeout error is thrown by rxjs's timeout + (error instanceof TimeoutError || + error instanceof DisconnectedDevice || + error instanceof CantOpenDevice || + error instanceof TransportStatusError) + ) { + return true; + } + + return false; +}; diff --git a/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.test.ts b/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.test.ts new file mode 100644 index 000000000000..ebb3063c437a --- /dev/null +++ b/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.test.ts @@ -0,0 +1,304 @@ +import { timer, of } from "rxjs"; +import { map, delayWhen } from "rxjs/operators"; +import { renderHook, act } from "@testing-library/react-hooks"; +import { DeviceModelId } from "@ledgerhq/devices"; +import { DisconnectedDevice } from "@ledgerhq/errors"; +import { useOnboardingStatePolling } from "./useOnboardingStatePolling"; +import { + OnboardingState, + OnboardingStep, +} from "../../hw/extractOnboardingState"; +import { SeedPhraseType } from "../../types/manager"; +import { getOnboardingStatePolling } from "../../hw/getOnboardingStatePolling"; + +jest.mock("../../hw/getOnboardingStatePolling"); +jest.useFakeTimers(); + +const aDevice = { + deviceId: "DEVICE_ID_A", + deviceName: "DEVICE_NAME_A", + modelId: DeviceModelId.nanoFTS, + wired: false, +}; + +const pollingPeriodMs = 1000; + +const mockedGetOnboardingStatePolling = jest.mocked(getOnboardingStatePolling); + +describe("useOnboardingStatePolling", () => { + let anOnboardingState: OnboardingState; + let aSecondOnboardingState: OnboardingState; + + beforeEach(() => { + anOnboardingState = { + isOnboarded: false, + isInRecoveryMode: false, + seedPhraseType: SeedPhraseType.TwentyFour, + currentSeedWordIndex: 0, + currentOnboardingStep: OnboardingStep.NewDevice, + }; + + aSecondOnboardingState = { + ...anOnboardingState, + currentOnboardingStep: OnboardingStep.NewDeviceConfirming, + }; + }); + + afterEach(() => { + mockedGetOnboardingStatePolling.mockClear(); + }); + + describe("When polling returns a correct device state", () => { + beforeEach(() => { + mockedGetOnboardingStatePolling.mockReturnValue( + of( + { + onboardingState: { ...anOnboardingState }, + allowedError: null, + }, + { + onboardingState: { ...aSecondOnboardingState }, + allowedError: null, + } + ).pipe( + delayWhen((_, index) => { + // "delay" or "delayWhen" piped to a streaming source, for ex the "of" operator, will not block the next + // Observable to be streamed. They return an Observable that delays the emission of the source Observable, + // but do not create a delay in-between each emission. That's why the delay is increased by multiplying by "index". + // "concatMap" could have been used to wait for the previous Observable to complete, but + // the "index" arg given to "delayWhen" would always be 0 + return timer(index * pollingPeriodMs); + }) + ) + ); + }); + + it("should update the onboarding state returned to the consumer", async () => { + const device = aDevice; + + const { result } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs }) + ); + + await act(async () => { + jest.advanceTimersByTime(1); + }); + + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + }); + + it("should fetch again the state at a defined frequency and update (if new) the onboarding state returned to the consumer", async () => { + const device = aDevice; + + const { result } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs }) + ); + + await act(async () => { + jest.advanceTimersByTime(1); + }); + + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + // Next polling + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs); + }); + + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(aSecondOnboardingState); + }); + + describe("and when the hook consumer stops the polling", () => { + it("should stop the polling and stop fetching the device onboarding state", async () => { + const device = aDevice; + let stopPolling = false; + + const { result, rerender } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs, stopPolling }) + ); + + await act(async () => { + jest.advanceTimersByTime(1); + }); + + // Everything is normal on the first run + expect(mockedGetOnboardingStatePolling).toHaveBeenCalledTimes(1); + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + // The consumer stops the polling + stopPolling = true; + rerender({ device, pollingPeriodMs, stopPolling }); + + await act(async () => { + // Waits as long as we want + jest.advanceTimersByTime(10 * pollingPeriodMs); + }); + + // While the hook was rerendered, it did not call a new time getOnboardingStatePolling + expect(mockedGetOnboardingStatePolling).toHaveBeenCalledTimes(1); + // And the state should stay the same (and not be aSecondOnboardingState) + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + }); + }); + }); + + describe("When an allowed error occurs while polling the device state", () => { + beforeEach(() => { + mockedGetOnboardingStatePolling.mockReturnValue( + of( + { + onboardingState: { ...anOnboardingState }, + allowedError: null, + }, + { + onboardingState: null, + allowedError: new DisconnectedDevice("An allowed error"), + }, + { + onboardingState: { ...aSecondOnboardingState }, + allowedError: null, + } + ).pipe( + delayWhen((_, index) => { + return timer(index * pollingPeriodMs); + }) + ) + ); + }); + + it("should update the allowed error returned to the consumer, update the fatal error to null and keep the previous onboarding state", async () => { + const device = aDevice; + + const { result } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs }) + ); + + await act(async () => { + jest.advanceTimersByTime(1); + }); + + // Everything is ok on the first run + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs); + }); + + expect(result.current.allowedError).toBeInstanceOf(DisconnectedDevice); + expect(result.current.fatalError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + }); + + it("should be able to recover once the allowed error is fixed and the onboarding state is updated", async () => { + const device = aDevice; + + const { result } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs }) + ); + + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs + 1); + }); + + // Allowed error occured + expect(result.current.allowedError).toBeInstanceOf(DisconnectedDevice); + expect(result.current.fatalError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs); + }); + + // Everything is ok on the next run + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(aSecondOnboardingState); + }); + }); + + describe("When a (fatal) error is thrown while polling the device state", () => { + const anOnboardingStateThatShouldNeverBeReached = { + ...aSecondOnboardingState, + }; + + beforeEach(() => { + mockedGetOnboardingStatePolling.mockReturnValue( + of( + { + onboardingState: { ...anOnboardingState }, + allowedError: null, + }, + { + onboardingState: { ...anOnboardingState }, + allowedError: null, + }, + { + // It should never be reached + onboardingState: { ...anOnboardingStateThatShouldNeverBeReached }, + allowedError: null, + } + ).pipe( + delayWhen((_, index) => { + return timer(index * pollingPeriodMs); + }), + map((value, index) => { + // Throws an error the second time + if (index === 1) { + throw new Error("An unallowed error"); + } + return value; + }) + ) + ); + }); + + it("should update the fatal error returned to the consumer, update the allowed error to null, keep the previous onboarding state and stop the polling", async () => { + const device = aDevice; + + const { result } = renderHook(() => + useOnboardingStatePolling({ device, pollingPeriodMs }) + ); + + await act(async () => { + jest.advanceTimersByTime(1); + }); + + // Everything is ok on the first run + expect(result.current.fatalError).toBeNull(); + expect(result.current.allowedError).toBeNull(); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs); + }); + + // Fatal error on the second run + expect(result.current.allowedError).toBeNull(); + expect(result.current.fatalError).toBeInstanceOf(Error); + expect(result.current.onboardingState).toEqual(anOnboardingState); + + await act(async () => { + jest.advanceTimersByTime(pollingPeriodMs); + }); + + // The polling should have been stopped, and we never update the onboardingState + expect(result.current.allowedError).toBeNull(); + expect(result.current.fatalError).toBeInstanceOf(Error); + expect(result.current.onboardingState).not.toEqual( + anOnboardingStateThatShouldNeverBeReached + ); + }); + }); +}); diff --git a/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.ts b/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.ts new file mode 100644 index 000000000000..aa84c9d7fe8c --- /dev/null +++ b/libs/ledger-live-common/src/onboarding/hooks/useOnboardingStatePolling.ts @@ -0,0 +1,86 @@ +import { useState, useEffect } from "react"; +import { Subscription } from "rxjs"; +import type { Device } from "../../hw/actions/types"; +import { DeviceOnboardingStatePollingError } from "@ledgerhq/errors"; + +import type { OnboardingStatePollingResult } from "../../hw/getOnboardingStatePolling"; +import { getOnboardingStatePolling } from "../../hw/getOnboardingStatePolling"; +import { OnboardingState } from "../../hw/extractOnboardingState"; + +export type UseOnboardingStatePollingResult = OnboardingStatePollingResult & { + fatalError: Error | null; +}; + +/** + * Polls the current device onboarding state, and notify the hook consumer of + * any allowed errors and fatal errors + * @param device A Device object + * @param pollingPeriodMs The period in ms after which the device onboarding state is fetched again + * @param stopPolling Flag to stop or continue the polling + * @returns An object containing: + * - onboardingState: the device state during the onboarding + * - allowedError: any error that is allowed and does not stop the polling + * - fatalError: any error that is fatal and stops the polling + */ +export const useOnboardingStatePolling = ({ + device, + pollingPeriodMs, + stopPolling = false, +}: { + device: Device | null; + pollingPeriodMs: number; + stopPolling?: boolean; +}): UseOnboardingStatePollingResult => { + const [onboardingState, setOnboardingState] = + useState(null); + const [allowedError, setAllowedError] = useState(null); + const [fatalError, setFatalError] = useState(null); + + useEffect(() => { + let onboardingStatePollingSubscription: Subscription; + + // If stopPolling is updated and set to true, the useEffect hook will call its + // cleanup function (return) and the polling won't restart with the below condition + if (device && !stopPolling) { + onboardingStatePollingSubscription = getOnboardingStatePolling({ + deviceId: device.deviceId, + pollingPeriodMs, + }).subscribe({ + next: (onboardingStatePollingResult: OnboardingStatePollingResult) => { + if (onboardingStatePollingResult) { + setFatalError(null); + setAllowedError(onboardingStatePollingResult.allowedError); + + // Does not update the onboarding state if an allowed error occurred + if (!onboardingStatePollingResult.allowedError) { + setOnboardingState(onboardingStatePollingResult.onboardingState); + } + } + }, + error: (error) => { + setAllowedError(null); + setFatalError( + error instanceof Error + ? error + : new DeviceOnboardingStatePollingError( + `Error from: ${error?.name ?? error} ${error?.message}` + ) + ); + }, + }); + } + + return () => { + onboardingStatePollingSubscription?.unsubscribe(); + }; + }, [ + device, + pollingPeriodMs, + setOnboardingState, + setAllowedError, + setFatalError, + stopPolling, + ]); + + return { onboardingState, allowedError, fatalError }; +}; diff --git a/libs/ledger-live-common/src/types/manager.ts b/libs/ledger-live-common/src/types/manager.ts index 68f8f1f332c8..36edb237a255 100644 --- a/libs/ledger-live-common/src/types/manager.ts +++ b/libs/ledger-live-common/src/types/manager.ts @@ -68,6 +68,11 @@ export type McuVersion = { date_creation: string; date_last_modified: string; }; +export enum SeedPhraseType { + Twelve = "12-words", + Eighteen = "18-words", + TwentyFour = "24-words", +} export type FirmwareInfo = { isBootloader: boolean; rawVersion: string; // if SE seVersion, if BL blVersion diff --git a/libs/ledgerjs/packages/errors/src/index.ts b/libs/ledgerjs/packages/errors/src/index.ts index 455571a1ee4a..65e90db0b554 100644 --- a/libs/ledgerjs/packages/errors/src/index.ts +++ b/libs/ledgerjs/packages/errors/src/index.ts @@ -54,6 +54,12 @@ export const DisconnectedDevice = createCustomErrorClass("DisconnectedDevice"); export const DisconnectedDeviceDuringOperation = createCustomErrorClass( "DisconnectedDeviceDuringOperation" ); +export const DeviceExtractOnboardingStateError = createCustomErrorClass( + "DeviceExtractOnboardingStateError" +); +export const DeviceOnboardingStatePollingError = createCustomErrorClass( + "DeviceOnboardingStatePollingError" +); export const EnpointConfigError = createCustomErrorClass("EnpointConfig"); export const EthAppPleaseEnableContractData = createCustomErrorClass( "EthAppPleaseEnableContractData"