diff --git a/src/bindings.d.ts b/src/bindings.d.ts index 99efbe1..60e7905 100644 --- a/src/bindings.d.ts +++ b/src/bindings.d.ts @@ -1,7 +1,45 @@ -import type { R2Bucket } from '@cloudflare/workers-types' +import type { Link } from 'multiformats/link' +import type { Context } from '@web3-storage/gateway-lib' +import type { CARLink } from 'cardex/api' +import type { R2Bucket, KVNamespace } from '@cloudflare/workers-types' +import type { MemoryBudget } from './lib/mem-budget' +import { CID } from '@web3-storage/gateway-lib/handlers' + +export {} export interface Environment { DEBUG: string CARPARK: R2Bucket CONTENT_CLAIMS_SERVICE_URL?: string + RATE_LIMITS_SERVICE_URL?: string + ACCOUNTING_SERVICE_URL: string + MY_RATE_LIMITER: RateLimit + AUTH_TOKEN_METADATA: KVNamespace +} + +export type GetCIDRequestData = Pick + +export type GetCIDRequestOptions = GetCIDRequestData + +export interface RateLimitsService { + check: (cid: CID, options: GetCIDRequestOptions) => Promise +} + +export interface TokenMetadata { + locationClaim?: unknown // TODO: figure out the right type to use for this - we probably need it for the private data case to verify auth + invalid?: boolean } + +export interface RateLimits { + create: ({ env }: { env: Environment }) => RateLimitsService +} + +export interface AccountingService { + record: (cid: CID, options: GetCIDRequestOptions) => Promise + getTokenMetadata: (token: string) => Promise +} + +export interface Accounting { + create: ({ serviceURL }: { serviceURL?: string }) => AccountingService +} + diff --git a/src/constants.js b/src/constants.js index 0ebcf5d..df030f9 100644 --- a/src/constants.js +++ b/src/constants.js @@ -1 +1,7 @@ export const CAR_CODE = 0x0202 + +export const RATE_LIMIT_EXCEEDED = { + NO: 0, + YES: 1, + MAYBE: 2 +} \ No newline at end of file diff --git a/src/index.js b/src/index.js index 9ab9dc8..71affaf 100644 --- a/src/index.js +++ b/src/index.js @@ -18,7 +18,8 @@ import { import { withContentClaimsDagula, withVersionHeader, - withCarBlockHandler + withCarBlockHandler, + withRateLimits } from './middleware.js' /** @@ -34,6 +35,7 @@ export default { fetch (request, env, ctx) { console.log(request.method, request.url) const middleware = composeMiddleware( + withRateLimits, withCdnCache, withContext, withCorsHeaders, diff --git a/src/middleware.js b/src/middleware.js index b7a42a7..78ccf4b 100644 --- a/src/middleware.js +++ b/src/middleware.js @@ -4,7 +4,7 @@ import { HttpError } from '@web3-storage/gateway-lib/util' import * as BatchingFetcher from '@web3-storage/blob-fetcher/fetcher/batching' import * as ContentClaimsLocator from '@web3-storage/blob-fetcher/locator/content-claims' import { version } from '../package.json' -import { CAR_CODE } from './constants.js' +import { CAR_CODE, RATE_LIMIT_EXCEEDED } from './constants.js' import { handleCarBlock } from './handlers/car-block.js' /** @@ -15,6 +15,149 @@ import { handleCarBlock } from './handlers/car-block.js' * @typedef {import('@web3-storage/gateway-lib').UnixfsContext} UnixfsContext */ +/** + * + * @param {string} s + * @returns {import('./bindings.js').TokenMetadata} + */ +function deserializeTokenMetadata(s) { + // TODO should this be dag-json? + return JSON.parse(s) +} + +/** + * + * @param {import('./bindings.js').TokenMetadata} m + * @returns string + */ +function serializeTokenMetadata(m) { + // TODO should this be dag-json? + return JSON.stringify(m) +} + +/** + * + * @param {Environment} env + * @param {import('@web3-storage/gateway-lib/handlers').CID} cid + */ +async function checkRateLimitForCID(env, cid) { + const rateLimitResponse = await env.MY_RATE_LIMITER.limit({ key: cid.toString() }) + if (rateLimitResponse.success) { + return RATE_LIMIT_EXCEEDED.NO + } else { + console.log(`limiting CID ${cid}`) + return RATE_LIMIT_EXCEEDED.YES + } +} + +/** + * + * @param {Environment} env + * @param {string} authToken + * @returns TokenMetadata + */ +async function getTokenMetadata(env, authToken) { + const cachedValue = await env.AUTH_TOKEN_METADATA.get(authToken) + // TODO: we should implement an SWR pattern here - record an expiry in the metadata and if the expiry has passed, re-validate the cache after + // returning the value + if (cachedValue) { + return deserializeTokenMetadata(cachedValue) + } else { + const accounting = Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) + const tokenMetadata = await accounting.getTokenMetadata(authToken) + if (tokenMetadata) { + await env.AUTH_TOKEN_METADATA.put(authToken, serializeTokenMetadata(tokenMetadata)) + return tokenMetadata + } else { + return null + } + } +} + +/** + * @type {import('./bindings.js').RateLimits} + */ +const RateLimits = { + create: ({ env }) => ({ + check: async (cid, options) => { + const authToken = await getAuthorizationTokenFromRequest(options) + if (authToken) { + console.log(`found token ${authToken}, looking for content commitment`) + const tokenMetadata = await getTokenMetadata(env, authToken) + + if (tokenMetadata) { + if (tokenMetadata.invalid) { + // this means we know about the token and we know it's invalid, so we should just use the CID rate limit + return checkRateLimitForCID(env, cid) + } else { + // TODO at some point we should enforce user configurable rate limits and origin matching + // but for now we just serve all valid token requests + return RATE_LIMIT_EXCEEDED.NO + } + } else { + // we didn't get any metadata - for now just use the top level rate limit + // this means token based requests will be subject to normal rate limits until the data propagates + return checkRateLimitForCID(env, cid) + } + } else { + // no token, use normal rate limit + return checkRateLimitForCID(env, cid) + } + } + }) +} + +/** + * @type {import('./bindings.js').Accounting} + */ +const Accounting = { + create: ({ serviceURL }) => ({ + record: async (cid, options) => { + console.log(`using ${serviceURL} to record a GET for ${cid} with options`, options) + }, + + getTokenMetadata: async () => { + // TODO I think this needs to check the content claims service (?) for any claims relevant to this token + // TODO do we have a plan for this? need to ask Hannah if the indexing service covers this? + return null + } + }) +} + +/** + * + * @param {Pick} request + * @returns string + */ +async function getAuthorizationTokenFromRequest(request) { + // TODO this is probably wrong + const authToken = request.headers.get('Authorization') + return authToken +} + +/** + * + * @type {import('@web3-storage/gateway-lib').Middleware} + */ +export function withRateLimits(handler) { + return async (request, env, ctx) => { + const { dataCid } = ctx + + const rateLimits = RateLimits.create({ env }) + const isRateLimitExceeded = await rateLimits.check(dataCid, request) + + if (isRateLimitExceeded === RATE_LIMIT_EXCEEDED.YES) { + // TODO should we record this? + throw new HttpError('Too Many Requests', { status: 429 }) + } else { + const accounting = Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) + // ignore the response from the accounting service - this is "fire and forget" + void accounting.record(dataCid, request) + return handler(request, env, ctx) + } + } +} + /** * Middleware that will serve CAR files if a CAR codec is found in the path * CID. If the CID is not a CAR CID it delegates to the next middleware. diff --git a/wrangler.toml b/wrangler.toml index 6eb9a3f..25cefa0 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -11,6 +11,24 @@ CONTENT_CLAIMS_SERVICE_URL = "https://dev.claims.web3.storage" [build] command = "npm run build:debug" +[[unsafe.bindings]] +# TODO BEFORE MERGE - update this to work in all environments - useful to do it like this for now +name = "MY_RATE_LIMITER" +type = "ratelimit" +# An identifier you define, that is unique to your Cloudflare account. +# Must be an integer. +namespace_id = "0" + +# Limit: the number of tokens allowed within a given period in a single +# Cloudflare location +# Period: the duration of the period, in seconds. Must be either 10 or 60 +simple = { limit = 100, period = 60 } + +[[kv_namespaces]] +# TODO BEFORE MERGE - update this to work in all environments - useful to do it like this for now +binding = "AUTH_TOKEN_METADATA" +id = "f848730e45d94f17bcaf3b6d0915da40" + # PROD! [env.production] account_id = "fffa4b4363a7e5250af8357087263b3a"