diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Transforms/ConstantFoldUtils.h | 9 | ||||
| -rw-r--r-- | mlir/lib/Transforms/TestConstantFold.cpp | 12 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp | 18 | ||||
| -rw-r--r-- | mlir/test/Transforms/constant-fold.mlir | 4 |
4 files changed, 17 insertions, 26 deletions
diff --git a/mlir/include/mlir/Transforms/ConstantFoldUtils.h b/mlir/include/mlir/Transforms/ConstantFoldUtils.h index 325f6be5699..d2309a3af69 100644 --- a/mlir/include/mlir/Transforms/ConstantFoldUtils.h +++ b/mlir/include/mlir/Transforms/ConstantFoldUtils.h @@ -42,14 +42,12 @@ class ConstantFoldHelper { public: /// Constructs an instance for managing constants in the given function `f`. /// Constants tracked by this instance will be moved to the entry block of - /// `f`. If `insertAtHead` is true, the insertion always happen at the very - /// top of the entry block; otherwise, the insertion happens after the last - /// one of consecutive constant ops at the beginning of the entry block. + /// `f`. The insertion always happens at the very top of the entry block. /// /// This instance does not proactively walk the operations inside `f`; /// instead, users must invoke the following methods to manually handle each /// operation of interest. - ConstantFoldHelper(Function *f, bool insertAtHead = true); + ConstantFoldHelper(Function *f); /// Tries to perform constant folding on the given `op`, including unifying /// deplicated constants. If successful, calls `preReplaceAction` (if @@ -82,9 +80,6 @@ private: /// The function where we are managing constant. Function *function; - /// Whether to always insert constants at the very top of the entry block. - bool isInsertAtHead; - /// This map keeps track of uniqued constants. DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants; }; diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp index 60407cd49d1..0990d7a73f6 100644 --- a/mlir/lib/Transforms/TestConstantFold.cpp +++ b/mlir/lib/Transforms/TestConstantFold.cpp @@ -61,10 +61,18 @@ void TestConstantFold::runOnFunction() { opsToErase.clear(); auto &f = getFunction(); + ConstantFoldHelper helper(&f); - ConstantFoldHelper helper(&f, /*insertAtHead=*/false); + // Collect and fold the operations within the function. + SmallVector<Operation *, 8> ops; + f.walk([&](Operation *op) { ops.push_back(op); }); - f.walk([&](Operation *op) { foldOperation(op, helper); }); + // Fold the constants in reverse so that the last generated constants from + // folding are at the beginning. This creates somewhat of a linear ordering to + // the newly generated constants that matches the operation order and improves + // the readability of test cases. + for (Operation *op : llvm::reverse(ops)) + foldOperation(op, helper); // At this point, these operations are dead, remove them. for (auto *op : opsToErase) { diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp index 5908ec251d0..fc8209be872 100644 --- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp @@ -29,8 +29,7 @@ using namespace mlir; -ConstantFoldHelper::ConstantFoldHelper(Function *f, bool insertAtHead) - : function(f), isInsertAtHead(insertAtHead) {} +ConstantFoldHelper::ConstantFoldHelper(Function *f) : function(f) {} bool ConstantFoldHelper::tryToConstantFold( Operation *op, std::function<void(Operation *)> preReplaceAction) { @@ -146,18 +145,7 @@ bool ConstantFoldHelper::tryToUnify(Operation *op) { } void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) { + // Insert at the very top of the entry block. auto &entryBB = function->front(); - if (isInsertAtHead || entryBB.empty()) { - // Insert at the very top of the entry block. - op->moveBefore(&entryBB, entryBB.begin()); - } else { - // TODO: This is only used by TestConstantFold and not very clean. We should - // figure out a better way to work around this. - - // Move to be ahead of the first non-constant op. - auto it = entryBB.begin(); - while (it != entryBB.end() && it->isa<ConstantOp>()) - ++it; - op->moveBefore(&entryBB, it); - } + op->moveBefore(&entryBB, entryBB.begin()); } diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 257323ba37d..c9a8903a17b 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -265,8 +265,8 @@ func @dim(%x : tensor<8x4xf32>) -> index { func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { %c42 = constant 42 : i32 %cm1 = constant -1 : i32 -// CHECK-NEXT: %false = constant 0 : i1 -// CHECK-NEXT: %true = constant 1 : i1 +// CHECK-DAG: %false = constant 0 : i1 +// CHECK-DAG: %true = constant 1 : i1 // CHECK-NEXT: return %false, %0 = cmpi "eq", %c42, %cm1 : i32 // CHECK-SAME: %true, |

