diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 91 |
4 files changed, 61 insertions, 52 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index d501a916ee8..a2e7eca127d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5562,9 +5562,11 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee, isTailCall = false; TargetLowering::CallLoweringInfo CLI(DAG); - CLI.setDebugLoc(getCurSDLoc()).setChain(getRoot()) - .setCallee(RetTy, FTy, Callee, std::move(Args), CS) - .setTailCall(isTailCall); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(getRoot()) + .setCallee(RetTy, FTy, Callee, std::move(Args), CS) + .setTailCall(isTailCall) + .setConvergent(CS.isConvergent()); std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB); if (Result.first.getNode()) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index c6263ca7317..592a269d1a0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -314,8 +314,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::DeclareRetParam"; case NVPTXISD::PrintCall: return "NVPTXISD::PrintCall"; + case NVPTXISD::PrintConvergentCall: + return "NVPTXISD::PrintConvergentCall"; case NVPTXISD::PrintCallUni: return "NVPTXISD::PrintCallUni"; + case NVPTXISD::PrintConvergentCallUni: + return "NVPTXISD::PrintConvergentCallUni"; case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam"; case NVPTXISD::LoadParamV2: @@ -1439,8 +1443,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDValue PrintCallOps[] = { Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, dl, MVT::i32), InFlag }; - Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), - dl, PrintCallVTs, PrintCallOps); + // We model convergent calls as separate opcodes. + unsigned Opcode = Func ? NVPTXISD::PrintCallUni : NVPTXISD::PrintCall; + if (CLI.IsConvergent) + Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni + : NVPTXISD::PrintConvergentCall; + Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps); InFlag = Chain.getValue(1); // Ops to print out the function name diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 60914c1d09b..735cd01ced6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -34,7 +34,9 @@ enum NodeType : unsigned { DeclareRet, DeclareScalarRet, PrintCall, + PrintConvergentCall, PrintCallUni, + PrintConvergentCallUni, CallArgBegin, CallArg, LastCallArg, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 51db8246c53..685d1b447b9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1701,9 +1701,15 @@ def LoadParamV4 : def PrintCall : SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCall : + SDNode<"NVPTXISD::PrintConvergentCall", SDTPrintCallProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def PrintCallUni : SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def PrintConvergentCallUni : + SDNode<"NVPTXISD::PrintConvergentCallUni", SDTPrintCallUniProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParam : SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; @@ -1821,53 +1827,44 @@ class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> : []>; let isCall=1 in { - def PrintCallNoRetInst : NVPTXInst<(outs), (ins), - "call ", [(PrintCall (i32 0))]>; - def PrintCallRetInst1 : NVPTXInst<(outs), (ins), - "call (retval0), ", [(PrintCall (i32 1))]>; - def PrintCallRetInst2 : NVPTXInst<(outs), (ins), - "call (retval0, retval1), ", [(PrintCall (i32 2))]>; - def PrintCallRetInst3 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2), ", [(PrintCall (i32 3))]>; - def PrintCallRetInst4 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3), ", [(PrintCall (i32 4))]>; - def PrintCallRetInst5 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCall (i32 5))]>; - def PrintCallRetInst6 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCall (i32 6))]>; - def PrintCallRetInst7 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCall (i32 7))]>; - def PrintCallRetInst8 : NVPTXInst<(outs), (ins), - "call (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCall (i32 8))]>; - - def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins), - "call.uni ", [(PrintCallUni (i32 0))]>; - def PrintCallUniRetInst1 : NVPTXInst<(outs), (ins), - "call.uni (retval0), ", [(PrintCallUni (i32 1))]>; - def PrintCallUniRetInst2 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1), ", [(PrintCallUni (i32 2))]>; - def PrintCallUniRetInst3 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2), ", [(PrintCallUni (i32 3))]>; - def PrintCallUniRetInst4 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3), ", [(PrintCallUni (i32 4))]>; - def PrintCallUniRetInst5 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4), ", - [(PrintCallUni (i32 5))]>; - def PrintCallUniRetInst6 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5), ", - [(PrintCallUni (i32 6))]>; - def PrintCallUniRetInst7 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6), ", - [(PrintCallUni (i32 7))]>; - def PrintCallUniRetInst8 : NVPTXInst<(outs), (ins), - "call.uni (retval0, retval1, retval2, retval3, retval4, retval5, retval6, " - "retval7), ", - [(PrintCallUni (i32 8))]>; + multiclass CALL<string OpcStr, SDNode OpNode> { + def PrintCallNoRetInst : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " "), [(OpNode (i32 0))]>; + def PrintCallRetInst1 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0), "), [(OpNode (i32 1))]>; + def PrintCallRetInst2 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1), "), [(OpNode (i32 2))]>; + def PrintCallRetInst3 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2), "), [(OpNode (i32 3))]>; + def PrintCallRetInst4 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3), "), + [(OpNode (i32 4))]>; + def PrintCallRetInst5 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4), "), + [(OpNode (i32 5))]>; + def PrintCallRetInst6 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5), "), + [(OpNode (i32 6))]>; + def PrintCallRetInst7 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6), "), + [(OpNode (i32 7))]>; + def PrintCallRetInst8 : NVPTXInst<(outs), (ins), + !strconcat(OpcStr, " (retval0, retval1, retval2, retval3, retval4, " + "retval5, retval6, retval7), "), + [(OpNode (i32 8))]>; + } +} + +defm Call : CALL<"call", PrintCall>; +defm CallUni : CALL<"call.uni", PrintCallUni>; + +// Convergent call instructions. These are identical to regular calls, except +// they have the isConvergent bit set. +let isConvergent=1 in { + defm ConvergentCall : CALL<"call", PrintConvergentCall>; + defm ConvergentCallUni : CALL<"call.uni", PrintConvergentCallUni>; } def LoadParamMemI64 : LoadParamMemInst<Int64Regs, ".b64">; |