diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index 732c152dcf..fce70e27b4 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -8,7 +8,7 @@ use crate::{ hal_api::HalApi, id::{BindGroupLayoutId, BufferId, SamplerId, TextureId, TextureViewId}, init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction}, - resource::{Resource, ResourceInfo, ResourceType}, + resource::{ParentDevice, Resource, ResourceInfo, ResourceType}, resource_log, snatch::{SnatchGuard, Snatchable}, track::{BindGroupStates, UsageConflict}, @@ -518,6 +518,13 @@ impl Resource for BindGroupLayout { &self.label } } + +impl ParentDevice for BindGroupLayout { + fn device(&self) -> &Arc> { + &self.device + } +} + impl BindGroupLayout { pub(crate) fn raw(&self) -> &A::BindGroupLayout { self.raw.as_ref().unwrap() @@ -751,6 +758,12 @@ impl Resource for PipelineLayout { } } +impl ParentDevice for PipelineLayout { + fn device(&self) -> &Arc> { + &self.device + } +} + #[repr(C)] #[derive(Clone, Debug, Hash, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -956,6 +969,12 @@ impl Resource for BindGroup { } } +impl ParentDevice for BindGroup { + fn device(&self) -> &Arc> { + &self.device + } +} + #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum GetBindGroupLayoutError { diff --git a/wgpu-core/src/command/bundle.rs b/wgpu-core/src/command/bundle.rs index 8701a0cb81..1a7733e657 100644 --- a/wgpu-core/src/command/bundle.rs +++ b/wgpu-core/src/command/bundle.rs @@ -97,7 +97,7 @@ use crate::{ id, init_tracker::{BufferInitTrackerAction, MemoryInitKind, TextureInitTrackerAction}, pipeline::{PipelineFlags, RenderPipeline, VertexStep}, - resource::{Buffer, Resource, ResourceInfo, ResourceType}, + resource::{Buffer, ParentDevice, Resource, ResourceInfo, ResourceType}, resource_log, snatch::SnatchGuard, track::RenderBundleScope, @@ -1104,6 +1104,12 @@ impl Resource for RenderBundle { } } +impl ParentDevice for RenderBundle { + fn device(&self) -> &Arc> { + &self.device + } +} + /// A render bundle's current index buffer state. /// /// [`RenderBundleEncoder::finish`] records the currently set index buffer here, diff --git a/wgpu-core/src/command/clear.rs b/wgpu-core/src/command/clear.rs index 9ef0f24d47..80167d2c2f 100644 --- a/wgpu-core/src/command/clear.rs +++ b/wgpu-core/src/command/clear.rs @@ -11,7 +11,7 @@ use crate::{ hal_api::HalApi, id::{BufferId, CommandEncoderId, DeviceId, TextureId}, init_tracker::{MemoryInitKind, TextureInitRange}, - resource::{Resource, Texture, TextureClearMode}, + resource::{ParentDevice, Resource, Texture, TextureClearMode}, snatch::SnatchGuard, track::{TextureSelector, TextureTracker}, }; @@ -104,9 +104,7 @@ impl Global { .get(dst) .map_err(|_| ClearError::InvalidBuffer(dst))?; - if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.same_device_as(cmd_buf.as_ref())?; cmd_buf_data .trackers @@ -203,9 +201,7 @@ impl Global { .get(dst) .map_err(|_| ClearError::InvalidTexture(dst))?; - if dst_texture.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_texture.same_device_as(cmd_buf.as_ref())?; // Check if subresource aspects are valid. let clear_aspects = diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index acbff0a030..3e19caf513 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -15,7 +15,7 @@ use crate::{ hal_api::HalApi, hal_label, id, init_tracker::MemoryInitKind, - resource::{self, Resource}, + resource::{self, ParentDevice, Resource}, snatch::SnatchGuard, track::{Tracker, TrackerIndex, UsageConflict, UsageScope}, validation::{check_buffer_usage, MissingBufferUsageError}, @@ -53,11 +53,6 @@ pub struct ComputePass { // Resource binding dedupe state. current_bind_groups: BindGroupStateChange, current_pipeline: StateChange, - - /// The device that this pass is associated with. - /// - /// Used for quick validation during recording. - device_id: id::DeviceId, } impl ComputePass { @@ -68,10 +63,6 @@ impl ComputePass { timestamp_writes, } = desc; - let device_id = parent - .as_ref() - .map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id()); - Self { base: Some(BasePass::new(label)), parent, @@ -79,8 +70,6 @@ impl ComputePass { current_bind_groups: BindGroupStateChange::new(), current_pipeline: StateChange::new(), - - device_id, } } @@ -361,11 +350,8 @@ impl Global { ); }; - if query_set.device.as_info().id() != cmd_buf.device.as_info().id() { - return ( - ComputePass::new(None, arc_desc), - Some(CommandEncoderError::WrongDeviceForTimestampWritesQuerySet), - ); + if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) { + return (ComputePass::new(None, arc_desc), Some(e.into())); } Some(ArcComputePassTimestampWrites { @@ -593,6 +579,8 @@ impl Global { } => { let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); + bind_group.same_device_as(cmd_buf).map_pass_err(scope)?; + let max_bind_groups = cmd_buf.limits.max_bind_groups; if index >= max_bind_groups { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { @@ -658,6 +646,8 @@ impl Global { let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); + pipeline.same_device_as(cmd_buf).map_pass_err(scope)?; + state.pipeline = Some(pipeline_id); let pipeline = tracker.compute_pipelines.insert_single(pipeline); @@ -797,6 +787,8 @@ impl Global { pipeline: state.pipeline, }; + buffer.same_device_as(cmd_buf).map_pass_err(scope)?; + state.is_ready().map_pass_err(scope)?; device @@ -890,6 +882,8 @@ impl Global { } => { let scope = PassErrorScope::WriteTimestamp; + query_set.same_device_as(cmd_buf).map_pass_err(scope)?; + device .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .map_pass_err(scope)?; @@ -906,6 +900,8 @@ impl Global { } => { let scope = PassErrorScope::BeginPipelineStatisticsQuery; + query_set.same_device_as(cmd_buf).map_pass_err(scope)?; + let query_set = tracker.query_sets.insert_single(query_set); validate_and_begin_pipeline_statistics_query( @@ -994,10 +990,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBindGroup(index)) .map_pass_err(scope)?; - if bind_group.device.as_info().id() != pass.device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), @@ -1016,7 +1008,6 @@ impl Global { let scope = PassErrorScope::SetPipelineCompute(pipeline_id); - let device_id = pass.device_id; let base = pass.base_mut(scope)?; if redundant { // Do redundant early-out **after** checking whether the pass is ended or not. @@ -1031,10 +1022,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id)) .map_pass_err(scope)?; - if pipeline.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::SetPipeline(pipeline)); Ok(()) @@ -1108,7 +1095,6 @@ impl Global { indirect: true, pipeline: pass.current_pipeline.last_state, }; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let buffer = hub @@ -1118,10 +1104,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id)) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); @@ -1185,7 +1167,6 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::WriteTimestamp; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1196,10 +1177,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; - if query_set.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands.push(ArcComputeCommand::WriteTimestamp { query_set, query_index, @@ -1215,7 +1192,6 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::BeginPipelineStatisticsQuery; - let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1226,10 +1202,6 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; - if query_set.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } - base.commands .push(ArcComputeCommand::BeginPipelineStatisticsQuery { query_set, diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 874e207a27..9c8bb71701 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -29,7 +29,7 @@ use crate::lock::{rank, Mutex}; use crate::snatch::SnatchGuard; use crate::init_tracker::BufferInitTrackerAction; -use crate::resource::{Resource, ResourceInfo, ResourceType}; +use crate::resource::{ParentDevice, Resource, ResourceInfo, ResourceType}; use crate::track::{Tracker, UsageScope}; use crate::{api_log, global::Global, hal_api::HalApi, id, resource_log, Label}; @@ -541,6 +541,12 @@ impl Resource for CommandBuffer { } } +impl ParentDevice for CommandBuffer { + fn device(&self) -> &Arc> { + &self.device + } +} + #[derive(Copy, Clone, Debug)] pub struct BasePassRef<'a, C> { pub label: Option<&'a str>, @@ -633,11 +639,8 @@ pub enum CommandEncoderError { Device(#[from] DeviceError), #[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")] Locked, - #[error("QuerySet provided for pass timestamp writes is invalid.")] InvalidTimestampWritesQuerySetId, - #[error("QuerySet provided for pass timestamp writes that was created by a different device.")] - WrongDeviceForTimestampWritesQuerySet, } impl Global { diff --git a/wgpu-core/src/command/query.rs b/wgpu-core/src/command/query.rs index bd4f9e991d..96831c1d16 100644 --- a/wgpu-core/src/command/query.rs +++ b/wgpu-core/src/command/query.rs @@ -9,7 +9,7 @@ use crate::{ hal_api::HalApi, id::{self, Id}, init_tracker::MemoryInitKind, - resource::{QuerySet, Resource}, + resource::{ParentDevice, QuerySet}, storage::Storage, Epoch, FastHashMap, Index, }; @@ -405,9 +405,7 @@ impl Global { .add_single(&*query_set_guard, query_set_id) .ok_or(QueryError::InvalidQuerySet(query_set_id))?; - if query_set.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + query_set.same_device_as(cmd_buf.as_ref())?; let (dst_buffer, dst_pending) = { let buffer_guard = hub.buffers.read(); @@ -415,9 +413,7 @@ impl Global { .get(destination) .map_err(|_| QueryError::InvalidBuffer(destination))?; - if dst_buffer.device.as_info().id() != cmd_buf.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.same_device_as(cmd_buf.as_ref())?; tracker .buffers diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index defd6a608b..5f0dde1db9 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -25,7 +25,7 @@ use crate::{ hal_label, id, init_tracker::{MemoryInitKind, TextureInitRange, TextureInitTrackerAction}, pipeline::{self, PipelineFlags}, - resource::{QuerySet, Texture, TextureView, TextureViewNotRenderableReason}, + resource::{ParentDevice, QuerySet, Texture, TextureView, TextureViewNotRenderableReason}, storage::Storage, track::{TextureSelector, Tracker, UsageConflict, UsageScope}, validation::{ @@ -1476,9 +1476,9 @@ impl Global { .ok_or(RenderCommandError::InvalidBindGroup(bind_group_id)) .map_pass_err(scope)?; - if bind_group.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + bind_group + .same_device_as(cmd_buf.as_ref()) + .map_pass_err(scope)?; bind_group .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) @@ -1544,9 +1544,9 @@ impl Global { .ok_or(RenderCommandError::InvalidPipeline(pipeline_id)) .map_pass_err(scope)?; - if pipeline.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + pipeline + .same_device_as(cmd_buf.as_ref()) + .map_pass_err(scope)?; info.context .check_compatible( @@ -1673,9 +1673,9 @@ impl Global { .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDEX) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + buffer + .same_device_as(cmd_buf.as_ref()) + .map_pass_err(scope)?; check_buffer_usage(buffer_id, buffer.usage, BufferUsages::INDEX) .map_pass_err(scope)?; @@ -1726,9 +1726,9 @@ impl Global { .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::VERTEX) .map_pass_err(scope)?; - if buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + buffer + .same_device_as(cmd_buf.as_ref()) + .map_pass_err(scope)?; let max_vertex_buffers = device.limits.max_vertex_buffers; if slot >= max_vertex_buffers { @@ -2333,9 +2333,9 @@ impl Global { .ok_or(RenderCommandError::InvalidRenderBundle(bundle_id)) .map_pass_err(scope)?; - if bundle.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice).map_pass_err(scope); - } + bundle + .same_device_as(cmd_buf.as_ref()) + .map_pass_err(scope)?; info.context .check_compatible( diff --git a/wgpu-core/src/command/transfer.rs b/wgpu-core/src/command/transfer.rs index 6c70739009..d04510f836 100644 --- a/wgpu-core/src/command/transfer.rs +++ b/wgpu-core/src/command/transfer.rs @@ -13,7 +13,7 @@ use crate::{ has_copy_partial_init_tracker_coverage, MemoryInitKind, TextureInitRange, TextureInitTrackerAction, }, - resource::{Resource, Texture, TextureErrorDimension}, + resource::{ParentDevice, Resource, Texture, TextureErrorDimension}, snatch::SnatchGuard, track::{TextureSelector, Tracker}, }; @@ -602,9 +602,7 @@ impl Global { .get(source) .map_err(|_| TransferError::InvalidBuffer(source))?; - if src_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_buffer.same_device_as(cmd_buf.as_ref())?; cmd_buf_data .trackers @@ -628,9 +626,7 @@ impl Global { .get(destination) .map_err(|_| TransferError::InvalidBuffer(destination))?; - if dst_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.same_device_as(cmd_buf.as_ref())?; cmd_buf_data .trackers @@ -781,9 +777,7 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(destination.texture))?; - if dst_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_texture.same_device_as(cmd_buf.as_ref())?; let (hal_copy_size, array_layer_count) = validate_texture_copy_range( destination, @@ -816,9 +810,7 @@ impl Global { .get(source.buffer) .map_err(|_| TransferError::InvalidBuffer(source.buffer))?; - if src_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_buffer.same_device_as(cmd_buf.as_ref())?; tracker .buffers @@ -951,9 +943,7 @@ impl Global { .get(source.texture) .map_err(|_| TransferError::InvalidTexture(source.texture))?; - if src_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_texture.same_device_as(cmd_buf.as_ref())?; let (hal_copy_size, array_layer_count) = validate_texture_copy_range(source, &src_texture.desc, CopySide::Source, copy_size)?; @@ -1007,9 +997,7 @@ impl Global { .get(destination.buffer) .map_err(|_| TransferError::InvalidBuffer(destination.buffer))?; - if dst_buffer.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst_buffer.same_device_as(cmd_buf.as_ref())?; tracker .buffers @@ -1139,12 +1127,8 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(source.texture))?; - if src_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } - if dst_texture.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + src_texture.same_device_as(cmd_buf.as_ref())?; + dst_texture.same_device_as(cmd_buf.as_ref())?; // src and dst texture format must be copy-compatible // https://gpuweb.github.io/gpuweb/#copy-compatible diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index d6133f4383..aca648a78b 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -15,7 +15,6 @@ use crate::{ pipeline, present, resource::{ self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError, - Resource, }, validation::check_buffer_usage, Label, LabelHelpers as _, @@ -1125,10 +1124,6 @@ impl Global { Err(..) => break 'error binding_model::CreateBindGroupError::InvalidLayout, }; - if bind_group_layout.device.as_info().id() != device.as_info().id() { - break 'error DeviceError::WrongDevice.into(); - } - let bind_group = match device.create_bind_group(&bind_group_layout, desc, hub) { Ok(bind_group) => bind_group, Err(e) => break 'error e, diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index e52f611f8b..f44764f94b 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -3,7 +3,9 @@ use crate::{ hal_api::HalApi, hub::Hub, id::{BindGroupLayoutId, PipelineLayoutId}, - resource::{Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation}, + resource::{ + Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, ResourceErrorIdent, + }, snatch::SnatchGuard, Label, DOWNLEVEL_ERROR_MESSAGE, }; @@ -382,6 +384,30 @@ fn map_buffer( #[error("Device is invalid")] pub struct InvalidDevice; +#[derive(Clone, Debug)] +pub struct DeviceMismatch { + pub(super) res: ResourceErrorIdent, + pub(super) res_device: ResourceErrorIdent, + pub(super) target: Option, + pub(super) target_device: ResourceErrorIdent, +} + +impl std::fmt::Display for DeviceMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "{} of {} doesn't match {}", + self.res_device, self.res, self.target_device + )?; + if let Some(target) = self.target.as_ref() { + write!(f, " of {target}")?; + } + Ok(()) + } +} + +impl std::error::Error for DeviceMismatch {} + #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum DeviceError { @@ -395,8 +421,8 @@ pub enum DeviceError { ResourceCreationFailed, #[error("QueueId is invalid")] InvalidQueueId, - #[error("Attempt to use a resource with a different device from the one that created it")] - WrongDevice, + #[error(transparent)] + DeviceMismatch(#[from] Box), } impl From for DeviceError { diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index 33af483882..5724b9447e 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -12,12 +12,12 @@ use crate::{ global::Global, hal_api::HalApi, hal_label, - id::{self, DeviceId, QueueId}, + id::{self, QueueId}, init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange}, lock::{rank, Mutex, RwLockWriteGuard}, resource::{ - Buffer, BufferAccessError, BufferMapState, DestroyedBuffer, DestroyedTexture, Resource, - ResourceInfo, ResourceType, StagingBuffer, Texture, TextureInner, + Buffer, BufferAccessError, BufferMapState, DestroyedBuffer, DestroyedTexture, ParentDevice, + Resource, ResourceInfo, ResourceType, StagingBuffer, Texture, TextureInner, }, resource_log, track, FastHashMap, SubmissionIndex, }; @@ -53,6 +53,12 @@ impl Resource for Queue { } } +impl ParentDevice for Queue { + fn device(&self) -> &Arc> { + self.device.as_ref().unwrap() + } +} + impl Drop for Queue { fn drop(&mut self) { let queue = self.raw.take().unwrap(); @@ -352,15 +358,6 @@ pub struct InvalidQueue; #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum QueueWriteError { - #[error( - "Device of queue ({:?}) does not match device of write recipient ({:?})", - queue_device_id, - target_device_id - )] - DeviceMismatch { - queue_device_id: DeviceId, - target_device_id: DeviceId, - }, #[error(transparent)] Queue(#[from] DeviceError), #[error(transparent)] @@ -405,13 +402,10 @@ impl Global { let hub = A::hub(self); - let buffer_device_id = hub + let buffer = hub .buffers .get(buffer_id) - .map_err(|_| TransferError::InvalidBuffer(buffer_id))? - .device - .as_info() - .id(); + .map_err(|_| TransferError::InvalidBuffer(buffer_id))?; let queue = hub .queues @@ -420,15 +414,7 @@ impl Global { let device = queue.device.as_ref().unwrap(); - { - let queue_device_id = device.as_info().id(); - if buffer_device_id != queue_device_id { - return Err(QueueWriteError::DeviceMismatch { - queue_device_id, - target_device_id: buffer_device_id, - }); - } - } + buffer.same_device_as(queue.as_ref())?; let data_size = data.len() as wgt::BufferAddress; @@ -469,6 +455,7 @@ impl Global { } let result = self.queue_write_staging_buffer_impl( + &queue, device, pending_writes, &staging_buffer, @@ -543,6 +530,7 @@ impl Global { } let result = self.queue_write_staging_buffer_impl( + &queue, device, pending_writes, &staging_buffer, @@ -607,7 +595,8 @@ impl Global { fn queue_write_staging_buffer_impl( &self, - device: &Device, + queue: &Arc>, + device: &Arc>, pending_writes: &mut PendingWrites, staging_buffer: &StagingBuffer, buffer_id: id::BufferId, @@ -632,9 +621,7 @@ impl Global { .get(&snatch_guard) .ok_or(TransferError::InvalidBuffer(buffer_id))?; - if dst.device.as_info().id() != device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + dst.same_device_as(queue.as_ref())?; let src_buffer_size = staging_buffer.size; self.queue_validate_write_buffer_impl(&dst, buffer_id, buffer_offset, src_buffer_size)?; @@ -717,9 +704,7 @@ impl Global { .get(destination.texture) .map_err(|_| TransferError::InvalidTexture(destination.texture))?; - if dst.device.as_info().id().into_queue_id() != queue_id { - return Err(DeviceError::WrongDevice.into()); - } + dst.same_device_as(queue.as_ref())?; if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) { return Err( @@ -1200,9 +1185,7 @@ impl Global { Err(_) => continue, }; - if cmdbuf.device.as_info().id().into_queue_id() != queue_id { - return Err(DeviceError::WrongDevice.into()); - } + cmdbuf.same_device_as(queue.as_ref())?; #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index f4702bc915..d44d98b962 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -24,8 +24,8 @@ use crate::{ pool::ResourcePool, registry::Registry, resource::{ - self, Buffer, QuerySet, Resource, ResourceInfo, ResourceType, Sampler, Texture, - TextureView, TextureViewNotRenderableReason, + self, Buffer, ParentDevice, QuerySet, Resource, ResourceInfo, ResourceType, Sampler, + Texture, TextureView, TextureViewNotRenderableReason, }, resource_log, snatch::{SnatchGuard, SnatchLock, Snatchable}, @@ -1837,6 +1837,7 @@ impl Device { } pub(crate) fn create_buffer_binding<'a>( + self: &Arc, bb: &binding_model::BufferBinding, binding: u32, decl: &wgt::BindGroupLayoutEntry, @@ -1846,7 +1847,6 @@ impl Device { used: &mut BindGroupStates, storage: &'a Storage>, limits: &wgt::Limits, - device_id: id::Id, snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1898,9 +1898,7 @@ impl Device { .add_single(storage, bb.buffer_id, internal_use) .ok_or(Error::InvalidBuffer(bb.buffer_id))?; - if buffer.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice.into()); - } + buffer.same_device(self)?; check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?; let raw_buffer = buffer @@ -1981,10 +1979,10 @@ impl Device { } fn create_sampler_binding<'a>( + self: &Arc, used: &BindGroupStates, storage: &'a Storage>, id: id::Id, - device_id: id::Id, ) -> Result<&'a Sampler, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1993,9 +1991,7 @@ impl Device { .add_single(storage, id) .ok_or(Error::InvalidSampler(id))?; - if sampler.device.as_info().id() != device_id { - return Err(DeviceError::WrongDevice.into()); - } + sampler.same_device(self)?; Ok(sampler) } @@ -2017,9 +2013,7 @@ impl Device { .add_single(storage, id) .ok_or(Error::InvalidTextureView(id))?; - if view.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + view.same_device(self)?; let (pub_usage, internal_use) = self.texture_use_parameters( binding, @@ -2038,9 +2032,7 @@ impl Device { texture_id, ))?; - if texture.device.as_info().id() != view.device.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + texture.same_device_as(view)?; check_texture_usage(texture.desc.usage, pub_usage)?; @@ -2073,6 +2065,9 @@ impl Device { hub: &Hub, ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::{BindingResource as Br, CreateBindGroupError as Error}; + + layout.same_device(self)?; + { // Check that the number of entries in the descriptor matches // the number of entries in the layout. @@ -2113,7 +2108,7 @@ impl Device { .ok_or(Error::MissingBindingDeclaration(binding))?; let (res_index, count) = match entry.resource { Br::Buffer(ref bb) => { - let bb = Self::create_buffer_binding( + let bb = self.create_buffer_binding( bb, binding, decl, @@ -2123,7 +2118,6 @@ impl Device { &mut used, &*buffer_guard, &self.limits, - self.as_info().id(), &snatch_guard, )?; @@ -2137,7 +2131,7 @@ impl Device { let res_index = hal_buffers.len(); for bb in bindings_array.iter() { - let bb = Self::create_buffer_binding( + let bb = self.create_buffer_binding( bb, binding, decl, @@ -2147,7 +2141,6 @@ impl Device { &mut used, &*buffer_guard, &self.limits, - self.as_info().id(), &snatch_guard, )?; hal_buffers.push(bb); @@ -2156,12 +2149,7 @@ impl Device { } Br::Sampler(id) => match decl.ty { wgt::BindingType::Sampler(ty) => { - let sampler = Self::create_sampler_binding( - &used, - &sampler_guard, - id, - self.as_info().id(), - )?; + let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; let (allowed_filtering, allowed_comparison) = match ty { wgt::SamplerBindingType::Filtering => (None, false), @@ -2203,12 +2191,7 @@ impl Device { let res_index = hal_samplers.len(); for &id in bindings_array.iter() { - let sampler = Self::create_sampler_binding( - &used, - &sampler_guard, - id, - self.as_info().id(), - )?; + let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; hal_samplers.push(sampler.raw()); } @@ -2537,9 +2520,7 @@ impl Device { // Validate total resource counts and check for a matching device for bgl in &bind_group_layouts { - if bgl.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + bgl.same_device(self)?; count_validator.merge(&bgl.binding_count_validator); } @@ -2647,9 +2628,7 @@ impl Device { .get(desc.stage.module) .map_err(|_| validation::StageError::InvalidModule)?; - if shader_module.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + shader_module.same_device(self)?; // Get the pipeline layout from the desc if it is provided. let pipeline_layout = match desc.layout { @@ -2659,9 +2638,7 @@ impl Device { .get(pipeline_layout_id) .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; - if pipeline_layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + pipeline_layout.same_device(self)?; Some(pipeline_layout) } @@ -2723,9 +2700,7 @@ impl Device { break 'cache None; }; - if cache.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + cache.same_device(self)?; Some(cache) }; @@ -3103,9 +3078,7 @@ impl Device { .get(pipeline_layout_id) .map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?; - if pipeline_layout.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + pipeline_layout.same_device(self)?; Some(pipeline_layout) } @@ -3140,9 +3113,7 @@ impl Device { error: validation::StageError::InvalidModule, } })?; - if vertex_shader_module.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + vertex_shader_module.same_device(self)?; let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; @@ -3334,9 +3305,7 @@ impl Device { break 'cache None; }; - if cache.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + cache.same_device(self)?; Some(cache) }; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index f3e7dbacb2..78cf3d567c 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -7,7 +7,7 @@ use crate::{ device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures, RenderPassContext}, hal_api::HalApi, id::{PipelineCacheId, PipelineLayoutId, ShaderModuleId}, - resource::{Resource, ResourceInfo, ResourceType}, + resource::{ParentDevice, Resource, ResourceInfo, ResourceType}, resource_log, validation, Label, }; use arrayvec::ArrayVec; @@ -90,6 +90,12 @@ impl Resource for ShaderModule { } } +impl ParentDevice for ShaderModule { + fn device(&self) -> &Arc> { + &self.device + } +} + impl ShaderModule { pub(crate) fn raw(&self) -> &A::ShaderModule { self.raw.as_ref().unwrap() @@ -258,6 +264,12 @@ impl Resource for ComputePipeline { } } +impl ParentDevice for ComputePipeline { + fn device(&self) -> &Arc> { + &self.device + } +} + impl ComputePipeline { pub(crate) fn raw(&self) -> &A::ComputePipeline { self.raw.as_ref().unwrap() @@ -326,6 +338,12 @@ impl Resource for PipelineCache { } } +impl ParentDevice for PipelineCache { + fn device(&self) -> &Arc> { + &self.device + } +} + /// Describes how the vertex buffer is interpreted. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -585,6 +603,12 @@ impl Resource for RenderPipeline { } } +impl ParentDevice for RenderPipeline { + fn device(&self) -> &Arc> { + &self.device + } +} + impl RenderPipeline { pub(crate) fn raw(&self) -> &A::RenderPipeline { self.raw.as_ref().unwrap() diff --git a/wgpu-core/src/pipeline_cache.rs b/wgpu-core/src/pipeline_cache.rs index d098cdafcf..b88fc21dda 100644 --- a/wgpu-core/src/pipeline_cache.rs +++ b/wgpu-core/src/pipeline_cache.rs @@ -16,7 +16,7 @@ pub enum PipelineCacheValidationError { #[error("The pipeline cacha data was out of date and so cannot be safely used")] Outdated, #[error("The cache data was created for a different device")] - WrongDevice, + DeviceMismatch, #[error("Pipeline cacha data was created for a future version of wgpu")] Unsupported, } @@ -26,7 +26,7 @@ impl PipelineCacheValidationError { /// That is, is there a mistake in user code interacting with the cache pub fn was_avoidable(&self) -> bool { match self { - PipelineCacheValidationError::WrongDevice => true, + PipelineCacheValidationError::DeviceMismatch => true, PipelineCacheValidationError::Truncated | PipelineCacheValidationError::Unsupported | PipelineCacheValidationError::Extended @@ -57,10 +57,10 @@ pub fn validate_pipeline_cache<'d>( return Err(PipelineCacheValidationError::Outdated); } if header.backend != adapter.backend as u8 { - return Err(PipelineCacheValidationError::WrongDevice); + return Err(PipelineCacheValidationError::DeviceMismatch); } if header.adapter_key != adapter_key { - return Err(PipelineCacheValidationError::WrongDevice); + return Err(PipelineCacheValidationError::DeviceMismatch); } if header.validation_key != validation_key { // If the validation key is wrong, that means that this device has changed @@ -420,7 +420,7 @@ mod tests { ]; let cache = cache.into_iter().flatten().collect::>(); let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); - assert_eq!(validation_result, Err(E::WrongDevice)); + assert_eq!(validation_result, Err(E::DeviceMismatch)); } #[test] fn wrong_adapter() { @@ -436,7 +436,7 @@ mod tests { ]; let cache = cache.into_iter().flatten().collect::>(); let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); - assert_eq!(validation_result, Err(E::WrongDevice)); + assert_eq!(validation_result, Err(E::DeviceMismatch)); } #[test] fn wrong_validation() { diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 9ae275615a..f45095d6df 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -3,8 +3,8 @@ use crate::device::trace; use crate::{ binding_model::BindGroup, device::{ - queue, resource::DeferredDestroy, BufferMapPendingClosure, Device, DeviceError, HostMap, - MissingDownlevelFlags, MissingFeatures, + queue, resource::DeferredDestroy, BufferMapPendingClosure, Device, DeviceError, + DeviceMismatch, HostMap, MissingDownlevelFlags, MissingFeatures, }, global::Global, hal_api::HalApi, @@ -143,6 +143,48 @@ impl ResourceInfo { } } +#[derive(Clone, Debug)] +pub struct ResourceErrorIdent { + r#type: ResourceType, + label: String, +} + +impl std::fmt::Display for ResourceErrorIdent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "{} with '{}' label", self.r#type, self.label) + } +} + +pub(crate) trait ParentDevice: Resource { + fn device(&self) -> &Arc>; + + fn same_device_as>(&self, other: &O) -> Result<(), DeviceError> { + Arc::ptr_eq(self.device(), other.device()) + .then_some(()) + .ok_or_else(|| { + DeviceError::DeviceMismatch(Box::new(DeviceMismatch { + res: self.error_ident(), + res_device: self.device().error_ident(), + target: Some(other.error_ident()), + target_device: other.device().error_ident(), + })) + }) + } + + fn same_device(&self, device: &Arc>) -> Result<(), DeviceError> { + Arc::ptr_eq(self.device(), device) + .then_some(()) + .ok_or_else(|| { + DeviceError::DeviceMismatch(Box::new(DeviceMismatch { + res: self.error_ident(), + res_device: self.device().error_ident(), + target: None, + target_device: device.error_ident(), + })) + }) + } +} + pub(crate) type ResourceType = &'static str; pub(crate) trait Resource: 'static + Sized + WasmNotSendSync { @@ -169,6 +211,12 @@ pub(crate) trait Resource: 'static + Sized + WasmNotSendSync { fn is_equal(&self, other: &Self) -> bool { self.as_info().id().unzip() == other.as_info().id().unzip() } + fn error_ident(&self) -> ResourceErrorIdent { + ResourceErrorIdent { + r#type: Self::TYPE, + label: self.label().to_owned(), + } + } } /// The status code provided to the buffer mapping callback. @@ -627,6 +675,12 @@ impl Resource for Buffer { } } +impl ParentDevice for Buffer { + fn device(&self) -> &Arc> { + &self.device + } +} + /// A buffer that has been marked as destroyed and is staged for actual deletion soon. #[derive(Debug)] pub struct DestroyedBuffer { @@ -731,6 +785,12 @@ impl Resource for StagingBuffer { } } +impl ParentDevice for StagingBuffer { + fn device(&self) -> &Arc> { + &self.device + } +} + pub type TextureDescriptor<'a> = wgt::TextureDescriptor, Vec>; #[derive(Debug)] @@ -1245,6 +1305,12 @@ impl Resource for Texture { } } +impl ParentDevice for Texture { + fn device(&self) -> &Arc> { + &self.device + } +} + impl Borrow for Texture { fn borrow(&self) -> &TextureSelector { &self.full_range @@ -1410,6 +1476,12 @@ impl Resource for TextureView { } } +impl ParentDevice for TextureView { + fn device(&self) -> &Arc> { + &self.device + } +} + /// Describes a [`Sampler`] #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -1531,6 +1603,12 @@ impl Resource for Sampler { } } +impl ParentDevice for Sampler { + fn device(&self) -> &Arc> { + &self.device + } +} + #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum CreateQuerySetError { @@ -1571,6 +1649,12 @@ impl Drop for QuerySet { } } +impl ParentDevice for QuerySet { + fn device(&self) -> &Arc> { + &self.device + } +} + impl Resource for QuerySet { const TYPE: ResourceType = "QuerySet";