diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-11-15 10:16:33 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-15 10:17:31 -0800 |
| commit | a0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0 (patch) | |
| tree | 7abd9adc08ea28818c0fde00974faea2544e20ce /mlir/lib/Dialect/StandardOps | |
| parent | 9d7039b001d6f454d0b6712e0ae31b1d0019adb8 (diff) | |
| download | bcm5719-llvm-a0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0.tar.gz bcm5719-llvm-a0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0.zip | |
NFC: Convert CmpIPredicate in StandardOps to use EnumAttr
This turns several hand-written functions to auto-generated ones.
PiperOrigin-RevId: 280684326
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 89 |
1 files changed, 19 insertions, 70 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index bf0cb75b8bc..c4abee3858e 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -35,6 +35,9 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) { // CmpIOp //===----------------------------------------------------------------------===// -// Returns an array of mnemonics for CmpIPredicates indexed by values thereof. -static inline const char *const *getCmpIPredicateNames() { - static const char *predicateNames[]{ - /*EQ*/ "eq", - /*NE*/ "ne", - /*SLT*/ "slt", - /*SLE*/ "sle", - /*SGT*/ "sgt", - /*SGE*/ "sge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UGT*/ "ugt", - /*UGE*/ "uge", - }; - static_assert(std::extent<decltype(predicateNames)>::value == - (size_t)CmpIPredicate::NumPredicates, - "wrong number of predicate names"); - return predicateNames; -} - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch<CmpIPredicate>(name) - .Case("eq", CmpIPredicate::EQ) - .Case("ne", CmpIPredicate::NE) - .Case("slt", CmpIPredicate::SLT) - .Case("sle", CmpIPredicate::SLE) - .Case("sgt", CmpIPredicate::SGT) - .Case("sge", CmpIPredicate::SGE) - .Case("ult", CmpIPredicate::ULT) - .Case("ule", CmpIPredicate::ULE) - .Case("ugt", CmpIPredicate::UGT) - .Case("uge", CmpIPredicate::UGE) - .Default(CmpIPredicate::NumPredicates); -} - static void buildCmpIOp(Builder *build, OperationState &result, CmpIPredicate predicate, Value *lhs, Value *rhs) { result.addOperands({lhs, rhs}); @@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); - auto predicate = CmpIOp::getPredicateByName(predicateName); - if (predicate == CmpIPredicate::NumPredicates) + Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName); + if (!predicate.hasValue()) return parser.emitError(parser.getNameLoc()) << "unknown comparison predicate \"" << predicateName << "\""; @@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); - attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); + attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate)); result.attributes = attrs; result.addTypes({i1Type}); @@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, CmpIOp op) { p << "cmpi "; + Builder b(op.getContext()); auto predicateValue = op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) && - predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) && - "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpIPredicateNames()[predicateValue]); - p.printAttribute(predicateStringAttr); + p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue)) + << '"'; p << ", "; p.printOperand(op.lhs()); @@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) { p << " : " << op.lhs()->getType(); } -static LogicalResult verify(CmpIOp op) { - auto predicateAttr = - op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()); - if (!predicateAttr) - return op.emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpIPredicate::FirstValidValue || - predicate >= (int64_t)CmpIPredicate::NumPredicates) - return op.emitOpError("'predicate' attribute value out of range"); - - return success(); -} - // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { switch (predicate) { - case CmpIPredicate::EQ: + case CmpIPredicate::eq: return lhs.eq(rhs); - case CmpIPredicate::NE: + case CmpIPredicate::ne: return lhs.ne(rhs); - case CmpIPredicate::SLT: + case CmpIPredicate::slt: return lhs.slt(rhs); - case CmpIPredicate::SLE: + case CmpIPredicate::sle: return lhs.sle(rhs); - case CmpIPredicate::SGT: + case CmpIPredicate::sgt: return lhs.sgt(rhs); - case CmpIPredicate::SGE: + case CmpIPredicate::sge: return lhs.sge(rhs); - case CmpIPredicate::ULT: + case CmpIPredicate::ult: return lhs.ult(rhs); - case CmpIPredicate::ULE: + case CmpIPredicate::ule: return lhs.ule(rhs); - case CmpIPredicate::UGT: + case CmpIPredicate::ugt: return lhs.ugt(rhs); - case CmpIPredicate::UGE: + case CmpIPredicate::uge: return lhs.uge(rhs); default: llvm_unreachable("unknown comparison predicate"); |

