From 6e6f1bec1ba6c315972a20e168632783227115bf Mon Sep 17 00:00:00 2001 From: Bob Farrell Date: Thu, 21 Nov 2024 22:09:15 +0000 Subject: [PATCH] Closes #108: Anti-CSRF middleware Add to middleware in app's `src/main.zig`: ```zig pub const jetzig_options = struct { pub const middleware: []const type = &.{ jetzig.middleware.AntiCsrfMiddleware, }; }; ``` CSRF token available in Zmpl templates: ``` {{context.authenticityToken()}} ``` or render a hidden form element: ``` {{context.authenticityFormElement()}} ``` The following HTML requests are rejected (403 Forbidden) if the submitted query param does not match the value stored in the encrypted session (added automatically when the token is generated for a template value): * POST * PUT * PATCH * DELETE JSON requests are not impacted - users should either disable JSON endpoints or implement a different authentication method to protect them. --- .github/workflows/CI.yml | 2 +- build.zig | 21 ++- build.zig.zon | 4 +- demo/src/app/views/anti_csrf.zig | 77 ++++++++++ demo/src/app/views/anti_csrf/index.zmpl | 8 ++ demo/src/app/views/anti_csrf/post.zmpl | 5 + demo/src/main.zig | 17 ++- src/Routes.zig | 4 +- src/commands/routes.zig | 4 +- src/compile_static_routes.zig | 11 +- src/jetzig.zig | 10 +- src/jetzig/App.zig | 6 +- src/jetzig/TemplateContext.zig | 25 ++++ src/jetzig/callbacks.zig | 112 +++++++++++++++ src/jetzig/config.zig | 24 +++- src/jetzig/development_static.zig | 12 ++ src/jetzig/http.zig | 7 +- src/jetzig/http/Cookies.zig | 142 +++++++++++-------- src/jetzig/http/Headers.zig | 7 +- src/jetzig/http/Request.zig | 40 +++++- src/jetzig/http/Server.zig | 93 +++++++++--- src/jetzig/mail/Job.zig | 4 +- src/jetzig/middleware.zig | 1 + src/jetzig/middleware/AntiCsrfMiddleware.zig | 73 ++++++++++ src/jetzig/middleware/AuthMiddleware.zig | 15 +- src/jetzig/testing/App.zig | 96 +++++++++++-- src/jetzig/util.zig | 12 +- src/jetzig/views/Route.zig | 2 + 28 files changed, 696 insertions(+), 138 deletions(-) create mode 100644 demo/src/app/views/anti_csrf.zig create mode 100644 demo/src/app/views/anti_csrf/index.zmpl create mode 100644 demo/src/app/views/anti_csrf/post.zmpl create mode 100644 src/jetzig/TemplateContext.zig create mode 100644 src/jetzig/callbacks.zig create mode 100644 src/jetzig/development_static.zig create mode 100644 src/jetzig/middleware/AntiCsrfMiddleware.zig diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 54f3b92..2b86562 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -45,7 +45,7 @@ jobs: - name: Run App Tests run: | cd demo - zig build jetzig:test + zig build -Denvironment=testing jetzig:test - name: Build artifacts if: ${{ matrix.os == 'ubuntu-latest' }} diff --git a/build.zig b/build.zig index 0f7d2dc..9be5667 100644 --- a/build.zig +++ b/build.zig @@ -109,6 +109,7 @@ pub fn build(b: *std.Build) !void { main_tests.root_module.addImport("smtp", smtp_client_dep.module("smtp_client")); const test_build_options = b.addOptions(); test_build_options.addOption(Environment, "environment", .testing); + test_build_options.addOption(bool, "build_static", true); const run_main_tests = b.addRunArtifact(main_tests); main_tests.root_module.addOptions("build_options", test_build_options); @@ -137,6 +138,11 @@ pub fn jetzigInit(b: *std.Build, exe: *std.Build.Step.Compile, options: JetzigIn "environment", "Jetzig server environment.", ) orelse .development; + const build_static = b.option( + bool, + "build_static", + "Pre-render static routes. [default: false in development, true in testing/production]", + ) orelse (environment != .development); const jetzig_dep = b.dependency( "jetzig", @@ -164,6 +170,7 @@ pub fn jetzigInit(b: *std.Build, exe: *std.Build.Step.Compile, options: JetzigIn const build_options = b.addOptions(); build_options.addOption(Environment, "environment", environment); + build_options.addOption(bool, "build_static", build_static); jetzig_module.addOptions("build_options", build_options); exe.root_module.addImport("jetzig", jetzig_module); @@ -253,15 +260,23 @@ pub fn jetzigInit(b: *std.Build, exe: *std.Build.Step.Compile, options: JetzigIn exe_static_routes.root_module.addImport("zmpl", zmpl_module); const markdown_fragments_write_files = b.addWriteFiles(); - const path = markdown_fragments_write_files.add("markdown_fragments.zig", try generateMarkdownFragments(b)); + const path = markdown_fragments_write_files.add( + "markdown_fragments.zig", + try generateMarkdownFragments(b), + ); const markdown_fragments_module = b.createModule(.{ .root_source_file = path }); exe_static_routes.root_module.addImport("markdown_fragments", markdown_fragments_module); const run_static_routes_cmd = b.addRunArtifact(exe_static_routes); const static_outputs_path = run_static_routes_cmd.addOutputFileArg("static.zig"); - const static_module = b.createModule(.{ .root_source_file = static_outputs_path }); - exe.root_module.addImport("static", static_module); + const static_module = if (build_static) + b.createModule(.{ .root_source_file = static_outputs_path }) + else + b.createModule(.{ + .root_source_file = jetzig_dep.builder.path("src/jetzig/development_static.zig"), + }); + exe.root_module.addImport("static", static_module); run_static_routes_cmd.expectExitCode(0); const run_tests_file_cmd = b.addRunArtifact(exe_routes_file); diff --git a/build.zig.zon b/build.zig.zon index f7e96fc..9c77fbd 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -7,8 +7,8 @@ .hash = "1220d0e8734628fd910a73146e804d10a3269e3e7d065de6bb0e3e88d5ba234eb163", }, .zmpl = .{ - .url = "https://github.com/jetzig-framework/zmpl/archive/25b91d030b992631d319adde1cf01baecd9f3934.tar.gz", - .hash = "12208dd5a4bf0c6c7efc4e9f37a5d8ed80d6004d5680176d1fc2114bfa593e927baf", + .url = "https://github.com/jetzig-framework/zmpl/archive/af75c8b842c3957eb97b4fc4bc49c7b2243968fa.tar.gz", + .hash = "1220ecac93d295dafd2f034a86f0979f6108d40e5ea1a39e3a2b9977c35147cac684", }, .jetkv = .{ .url = "https://github.com/jetzig-framework/jetkv/archive/2b1130a48979ea2871c8cf6ca89c38b1e7062839.tar.gz", diff --git a/demo/src/app/views/anti_csrf.zig b/demo/src/app/views/anti_csrf.zig new file mode 100644 index 0000000..e6f98ff --- /dev/null +++ b/demo/src/app/views/anti_csrf.zig @@ -0,0 +1,77 @@ +const std = @import("std"); +const jetzig = @import("jetzig"); + +pub const layout = "application"; + +pub const actions = .{ + .before = .{jetzig.middleware.AntiCsrfMiddleware}, +}; + +pub fn post(request: *jetzig.Request) !jetzig.View { + var root = try request.data(.object); + + const Params = struct { spam: []const u8 }; + const params = try request.expectParams(Params) orelse { + return request.fail(.unprocessable_entity); + }; + + try root.put("spam", params.spam); + + return request.render(.created); +} + +pub fn index(request: *jetzig.Request) !jetzig.View { + return request.render(.ok); +} + +test "post with missing token" { + var app = try jetzig.testing.app(std.testing.allocator, @import("routes")); + defer app.deinit(); + + const response = try app.request(.POST, "/anti_csrf", .{}); + try response.expectStatus(.forbidden); +} + +test "post with invalid token" { + var app = try jetzig.testing.app(std.testing.allocator, @import("routes")); + defer app.deinit(); + + const response = try app.request(.POST, "/anti_csrf", .{}); + try response.expectStatus(.forbidden); +} + +test "post with valid token but missing expected params" { + var app = try jetzig.testing.app(std.testing.allocator, @import("routes")); + defer app.deinit(); + + _ = try app.request(.GET, "/anti_csrf", .{}); + const token = app.session.getT(.string, jetzig.authenticity_token_name).?; + const response = try app.request( + .POST, + "/anti_csrf", + .{ .params = .{ ._jetzig_authenticity_token = token } }, + ); + try response.expectStatus(.unprocessable_entity); +} + +test "post with valid token and expected params" { + var app = try jetzig.testing.app(std.testing.allocator, @import("routes")); + defer app.deinit(); + + _ = try app.request(.GET, "/anti_csrf", .{}); + const token = app.session.getT(.string, jetzig.authenticity_token_name).?; + const response = try app.request( + .POST, + "/anti_csrf", + .{ .params = .{ ._jetzig_authenticity_token = token, .spam = "Spam" } }, + ); + try response.expectStatus(.created); +} + +test "index" { + var app = try jetzig.testing.app(std.testing.allocator, @import("routes")); + defer app.deinit(); + + const response = try app.request(.GET, "/anti_csrf", .{}); + try response.expectStatus(.ok); +} diff --git a/demo/src/app/views/anti_csrf/index.zmpl b/demo/src/app/views/anti_csrf/index.zmpl new file mode 100644 index 0000000..4f89eaa --- /dev/null +++ b/demo/src/app/views/anti_csrf/index.zmpl @@ -0,0 +1,8 @@ +
+ {{context.authenticityFormElement()}} + + + + + +
diff --git a/demo/src/app/views/anti_csrf/post.zmpl b/demo/src/app/views/anti_csrf/post.zmpl new file mode 100644 index 0000000..7a1b222 --- /dev/null +++ b/demo/src/app/views/anti_csrf/post.zmpl @@ -0,0 +1,5 @@ +

Spam Submitted Successfully

+ +

Spam:

+ +
{{$.spam}}
diff --git a/demo/src/main.zig b/demo/src/main.zig index 336851c..3360536 100644 --- a/demo/src/main.zig +++ b/demo/src/main.zig @@ -12,6 +12,8 @@ pub const jetzig_options = struct { /// Middleware chain. Add any custom middleware here, or use middleware provided in /// `jetzig.middleware` (e.g. `jetzig.middleware.HtmxMiddleware`). pub const middleware: []const type = &.{ + // jetzig.middleware.AuthMiddleware, + // jetzig.middleware.AntiCsrfMiddleware, // jetzig.middleware.HtmxMiddleware, // jetzig.middleware.CompressionMiddleware, // @import("app/middleware/DemoMiddleware.zig"), @@ -79,13 +81,16 @@ pub const jetzig_options = struct { pub const Schema = @import("Schema"); /// HTTP cookie configuration - pub const cookies: jetzig.http.Cookies.CookieOptions = .{ - .domain = switch (jetzig.environment) { - .development => "localhost", - .testing => "localhost", - .production => "www.example.com", + pub const cookies: jetzig.http.Cookies.CookieOptions = switch (jetzig.environment) { + .development, .testing => .{ + .domain = "localhost", + .path = "/", + }, + .production => .{ + .same_site = true, + .secure = true, + .http_only = true, }, - .path = "/", }; /// Key-value store options. Set backend to `.file` to use a file-based store. diff --git a/src/Routes.zig b/src/Routes.zig index eadedb2..5751e8c 100644 --- a/src/Routes.zig +++ b/src/Routes.zig @@ -280,6 +280,8 @@ fn writeRoute(self: *Routes, writer: std.ArrayList(u8).Writer, route: Function) \\ .static = {4s}, \\ .uri_path = "{5s}", \\ .template = "{6s}", + \\ .before_callbacks = jetzig.callbacks.beforeCallbacks(@import("{7s}")), + \\ .after_callbacks = jetzig.callbacks.afterCallbacks(@import("{7s}")), \\ .layout = if (@hasDecl(@import("{7s}"), "layout")) @import("{7s}").layout else null, \\ .json_params = &[_][]const u8 {{ {8s} }}, \\ .formats = if (@hasDecl(@import("{7s}"), "formats")) @import("{7s}").formats else null, @@ -389,7 +391,7 @@ fn generateRoutesForView(self: *Routes, dir: std.fs.Dir, path: []const u8) !Rout for (capture.args, 0..) |arg, arg_index| { if (std.mem.eql(u8, try arg.typeBasename(), "StaticRequest")) { - capture.static = true; + capture.static = jetzig.build_options.build_static; capture.legacy = arg_index + 1 < capture.args.len; try static_routes.append(capture.*); } else if (std.mem.eql(u8, try arg.typeBasename(), "Request")) { diff --git a/src/commands/routes.zig b/src/commands/routes.zig index d5a5599..c70b79d 100644 --- a/src/commands/routes.zig +++ b/src/commands/routes.zig @@ -13,7 +13,7 @@ pub fn main() !void { log("Jetzig Routes:", .{}); - const environment = jetzig.Environment.init(allocator, .{ .silent = true }); + const environment = try jetzig.Environment.init(allocator, .{ .silent = true }); const initHook: ?*const fn (*jetzig.App) anyerror!void = if (@hasDecl(app, "init")) app.init else null; inline for (routes.routes) |route| max_uri_path_len = @max(route.uri_path.len + 5, max_uri_path_len); @@ -44,7 +44,7 @@ pub fn main() !void { } var jetzig_app = jetzig.App{ - .environment = environment, + .env = environment, .allocator = allocator, .custom_routes = std.ArrayList(jetzig.views.Route).init(allocator), .initHook = initHook, diff --git a/src/compile_static_routes.zig b/src/compile_static_routes.zig index 0ac298e..1f45ec5 100644 --- a/src/compile_static_routes.zig +++ b/src/compile_static_routes.zig @@ -147,7 +147,7 @@ fn renderMarkdown( if (zmpl.findPrefixed("views", prefixed_name)) |layout| { view.data.content = .{ .data = content }; - return try layout.render(view.data); + return try layout.render(view.data, jetzig.TemplateContext, .{}, .{}); } else { std.debug.print("Unknown layout: {s}\n", .{layout_name}); return content; @@ -170,13 +170,18 @@ fn renderZmplTemplate( defer allocator.free(prefixed_name); if (zmpl.findPrefixed("views", prefixed_name)) |layout| { - return try template.renderWithOptions(view.data, .{ .layout = layout }); + return try template.render( + view.data, + jetzig.TemplateContext, + .{}, + .{ .layout = layout }, + ); } else { std.debug.print("Unknown layout: {s}\n", .{layout_name}); return try allocator.dupe(u8, ""); } } else { - return try template.render(view.data); + return try template.render(view.data, jetzig.TemplateContext, .{}, .{}); } } else return null; } diff --git a/src/jetzig.zig b/src/jetzig.zig index 231f149..de8af9e 100644 --- a/src/jetzig.zig +++ b/src/jetzig.zig @@ -22,11 +22,15 @@ pub const database = @import("jetzig/database.zig"); pub const testing = @import("jetzig/testing.zig"); pub const config = @import("jetzig/config.zig"); pub const auth = @import("jetzig/auth.zig"); +pub const callbacks = @import("jetzig/callbacks.zig"); +pub const TemplateContext = @import("jetzig/TemplateContext.zig"); pub const DateTime = jetcommon.types.DateTime; pub const Time = jetcommon.types.Time; pub const Date = jetcommon.types.Date; +pub const authenticity_token_name = config.get([]const u8, "authenticity_token_name"); + pub const build_options = @import("build_options"); pub const environment = std.enums.nameCast(Environment.EnvironmentName, build_options.environment); @@ -46,6 +50,9 @@ pub const Request = http.Request; /// requests. pub const StaticRequest = http.StaticRequest; +/// An HTTP response generated during request processing. +pub const Response = http.Response; + /// Generic, JSON-compatible data type. Provides `Value` which in turn provides `Object`, /// `Array`, `String`, `Integer`, `Float`, `Boolean`, and `NullType`. pub const Data = data.Data; @@ -78,7 +85,8 @@ pub const Logger = loggers.Logger; pub const root = @import("root"); pub const Global = if (@hasDecl(root, "Global")) root.Global else DefaultGlobal; -pub const DefaultGlobal = struct { __jetzig_default: bool }; +pub const DefaultGlobal = struct { comptime __jetzig_default: bool = true }; +pub const default_global = DefaultGlobal{}; pub const initHook: ?*const fn (*App) anyerror!void = if (@hasDecl(root, "init")) root.init else null; diff --git a/src/jetzig/App.zig b/src/jetzig/App.zig index 474b1fd..2274790 100644 --- a/src/jetzig/App.zig +++ b/src/jetzig/App.zig @@ -17,8 +17,8 @@ pub fn deinit(self: *const App) void { @constCast(self).custom_routes.deinit(); } -// Not used yet, but allows us to add new options to `start()` without breaking -// backward-compatibility. +/// Specify a global value accessible as `request.server.global`. +/// Must specify type by defining `pub const Global` in your app's `src/main.zig`. const AppOptions = struct { global: *anyopaque = undefined, }; @@ -228,6 +228,8 @@ pub fn createRoutes( .template = const_route.template, .json_params = const_route.json_params, .formats = const_route.formats, + .before_callbacks = const_route.before_callbacks, + .after_callbacks = const_route.after_callbacks, }; try var_route.initParams(allocator); diff --git a/src/jetzig/TemplateContext.zig b/src/jetzig/TemplateContext.zig new file mode 100644 index 0000000..321e527 --- /dev/null +++ b/src/jetzig/TemplateContext.zig @@ -0,0 +1,25 @@ +const std = @import("std"); + +pub const http = @import("http.zig"); +pub const config = @import("config.zig"); + +/// Context available in every Zmpl template as `context`. +pub const TemplateContext = @This(); + +request: ?*http.Request = null, + +pub fn authenticityToken(self: TemplateContext) !?[]const u8 { + return if (self.request) |request| + try request.authenticityToken() + else + null; +} + +pub fn authenticityFormElement(self: TemplateContext) !?[]const u8 { + return if (self.request) |request| blk: { + const token = try request.authenticityToken(); + break :blk try std.fmt.allocPrint(request.allocator, + \\ + , .{ config.get([]const u8, "authenticity_token_name"), token }); + } else null; +} diff --git a/src/jetzig/callbacks.zig b/src/jetzig/callbacks.zig new file mode 100644 index 0000000..6339990 --- /dev/null +++ b/src/jetzig/callbacks.zig @@ -0,0 +1,112 @@ +const std = @import("std"); + +const jetzig = @import("../jetzig.zig"); + +pub const BeforeCallback = *const fn ( + *jetzig.http.Request, + jetzig.views.Route, +) anyerror!void; + +pub const AfterCallback = *const fn ( + *jetzig.http.Request, + *jetzig.http.Response, + jetzig.views.Route, +) anyerror!void; + +pub const Context = enum { before, after }; + +pub fn beforeCallbacks(view: type) []const BeforeCallback { + comptime { + return buildCallbacks(.before, view); + } +} + +pub fn afterCallbacks(view: type) []const AfterCallback { + comptime { + return buildCallbacks(.after, view); + } +} + +fn buildCallbacks(comptime context: Context, view: type) switch (context) { + .before => []const BeforeCallback, + .after => []const AfterCallback, +} { + comptime { + if (!@hasDecl(view, "actions")) return &.{}; + if (!@hasField(@TypeOf(view.actions), @tagName(context))) return &.{}; + + var size: usize = 0; + for (@field(view.actions, @tagName(context))) |module| { + if (isCallback(context, module)) { + size += 1; + } else { + @compileError(std.fmt.comptimePrint( + "`{0s}` callbacks must be either a function `{1s}` or a type that defines " ++ + "`pub const {0s}Render`. Found: `{2s}`", + .{ + @tagName(context), + switch (context) { + .before => @typeName(BeforeCallback), + .after => @typeName(AfterCallback), + }, + if (@TypeOf(module) == type) + @typeName(module) + else + @typeName(@TypeOf(&module)), + }, + )); + } + } + + var callbacks: [size]switch (context) { + .before => BeforeCallback, + .after => AfterCallback, + } = undefined; + var index: usize = 0; + for (@field(view.actions, @tagName(context))) |module| { + if (!isCallback(context, module)) continue; + + callbacks[index] = if (@TypeOf(module) == type) + @field(module, @tagName(context) ++ "Render") + else + &module; + + index += 1; + } + + const final = callbacks; + return &final; + } +} + +fn isCallback(comptime context: Context, comptime module: anytype) bool { + comptime { + if (@typeInfo(@TypeOf(module)) == .@"fn") { + const expected = switch (context) { + .before => BeforeCallback, + .after => AfterCallback, + }; + + const info = @typeInfo(@TypeOf(module)).@"fn"; + + const actual_params = info.params; + const expected_params = @typeInfo(@typeInfo(expected).pointer.child).@"fn".params; + + if (actual_params.len != expected_params.len) return false; + + for (actual_params, expected_params) |actual_param, expected_param| { + if (actual_param.type != expected_param.type) return false; + } + + if (@typeInfo(info.return_type.?) != .error_union) return false; + if (@typeInfo(info.return_type.?).error_union.payload != void) return false; + + return true; + } + + return if (@TypeOf(module) == type and @hasDecl(module, @tagName(context) ++ "Render")) + true + else + false; + } +} diff --git a/src/jetzig/config.zig b/src/jetzig/config.zig index cce32c3..470b8d8 100644 --- a/src/jetzig/config.zig +++ b/src/jetzig/config.zig @@ -17,6 +17,9 @@ pub const jobs = @import("jobs.zig"); pub const mail = @import("mail.zig"); pub const kv = @import("kv.zig"); pub const db = @import("database.zig"); +pub const Environment = @import("Environment.zig"); +pub const environment = std.enums.nameCast(Environment.EnvironmentName, build_options.environment); +pub const build_options = @import("build_options"); const root = @import("root"); @@ -149,11 +152,26 @@ pub const smtp: mail.SMTPConfig = .{ }; /// HTTP cookie configuration -pub const cookies: http.Cookies.CookieOptions = .{ - .domain = "localhost", - .path = "/", +pub const cookies: http.Cookies.CookieOptions = switch (environment) { + .development, .testing => .{ + .domain = "localhost", + .path = "/", + }, + .production => .{ + .secure = true, + .http_only = true, + .same_site = true, + .path = "/", + }, }; +/// Override the default anti-CSRF authenticity token name that is stored in the encrypted +/// session. This value is also used by `context.authenticityFormElement()` to render an HTML +/// element: the element's `name` attribute is set to this value. +pub const authenticity_token_name: []const u8 = "_jetzig_authenticity_token"; + +/// When using `AuthMiddleware`, set this value to override the default JetQuery model name that +/// maps the users table. pub const auth: @import("auth.zig").AuthOptions = .{ .user_model = "User", }; diff --git a/src/jetzig/development_static.zig b/src/jetzig/development_static.zig new file mode 100644 index 0000000..0d6892b --- /dev/null +++ b/src/jetzig/development_static.zig @@ -0,0 +1,12 @@ +pub const compiled = [_]Compiled{}; + +const StaticOutput = struct { + json: ?[]const u8 = null, + html: ?[]const u8 = null, + params: ?[]const u8, +}; + +const Compiled = struct { + route_id: []const u8, + output: StaticOutput, +}; diff --git a/src/jetzig/http.zig b/src/jetzig/http.zig index 3426606..5137903 100644 --- a/src/jetzig/http.zig +++ b/src/jetzig/http.zig @@ -1,9 +1,14 @@ const std = @import("std"); const builtin = @import("builtin"); +pub const build_options = @import("build_options"); + pub const Server = @import("http/Server.zig"); pub const Request = @import("http/Request.zig"); -pub const StaticRequest = @import("http/StaticRequest.zig"); +pub const StaticRequest = if (build_options.environment == .development) + Request +else + @import("http/StaticRequest.zig"); pub const Response = @import("http/Response.zig"); pub const Session = @import("http/Session.zig"); pub const Cookies = @import("http/Cookies.zig"); diff --git a/src/jetzig/http/Cookies.zig b/src/jetzig/http/Cookies.zig index 736715e..214bb79 100644 --- a/src/jetzig/http/Cookies.zig +++ b/src/jetzig/http/Cookies.zig @@ -8,18 +8,18 @@ cookies: std.StringArrayHashMap(*Cookie), modified: bool = false, arena: std.heap.ArenaAllocator, -const Self = @This(); +const Cookies = @This(); const SameSite = enum { strict, lax, none }; pub const CookieOptions = struct { - domain: []const u8 = "localhost", + domain: ?[]const u8 = "localhost", path: []const u8 = "/", - same_site: ?SameSite = null, secure: bool = false, - expires: ?i64 = null, // if used, set to time in seconds to be added to std.time.timestamp() http_only: bool = false, - max_age: ?i64 = null, partitioned: bool = false, + same_site: ?SameSite = null, + expires: ?i64 = null, // if used, set to time in seconds to be added to std.time.timestamp() + max_age: ?i64 = null, }; const cookie_options = jetzig.config.get(CookieOptions, "cookies"); @@ -27,43 +27,50 @@ const cookie_options = jetzig.config.get(CookieOptions, "cookies"); pub const Cookie = struct { name: []const u8, value: []const u8, - domain: ?[]const u8 = null, - path: ?[]const u8 = null, - same_site: ?SameSite = null, - secure: ?bool = null, - expires: ?i64 = null, // if used, set to time in seconds to be added to std.time.timestamp() - http_only: ?bool = null, - max_age: ?i64 = null, - partitioned: ?bool = null, + secure: bool = cookie_options.secure, + http_only: bool = cookie_options.http_only, + partitioned: bool = cookie_options.partitioned, + domain: ?[]const u8 = cookie_options.domain, + path: ?[]const u8 = cookie_options.path, + same_site: ?SameSite = cookie_options.same_site, + // if used, set to time in seconds to be added to std.time.timestamp() + expires: ?i64 = cookie_options.expires, + max_age: ?i64 = cookie_options.max_age, /// Build a cookie string. pub fn bufPrint(self: Cookie, buf: *[4096]u8) ![]const u8 { - var options = cookie_options; - inline for (std.meta.fields(CookieOptions)) |field| { - @field(options, field.name) = @field(self, field.name) orelse @field(cookie_options, field.name); - } - - // secure is required if samesite is set to none - const require_secure = if (options.same_site) |same_site| same_site == .none else false; - var stream = std.io.fixedBufferStream(buf); const writer = stream.writer(); + try writer.print("{}", .{self}); + return stream.getWritten(); + } + + /// Build a cookie string. + pub fn format(self: Cookie, _: anytype, _: anytype, writer: anytype) !void { + // secure is required if samesite is set to none + const require_secure = if (self.same_site) |same_site| same_site == .none else false; - try writer.print("{s}={s}; path={s}; domain={s};", .{ + try writer.print("{s}={s}; path={s};", .{ self.name, self.value, - options.path, - options.domain, + self.path orelse "/", }); - if (options.same_site) |same_site| try writer.print(" SameSite={s};", .{@tagName(same_site)}); - if (options.secure or require_secure) try writer.writeAll(" Secure;"); - if (options.expires) |expires| try writer.print(" Expires={d};", .{std.time.timestamp() + expires}); - if (options.max_age) |max_age| try writer.print(" Max-Age={d};", .{max_age}); - if (options.http_only) try writer.writeAll(" HttpOnly;"); - if (options.partitioned) try writer.writeAll(" Partitioned;"); - - return stream.getWritten(); + if (self.domain) |domain| try writer.print(" domain={s};", .{domain}); + if (self.same_site) |same_site| try writer.print( + " SameSite={s};", + .{@tagName(same_site)}, + ); + if (self.secure or require_secure) try writer.writeAll(" Secure;"); + if (self.expires) |expires| { + const seconds = std.time.timestamp() + expires; + const timestamp = try jetzig.jetcommon.DateTime.fromUnix(seconds, .seconds); + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#expiresdate + try timestamp.strftime(writer, " Expires=%a, %d %h %Y %H:%M:%S GMT;"); + } + if (self.max_age) |max_age| try writer.print(" Max-Age={d};", .{max_age}); + if (self.http_only) try writer.writeAll(" HttpOnly;"); + if (self.partitioned) try writer.writeAll(" Partitioned;"); } pub fn applyFlag(self: *Cookie, allocator: std.mem.Allocator, flag: Flag) !void { @@ -80,7 +87,7 @@ pub const Cookie = struct { } }; -pub fn init(allocator: std.mem.Allocator, cookie_string: []const u8) Self { +pub fn init(allocator: std.mem.Allocator, cookie_string: []const u8) Cookies { return .{ .allocator = allocator, .cookie_string = cookie_string, @@ -89,7 +96,7 @@ pub fn init(allocator: std.mem.Allocator, cookie_string: []const u8) Self { }; } -pub fn deinit(self: *Self) void { +pub fn deinit(self: *Cookies) void { var it = self.cookies.iterator(); while (it.next()) |item| { self.allocator.free(item.key_ptr.*); @@ -100,11 +107,11 @@ pub fn deinit(self: *Self) void { self.arena.deinit(); } -pub fn get(self: *Self, key: []const u8) ?*Cookie { +pub fn get(self: *Cookies, key: []const u8) ?*Cookie { return self.cookies.get(key); } -pub fn put(self: *Self, cookie: Cookie) !void { +pub fn put(self: *Cookies, cookie: Cookie) !void { self.modified = true; if (self.cookies.fetchSwapRemove(cookie.name)) |entry| { @@ -125,8 +132,12 @@ pub const HeaderIterator = struct { cookies_iterator: std.StringArrayHashMap(*Cookie).Iterator, buf: *[4096]u8, - pub fn init(allocator: std.mem.Allocator, cookies: *Self, buf: *[4096]u8) HeaderIterator { - return .{ .allocator = allocator, .cookies_iterator = cookies.cookies.iterator(), .buf = buf }; + pub fn init(allocator: std.mem.Allocator, cookies: *Cookies, buf: *[4096]u8) HeaderIterator { + return .{ + .allocator = allocator, + .cookies_iterator = cookies.cookies.iterator(), + .buf = buf, + }; } pub fn next(self: *HeaderIterator) !?[]const u8 { @@ -139,14 +150,14 @@ pub const HeaderIterator = struct { } }; -pub fn headerIterator(self: *Self, buf: *[4096]u8) HeaderIterator { +pub fn headerIterator(self: *Cookies, buf: *[4096]u8) HeaderIterator { return HeaderIterator.init(self.allocator, self, buf); } // https://datatracker.ietf.org/doc/html/rfc6265#section-4.2.1 // cookie-header = "Cookie:" OWS cookie-string OWS // cookie-string = cookie-pair *( ";" SP cookie-pair ) -pub fn parse(self: *Self) !void { +pub fn parse(self: *Cookies) !void { var key_buf = std.ArrayList(u8).init(self.allocator); var value_buf = std.ArrayList(u8).init(self.allocator); var key_terminated = false; @@ -202,6 +213,13 @@ pub fn parse(self: *Self) !void { for (cookie_buf.items) |cookie| try self.put(cookie); } +pub fn format(self: Cookies, _: anytype, _: anytype, writer: anytype) !void { + var it = self.cookies.iterator(); + while (it.next()) |entry| { + try writer.print("{}; ", .{entry.value_ptr.*}); + } +} + const Flag = union(enum) { domain: []const u8, path: []const u8, @@ -250,7 +268,7 @@ fn parseFlag(key: []const u8, value: []const u8) ?Flag { test "basic cookie string" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux;"); defer cookies.deinit(); try cookies.parse(); try std.testing.expectEqualStrings("bar", cookies.get("foo").?.value); @@ -259,14 +277,14 @@ test "basic cookie string" { test "empty cookie string" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, ""); + var cookies = Cookies.init(allocator, ""); defer cookies.deinit(); try cookies.parse(); } test "cookie string with irregular spaces" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo= bar; baz= qux;"); + var cookies = Cookies.init(allocator, "foo= bar; baz= qux;"); defer cookies.deinit(); try cookies.parse(); try std.testing.expectEqualStrings("bar", cookies.get("foo").?.value); @@ -280,7 +298,7 @@ test "headerIterator" { const writer = buf.writer(); - var cookies = Self.init(allocator, "foo=bar; baz=qux;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux;"); defer cookies.deinit(); try cookies.parse(); @@ -300,7 +318,7 @@ test "headerIterator" { test "modified" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux;"); defer cookies.deinit(); try cookies.parse(); @@ -312,7 +330,7 @@ test "modified" { test "domain=example.com" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Domain=example.com;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Domain=example.com;"); defer cookies.deinit(); try cookies.parse(); @@ -322,7 +340,7 @@ test "domain=example.com" { test "path=/example_path" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Path=/example_path;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Path=/example_path;"); defer cookies.deinit(); try cookies.parse(); @@ -332,7 +350,7 @@ test "path=/example_path" { test "SameSite=lax" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; SameSite=lax;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; SameSite=lax;"); defer cookies.deinit(); try cookies.parse(); @@ -342,7 +360,7 @@ test "SameSite=lax" { test "SameSite=none" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; SameSite=none;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; SameSite=none;"); defer cookies.deinit(); try cookies.parse(); @@ -352,7 +370,7 @@ test "SameSite=none" { test "SameSite=strict" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; SameSite=strict;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; SameSite=strict;"); defer cookies.deinit(); try cookies.parse(); @@ -362,27 +380,27 @@ test "SameSite=strict" { test "Secure" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Secure;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Secure;"); defer cookies.deinit(); try cookies.parse(); const cookie = cookies.get("foo").?; - try std.testing.expect(cookie.secure.?); + try std.testing.expect(cookie.secure); } test "Partitioned" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Partitioned;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Partitioned;"); defer cookies.deinit(); try cookies.parse(); const cookie = cookies.get("foo").?; - try std.testing.expect(cookie.partitioned.?); + try std.testing.expect(cookie.partitioned); } test "Max-Age" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Max-Age=123123123;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Max-Age=123123123;"); defer cookies.deinit(); try cookies.parse(); @@ -392,7 +410,7 @@ test "Max-Age" { test "Expires" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux; Expires=123123123;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux; Expires=123123123;"); defer cookies.deinit(); try cookies.parse(); @@ -402,17 +420,17 @@ test "Expires" { test "default flags" { const allocator = std.testing.allocator; - var cookies = Self.init(allocator, "foo=bar; baz=qux;"); + var cookies = Cookies.init(allocator, "foo=bar; baz=qux;"); defer cookies.deinit(); try cookies.parse(); const cookie = cookies.get("foo").?; - try std.testing.expect(cookie.domain == null); - try std.testing.expect(cookie.path == null); + try std.testing.expect(cookie.secure == false); + try std.testing.expect(cookie.partitioned == false); + try std.testing.expect(cookie.http_only == false); try std.testing.expect(cookie.same_site == null); - try std.testing.expect(cookie.secure == null); + try std.testing.expectEqualStrings(cookie.domain.?, "localhost"); + try std.testing.expectEqualStrings(cookie.path.?, "/"); try std.testing.expect(cookie.expires == null); - try std.testing.expect(cookie.http_only == null); try std.testing.expect(cookie.max_age == null); - try std.testing.expect(cookie.partitioned == null); } diff --git a/src/jetzig/http/Headers.zig b/src/jetzig/http/Headers.zig index c721642..83978e3 100644 --- a/src/jetzig/http/Headers.zig +++ b/src/jetzig/http/Headers.zig @@ -46,10 +46,9 @@ pub fn getAll(self: Headers, name: []const u8) []const []const u8 { var headers = std.ArrayList([]const u8).init(self.allocator); for (self.httpz_headers.keys, 0..) |key, index| { - var buf: [max_bytes_header_name]u8 = undefined; - const lower = std.ascii.lowerString(&buf, name); - - if (std.mem.eql(u8, lower, key)) headers.append(self.httpz_headers.values[index]) catch @panic("OOM"); + if (std.ascii.eqlIgnoreCase(name, key)) { + headers.append(self.httpz_headers.values[index]) catch @panic("OOM"); + } } return headers.toOwnedSlice() catch @panic("OOM"); } diff --git a/src/jetzig/http/Request.zig b/src/jetzig/http/Request.zig index 4a78395..b864913 100644 --- a/src/jetzig/http/Request.zig +++ b/src/jetzig/http/Request.zig @@ -10,6 +10,7 @@ const default_content_type = "text/html"; pub const Method = enum { DELETE, GET, PATCH, POST, HEAD, PUT, CONNECT, OPTIONS, TRACE }; pub const Modifier = enum { edit, new }; pub const Format = enum { HTML, JSON, UNKNOWN }; +pub const Protocol = enum { http, https }; allocator: std.mem.Allocator, path: jetzig.http.Path, @@ -34,6 +35,8 @@ response_started: bool = false, dynamic_assigned_template: ?[]const u8 = null, layout: ?[]const u8 = null, layout_disabled: bool = false, +// TODO: Squash rendered/redirected/failed into +// `state: enum { initial, rendered, redirected, failed }` rendered: bool = false, redirected: bool = false, failed: bool = false, @@ -249,7 +252,10 @@ pub fn middleware( unreachable; } -const RedirectState = struct { location: []const u8, status_code: jetzig.http.status_codes.StatusCode }; +const RedirectState = struct { + location: []const u8, + status_code: jetzig.http.status_codes.StatusCode, +}; pub fn renderRedirect(self: *Request, state: RedirectState) !void { self.response_data.reset(); @@ -275,7 +281,12 @@ pub fn renderRedirect(self: *Request, state: RedirectState) !void { .HTML, .UNKNOWN => if (maybe_template) |template| blk: { try view.data.addConst("jetzig_view", view.data.string("internal")); try view.data.addConst("jetzig_action", view.data.string(@tagName(state.status_code))); - break :blk try template.render(self.response_data); + break :blk try template.render( + self.response_data, + jetzig.TemplateContext, + .{ .request = self }, + .{}, + ); } else try std.fmt.allocPrint(self.allocator, "Redirecting to {s}", .{state.location}), .JSON => blk: { break :blk try std.json.stringifyAlloc( @@ -486,6 +497,31 @@ pub fn session(self: *Request) !*jetzig.http.Session { return local_session; } +/// Return the anti-CSRF token cookie value. If no cookie exist, create it. +pub fn authenticityToken(self: *Request) ![]const u8 { + var local_session = try self.session(); + + return local_session.getT(.string, jetzig.authenticity_token_name) orelse blk: { + const token = try jetzig.util.generateSecret(self.allocator, 32); + try local_session.put(jetzig.authenticity_token_name, token); + break :blk local_session.getT(.string, jetzig.authenticity_token_name).?; + }; +} + +pub fn resourceId(self: *const Request) ![]const u8 { + return self.path.resource_id; +} + +/// Return the protocol used to serve the current request by detecting the `X-Forwarded-Proto` +/// header. +pub fn protocol(self: *const Request) Protocol { + return if (self.headers.get("x-forwarded-proto")) |x_forwarded_proto| + if (std.ascii.eqlIgnoreCase(x_forwarded_proto, "https")) .https else .http + else + // TODO: Extend login when we support serving HTTPS directly. + .http; +} + /// Create a new Job. Receives a job name which must resolve to `src/app/jobs/.zig` /// Call `Job.put(...)` to set job params. /// Call `Job.background()` to run the job outside of the request/response flow. diff --git a/src/jetzig/http/Server.zig b/src/jetzig/http/Server.zig index 5ac06be..819067a 100644 --- a/src/jetzig/http/Server.zig +++ b/src/jetzig/http/Server.zig @@ -14,7 +14,6 @@ custom_routes: []jetzig.views.Route, job_definitions: []const jetzig.JobDefinition, mailer_definitions: []const jetzig.MailerDefinition, mime_map: *jetzig.http.mime.MimeMap, -std_net_server: std.net.Server = undefined, initialized: bool = false, store: *jetzig.kv.Store, job_queue: *jetzig.kv.Store, @@ -57,7 +56,6 @@ pub fn init( } pub fn deinit(self: *Server) void { - if (self.initialized) self.std_net_server.deinit(); self.allocator.free(self.env.secret); self.allocator.free(self.env.bind); } @@ -187,21 +185,45 @@ fn renderResponse(self: *Server, request: *jetzig.http.Request) !void { } else unreachable; // In future a MiddlewareRoute might provide a render function etc. } - const route = self.matchCustomRoute(request) orelse try self.matchRoute(request, false); + const maybe_route = self.matchCustomRoute(request) orelse try self.matchRoute(request, false); - if (route) |capture| { - if (!capture.validateFormat(request)) { + if (maybe_route) |route| { + if (!route.validateFormat(request)) { return request.setResponse(try self.renderNotFound(request), .{}); } } + if (maybe_route) |route| { + for (route.before_callbacks) |callback| { + try callback(request, route); + if (request.rendered_view) |view| { + if (request.failed) { + request.setResponse(try self.renderError(request, view.status_code), .{}); + } else if (request.rendered) { + // TODO: Allow callbacks to set content + } + return; + } + if (request.redirect_state) |state| { + try request.renderRedirect(state); + return; + } + } + } + switch (request.requestFormat()) { - .HTML => try self.renderHTML(request, route), - .JSON => try self.renderJSON(request, route), - .UNKNOWN => try self.renderHTML(request, route), + .HTML => try self.renderHTML(request, maybe_route), + .JSON => try self.renderJSON(request, maybe_route), + .UNKNOWN => try self.renderHTML(request, maybe_route), + } + + if (maybe_route) |route| { + for (route.after_callbacks) |callback| { + try callback(request, request.response, route); + } } - if (request.redirect_state) |state| return try request.renderRedirect(state); + if (request.redirect_state) |state| try request.renderRedirect(state); } fn renderStatic(resource: StaticResource, request: *jetzig.http.Request) !void { @@ -355,6 +377,8 @@ fn renderTemplateWithLayout( ) ![]const u8 { try addTemplateConstants(view, route); + const template_context = jetzig.TemplateContext{ .request = request }; + if (request.getLayout(route)) |layout_name| { // TODO: Allow user to configure layouts directory other than src/app/views/layouts/ const prefixed_name = try std.mem.concat( @@ -365,23 +389,37 @@ fn renderTemplateWithLayout( defer self.allocator.free(prefixed_name); if (zmpl.findPrefixed("views", prefixed_name)) |layout| { - return try template.renderWithOptions(view.data, .{ .layout = layout }); + return try template.render( + view.data, + jetzig.TemplateContext, + template_context, + .{ .layout = layout }, + ); } else { try self.logger.WARN("Unknown layout: {s}", .{layout_name}); - return try template.render(view.data); + return try template.render( + view.data, + jetzig.TemplateContext, + template_context, + .{}, + ); } - } else return try template.render(view.data); + } else return try template.render( + view.data, + jetzig.TemplateContext, + template_context, + .{}, + ); } fn addTemplateConstants(view: jetzig.views.View, route: jetzig.views.Route) !void { - try view.data.addConst("jetzig_view", view.data.string(route.view_name)); - const action = switch (route.action) { .custom => route.name, else => |tag| @tagName(tag), }; try view.data.addConst("jetzig_action", view.data.string(action)); + try view.data.addConst("jetzig_view", view.data.string(route.view_name)); } fn isBadRequest(err: anyerror) bool { @@ -481,7 +519,15 @@ fn renderErrorView( .HTML, .UNKNOWN => { if (zmpl.findPrefixed("views", route.template)) |template| { try addTemplateConstants(view, route.*); - return .{ .view = view, .content = try template.render(request.response_data) }; + return .{ + .view = view, + .content = try template.render( + request.response_data, + jetzig.TemplateContext, + .{ .request = request }, + .{}, + ), + }; } }, .JSON => return .{ .view = view, .content = try request.response_data.toJson() }, @@ -573,19 +619,26 @@ fn matchMiddlewareRoute(request: *const jetzig.http.Request) ?jetzig.middleware. fn matchRoute(self: *Server, request: *jetzig.http.Request, static: bool) !?jetzig.views.Route { for (self.routes) |route| { // .index routes always take precedence. - if (route.static == static and route.action == .index and try request.match(route.*)) { - return route.*; + if (route.action == .index and try request.match(route.*)) { + if (!jetzig.build_options.build_static) return route.*; + if (route.static == static) return route.*; } } for (self.routes) |route| { - if (route.static == static and try request.match(route.*)) return route.*; + if (try request.match(route.*)) { + if (!jetzig.build_options.build_static) return route.*; + if (route.static == static) return route.*; + } } return null; } -const StaticResource = struct { content: []const u8, mime_type: []const u8 = "application/octet-stream" }; +const StaticResource = struct { + content: []const u8, + mime_type: []const u8 = "application/octet-stream", +}; fn matchStaticResource(self: *Server, request: *jetzig.http.Request) !?StaticResource { // TODO: Map public and static routes at launch to avoid accessing the file system when @@ -682,7 +735,7 @@ fn matchStaticContent(self: *Server, request: *jetzig.http.Request) !?[]const u8 } pub fn decodeStaticParams(self: *Server) !void { - if (!@hasDecl(jetzig.root, "static")) return; + if (comptime !@hasDecl(jetzig.root, "static")) return; // Store decoded static params (i.e. declared in views) for faster comparison at request time. var decoded = std.ArrayList(*jetzig.data.Value).init(self.allocator); diff --git a/src/jetzig/mail/Job.zig b/src/jetzig/mail/Job.zig index f1585bb..e4b7c23 100644 --- a/src/jetzig/mail/Job.zig +++ b/src/jetzig/mail/Job.zig @@ -134,7 +134,7 @@ fn defaultHtml( try data.addConst("jetzig_view", data.string("")); try data.addConst("jetzig_action", data.string("")); return if (jetzig.zmpl.findPrefixed("mailers", mailer.html_template)) |template| - try template.render(&data) + try template.render(&data, jetzig.TemplateContext, .{}, .{}) else null; } @@ -152,7 +152,7 @@ fn defaultText( try data.addConst("jetzig_view", data.string("")); try data.addConst("jetzig_action", data.string("")); return if (jetzig.zmpl.findPrefixed("mailers", mailer.text_template)) |template| - try template.render(&data) + try template.render(&data, jetzig.TemplateContext, .{}, .{}) else null; } diff --git a/src/jetzig/middleware.zig b/src/jetzig/middleware.zig index e3e5b20..4be3640 100644 --- a/src/jetzig/middleware.zig +++ b/src/jetzig/middleware.zig @@ -4,6 +4,7 @@ const jetzig = @import("../jetzig.zig"); pub const HtmxMiddleware = @import("middleware/HtmxMiddleware.zig"); pub const CompressionMiddleware = @import("middleware/CompressionMiddleware.zig"); pub const AuthMiddleware = @import("middleware/AuthMiddleware.zig"); +pub const AntiCsrfMiddleware = @import("middleware/AntiCsrfMiddleware.zig"); const RouteOptions = struct { content: ?[]const u8 = null, diff --git a/src/jetzig/middleware/AntiCsrfMiddleware.zig b/src/jetzig/middleware/AntiCsrfMiddleware.zig new file mode 100644 index 0000000..c8d2eb0 --- /dev/null +++ b/src/jetzig/middleware/AntiCsrfMiddleware.zig @@ -0,0 +1,73 @@ +const std = @import("std"); +const jetzig = @import("../../jetzig.zig"); + +pub const middleware_name = "anti_csrf"; + +const TokenParams = @Type(.{ + .@"struct" = .{ + .layout = .auto, + .is_tuple = false, + .decls = &.{}, + .fields = &.{.{ + .name = jetzig.authenticity_token_name ++ "", + .type = []const u8, + .is_comptime = false, + .default_value = null, + .alignment = @alignOf([]const u8), + }}, + }, +}); + +pub fn afterRequest(request: *jetzig.http.Request) !void { + try verifyCsrfToken(request); +} + +pub fn beforeRender(request: *jetzig.http.Request, route: jetzig.views.Route) !void { + _ = route; + try verifyCsrfToken(request); +} + +fn logFailure(request: *jetzig.http.Request) !void { + _ = request.fail(.forbidden); + try request.server.logger.DEBUG("Anti-CSRF token validation failed. Request aborted.", .{}); +} + +fn verifyCsrfToken(request: *jetzig.http.Request) !void { + switch (request.method) { + .DELETE, .PATCH, .PUT, .POST => {}, + else => return, + } + + switch (request.requestFormat()) { + .HTML, .UNKNOWN => {}, + // We do not authenticate JSON requests. Users must implement their own authentication + // system or disable JSON endpoints that should be protected. + .JSON => return, + } + + const session = try request.session(); + + if (session.getT(.string, jetzig.authenticity_token_name)) |token| { + const params = try request.expectParams(TokenParams) orelse { + return logFailure(request); + }; + + if (token.len != 32 or @field(params, jetzig.authenticity_token_name).len != 32) { + return try logFailure(request); + } + + var actual: [32]u8 = undefined; + var expected: [32]u8 = undefined; + + @memcpy(&actual, token[0..32]); + @memcpy(&expected, @field(params, jetzig.authenticity_token_name)[0..32]); + + const valid = std.crypto.timing_safe.eql([32]u8, expected, actual); + + if (!valid) { + return try logFailure(request); + } + } else { + return try logFailure(request); + } +} diff --git a/src/jetzig/middleware/AuthMiddleware.zig b/src/jetzig/middleware/AuthMiddleware.zig index 22f8b35..f0d5019 100644 --- a/src/jetzig/middleware/AuthMiddleware.zig +++ b/src/jetzig/middleware/AuthMiddleware.zig @@ -11,11 +11,11 @@ const user_model = jetzig.config.get(jetzig.auth.AuthOptions, "auth").user_model /// they can also be modified. user: ?@TypeOf(jetzig.database.Query(user_model).find(0)).ResultType, -const Self = @This(); +const AuthMiddleware = @This(); /// Initialize middleware. -pub fn init(request: *jetzig.http.Request) !*Self { - const middleware = try request.allocator.create(Self); +pub fn init(request: *jetzig.http.Request) !*AuthMiddleware { + const middleware = try request.allocator.create(AuthMiddleware); middleware.* = .{ .user = null }; return middleware; } @@ -31,8 +31,11 @@ const map = std.StaticStringMap(void).initComptime(.{ /// /// User ID is accessible from a request: /// ```zig -/// -pub fn afterRequest(self: *Self, request: *jetzig.http.Request) !void { +/// if (request.middleware(.auth).user) |user| { +/// try request.server.log(.DEBUG, "{}", .{user.id}); +/// } +/// ``` +pub fn afterRequest(self: *AuthMiddleware, request: *jetzig.http.Request) !void { if (request.path.extension) |extension| { if (map.get(extension) == null) return; } @@ -47,6 +50,6 @@ pub fn afterRequest(self: *Self, request: *jetzig.http.Request) !void { /// Invoked after `afterRequest` is called, use this function to do any clean-up. /// Note that `request.allocator` is an arena allocator, so any allocations are automatically /// done before the next request starts processing. -pub fn deinit(self: *Self, request: *jetzig.http.Request) void { +pub fn deinit(self: *AuthMiddleware, request: *jetzig.http.Request) void { request.allocator.destroy(self); } diff --git a/src/jetzig/testing/App.zig b/src/jetzig/testing/App.zig index 9b99e42..ef4d2b7 100644 --- a/src/jetzig/testing/App.zig +++ b/src/jetzig/testing/App.zig @@ -15,6 +15,8 @@ multipart_boundary: ?[]const u8 = null, logger: jetzig.loggers.Logger, server: Server, repo: *jetzig.database.Repo, +cookies: *jetzig.http.Cookies, +session: *jetzig.http.Session, const Server = struct { logger: jetzig.loggers.Logger }; @@ -47,6 +49,13 @@ pub fn init(allocator: std.mem.Allocator, routes_module: type) !App { const app = try alloc.create(App); const repo = try alloc.create(jetzig.database.Repo); + const cookies = try alloc.create(jetzig.http.Cookies); + cookies.* = jetzig.http.Cookies.init(alloc, ""); + try cookies.parse(); + + const session = try alloc.create(jetzig.http.Session); + session.* = jetzig.http.Session.init(alloc, cookies, jetzig.testing.secret); + app.* = App{ .arena = arena, .allocator = allocator, @@ -57,6 +66,8 @@ pub fn init(allocator: std.mem.Allocator, routes_module: type) !App { .logger = logger, .server = .{ .logger = logger }, .repo = repo, + .cookies = cookies, + .session = session, }; repo.* = try jetzig.database.repo(alloc, app.*); @@ -149,30 +160,61 @@ pub fn request( try server.decodeStaticParams(); var buf: [1024]u8 = undefined; - var httpz_request = try stubbedRequest(allocator, &buf, method, path, self.multipart_boundary, options); + var httpz_request = try stubbedRequest( + allocator, + &buf, + method, + path, + self.multipart_boundary, + options, + self.cookies, + ); var httpz_response = try stubbedResponse(allocator); + try server.processNextRequest(&httpz_request, &httpz_response); - var headers = std.ArrayList(jetzig.testing.TestResponse.Header).init(self.arena.allocator()); + + { + const cookies = try allocator.create(jetzig.http.Cookies); + cookies.* = jetzig.http.Cookies.init(allocator, ""); + try cookies.parse(); + self.cookies = cookies; + } + + var headers = std.ArrayList(jetzig.testing.TestResponse.Header).init(allocator); for (0..httpz_response.headers.len) |index| { + const key = httpz_response.headers.keys[index]; + const value = httpz_response.headers.values[index]; + try headers.append(.{ - .name = try self.arena.allocator().dupe(u8, httpz_response.headers.keys[index]), - .value = try self.arena.allocator().dupe(u8, httpz_response.headers.values[index]), + .name = try allocator.dupe(u8, key), + .value = try allocator.dupe(u8, httpz_response.headers.values[index]), }); + + if (std.ascii.eqlIgnoreCase(key, "set-cookie")) { + // FIXME: We only expect one set-cookie header at the moment. + const cookies = try allocator.create(jetzig.http.Cookies); + cookies.* = jetzig.http.Cookies.init(allocator, value); + self.cookies = cookies; + try self.cookies.parse(); + } } + var data = jetzig.data.Data.init(allocator); defer data.deinit(); - var jobs = std.ArrayList(jetzig.testing.TestResponse.Job).init(self.arena.allocator()); + var jobs = std.ArrayList(jetzig.testing.TestResponse.Job).init(allocator); while (try self.job_queue.popFirst(&data, "__jetzig_jobs")) |value| { if (value.getT(.string, "__jetzig_job_name")) |job_name| try jobs.append(.{ - .name = try self.arena.allocator().dupe(u8, job_name), + .name = try allocator.dupe(u8, job_name), }); } + try self.initSession(); + return .{ - .allocator = self.arena.allocator(), + .allocator = allocator, .status = httpz_response.status, - .body = try self.arena.allocator().dupe(u8, httpz_response.body), + .body = try allocator.dupe(u8, httpz_response.body), .headers = try headers.toOwnedSlice(), .jobs = try jobs.toOwnedSlice(), }; @@ -189,6 +231,16 @@ pub fn params(self: App, args: anytype) []Param { return array.toOwnedSlice() catch @panic("OOM"); } +pub fn initSession(self: *App) !void { + const allocator = self.arena.allocator(); + + var local_session = try allocator.create(jetzig.http.Session); + local_session.* = jetzig.http.Session.init(allocator, self.cookies, jetzig.testing.secret); + try local_session.parse(); + + self.session = local_session; +} + /// Encode an arbitrary struct to a JSON string for use as a request body. pub fn json(self: App, args: anytype) []const u8 { const allocator = self.arena.allocator(); @@ -245,15 +297,29 @@ fn stubbedRequest( comptime path: []const u8, multipart_boundary: ?[]const u8, options: RequestOptions, + maybe_cookies: ?*const jetzig.http.Cookies, ) !httpz.Request { // TODO: Use httpz.testing var request_headers = try keyValue(allocator, 32); for (options.headers) |header| request_headers.add(header.name, header.value); + + if (maybe_cookies) |cookies| { + var cookie_buf = std.ArrayList(u8).init(allocator); + const cookie_writer = cookie_buf.writer(); + try cookie_writer.print("{}", .{cookies}); + const cookie = try cookie_buf.toOwnedSlice(); + request_headers.add("cookie", cookie); + } + if (options.json != null) { request_headers.add("accept", "application/json"); request_headers.add("content-type", "application/json"); } else if (multipart_boundary) |boundary| { - const header = try std.mem.concat(allocator, u8, &.{ "multipart/form-data; boundary=", boundary }); + const header = try std.mem.concat( + allocator, + u8, + &.{ "multipart/form-data; boundary=", boundary }, + ); request_headers.add("content-type", header); } @@ -358,7 +424,10 @@ fn buildOptions(allocator: std.mem.Allocator, app: *const App, args: anytype) !R } return .{ - .headers = if (@hasField(@TypeOf(args), "headers")) try buildHeaders(allocator, args.headers) else &.{}, + .headers = if (@hasField(@TypeOf(args), "headers")) + try buildHeaders(allocator, args.headers) + else + &.{}, .json = if (@hasField(@TypeOf(args), "json")) app.json(args.json) else null, .params = if (@hasField(@TypeOf(args), "params")) app.params(args.params) else null, .body = if (@hasField(@TypeOf(args), "body")) args.body else null, @@ -368,7 +437,12 @@ fn buildOptions(allocator: std.mem.Allocator, app: *const App, args: anytype) !R fn buildHeaders(allocator: std.mem.Allocator, args: anytype) ![]const jetzig.testing.TestResponse.Header { var headers = std.ArrayList(jetzig.testing.TestResponse.Header).init(allocator); inline for (std.meta.fields(@TypeOf(args))) |field| { - try headers.append(jetzig.testing.TestResponse.Header{ .name = field.name, .value = @field(args, field.name) }); + try headers.append( + jetzig.testing.TestResponse.Header{ + .name = field.name, + .value = @field(args, field.name), + }, + ); } return try headers.toOwnedSlice(); } diff --git a/src/jetzig/util.zig b/src/jetzig/util.zig index 0391b3f..0fc1a00 100644 --- a/src/jetzig/util.zig +++ b/src/jetzig/util.zig @@ -67,25 +67,25 @@ pub inline fn unquote(input: []const u8) []const u8 { } /// Generate a secure random string of `len` characters (for cryptographic purposes). -pub fn generateSecret(allocator: std.mem.Allocator, comptime len: u10) ![]const u8 { +pub fn generateSecret(allocator: std.mem.Allocator, len: u10) ![]const u8 { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - var secret: [len]u8 = undefined; + const secret = try allocator.alloc(u8, len); for (0..len) |index| { - secret[index] = chars[std.crypto.random.intRangeAtMost(u8, 0, chars.len)]; + secret[index] = chars[std.crypto.random.intRangeAtMost(u8, 0, chars.len - 1)]; } - return try allocator.dupe(u8, &secret); + return secret; } pub fn generateRandomString(buf: []u8) []const u8 { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; for (0..buf.len) |index| { - buf[index] = chars[std.crypto.random.intRangeAtMost(u8, 0, chars.len)]; + buf[index] = chars[std.crypto.random.intRangeAtMost(u8, 0, chars.len - 1)]; } - return buf; + return buf[0..]; } /// Calculate a duration from a given start time (in nanoseconds) to the current time. diff --git a/src/jetzig/views/Route.zig b/src/jetzig/views/Route.zig index 50a89a7..851dacf 100644 --- a/src/jetzig/views/Route.zig +++ b/src/jetzig/views/Route.zig @@ -56,6 +56,8 @@ json_params: []const []const u8, params: std.ArrayList(*jetzig.data.Data) = undefined, id: []const u8, formats: ?Formats = null, +before_callbacks: []const jetzig.callbacks.BeforeCallback = &.{}, +after_callbacks: []const jetzig.callbacks.AfterCallback = &.{}, /// Initializes a route's static params on server launch. Converts static params (JSON strings) /// to `jetzig.data.Data` values. Memory is owned by caller (`App.start()`).