Skip to content

Commit

Permalink
feat: enable base fetch override (starknet-io#1279)
Browse files Browse the repository at this point in the history
  • Loading branch information
penovicp authored Dec 3, 2024
1 parent 394d6e3 commit 0fce61e
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 41 deletions.
31 changes: 22 additions & 9 deletions __tests__/rpcChannel.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import { LibraryError, RPC07, RpcError } from '../src';
import { LibraryError, RPC06, RPC07, RpcError } from '../src';
import { createBlockForDevnet, getTestProvider } from './config/fixtures';
import { initializeMatcher } from './config/schema';

describe('RPC 0.7.0', () => {
const rpcProvider = getTestProvider(false);
const channel = rpcProvider.channel as RPC07.RpcChannel;
describe('RpcChannel', () => {
const { nodeUrl } = getTestProvider(false).channel;
const channel07 = new RPC07.RpcChannel({ nodeUrl });
initializeMatcher(expect);

beforeAll(async () => {
await createBlockForDevnet();
});

test('getBlockWithReceipts', async () => {
const response = await channel.getBlockWithReceipts('latest');
expect(response).toMatchSchemaRef('BlockWithTxReceipts');
test('baseFetch override', async () => {
const baseFetch = jest.fn();
const fetchChannel06 = new RPC06.RpcChannel({ nodeUrl, baseFetch });
const fetchChannel07 = new RPC07.RpcChannel({ nodeUrl, baseFetch });
(fetchChannel06.fetch as any)();
expect(baseFetch).toHaveBeenCalledTimes(1);
baseFetch.mockClear();
(fetchChannel07.fetch as any)();
expect(baseFetch).toHaveBeenCalledTimes(1);
});

test('RPC error handling', async () => {
const fetchSpy = jest.spyOn(channel, 'fetch');
const fetchSpy = jest.spyOn(channel07, 'fetch');
fetchSpy.mockResolvedValue({
json: async () => ({
jsonrpc: '2.0',
Expand All @@ -32,12 +38,19 @@ describe('RPC 0.7.0', () => {
expect.assertions(3);
try {
// @ts-expect-error
await channel.fetchEndpoint('starknet_chainId');
await channel07.fetchEndpoint('starknet_chainId');
} catch (error) {
expect(error).toBeInstanceOf(LibraryError);
expect(error).toBeInstanceOf(RpcError);
expect((error as RpcError).isType('BLOCK_NOT_FOUND')).toBe(true);
}
fetchSpy.mockRestore();
});

describe('RPC 0.7.0', () => {
test('getBlockWithReceipts', async () => {
const response = await channel07.getBlockWithReceipts('latest');
expect(response).toMatchSchemaRef('BlockWithTxReceipts');
});
});
});
8 changes: 8 additions & 0 deletions __tests__/rpcProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ describeIfRpc('RPCProvider', () => {
await createBlockForDevnet();
});

test('baseFetch override', async () => {
const { nodeUrl } = rpcProvider.channel;
const baseFetch = jest.fn();
const fetchProvider = new RpcProvider({ nodeUrl, baseFetch });
(fetchProvider.fetch as any)();
expect(baseFetch.mock.calls.length).toBe(1);
});

test('instantiate from rpcProvider', () => {
const newInsRPCProvider = new RpcProvider();

Expand Down
2 changes: 2 additions & 0 deletions __tests__/utils/batch.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fetch from '../../src/utils/fetchPonyfill';
import { BatchClient } from '../../src/utils/batch';
import { createBlockForDevnet, getTestProvider } from '../config/fixtures';
import { initializeMatcher } from '../config/schema';
Expand All @@ -9,6 +10,7 @@ describe('Batch Client', () => {
nodeUrl: provider.channel.nodeUrl,
headers: provider.channel.headers,
interval: 0,
baseFetch: fetch,
});

initializeMatcher(expect);
Expand Down
36 changes: 21 additions & 15 deletions src/channel/rpc_0_6.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,36 @@ export class RpcChannel {

public headers: object;

readonly retries: number;

public requestId: number;

readonly blockIdentifier: BlockIdentifier;

readonly retries: number;

readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed

private chainId?: StarknetChainId;

private specVersion?: string;

private transactionRetryIntervalFallback?: number;

readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed

private batchClient?: BatchClient;

private baseFetch: NonNullable<RpcProviderOptions['baseFetch']>;

constructor(optionsOrProvider?: RpcProviderOptions) {
const {
nodeUrl,
retries,
headers,
baseFetch,
batch,
blockIdentifier,
chainId,
headers,
nodeUrl,
retries,
specVersion,
waitMode,
transactionRetryIntervalFallback,
batch,
waitMode,
} = optionsOrProvider || {};
if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) {
this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default);
Expand All @@ -77,20 +80,23 @@ export class RpcChannel {
} else {
this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default);
}
this.retries = retries || defaultOptions.retries;
this.headers = { ...defaultOptions.headers, ...headers };
this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier;
this.baseFetch = baseFetch ?? fetch;
this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier;
this.chainId = chainId;
this.headers = { ...defaultOptions.headers, ...headers };
this.retries = retries ?? defaultOptions.retries;
this.specVersion = specVersion;
this.waitMode = waitMode || false;
this.requestId = 0;
this.transactionRetryIntervalFallback = transactionRetryIntervalFallback;
this.waitMode = waitMode ?? false;

this.requestId = 0;

if (typeof batch === 'number') {
this.batchClient = new BatchClient({
nodeUrl: this.nodeUrl,
headers: this.headers,
interval: batch,
baseFetch: this.baseFetch,
});
}
}
Expand All @@ -110,7 +116,7 @@ export class RpcChannel {
method,
...(params && { params }),
};
return fetch(this.nodeUrl, {
return this.baseFetch(this.nodeUrl, {
method: 'POST',
body: stringify(rpcRequestBody),
headers: this.headers as Record<string, string>,
Expand Down
36 changes: 21 additions & 15 deletions src/channel/rpc_0_7.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,36 @@ export class RpcChannel {

public headers: object;

readonly retries: number;

public requestId: number;

readonly blockIdentifier: BlockIdentifier;

readonly retries: number;

readonly waitMode: boolean; // behave like web2 rpc and return when tx is processed

private chainId?: StarknetChainId;

private specVersion?: string;

private transactionRetryIntervalFallback?: number;

readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed

private batchClient?: BatchClient;

private baseFetch: NonNullable<RpcProviderOptions['baseFetch']>;

constructor(optionsOrProvider?: RpcProviderOptions) {
const {
nodeUrl,
retries,
headers,
baseFetch,
batch,
blockIdentifier,
chainId,
headers,
nodeUrl,
retries,
specVersion,
waitMode,
transactionRetryIntervalFallback,
batch,
waitMode,
} = optionsOrProvider || {};
if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) {
this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default);
Expand All @@ -77,20 +80,23 @@ export class RpcChannel {
} else {
this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default);
}
this.retries = retries || defaultOptions.retries;
this.headers = { ...defaultOptions.headers, ...headers };
this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier;
this.baseFetch = baseFetch ?? fetch;
this.blockIdentifier = blockIdentifier ?? defaultOptions.blockIdentifier;
this.chainId = chainId;
this.headers = { ...defaultOptions.headers, ...headers };
this.retries = retries ?? defaultOptions.retries;
this.specVersion = specVersion;
this.waitMode = waitMode || false;
this.requestId = 0;
this.transactionRetryIntervalFallback = transactionRetryIntervalFallback;
this.waitMode = waitMode ?? false;

this.requestId = 0;

if (typeof batch === 'number') {
this.batchClient = new BatchClient({
nodeUrl: this.nodeUrl,
headers: this.headers,
interval: batch,
baseFetch: this.baseFetch,
});
}
}
Expand All @@ -110,7 +116,7 @@ export class RpcChannel {
method,
...(params && { params }),
};
return fetch(this.nodeUrl, {
return this.baseFetch(this.nodeUrl, {
method: 'POST',
body: stringify(rpcRequestBody),
headers: this.headers as Record<string, string>,
Expand Down
1 change: 1 addition & 0 deletions src/types/provider/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export type RpcProviderOptions = {
specVersion?: string;
default?: boolean;
waitMode?: boolean;
baseFetch?: WindowOrWorkerGlobalScope['fetch'];
feeMarginPercentage?: {
l1BoundMaxAmount: number;
l1BoundMaxPricePerUnit: number;
Expand Down
8 changes: 6 additions & 2 deletions src/utils/batch/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { stringify } from '../json';
import { RPC } from '../../types';
import { RPC, RpcProviderOptions } from '../../types';
import { JRPC } from '../../types/api';

export type BatchClientOptions = {
nodeUrl: string;
headers: object;
interval: number;
baseFetch: NonNullable<RpcProviderOptions['baseFetch']>;
};

export class BatchClient {
Expand All @@ -27,10 +28,13 @@ export class BatchClient {

private delayPromiseResolve?: () => void;

private baseFetch: BatchClientOptions['baseFetch'];

constructor(options: BatchClientOptions) {
this.nodeUrl = options.nodeUrl;
this.headers = options.headers;
this.interval = options.interval;
this.baseFetch = options.baseFetch;
}

private async wait(): Promise<void> {
Expand Down Expand Up @@ -77,7 +81,7 @@ export class BatchClient {
}

private async sendBatch(requests: JRPC.RequestBody[]) {
const raw = await fetch(this.nodeUrl, {
const raw = await this.baseFetch(this.nodeUrl, {
method: 'POST',
body: stringify(requests),
headers: this.headers as Record<string, string>,
Expand Down

0 comments on commit 0fce61e

Please sign in to comment.