summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp8
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp12
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td91
4 files changed, 61 insertions, 52 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d501a916ee8..a2e7eca127d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
isTailCall = false;
TargetLowering::CallLoweringInfo CLI(DAG);
- CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot())
- .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
- .setTailCall(isTailCall);
+ CLI.setDebugLoc(getCurSDLoc())
+ .setChain(getRoot())
+ .setCallee(RetTy, FTy, Callee, std::move(Args), CS)
+ .setTailCall(isTailCall)
+ .setConvergent(CS.isConvergent());
std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
if (Result.first.getNode()) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c6263ca7317..592a269d1a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::DeclareRetParam";
case NVPTXISD::PrintCall:
return "NVPTXISD::PrintCall";
+ case NVPTXISD::PrintConvergentCall:
+ return "NVPTXISD::PrintConvergentCall";
case NVPTXISD::PrintCallUni:
return "NVPTXISD::PrintCallUni";
+ case NVPTXISD::PrintConvergentCallUni:
+ return "NVPTXISD::PrintConvergentCallUni";
case NVPTXISD::LoadParam:
return "NVPTXISD::LoadParam";
case NVPTXISD::LoadParamV2:
@@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue PrintCallOps[] = {
Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag
};
- Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
- dl, PrintCallVTs, PrintCallOps);
+ // We model convergent calls as separate opcodes.
+ unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall;
+ if (CLI.IsConvergent)
+ Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
+ : NVPTXISD::PrintConvergentCall;
+ Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
InFlag = Chain.getValue(1);
// Ops to print out the function name
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 60914c1d09b..735cd01ced6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -34,7 +34,9 @@ enum NodeType : unsigned {
DeclareRet,
DeclareScalarRet,
PrintCall,
+ PrintConvergentCall,
PrintCallUni,
+ PrintConvergentCallUni,
CallArgBegin,
CallArg,
LastCallArg,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 51db8246c53..685d1b447b9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1701,9 +1701,15 @@ def LoadParamV4 :
def PrintCall :
SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCall :
+ SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def PrintCallUni :
SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def PrintConvergentCallUni :
+ SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def StoreParam :
SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
@@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> :
[]>;
let isCall=1 in {
- def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
- "call ", [(PrintCall (i32 0))]>;
- def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
- "call (retval0), ", [(PrintCall (i32 1))]>;
- def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1), ", [(PrintCall (i32 2))]>;
- def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>;
- def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>;
- def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4), ",
- [(PrintCall (i32 5))]>;
- def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5), ",
- [(PrintCall (i32 6))]>;
- def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
- [(PrintCall (i32 7))]>;
- def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
- "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
- "retval7), ",
- [(PrintCall (i32 8))]>;
-
- def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins),
- "call.uni ", [(PrintCallUni (i32 0))]>;
- def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins),
- "call.uni (retval0), ", [(PrintCallUni (i32 1))]>;
- def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>;
- def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>;
- def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>;
- def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4), ",
- [(PrintCallUni (i32 5))]>;
- def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ",
- [(PrintCallUni (i32 6))]>;
- def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ",
- [(PrintCallUni (i32 7))]>;
- def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins),
- "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, "
- "retval7), ",
- [(PrintCallUni (i32 8))]>;
+ multiclass CALL<string OpcStr, SDNode OpNode> {
+ def PrintCallNoRetInst : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " "), [(OpNode (i32 0))]>;
+ def PrintCallRetInst1 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>;
+ def PrintCallRetInst2 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>;
+ def PrintCallRetInst3 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>;
+ def PrintCallRetInst4 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "),
+ [(OpNode (i32 4))]>;
+ def PrintCallRetInst5 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "),
+ [(OpNode (i32 5))]>;
+ def PrintCallRetInst6 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5), "),
+ [(OpNode (i32 6))]>;
+ def PrintCallRetInst7 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5, retval6), "),
+ [(OpNode (i32 7))]>;
+ def PrintCallRetInst8 : NVPTXInst<(outs), (ins),
+ !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, "
+ "retval5, retval6, retval7), "),
+ [(OpNode (i32 8))]>;
+ }
+}
+
+defm Call : CALL<"call", PrintCall>;
+defm CallUni : CALL<"call.uni", PrintCallUni>;
+
+// Convergent call instructions. These are identical to regular calls, except
+// they have the isConvergent bit set.
+let isConvergent=1 in {
+ defm ConvergentCall : CALL<"call", PrintConvergentCall>;
+ defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>;
}
def LoadParamMemI64 : LoadParamMemInst<Int64Regs, ".b64">;
OpenPOWER on IntegriCloud