diff options
Diffstat (limited to 'llvm/utils/TableGen/GlobalISelEmitter.cpp')
-rw-r--r-- | llvm/utils/TableGen/GlobalISelEmitter.cpp | 51 |
1 files changed, 35 insertions, 16 deletions
diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp index 238a50a94c6..0da1b86021a 100644 --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -33,6 +33,7 @@ #include "CodeGenDAGPatterns.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/LowLevelType.h" #include "llvm/CodeGen/MachineValueType.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" @@ -58,22 +59,38 @@ static cl::opt<bool> WarnOnSkippedPatterns( //===- Helper functions ---------------------------------------------------===// +/// This class stands in for LLT wherever we want to tablegen-erate an +/// equivalent at compiler run-time. +class LLTCodeGen { +private: + LLT Ty; + +public: + LLTCodeGen(const LLT &Ty) : Ty(Ty) {} + + void emitCxxConstructorCall(raw_ostream &OS) const { + if (Ty.isScalar()) { + OS << "LLT::scalar(" << Ty.getSizeInBits() << ")"; + return; + } + if (Ty.isVector()) { + OS << "LLT::vector(" << Ty.getNumElements() << ", " << Ty.getSizeInBits() + << ")"; + return; + } + llvm_unreachable("Unhandled LLT"); + } +}; + /// Convert an MVT to an equivalent LLT if possible, or the invalid LLT() for /// MVTs that don't map cleanly to an LLT (e.g., iPTR, *any, ...). -static Optional<std::string> MVTToLLT(MVT::SimpleValueType SVT) { - std::string TyStr; - raw_string_ostream OS(TyStr); +static Optional<LLTCodeGen> MVTToLLT(MVT::SimpleValueType SVT) { MVT VT(SVT); - if (VT.isVector() && VT.getVectorNumElements() != 1) { - OS << "LLT::vector(" << VT.getVectorNumElements() << ", " - << VT.getScalarSizeInBits() << ")"; - } else if (VT.isInteger() || VT.isFloatingPoint()) { - OS << "LLT::scalar(" << VT.getSizeInBits() << ")"; - } else { - return None; - } - OS.flush(); - return TyStr; + if (VT.isVector() && VT.getVectorNumElements() != 1) + return LLTCodeGen(LLT::vector(VT.getVectorNumElements(), VT.getScalarSizeInBits())); + if (VT.isInteger() || VT.isFloatingPoint()) + return LLTCodeGen(LLT::scalar(VT.getSizeInBits())); + return None; } static bool isTrivialOperatorNode(const TreePatternNode *N) { @@ -167,10 +184,10 @@ public: /// Generates code to check that an operand is a particular LLT. class LLTOperandMatcher : public OperandPredicateMatcher { protected: - std::string Ty; + LLTCodeGen Ty; public: - LLTOperandMatcher(std::string Ty) + LLTOperandMatcher(const LLTCodeGen &Ty) : OperandPredicateMatcher(OPM_LLT), Ty(Ty) {} static bool classof(const OperandPredicateMatcher *P) { @@ -179,7 +196,9 @@ public: void emitCxxPredicateExpr(raw_ostream &OS, StringRef OperandExpr) const override { - OS << "MRI.getType(" << OperandExpr << ".getReg()) == (" << Ty << ")"; + OS << "MRI.getType(" << OperandExpr << ".getReg()) == ("; + Ty.emitCxxConstructorCall(OS); + OS << ")"; } }; |