Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate websocket request #7317

Merged
merged 17 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/config/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1749,12 +1749,23 @@ Open Vitest UI (WIP)

### api

- **Type:** `boolean | number`
- **Type:** `boolean | number | ApiConfig`
- **Default:** `false`
- **CLI:** `--api`, `--api.port`, `--api.host`, `--api.strictPort`

Listen to port and serve API. When set to true, the default port is 51204

### api.allowedHosts
hi-ogawa marked this conversation as resolved.
Show resolved Hide resolved

- **Type:** `string[] | true`
- **Default:** `[]`

The hostnames that Vitest API is allowed to respond to. `localhost` and domains under `.localhost` and all IP addresses are allowed by default. When using HTTPS, this check is skipped.

If a string starts with `.`, it will allow that hostname without the `.` and all subdomains under the hostname. For example, `.example.com` will allow `example.com`, `foo.example.com`, and `foo.bar.example.com`.

If set to `true`, the server is allowed to respond to requests for any hosts. This is not recommended as it will be vulnerable to DNS rebinding attacks.

### browser <Badge type="warning">experimental</Badge> {#browser}

- **Default:** `{ enabled: false }`
Expand Down
2 changes: 1 addition & 1 deletion packages/browser/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const RPC_ID
const METHOD = getBrowserState().method
export const ENTRY_URL = `${
location.protocol === 'https:' ? 'wss:' : 'ws:'
}//${HOST}/__vitest_browser_api__?type=${PAGE_TYPE}&rpcId=${RPC_ID}&sessionId=${getBrowserState().sessionId}&projectName=${getBrowserState().config.name || ''}&method=${METHOD}`
}//${HOST}/__vitest_browser_api__?type=${PAGE_TYPE}&rpcId=${RPC_ID}&sessionId=${getBrowserState().sessionId}&projectName=${getBrowserState().config.name || ''}&method=${METHOD}&token=${(window as any).VITEST_API_TOKEN}`

let setCancel = (_: CancelReason) => {}
export const onCancel = new Promise<CancelReason>((resolve) => {
Expand Down
1 change: 1 addition & 0 deletions packages/browser/src/client/public/esm-client-injector.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
method: { __VITEST_METHOD__ },
providedContext: { __VITEST_PROVIDED_CONTEXT__ },
};
window.VITEST_API_TOKEN = { __VITEST_API_TOKEN__ };

const config = __vitest_browser_runner__.config;

Expand Down
7 changes: 6 additions & 1 deletion packages/browser/src/node/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { ServerMockResolver } from '@vitest/mocker/node'
import { createBirpc } from 'birpc'
import { parse, stringify } from 'flatted'
import { dirname } from 'pathe'
import { createDebugger, isFileServingAllowed } from 'vitest/node'
import { createDebugger, isFileServingAllowed, isWebsocketRequestAllowed } from 'vitest/node'
import { WebSocketServer } from 'ws'

const debug = createDebugger('vitest:browser:api')
Expand All @@ -33,6 +33,11 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject) {
return
}

if (!isWebsocketRequestAllowed(vitest.config, vite.config, request)) {
socket.destroy()
return
}

const type = searchParams.get('type')
const rpcId = searchParams.get('rpcId')
const sessionId = searchParams.get('sessionId')
Expand Down
1 change: 1 addition & 0 deletions packages/browser/src/node/serverOrchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export async function resolveOrchestrator(
__VITEST_SESSION_ID__: JSON.stringify(sessionId),
__VITEST_TESTER_ID__: '"none"',
__VITEST_PROVIDED_CONTEXT__: '{}',
__VITEST_API_TOKEN__: JSON.stringify(globalServer.vitest.config.api.token),
})

// disable CSP for the orchestrator as we are the ones controlling it
Expand Down
1 change: 1 addition & 0 deletions packages/browser/src/node/serverTester.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export async function resolveTester(
__VITEST_SESSION_ID__: JSON.stringify(sessionId),
__VITEST_TESTER_ID__: JSON.stringify(crypto.randomUUID()),
__VITEST_PROVIDED_CONTEXT__: JSON.stringify(stringify(project.getProvidedContext())),
__VITEST_API_TOKEN__: JSON.stringify(globalServer.vitest.config.api.token),
})

const testerHtml = typeof browserProject.testerHtml === 'string'
Expand Down
2 changes: 1 addition & 1 deletion packages/ui/client/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ export const PORT = import.meta.hot && !browserState ? '51204' : location.port
export const HOST = [location.hostname, PORT].filter(Boolean).join(':')
export const ENTRY_URL = `${
location.protocol === 'https:' ? 'wss:' : 'ws:'
}//${HOST}/__vitest_api__`
}//${HOST}/__vitest_api__?token=${(window as any).VITEST_API_TOKEN}`
export const isReport = !!window.METADATA_PATH
export const BASE_PATH = isReport ? import.meta.env.BASE_URL : __BASE_PATH__
22 changes: 22 additions & 0 deletions packages/ui/node/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { Plugin } from 'vite'
import type { Vitest } from 'vitest/node'
import fs from 'node:fs'
import { fileURLToPath } from 'node:url'
import { toArray } from '@vitest/utils'
import { basename, resolve } from 'pathe'
Expand Down Expand Up @@ -52,6 +53,27 @@ export default (ctx: Vitest): Plugin => {
}

const clientDist = resolve(fileURLToPath(import.meta.url), '../client')
const clientIndexHtml = fs.readFileSync(resolve(clientDist, 'index.html'), 'utf-8')

// serve index.html with api token
server.middlewares.use((req, res, next) => {
if (req.url) {
const url = new URL(req.url, 'http://localhost')
if (url.pathname === base) {
const html = clientIndexHtml.replace(
'<!-- !LOAD_METADATA! -->',
`<script>window.VITEST_API_TOKEN = ${JSON.stringify(ctx.config.api.token)}</script>`,
)
res.setHeader('Cache-Control', 'no-cache, max-age=0, must-revalidate')
res.setHeader('Content-Type', 'text/html; charset=utf-8')
res.write(html)
res.end()
return
}
}
next()
})

server.middlewares.use(
base,
sirv(clientDist, {
Expand Down
1 change: 1 addition & 0 deletions packages/vitest/rollup.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const external = [
'node:os',
'node:stream',
'node:vm',
'node:http',
'inspector',
'vite-node/source-map',
'vite-node/client',
Expand Down
163 changes: 163 additions & 0 deletions packages/vitest/src/api/hostCheck.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import type { IncomingMessage } from 'node:http'
import type { ResolvedConfig } from 'vite'
import type { ResolvedConfig as VitestResolvedConfig } from '../node/types/config'
import crypto from 'node:crypto'
import net from 'node:net'

// based on
// https://github.com/vitejs/vite/blob/9654348258eaa0883171533a2b74b4e2825f5fb6/packages/vite/src/node/server/middlewares/hostCheck.ts

const isFileOrExtensionProtocolRE = /^(?:file|.+-extension):/i

function getAdditionalAllowedHosts(
resolvedServerOptions: Pick<ResolvedConfig['server'], 'host' | 'hmr' | 'origin'>,
resolvedPreviewOptions: Pick<ResolvedConfig['preview'], 'host'>,
): string[] {
const list = []

// allow host option by default as that indicates that the user is
// expecting Vite to respond on that host
if (
typeof resolvedServerOptions.host === 'string'
&& resolvedServerOptions.host
) {
list.push(resolvedServerOptions.host)
}
if (
typeof resolvedServerOptions.hmr === 'object'
&& resolvedServerOptions.hmr.host
) {
list.push(resolvedServerOptions.hmr.host)
}
if (
typeof resolvedPreviewOptions.host === 'string'
&& resolvedPreviewOptions.host
) {
list.push(resolvedPreviewOptions.host)
}

// allow server origin by default as that indicates that the user is
// expecting Vite to respond on that host
if (resolvedServerOptions.origin) {
try {
const serverOriginUrl = new URL(resolvedServerOptions.origin)
list.push(serverOriginUrl.hostname)
}
catch {}
}

return list
}

// Based on webpack-dev-server's `checkHeader` function: https://github.com/webpack/webpack-dev-server/blob/v5.2.0/lib/Server.js#L3086
// https://github.com/webpack/webpack-dev-server/blob/v5.2.0/LICENSE
function isHostAllowedWithoutCache(
allowedHosts: string[],
additionalAllowedHosts: string[],
host: string,
): boolean {
if (isFileOrExtensionProtocolRE.test(host)) {
return true
}

// We don't care about malformed Host headers,
// because we only need to consider browser requests.
// Non-browser clients can send any value they want anyway.
//
// `Host = uri-host [ ":" port ]`
const trimmedHost = host.trim()

// IPv6
if (trimmedHost[0] === '[') {
const endIpv6 = trimmedHost.indexOf(']')
if (endIpv6 < 0) {
return false
}
// DNS rebinding attacks does not happen with IP addresses
return net.isIP(trimmedHost.slice(1, endIpv6)) === 6
}

// uri-host does not include ":" unless IPv6 address
const colonPos = trimmedHost.indexOf(':')
const hostname
= colonPos === -1 ? trimmedHost : trimmedHost.slice(0, colonPos)

// DNS rebinding attacks does not happen with IP addresses
if (net.isIP(hostname) === 4) {
return true
}

// allow localhost and .localhost by default as they always resolve to the loopback address
// https://datatracker.ietf.org/doc/html/rfc6761#section-6.3
if (hostname === 'localhost' || hostname.endsWith('.localhost')) {
return true
}

for (const additionalAllowedHost of additionalAllowedHosts) {
if (additionalAllowedHost === hostname) {
return true
}
}

for (const allowedHost of allowedHosts) {
if (allowedHost === hostname) {
return true
}

// allow all subdomains of it
// e.g. `.foo.example` will allow `foo.example`, `*.foo.example`, `*.*.foo.example`, etc
if (
allowedHost[0] === '.'
&& (allowedHost.slice(1) === hostname || hostname.endsWith(allowedHost))
) {
return true
}
}

return false
}

/**
* @param vitestConfig
* @param viteConfig resolved config
* @param host the value of host header. See [RFC 9110 7.2](https://datatracker.ietf.org/doc/html/rfc9110#name-host-and-authority).
*/
function isHostAllowed(vitestConfig: VitestResolvedConfig, viteConfig: ResolvedConfig, host: string): boolean {
const apiAllowedHosts = vitestConfig.api.allowedHosts ?? []
if (apiAllowedHosts === true) {
return true
}
// Vitest only validates websocket upgrade request, so caching won't probably matter.
return isHostAllowedWithoutCache(
apiAllowedHosts,
getAdditionalAllowedHosts(viteConfig.server, viteConfig.preview),
host,
)
}

export function isWebsocketRequestAllowed(vitestConfig: VitestResolvedConfig, viteConfig: ResolvedConfig, req: IncomingMessage): boolean {
const url = new URL(req.url ?? '', 'http://localhost')

// validate token. token is injected in ui/tester/orchestrator html, which is cross origin proteced.
try {
const token = url.searchParams.get('token')
if (!token || !crypto.timingSafeEqual(
Buffer.from(token),
Buffer.from(vitestConfig.api.token),
)) {
return false
}
}
catch {
// an error is thrown when the length is incorrect
return false
}

// host check to prevent DNS rebinding attacks
// (websocket upgrade request cannot be http2 even on `wss`, so `host` header is guaranteed.)
if (!req.headers.host || !isHostAllowed(vitestConfig, viteConfig, req.headers.host)) {
return false
}

return true
}
9 changes: 8 additions & 1 deletion packages/vitest/src/api/setup.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { File, TaskResultPack } from '@vitest/runner'

import type { IncomingMessage } from 'node:http'
import type { ViteDevServer } from 'vite'
import type { WebSocket } from 'ws'
import type { Vitest } from '../node/core'
Expand All @@ -21,6 +22,7 @@ import { API_PATH } from '../constants'
import { getModuleGraph } from '../utils/graph'
import { stringifyReplace } from '../utils/serialization'
import { parseErrorStacktrace } from '../utils/source-map'
import { isWebsocketRequestAllowed } from './hostCheck'

export function setup(ctx: Vitest, _server?: ViteDevServer) {
const wss = new WebSocketServer({ noServer: true })
Expand All @@ -29,7 +31,7 @@ export function setup(ctx: Vitest, _server?: ViteDevServer) {

const server = _server || ctx.server

server.httpServer?.on('upgrade', (request, socket, head) => {
server.httpServer?.on('upgrade', (request: IncomingMessage, socket, head) => {
if (!request.url) {
return
}
Expand All @@ -39,6 +41,11 @@ export function setup(ctx: Vitest, _server?: ViteDevServer) {
return
}

if (!isWebsocketRequestAllowed(ctx.config, server.config, request)) {
socket.destroy()
return
}

wss.handleUpgrade(request, socket, head, (ws) => {
wss.emit('connection', ws, request)
setupClient(ws)
Expand Down
1 change: 1 addition & 0 deletions packages/vitest/src/node/cli/cli-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const apiConfig: (port: number) => CLIOptions<ApiConfig> = (port: number) => ({
'Set to true to exit if port is already in use, instead of automatically trying the next available port',
},
middlewareMode: null,
allowedHosts: null,
})

const poolThreadsCommands: CLIOptions<ThreadsOptions & WorkerContextOptions> = {
Expand Down
4 changes: 3 additions & 1 deletion packages/vitest/src/node/config/resolveConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
} from '../types/config'
import type { BaseCoverageOptions, CoverageReporterWithOptions } from '../types/coverage'
import type { BuiltinPool, ForksOptions, PoolOptions, ThreadsOptions } from '../types/pool-options'
import crypto from 'node:crypto'
import { toArray } from '@vitest/utils'
import { resolveModule } from 'local-pkg'
import { normalize, relative, resolve } from 'pathe'
Expand Down Expand Up @@ -629,7 +630,8 @@ export function resolveConfig(
}

// the server has been created, we don't need to override vite.server options
resolved.api = resolveApiServerConfig(options, defaultPort)
const api = resolveApiServerConfig(options, defaultPort)
resolved.api = { ...api, token: crypto.randomUUID() }

if (options.related) {
resolved.related = toArray(options.related).map(file =>
Expand Down
6 changes: 4 additions & 2 deletions packages/vitest/src/node/types/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ export type CSSModuleScopeStrategy = 'stable' | 'scoped' | 'non-scoped'
export type ApiConfig = Pick<
ServerOptions,
'port' | 'strictPort' | 'host' | 'middlewareMode'
>
> & {
allowedHosts?: string[] | true
}

export type { EnvironmentOptions, HappyDOMOptions, JSDOMOptions }

Expand Down Expand Up @@ -1013,7 +1015,7 @@ export interface ResolvedConfig

defines: Record<string, any>

api?: ApiConfig
api: ApiConfig & { token: string }
cliExclude?: string[]

benchmark?: Required<
Expand Down
1 change: 1 addition & 0 deletions packages/vitest/src/public/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { TestModule as _TestFile } from '../node/reporters/reported-tasks'

export const version = Vitest.version

export { isWebsocketRequestAllowed } from '../api/hostCheck'
export { parseCLI } from '../node/cli/cac'
export type { CliParseOptions } from '../node/cli/cac'
export { startVitest } from '../node/cli/cli-api'
Expand Down
2 changes: 2 additions & 0 deletions test/config/test/override.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ describe('correctly defines api flag', () => {
expect(c.server.config.server.middlewareMode).toBe(true)
expect(c.config.api).toEqual({
middlewareMode: true,
token: expect.any(String),
})
})

Expand All @@ -262,6 +263,7 @@ describe('correctly defines api flag', () => {
expect(c.server.config.server.port).toBe(4321)
expect(c.config.api).toEqual({
port: 4321,
token: expect.any(String),
})
})
})
Expand Down
Loading