Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reusable PBR shader types/bindings/functions #3969

Closed
wants to merge 9 commits into from
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ name = "scene"
path = "examples/scene/scene.rs"

# Shaders
[[example]]
name = "array_texture"
path = "examples/shader/array_texture.rs"

[[example]]
name = "shader_defs"
path = "examples/shader/shader_defs.rs"
Expand Down
55 changes: 55 additions & 0 deletions assets/shaders/array_texture.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#import bevy_pbr::mesh_view_types
#import bevy_pbr::mesh_view_bindings
#import bevy_pbr::mesh_types
#import bevy_pbr::mesh_bindings
// NOTE: Bindings must come before functions that use them!
#import bevy_pbr::mesh_functions

[[group(1), binding(0)]]
var my_array_texture: texture_2d_array<f32>;
[[group(1), binding(1)]]
var my_array_texture_sampler: sampler;

struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
};

struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
[[location(0)]] position: vec4<f32>;
};

[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
var out: VertexOutput;
out.clip_position = mesh_model_position_to_clip(vec4<f32>(vertex.position, 1.0));
out.position = out.clip_position;
return out;
}

struct FragmentInput {
[[location(0)]] clip_position: vec4<f32>;
};

[[stage(fragment)]]
fn fragment(in: FragmentInput) -> [[location(0)]] vec4<f32> {
// Screen-space coordinates determine which layer of the array texture we sample.
let ss = in.clip_position.xy / in.clip_position.w;
var layer: f32 = 0.0;
if (ss.x > 0.0 && ss.y > 0.0) {
layer = 0.0;
} else if (ss.x < 0.0 && ss.y > 0.0) {
layer = 1.0;
} else if (ss.x > 0.0 && ss.y < 0.0) {
layer = 2.0;
} else {
layer = 3.0;
}

// Convert to texture coordinates.
let uv = (ss + vec2<f32>(1.0)) / 2.0;

return textureSampleLevel(my_array_texture, my_array_texture_sampler, uv, i32(layer), 0.0);
}
12 changes: 7 additions & 5 deletions assets/shaders/shader_defs.wgsl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_struct
#import bevy_pbr::mesh_view_types
#import bevy_pbr::mesh_view_bindings
#import bevy_pbr::mesh_types

[[group(1), binding(0)]]
var<uniform> mesh: Mesh;

// NOTE: Bindings must come before functions that use them!
#import bevy_pbr::mesh_functions

struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
Expand All @@ -16,10 +20,8 @@ struct VertexOutput {

[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.model * vec4<f32>(vertex.position, 1.0);

var out: VertexOutput;
out.clip_position = view.view_proj * world_position;
out.clip_position = mesh_model_position_to_clip(vec4<f32>(vertex.position, 1.0));
return out;
}

Expand Down
21 changes: 21 additions & 0 deletions crates/bevy_pbr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ use bevy_transform::TransformSystem;

pub const PBR_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 4805239651767701046);
pub const PBR_TYPES_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 14465578778686805602);
pub const PBR_BINDINGS_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1501114814264999179);
pub const PBR_FUNCTIONS_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1167493567156271479);
pub const SHADOW_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1836745567947005696);

Expand All @@ -62,6 +68,21 @@ impl Plugin for PbrPlugin {
PBR_SHADER_HANDLE,
Shader::from_wgsl(include_str!("render/pbr.wgsl")),
);
shaders.set_untracked(
PBR_TYPES_HANDLE,
Shader::from_wgsl(include_str!("render/pbr_types.wgsl"))
.with_import_path("bevy_pbr::pbr_types"),
);
shaders.set_untracked(
PBR_BINDINGS_HANDLE,
Shader::from_wgsl(include_str!("render/pbr_bindings.wgsl"))
.with_import_path("bevy_pbr::pbr_bindings"),
);
shaders.set_untracked(
PBR_FUNCTIONS_HANDLE,
Shader::from_wgsl(include_str!("render/pbr_functions.wgsl"))
.with_import_path("bevy_pbr::pbr_functions"),
);
shaders.set_untracked(
SHADOW_SHADER_HANDLE,
Shader::from_wgsl(include_str!("render/depth.wgsl")),
Expand Down
14 changes: 6 additions & 8 deletions crates/bevy_pbr/src/render/depth.wgsl
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#import bevy_pbr::mesh_struct
#import bevy_pbr::mesh_view_types
#import bevy_pbr::mesh_types

// NOTE: Keep in sync with pbr.wgsl
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;

[[group(1), binding(0)]]
var<uniform> mesh: Mesh;

// NOTE: Bindings must come before functions that use them!
#import bevy_pbr::mesh_functions

struct Vertex {
[[location(0)]] position: vec3<f32>;
};
Expand All @@ -23,6 +21,6 @@ struct VertexOutput {
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
var out: VertexOutput;
out.clip_position = view.view_proj * mesh.model * vec4<f32>(vertex.position, 1.0);
out.clip_position = mesh_model_position_to_clip(vec4<f32>(vertex.position, 1.0));
return out;
}
37 changes: 29 additions & 8 deletions crates/bevy_pbr/src/render/mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ use bevy_transform::components::GlobalTransform;
#[derive(Default)]
pub struct MeshRenderPlugin;

pub const MESH_VIEW_BIND_GROUP_HANDLE: HandleUntyped =
pub const MESH_VIEW_TYPES_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 6944437233335238185);
pub const MESH_VIEW_BINDINGS_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 9076678235888822571);
pub const MESH_STRUCT_HANDLE: HandleUntyped =
pub const MESH_TYPES_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 2506024101911992377);
pub const MESH_BINDINGS_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 17763658410392053870);
pub const MESH_FUNCTIONS_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 8157763673499264335);
pub const MESH_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 3252377289100772450);

Expand All @@ -41,14 +47,29 @@ impl Plugin for MeshRenderPlugin {
Shader::from_wgsl(include_str!("mesh.wgsl")),
);
shaders.set_untracked(
MESH_STRUCT_HANDLE,
Shader::from_wgsl(include_str!("mesh_struct.wgsl"))
.with_import_path("bevy_pbr::mesh_struct"),
MESH_VIEW_TYPES_HANDLE,
Shader::from_wgsl(include_str!("mesh_view_types.wgsl"))
.with_import_path("bevy_pbr::mesh_view_types"),
);
shaders.set_untracked(
MESH_VIEW_BIND_GROUP_HANDLE,
Shader::from_wgsl(include_str!("mesh_view_bind_group.wgsl"))
.with_import_path("bevy_pbr::mesh_view_bind_group"),
MESH_VIEW_BINDINGS_HANDLE,
Shader::from_wgsl(include_str!("mesh_view_bindings.wgsl"))
.with_import_path("bevy_pbr::mesh_view_bindings"),
);
shaders.set_untracked(
MESH_TYPES_HANDLE,
Shader::from_wgsl(include_str!("mesh_types.wgsl"))
.with_import_path("bevy_pbr::mesh_types"),
);
shaders.set_untracked(
MESH_BINDINGS_HANDLE,
Shader::from_wgsl(include_str!("mesh_bindings.wgsl"))
.with_import_path("bevy_pbr::mesh_bindings"),
);
shaders.set_untracked(
MESH_FUNCTIONS_HANDLE,
Shader::from_wgsl(include_str!("mesh_functions.wgsl"))
.with_import_path("bevy_pbr::mesh_functions"),
);

app.add_plugin(UniformComponentPlugin::<MeshUniform>::default());
Expand Down
32 changes: 11 additions & 21 deletions crates/bevy_pbr/src/render/mesh.wgsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_struct
#import bevy_pbr::mesh_view_types
#import bevy_pbr::mesh_view_bindings
#import bevy_pbr::mesh_types
#import bevy_pbr::mesh_bindings
// NOTE: Bindings must come before functions that use them!
#import bevy_pbr::mesh_functions

struct Vertex {
[[location(0)]] position: vec3<f32>;
Expand All @@ -20,31 +24,17 @@ struct VertexOutput {
#endif
};

[[group(2), binding(0)]]
var<uniform> mesh: Mesh;

[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.model * vec4<f32>(vertex.position, 1.0);
let world_position = mesh_model_position_to_world(vec4<f32>(vertex.position, 1.0));

var out: VertexOutput;
out.uv = vertex.uv;
out.world_position = world_position;
out.clip_position = view.view_proj * world_position;
out.world_normal = mat3x3<f32>(
mesh.inverse_transpose_model[0].xyz,
mesh.inverse_transpose_model[1].xyz,
mesh.inverse_transpose_model[2].xyz
) * vertex.normal;
out.clip_position = mesh_world_position_to_clip(world_position);
out.world_normal = mesh_model_normal_to_world(vertex.normal);
#ifdef VERTEX_TANGENTS
out.world_tangent = vec4<f32>(
mat3x3<f32>(
mesh.model[0].xyz,
mesh.model[1].xyz,
mesh.model[2].xyz
) * vertex.tangent.xyz,
vertex.tangent.w
);
out.world_tangent = mesh_model_tangent_to_world(vertex.tangent);
#endif
return out;
}
Expand All @@ -62,4 +52,4 @@ struct FragmentInput {
[[stage(fragment)]]
fn fragment(in: FragmentInput) -> [[location(0)]] vec4<f32> {
return vec4<f32>(1.0, 0.0, 1.0, 1.0);
}
}
2 changes: 2 additions & 0 deletions crates/bevy_pbr/src/render/mesh_bindings.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[[group(2), binding(0)]]
var<uniform> mesh: Mesh;
34 changes: 34 additions & 0 deletions crates/bevy_pbr/src/render/mesh_functions.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
fn mesh_model_position_to_world(vertex_position: vec4<f32>) -> vec4<f32> {
superdump marked this conversation as resolved.
Show resolved Hide resolved
return mesh.model * vertex_position;
}

fn mesh_world_position_to_clip(world_position: vec4<f32>) -> vec4<f32> {
return view.view_proj * world_position;
}

// NOTE: The intermediate world_position assignment is important
// for precision purposes when using the 'equals' depth comparison
// function.
fn mesh_model_position_to_clip(vertex_position: vec4<f32>) -> vec4<f32> {
let world_position = mesh_model_position_to_world(vertex_position);
return mesh_world_position_to_clip(world_position);
}

fn mesh_model_normal_to_world(vertex_normal: vec3<f32>) -> vec3<f32> {
return mat3x3<f32>(
mesh.inverse_transpose_model[0].xyz,
mesh.inverse_transpose_model[1].xyz,
mesh.inverse_transpose_model[2].xyz
) * vertex_normal;
}

fn mesh_model_tangent_to_world(vertex_tangent: vec4<f32>) -> vec4<f32> {
return vec4<f32>(
mat3x3<f32>(
mesh.model[0].xyz,
mesh.model[1].xyz,
mesh.model[2].xyz
) * vertex_tangent.xyz,
vertex_tangent.w
);
}
28 changes: 28 additions & 0 deletions crates/bevy_pbr/src/render/mesh_view_bindings.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(0), binding(1)]]
var<uniform> lights: Lights;
#ifdef NO_ARRAY_TEXTURES_SUPPORT
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube;
#else
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube_array;
#endif
[[group(0), binding(3)]]
var point_shadow_textures_sampler: sampler_comparison;
#ifdef NO_ARRAY_TEXTURES_SUPPORT
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d;
#else
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d_array;
#endif
[[group(0), binding(5)]]
var directional_shadow_textures_sampler: sampler_comparison;
[[group(0), binding(6)]]
var<uniform> point_lights: PointLights;
[[group(0), binding(7)]]
var<uniform> cluster_light_index_lists: ClusterLightIndexLists;
[[group(0), binding(8)]]
var<uniform> cluster_offsets_and_counts: ClusterOffsetsAndCounts;
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,3 @@ struct ClusterOffsetsAndCounts {
// and an 8-bit count of the number of lights in the low 8 bits
data: array<vec4<u32>, 1024u>;
};

[[group(0), binding(0)]]
var<uniform> view: View;
[[group(0), binding(1)]]
var<uniform> lights: Lights;
#ifdef NO_ARRAY_TEXTURES_SUPPORT
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube;
#else
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube_array;
#endif
[[group(0), binding(3)]]
var point_shadow_textures_sampler: sampler_comparison;
#ifdef NO_ARRAY_TEXTURES_SUPPORT
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d;
#else
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d_array;
#endif
[[group(0), binding(5)]]
var directional_shadow_textures_sampler: sampler_comparison;
[[group(0), binding(6)]]
var<uniform> point_lights: PointLights;
[[group(0), binding(7)]]
var<uniform> cluster_light_index_lists: ClusterLightIndexLists;
[[group(0), binding(8)]]
var<uniform> cluster_offsets_and_counts: ClusterOffsetsAndCounts;
Loading