summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/Target/TargetLowering.h15
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp8
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp12
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td91
-rw-r--r--llvm/test/CodeGen/NVPTX/convergent-mir-call.ll27
6 files changed, 99 insertions, 56 deletions
diff --git a/llvm/include/llvm/Target/TargetLowering.h b/llvm/include/llvm/Target/TargetLowering.h
index eb640529e0f..6abeb44a368 100644
--- a/llvm/include/llvm/Target/TargetLowering.h
+++ b/llvm/include/llvm/Target/TargetLowering.h
@@ -2348,6 +2348,7 @@ public:
bool IsInReg : 1;
bool DoesNotReturn : 1;
bool IsReturnValueUsed : 1;
+ bool IsConvergent : 1;
// IsTailCall should be modified by implementations of
// TargetLowering::LowerCall that perform tail call conversions.
@@ -2366,10 +2367,11 @@ public:
SmallVector<ISD::InputArg, 32> Ins;
CallLoweringInfo(SelectionDAG &DAG)
- : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false),
- IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true),
- IsTailCall(false), NumFixedArgs(-1), CallConv(CallingConv::C),
- DAG(DAG), CS(nullptr), IsPatchPoint(false) {}
+ : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false),
+ IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true),
+ IsConvergent(false), IsTailCall(false), NumFixedArgs(-1),
+ CallConv(CallingConv::C), DAG(DAG), CS(nullptr), IsPatchPoint(false) {
+ }
CallLoweringInfo &setDebugLoc(SDLoc dl) {
DL = dl;
@@ -2441,6 +2443,11 @@ public:
return *this;
}
+ CallLoweringInfo &setConvergent(bool Value = true) {
+ IsConvergent = Value;
+ return *this;
+ }
+
CallLoweringInfo &setSExtResult(bool Value = true) {
RetSExt = Value;
return *this;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d501a916ee8..a2e7eca127d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
isTailCall = false;
TargetLowering::CallLoweringInfo CLI(DAG);
- CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot())
- .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
- .setTailCall(isTailCall);
+ CLI.setDebugLoc(getCurSDLoc())
+ .setChain(getRoot())
+ .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
+ .setTailCall(isTailCall)
+ .setConvergent(CS.isConvergent());
std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
if (Result.first.getNode()) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c6263ca7317..592a269d1a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::DeclareRetParam";
case NVPTXISD::PrintCall:
return "NVPTXISD::PrintCall";
+ case NVPTXISD::PrintConvergentCall:
+ return "NVPTXISD::PrintConvergentCall";
case NVPTXISD::PrintCallUni:
return "NVPTXISD::PrintCallUni";
+ case NVPTXISD::PrintConvergentCallUni:
+ return "NVPTXISD::PrintConvergentCallUni";
case NVPTXISD::LoadParam:
return "NVPTXISD::LoadParam";
case NVPTXISD::LoadParamV2:
@@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue PrintCallOps[] = {
Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag
};
- Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
- dl, PrintCallVTs, PrintCallOps);
+ // We model convergent calls as separate opcodes.
+ unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall;
+ if (CLI.IsConvergent)
+ Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
+ : NVPTXISD::PrintConvergentCall;
+ Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
InFlag = Chain.getValue(1);
// Ops to print out the function name
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 60914c1d09b..735cd01ced6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -34,7 +34,9 @@ enum NodeType : unsigned {
DeclareRet,
DeclareScalarRet,
PrintCall,
+ PrintConvergentCall,
PrintCallUni,
+ PrintConvergentCallUni,
CallArgBegin,
CallArg,
LastCallArg,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 51db8246c53..685d1b447b9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1701,9 +1701,15 @@ def LoadParamV4 :
def PrintCall :
SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCall :
+ SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def PrintCallUni :
SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCallUni :
+ SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def StoreParam :
SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
@@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> :
[]>;
let isCall=1 in {
- def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
- "call ", [(PrintCall (i32 0))]>;
- def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
- "call (retval0), ", [(PrintCall (i32 1))]>;
- def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1), ", [(PrintCall (i32 2))]>;
- def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>;
- def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>;
- def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4), ",
- [(PrintCall (i32 5))]>;
- def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5), ",
- [(PrintCall (i32 6))]>;
- def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
- [(PrintCall (i32 7))]>;
- def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
- "retval7), ",
- [(PrintCall (i32 8))]>;
-
- def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins),
- "call.uni ", [(PrintCallUni (i32 0))]>;
- def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins),
- "call.uni (retval0), ", [(PrintCallUni (i32 1))]>;
- def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>;
- def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>;
- def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>;
- def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4), ",
- [(PrintCallUni (i32 5))]>;
- def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ",
- [(PrintCallUni (i32 6))]>;
- def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
- [(PrintCallUni (i32 7))]>;
- def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
- "retval7), ",
- [(PrintCallUni (i32 8))]>;
+ multiclass CALL<string OpcStr, SDNode OpNode> {
+ def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " "), [(OpNode (i32 0))]>;
+ def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>;
+ def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>;
+ def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>;
+ def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "),
+ [(OpNode (i32 4))]>;
+ def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "),
+ [(OpNode (i32 5))]>;
+ def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5), "),
+ [(OpNode (i32 6))]>;
+ def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5, retval6), "),
+ [(OpNode (i32 7))]>;
+ def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5, retval6, retval7), "),
+ [(OpNode (i32 8))]>;
+ }
+}
+
+defm Call : CALL<"call", PrintCall>;
+defm CallUni : CALL<"call.uni", PrintCallUni>;
+
+// Convergent call instructions. These are identical to regular calls, except
+// they have the isConvergent bit set.
+let isConvergent=1 in {
+ defm ConvergentCall : CALL<"call", PrintConvergentCall>;
+ defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>;
}
def LoadParamMemI64 : LoadParamMemInst<Int64Regs, ".b64">;
diff --git a/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll
new file mode 100644
index 00000000000..18142450490
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll
@@ -0,0 +1,27 @@
+; RUN: llc -mtriple nvptx64-nvidia-cuda -stop-after machine-cp -o - < %s 2>&1 | FileCheck %s
+
+; Check that convergent calls are emitted using convergent MIR instructions,
+; while non-convergent calls are not.
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare void @conv() convergent
+declare void @not_conv()
+
+define void @test(void ()* %f) {
+ ; CHECK: ConvergentCallUniPrintCall
+ ; CHECK-NEXT: @conv
+ call void @conv()
+
+ ; CHECK: CallUniPrintCall
+ ; CHECK-NEXT: @not_conv
+ call void @not_conv()
+
+ ; CHECK: ConvergentCallPrintCall
+ call void %f() convergent
+
+ ; CHECK: CallPrintCall
+ call void %f()
+
+ ret void
+}
OpenPOWER on IntegriCloud