Skip to content

Commit

Permalink
[msan] Handle horizontal add/subtract intrinsic by applying to shadow (
Browse files Browse the repository at this point in the history
…#124159)

Horizontal add (hadd) and subtract (hsub) are currently heuristically
handled by `maybeHandleSimpleNomemIntrinsic()` (via
`handleUnknownIntrinsic()`), which computes the shadow by bitwise OR'ing
the two operands. This has false positives for hadd/hsub shadows. For
example, suppose the shadows for the two operands are 00000000 and
11111111 respectively. The expected shadow for the result is 00001111,
but `maybeHandleSimpleNomemIntrinsic` would compute it as 11111111.

This patch handles horizontal add using
`handleIntrinsicByApplyingToShadow` (from
#114490), which has no false
positives for hadd/hsub: if each pair of adjacent shadow values is zero
(fully initialized), the result will be zero (fully initialized). More
generally, it is precise for hadd/hsub if at least one of the two
adjacent shadow values in each pair is zero.

It does have some false negatives for hadd/hsub: if we add/subtract two
adjacent non-zero shadow values, some bits of the result may incorrectly
be zero. We consider this an acceptable tradeoff for performance. To
make shadow propagation precise, we want the equivalent of "horizontal
OR", but this is not available. Reducing horizontal OR to (permutation
plus bitwise OR) is left as an exercise for the reader.
  • Loading branch information
thurstond authored Jan 24, 2025
1 parent 45d83ae commit 8ef171e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 61 deletions.
47 changes: 47 additions & 0 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3906,6 +3906,23 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

void handleAVXHorizontalAddSubIntrinsic(IntrinsicInst &I) {
// Approximation only:
// output = horizontal_add(A, B)
// => shadow[output] = horizontal_add(shadow[A], shadow[B])
//
// - If we add/subtract two adjacent zero (initialized) shadow values, the
// result always be zero i.e., no false positives.
// - If we add/subtract two shadows, one of which is uninitialized, the
// result will always be non-zero i.e., no false negative.
// - However, we can have false negatives if we subtract two non-zero
// shadows of the same value (or do an addition that wraps to zero); we
// consider this an acceptable tradeoff for performance.
// To make shadow propagation precise, we want the equivalent of
// "horizontal OR", but this is not available.
return handleIntrinsicByApplyingToShadow(I, /* trailingVerbatimArgs */ 0);
}

/// Handle Arm NEON vector store intrinsics (vst{2,3,4}, vst1x_{2,3,4},
/// and vst{2,3,4}lane).
///
Expand Down Expand Up @@ -4418,6 +4435,36 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
handleVtestIntrinsic(I);
break;

case Intrinsic::x86_sse3_hadd_ps:
case Intrinsic::x86_sse3_hadd_pd:
case Intrinsic::x86_ssse3_phadd_d:
case Intrinsic::x86_ssse3_phadd_d_128:
case Intrinsic::x86_ssse3_phadd_w:
case Intrinsic::x86_ssse3_phadd_w_128:
case Intrinsic::x86_ssse3_phadd_sw:
case Intrinsic::x86_ssse3_phadd_sw_128:
case Intrinsic::x86_avx_hadd_pd_256:
case Intrinsic::x86_avx_hadd_ps_256:
case Intrinsic::x86_avx2_phadd_d:
case Intrinsic::x86_avx2_phadd_w:
case Intrinsic::x86_avx2_phadd_sw:
case Intrinsic::x86_sse3_hsub_ps:
case Intrinsic::x86_sse3_hsub_pd:
case Intrinsic::x86_ssse3_phsub_d:
case Intrinsic::x86_ssse3_phsub_d_128:
case Intrinsic::x86_ssse3_phsub_w:
case Intrinsic::x86_ssse3_phsub_w_128:
case Intrinsic::x86_ssse3_phsub_sw:
case Intrinsic::x86_ssse3_phsub_sw_128:
case Intrinsic::x86_avx_hsub_pd_256:
case Intrinsic::x86_avx_hsub_ps_256:
case Intrinsic::x86_avx2_phsub_d:
case Intrinsic::x86_avx2_phsub_w:
case Intrinsic::x86_avx2_phsub_sw: {
handleAVXHorizontalAddSubIntrinsic(I);
break;
}

case Intrinsic::fshl:
case Intrinsic::fshr:
handleFunnelShift(I);
Expand Down
36 changes: 24 additions & 12 deletions llvm/test/Instrumentation/MemorySanitizer/X86/avx-intrinsics-x86.ll
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,13 @@ define <4 x double> @test_x86_avx_hadd_pd_256(<4 x double> %a0, <4 x double> %a1
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <4 x i64> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A0:%.*]], <4 x double> [[A1:%.*]])
; CHECK-NEXT: [[A0:%.*]] = bitcast <4 x i64> [[TMP1]] to <4 x double>
; CHECK-NEXT: [[A1:%.*]] = bitcast <4 x i64> [[TMP2]] to <4 x double>
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A0]], <4 x double> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <4 x double> [[RES]] to <4 x i64>
; CHECK-NEXT: [[RES1:%.*]] = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> [[A2:%.*]], <4 x double> [[A3:%.*]])
; CHECK-NEXT: store <4 x i64> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <4 x double> [[RES]]
; CHECK-NEXT: ret <4 x double> [[RES1]]
;
%res = call <4 x double> @llvm.x86.avx.hadd.pd.256(<4 x double> %a0, <4 x double> %a1) ; <<4 x double>> [#uses=1]
ret <4 x double> %res
Expand All @@ -451,10 +454,13 @@ define <8 x float> @test_x86_avx_hadd_ps_256(<8 x float> %a0, <8 x float> %a1) #
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A0:%.*]], <8 x float> [[A1:%.*]])
; CHECK-NEXT: [[A0:%.*]] = bitcast <8 x i32> [[TMP1]] to <8 x float>
; CHECK-NEXT: [[A1:%.*]] = bitcast <8 x i32> [[TMP2]] to <8 x float>
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A0]], <8 x float> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <8 x float> [[RES]] to <8 x i32>
; CHECK-NEXT: [[RES1:%.*]] = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> [[A2:%.*]], <8 x float> [[A3:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES]]
; CHECK-NEXT: ret <8 x float> [[RES1]]
;
%res = call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> %a0, <8 x float> %a1) ; <<8 x float>> [#uses=1]
ret <8 x float> %res
Expand All @@ -467,10 +473,13 @@ define <4 x double> @test_x86_avx_hsub_pd_256(<4 x double> %a0, <4 x double> %a1
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <4 x i64> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hsub.pd.256(<4 x double> [[A0:%.*]], <4 x double> [[A1:%.*]])
; CHECK-NEXT: [[A0:%.*]] = bitcast <4 x i64> [[TMP1]] to <4 x double>
; CHECK-NEXT: [[A1:%.*]] = bitcast <4 x i64> [[TMP2]] to <4 x double>
; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.x86.avx.hsub.pd.256(<4 x double> [[A0]], <4 x double> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <4 x double> [[RES]] to <4 x i64>
; CHECK-NEXT: [[RES1:%.*]] = call <4 x double> @llvm.x86.avx.hsub.pd.256(<4 x double> [[A2:%.*]], <4 x double> [[A3:%.*]])
; CHECK-NEXT: store <4 x i64> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <4 x double> [[RES]]
; CHECK-NEXT: ret <4 x double> [[RES1]]
;
%res = call <4 x double> @llvm.x86.avx.hsub.pd.256(<4 x double> %a0, <4 x double> %a1) ; <<4 x double>> [#uses=1]
ret <4 x double> %res
Expand All @@ -483,10 +492,13 @@ define <8 x float> @test_x86_avx_hsub_ps_256(<8 x float> %a0, <8 x float> %a1) #
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hsub.ps.256(<8 x float> [[A0:%.*]], <8 x float> [[A1:%.*]])
; CHECK-NEXT: [[A0:%.*]] = bitcast <8 x i32> [[TMP1]] to <8 x float>
; CHECK-NEXT: [[A1:%.*]] = bitcast <8 x i32> [[TMP2]] to <8 x float>
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.hsub.ps.256(<8 x float> [[A0]], <8 x float> [[A1]])
; CHECK-NEXT: [[_MSPROP:%.*]] = bitcast <8 x float> [[RES]] to <8 x i32>
; CHECK-NEXT: [[RES1:%.*]] = call <8 x float> @llvm.x86.avx.hsub.ps.256(<8 x float> [[A2:%.*]], <8 x float> [[A3:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES]]
; CHECK-NEXT: ret <8 x float> [[RES1]]
;
%res = call <8 x float> @llvm.x86.avx.hsub.ps.256(<8 x float> %a0, <8 x float> %a1) ; <<8 x float>> [#uses=1]
ret <8 x float> %res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ define <8 x i32> @test_x86_avx2_phadd_d(<8 x i32> %a0, <8 x i32> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <8 x i32> @llvm.x86.avx2.phadd.d(<8 x i32> [[TMP1]], <8 x i32> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.phadd.d(<8 x i32> [[A0:%.*]], <8 x i32> [[A1:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x i32> [[RES]]
Expand All @@ -585,7 +585,7 @@ define <16 x i16> @test_x86_avx2_phadd_sw(<16 x i16> %a0, <16 x i16> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <16 x i16> @llvm.x86.avx2.phadd.sw(<16 x i16> [[TMP1]], <16 x i16> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.phadd.sw(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x i16> [[RES]]
Expand All @@ -601,7 +601,7 @@ define <16 x i16> @test_x86_avx2_phadd_w(<16 x i16> %a0, <16 x i16> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <16 x i16> @llvm.x86.avx2.phadd.w(<16 x i16> [[TMP1]], <16 x i16> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.phadd.w(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x i16> [[RES]]
Expand All @@ -617,7 +617,7 @@ define <8 x i32> @test_x86_avx2_phsub_d(<8 x i32> %a0, <8 x i32> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <8 x i32> @llvm.x86.avx2.phsub.d(<8 x i32> [[TMP1]], <8 x i32> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.phsub.d(<8 x i32> [[A0:%.*]], <8 x i32> [[A1:%.*]])
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x i32> [[RES]]
Expand All @@ -633,7 +633,7 @@ define <16 x i16> @test_x86_avx2_phsub_sw(<16 x i16> %a0, <16 x i16> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <16 x i16> @llvm.x86.avx2.phsub.sw(<16 x i16> [[TMP1]], <16 x i16> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.phsub.sw(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x i16> [[RES]]
Expand All @@ -649,7 +649,7 @@ define <16 x i16> @test_x86_avx2_phsub_w(<16 x i16> %a0, <16 x i16> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[_MSPROP:%.*]] = call <16 x i16> @llvm.x86.avx2.phsub.w(<16 x i16> [[TMP1]], <16 x i16> [[TMP2]])
; CHECK-NEXT: [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.phsub.w(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x i16> [[RES]]
Expand Down
Loading

0 comments on commit 8ef171e

Please sign in to comment.