Skip to content

Commit

Permalink
Fire passthrough-websocket-connect rule event on passthrough WebSockets
Browse files Browse the repository at this point in the history
  • Loading branch information
pimterry committed Jan 29, 2025
1 parent bb61c2c commit c9c4438
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 16 deletions.
33 changes: 24 additions & 9 deletions src/rules/websockets/websocket-handlers.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import * as _ from 'lodash';
import net = require('net');
import * as url from 'url';
import * as tls from 'tls';
import * as http from 'http';
import * as fs from 'fs/promises';
import * as WebSocket from 'ws';

import {
Expand All @@ -12,10 +10,11 @@ import {
deserializeProxyConfig
} from "../../serialization/serialization";

import { OngoingRequest, RawHeaders } from "../../types";
import { Headers, OngoingRequest, RawHeaders } from "../../types";

import {
CloseConnectionHandler,
RequestHandlerOptions,
ResetConnectionHandler,
TimeoutHandler
} from '../requests/request-handlers';
Expand Down Expand Up @@ -60,7 +59,9 @@ export interface WebSocketHandler extends WebSocketHandlerDefinition {
// The raw socket on which we'll be communicating
socket: net.Socket,
// Initial data received
head: Buffer
head: Buffer,
// Other general handler options
options: RequestHandlerOptions
): Promise<void>;
}

Expand Down Expand Up @@ -219,7 +220,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
return this._trustedCACertificates;
}

async handle(req: OngoingRequest, socket: net.Socket, head: Buffer) {
async handle(req: OngoingRequest, socket: net.Socket, head: Buffer, options: RequestHandlerOptions) {
this.initializeWsServer();

let { protocol, hostname, port, path } = url.parse(req.url!);
Expand Down Expand Up @@ -266,7 +267,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
hostHeader[1] = updateHostHeader;
} // Otherwise: falsey means don't touch it.

await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
} else if (!hostname) { // No hostname in URL means transparent proxy, so use Host header
const hostHeader = req.headers[hostHeaderName];
[ hostname, port ] = hostHeader!.split(':');
Expand All @@ -280,14 +281,14 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
}

const wsUrl = `${protocol}://${hostname}${port ? ':' + port : ''}${path}`;
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
} else {
// Connect directly according to the specified URL
const wsUrl = `${
protocol!.replace('http', 'ws')
}//${hostname}${port ? ':' + port : ''}${path}`;

await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head, options);
}
}

Expand All @@ -296,7 +297,8 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
req: http.IncomingMessage,
rawHeaders: RawHeaders,
incomingSocket: net.Socket,
head: Buffer
head: Buffer,
options: RequestHandlerOptions
) {
const parsedUrl = url.parse(wsUrl);

Expand Down Expand Up @@ -370,6 +372,19 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
...caConfig
} as WebSocket.ClientOptions & { lookup: any, maxPayload: number });

if (options.emitEventCallback) {
const upstreamReq = (upstreamWebSocket as any as { _req: http.ClientRequest })._req;
options.emitEventCallback('passthrough-websocket-connect', {
method: upstreamReq.method,
protocol: upstreamReq.protocol.replace(/:$/, ''),
hostname: upstreamReq.host,
port: effectivePort.toString(),
path: upstreamReq.path,
rawHeaders: objectHeadersToRaw(upstreamReq.getHeaders() as Headers),
subprotocols: filteredSubprotocols
});
}

upstreamWebSocket.once('open', () => {
// Used in the subprotocol selection handler during the upgrade:
(req as InterceptedWebSocketRequest).upstreamWebSocketProtocol = upstreamWebSocket.protocol || false;
Expand Down
24 changes: 20 additions & 4 deletions src/rules/websockets/websocket-rule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ export interface WebSocketRule extends Explainable {

// We don't extend the main interfaces for these, because MockRules are not Serializable
matches(request: OngoingRequest): MaybePromise<boolean>;
handle(request: OngoingRequest, response: net.Socket, head: Buffer, record: boolean): Promise<void>;
handle(
request: OngoingRequest,
response: net.Socket,
head: Buffer,
options: {
record: boolean,
emitEventCallback?: (type: string, event: unknown) => void
}
): Promise<void>;
isComplete(): boolean | null;
}

Expand Down Expand Up @@ -71,14 +79,22 @@ export class WebSocketRule implements WebSocketRule {
return matchers.matchesAll(request, this.matchers);
}

handle(req: OngoingRequest, res: net.Socket, head: Buffer, record: boolean): Promise<void> {
handle(
req: OngoingRequest,
res: net.Socket,
head: Buffer,
options: {
record: boolean,
emitEventCallback?: (type: string, event: unknown) => void
}
): Promise<void> {
let handlerPromise = (async () => { // Catch (a)sync errors
return this.handler.handle(req as OngoingRequest & http.IncomingMessage, res, head);
return this.handler.handle(req as OngoingRequest & http.IncomingMessage, res, head, options);
})();

// Requests are added to rule.requests as soon as they start being handled,
// as promises, which resolve only when the response & request body is complete.
if (record) {
if (options.record) {
this.requests.push(
Promise.race([
// When the handler resolves, the request is completed:
Expand Down
13 changes: 11 additions & 2 deletions src/server/mockttp-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,24 @@ export class MockttpServer extends AbstractMockttp implements Mockttp {
let nextRule = await nextRulePromise;
if (nextRule) {
if (this.debug) console.log(`Websocket matched rule: ${nextRule.explain()}`);
await nextRule.handle(request, socket, head, this.recordTraffic);
await nextRule.handle(request, socket, head, {
record: this.recordTraffic,
emitEventCallback: (this.eventEmitter.listenerCount('rule-event') !== 0)
? (type, event) => this.announceRuleEventAsync(request.id, nextRule!.id, type, event)
: undefined
});
} else {
// Unmatched requests get passed through untouched automatically. This exists for
// historical/backward-compat reasons, to match the initial WS implementation, and
// will probably be removed to match handleRequest in future.
await this.defaultWsHandler.handle(
request as OngoingRequest & http.IncomingMessage,
socket,
head
head,
{ emitEventCallback: (this.eventEmitter.listenerCount('rule-event') !== 0)
? (type, event) => this.announceRuleEventAsync(request.id, nextRule!.id, type, event)
: undefined
}
);
}
} catch (e) {
Expand Down
66 changes: 65 additions & 1 deletion test/integration/subscriptions/rule-events.spec.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import * as _ from 'lodash';
import * as WebSocket from 'isomorphic-ws';

import {
getLocal,
RawHeaders,
RuleEvent
} from "../../..";
import {
delay,
expect,
fetch
fetch,
isNode
} from "../../test-utils";

describe("Rule event susbcriptions", () => {
Expand Down Expand Up @@ -168,4 +171,65 @@ describe("Rule event susbcriptions", () => {
expect(responseBodyEvent.rawBody.toString('utf8')).to.equal('Original response body');
});

it("should fire for proxied websockets", async () => {
await remoteServer.forAnyWebSocket().thenPassivelyListen();
const forwardingRule = await server.forAnyWebSocket().thenForwardTo(remoteServer.url);

const ruleEvents: RuleEvent<any>[] = [];
await server.on('rule-event', (e) => ruleEvents.push(e));

const ws = new WebSocket(`ws://localhost:${server.port}`);
const downstreamWsKey = isNode
? (ws as any)._req.getHeaders()['sec-websocket-key']
: undefined;

await new Promise<void>((resolve, reject) => {
ws.addEventListener('open', () => {
resolve();
ws.close();
});
ws.addEventListener('error', reject);
});

await delay(100);

expect(ruleEvents.length).to.equal(1);

const requestId = (await forwardingRule.getSeenRequests())[0].id;
ruleEvents.forEach((event) => {
expect(event.ruleId).to.equal(forwardingRule.id);
expect(event.requestId).to.equal(requestId);
});

expect(ruleEvents.map(e => e.eventType)).to.deep.equal([
'passthrough-websocket-connect'
]);

const connectEvent = ruleEvents[0].eventData;
expect(_.omit(connectEvent, 'rawHeaders')).to.deep.equal({
method: 'GET',
protocol: 'http',
hostname: 'localhost',
// This reports the *modified* port, not the original:
port: remoteServer.port.toString(),
path: '/',
subprotocols: []
});

// This reports the *modified* header, not the original:
expect(connectEvent.rawHeaders).to.deep.include(['host', `localhost:${remoteServer.port}`]);
expect(connectEvent.rawHeaders).to.deep.include(['sec-websocket-version', '13']);
expect(connectEvent.rawHeaders).to.deep.include(['sec-websocket-extensions', 'permessage-deflate; client_max_window_bits']);
expect(connectEvent.rawHeaders).to.deep.include(['connection', 'Upgrade']);
expect(connectEvent.rawHeaders).to.deep.include(['upgrade', 'websocket']);

// Make sure we want to see the upstream WS key, not the downstream one
const upstreamWsKey = (connectEvent.rawHeaders as RawHeaders)
.find(([key]) => key.toLowerCase() === 'sec-websocket-key')!;
expect(upstreamWsKey[1]).to.not.equal(downstreamWsKey);
});

// For now, we only support transformation of websocket URLs in forwarding, and nothing
// else, so initial conn params are the only passthrough data that's useful to expose.

});

0 comments on commit c9c4438

Please sign in to comment.