summaryrefslogtreecommitdiffstats
path: root/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp')
-rw-r--r--mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp104
1 files changed, 104 insertions, 0 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
new file mode 100644
index 00000000000..517a1f07530
--- /dev/null
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -0,0 +1,104 @@
+//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a Function level pass performing interprocedural
+// propagation of array shapes through function specialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
+#include "toy/ShapeInferenceInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "shape-inference"
+
+using namespace mlir;
+using namespace toy;
+
+/// Include the auto-generated definitions for the shape inference interfaces.
+#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
+
+namespace {
+/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
+/// shape inference.
+///
+/// Algorithm:
+///
+/// 1) Build a worklist containing all the operations that return a
+/// dynamically shaped tensor: these are the operations that need shape
+/// inference.
+/// 2) Iterate on the worklist:
+/// a) find an operation to process: the next ready operation in the
+/// worklist has all of its arguments non-generic,
+/// b) if no operation is found, break out of the loop,
+/// c) remove the operation from the worklist,
+/// d) infer the shape of its output from the argument types.
+/// 3) If the worklist is empty, the algorithm succeeded.
+///
+class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
+public:
+ void runOnFunction() override {
+ auto f = getFunction();
+
+ // Populate the worklist with the operations that need shape inference:
+ // these are operations that return a dynamic shape.
+ llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
+ f.walk([&](mlir::Operation *op) {
+ if (returnsDynamicShape(op))
+ opWorklist.insert(op);
+ });
+
+ // Iterate on the operations in the worklist until all operations have been
+ // inferred or no change happened (fix point).
+ while (!opWorklist.empty()) {
+ // Find the next operation ready for inference, that is an operation
+ // with all operands already resolved (non-generic).
+ auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+ if (nextop == opWorklist.end())
+ break;
+
+ Operation *op = *nextop;
+ opWorklist.erase(op);
+
+ // Ask the operation to infer its output shapes.
+ LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
+ if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
+ shapeOp.inferShapes();
+ } else {
+ op->emitError("unable to infer shape of operation without shape "
+ "inference interface");
+ return signalPassFailure();
+ }
+ }
+
+ // If the operation worklist isn't empty, this indicates a failure.
+ if (!opWorklist.empty()) {
+ f.emitError("Shape inference failed, ")
+ << opWorklist.size() << " operations couldn't be inferred\n";
+ signalPassFailure();
+ }
+ }
+
+ /// A utility method that returns if the given operation has a dynamically
+ /// shaped result.
+ static bool returnsDynamicShape(Operation *op) {
+ return llvm::any_of(op->getResultTypes(), [](Type resultType) {
+ return !resultType.isa<RankedTensorType>();
+ });
+ }
+};
+} // end anonymous namespace
+
+/// Create a Shape Inference pass.
+std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
+ return std::make_unique<ShapeInferencePass>();
+}
OpenPOWER on IntegriCloud