summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-11-15 10:16:33 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-15 10:17:31 -0800
commita0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0 (patch)
tree7abd9adc08ea28818c0fde00974faea2544e20ce /mlir/lib/Dialect/StandardOps
parent9d7039b001d6f454d0b6712e0ae31b1d0019adb8 (diff)
downloadbcm5719-llvm-a0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0.tar.gz
bcm5719-llvm-a0986bf43d8287175b93e8dd6e2da2d1dc7ac7f0.zip
NFC: Convert CmpIPredicate in StandardOps to use EnumAttr
This turns several hand-written functions to auto-generated ones. PiperOrigin-RevId: 280684326
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp89
1 files changed, 19 insertions, 70 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index bf0cb75b8bc..c4abee3858e 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -35,6 +35,9 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) {
// CmpIOp
//===----------------------------------------------------------------------===//
-// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
-static inline const char *const *getCmpIPredicateNames() {
- static const char *predicateNames[]{
- /*EQ*/ "eq",
- /*NE*/ "ne",
- /*SLT*/ "slt",
- /*SLE*/ "sle",
- /*SGT*/ "sgt",
- /*SGE*/ "sge",
- /*ULT*/ "ult",
- /*ULE*/ "ule",
- /*UGT*/ "ugt",
- /*UGE*/ "uge",
- };
- static_assert(std::extent<decltype(predicateNames)>::value ==
- (size_t)CmpIPredicate::NumPredicates,
- "wrong number of predicate names");
- return predicateNames;
-}
-
-// Returns a value of the predicate corresponding to the given mnemonic.
-// Returns NumPredicates (one-past-end) if there is no such mnemonic.
-CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
- return llvm::StringSwitch<CmpIPredicate>(name)
- .Case("eq", CmpIPredicate::EQ)
- .Case("ne", CmpIPredicate::NE)
- .Case("slt", CmpIPredicate::SLT)
- .Case("sle", CmpIPredicate::SLE)
- .Case("sgt", CmpIPredicate::SGT)
- .Case("sge", CmpIPredicate::SGE)
- .Case("ult", CmpIPredicate::ULT)
- .Case("ule", CmpIPredicate::ULE)
- .Case("ugt", CmpIPredicate::UGT)
- .Case("uge", CmpIPredicate::UGE)
- .Default(CmpIPredicate::NumPredicates);
-}
-
static void buildCmpIOp(Builder *build, OperationState &result,
CmpIPredicate predicate, Value *lhs, Value *rhs) {
result.addOperands({lhs, rhs});
@@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
// Rewrite string attribute to an enum value.
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
- auto predicate = CmpIOp::getPredicateByName(predicateName);
- if (predicate == CmpIPredicate::NumPredicates)
+ Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
+ if (!predicate.hasValue())
return parser.emitError(parser.getNameLoc())
<< "unknown comparison predicate \"" << predicateName << "\"";
@@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
- attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+ attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
result.attributes = attrs;
result.addTypes({i1Type});
@@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, CmpIOp op) {
p << "cmpi ";
+ Builder b(op.getContext());
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
- assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
- predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
- "unknown predicate index");
- Builder b(op.getContext());
- auto predicateStringAttr =
- b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
- p.printAttribute(predicateStringAttr);
+ p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
+ << '"';
p << ", ";
p.printOperand(op.lhs());
@@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
p << " : " << op.lhs()->getType();
}
-static LogicalResult verify(CmpIOp op) {
- auto predicateAttr =
- op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
- if (!predicateAttr)
- return op.emitOpError("requires an integer attribute named 'predicate'");
- auto predicate = predicateAttr.getInt();
- if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
- predicate >= (int64_t)CmpIPredicate::NumPredicates)
- return op.emitOpError("'predicate' attribute value out of range");
-
- return success();
-}
-
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs) {
switch (predicate) {
- case CmpIPredicate::EQ:
+ case CmpIPredicate::eq:
return lhs.eq(rhs);
- case CmpIPredicate::NE:
+ case CmpIPredicate::ne:
return lhs.ne(rhs);
- case CmpIPredicate::SLT:
+ case CmpIPredicate::slt:
return lhs.slt(rhs);
- case CmpIPredicate::SLE:
+ case CmpIPredicate::sle:
return lhs.sle(rhs);
- case CmpIPredicate::SGT:
+ case CmpIPredicate::sgt:
return lhs.sgt(rhs);
- case CmpIPredicate::SGE:
+ case CmpIPredicate::sge:
return lhs.sge(rhs);
- case CmpIPredicate::ULT:
+ case CmpIPredicate::ult:
return lhs.ult(rhs);
- case CmpIPredicate::ULE:
+ case CmpIPredicate::ule:
return lhs.ule(rhs);
- case CmpIPredicate::UGT:
+ case CmpIPredicate::ugt:
return lhs.ugt(rhs);
- case CmpIPredicate::UGE:
+ case CmpIPredicate::uge:
return lhs.uge(rhs);
default:
llvm_unreachable("unknown comparison predicate");
OpenPOWER on IntegriCloud