Skip to content

Commit

Permalink
[ARM] Avoid clobbering byval arguments when passing to tail-calls
Browse files Browse the repository at this point in the history
When passing byval arguments to tail-calls, we need to store them into
the stack memory in which this the caller received it's arguments. If
any of the outgoing arguments are forwarded from incoming byval
arguments, then the source of the copy is from the same stack memory.
This can result in the copy corrupting a value which is still to be
read.

The fix is to first make a copy of the outgoing byval arguments in local
stack space, and then copy them to their final location. This fixes the
correctness issue, but results in extra copying, which could be
optimised.
  • Loading branch information
ostannard committed Sep 27, 2024
1 parent ec047cd commit e504e79
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 28 deletions.
53 changes: 51 additions & 2 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

MachineFunction &MF = DAG.getMachineFunction();
ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>();
MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
MachineFunction::CallSiteInfo CSInfo;
bool isStructRet = (Outs.empty()) ? false : Outs[0].Flags.isSRet();
bool isThisReturn = false;
Expand Down Expand Up @@ -2492,6 +2493,45 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
RegsToPassVector RegsToPass;
SmallVector<SDValue, 8> MemOpChains;

// If we are doing a tail-call, any byval arguments will be written to stack
// space which was used for incoming arguments. If any the values being used
// are incoming byval arguments to this function, then they might be
// overwritten by the stores of the outgoing arguments. To avoid this, we
// need to make a temporary copy of them in local stack space, then copy back
// to the argument area.
// TODO This could be optimised to skip byvals which are already being copied
// from local stack space, or which are copied from the incoming stack at the
// exact same location.
DenseMap<unsigned, SDValue> ByValTemporaries;
SDValue ByValTempChain;
if (isTailCall) {
for (unsigned ArgIdx = 0, e = OutVals.size(); ArgIdx != e; ++ArgIdx) {
SDValue Arg = OutVals[ArgIdx];
ISD::ArgFlagsTy Flags = Outs[ArgIdx].Flags;

if (Flags.isByVal()) {
int FrameIdx = MFI.CreateStackObject(
Flags.getByValSize(), Flags.getNonZeroByValAlign(), false);
SDValue Dst =
DAG.getFrameIndex(FrameIdx, getPointerTy(DAG.getDataLayout()));

SDValue SizeNode = DAG.getConstant(Flags.getByValSize(), dl, MVT::i32);
SDValue AlignNode =
DAG.getConstant(Flags.getNonZeroByValAlign().value(), dl, MVT::i32);

SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue Ops[] = { Chain, Dst, Arg, SizeNode, AlignNode};
MemOpChains.push_back(DAG.getNode(ARMISD::COPY_STRUCT_BYVAL, dl, VTs,
Ops));
ByValTemporaries[ArgIdx] = Dst;
}
}
if (!MemOpChains.empty()) {
ByValTempChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MemOpChains);
MemOpChains.clear();
}
}

// During a tail call, stores to the argument area must happen after all of
// the function's incoming arguments have been loaded because they may alias.
// This is done by folding in a TokenFactor from LowerFormalArguments, but
Expand Down Expand Up @@ -2529,6 +2569,9 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

if (isTailCall && VA.isMemLoc() && !AfterFormalArgLoads) {
Chain = DAG.getStackArgumentTokenFactor(Chain);
if (ByValTempChain)
Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chain,
ByValTempChain);
AfterFormalArgLoads = true;
}

Expand Down Expand Up @@ -2600,6 +2643,12 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned ByValArgsCount = CCInfo.getInRegsParamsCount();
unsigned CurByValIdx = CCInfo.getInRegsParamsProcessed();

SDValue ByValSrc;
if (ByValTemporaries.contains(realArgIdx))
ByValSrc = ByValTemporaries[realArgIdx];
else
ByValSrc = Arg;

if (CurByValIdx < ByValArgsCount) {

unsigned RegBegin, RegEnd;
Expand All @@ -2610,7 +2659,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned int i, j;
for (i = 0, j = RegBegin; j < RegEnd; i++, j++) {
SDValue Const = DAG.getConstant(4*i, dl, MVT::i32);
SDValue AddArg = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, Const);
SDValue AddArg = DAG.getNode(ISD::ADD, dl, PtrVT, ByValSrc, Const);
SDValue Load =
DAG.getLoad(PtrVT, dl, Chain, AddArg, MachinePointerInfo(),
DAG.InferPtrAlign(AddArg));
Expand All @@ -2632,7 +2681,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
std::tie(Dst, DstInfo) =
computeAddrForCallArg(dl, DAG, VA, StackPtr, isTailCall, SPDiff);
SDValue SrcOffset = DAG.getIntPtrConstant(4*offset, dl);
SDValue Src = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, SrcOffset);
SDValue Src = DAG.getNode(ISD::ADD, dl, PtrVT, ByValSrc, SrcOffset);
SDValue SizeNode = DAG.getConstant(Flags.getByValSize() - 4*offset, dl,
MVT::i32);
SDValue AlignNode =
Expand Down
161 changes: 135 additions & 26 deletions llvm/test/CodeGen/ARM/musttail.ll
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,28 @@ define void @large_caller(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
; CHECK-NEXT: add r12, sp, #8
; CHECK-NEXT: add lr, sp, #24
; CHECK-NEXT: .pad #20
; CHECK-NEXT: sub sp, sp, #20
; CHECK-NEXT: add r12, sp, #28
; CHECK-NEXT: add lr, sp, #44
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: add r12, sp, #8
; CHECK-NEXT: add r12, r12, #16
; CHECK-NEXT: add r0, sp, #28
; CHECK-NEXT: mov r1, sp
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: add r12, r1, #16
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r4, [r12], #4
; CHECK-NEXT: str r4, [lr], #4
; CHECK-NEXT: add sp, sp, #20
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
Expand All @@ -163,17 +178,30 @@ define void @large_caller_check_regs(%twenty_bytes* byval(%twenty_bytes) align 4
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
; CHECK-NEXT: add r12, sp, #8
; CHECK-NEXT: add lr, sp, #24
; CHECK-NEXT: .pad #20
; CHECK-NEXT: sub sp, sp, #20
; CHECK-NEXT: add r12, sp, #28
; CHECK-NEXT: add lr, sp, #44
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: @APP
; CHECK-NEXT: @NO_APP
; CHECK-NEXT: add r3, sp, #8
; CHECK-NEXT: add r0, sp, #8
; CHECK-NEXT: add r12, r0, #16
; CHECK-NEXT: ldm r3, {r0, r1, r2, r3}
; CHECK-NEXT: add r0, sp, #28
; CHECK-NEXT: mov r1, sp
; CHECK-NEXT: add r12, r1, #16
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r4, [r12], #4
; CHECK-NEXT: str r4, [lr], #4
; CHECK-NEXT: add sp, sp, #20
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
Expand All @@ -190,30 +218,44 @@ entry:
define void @large_caller_new_value(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
; CHECK-LABEL: large_caller_new_value:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: .pad #36
; CHECK-NEXT: sub sp, sp, #36
; CHECK-NEXT: add r12, sp, #20
; CHECK-NEXT: .pad #16
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
; CHECK-NEXT: .pad #40
; CHECK-NEXT: sub sp, sp, #40
; CHECK-NEXT: add r12, sp, #48
; CHECK-NEXT: add lr, sp, #64
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: mov r0, #4
; CHECK-NEXT: add r1, sp, #36
; CHECK-NEXT: str r0, [sp, #16]
; CHECK-NEXT: mov r1, sp
; CHECK-NEXT: str r0, [sp, #36]
; CHECK-NEXT: mov r0, #3
; CHECK-NEXT: str r0, [sp, #12]
; CHECK-NEXT: str r0, [sp, #32]
; CHECK-NEXT: mov r0, #2
; CHECK-NEXT: str r0, [sp, #8]
; CHECK-NEXT: str r0, [sp, #28]
; CHECK-NEXT: mov r0, #1
; CHECK-NEXT: str r0, [sp, #4]
; CHECK-NEXT: str r0, [sp, #24]
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: str r0, [sp]
; CHECK-NEXT: mov r0, sp
; CHECK-NEXT: add r0, r0, #16
; CHECK-NEXT: mov r3, #3
; CHECK-NEXT: str r0, [sp, #20]
; CHECK-NEXT: add r0, sp, #20
; CHECK-NEXT: add r12, r1, #16
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r1, #1
; CHECK-NEXT: mov r2, #2
; CHECK-NEXT: add sp, sp, #36
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r4, [r12], #4
; CHECK-NEXT: str r4, [lr], #4
; CHECK-NEXT: add sp, sp, #40
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
entry:
%y = alloca %twenty_bytes, align 4
Expand All @@ -229,3 +271,70 @@ entry:
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %y)
ret void
}

; TODO: This test swaps the two byval arguments, but does so without copying to
; a temporary location first, so the first copy overwrites the memory which
; will be ready by the second.
declare void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4, %twenty_bytes* byval(%twenty_bytes) align 4)
define void @swap_byvals(%twenty_bytes* byval(%twenty_bytes) align 4 %a, %twenty_bytes* byval(%twenty_bytes) align 4 %b) {
; CHECK-LABEL: swap_byvals:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: .pad #16
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, r5, r11, lr}
; CHECK-NEXT: push {r4, r5, r11, lr}
; CHECK-NEXT: .pad #40
; CHECK-NEXT: sub sp, sp, #40
; CHECK-NEXT: add r12, sp, #56
; CHECK-NEXT: add lr, sp, #20
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: add r0, sp, #56
; CHECK-NEXT: mov r12, sp
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: mov r2, r12
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: add r3, sp, #20
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: add r4, sp, #76
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: add r0, sp, #76
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: mov r2, lr
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldr r1, [r0], #4
; CHECK-NEXT: str r1, [r2], #4
; CHECK-NEXT: ldm r3, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r5, [r12], #4
; CHECK-NEXT: str r5, [r4], #4
; CHECK-NEXT: ldr r5, [r12], #4
; CHECK-NEXT: str r5, [r4], #4
; CHECK-NEXT: ldr r5, [r12], #4
; CHECK-NEXT: str r5, [r4], #4
; CHECK-NEXT: ldr r5, [r12], #4
; CHECK-NEXT: str r5, [r4], #4
; CHECK-NEXT: ldr r5, [r12], #4
; CHECK-NEXT: str r5, [r4], #4
; CHECK-NEXT: add r5, lr, #16
; CHECK-NEXT: add r12, sp, #72
; CHECK-NEXT: ldr r4, [r5], #4
; CHECK-NEXT: str r4, [r12], #4
; CHECK-NEXT: add sp, sp, #40
; CHECK-NEXT: pop {r4, r5, r11, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b two_byvals_callee
entry:
musttail call void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %b, %twenty_bytes* byval(%twenty_bytes) align 4 %a)
ret void
}

0 comments on commit e504e79

Please sign in to comment.