Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device mismatch validation improvements #5841

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion wgpu-core/src/binding_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
hal_api::HalApi,
id::{BindGroupLayoutId, BufferId, SamplerId, TextureId, TextureViewId},
init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction},
resource::{Resource, ResourceInfo, ResourceType},
resource::{ParentDevice, Resource, ResourceInfo, ResourceType},
resource_log,
snatch::{SnatchGuard, Snatchable},
track::{BindGroupStates, UsageConflict},
Expand Down Expand Up @@ -518,6 +518,13 @@ impl<A: HalApi> Resource for BindGroupLayout<A> {
&self.label
}
}

impl<A: HalApi> ParentDevice<A> for BindGroupLayout<A> {
fn device(&self) -> &Arc<Device<A>> {
&self.device
}
}

impl<A: HalApi> BindGroupLayout<A> {
pub(crate) fn raw(&self) -> &A::BindGroupLayout {
self.raw.as_ref().unwrap()
Expand Down Expand Up @@ -751,6 +758,12 @@ impl<A: HalApi> Resource for PipelineLayout<A> {
}
}

impl<A: HalApi> ParentDevice<A> for PipelineLayout<A> {
fn device(&self) -> &Arc<Device<A>> {
&self.device
}
}

#[repr(C)]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -956,6 +969,12 @@ impl<A: HalApi> Resource for BindGroup<A> {
}
}

impl<A: HalApi> ParentDevice<A> for BindGroup<A> {
fn device(&self) -> &Arc<Device<A>> {
&self.device
}
}

#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum GetBindGroupLayoutError {
Expand Down
8 changes: 7 additions & 1 deletion wgpu-core/src/command/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ use crate::{
id,
init_tracker::{BufferInitTrackerAction, MemoryInitKind, TextureInitTrackerAction},
pipeline::{PipelineFlags, RenderPipeline, VertexStep},
resource::{Buffer, Resource, ResourceInfo, ResourceType},
resource::{Buffer, ParentDevice, Resource, ResourceInfo, ResourceType},
resource_log,
snatch::SnatchGuard,
track::RenderBundleScope,
Expand Down Expand Up @@ -1104,6 +1104,12 @@ impl<A: HalApi> Resource for RenderBundle<A> {
}
}

impl<A: HalApi> ParentDevice<A> for RenderBundle<A> {
fn device(&self) -> &Arc<Device<A>> {
&self.device
}
}

/// A render bundle's current index buffer state.
///
/// [`RenderBundleEncoder::finish`] records the currently set index buffer here,
Expand Down
10 changes: 3 additions & 7 deletions wgpu-core/src/command/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
hal_api::HalApi,
id::{BufferId, CommandEncoderId, DeviceId, TextureId},
init_tracker::{MemoryInitKind, TextureInitRange},
resource::{Resource, Texture, TextureClearMode},
resource::{ParentDevice, Resource, Texture, TextureClearMode},
snatch::SnatchGuard,
track::{TextureSelector, TextureTracker},
};
Expand Down Expand Up @@ -104,9 +104,7 @@ impl Global {
.get(dst)
.map_err(|_| ClearError::InvalidBuffer(dst))?;

if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.same_device_as(cmd_buf.as_ref())?;

cmd_buf_data
.trackers
Expand Down Expand Up @@ -203,9 +201,7 @@ impl Global {
.get(dst)
.map_err(|_| ClearError::InvalidTexture(dst))?;

if dst_texture.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_texture.same_device_as(cmd_buf.as_ref())?;

// Check if subresource aspects are valid.
let clear_aspects =
Expand Down
54 changes: 13 additions & 41 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
hal_api::HalApi,
hal_label, id,
init_tracker::MemoryInitKind,
resource::{self, Resource},
resource::{self, ParentDevice, Resource},
snatch::SnatchGuard,
track::{Tracker, TrackerIndex, UsageConflict, UsageScope},
validation::{check_buffer_usage, MissingBufferUsageError},
Expand Down Expand Up @@ -53,11 +53,6 @@ pub struct ComputePass<A: HalApi> {
// Resource binding dedupe state.
current_bind_groups: BindGroupStateChange,
current_pipeline: StateChange<id::ComputePipelineId>,

/// The device that this pass is associated with.
///
/// Used for quick validation during recording.
device_id: id::DeviceId,
}

impl<A: HalApi> ComputePass<A> {
Expand All @@ -68,19 +63,13 @@ impl<A: HalApi> ComputePass<A> {
timestamp_writes,
} = desc;

let device_id = parent
.as_ref()
.map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id());

Self {
base: Some(BasePass::new(label)),
parent,
timestamp_writes,

current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),

device_id,
}
}

Expand Down Expand Up @@ -361,11 +350,8 @@ impl Global {
);
};

if query_set.device.as_info().id() != cmd_buf.device.as_info().id() {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::WrongDeviceForTimestampWritesQuerySet),
);
if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) {
return (ComputePass::new(None, arc_desc), Some(e.into()));
}

Some(ArcComputePassTimestampWrites {
Expand Down Expand Up @@ -593,6 +579,8 @@ impl Global {
} => {
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());

bind_group.same_device_as(cmd_buf).map_pass_err(scope)?;

let max_bind_groups = cmd_buf.limits.max_bind_groups;
if index >= max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
Expand Down Expand Up @@ -658,6 +646,8 @@ impl Global {
let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);

pipeline.same_device_as(cmd_buf).map_pass_err(scope)?;

state.pipeline = Some(pipeline_id);

let pipeline = tracker.compute_pipelines.insert_single(pipeline);
Expand Down Expand Up @@ -797,6 +787,8 @@ impl Global {
pipeline: state.pipeline,
};

buffer.same_device_as(cmd_buf).map_pass_err(scope)?;

state.is_ready().map_pass_err(scope)?;

device
Expand Down Expand Up @@ -890,6 +882,8 @@ impl Global {
} => {
let scope = PassErrorScope::WriteTimestamp;

query_set.same_device_as(cmd_buf).map_pass_err(scope)?;

device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
.map_pass_err(scope)?;
Expand All @@ -906,6 +900,8 @@ impl Global {
} => {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;

query_set.same_device_as(cmd_buf).map_pass_err(scope)?;

let query_set = tracker.query_sets.insert_single(query_set);

validate_and_begin_pipeline_statistics_query(
Expand Down Expand Up @@ -994,10 +990,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
.map_pass_err(scope)?;

if bind_group.device.as_info().id() != pass.device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

base.commands.push(ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets: offsets.len(),
Expand All @@ -1016,7 +1008,6 @@ impl Global {

let scope = PassErrorScope::SetPipelineCompute(pipeline_id);

let device_id = pass.device_id;
let base = pass.base_mut(scope)?;
if redundant {
// Do redundant early-out **after** checking whether the pass is ended or not.
Expand All @@ -1031,10 +1022,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?;

if pipeline.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

base.commands.push(ArcComputeCommand::SetPipeline(pipeline));

Ok(())
Expand Down Expand Up @@ -1108,7 +1095,6 @@ impl Global {
indirect: true,
pipeline: pass.current_pipeline.last_state,
};
let device_id = pass.device_id;
let base = pass.base_mut(scope)?;

let buffer = hub
Expand All @@ -1118,10 +1104,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
.map_pass_err(scope)?;

if buffer.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

base.commands
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });

Expand Down Expand Up @@ -1185,7 +1167,6 @@ impl Global {
query_index: u32,
) -> Result<(), ComputePassError> {
let scope = PassErrorScope::WriteTimestamp;
let device_id = pass.device_id;
let base = pass.base_mut(scope)?;

let hub = A::hub(self);
Expand All @@ -1196,10 +1177,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;

if query_set.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
Expand All @@ -1215,7 +1192,6 @@ impl Global {
query_index: u32,
) -> Result<(), ComputePassError> {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
let device_id = pass.device_id;
let base = pass.base_mut(scope)?;

let hub = A::hub(self);
Expand All @@ -1226,10 +1202,6 @@ impl Global {
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;

if query_set.device.as_info().id() != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

base.commands
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
Expand Down
11 changes: 7 additions & 4 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard;

use crate::init_tracker::BufferInitTrackerAction;
use crate::resource::{Resource, ResourceInfo, ResourceType};
use crate::resource::{ParentDevice, Resource, ResourceInfo, ResourceType};
use crate::track::{Tracker, UsageScope};
use crate::{api_log, global::Global, hal_api::HalApi, id, resource_log, Label};

Expand Down Expand Up @@ -541,6 +541,12 @@ impl<A: HalApi> Resource for CommandBuffer<A> {
}
}

impl<A: HalApi> ParentDevice<A> for CommandBuffer<A> {
fn device(&self) -> &Arc<Device<A>> {
&self.device
}
}

#[derive(Copy, Clone, Debug)]
pub struct BasePassRef<'a, C> {
pub label: Option<&'a str>,
Expand Down Expand Up @@ -633,11 +639,8 @@ pub enum CommandEncoderError {
Device(#[from] DeviceError),
#[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")]
Locked,

#[error("QuerySet provided for pass timestamp writes is invalid.")]
InvalidTimestampWritesQuerySetId,
#[error("QuerySet provided for pass timestamp writes that was created by a different device.")]
WrongDeviceForTimestampWritesQuerySet,
}

impl Global {
Expand Down
10 changes: 3 additions & 7 deletions wgpu-core/src/command/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
hal_api::HalApi,
id::{self, Id},
init_tracker::MemoryInitKind,
resource::{QuerySet, Resource},
resource::{ParentDevice, QuerySet},
storage::Storage,
Epoch, FastHashMap, Index,
};
Expand Down Expand Up @@ -405,19 +405,15 @@ impl Global {
.add_single(&*query_set_guard, query_set_id)
.ok_or(QueryError::InvalidQuerySet(query_set_id))?;

if query_set.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
query_set.same_device_as(cmd_buf.as_ref())?;

let (dst_buffer, dst_pending) = {
let buffer_guard = hub.buffers.read();
let dst_buffer = buffer_guard
.get(destination)
.map_err(|_| QueryError::InvalidBuffer(destination))?;

if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
dst_buffer.same_device_as(cmd_buf.as_ref())?;

tracker
.buffers
Expand Down
Loading
Loading