diff options
| -rw-r--r-- | llvm/include/llvm/Target/TargetLowering.h | 15 | ||||
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 8 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 91 | ||||
| -rw-r--r-- | llvm/test/CodeGen/NVPTX/convergent-mir-call.ll | 27 |
6 files changed, 99 insertions, 56 deletions
diff --git a/llvm/include/llvm/Target/TargetLowering.h b/llvm/include/llvm/Target/TargetLowering.h index eb640529e0f..6abeb44a368 100644 --- a/llvm/include/llvm/Target/TargetLowering.h +++ b/llvm/include/llvm/Target/TargetLowering.h @@ -2348,6 +2348,7 @@ public: bool IsInReg : 1; bool DoesNotReturn : 1; bool IsReturnValueUsed : 1; + bool IsConvergent : 1; // IsTailCall should be modified by implementations of // TargetLowering::LowerCall that perform tail call conversions. @@ -2366,10 +2367,11 @@ public: SmallVector<ISD::InputArg, 32> Ins; CallLoweringInfo(SelectionDAG &DAG) - : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), - IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), - IsTailCall(false), NumFixedArgs(-1), CallConv(CallingConv::C), - DAG(DAG), CS(nullptr), IsPatchPoint(false) {} + : RetTy(nullptr), RetSExt(false), RetZExt(false), IsVarArg(false), + IsInReg(false), DoesNotReturn(false), IsReturnValueUsed(true), + IsConvergent(false), IsTailCall(false), NumFixedArgs(-1), + CallConv(CallingConv::C), DAG(DAG), CS(nullptr), IsPatchPoint(false) { + } CallLoweringInfo &setDebugLoc(SDLoc dl) { DL = dl; @@ -2441,6 +2443,11 @@ public: return *this; } + CallLoweringInfo &setConvergent(bool Value = true) { + IsConvergent = Value; + return *this; + } + CallLoweringInfo &setSExtResult(bool Value = true) { RetSExt = Value; return *this; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index d501a916ee8..a2e7eca127d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee, isTailCall = false; TargetLowering::CallLoweringInfo CLI(DAG); - CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot()) - .setCallee(RetTy, FTy, Callee, std::move(Args), CS) - .setTailCall(isTailCall); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(getRoot()) + .setCallee(RetTy, FTy, Callee, std::move(Args), CS) + .setTailCall(isTailCall) + .setConvergent(CS.isConvergent()); std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB); if (Result.first.getNode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index c6263ca7317..592a269d1a0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::DeclareRetParam"; case NVPTXISD::PrintCall: return "NVPTXISD::PrintCall"; + case NVPTXISD::PrintConvergentCall: + return "NVPTXISD::PrintConvergentCall"; case NVPTXISD::PrintCallUni: return "NVPTXISD::PrintCallUni"; + case NVPTXISD::PrintConvergentCallUni: + return "NVPTXISD::PrintConvergentCallUni"; case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam"; case NVPTXISD::LoadParamV2: @@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDValue PrintCallOps[] = { Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag }; - Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), - dl, PrintCallVTs, PrintCallOps); + // We model convergent calls as separate opcodes. + unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall; + if (CLI.IsConvergent) + Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni + : NVPTXISD::PrintConvergentCall; + Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps); InFlag = Chain.getValue(1); // Ops to print out the function name diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 60914c1d09b..735cd01ced6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -34,7 +34,9 @@ enum NodeType : unsigned { DeclareRet, DeclareScalarRet, PrintCall, + PrintConvergentCall, PrintCallUni, + PrintConvergentCallUni, CallArgBegin, CallArg, LastCallArg, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 51db8246c53..685d1b447b9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1701,9 +1701,15 @@ def LoadParamV4 : def PrintCall : SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCall : + SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def PrintCallUni : SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCallUni : + SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParam : SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; @@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> : []>; let isCall=1 in { - def PrintCallNoRetInst : NVPTXInst<(outs), (ins), - "call ", [(PrintCall (i32 0))]>; - def PrintCallRetInst1 : NVPTXInst<(outs), (ins), - "call (retval0), ", [(PrintCall (i32 1))]>; - def PrintCallRetInst2 : NVPTXInst<(outs), (ins), - "call (retval0, retval1), ", [(PrintCall (i32 2))]>; - def PrintCallRetInst3 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>; - def PrintCallRetInst4 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>; - def PrintCallRetInst5 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCall (i32 5))]>; - def PrintCallRetInst6 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCall (i32 6))]>; - def PrintCallRetInst7 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCall (i32 7))]>; - def PrintCallRetInst8 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCall (i32 8))]>; - - def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins), - "call.uni ", [(PrintCallUni (i32 0))]>; - def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins), - "call.uni (retval0), ", [(PrintCallUni (i32 1))]>; - def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>; - def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>; - def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>; - def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCallUni (i32 5))]>; - def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCallUni (i32 6))]>; - def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCallUni (i32 7))]>; - def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCallUni (i32 8))]>; + multiclass CALL<string OpcStr, SDNode OpNode> { + def PrintCallNoRetInst : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " "), [(OpNode (i32 0))]>; + def PrintCallRetInst1 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>; + def PrintCallRetInst2 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>; + def PrintCallRetInst3 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>; + def PrintCallRetInst4 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "), + [(OpNode (i32 4))]>; + def PrintCallRetInst5 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "), + [(OpNode (i32 5))]>; + def PrintCallRetInst6 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5), "), + [(OpNode (i32 6))]>; + def PrintCallRetInst7 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6), "), + [(OpNode (i32 7))]>; + def PrintCallRetInst8 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6, retval7), "), + [(OpNode (i32 8))]>; + } +} + +defm Call : CALL<"call", PrintCall>; +defm CallUni : CALL<"call.uni", PrintCallUni>; + +// Convergent call instructions. These are identical to regular calls, except +// they have the isConvergent bit set. +let isConvergent=1 in { + defm ConvergentCall : CALL<"call", PrintConvergentCall>; + defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>; } def LoadParamMemI64 : LoadParamMemInst<Int64Regs, ".b64">; diff --git a/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll new file mode 100644 index 00000000000..18142450490 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/convergent-mir-call.ll @@ -0,0 +1,27 @@ +; RUN: llc -mtriple nvptx64-nvidia-cuda -stop-after machine-cp -o - < %s 2>&1 | FileCheck %s + +; Check that convergent calls are emitted using convergent MIR instructions, +; while non-convergent calls are not. + +target triple = "nvptx64-nvidia-cuda" + +declare void @conv() convergent +declare void @not_conv() + +define void @test(void ()* %f) { + ; CHECK: ConvergentCallUniPrintCall + ; CHECK-NEXT: @conv + call void @conv() + + ; CHECK: CallUniPrintCall + ; CHECK-NEXT: @not_conv + call void @not_conv() + + ; CHECK: ConvergentCallPrintCall + call void %f() convergent + + ; CHECK: CallPrintCall + call void %f() + + ret void +} |

