summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/StandardOps/StandardOps.h2
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp44
-rw-r--r--mlir/test/Transforms/constant-fold.mlir26
3 files changed, 72 insertions, 0 deletions
diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h
index 4d8867ffd4d..bb368bd2ef8 100644
--- a/mlir/include/mlir/StandardOps/StandardOps.h
+++ b/mlir/include/mlir/StandardOps/StandardOps.h
@@ -199,6 +199,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
+ Attribute constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) const;
private:
friend class OperationInst;
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index f0ce38aadde..3a12a9c4b8d 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -571,6 +571,50 @@ bool CmpIOp::verify() const {
return false;
}
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
+// comparison predicates.
+static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
+ const APInt &rhs) {
+ switch (predicate) {
+ case CmpIPredicate::EQ:
+ return lhs.eq(rhs);
+ case CmpIPredicate::NE:
+ return lhs.ne(rhs);
+ case CmpIPredicate::SLT:
+ return lhs.slt(rhs);
+ case CmpIPredicate::SLE:
+ return lhs.sle(rhs);
+ case CmpIPredicate::SGT:
+ return lhs.sgt(rhs);
+ case CmpIPredicate::SGE:
+ return lhs.sge(rhs);
+ case CmpIPredicate::ULT:
+ return lhs.ult(rhs);
+ case CmpIPredicate::ULE:
+ return lhs.ule(rhs);
+ case CmpIPredicate::UGT:
+ return lhs.ugt(rhs);
+ case CmpIPredicate::UGE:
+ return lhs.uge(rhs);
+ default:
+ llvm_unreachable("unknown comparison predicate");
+ }
+}
+
+// Constant folding hook for comparisons.
+Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) const {
+ assert(operands.size() == 2 && "cmpi takes two arguments");
+
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs || !rhs)
+ return {};
+
+ auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+ return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+}
+
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index be65cc44d97..78f65c17b52 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -227,3 +227,29 @@ func @dim(%x : tensor<8x4xf32>) -> index {
return %0 : index
}
+// CHECK-LABEL: func @cmpi
+func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+ %c42 = constant 42 : i32
+ %cm1 = constant -1 : i32
+// CHECK-NEXT: %false = constant 0 : i1
+ %0 = cmpi "eq", %c42, %cm1 : i32
+// CHECK-NEXT: %true = constant 1 : i1
+ %1 = cmpi "ne", %c42, %cm1 : i32
+// CHECK-NEXT: %false_0 = constant 0 : i1
+ %2 = cmpi "slt", %c42, %cm1 : i32
+// CHECK-NEXT: %false_1 = constant 0 : i1
+ %3 = cmpi "sle", %c42, %cm1 : i32
+// CHECK-NEXT: %true_2 = constant 1 : i1
+ %4 = cmpi "sgt", %c42, %cm1 : i32
+// CHECK-NEXT: %true_3 = constant 1 : i1
+ %5 = cmpi "sge", %c42, %cm1 : i32
+// CHECK-NEXT: %true_4 = constant 1 : i1
+ %6 = cmpi "ult", %c42, %cm1 : i32
+// CHECK-NEXT: %true_5 = constant 1 : i1
+ %7 = cmpi "ule", %c42, %cm1 : i32
+// CHECK-NEXT: %false_6 = constant 0 : i1
+ %8 = cmpi "ugt", %c42, %cm1 : i32
+// CHECK-NEXT: %false_7 = constant 0 : i1
+ %9 = cmpi "uge", %c42, %cm1 : i32
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
OpenPOWER on IntegriCloud