From 0fce61e40535a4f1b3b05fdd9da60f9218250c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petar=20Penovi=C4=87?= Date: Tue, 3 Dec 2024 19:31:06 +0100 Subject: [PATCH] feat: enable base fetch override (#1279) --- __tests__/rpcChannel.test.ts | 31 +++++++++++++++++-------- __tests__/rpcProvider.test.ts | 8 +++++++ __tests__/utils/batch.test.ts | 2 ++ src/channel/rpc_0_6.ts | 36 +++++++++++++++++------------ src/channel/rpc_0_7.ts | 36 +++++++++++++++++------------ src/types/provider/configuration.ts | 1 + src/utils/batch/index.ts | 8 +++++-- 7 files changed, 81 insertions(+), 41 deletions(-) diff --git a/__tests__/rpcChannel.test.ts b/__tests__/rpcChannel.test.ts index 8d0a0d01a..471486ac8 100644 --- a/__tests__/rpcChannel.test.ts +++ b/__tests__/rpcChannel.test.ts @@ -1,23 +1,29 @@ -import { LibraryError, RPC07, RpcError } from '../src'; +import { LibraryError, RPC06, RPC07, RpcError } from '../src'; import { createBlockForDevnet, getTestProvider } from './config/fixtures'; import { initializeMatcher } from './config/schema'; -describe('RPC 0.7.0', () => { - const rpcProvider = getTestProvider(false); - const channel = rpcProvider.channel as RPC07.RpcChannel; +describe('RpcChannel', () => { + const { nodeUrl } = getTestProvider(false).channel; + const channel07 = new RPC07.RpcChannel({ nodeUrl }); initializeMatcher(expect); beforeAll(async () => { await createBlockForDevnet(); }); - test('getBlockWithReceipts', async () => { - const response = await channel.getBlockWithReceipts('latest'); - expect(response).toMatchSchemaRef('BlockWithTxReceipts'); + test('baseFetch override', async () => { + const baseFetch = jest.fn(); + const fetchChannel06 = new RPC06.RpcChannel({ nodeUrl, baseFetch }); + const fetchChannel07 = new RPC07.RpcChannel({ nodeUrl, baseFetch }); + (fetchChannel06.fetch as any)(); + expect(baseFetch).toHaveBeenCalledTimes(1); + baseFetch.mockClear(); + (fetchChannel07.fetch as any)(); + expect(baseFetch).toHaveBeenCalledTimes(1); }); test('RPC error handling', async () => { - const fetchSpy = jest.spyOn(channel, 'fetch'); + const fetchSpy = jest.spyOn(channel07, 'fetch'); fetchSpy.mockResolvedValue({ json: async () => ({ jsonrpc: '2.0', @@ -32,7 +38,7 @@ describe('RPC 0.7.0', () => { expect.assertions(3); try { // @ts-expect-error - await channel.fetchEndpoint('starknet_chainId'); + await channel07.fetchEndpoint('starknet_chainId'); } catch (error) { expect(error).toBeInstanceOf(LibraryError); expect(error).toBeInstanceOf(RpcError); @@ -40,4 +46,11 @@ describe('RPC 0.7.0', () => { } fetchSpy.mockRestore(); }); + + describe('RPC 0.7.0', () => { + test('getBlockWithReceipts', async () => { + const response = await channel07.getBlockWithReceipts('latest'); + expect(response).toMatchSchemaRef('BlockWithTxReceipts'); + }); + }); }); diff --git a/__tests__/rpcProvider.test.ts b/__tests__/rpcProvider.test.ts index c2968e995..76dbd1b4f 100644 --- a/__tests__/rpcProvider.test.ts +++ b/__tests__/rpcProvider.test.ts @@ -47,6 +47,14 @@ describeIfRpc('RPCProvider', () => { await createBlockForDevnet(); }); + test('baseFetch override', async () => { + const { nodeUrl } = rpcProvider.channel; + const baseFetch = jest.fn(); + const fetchProvider = new RpcProvider({ nodeUrl, baseFetch }); + (fetchProvider.fetch as any)(); + expect(baseFetch.mock.calls.length).toBe(1); + }); + test('instantiate from rpcProvider', () => { const newInsRPCProvider = new RpcProvider(); diff --git a/__tests__/utils/batch.test.ts b/__tests__/utils/batch.test.ts index fbd78f7e3..a60ff862a 100644 --- a/__tests__/utils/batch.test.ts +++ b/__tests__/utils/batch.test.ts @@ -1,3 +1,4 @@ +import fetch from '../../src/utils/fetchPonyfill'; import { BatchClient } from '../../src/utils/batch'; import { createBlockForDevnet, getTestProvider } from '../config/fixtures'; import { initializeMatcher } from '../config/schema'; @@ -9,6 +10,7 @@ describe('Batch Client', () => { nodeUrl: provider.channel.nodeUrl, headers: provider.channel.headers, interval: 0, + baseFetch: fetch, }); initializeMatcher(expect); diff --git a/src/channel/rpc_0_6.ts b/src/channel/rpc_0_6.ts index f9bc221e3..f8386d59a 100644 --- a/src/channel/rpc_0_6.ts +++ b/src/channel/rpc_0_6.ts @@ -42,33 +42,36 @@ export class RpcChannel { public headers: object; - readonly retries: number; - public requestId: number; readonly blockIdentifier: BlockIdentifier; + readonly retries: number; + + readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed + private chainId?: StarknetChainId; private specVersion?: string; private transactionRetryIntervalFallback?: number; - readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed - private batchClient?: BatchClient; + private baseFetch: NonNullable; + constructor(optionsOrProvider?: RpcProviderOptions) { const { - nodeUrl, - retries, - headers, + baseFetch, + batch, blockIdentifier, chainId, + headers, + nodeUrl, + retries, specVersion, - waitMode, transactionRetryIntervalFallback, - batch, + waitMode, } = optionsOrProvider || {}; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); @@ -77,20 +80,23 @@ export class RpcChannel { } else { this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default); } - this.retries = retries || defaultOptions.retries; - this.headers = { ...defaultOptions.headers, ...headers }; - this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier; + this.baseFetch = baseFetch ?? fetch; + this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier; this.chainId = chainId; + this.headers = { ...defaultOptions.headers, ...headers }; + this.retries = retries ?? defaultOptions.retries; this.specVersion = specVersion; - this.waitMode = waitMode || false; - this.requestId = 0; this.transactionRetryIntervalFallback = transactionRetryIntervalFallback; + this.waitMode = waitMode ?? false; + + this.requestId = 0; if (typeof batch === 'number') { this.batchClient = new BatchClient({ nodeUrl: this.nodeUrl, headers: this.headers, interval: batch, + baseFetch: this.baseFetch, }); } } @@ -110,7 +116,7 @@ export class RpcChannel { method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return this.baseFetch(this.nodeUrl, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, diff --git a/src/channel/rpc_0_7.ts b/src/channel/rpc_0_7.ts index 1918fc880..54c5fb030 100644 --- a/src/channel/rpc_0_7.ts +++ b/src/channel/rpc_0_7.ts @@ -42,33 +42,36 @@ export class RpcChannel { public headers: object; - readonly retries: number; - public requestId: number; readonly blockIdentifier: BlockIdentifier; + readonly retries: number; + + readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed + private chainId?: StarknetChainId; private specVersion?: string; private transactionRetryIntervalFallback?: number; - readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed - private batchClient?: BatchClient; + private baseFetch: NonNullable; + constructor(optionsOrProvider?: RpcProviderOptions) { const { - nodeUrl, - retries, - headers, + baseFetch, + batch, blockIdentifier, chainId, + headers, + nodeUrl, + retries, specVersion, - waitMode, transactionRetryIntervalFallback, - batch, + waitMode, } = optionsOrProvider || {}; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); @@ -77,20 +80,23 @@ export class RpcChannel { } else { this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default); } - this.retries = retries || defaultOptions.retries; - this.headers = { ...defaultOptions.headers, ...headers }; - this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier; + this.baseFetch = baseFetch ?? fetch; + this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier; this.chainId = chainId; + this.headers = { ...defaultOptions.headers, ...headers }; + this.retries = retries ?? defaultOptions.retries; this.specVersion = specVersion; - this.waitMode = waitMode || false; - this.requestId = 0; this.transactionRetryIntervalFallback = transactionRetryIntervalFallback; + this.waitMode = waitMode ?? false; + + this.requestId = 0; if (typeof batch === 'number') { this.batchClient = new BatchClient({ nodeUrl: this.nodeUrl, headers: this.headers, interval: batch, + baseFetch: this.baseFetch, }); } } @@ -110,7 +116,7 @@ export class RpcChannel { method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return this.baseFetch(this.nodeUrl, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, diff --git a/src/types/provider/configuration.ts b/src/types/provider/configuration.ts index 7076ea2c8..d5be9624d 100644 --- a/src/types/provider/configuration.ts +++ b/src/types/provider/configuration.ts @@ -13,6 +13,7 @@ export type RpcProviderOptions = { specVersion?: string; default?: boolean; waitMode?: boolean; + baseFetch?: WindowOrWorkerGlobalScope['fetch']; feeMarginPercentage?: { l1BoundMaxAmount: number; l1BoundMaxPricePerUnit: number; diff --git a/src/utils/batch/index.ts b/src/utils/batch/index.ts index f8130c83a..263ff0710 100644 --- a/src/utils/batch/index.ts +++ b/src/utils/batch/index.ts @@ -1,11 +1,12 @@ import { stringify } from '../json'; -import { RPC } from '../../types'; +import { RPC, RpcProviderOptions } from '../../types'; import { JRPC } from '../../types/api'; export type BatchClientOptions = { nodeUrl: string; headers: object; interval: number; + baseFetch: NonNullable; }; export class BatchClient { @@ -27,10 +28,13 @@ export class BatchClient { private delayPromiseResolve?: () => void; + private baseFetch: BatchClientOptions['baseFetch']; + constructor(options: BatchClientOptions) { this.nodeUrl = options.nodeUrl; this.headers = options.headers; this.interval = options.interval; + this.baseFetch = options.baseFetch; } private async wait(): Promise { @@ -77,7 +81,7 @@ export class BatchClient { } private async sendBatch(requests: JRPC.RequestBody[]) { - const raw = await fetch(this.nodeUrl, { + const raw = await this.baseFetch(this.nodeUrl, { method: 'POST', body: stringify(requests), headers: this.headers as Record,