diff options
Diffstat (limited to 'mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp')
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp new file mode 100644 index 00000000000..517a1f07530 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -0,0 +1,104 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { +public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast<ShapeInference>(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa<RankedTensorType>(); + }); + } +}; +} // end anonymous namespace + +/// Create a Shape Inference pass. +std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() { + return std::make_unique<ShapeInferencePass>(); +} |