diff --git a/crates/encoding/src/config.rs b/crates/encoding/src/config.rs index b6ef8857a..a6a40db86 100644 --- a/crates/encoding/src/config.rs +++ b/crates/encoding/src/config.rs @@ -234,6 +234,7 @@ impl WorkgroupCounts { path_tag_wgs }; let draw_object_wgs = (n_draw_objects + PATH_BBOX_WG - 1) / PATH_BBOX_WG; + let draw_monoid_wgs = draw_object_wgs.min(PATH_BBOX_WG); let flatten_wgs = (n_path_tags + FLATTEN_WG - 1) / FLATTEN_WG; let clip_reduce_wgs = n_clips.saturating_sub(1) / CLIP_REDUCE_WG; let clip_wgs = (n_clips + CLIP_REDUCE_WG - 1) / CLIP_REDUCE_WG; @@ -248,8 +249,8 @@ impl WorkgroupCounts { path_scan: (path_tag_wgs, 1, 1), bbox_clear: (draw_object_wgs, 1, 1), flatten: (flatten_wgs, 1, 1), - draw_reduce: (draw_object_wgs, 1, 1), - draw_leaf: (draw_object_wgs, 1, 1), + draw_reduce: (draw_monoid_wgs, 1, 1), + draw_leaf: (draw_monoid_wgs, 1, 1), clip_reduce: (clip_reduce_wgs, 1, 1), clip_leaf: (clip_wgs, 1, 1), binning: (draw_object_wgs, 1, 1), @@ -364,8 +365,9 @@ impl BufferSizes { let path_reduced_scan = BufferSize::new(path_tag_wgs); let path_monoids = BufferSize::new(path_tag_wgs * PATH_REDUCE_WG); let path_bboxes = BufferSize::new(n_paths); - let draw_object_wgs = workgroups.draw_reduce.0; - let draw_reduced = BufferSize::new(draw_object_wgs); + let binning_wgs = workgroups.binning.0; + let draw_monoid_wgs = workgroups.draw_reduce.0; + let draw_reduced = BufferSize::new(draw_monoid_wgs); let draw_monoids = BufferSize::new(n_draw_objects); let info = BufferSize::new(layout.bin_data_start); let clip_inps = BufferSize::new(n_clips); @@ -375,7 +377,7 @@ impl BufferSizes { let draw_bboxes = BufferSize::new(n_paths); let bump_alloc = BufferSize::new(1); let indirect_count = BufferSize::new(1); - let bin_headers = BufferSize::new(draw_object_wgs * 256); + let bin_headers = BufferSize::new(binning_wgs * 256); let n_paths_aligned = align_up(n_paths, 256); let paths = BufferSize::new(n_paths_aligned); diff --git a/examples/scenes/src/test_scenes.rs b/examples/scenes/src/test_scenes.rs index 02c31ac35..e03792a1a 100644 --- a/examples/scenes/src/test_scenes.rs +++ b/examples/scenes/src/test_scenes.rs @@ -2,7 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::{ExampleScene, SceneConfig, SceneParams, SceneSet}; -use vello::kurbo::{Affine, BezPath, Cap, Ellipse, Join, PathEl, Point, Rect, Shape, Stroke, Vec2}; +use vello::kurbo::{ + Affine, BezPath, Cap, Circle, Ellipse, Join, PathEl, Point, Rect, Shape, Stroke, Vec2, +}; use vello::peniko::*; use vello::*; @@ -60,6 +62,7 @@ pub fn test_scenes() -> SceneSet { scene!(longpathdash(Cap::Butt), "longpathdash (butt caps)", false), scene!(longpathdash(Cap::Round), "longpathdash (round caps)", false), scene!(crate::mmark::MMark::new(80_000), "mmark", false), + scene!(many_draw_objects), ]; SceneSet { scenes } @@ -1520,6 +1523,22 @@ fn make_diamond(cx: f64, cy: f64) -> [PathEl; 5] { ] } +fn many_draw_objects(scene: &mut Scene, params: &mut SceneParams) { + const N_WIDE: usize = 300; + const N_HIGH: usize = 300; + const SCENE_WIDTH: f64 = 2000.0; + const SCENE_HEIGHT: f64 = 1500.0; + params.resolution = Some((SCENE_WIDTH, SCENE_HEIGHT).into()); + for j in 0..N_HIGH { + let y = (j as f64 + 0.5) * (SCENE_HEIGHT / N_HIGH as f64); + for i in 0..N_WIDE { + let x = (i as f64 + 0.5) * (SCENE_WIDTH / N_WIDE as f64); + let c = Circle::new((x, y), 3.0); + scene.fill(Fill::NonZero, Affine::IDENTITY, Color::YELLOW, None, &c); + } + } +} + fn splash_screen(scene: &mut Scene, params: &mut SceneParams) { let strings = [ "Vello test", diff --git a/shader/coarse.wgsl b/shader/coarse.wgsl index 6caa35a13..758244c84 100644 --- a/shader/coarse.wgsl +++ b/shader/coarse.wgsl @@ -236,6 +236,7 @@ fn main( if wr_ix - rd_ix >= N_TILE || (wr_ix >= ready_ix && partition_ix >= n_partitions) { break; } + workgroupBarrier(); } // At this point, sh_drawobj_ix[0.. wr_ix - rd_ix] contains merged binning results. var tag = DRAWTAG_NOP; diff --git a/shader/draw_leaf.wgsl b/shader/draw_leaf.wgsl index 6f1a2e6c2..e041a5148 100644 --- a/shader/draw_leaf.wgsl +++ b/shader/draw_leaf.wgsl @@ -51,11 +51,9 @@ var sh_scratch: array; @compute @workgroup_size(256) fn main( - @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, ) { - let ix = global_id.x; // Reduce prefix of workgroups up to this one var agg = draw_monoid_identity(); if local_id.x < wg_id.x { @@ -74,184 +72,199 @@ fn main( // Two barriers can be eliminated if we use separate shared arrays // for prefix and intra-workgroup prefix sum. workgroupBarrier(); - var m = sh_scratch[0]; - workgroupBarrier(); - let tag_word = read_draw_tag_from_scene(ix); - agg = map_draw_tag(tag_word); - sh_scratch[local_id.x] = agg; - for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { + var prefix = sh_scratch[0]; + + // This is the same division of work as draw_reduce. + let num_blocks_total = (config.n_drawobj + WG_SIZE - 1u) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + let first_block = n_blocks_base * wg_id.x + min(wg_id.x, remainder); + let n_blocks = n_blocks_base + u32(wg_id.x < remainder); + var ix = first_block * WG_SIZE + local_id.x; + let ix_end = ix + n_blocks * WG_SIZE; + while ix != ix_end { + let tag_word = read_draw_tag_from_scene(ix); + agg = map_draw_tag(tag_word); workgroupBarrier(); - if local_id.x >= 1u << i { - let other = sh_scratch[local_id.x - (1u << i)]; - agg = combine_draw_monoid(agg, other); + sh_scratch[local_id.x] = agg; + for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { + workgroupBarrier(); + if local_id.x >= 1u << i { + let other = sh_scratch[local_id.x - (1u << i)]; + agg = combine_draw_monoid(agg, other); + } + workgroupBarrier(); + sh_scratch[local_id.x] = agg; } + var m = prefix; workgroupBarrier(); - sh_scratch[local_id.x] = agg; - } - workgroupBarrier(); - if local_id.x > 0u { - m = combine_draw_monoid(m, sh_scratch[local_id.x - 1u]); - } - // m now contains exclusive prefix sum of draw monoid - if ix < config.n_drawobj { - draw_monoid[ix] = m; - } - let dd = config.drawdata_base + m.scene_offset; - let di = m.info_offset; - if tag_word == DRAWTAG_FILL_COLOR || tag_word == DRAWTAG_FILL_LIN_GRADIENT || - tag_word == DRAWTAG_FILL_RAD_GRADIENT || tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || - tag_word == DRAWTAG_FILL_IMAGE || tag_word == DRAWTAG_BEGIN_CLIP - { - let bbox = path_bbox[m.path_ix]; - // TODO: bbox is mostly yagni here, sort that out. Maybe clips? - // let x0 = f32(bbox.x0); - // let y0 = f32(bbox.y0); - // let x1 = f32(bbox.x1); - // let y1 = f32(bbox.y1); - // let bbox_f = vec4(x0, y0, x1, y1); - var transform = Transform(); - let draw_flags = bbox.draw_flags; - if tag_word == DRAWTAG_FILL_LIN_GRADIENT || tag_word == DRAWTAG_FILL_RAD_GRADIENT || - tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || tag_word == DRAWTAG_FILL_IMAGE - { - transform = read_transform(config.transform_base, bbox.trans_ix); + if local_id.x > 0u { + m = combine_draw_monoid(m, sh_scratch[local_id.x - 1u]); } - switch tag_word { - case DRAWTAG_FILL_COLOR: { - info[di] = draw_flags; - } - case DRAWTAG_FILL_LIN_GRADIENT: { - info[di] = draw_flags; - var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); - p0 = transform_apply(transform, p0); - p1 = transform_apply(transform, p1); - let dxy = p1 - p0; - let scale = 1.0 / dot(dxy, dxy); - let line_xy = dxy * scale; - let line_c = -dot(p0, line_xy); - info[di + 1u] = bitcast(line_xy.x); - info[di + 2u] = bitcast(line_xy.y); - info[di + 3u] = bitcast(line_c); + // m now contains exclusive prefix sum of draw monoid + if ix < config.n_drawobj { + draw_monoid[ix] = m; + } + let dd = config.drawdata_base + m.scene_offset; + let di = m.info_offset; + if tag_word == DRAWTAG_FILL_COLOR || tag_word == DRAWTAG_FILL_LIN_GRADIENT || + tag_word == DRAWTAG_FILL_RAD_GRADIENT || tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || + tag_word == DRAWTAG_FILL_IMAGE || tag_word == DRAWTAG_BEGIN_CLIP + { + let bbox = path_bbox[m.path_ix]; + // TODO: bbox is mostly yagni here, sort that out. Maybe clips? + // let x0 = f32(bbox.x0); + // let y0 = f32(bbox.y0); + // let x1 = f32(bbox.x1); + // let y1 = f32(bbox.y1); + // let bbox_f = vec4(x0, y0, x1, y1); + var transform = Transform(); + let draw_flags = bbox.draw_flags; + if tag_word == DRAWTAG_FILL_LIN_GRADIENT || tag_word == DRAWTAG_FILL_RAD_GRADIENT || + tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || tag_word == DRAWTAG_FILL_IMAGE + { + transform = read_transform(config.transform_base, bbox.trans_ix); } - case DRAWTAG_FILL_RAD_GRADIENT: { - // Two-point conical gradient implementation based - // on the algorithm at - // This epsilon matches what Skia uses - let GRADIENT_EPSILON = 1.0 / f32(1u << 12u); - info[di] = draw_flags; - var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); - var r0 = bitcast(scene[dd + 5u]); - var r1 = bitcast(scene[dd + 6u]); - let user_to_gradient = transform_inverse(transform); - // Output variables - var xform = Transform(); - var focal_x = 0.0; - var radius = 0.0; - var kind = 0u; - var flags = 0u; - if abs(r0 - r1) <= GRADIENT_EPSILON { - // When the radii are the same, emit a strip gradient - kind = RAD_GRAD_KIND_STRIP; - let scaled = r0 / distance(p0, p1); - xform = transform_mul( - two_point_to_unit_line(p0, p1), - user_to_gradient - ); - radius = scaled * scaled; - } else { - // Assume a two point conical gradient unless the centers - // are equal. - kind = RAD_GRAD_KIND_CONE; - if all(p0 == p1) { - kind = RAD_GRAD_KIND_CIRCULAR; - // Nudge p0 a bit to avoid denormals. - p0 += GRADIENT_EPSILON; - } - if r1 == 0.0 { - // If r1 == 0.0, swap the points and radii - flags |= RAD_GRAD_SWAPPED; - let tmp_p = p0; - p0 = p1; - p1 = tmp_p; - let tmp_r = r0; - r0 = r1; - r1 = tmp_r; - } - focal_x = r0 / (r0 - r1); - let cf = (1.0 - focal_x) * p0 + focal_x * p1; - radius = r1 / (distance(cf, p1)); - let user_to_unit_line = transform_mul( - two_point_to_unit_line(cf, p1), - user_to_gradient - ); - var user_to_scaled = user_to_unit_line; - // When r == 1.0, focal point is on circle - if abs(radius - 1.0) <= GRADIENT_EPSILON { - kind = RAD_GRAD_KIND_FOCAL_ON_CIRCLE; - let scale = 0.5 * abs(1.0 - focal_x); - user_to_scaled = transform_mul( - Transform(vec4(scale, 0.0, 0.0, scale), vec2(0.0)), - user_to_unit_line + switch tag_word { + case DRAWTAG_FILL_COLOR: { + info[di] = draw_flags; + } + case DRAWTAG_FILL_LIN_GRADIENT: { + info[di] = draw_flags; + var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); + p0 = transform_apply(transform, p0); + p1 = transform_apply(transform, p1); + let dxy = p1 - p0; + let scale = 1.0 / dot(dxy, dxy); + let line_xy = dxy * scale; + let line_c = -dot(p0, line_xy); + info[di + 1u] = bitcast(line_xy.x); + info[di + 2u] = bitcast(line_xy.y); + info[di + 3u] = bitcast(line_c); + } + case DRAWTAG_FILL_RAD_GRADIENT: { + // Two-point conical gradient implementation based + // on the algorithm at + // This epsilon matches what Skia uses + let GRADIENT_EPSILON = 1.0 / f32(1u << 12u); + info[di] = draw_flags; + var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); + var r0 = bitcast(scene[dd + 5u]); + var r1 = bitcast(scene[dd + 6u]); + let user_to_gradient = transform_inverse(transform); + // Output variables + var xform = Transform(); + var focal_x = 0.0; + var radius = 0.0; + var kind = 0u; + var flags = 0u; + if abs(r0 - r1) <= GRADIENT_EPSILON { + // When the radii are the same, emit a strip gradient + kind = RAD_GRAD_KIND_STRIP; + let scaled = r0 / distance(p0, p1); + xform = transform_mul( + two_point_to_unit_line(p0, p1), + user_to_gradient ); + radius = scaled * scaled; } else { - let a = radius * radius - 1.0; - let scale_ratio = abs(1.0 - focal_x) / a; - let scale_x = radius * scale_ratio; - let scale_y = sqrt(abs(a)) * scale_ratio; - user_to_scaled = transform_mul( - Transform(vec4(scale_x, 0.0, 0.0, scale_y), vec2(0.0)), - user_to_unit_line + // Assume a two point conical gradient unless the centers + // are equal. + kind = RAD_GRAD_KIND_CONE; + if all(p0 == p1) { + kind = RAD_GRAD_KIND_CIRCULAR; + // Nudge p0 a bit to avoid denormals. + p0 += GRADIENT_EPSILON; + } + if r1 == 0.0 { + // If r1 == 0.0, swap the points and radii + flags |= RAD_GRAD_SWAPPED; + let tmp_p = p0; + p0 = p1; + p1 = tmp_p; + let tmp_r = r0; + r0 = r1; + r1 = tmp_r; + } + focal_x = r0 / (r0 - r1); + let cf = (1.0 - focal_x) * p0 + focal_x * p1; + radius = r1 / (distance(cf, p1)); + let user_to_unit_line = transform_mul( + two_point_to_unit_line(cf, p1), + user_to_gradient ); + var user_to_scaled = user_to_unit_line; + // When r == 1.0, focal point is on circle + if abs(radius - 1.0) <= GRADIENT_EPSILON { + kind = RAD_GRAD_KIND_FOCAL_ON_CIRCLE; + let scale = 0.5 * abs(1.0 - focal_x); + user_to_scaled = transform_mul( + Transform(vec4(scale, 0.0, 0.0, scale), vec2(0.0)), + user_to_unit_line + ); + } else { + let a = radius * radius - 1.0; + let scale_ratio = abs(1.0 - focal_x) / a; + let scale_x = radius * scale_ratio; + let scale_y = sqrt(abs(a)) * scale_ratio; + user_to_scaled = transform_mul( + Transform(vec4(scale_x, 0.0, 0.0, scale_y), vec2(0.0)), + user_to_unit_line + ); + } + xform = user_to_scaled; } - xform = user_to_scaled; + info[di + 1u] = bitcast(xform.matrx.x); + info[di + 2u] = bitcast(xform.matrx.y); + info[di + 3u] = bitcast(xform.matrx.z); + info[di + 4u] = bitcast(xform.matrx.w); + info[di + 5u] = bitcast(xform.translate.x); + info[di + 6u] = bitcast(xform.translate.y); + info[di + 7u] = bitcast(focal_x); + info[di + 8u] = bitcast(radius); + info[di + 9u] = bitcast((flags << 3u) | kind); } - info[di + 1u] = bitcast(xform.matrx.x); - info[di + 2u] = bitcast(xform.matrx.y); - info[di + 3u] = bitcast(xform.matrx.z); - info[di + 4u] = bitcast(xform.matrx.w); - info[di + 5u] = bitcast(xform.translate.x); - info[di + 6u] = bitcast(xform.translate.y); - info[di + 7u] = bitcast(focal_x); - info[di + 8u] = bitcast(radius); - info[di + 9u] = bitcast((flags << 3u) | kind); - } - case DRAWTAG_FILL_SWEEP_GRADIENT: { - info[di] = draw_flags; - let p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - let xform = transform_mul(transform, Transform(vec4(1.0, 0.0, 0.0, 1.0), p0)); - let inv = transform_inverse(xform); - info[di + 1u] = bitcast(inv.matrx.x); - info[di + 2u] = bitcast(inv.matrx.y); - info[di + 3u] = bitcast(inv.matrx.z); - info[di + 4u] = bitcast(inv.matrx.w); - info[di + 5u] = bitcast(inv.translate.x); - info[di + 6u] = bitcast(inv.translate.y); - info[di + 7u] = scene[dd + 3u]; - info[di + 8u] = scene[dd + 4u]; - } - case DRAWTAG_FILL_IMAGE: { - info[di] = draw_flags; - let inv = transform_inverse(transform); - info[di + 1u] = bitcast(inv.matrx.x); - info[di + 2u] = bitcast(inv.matrx.y); - info[di + 3u] = bitcast(inv.matrx.z); - info[di + 4u] = bitcast(inv.matrx.w); - info[di + 5u] = bitcast(inv.translate.x); - info[di + 6u] = bitcast(inv.translate.y); - info[di + 7u] = scene[dd]; - info[di + 8u] = scene[dd + 1u]; + case DRAWTAG_FILL_SWEEP_GRADIENT: { + info[di] = draw_flags; + let p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + let xform = transform_mul(transform, Transform(vec4(1.0, 0.0, 0.0, 1.0), p0)); + let inv = transform_inverse(xform); + info[di + 1u] = bitcast(inv.matrx.x); + info[di + 2u] = bitcast(inv.matrx.y); + info[di + 3u] = bitcast(inv.matrx.z); + info[di + 4u] = bitcast(inv.matrx.w); + info[di + 5u] = bitcast(inv.translate.x); + info[di + 6u] = bitcast(inv.translate.y); + info[di + 7u] = scene[dd + 3u]; + info[di + 8u] = scene[dd + 4u]; + } + case DRAWTAG_FILL_IMAGE: { + info[di] = draw_flags; + let inv = transform_inverse(transform); + info[di + 1u] = bitcast(inv.matrx.x); + info[di + 2u] = bitcast(inv.matrx.y); + info[di + 3u] = bitcast(inv.matrx.z); + info[di + 4u] = bitcast(inv.matrx.w); + info[di + 5u] = bitcast(inv.translate.x); + info[di + 6u] = bitcast(inv.translate.y); + info[di + 7u] = scene[dd]; + info[di + 8u] = scene[dd + 1u]; + } + default: {} } - default: {} } - } - if tag_word == DRAWTAG_BEGIN_CLIP || tag_word == DRAWTAG_END_CLIP { - var path_ix = ~ix; - if tag_word == DRAWTAG_BEGIN_CLIP { - path_ix = m.path_ix; + if tag_word == DRAWTAG_BEGIN_CLIP || tag_word == DRAWTAG_END_CLIP { + var path_ix = ~ix; + if tag_word == DRAWTAG_BEGIN_CLIP { + path_ix = m.path_ix; + } + clip_inp[m.clip_ix] = ClipInp(ix, i32(path_ix)); } - clip_inp[m.clip_ix] = ClipInp(ix, i32(path_ix)); + ix += WG_SIZE; + // break here on end to save monoid aggregation? + prefix = combine_draw_monoid(prefix, sh_scratch[WG_SIZE - 1u]); } } diff --git a/shader/draw_reduce.wgsl b/shader/draw_reduce.wgsl index a39758bd3..7a12b8188 100644 --- a/shader/draw_reduce.wgsl +++ b/shader/draw_reduce.wgsl @@ -13,7 +13,7 @@ var scene: array; @group(0) @binding(2) var reduced: array; -let WG_SIZE = 256u; +const WG_SIZE = 256u; var sh_scratch: array; @@ -21,12 +21,24 @@ var sh_scratch: array; @compute @workgroup_size(256) fn main( - @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, ) { - let ix = global_id.x; - let tag_word = read_draw_tag_from_scene(ix); - var agg = map_draw_tag(tag_word); + let num_blocks_total = (config.n_drawobj + (WG_SIZE - 1u)) / WG_SIZE; + // When the number of blocks exceeds the workgroup size, divide + // the work evenly so each workgroup handles n_blocks / wg, with + // the low workgroups doing one more each to handle the remainder. + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + let first_block = n_blocks_base * wg_id.x + min(wg_id.x, remainder); + let n_blocks = n_blocks_base + u32(wg_id.x < remainder); + var block_index = first_block * WG_SIZE + local_id.x; + var agg = draw_monoid_identity(); + for (var i = 0u; i < n_blocks; i++) { + let tag_word = read_draw_tag_from_scene(block_index); + agg = combine_draw_monoid(agg, map_draw_tag(tag_word)); + block_index += WG_SIZE; + } sh_scratch[local_id.x] = agg; for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { workgroupBarrier(); @@ -38,6 +50,6 @@ fn main( sh_scratch[local_id.x] = agg; } if local_id.x == 0u { - reduced[ix >> firstTrailingBit(WG_SIZE)] = agg; + reduced[wg_id.x] = agg; } } diff --git a/src/cpu_shader/draw_leaf.rs b/src/cpu_shader/draw_leaf.rs index df86f6c3b..c1a958c1e 100644 --- a/src/cpu_shader/draw_leaf.rs +++ b/src/cpu_shader/draw_leaf.rs @@ -23,11 +23,16 @@ fn draw_leaf_main( info: &mut [u32], clip_inp: &mut [Clip], ) { + let num_blocks_total = (config.layout.n_draw_objects as usize + (WG_SIZE - 1)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; let mut prefix = DrawMonoid::default(); - for i in 0..n_wg { + for i in 0..n_wg as usize { + let first_block = n_blocks_base * i + i.min(remainder); + let n_blocks = n_blocks_base + (i < remainder) as usize; let mut m = prefix; - for j in 0..WG_SIZE { - let ix = i * WG_SIZE as u32 + j as u32; + for j in 0..WG_SIZE * n_blocks { + let ix = (first_block * WG_SIZE) as u32 + j as u32; let tag_raw = read_draw_tag_from_scene(config, scene, ix); let tag_word = DrawTag(tag_raw); // store exclusive prefix sum @@ -185,7 +190,7 @@ fn draw_leaf_main( } m = m_next; } - prefix = prefix.combine(&reduced[i as usize]); + prefix = prefix.combine(&reduced[i]); } } diff --git a/src/cpu_shader/draw_reduce.rs b/src/cpu_shader/draw_reduce.rs index 24e15a134..bc2104efc 100644 --- a/src/cpu_shader/draw_reduce.rs +++ b/src/cpu_shader/draw_reduce.rs @@ -10,14 +10,19 @@ use super::util::read_draw_tag_from_scene; const WG_SIZE: usize = 256; fn draw_reduce_main(n_wg: u32, config: &ConfigUniform, scene: &[u32], reduced: &mut [DrawMonoid]) { - for i in 0..n_wg { + let num_blocks_total = (config.layout.n_draw_objects as usize + (WG_SIZE - 1)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + for i in 0..n_wg as usize { + let first_block = n_blocks_base * i + i.min(remainder); + let n_blocks = n_blocks_base + (i < remainder) as usize; let mut m = DrawMonoid::default(); - for j in 0..WG_SIZE { - let ix = i * WG_SIZE as u32 + j as u32; + for j in 0..WG_SIZE * n_blocks { + let ix = (first_block * WG_SIZE) as u32 + j as u32; let tag = read_draw_tag_from_scene(config, scene, ix); m = m.combine(&DrawMonoid::new(DrawTag(tag))); } - reduced[i as usize] = m; + reduced[i] = m; } }