summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-13 12:21:42 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-13 16:47:26 -0800
commitb030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch)
treefa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib/IR
parent7b19d736172789ce8e5ca10ae6276302004533f0 (diff)
downloadbcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz
bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter. PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/Builders.cpp80
1 files changed, 61 insertions, 19 deletions
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8c54df4d55b..691b2ad99c4 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -18,12 +18,13 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
Builder::Builder(ModuleOp module) : context(module.getContext()) {}
@@ -339,27 +340,68 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
}
/// Attempts to fold the given operation and places new results within
-/// 'results'.
-void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult OpBuilder::tryFold(Operation *op,
+ SmallVectorImpl<Value *> &results) {
results.reserve(op->getNumResults());
- SmallVector<OpFoldResult, 4> foldResults;
-
- // Returns if the given fold result corresponds to a valid existing value.
- auto isValidValue = [](OpFoldResult result) {
- return result.dyn_cast<Value *>();
+ auto cleanupFailure = [&] {
+ results.assign(op->result_begin(), op->result_end());
+ return failure();
};
- // Check if the fold failed, or did not result in only existing values.
+ // If this operation is already a constant, there is nothing to do.
+ Attribute unused;
+ if (matchPattern(op, m_Constant(&unused)))
+ return cleanupFailure();
+
+ // Check to see if any operands to the operation is constant and whether
+ // the operation knows how to constant fold itself.
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
- if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
- !llvm::all_of(foldResults, isValidValue)) {
- // Simply return the existing operation results.
- results.assign(op->result_begin(), op->result_end());
- return;
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
+
+ // Try to fold the operation.
+ SmallVector<OpFoldResult, 4> foldResults;
+ if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
+ return cleanupFailure();
+
+ // A temporary builder used for creating constants during folding.
+ OpBuilder cstBuilder(context);
+ SmallVector<Operation *, 1> generatedConstants;
+
+ // Populate the results with the folded results.
+ Dialect *dialect = op->getDialect();
+ for (auto &it : llvm::enumerate(foldResults)) {
+ // Normal values get pushed back directly.
+ if (auto *value = it.value().dyn_cast<Value *>()) {
+ results.push_back(value);
+ continue;
+ }
+
+ // Otherwise, try to materialize a constant operation.
+ if (!dialect)
+ return cleanupFailure();
+
+ // Ask the dialect to materialize a constant operation for this value.
+ Attribute attr = it.value().get<Attribute>();
+ auto *constOp = dialect->materializeConstant(
+ cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc());
+ if (!constOp) {
+ // Erase any generated constants.
+ for (Operation *cst : generatedConstants)
+ cst->erase();
+ return cleanupFailure();
+ }
+ assert(matchPattern(constOp, m_Constant(&attr)));
+
+ generatedConstants.push_back(constOp);
+ results.push_back(constOp->getResult(0));
}
- // Populate the results with the folded results and remove the original op.
- llvm::transform(foldResults, std::back_inserter(results),
- [](OpFoldResult result) { return result.get<Value *>(); });
- op->erase();
+ // If we were successful, insert any generated constants.
+ for (Operation *cst : generatedConstants)
+ insert(cst);
+
+ return success();
}
OpenPOWER on IntegriCloud