Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move ShaderCache shader defs into PipelineCache #7903

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions crates/bevy_core_pipeline/src/blit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,18 @@ pub struct BlitPipelineKey {
impl SpecializedRenderPipeline for BlitPipeline {
type Key = BlitPipelineKey;

fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
fn specialize(
&self,
key: Self::Key,
shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("blit pipeline".into()),
layout: vec![self.texture_bind_group.clone()],
vertex: fullscreen_shader_vertex_state(),
vertex: fullscreen_shader_vertex_state(shader_defs.clone()),
fragment: Some(FragmentState {
shader: BLIT_SHADER_HANDLE.typed(),
shader_defs: vec![],
shader_defs,
entry_point: "fs_main".into(),
targets: vec![Some(ColorTargetState {
format: key.texture_format,
Expand Down
10 changes: 6 additions & 4 deletions crates/bevy_core_pipeline/src/bloom/downsampling_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ impl FromWorld for BloomDownsamplingPipeline {
impl SpecializedRenderPipeline for BloomDownsamplingPipeline {
type Key = BloomDownsamplingPipelineKeys;

fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
fn specialize(
&self,
key: Self::Key,
mut shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor {
let layout = vec![self.bind_group_layout.clone()];

let entry_point = if key.first_downsample {
Expand All @@ -108,8 +112,6 @@ impl SpecializedRenderPipeline for BloomDownsamplingPipeline {
"downsample".into()
};

let mut shader_defs = vec![];

if key.first_downsample {
shader_defs.push("FIRST_DOWNSAMPLE".into());
}
Expand All @@ -128,7 +130,7 @@ impl SpecializedRenderPipeline for BloomDownsamplingPipeline {
.into(),
),
layout,
vertex: fullscreen_shader_vertex_state(),
vertex: fullscreen_shader_vertex_state(shader_defs.clone()),
fragment: Some(FragmentState {
shader: BLOOM_SHADER_HANDLE.typed::<Shader>(),
shader_defs,
Expand Down
10 changes: 7 additions & 3 deletions crates/bevy_core_pipeline/src/bloom/upsampling_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ impl FromWorld for BloomUpsamplingPipeline {
impl SpecializedRenderPipeline for BloomUpsamplingPipeline {
type Key = BloomUpsamplingPipelineKeys;

fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
fn specialize(
&self,
key: Self::Key,
shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor {
let texture_format = if key.final_pipeline {
ViewTarget::TEXTURE_FORMAT_HDR
} else {
Expand Down Expand Up @@ -116,10 +120,10 @@ impl SpecializedRenderPipeline for BloomUpsamplingPipeline {
RenderPipelineDescriptor {
label: Some("bloom_upsampling_pipeline".into()),
layout: vec![self.bind_group_layout.clone()],
vertex: fullscreen_shader_vertex_state(),
vertex: fullscreen_shader_vertex_state(shader_defs.clone()),
fragment: Some(FragmentState {
shader: BLOOM_SHADER_HANDLE.typed::<Shader>(),
shader_defs: vec![],
shader_defs,
entry_point: "upsample".into(),
targets: vec![Some(ColorTargetState {
format: texture_format,
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_core_pipeline/src/fullscreen_vertex_shader/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use bevy_asset::HandleUntyped;
use bevy_reflect::TypeUuid;
use bevy_render::{prelude::Shader, render_resource::VertexState};
use bevy_render::{
prelude::Shader,
render_resource::{ShaderDefVal, VertexState},
};

pub const FULLSCREEN_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 7837534426033940724);
Expand All @@ -16,10 +19,10 @@ pub const FULLSCREEN_SHADER_HANDLE: HandleUntyped =
/// ```
/// from the vertex shader.
/// The draw call should render one triangle: `render_pass.draw(0..3, 0..1);`
pub fn fullscreen_shader_vertex_state() -> VertexState {
pub fn fullscreen_shader_vertex_state(shader_defs: Vec<ShaderDefVal>) -> VertexState {
VertexState {
shader: FULLSCREEN_SHADER_HANDLE.typed(),
shader_defs: Vec::new(),
shader_defs,
entry_point: "fullscreen_vertex_shader".into(),
buffers: Vec::new(),
}
Expand Down
19 changes: 13 additions & 6 deletions crates/bevy_core_pipeline/src/fxaa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,24 @@ pub struct FxaaPipelineKey {
impl SpecializedRenderPipeline for FxaaPipeline {
type Key = FxaaPipelineKey;

fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
fn specialize(
&self,
key: Self::Key,
shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
label: Some("fxaa".into()),
layout: vec![self.texture_bind_group.clone()],
vertex: fullscreen_shader_vertex_state(),
vertex: fullscreen_shader_vertex_state(shader_defs.clone()),
fragment: Some(FragmentState {
shader: FXAA_SHADER_HANDLE.typed(),
shader_defs: vec![
format!("EDGE_THRESH_{}", key.edge_threshold.get_str()).into(),
format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()).into(),
],
shader_defs: shader_defs
.into_iter()
.chain([
format!("EDGE_THRESH_{}", key.edge_threshold.get_str()).into(),
format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()).into(),
])
.collect(),
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format: key.texture_format,
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_core_pipeline/src/tonemapping/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ pub struct TonemappingPipelineKey {
impl SpecializedRenderPipeline for TonemappingPipeline {
type Key = TonemappingPipelineKey;

fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let mut shader_defs = Vec::new();
fn specialize(
&self,
key: Self::Key,
mut shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor {
if let DebandDither::Enabled = key.deband_dither {
shader_defs.push("DEBAND_DITHER".into());
}
Expand All @@ -208,7 +211,7 @@ impl SpecializedRenderPipeline for TonemappingPipeline {
RenderPipelineDescriptor {
label: Some("tonemapping pipeline".into()),
layout: vec![self.texture_bind_group.clone()],
vertex: fullscreen_shader_vertex_state(),
vertex: fullscreen_shader_vertex_state(shader_defs.clone()),
fragment: Some(FragmentState {
shader: TONEMAPPING_SHADER_HANDLE.typed(),
shader_defs,
Expand Down
5 changes: 3 additions & 2 deletions crates/bevy_gizmos/src/pipeline_2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl SpecializedMeshPipeline for GizmoLinePipeline {
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let vertex_buffer_layout = layout.get_layout(&[
Mesh::ATTRIBUTE_POSITION.at_shader_location(0),
Expand All @@ -57,12 +58,12 @@ impl SpecializedMeshPipeline for GizmoLinePipeline {
vertex: VertexState {
shader: self.shader.clone_weak(),
entry_point: "vertex".into(),
shader_defs: vec![],
shader_defs: shader_defs.clone(),
buffers: vec![vertex_buffer_layout],
},
fragment: Some(FragmentState {
shader: self.shader.clone_weak(),
shader_defs: vec![],
shader_defs,
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {
format,
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_gizmos/src/pipeline_3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ impl SpecializedMeshPipeline for GizmoPipeline {
&self,
(depth_test, key): Self::Key,
layout: &MeshVertexBufferLayout,
mut shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut shader_defs = Vec::new();
shader_defs.push("GIZMO_LINES_3D".into());
shader_defs.push(ShaderDefVal::Int(
"MAX_DIRECTIONAL_LIGHTS".to_string(),
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_pbr/src/material.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use bevy_render::{
},
render_resource::{
AsBindGroup, AsBindGroupError, BindGroup, BindGroupLayout, OwnedBindingResource,
PipelineCache, RenderPipelineDescriptor, Shader, ShaderRef, SpecializedMeshPipeline,
SpecializedMeshPipelineError, SpecializedMeshPipelines,
PipelineCache, RenderPipelineDescriptor, Shader, ShaderDefVal, ShaderRef,
SpecializedMeshPipeline, SpecializedMeshPipelineError, SpecializedMeshPipelines,
},
renderer::RenderDevice,
texture::FallbackImage,
Expand Down Expand Up @@ -292,8 +292,11 @@ where
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key.mesh_key, layout)?;
let mut descriptor = self
.mesh_pipeline
.specialize(key.mesh_key, layout, shader_defs)?;
if let Some(vertex_shader) = &self.vertex_shader {
descriptor.vertex.shader = vertex_shader.clone();
}
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/prepass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ where
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
mut shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut bind_group_layout = vec![self.view_layout.clone()];
let mut shader_defs = Vec::new();
let mut vertex_attributes = Vec::new();

// NOTE: Eventually, it would be nice to only add this when the shaders are overloaded by the Material.
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_pbr/src/render/mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,8 @@ impl SpecializedMeshPipeline for MeshPipeline {
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
mut shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut shader_defs = Vec::new();
let mut vertex_attributes = Vec::new();

if layout.contains(Mesh::ATTRIBUTE_POSITION) {
Expand Down
4 changes: 3 additions & 1 deletion crates/bevy_pbr/src/wireframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use bevy_ecs::{prelude::*, reflect::ReflectComponent};
use bevy_reflect::std_traits::ReflectDefault;
use bevy_reflect::{Reflect, TypeUuid};
use bevy_render::extract_component::{ExtractComponent, ExtractComponentPlugin};
use bevy_render::render_resource::ShaderDefVal;
use bevy_render::Render;
use bevy_render::{
extract_resource::{ExtractResource, ExtractResourcePlugin},
Expand Down Expand Up @@ -86,8 +87,9 @@ impl SpecializedMeshPipeline for WireframePipeline {
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
let mut descriptor = self.mesh_pipeline.specialize(key, layout, shader_defs)?;
descriptor.vertex.shader = self.shader.clone_weak();
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone_weak();
descriptor.primitive.polygon_mode = PolygonMode::Line;
Expand Down
34 changes: 21 additions & 13 deletions crates/bevy_render/src/render_resource/pipeline_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,25 +193,13 @@ impl ShaderCache {
let module = match data.processed_shaders.entry(shader_defs.to_vec()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let mut shader_defs = shader_defs.to_vec();
#[cfg(feature = "webgl")]
{
shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
}

shader_defs.push(ShaderDefVal::UInt(
String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
render_device.limits().max_storage_buffers_per_shader_stage,
));

debug!(
"processing shader {:?}, with shader defs {:?}",
handle, shader_defs
);
let processed = self.processor.process(
shader,
&shader_defs,
shader_defs,
&self.shaders,
&self.import_path_shaders,
)?;
Expand Down Expand Up @@ -311,6 +299,7 @@ impl ShaderCache {
}

type LayoutCacheKey = (Vec<BindGroupLayoutId>, Vec<PushConstantRange>);

#[derive(Default)]
struct LayoutCache {
layouts: HashMap<LayoutCacheKey, ErasedPipelineLayout>,
Expand Down Expand Up @@ -362,22 +351,41 @@ pub struct PipelineCache {
pipelines: Vec<CachedPipeline>,
waiting_pipelines: HashSet<CachedPipelineId>,
new_pipelines: Mutex<Vec<CachedPipeline>>,
base_shader_defs: Vec<ShaderDefVal>,
}

impl PipelineCache {
pub fn pipelines(&self) -> impl Iterator<Item = &CachedPipeline> {
self.pipelines.iter()
}

pub fn base_shader_defs(&self) -> Vec<ShaderDefVal> {
self.base_shader_defs.clone()
}

/// Create a new pipeline cache associated with the given render device.
pub fn new(device: RenderDevice) -> Self {
let mut base_shader_defs = Vec::new();

#[cfg(feature = "webgl")]
{
base_shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
base_shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
}

base_shader_defs.push(ShaderDefVal::UInt(
String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
device.limits().max_storage_buffers_per_shader_stage,
));

Self {
device,
layout_cache: default(),
shader_cache: default(),
waiting_pipelines: default(),
new_pipelines: default(),
pipelines: default(),
base_shader_defs,
}
}

Expand Down
21 changes: 16 additions & 5 deletions crates/bevy_render/src/render_resource/pipeline_specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ use bevy_utils::{
use std::{fmt::Debug, hash::Hash};
use thiserror::Error;

use super::ShaderDefVal;

pub trait SpecializedRenderPipeline {
type Key: Clone + Hash + PartialEq + Eq;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor;
fn specialize(
&self,
key: Self::Key,
shader_defs: Vec<ShaderDefVal>,
) -> RenderPipelineDescriptor;
}

#[derive(Resource)]
Expand All @@ -38,15 +44,19 @@ impl<S: SpecializedRenderPipeline> SpecializedRenderPipelines<S> {
key: S::Key,
) -> CachedRenderPipelineId {
*self.cache.entry(key.clone()).or_insert_with(|| {
let descriptor = specialize_pipeline.specialize(key);
let descriptor = specialize_pipeline.specialize(key, cache.base_shader_defs());
cache.queue_render_pipeline(descriptor)
})
}
}

pub trait SpecializedComputePipeline {
type Key: Clone + Hash + PartialEq + Eq;
fn specialize(&self, key: Self::Key) -> ComputePipelineDescriptor;
fn specialize(
&self,
key: Self::Key,
shader_defs: Vec<ShaderDefVal>,
) -> ComputePipelineDescriptor;
}

#[derive(Resource)]
Expand All @@ -68,7 +78,7 @@ impl<S: SpecializedComputePipeline> SpecializedComputePipelines<S> {
key: S::Key,
) -> CachedComputePipelineId {
*self.cache.entry(key.clone()).or_insert_with(|| {
let descriptor = specialize_pipeline.specialize(key);
let descriptor = specialize_pipeline.specialize(key, cache.base_shader_defs());
cache.queue_compute_pipeline(descriptor)
})
}
Expand All @@ -80,6 +90,7 @@ pub trait SpecializedMeshPipeline {
&self,
key: Self::Key,
layout: &MeshVertexBufferLayout,
shader_defs: Vec<ShaderDefVal>,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError>;
}

Expand Down Expand Up @@ -115,7 +126,7 @@ impl<S: SpecializedMeshPipeline> SpecializedMeshPipelines<S> {
Entry::Occupied(entry) => Ok(*entry.into_mut()),
Entry::Vacant(entry) => {
let descriptor = specialize_pipeline
.specialize(key.clone(), layout)
.specialize(key.clone(), layout, cache.base_shader_defs())
.map_err(|mut err| {
{
let SpecializedMeshPipelineError::MissingVertexAttribute(err) =
Expand Down
Loading