summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
authorChris Lattner <clattner@google.com>2018-10-25 22:04:35 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:40:35 -0700
commit967d934180dd7f691132e1e625aad29db788a5f1 (patch)
treea65f6917c39ac257521d0d92dc9411c0c44c45fd /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
parent988ce3387f61174d763edff5014d8c1d19627d35 (diff)
downloadbcm5719-llvm-967d934180dd7f691132e1e625aad29db788a5f1.tar.gz
bcm5719-llvm-967d934180dd7f691132e1e625aad29db788a5f1.zip
Fix two issues:
1) We incorrectly reassociated non-reassociative operations like subi, causing miscompilations. 2) When constant folding, we didn't add users of the new constant back to the worklist for reprocessing, causing us to miss some cases (pointed out by Uday). The code for tensorflow/mlir#2 is gross, but I'll add the new APIs in a followup patch. PiperOrigin-RevId: 218803984
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index ebad9e20316..30034b6fce5 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -202,6 +202,28 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
else
cstValue = rewriter.create<ConstantOp>(
op->getLoc(), resultConstants[i], res->getType());
+
+ // Add all the users of the result to the worklist so we make sure to
+ // revisit them.
+ //
+ // TODO: This is super gross. SSAValue use iterators should have an
+ // "owner" that can be downcasted to operation and other things. This
+ // will require a rejiggering of the class hierarchies.
+ if (auto *stmt = dyn_cast<OperationStmt>(op)) {
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &operand : stmt->getResult(i)->getUses()) {
+ if (auto *op = dyn_cast<OperationStmt>(operand.getOwner()))
+ addToWorklist(op);
+ }
+ } else {
+ auto *inst = cast<OperationInst>(op);
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &operand : inst->getResult(i)->getUses()) {
+ if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
+ addToWorklist(op);
+ }
+ }
+
res->replaceAllUsesWith(cstValue);
}
@@ -210,10 +232,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
continue;
}
- // If this is an associative binary operation with a constant on the LHS,
- // move it to the right side.
+ // If this is a commutative binary operation with a constant on the left
+ // side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
- !operandConstants[1]) {
+ !operandConstants[1] && op->isCommutative()) {
auto *newLHS = op->getOperand(1);
op->setOperand(1, op->getOperand(0));
op->setOperand(0, newLHS);
OpenPOWER on IntegriCloud