summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Transforms/ConstantFoldUtils.h9
-rw-r--r--mlir/lib/Transforms/TestConstantFold.cpp12
-rw-r--r--mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp18
-rw-r--r--mlir/test/Transforms/constant-fold.mlir4
4 files changed, 17 insertions, 26 deletions
diff --git a/mlir/include/mlir/Transforms/ConstantFoldUtils.h b/mlir/include/mlir/Transforms/ConstantFoldUtils.h
index 325f6be5699..d2309a3af69 100644
--- a/mlir/include/mlir/Transforms/ConstantFoldUtils.h
+++ b/mlir/include/mlir/Transforms/ConstantFoldUtils.h
@@ -42,14 +42,12 @@ class ConstantFoldHelper {
public:
/// Constructs an instance for managing constants in the given function `f`.
/// Constants tracked by this instance will be moved to the entry block of
- /// `f`. If `insertAtHead` is true, the insertion always happen at the very
- /// top of the entry block; otherwise, the insertion happens after the last
- /// one of consecutive constant ops at the beginning of the entry block.
+ /// `f`. The insertion always happens at the very top of the entry block.
///
/// This instance does not proactively walk the operations inside `f`;
/// instead, users must invoke the following methods to manually handle each
/// operation of interest.
- ConstantFoldHelper(Function *f, bool insertAtHead = true);
+ ConstantFoldHelper(Function *f);
/// Tries to perform constant folding on the given `op`, including unifying
/// deplicated constants. If successful, calls `preReplaceAction` (if
@@ -82,9 +80,6 @@ private:
/// The function where we are managing constant.
Function *function;
- /// Whether to always insert constants at the very top of the entry block.
- bool isInsertAtHead;
-
/// This map keeps track of uniqued constants.
DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
};
diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp
index 60407cd49d1..0990d7a73f6 100644
--- a/mlir/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/lib/Transforms/TestConstantFold.cpp
@@ -61,10 +61,18 @@ void TestConstantFold::runOnFunction() {
opsToErase.clear();
auto &f = getFunction();
+ ConstantFoldHelper helper(&f);
- ConstantFoldHelper helper(&f, /*insertAtHead=*/false);
+ // Collect and fold the operations within the function.
+ SmallVector<Operation *, 8> ops;
+ f.walk([&](Operation *op) { ops.push_back(op); });
- f.walk([&](Operation *op) { foldOperation(op, helper); });
+ // Fold the constants in reverse so that the last generated constants from
+ // folding are at the beginning. This creates somewhat of a linear ordering to
+ // the newly generated constants that matches the operation order and improves
+ // the readability of test cases.
+ for (Operation *op : llvm::reverse(ops))
+ foldOperation(op, helper);
// At this point, these operations are dead, remove them.
for (auto *op : opsToErase) {
diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp
index 5908ec251d0..fc8209be872 100644
--- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp
@@ -29,8 +29,7 @@
using namespace mlir;
-ConstantFoldHelper::ConstantFoldHelper(Function *f, bool insertAtHead)
- : function(f), isInsertAtHead(insertAtHead) {}
+ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {}
bool ConstantFoldHelper::tryToConstantFold(
Operation *op, std::function<void(Operation *)> preReplaceAction) {
@@ -146,18 +145,7 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) {
}
void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
+ // Insert at the very top of the entry block.
auto &entryBB = function->front();
- if (isInsertAtHead || entryBB.empty()) {
- // Insert at the very top of the entry block.
- op->moveBefore(&entryBB, entryBB.begin());
- } else {
- // TODO: This is only used by TestConstantFold and not very clean. We should
- // figure out a better way to work around this.
-
- // Move to be ahead of the first non-constant op.
- auto it = entryBB.begin();
- while (it != entryBB.end() && it->isa<ConstantOp>())
- ++it;
- op->moveBefore(&entryBB, it);
- }
+ op->moveBefore(&entryBB, entryBB.begin());
}
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 257323ba37d..c9a8903a17b 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -265,8 +265,8 @@ func @dim(%x : tensor<8x4xf32>) -> index {
func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%c42 = constant 42 : i32
%cm1 = constant -1 : i32
-// CHECK-NEXT: %false = constant 0 : i1
-// CHECK-NEXT: %true = constant 1 : i1
+// CHECK-DAG: %false = constant 0 : i1
+// CHECK-DAG: %true = constant 1 : i1
// CHECK-NEXT: return %false,
%0 = cmpi "eq", %c42, %cm1 : i32
// CHECK-SAME: %true,
OpenPOWER on IntegriCloud