summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-13 12:21:42 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-13 16:47:26 -0800
commitb030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch)
treefa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib
parent7b19d736172789ce8e5ca10ae6276302004533f0 (diff)
downloadbcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz
bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter. PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp8
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp8
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp10
-rw-r--r--mlir/lib/IR/Builders.cpp80
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp43
5 files changed, 128 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 96a1a68889c..22d4ec10dd0 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -99,6 +99,14 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
/// A utility function to check if a given region is attached to a function.
static bool isFunctionRegion(Region *region) {
return llvm::isa<FuncOp>(region->getParentOp());
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 531be29666a..713546fc40d 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -163,6 +163,14 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
addInterfaces<StdInlinerInterface>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index a2345fe1c40..ae5579d9e3d 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -40,7 +40,7 @@ using namespace mlir::vector;
// VectorOpsDialect
//===----------------------------------------------------------------------===//
-mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
@@ -48,6 +48,14 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8c54df4d55b..691b2ad99c4 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -18,12 +18,13 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
Builder::Builder(ModuleOp module) : context(module.getContext()) {}
@@ -339,27 +340,68 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
}
/// Attempts to fold the given operation and places new results within
-/// 'results'.
-void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult OpBuilder::tryFold(Operation *op,
+ SmallVectorImpl<Value *> &results) {
results.reserve(op->getNumResults());
- SmallVector<OpFoldResult, 4> foldResults;
-
- // Returns if the given fold result corresponds to a valid existing value.
- auto isValidValue = [](OpFoldResult result) {
- return result.dyn_cast<Value *>();
+ auto cleanupFailure = [&] {
+ results.assign(op->result_begin(), op->result_end());
+ return failure();
};
- // Check if the fold failed, or did not result in only existing values.
+ // If this operation is already a constant, there is nothing to do.
+ Attribute unused;
+ if (matchPattern(op, m_Constant(&unused)))
+ return cleanupFailure();
+
+ // Check to see if any operands to the operation is constant and whether
+ // the operation knows how to constant fold itself.
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
- if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
- !llvm::all_of(foldResults, isValidValue)) {
- // Simply return the existing operation results.
- results.assign(op->result_begin(), op->result_end());
- return;
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
+
+ // Try to fold the operation.
+ SmallVector<OpFoldResult, 4> foldResults;
+ if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
+ return cleanupFailure();
+
+ // A temporary builder used for creating constants during folding.
+ OpBuilder cstBuilder(context);
+ SmallVector<Operation *, 1> generatedConstants;
+
+ // Populate the results with the folded results.
+ Dialect *dialect = op->getDialect();
+ for (auto &it : llvm::enumerate(foldResults)) {
+ // Normal values get pushed back directly.
+ if (auto *value = it.value().dyn_cast<Value *>()) {
+ results.push_back(value);
+ continue;
+ }
+
+ // Otherwise, try to materialize a constant operation.
+ if (!dialect)
+ return cleanupFailure();
+
+ // Ask the dialect to materialize a constant operation for this value.
+ Attribute attr = it.value().get<Attribute>();
+ auto *constOp = dialect->materializeConstant(
+ cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc());
+ if (!constOp) {
+ // Erase any generated constants.
+ for (Operation *cst : generatedConstants)
+ cst->erase();
+ return cleanupFailure();
+ }
+ assert(matchPattern(constOp, m_Constant(&attr)));
+
+ generatedConstants.push_back(constOp);
+ results.push_back(constOp->getResult(0));
}
- // Populate the results with the folded results and remove the original op.
- llvm::transform(foldResults, std::back_inserter(results),
- [](OpFoldResult result) { return result.get<Value *>(); });
- op->erase();
+ // If we were successful, insert any generated constants.
+ for (Operation *cst : generatedConstants)
+ insert(cst);
+
+ return success();
}
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ea4ad681693..ac13bc2ba5b 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -25,7 +25,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
@@ -938,6 +937,10 @@ public:
ConversionTarget &getTarget() { return target; }
private:
+ /// Attempt to legalize the given operation by folding it.
+ LogicalResult legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter);
+
/// Attempt to legalize the given operation by applying the provided pattern.
/// Returns success if the operation was legalized, failure otherwise.
LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
@@ -1003,6 +1006,14 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation isn't legal, try to fold it in-place.
+ // TODO(riverriddle) Should we always try to do this, even if the op is
+ // already legal?
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n");
+ return success();
+ }
+
// Otherwise, we need to apply a legalization pattern to this operation.
auto it = legalizerPatterns.find(op->getName());
if (it == legalizerPatterns.end()) {
@@ -1020,6 +1031,36 @@ OperationLegalizer::legalize(Operation *op,
}
LogicalResult
+OperationLegalizer::legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
+
+ // Try to fold the operation.
+ SmallVector<Value *, 2> replacementValues;
+ rewriter.setInsertionPoint(op);
+ if (failed(rewriter.tryFold(op, replacementValues)))
+ return failure();
+
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
+ // Recursively legalize any new constant operations.
+ for (unsigned i = curState.numCreatedOperations,
+ e = rewriterImpl.createdOps.size();
+ i != e; ++i) {
+ Operation *cstOp = rewriterImpl.createdOps[i];
+ if (failed(legalize(cstOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '"
+ << cstOp->getName() << "' was illegal.\n");
+ rewriterImpl.resetState(curState);
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult
OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
OpenPOWER on IntegriCloud