diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-13 12:21:42 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-13 16:47:26 -0800 |
| commit | b030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch) | |
| tree | fa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib/IR | |
| parent | 7b19d736172789ce8e5ca10ae6276302004533f0 (diff) | |
| download | bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip | |
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter.
PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 80 |
1 files changed, 61 insertions, 19 deletions
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8c54df4d55b..691b2ad99c4 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -18,12 +18,13 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/Functional.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; Builder::Builder(ModuleOp module) : context(module.getContext()) {} @@ -339,27 +340,68 @@ Operation *OpBuilder::createOperation(const OperationState &state) { } /// Attempts to fold the given operation and places new results within -/// 'results'. -void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) { +/// 'results'. Returns success if the operation was folded, failure otherwise. +/// Note: This function does not erase the operation on a successful fold. +LogicalResult OpBuilder::tryFold(Operation *op, + SmallVectorImpl<Value *> &results) { results.reserve(op->getNumResults()); - SmallVector<OpFoldResult, 4> foldResults; - - // Returns if the given fold result corresponds to a valid existing value. - auto isValidValue = [](OpFoldResult result) { - return result.dyn_cast<Value *>(); + auto cleanupFailure = [&] { + results.assign(op->result_begin(), op->result_end()); + return failure(); }; - // Check if the fold failed, or did not result in only existing values. + // If this operation is already a constant, there is nothing to do. + Attribute unused; + if (matchPattern(op, m_Constant(&unused))) + return cleanupFailure(); + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. SmallVector<Attribute, 4> constOperands(op->getNumOperands()); - if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() || - !llvm::all_of(foldResults, isValidValue)) { - // Simply return the existing operation results. - results.assign(op->result_begin(), op->result_end()); - return; + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constOperands[i])); + + // Try to fold the operation. + SmallVector<OpFoldResult, 4> foldResults; + if (failed(op->fold(constOperands, foldResults)) || foldResults.empty()) + return cleanupFailure(); + + // A temporary builder used for creating constants during folding. + OpBuilder cstBuilder(context); + SmallVector<Operation *, 1> generatedConstants; + + // Populate the results with the folded results. + Dialect *dialect = op->getDialect(); + for (auto &it : llvm::enumerate(foldResults)) { + // Normal values get pushed back directly. + if (auto *value = it.value().dyn_cast<Value *>()) { + results.push_back(value); + continue; + } + + // Otherwise, try to materialize a constant operation. + if (!dialect) + return cleanupFailure(); + + // Ask the dialect to materialize a constant operation for this value. + Attribute attr = it.value().get<Attribute>(); + auto *constOp = dialect->materializeConstant( + cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc()); + if (!constOp) { + // Erase any generated constants. + for (Operation *cst : generatedConstants) + cst->erase(); + return cleanupFailure(); + } + assert(matchPattern(constOp, m_Constant(&attr))); + + generatedConstants.push_back(constOp); + results.push_back(constOp->getResult(0)); } - // Populate the results with the folded results and remove the original op. - llvm::transform(foldResults, std::back_inserter(results), - [](OpFoldResult result) { return result.get<Value *>(); }); - op->erase(); + // If we were successful, insert any generated constants. + for (Operation *cst : generatedConstants) + insert(cst); + + return success(); } |

