diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index b384d5706c..5a1c90326a 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -303,6 +303,14 @@ impl NumericType { } } + const fn scalar(self) -> crate::Scalar { + match self { + NumericType::Scalar(scalar) + | NumericType::Vector { scalar, .. } + | NumericType::Matrix { scalar, .. } => scalar, + } + } + const fn with_scalar(self, scalar: crate::Scalar) -> Self { match self { NumericType::Scalar(_) => NumericType::Scalar(scalar), diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 2960e82f6e..097791ef33 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -223,6 +223,10 @@ impl Writer { self.get_type_id(lookup_ty) } + pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word { + self.get_type_id(LookupType::Local(local)) + } + pub(super) fn get_pointer_id( &mut self, handle: Handle, @@ -320,199 +324,27 @@ impl Writer { for (expr_handle, expr) in ir_function.expressions.iter() { match *expr { crate::Expression::Binary { op, left, right } => { - let expr_ty = info[expr_handle].ty.inner_with(&ir_module.types); - let Some(numeric_type) = NumericType::from_inner(expr_ty) else { - continue; - }; - match (op, expr_ty.scalar()) { - // Division and modulo are undefined behaviour when the dividend is the - // minimum representable value and the divisor is negative one, or when - // the divisor is zero. These wrapped functions override the divisor to - // one in these cases, matching the WGSL spec. - ( - crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo, - Some( - scalar @ crate::Scalar { - kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, - .. - }, - ), - ) => { - let return_type_id = self.get_expression_type_id(&info[expr_handle].ty); - let left_type_id = self.get_expression_type_id(&info[left].ty); - let right_type_id = self.get_expression_type_id(&info[right].ty); - let wrapped = WrappedFunction::BinaryOp { - op, - left_type_id, - right_type_id, - }; - let function_id = *match self.wrapped_functions.entry(wrapped) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => e.insert(self.id_gen.next()), - }; - if self.flags.contains(WriterFlags::DEBUG) { - let function_name = match op { - crate::BinaryOperator::Divide => "naga_div", - crate::BinaryOperator::Modulo => "naga_mod", - _ => unreachable!(), - }; - self.debugs - .push(Instruction::name(function_id, function_name)); + let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types); + if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) { + match (op, expr_ty.scalar().kind) { + // Division and modulo are undefined behaviour when the + // dividend is the minimum representable value and the divisor + // is negative one, or when the divisor is zero. These wrapped + // functions override the divisor to one in these cases, + // matching the WGSL spec. + ( + crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo, + crate::ScalarKind::Sint | crate::ScalarKind::Uint, + ) => { + self.write_wrapped_binary_op( + op, + expr_ty, + &info[left].ty, + &info[right].ty, + )?; } - let mut function = Function::default(); - - let function_type_id = self.get_function_type(LookupFunctionType { - parameter_type_ids: vec![left_type_id, right_type_id], - return_type_id, - }); - function.signature = Some(Instruction::function( - return_type_id, - function_id, - spirv::FunctionControl::empty(), - function_type_id, - )); - - let lhs_id = self.id_gen.next(); - let rhs_id = self.id_gen.next(); - if self.flags.contains(WriterFlags::DEBUG) { - self.debugs.push(Instruction::name(lhs_id, "lhs")); - self.debugs.push(Instruction::name(rhs_id, "rhs")); - } - let left_par = Instruction::function_parameter(left_type_id, lhs_id); - let right_par = Instruction::function_parameter(right_type_id, rhs_id); - for instruction in [left_par, right_par] { - function.parameters.push(FunctionArgument { - instruction, - handle_id: 0, - }); - } - - let label_id = self.id_gen.next(); - let mut block = Block::new(label_id); - - let bool_type = numeric_type.with_scalar(crate::Scalar::BOOL); - let bool_type_id = - self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type))); - - let maybe_splat_const = |writer: &mut Self, const_id| match numeric_type - { - NumericType::Scalar(_) => const_id, - NumericType::Vector { size, .. } => { - let constituent_ids = [const_id; crate::VectorSize::MAX]; - writer.get_constant_composite( - LookupType::Local(LocalType::Numeric(numeric_type)), - &constituent_ids[..size as usize], - ) - } - NumericType::Matrix { .. } => unreachable!(), - }; - - let const_zero_id = self.get_constant_scalar_with(0, scalar)?; - let composite_zero_id = maybe_splat_const(self, const_zero_id); - let rhs_eq_zero_id = self.id_gen.next(); - block.body.push(Instruction::binary( - spirv::Op::IEqual, - bool_type_id, - rhs_eq_zero_id, - rhs_id, - composite_zero_id, - )); - let divisor_selector_id = match scalar.kind { - crate::ScalarKind::Sint => { - let (const_min_id, const_neg_one_id) = match scalar.width { - 4 => Ok(( - self.get_constant_scalar(crate::Literal::I32(i32::MIN)), - self.get_constant_scalar(crate::Literal::I32(-1i32)), - )), - 8 => Ok(( - self.get_constant_scalar(crate::Literal::I64(i64::MIN)), - self.get_constant_scalar(crate::Literal::I64(-1i64)), - )), - _ => Err(Error::Validation("Unexpected scalar width")), - }?; - let composite_min_id = maybe_splat_const(self, const_min_id); - let composite_neg_one_id = - maybe_splat_const(self, const_neg_one_id); - - let lhs_eq_int_min_id = self.id_gen.next(); - block.body.push(Instruction::binary( - spirv::Op::IEqual, - bool_type_id, - lhs_eq_int_min_id, - lhs_id, - composite_min_id, - )); - let rhs_eq_neg_one_id = self.id_gen.next(); - block.body.push(Instruction::binary( - spirv::Op::IEqual, - bool_type_id, - rhs_eq_neg_one_id, - rhs_id, - composite_neg_one_id, - )); - let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next(); - block.body.push(Instruction::binary( - spirv::Op::LogicalAnd, - bool_type_id, - lhs_eq_int_min_and_rhs_eq_neg_one_id, - lhs_eq_int_min_id, - rhs_eq_neg_one_id, - )); - let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = - self.id_gen.next(); - block.body.push(Instruction::binary( - spirv::Op::LogicalOr, - bool_type_id, - rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id, - rhs_eq_zero_id, - lhs_eq_int_min_and_rhs_eq_neg_one_id, - )); - rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id - } - crate::ScalarKind::Uint => rhs_eq_zero_id, - _ => unreachable!(), - }; - - let const_one_id = self.get_constant_scalar_with(1, scalar)?; - let composite_one_id = maybe_splat_const(self, const_one_id); - let divisor_id = self.id_gen.next(); - block.body.push(Instruction::select( - right_type_id, - divisor_id, - divisor_selector_id, - composite_one_id, - rhs_id, - )); - let op = match (op, scalar.kind) { - (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => { - spirv::Op::SDiv - } - (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => { - spirv::Op::UDiv - } - (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => { - spirv::Op::SRem - } - (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => { - spirv::Op::UMod - } - _ => unreachable!(), - }; - let return_id = self.id_gen.next(); - block.body.push(Instruction::binary( - op, - return_type_id, - return_id, - lhs_id, - divisor_id, - )); - - function.consume(block, Instruction::return_value(return_id)); - function.to_words(&mut self.logical_layout.function_definitions); - Instruction::function_end() - .to_words(&mut self.logical_layout.function_definitions); + _ => {} } - _ => {} } } _ => {} @@ -522,6 +354,200 @@ impl Writer { Ok(()) } + /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics. + /// + /// Define a function that performs an integer division or modulo operation, + /// except that using a divisor of zero or causing signed overflow with a + /// divisor of -1 returns the numerator unchanged, rather than exhibiting + /// undefined behavior. + /// + /// Store the generated function's id in the [`wrapped_functions`] table. + /// + /// The operator `op` must be either [`Divide`] or [`Modulo`]. + /// + /// # Panics + /// + /// The `return_type`, `left_type` or `right_type` arguments must all be + /// integer scalars or vectors. If not, this function panics. + /// + /// [`wrapped_functions`]: Writer::wrapped_functions + /// [`Divide`]: crate::BinaryOperator::Divide + /// [`Modulo`]: crate::BinaryOperator::Modulo + fn write_wrapped_binary_op( + &mut self, + op: crate::BinaryOperator, + return_type: NumericType, + left_type: &TypeResolution, + right_type: &TypeResolution, + ) -> Result<(), Error> { + let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type)); + let left_type_id = self.get_expression_type_id(left_type); + let right_type_id = self.get_expression_type_id(right_type); + + // Check if we've already emitted this function. + let wrapped = WrappedFunction::BinaryOp { + op, + left_type_id, + right_type_id, + }; + let function_id = match self.wrapped_functions.entry(wrapped) { + Entry::Occupied(_) => return Ok(()), + Entry::Vacant(e) => *e.insert(self.id_gen.next()), + }; + + let scalar = return_type.scalar(); + + if self.flags.contains(WriterFlags::DEBUG) { + let function_name = match op { + crate::BinaryOperator::Divide => "naga_div", + crate::BinaryOperator::Modulo => "naga_mod", + _ => unreachable!(), + }; + self.debugs + .push(Instruction::name(function_id, function_name)); + } + let mut function = Function::default(); + + let function_type_id = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![left_type_id, right_type_id], + return_type_id, + }); + function.signature = Some(Instruction::function( + return_type_id, + function_id, + spirv::FunctionControl::empty(), + function_type_id, + )); + + let lhs_id = self.id_gen.next(); + let rhs_id = self.id_gen.next(); + if self.flags.contains(WriterFlags::DEBUG) { + self.debugs.push(Instruction::name(lhs_id, "lhs")); + self.debugs.push(Instruction::name(rhs_id, "rhs")); + } + let left_par = Instruction::function_parameter(left_type_id, lhs_id); + let right_par = Instruction::function_parameter(right_type_id, rhs_id); + for instruction in [left_par, right_par] { + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + } + + let label_id = self.id_gen.next(); + let mut block = Block::new(label_id); + + let bool_type = return_type.with_scalar(crate::Scalar::BOOL); + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type))); + + let maybe_splat_const = |writer: &mut Self, const_id| match return_type { + NumericType::Scalar(_) => const_id, + NumericType::Vector { size, .. } => { + let constituent_ids = [const_id; crate::VectorSize::MAX]; + writer.get_constant_composite( + LookupType::Local(LocalType::Numeric(return_type)), + &constituent_ids[..size as usize], + ) + } + NumericType::Matrix { .. } => unreachable!(), + }; + + let const_zero_id = self.get_constant_scalar_with(0, scalar)?; + let composite_zero_id = maybe_splat_const(self, const_zero_id); + let rhs_eq_zero_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + rhs_eq_zero_id, + rhs_id, + composite_zero_id, + )); + let divisor_selector_id = match scalar.kind { + crate::ScalarKind::Sint => { + let (const_min_id, const_neg_one_id) = match scalar.width { + 4 => Ok(( + self.get_constant_scalar(crate::Literal::I32(i32::MIN)), + self.get_constant_scalar(crate::Literal::I32(-1i32)), + )), + 8 => Ok(( + self.get_constant_scalar(crate::Literal::I64(i64::MIN)), + self.get_constant_scalar(crate::Literal::I64(-1i64)), + )), + _ => Err(Error::Validation("Unexpected scalar width")), + }?; + let composite_min_id = maybe_splat_const(self, const_min_id); + let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id); + + let lhs_eq_int_min_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + lhs_eq_int_min_id, + lhs_id, + composite_min_id, + )); + let rhs_eq_neg_one_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + rhs_eq_neg_one_id, + rhs_id, + composite_neg_one_id, + )); + let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + lhs_eq_int_min_and_rhs_eq_neg_one_id, + lhs_eq_int_min_id, + rhs_eq_neg_one_id, + )); + let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalOr, + bool_type_id, + rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id, + rhs_eq_zero_id, + lhs_eq_int_min_and_rhs_eq_neg_one_id, + )); + rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id + } + crate::ScalarKind::Uint => rhs_eq_zero_id, + _ => unreachable!(), + }; + + let const_one_id = self.get_constant_scalar_with(1, scalar)?; + let composite_one_id = maybe_splat_const(self, const_one_id); + let divisor_id = self.id_gen.next(); + block.body.push(Instruction::select( + right_type_id, + divisor_id, + divisor_selector_id, + composite_one_id, + rhs_id, + )); + let op = match (op, scalar.kind) { + (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv, + (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv, + (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem, + (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod, + _ => unreachable!(), + }; + let return_id = self.id_gen.next(); + block.body.push(Instruction::binary( + op, + return_type_id, + return_id, + lhs_id, + divisor_id, + )); + + function.consume(block, Instruction::return_value(return_id)); + function.to_words(&mut self.logical_layout.function_definitions); + Instruction::function_end().to_words(&mut self.logical_layout.function_definitions); + Ok(()) + } + fn write_function( &mut self, ir_function: &crate::Function, @@ -1138,7 +1164,7 @@ impl Writer { } LocalType::Image(image) => { let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type)); - let type_id = self.get_type_id(LookupType::Local(local_type)); + let type_id = self.get_localtype_id(local_type); Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format) } LocalType::Sampler => Instruction::type_sampler(id),