diff options
| author | Chris Lattner <clattner@google.com> | 2018-10-25 22:04:35 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:40:35 -0700 |
| commit | 967d934180dd7f691132e1e625aad29db788a5f1 (patch) | |
| tree | a65f6917c39ac257521d0d92dc9411c0c44c45fd /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | |
| parent | 988ce3387f61174d763edff5014d8c1d19627d35 (diff) | |
| download | bcm5719-llvm-967d934180dd7f691132e1e625aad29db788a5f1.tar.gz bcm5719-llvm-967d934180dd7f691132e1e625aad29db788a5f1.zip | |
Fix two issues:
1) We incorrectly reassociated non-reassociative operations like subi, causing
miscompilations.
2) When constant folding, we didn't add users of the new constant back to the
worklist for reprocessing, causing us to miss some cases (pointed out by
Uday).
The code for tensorflow/mlir#2 is gross, but I'll add the new APIs in a followup patch.
PiperOrigin-RevId: 218803984
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index ebad9e20316..30034b6fce5 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -202,6 +202,28 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, else cstValue = rewriter.create<ConstantOp>( op->getLoc(), resultConstants[i], res->getType()); + + // Add all the users of the result to the worklist so we make sure to + // revisit them. + // + // TODO: This is super gross. SSAValue use iterators should have an + // "owner" that can be downcasted to operation and other things. This + // will require a rejiggering of the class hierarchies. + if (auto *stmt = dyn_cast<OperationStmt>(op)) { + // TODO: Add a result->getUsers() iterator. + for (auto &operand : stmt->getResult(i)->getUses()) { + if (auto *op = dyn_cast<OperationStmt>(operand.getOwner())) + addToWorklist(op); + } + } else { + auto *inst = cast<OperationInst>(op); + // TODO: Add a result->getUsers() iterator. + for (auto &operand : inst->getResult(i)->getUses()) { + if (auto *op = dyn_cast<OperationInst>(operand.getOwner())) + addToWorklist(op); + } + } + res->replaceAllUsesWith(cstValue); } @@ -210,10 +232,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, continue; } - // If this is an associative binary operation with a constant on the LHS, - // move it to the right side. + // If this is a commutative binary operation with a constant on the left + // side move it to the right side. if (operandConstants.size() == 2 && operandConstants[0] && - !operandConstants[1]) { + !operandConstants[1] && op->isCommutative()) { auto *newLHS = op->getOperand(1); op->setOperand(1, op->getOperand(0)); op->setOperand(0, newLHS); |

