summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
committerMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
commit0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch)
tree31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/lib/Transforms
parent6f635f90929da9545dd696071a829a1a42f84b30 (diff)
parent5b4a01d4a63cb66ab981e52548f940813393bf42 (diff)
downloadbcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz
bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/AffineDataCopyGeneration.cpp268
-rw-r--r--mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp239
-rw-r--r--mlir/lib/Transforms/CMakeLists.txt38
-rw-r--r--mlir/lib/Transforms/CSE.cpp263
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp45
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp1846
-rw-r--r--mlir/lib/Transforms/Inliner.cpp296
-rw-r--r--mlir/lib/Transforms/LoopCoalescing.cpp96
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp1979
-rw-r--r--mlir/lib/Transforms/LoopInvariantCodeMotion.cpp140
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp402
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp182
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp235
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp227
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp379
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp108
-rw-r--r--mlir/lib/Transforms/StripDebugInfo.cpp37
-rw-r--r--mlir/lib/Transforms/Utils/CMakeLists.txt21
-rw-r--r--mlir/lib/Transforms/Utils/FoldUtils.cpp246
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp247
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp356
-rw-r--r--mlir/lib/Transforms/Utils/LoopFusionUtils.cpp480
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp1779
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp348
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp469
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp1292
-rw-r--r--mlir/lib/Transforms/ViewOpGraph.cpp170
-rw-r--r--mlir/lib/Transforms/ViewRegionGraph.cpp85
28 files changed, 12273 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
new file mode 100644
index 00000000000..902f5c3adcb
--- /dev/null
+++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
@@ -0,0 +1,268 @@
+//===- AffineDataCopyGeneration.cpp - Explicit memref copying pass ------*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to automatically promote accessed memref regions
+// to buffers in a faster memory space that is explicitly managed, with the
+// necessary data movement operations performed through either regular
+// point-wise load/store's or DMAs. Such explicit copying (also referred to as
+// array packing/unpacking in the literature), when done on arrays that exhibit
+// reuse, results in near elimination of conflict misses, TLB misses, reduced
+// use of hardware prefetch streams, and reduced false sharing. It is also
+// necessary for hardware that explicitly managed levels in the memory
+// hierarchy, and where DMAs may have to be used. This optimization is often
+// performed on already tiled code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "affine-data-copy-generate"
+
+using namespace mlir;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<unsigned long long> clFastMemoryCapacity(
+ "affine-data-copy-generate-fast-mem-capacity",
+ llvm::cl::desc(
+ "Set fast memory space capacity in KiB (default: unlimited)"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool>
+ clDma("affine-data-copy-generate-dma",
+ llvm::cl::desc("Generate DMA instead of point-wise copy"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(true));
+
+static llvm::cl::opt<unsigned> clFastMemorySpace(
+ "affine-data-copy-generate-fast-mem-space", llvm::cl::init(1),
+ llvm::cl::desc(
+ "Fast memory space identifier for copy generation (default: 1)"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool> clSkipNonUnitStrideLoop(
+ "affine-data-copy-generate-skip-non-unit-stride-loops", llvm::cl::Hidden,
+ llvm::cl::init(false),
+ llvm::cl::desc("Testing purposes: avoid non-unit stride loop choice depths "
+ "for copy placement"),
+ llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// Replaces all loads and stores on memref's living in 'slowMemorySpace' by
+/// introducing copy operations to transfer data into `fastMemorySpace` and
+/// rewriting the original load's/store's to instead load/store from the
+/// allocated fast memory buffers. Additional options specify the identifier
+/// corresponding to the fast memory space and the amount of fast memory space
+/// available. The pass traverses through the nesting structure, recursing to
+/// inner levels if necessary to determine at what depth copies need to be
+/// placed so that the allocated buffers fit within the memory capacity
+/// provided.
+// TODO(bondhugula): We currently can't generate copies correctly when stores
+// are strided. Check for strided stores.
+struct AffineDataCopyGeneration
+ : public FunctionPass<AffineDataCopyGeneration> {
+ explicit AffineDataCopyGeneration(
+ unsigned slowMemorySpace = 0,
+ unsigned fastMemorySpace = clFastMemorySpace, unsigned tagMemorySpace = 0,
+ int minDmaTransferSize = 1024,
+ uint64_t fastMemCapacityBytes =
+ (clFastMemoryCapacity.getNumOccurrences() > 0
+ ? clFastMemoryCapacity * 1024 // cl-provided size is in KiB
+ : std::numeric_limits<uint64_t>::max()),
+ bool generateDma = clDma,
+ bool skipNonUnitStrideLoops = clSkipNonUnitStrideLoop)
+ : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace),
+ tagMemorySpace(tagMemorySpace), minDmaTransferSize(minDmaTransferSize),
+ fastMemCapacityBytes(fastMemCapacityBytes), generateDma(generateDma),
+ skipNonUnitStrideLoops(skipNonUnitStrideLoops) {}
+
+ explicit AffineDataCopyGeneration(const AffineDataCopyGeneration &other)
+ : slowMemorySpace(other.slowMemorySpace),
+ fastMemorySpace(other.fastMemorySpace),
+ tagMemorySpace(other.tagMemorySpace),
+ minDmaTransferSize(other.minDmaTransferSize),
+ fastMemCapacityBytes(other.fastMemCapacityBytes),
+ generateDma(other.generateDma),
+ skipNonUnitStrideLoops(other.skipNonUnitStrideLoops) {}
+
+ void runOnFunction() override;
+ LogicalResult runOnBlock(Block *block, DenseSet<Operation *> &copyNests);
+
+ // Slow memory space associated with copies.
+ const unsigned slowMemorySpace;
+ // Fast memory space associated with copies.
+ unsigned fastMemorySpace;
+ // Memory space associated with DMA tags.
+ unsigned tagMemorySpace;
+ // Minimum DMA transfer size supported by the target in bytes.
+ const int minDmaTransferSize;
+ // Capacity of the faster memory space.
+ uint64_t fastMemCapacityBytes;
+
+ // If set, generate DMA operations instead of read/write.
+ bool generateDma;
+
+ // If set, ignore loops with steps other than 1.
+ bool skipNonUnitStrideLoops;
+
+ // Constant zero index to avoid too many duplicates.
+ Value zeroIndex = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Generates copies for memref's living in 'slowMemorySpace' into newly created
+/// buffers in 'fastMemorySpace', and replaces memory operations to the former
+/// by the latter. Only load op's handled for now.
+/// TODO(bondhugula): extend this to store op's.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createAffineDataCopyGenerationPass(
+ unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace,
+ int minDmaTransferSize, uint64_t fastMemCapacityBytes) {
+ return std::make_unique<AffineDataCopyGeneration>(
+ slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize,
+ fastMemCapacityBytes);
+}
+
+/// Generate copies for this block. The block is partitioned into separate
+/// ranges: each range is either a sequence of one or more operations starting
+/// and ending with an affine load or store op, or just an affine.forop (which
+/// could have other affine for op's nested within).
+LogicalResult
+AffineDataCopyGeneration::runOnBlock(Block *block,
+ DenseSet<Operation *> &copyNests) {
+ if (block->empty())
+ return success();
+
+ AffineCopyOptions copyOptions = {generateDma, slowMemorySpace,
+ fastMemorySpace, tagMemorySpace,
+ fastMemCapacityBytes};
+
+ // Every affine.forop in the block starts and ends a block range for copying;
+ // in addition, a contiguous sequence of operations starting with a
+ // load/store op but not including any copy nests themselves is also
+ // identified as a copy block range. Straightline code (a contiguous chunk of
+ // operations excluding AffineForOp's) are always assumed to not exhaust
+ // memory. As a result, this approach is conservative in some cases at the
+ // moment; we do a check later and report an error with location info.
+ // TODO(bondhugula): An 'affine.if' operation is being treated similar to an
+ // operation. 'affine.if''s could have 'affine.for's in them;
+ // treat them separately.
+
+ // Get to the first load, store, or for op (that is not a copy nest itself).
+ auto curBegin =
+ std::find_if(block->begin(), block->end(), [&](Operation &op) {
+ return (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ isa<AffineForOp>(op)) &&
+ copyNests.count(&op) == 0;
+ });
+
+ // Create [begin, end) ranges.
+ auto it = curBegin;
+ while (it != block->end()) {
+ AffineForOp forOp;
+ // If you hit a non-copy for loop, we will split there.
+ if ((forOp = dyn_cast<AffineForOp>(&*it)) && copyNests.count(forOp) == 0) {
+ // Perform the copying up unti this 'for' op first.
+ affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/it, copyOptions,
+ copyNests);
+
+ // Returns true if the footprint is known to exceed capacity.
+ auto exceedsCapacity = [&](AffineForOp forOp) {
+ Optional<int64_t> footprint =
+ getMemoryFootprintBytes(forOp,
+ /*memorySpace=*/0);
+ return (footprint.hasValue() &&
+ static_cast<uint64_t>(footprint.getValue()) >
+ fastMemCapacityBytes);
+ };
+
+ // If the memory footprint of the 'affine.for' loop is higher than fast
+ // memory capacity (when provided), we recurse to copy at an inner level
+ // until we find a depth at which footprint fits in fast mem capacity. If
+ // the footprint can't be calculated, we assume for now it fits. Recurse
+ // inside if footprint for 'forOp' exceeds capacity, or when
+ // skipNonUnitStrideLoops is set and the step size is not one.
+ bool recurseInner = skipNonUnitStrideLoops ? forOp.getStep() != 1
+ : exceedsCapacity(forOp);
+ if (recurseInner) {
+ // We'll recurse and do the copies at an inner level for 'forInst'.
+ // Recurse onto the body of this loop.
+ runOnBlock(forOp.getBody(), copyNests);
+ } else {
+ // We have enough capacity, i.e., copies will be computed for the
+ // portion of the block until 'it', and for 'it', which is 'forOp'. Note
+ // that for the latter, the copies are placed just before this loop (for
+ // incoming copies) and right after (for outgoing ones).
+
+ // Inner loop copies have their own scope - we don't thus update
+ // consumed capacity. The footprint check above guarantees this inner
+ // loop's footprint fits.
+ affineDataCopyGenerate(/*begin=*/it, /*end=*/std::next(it), copyOptions,
+ copyNests);
+ }
+ // Get to the next load or store op after 'forOp'.
+ curBegin = std::find_if(std::next(it), block->end(), [&](Operation &op) {
+ return (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ isa<AffineForOp>(op)) &&
+ copyNests.count(&op) == 0;
+ });
+ it = curBegin;
+ } else {
+ assert(copyNests.count(&*it) == 0 &&
+ "all copy nests generated should have been skipped above");
+ // We simply include this op in the current range and continue for more.
+ ++it;
+ }
+ }
+
+ // Generate the copy for the final block range.
+ if (curBegin != block->end()) {
+ // Can't be a terminator because it would have been skipped above.
+ assert(!curBegin->isKnownTerminator() && "can't be a terminator");
+ // Exclude the affine terminator - hence, the std::prev.
+ affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/std::prev(block->end()),
+ copyOptions, copyNests);
+ }
+
+ return success();
+}
+
+void AffineDataCopyGeneration::runOnFunction() {
+ FuncOp f = getFunction();
+ OpBuilder topBuilder(f.getBody());
+ zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
+
+ // Nests that are copy-in's or copy-out's; the root AffineForOps of those
+ // nests are stored herein.
+ DenseSet<Operation *> copyNests;
+
+ // Clear recorded copy nests.
+ copyNests.clear();
+
+ for (auto &block : f)
+ runOnBlock(&block, copyNests);
+
+ // Promote any single iteration loops in the copy nests.
+ for (auto nest : copyNests) {
+ nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
+ }
+}
+
+static PassRegistration<AffineDataCopyGeneration>
+ pass("affine-data-copy-generate",
+ "Generate explicit copying for memory operations");
diff --git a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp
new file mode 100644
index 00000000000..24ec2d7c70b
--- /dev/null
+++ b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -0,0 +1,239 @@
+//===- AffineLoopInvariantCodeMotion.cpp - Code to perform loop fusion-----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop invariant code motion.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "licm"
+
+using namespace mlir;
+
+namespace {
+
+/// Loop invariant code motion (LICM) pass.
+/// TODO(asabne) : The pass is missing zero-trip tests.
+/// TODO(asabne) : Check for the presence of side effects before hoisting.
+/// TODO: This code should be removed once the new LICM pass can handle its
+/// uses.
+struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> {
+ void runOnFunction() override;
+ void runOnAffineForOp(AffineForOp forOp);
+};
+} // end anonymous namespace
+
+static bool
+checkInvarianceOfNestedIfOps(Operation *op, Value indVar,
+ SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist);
+static bool isOpLoopInvariant(Operation &op, Value indVar,
+ SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist);
+
+static bool
+areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
+ SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist);
+
+static bool isMemRefDereferencingOp(Operation &op) {
+ // TODO(asabne): Support DMA Ops.
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
+ return true;
+ }
+ return false;
+}
+
+// Returns true if the individual op is loop invariant.
+bool isOpLoopInvariant(Operation &op, Value indVar,
+ SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist) {
+ LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
+
+ if (isa<AffineIfOp>(op)) {
+ if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) {
+ return false;
+ }
+ } else if (isa<AffineForOp>(op)) {
+ // If the body of a predicated region has a for loop, we don't hoist the
+ // 'affine.if'.
+ return false;
+ } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ // TODO(asabne): Support DMA ops.
+ return false;
+ } else if (!isa<ConstantOp>(op)) {
+ if (isMemRefDereferencingOp(op)) {
+ Value memref = isa<AffineLoadOp>(op)
+ ? cast<AffineLoadOp>(op).getMemRef()
+ : cast<AffineStoreOp>(op).getMemRef();
+ for (auto *user : memref->getUsers()) {
+ // If this memref has a user that is a DMA, give up because these
+ // operations write to this memref.
+ if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ return false;
+ }
+ // If the memref used by the load/store is used in a store elsewhere in
+ // the loop nest, we do not hoist. Similarly, if the memref used in a
+ // load is also being stored too, we do not hoist the load.
+ if (isa<AffineStoreOp>(user) ||
+ (isa<AffineLoadOp>(user) && isa<AffineStoreOp>(op))) {
+ if (&op != user) {
+ SmallVector<AffineForOp, 8> userIVs;
+ getLoopIVs(*user, &userIVs);
+ // Check that userIVs don't contain the for loop around the op.
+ if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) {
+ return false;
+ }
+ }
+ }
+ }
+ }
+
+ // Insert this op in the defined ops list.
+ definedOps.insert(&op);
+
+ if (op.getNumOperands() == 0 && !isa<AffineTerminatorOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n");
+ return false;
+ }
+ for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
+ auto *operandSrc = op.getOperand(i)->getDefiningOp();
+
+ LLVM_DEBUG(
+ op.getOperand(i)->print(llvm::dbgs() << "\nIterating on operand\n"));
+
+ // If the loop IV is the operand, this op isn't loop invariant.
+ if (indVar == op.getOperand(i)) {
+ LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n");
+ return false;
+ }
+
+ if (operandSrc != nullptr) {
+ LLVM_DEBUG(llvm::dbgs()
+ << *operandSrc << "\nIterating on operand src\n");
+
+ // If the value was defined in the loop (outside of the
+ // if/else region), and that operation itself wasn't meant to
+ // be hoisted, then mark this operation loop dependent.
+ if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) {
+ return false;
+ }
+ }
+ }
+ }
+
+ // If no operand was loop variant, mark this op for motion.
+ opsToHoist.insert(&op);
+ return true;
+}
+
+// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
+bool areAllOpsInTheBlockListInvariant(
+ Region &blockList, Value indVar, SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist) {
+
+ for (auto &b : blockList) {
+ for (auto &op : b) {
+ if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+// Returns true if the affine.if op can be hoisted.
+bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar,
+ SmallPtrSetImpl<Operation *> &definedOps,
+ SmallPtrSetImpl<Operation *> &opsToHoist) {
+ assert(isa<AffineIfOp>(op));
+ auto ifOp = cast<AffineIfOp>(op);
+
+ if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, definedOps,
+ opsToHoist)) {
+ return false;
+ }
+
+ if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, definedOps,
+ opsToHoist)) {
+ return false;
+ }
+
+ return true;
+}
+
+void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
+ auto *loopBody = forOp.getBody();
+ auto indVar = forOp.getInductionVar();
+
+ SmallPtrSet<Operation *, 8> definedOps;
+ // This is the place where hoisted instructions would reside.
+ OpBuilder b(forOp.getOperation());
+
+ SmallPtrSet<Operation *, 8> opsToHoist;
+ SmallVector<Operation *, 8> opsToMove;
+
+ for (auto &op : *loopBody) {
+ // We don't hoist for loops.
+ if (!isa<AffineForOp>(op)) {
+ if (!isa<AffineTerminatorOp>(op)) {
+ if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) {
+ opsToMove.push_back(&op);
+ }
+ }
+ }
+ }
+
+ // For all instructions that we found to be invariant, place sequentially
+ // right before the for loop.
+ for (auto *op : opsToMove) {
+ op->moveBefore(forOp);
+ }
+
+ LLVM_DEBUG(forOp.getOperation()->print(llvm::dbgs() << "Modified loop\n"));
+}
+
+void LoopInvariantCodeMotion::runOnFunction() {
+ // Walk through all loops in a function in innermost-loop-first order. This
+ // way, we first LICM from the inner loop, and place the ops in
+ // the outer loop, which in turn can be further LICM'ed.
+ getFunction().walk([&](AffineForOp op) {
+ LLVM_DEBUG(op.getOperation()->print(llvm::dbgs() << "\nOriginal loop\n"));
+ runOnAffineForOp(op);
+ });
+}
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createAffineLoopInvariantCodeMotionPass() {
+ return std::make_unique<LoopInvariantCodeMotion>();
+}
+
+static PassRegistration<LoopInvariantCodeMotion>
+ pass("affine-loop-invariant-code-motion",
+ "Hoist loop invariant instructions outside of the loop");
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000..d6c5bd88f7f
--- /dev/null
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,38 @@
+add_subdirectory(Utils)
+
+add_llvm_library(MLIRTransforms
+ AffineDataCopyGeneration.cpp
+ AffineLoopInvariantCodeMotion.cpp
+ Canonicalizer.cpp
+ CSE.cpp
+ DialectConversion.cpp
+ Inliner.cpp
+ LoopCoalescing.cpp
+ LoopFusion.cpp
+ LoopInvariantCodeMotion.cpp
+ LoopTiling.cpp
+ LoopUnrollAndJam.cpp
+ LoopUnroll.cpp
+ MemRefDataFlowOpt.cpp
+ PipelineDataTransfer.cpp
+ SimplifyAffineStructures.cpp
+ StripDebugInfo.cpp
+ Vectorize.cpp
+ ViewOpGraph.cpp
+ ViewRegionGraph.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+ )
+
+add_dependencies(MLIRTransforms
+ MLIRLoopLikeInterfaceIncGen
+ MLIRStandardOpsIncGen)
+target_link_libraries(MLIRTransforms
+ MLIRAffineOps
+ MLIRAnalysis
+ MLIRLoopOps
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorOps
+ )
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
new file mode 100644
index 00000000000..714fb1d0109
--- /dev/null
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -0,0 +1,263 @@
+//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass performs a simple common sub-expression elimination
+// algorithm on operations within a function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RecyclingAllocator.h"
+#include <deque>
+using namespace mlir;
+
+namespace {
+// TODO(riverriddle) Handle commutative operations.
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+ static unsigned getHashValue(const Operation *opC) {
+ auto *op = const_cast<Operation *>(opC);
+ // Hash the operations based upon their:
+ // - Operation Name
+ // - Attributes
+ // - Result Types
+ // - Operands
+ return hash_combine(
+ op->getName(), op->getAttrList().getDictionary(),
+ hash_combine_range(op->result_type_begin(), op->result_type_end()),
+ hash_combine_range(op->operand_begin(), op->operand_end()));
+ }
+ static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+ auto *lhs = const_cast<Operation *>(lhsC);
+ auto *rhs = const_cast<Operation *>(rhsC);
+ if (lhs == rhs)
+ return true;
+ if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+ rhs == getTombstoneKey() || rhs == getEmptyKey())
+ return false;
+
+ // Compare the operation name.
+ if (lhs->getName() != rhs->getName())
+ return false;
+ // Check operand and result type counts.
+ if (lhs->getNumOperands() != rhs->getNumOperands() ||
+ lhs->getNumResults() != rhs->getNumResults())
+ return false;
+ // Compare attributes.
+ if (lhs->getAttrList() != rhs->getAttrList())
+ return false;
+ // Compare operands.
+ if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
+ rhs->operand_begin()))
+ return false;
+ // Compare result types.
+ return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
+ rhs->result_type_begin());
+ }
+};
+} // end anonymous namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+struct CSE : public OperationPass<CSE> {
+ CSE() = default;
+ CSE(const CSE &) {}
+
+ /// Shared implementation of operation elimination and scoped map definitions.
+ using AllocatorTy = llvm::RecyclingAllocator<
+ llvm::BumpPtrAllocator,
+ llvm::ScopedHashTableVal<Operation *, Operation *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+ SimpleOperationInfo, AllocatorTy>;
+
+ /// Represents a single entry in the depth first traversal of a CFG.
+ struct CFGStackNode {
+ CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+ : scope(knownValues), node(node), childIterator(node->begin()),
+ processed(false) {}
+
+ /// Scope for the known values.
+ ScopedMapTy::ScopeTy scope;
+
+ DominanceInfoNode *node;
+ DominanceInfoNode::iterator childIterator;
+
+ /// If this node has been fully processed yet or not.
+ bool processed;
+ };
+
+ /// Attempt to eliminate a redundant operation. Returns success if the
+ /// operation was marked for removal, failure otherwise.
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op);
+
+ void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+ Block *bb);
+ void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+ Region &region);
+
+ void runOnOperation() override;
+
+private:
+ /// Operations marked as dead and to be erased.
+ std::vector<Operation *> opsToErase;
+
+ /// Statistics for CSE.
+ Statistic numCSE{this, "num-cse'd", "Number of operations CSE'd"};
+ Statistic numDCE{this, "num-dce'd", "Number of operations trivially DCE'd"};
+};
+} // end anonymous namespace
+
+/// Attempt to eliminate a redundant operation.
+LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
+ // Don't simplify operations with nested blocks. We don't currently model
+ // equality comparisons correctly among other things. It is also unclear
+ // whether we would want to CSE such operations.
+ if (op->getNumRegions() != 0)
+ return failure();
+
+ // TODO(riverriddle) We currently only eliminate non side-effecting
+ // operations.
+ if (!op->hasNoSideEffect())
+ return failure();
+
+ // If the operation is already trivially dead just add it to the erase list.
+ if (op->use_empty()) {
+ opsToErase.push_back(op);
+ ++numDCE;
+ return success();
+ }
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion.
+ op->replaceAllUsesWith(existing);
+ opsToErase.push_back(op);
+
+ // If the existing operation has an unknown location and the current
+ // operation doesn't, then set the existing op's location to that of the
+ // current op.
+ if (existing->getLoc().isa<UnknownLoc>() &&
+ !op->getLoc().isa<UnknownLoc>()) {
+ existing->setLoc(op->getLoc());
+ }
+
+ ++numCSE;
+ return success();
+ }
+
+ // Otherwise, we add this operation to the known values map.
+ knownValues.insert(op, op);
+ return failure();
+}
+
+void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+ Block *bb) {
+ for (auto &inst : *bb) {
+ // If the operation is simplified, we don't process any held regions.
+ if (succeeded(simplifyOperation(knownValues, &inst)))
+ continue;
+
+ // If this operation is isolated above, we can't process nested regions with
+ // the given 'knownValues' map. This would cause the insertion of implicit
+ // captures in explicit capture only regions.
+ if (!inst.isRegistered() || inst.isKnownIsolatedFromAbove()) {
+ ScopedMapTy nestedKnownValues;
+ for (auto &region : inst.getRegions())
+ simplifyRegion(nestedKnownValues, domInfo, region);
+ continue;
+ }
+
+ // Otherwise, process nested regions normally.
+ for (auto &region : inst.getRegions())
+ simplifyRegion(knownValues, domInfo, region);
+ }
+}
+
+void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
+ Region &region) {
+ // If the region is empty there is nothing to do.
+ if (region.empty())
+ return;
+
+ // If the region only contains one block, then simplify it directly.
+ if (std::next(region.begin()) == region.end()) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(knownValues, domInfo, &region.front());
+ return;
+ }
+
+ // Note, deque is being used here because there was significant performance
+ // gains over vector when the container becomes very large due to the
+ // specific access patterns. If/when these performance issues are no
+ // longer a problem we can change this to vector. For more information see
+ // the llvm mailing list discussion on this:
+ // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
+ std::deque<std::unique_ptr<CFGStackNode>> stack;
+
+ // Process the nodes of the dom tree for this region.
+ stack.emplace_back(std::make_unique<CFGStackNode>(
+ knownValues, domInfo.getRootNode(&region)));
+
+ while (!stack.empty()) {
+ auto &currentNode = stack.back();
+
+ // Check to see if we need to process this node.
+ if (!currentNode->processed) {
+ currentNode->processed = true;
+ simplifyBlock(knownValues, domInfo, currentNode->node->getBlock());
+ }
+
+ // Otherwise, check to see if we need to process a child node.
+ if (currentNode->childIterator != currentNode->node->end()) {
+ auto *childNode = *(currentNode->childIterator++);
+ stack.emplace_back(
+ std::make_unique<CFGStackNode>(knownValues, childNode));
+ } else {
+ // Finally, if the node and all of its children have been processed
+ // then we delete the node.
+ stack.pop_back();
+ }
+ }
+}
+
+void CSE::runOnOperation() {
+ /// A scoped hash table of defining operations within a region.
+ ScopedMapTy knownValues;
+
+ DominanceInfo &domInfo = getAnalysis<DominanceInfo>();
+ for (Region &region : getOperation()->getRegions())
+ simplifyRegion(knownValues, domInfo, region);
+
+ // If no operations were erased, then we mark all analyses as preserved.
+ if (opsToErase.empty())
+ return markAllAnalysesPreserved();
+
+ /// Erase any operations that were marked as dead during simplification.
+ for (auto *op : opsToErase)
+ op->erase();
+ opsToErase.clear();
+
+ // We currently don't remove region operations, so mark dominance as
+ // preserved.
+ markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
+}
+
+std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); }
+
+static PassRegistration<CSE> pass("cse", "Eliminate common sub-expressions");
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
new file mode 100644
index 00000000000..5b3a1eb1cf3
--- /dev/null
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -0,0 +1,45 @@
+//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass converts operations into their canonical forms by
+// folding constants, applying operation identity transformations etc.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+using namespace mlir;
+
+namespace {
+/// Canonicalize operations in nested regions.
+struct Canonicalizer : public OperationPass<Canonicalizer> {
+ void runOnOperation() override {
+ OwningRewritePatternList patterns;
+
+ // TODO: Instead of adding all known patterns from the whole system lazily
+ // add and cache the canonicalization patterns for ops we see in practice
+ // when building the worklist. For now, we just grab everything.
+ auto *context = &getContext();
+ for (auto *op : context->getRegisteredOperations())
+ op->getCanonicalizationPatterns(patterns, context);
+
+ Operation *op = getOperation();
+ applyPatternsGreedily(op->getRegions(), patterns);
+ }
+};
+} // end anonymous namespace
+
+/// Create a Canonicalizer pass.
+std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
+ return std::make_unique<Canonicalizer>();
+}
+
+static PassRegistration<Canonicalizer> pass("canonicalize",
+ "Canonicalize operations");
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
new file mode 100644
index 00000000000..5f7fb7a68c9
--- /dev/null
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -0,0 +1,1846 @@
+//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+#define DEBUG_TYPE "dialect-conversion"
+
+/// Recursively collect all of the operations to convert from within 'region'.
+/// If 'target' is nonnull, operations that are recursively legal have their
+/// regions pre-filtered to avoid considering them for legalization.
+static LogicalResult
+computeConversionSet(iterator_range<Region::iterator> region,
+ Location regionLoc, std::vector<Operation *> &toConvert,
+ ConversionTarget *target = nullptr) {
+ if (llvm::empty(region))
+ return success();
+
+ // Traverse starting from the entry block.
+ SmallVector<Block *, 16> worklist(1, &*region.begin());
+ DenseSet<Block *> visitedBlocks;
+ visitedBlocks.insert(worklist.front());
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+
+ // Compute the conversion set of each of the nested operations.
+ for (Operation &op : *block) {
+ toConvert.emplace_back(&op);
+
+ // Don't check this operation's children for conversion if the operation
+ // is recursively legal.
+ auto legalityInfo = target ? target->isLegal(&op)
+ : Optional<ConversionTarget::LegalOpDetails>();
+ if (legalityInfo && legalityInfo->isRecursivelyLegal)
+ continue;
+ for (auto &region : op.getRegions())
+ computeConversionSet(region.getBlocks(), region.getLoc(), toConvert,
+ target);
+ }
+
+ // Recurse to children that haven't been visited.
+ for (Block *succ : block->getSuccessors())
+ if (visitedBlocks.insert(succ).second)
+ worklist.push_back(succ);
+ }
+
+ // Check that all blocks in the region were visited.
+ if (llvm::any_of(llvm::drop_begin(region, 1),
+ [&](Block &block) { return !visitedBlocks.count(&block); }))
+ return emitError(regionLoc, "unreachable blocks were not converted");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Multi-Level Value Mapper
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class wraps a BlockAndValueMapping to provide recursive lookup
+/// functionality, i.e. we will traverse if the mapped value also has a mapping.
+struct ConversionValueMapping {
+ /// Lookup a mapped value within the map. If a mapping for the provided value
+ /// does not exist then return the provided value.
+ Value lookupOrDefault(Value from) const;
+
+ /// Map a value to the one provided.
+ void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
+
+ /// Drop the last mapping for the given value.
+ void erase(Value value) { mapping.erase(value); }
+
+private:
+ /// Current value mappings.
+ BlockAndValueMapping mapping;
+};
+} // end anonymous namespace
+
+/// Lookup a mapped value within the map. If a mapping for the provided value
+/// does not exist then return the provided value.
+Value ConversionValueMapping::lookupOrDefault(Value from) const {
+ // If this value had a valid mapping, unmap that value as well in the case
+ // that it was also replaced.
+ while (auto mappedValue = mapping.lookupOrNull(from))
+ from = mappedValue;
+ return from;
+}
+
+//===----------------------------------------------------------------------===//
+// ArgConverter
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides a simple interface for converting the types of block
+/// arguments. This is done by creating a new block that contains the new legal
+/// types and extracting the block that contains the old illegal types to allow
+/// for undoing pending rewrites in the case of failure.
+struct ArgConverter {
+ ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
+ : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+ rewriter(rewriter) {}
+
+ /// This structure contains the information pertaining to an argument that has
+ /// been converted.
+ struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+ };
+
+ /// This structure contains information pertaining to a block that has had its
+ /// signature converted.
+ struct ConvertedBlockInfo {
+ ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {}
+
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// None if the argument was dropped during conversion.
+ SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
+ };
+
+ /// Return if the signature of the given block has already been converted.
+ bool hasBeenConverted(Block *block) const {
+ return conversionInfo.count(block);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Rewrite Application
+ //===--------------------------------------------------------------------===//
+
+ /// Erase any rewrites registered for the blocks within the given operation
+ /// which is about to be removed. This merely drops the rewrites without
+ /// undoing them.
+ void notifyOpRemoved(Operation *op);
+
+ /// Cleanup and undo any generated conversions for the arguments of block.
+ /// This method replaces the new block with the original, reverting the IR to
+ /// its original state.
+ void discardRewrites(Block *block);
+
+ /// Fully replace uses of the old arguments with the new, materializing cast
+ /// operations as necessary.
+ // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
+ // implementation of replaceUsesOfBlockArgument is buggy.
+ void applyRewrites(ConversionValueMapping &mapping);
+
+ //===--------------------------------------------------------------------===//
+ // Conversion
+ //===--------------------------------------------------------------------===//
+
+ /// Attempt to convert the signature of the given block, if successful a new
+ /// block is returned containing the new arguments. On failure, nullptr is
+ /// returned.
+ Block *convertSignature(Block *block, ConversionValueMapping &mapping);
+
+ /// Apply the given signature conversion on the given block. The new block
+ /// containing the updated signature is returned.
+ Block *applySignatureConversion(
+ Block *block, TypeConverter::SignatureConversion &signatureConversion,
+ ConversionValueMapping &mapping);
+
+ /// Insert a new conversion into the cache.
+ void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
+
+ /// A collection of blocks that have had their arguments converted.
+ llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
+
+ /// A mapping from valid regions, to those containing the original blocks of a
+ /// conversion.
+ DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
+
+ /// An instance of the unknown location that is used when materializing
+ /// conversions.
+ Location loc;
+
+ /// The type converter to use when changing types.
+ TypeConverter *typeConverter;
+
+ /// The pattern rewriter to use when materializing conversions.
+ PatternRewriter &rewriter;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Rewrite Application
+
+void ArgConverter::notifyOpRemoved(Operation *op) {
+ for (Region &region : op->getRegions()) {
+ for (Block &block : region) {
+ // Drop any rewrites from within.
+ for (Operation &nestedOp : block)
+ if (nestedOp.getNumRegions())
+ notifyOpRemoved(&nestedOp);
+
+ // Check if this block was converted.
+ auto it = conversionInfo.find(&block);
+ if (it == conversionInfo.end())
+ return;
+
+ // Drop all uses of the original arguments and delete the original block.
+ Block *origBlock = it->second.origBlock;
+ for (BlockArgument arg : origBlock->getArguments())
+ arg->dropAllUses();
+ conversionInfo.erase(it);
+ }
+ }
+}
+
+void ArgConverter::discardRewrites(Block *block) {
+ auto it = conversionInfo.find(block);
+ if (it == conversionInfo.end())
+ return;
+ Block *origBlock = it->second.origBlock;
+
+ // Drop all uses of the new block arguments and replace uses of the new block.
+ for (int i = block->getNumArguments() - 1; i >= 0; --i)
+ block->getArgument(i)->dropAllUses();
+ block->replaceAllUsesWith(origBlock);
+
+ // Move the operations back the original block and the delete the new block.
+ origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+ origBlock->moveBefore(block);
+ block->erase();
+
+ conversionInfo.erase(it);
+}
+
+void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
+ for (auto &info : conversionInfo) {
+ Block *newBlock = info.first;
+ ConvertedBlockInfo &blockInfo = info.second;
+ Block *origBlock = blockInfo.origBlock;
+
+ // Process the remapping for each of the original arguments.
+ for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+ Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
+ BlockArgument origArg = origBlock->getArgument(i);
+
+ // Handle the case of a 1->0 value mapping.
+ if (!argInfo) {
+ // If a replacement value was given for this argument, use that to
+ // replace all uses.
+ auto argReplacementValue = mapping.lookupOrDefault(origArg);
+ if (argReplacementValue != origArg) {
+ origArg->replaceAllUsesWith(argReplacementValue);
+ continue;
+ }
+ // If there are any dangling uses then replace the argument with one
+ // generated by the type converter. This is necessary as the cast must
+ // persist in the IR after conversion.
+ if (!origArg->use_empty()) {
+ rewriter.setInsertionPointToStart(newBlock);
+ auto *newOp = typeConverter->materializeConversion(
+ rewriter, origArg->getType(), llvm::None, loc);
+ origArg->replaceAllUsesWith(newOp->getResult(0));
+ }
+ continue;
+ }
+
+ // If mapping is 1-1, replace the remaining uses and drop the cast
+ // operation.
+ // FIXME(riverriddle) This should check that the result type and operand
+ // type are the same, otherwise it should force a conversion to be
+ // materialized.
+ if (argInfo->newArgSize == 1) {
+ origArg->replaceAllUsesWith(
+ mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
+ continue;
+ }
+
+ // Otherwise this is a 1->N value mapping.
+ Value castValue = argInfo->castValue;
+ assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
+
+ // If the argument is still used, replace it with the generated cast.
+ if (!origArg->use_empty())
+ origArg->replaceAllUsesWith(mapping.lookupOrDefault(castValue));
+
+ // If all users of the cast were removed, we can drop it. Otherwise, keep
+ // the operation alive and let the user handle any remaining usages.
+ if (castValue->use_empty())
+ castValue->getDefiningOp()->erase();
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion
+
+Block *ArgConverter::convertSignature(Block *block,
+ ConversionValueMapping &mapping) {
+ if (auto conversion = typeConverter->convertBlockSignature(block))
+ return applySignatureConversion(block, *conversion, mapping);
+ return nullptr;
+}
+
+Block *ArgConverter::applySignatureConversion(
+ Block *block, TypeConverter::SignatureConversion &signatureConversion,
+ ConversionValueMapping &mapping) {
+ // If no arguments are being changed or added, there is nothing to do.
+ unsigned origArgCount = block->getNumArguments();
+ auto convertedTypes = signatureConversion.getConvertedTypes();
+ if (origArgCount == 0 && convertedTypes.empty())
+ return block;
+
+ // Split the block at the beginning to get a new block to use for the updated
+ // signature.
+ Block *newBlock = block->splitBlock(block->begin());
+ block->replaceAllUsesWith(newBlock);
+
+ SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
+ ArrayRef<Value> newArgs(newArgRange);
+
+ // Remap each of the original arguments as determined by the signature
+ // conversion.
+ ConvertedBlockInfo info(block);
+ info.argInfo.resize(origArgCount);
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(newBlock);
+ for (unsigned i = 0; i != origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap)
+ continue;
+ BlockArgument origArg = block->getArgument(i);
+
+ // If inputMap->replacementValue is not nullptr, then the argument is
+ // dropped and a replacement value is provided to be the remappedValue.
+ if (inputMap->replacementValue) {
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, inputMap->replacementValue);
+ continue;
+ }
+
+ // If this is a 1->1 mapping, then map the argument directly.
+ if (inputMap->size == 1) {
+ mapping.map(origArg, newArgs[inputMap->inputNo]);
+ info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
+ continue;
+ }
+
+ // Otherwise, this is a 1->N mapping. Call into the provided type converter
+ // to pack the new values.
+ auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ Operation *cast = typeConverter->materializeConversion(
+ rewriter, origArg->getType(), replArgs, loc);
+ assert(cast->getNumResults() == 1 &&
+ cast->getNumOperands() == replArgs.size());
+ mapping.map(origArg, cast->getResult(0));
+ info.argInfo[i] =
+ ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
+ }
+
+ // Remove the original block from the region and return the new one.
+ insertConversion(newBlock, std::move(info));
+ return newBlock;
+}
+
+void ArgConverter::insertConversion(Block *newBlock,
+ ConvertedBlockInfo &&info) {
+ // Get a region to insert the old block.
+ Region *region = newBlock->getParent();
+ std::unique_ptr<Region> &mappedRegion = regionMapping[region];
+ if (!mappedRegion)
+ mappedRegion = std::make_unique<Region>(region->getParentOp());
+
+ // Move the original block to the mapped region and emplace the conversion.
+ mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
+ info.origBlock->getIterator());
+ conversionInfo.insert({newBlock, std::move(info)});
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriterImpl
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class contains a snapshot of the current conversion rewriter state.
+/// This is useful when saving and undoing a set of rewrites.
+struct RewriterState {
+ RewriterState(unsigned numCreatedOps, unsigned numReplacements,
+ unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numBlockActions(numBlockActions),
+ numIgnoredOperations(numIgnoredOperations),
+ numRootUpdates(numRootUpdates) {}
+
+ /// The current number of created operations.
+ unsigned numCreatedOps;
+
+ /// The current number of replacements queued.
+ unsigned numReplacements;
+
+ /// The current number of block actions performed.
+ unsigned numBlockActions;
+
+ /// The current number of ignored operations.
+ unsigned numIgnoredOperations;
+
+ /// The current number of operations that were updated in place.
+ unsigned numRootUpdates;
+};
+
+/// The state of an operation that was updated by a pattern in-place. This
+/// contains all of the necessary information to reconstruct an operation that
+/// was updated in place.
+class OperationTransactionState {
+public:
+ OperationTransactionState() = default;
+ OperationTransactionState(Operation *op)
+ : op(op), loc(op->getLoc()), attrs(op->getAttrList()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void resetOperation() const {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (auto it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+ /// Return the original operation of this state.
+ Operation *getOperation() const { return op; }
+
+private:
+ Operation *op;
+ LocationAttr loc;
+ NamedAttributeList attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
+ /// This class represents one requested operation replacement via 'replaceOp'.
+ struct OpReplacement {
+ OpReplacement() = default;
+ OpReplacement(Operation *op, ValueRange newValues)
+ : op(op), newValues(newValues.begin(), newValues.end()) {}
+
+ Operation *op;
+ SmallVector<Value, 2> newValues;
+ };
+
+ /// The kind of the block action performed during the rewrite. Actions can be
+ /// undone if the conversion fails.
+ enum class BlockActionKind { Create, Move, Split, TypeConversion };
+
+ /// Original position of the given block in its parent region. We cannot use
+ /// a region iterator because it could have been invalidated by other region
+ /// operations since the position was stored.
+ struct BlockPosition {
+ Region *region;
+ Region::iterator::difference_type position;
+ };
+
+ /// The storage class for an undoable block action (one of BlockActionKind),
+ /// contains the information necessary to undo this action.
+ struct BlockAction {
+ static BlockAction getCreate(Block *block) {
+ return {BlockActionKind::Create, block, {}};
+ }
+ static BlockAction getMove(Block *block, BlockPosition originalPos) {
+ return {BlockActionKind::Move, block, {originalPos}};
+ }
+ static BlockAction getSplit(Block *block, Block *originalBlock) {
+ BlockAction action{BlockActionKind::Split, block, {}};
+ action.originalBlock = originalBlock;
+ return action;
+ }
+ static BlockAction getTypeConversion(Block *block) {
+ return BlockAction{BlockActionKind::TypeConversion, block, {}};
+ }
+
+ // The action kind.
+ BlockActionKind kind;
+
+ // A pointer to the block that was created by the action.
+ Block *block;
+
+ union {
+ // In use if kind == BlockActionKind::Move and contains a pointer to the
+ // region that originally contained the block as well as the position of
+ // the block in that region.
+ BlockPosition originalPosition;
+ // In use if kind == BlockActionKind::Split and contains a pointer to the
+ // block that was split into two parts.
+ Block *originalBlock;
+ };
+ };
+
+ ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+ TypeConverter *converter)
+ : argConverter(converter, rewriter) {}
+
+ /// Return the current state of the rewriter.
+ RewriterState getCurrentState();
+
+ /// Reset the state of the rewriter to a previously saved point.
+ void resetState(RewriterState state);
+
+ /// Undo the block actions (motions, splits) one by one in reverse order until
+ /// "numActionsToKeep" actions remains.
+ void undoBlockActions(unsigned numActionsToKeep = 0);
+
+ /// Cleanup and destroy any generated rewrite operations. This method is
+ /// invoked when the conversion process fails.
+ void discardRewrites();
+
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
+ /// Convert the signature of the given block.
+ LogicalResult convertBlockSignature(Block *block);
+
+ /// Apply a signature conversion on the given region.
+ Block *
+ applySignatureConversion(Region *region,
+ TypeConverter::SignatureConversion &conversion);
+
+ /// PatternRewriter hook for replacing the results of an operation.
+ void replaceOp(Operation *op, ValueRange newValues,
+ ValueRange valuesToRemoveIfDead);
+
+ /// Notifies that a block was split.
+ void notifySplitBlock(Block *block, Block *continuation);
+
+ /// Notifies that the blocks of a region are about to be moved.
+ void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
+ Region::iterator before);
+
+ /// Notifies that the blocks of a region were cloned into another.
+ void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
+ Location origRegionLoc);
+
+ /// Remap the given operands to those with potentially different types.
+ void remapValues(Operation::operand_range operands,
+ SmallVectorImpl<Value> &remapped);
+
+ /// Returns true if the given operation is ignored, and does not need to be
+ /// converted.
+ bool isOpIgnored(Operation *op) const;
+
+ /// Recursively marks the nested operations under 'op' as ignored. This
+ /// removes them from being considered for legalization.
+ void markNestedOpsIgnored(Operation *op);
+
+ // Mapping between replaced values that differ in type. This happens when
+ // replacing a value with one of a different type.
+ ConversionValueMapping mapping;
+
+ /// Utility used to convert block arguments.
+ ArgConverter argConverter;
+
+ /// Ordered vector of all of the newly created operations during conversion.
+ std::vector<Operation *> createdOps;
+
+ /// Ordered vector of any requested operation replacements.
+ SmallVector<OpReplacement, 4> replacements;
+
+ /// Ordered list of block operations (creations, splits, motions).
+ SmallVector<BlockAction, 4> blockActions;
+
+ /// A set of operations that have been erased/replaced/etc that should no
+ /// longer be considered for legalization. This is not meant to be an
+ /// exhaustive list of all operations, but the minimal set that can be used to
+ /// detect if a given operation should be `ignored`. For example, we may add
+ /// the operations that define non-empty regions to the set, but not any of
+ /// the others. This simplifies the amount of memory needed as we can query if
+ /// the parent operation was ignored.
+ llvm::SetVector<Operation *> ignoredOps;
+
+ /// A transaction state for each of operations that were updated in-place.
+ SmallVector<OperationTransactionState, 4> rootUpdates;
+
+#ifndef NDEBUG
+ /// A set of operations that have pending updates. This tracking isn't
+ /// strictly necessary, and is thus only active during debug builds for extra
+ /// verification.
+ SmallPtrSet<Operation *, 1> pendingRootUpdates;
+#endif
+};
+} // end namespace detail
+} // end namespace mlir
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+ return RewriterState(createdOps.size(), replacements.size(),
+ blockActions.size(), ignoredOps.size(),
+ rootUpdates.size());
+}
+
+void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+ // Reset any operations that were updated in place.
+ for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
+ rootUpdates[i].resetOperation();
+ rootUpdates.resize(state.numRootUpdates);
+
+ // Undo any block actions.
+ undoBlockActions(state.numBlockActions);
+
+ // Reset any replaced operations and undo any saved mappings.
+ for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+ for (auto result : repl.op->getResults())
+ mapping.erase(result);
+ replacements.resize(state.numReplacements);
+
+ // Pop all of the newly created operations.
+ while (createdOps.size() != state.numCreatedOps) {
+ createdOps.back()->erase();
+ createdOps.pop_back();
+ }
+
+ // Pop all of the recorded ignored operations that are no longer valid.
+ while (ignoredOps.size() != state.numIgnoredOperations)
+ ignoredOps.pop_back();
+}
+
+void ConversionPatternRewriterImpl::undoBlockActions(
+ unsigned numActionsToKeep) {
+ for (auto &action :
+ llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
+ switch (action.kind) {
+ // Delete the created block.
+ case BlockActionKind::Create: {
+ // Unlink all of the operations within this block, they will be deleted
+ // separately.
+ auto &blockOps = action.block->getOperations();
+ while (!blockOps.empty())
+ blockOps.remove(blockOps.begin());
+ action.block->dropAllDefinedValueUses();
+ action.block->erase();
+ break;
+ }
+ // Move the block back to its original position.
+ case BlockActionKind::Move: {
+ Region *originalRegion = action.originalPosition.region;
+ originalRegion->getBlocks().splice(
+ std::next(originalRegion->begin(), action.originalPosition.position),
+ action.block->getParent()->getBlocks(), action.block);
+ break;
+ }
+ // Merge back the block that was split out.
+ case BlockActionKind::Split: {
+ action.originalBlock->getOperations().splice(
+ action.originalBlock->end(), action.block->getOperations());
+ action.block->dropAllDefinedValueUses();
+ action.block->erase();
+ break;
+ }
+ // Undo the type conversion.
+ case BlockActionKind::TypeConversion: {
+ argConverter.discardRewrites(action.block);
+ break;
+ }
+ }
+ }
+ blockActions.resize(numActionsToKeep);
+}
+
+void ConversionPatternRewriterImpl::discardRewrites() {
+ // Reset any operations that were updated in place.
+ for (auto &state : rootUpdates)
+ state.resetOperation();
+
+ undoBlockActions();
+
+ // Remove any newly created ops.
+ for (auto *op : llvm::reverse(createdOps))
+ op->erase();
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+ // Apply all of the rewrites replacements requested during conversion.
+ for (auto &repl : replacements) {
+ for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
+ if (auto newValue = repl.newValues[i])
+ repl.op->getResult(i)->replaceAllUsesWith(
+ mapping.lookupOrDefault(newValue));
+ }
+
+ // If this operation defines any regions, drop any pending argument
+ // rewrites.
+ if (argConverter.typeConverter && repl.op->getNumRegions())
+ argConverter.notifyOpRemoved(repl.op);
+ }
+
+ // In a second pass, erase all of the replaced operations in reverse. This
+ // allows processing nested operations before their parent region is
+ // destroyed.
+ for (auto &repl : llvm::reverse(replacements))
+ repl.op->erase();
+
+ argConverter.applyRewrites(mapping);
+}
+
+LogicalResult
+ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
+ // Check to see if this block should not be converted:
+ // * There is no type converter.
+ // * The block has already been converted.
+ // * This is an entry block, these are converted explicitly via patterns.
+ if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
+ !block->getParent() || block->isEntryBlock())
+ return success();
+
+ // Otherwise, try to convert the block signature.
+ Block *newBlock = argConverter.convertSignature(block, mapping);
+ if (newBlock)
+ blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+ return success(newBlock);
+}
+
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
+ Region *region, TypeConverter::SignatureConversion &conversion) {
+ if (!region->empty()) {
+ Block *newEntry = argConverter.applySignatureConversion(
+ &region->front(), conversion, mapping);
+ blockActions.push_back(BlockAction::getTypeConversion(newEntry));
+ return newEntry;
+ }
+ return nullptr;
+}
+
+void ConversionPatternRewriterImpl::replaceOp(Operation *op,
+ ValueRange newValues,
+ ValueRange valuesToRemoveIfDead) {
+ assert(newValues.size() == op->getNumResults());
+
+ // Create mappings for each of the new result values.
+ for (unsigned i = 0, e = newValues.size(); i < e; ++i)
+ if (auto repl = newValues[i])
+ mapping.map(op->getResult(i), repl);
+
+ // Record the requested operation replacement.
+ replacements.emplace_back(op, newValues);
+
+ /// Mark this operation as recursively ignored so that we don't need to
+ /// convert any nested operations.
+ markNestedOpsIgnored(op);
+}
+
+void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
+ Block *continuation) {
+ blockActions.push_back(BlockAction::getSplit(continuation, block));
+}
+
+void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
+ Region &region, Region &parent, Region::iterator before) {
+ for (auto &pair : llvm::enumerate(region)) {
+ Block &block = pair.value();
+ unsigned position = pair.index();
+ blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
+ }
+}
+
+void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
+ iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
+ for (Block &block : blocks)
+ blockActions.push_back(BlockAction::getCreate(&block));
+
+ // Compute the conversion set for the inlined region.
+ auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
+
+ // This original region has already had its conversion set computed, so there
+ // shouldn't be any new failures.
+ (void)result;
+ assert(succeeded(result) && "expected region to have no unreachable blocks");
+}
+
+void ConversionPatternRewriterImpl::remapValues(
+ Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
+ remapped.reserve(llvm::size(operands));
+ for (Value operand : operands)
+ remapped.push_back(mapping.lookupOrDefault(operand));
+}
+
+bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
+ // Check to see if this operation or its parent were ignored.
+ return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
+}
+
+void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
+ // Walk this operation and collect nested operations that define non-empty
+ // regions. We mark such operations as 'ignored' so that we know we don't have
+ // to convert them, or their nested ops.
+ if (op->getNumRegions() == 0)
+ return;
+ op->walk([&](Operation *op) {
+ if (llvm::any_of(op->getRegions(),
+ [](Region &region) { return !region.empty(); }))
+ ignoredOps.insert(op);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
+ TypeConverter *converter)
+ : PatternRewriter(ctx),
+ impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+ConversionPatternRewriter::~ConversionPatternRewriter() {}
+
+/// PatternRewriter hook for replacing the results of an operation.
+void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
+ ValueRange valuesToRemoveIfDead) {
+ LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
+ << "\n");
+ impl->replaceOp(op, newValues, valuesToRemoveIfDead);
+}
+
+/// PatternRewriter hook for erasing a dead operation. The uses of this
+/// operation *must* be made dead by the end of the conversion process,
+/// otherwise an assert will be issued.
+void ConversionPatternRewriter::eraseOp(Operation *op) {
+ LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName()
+ << "\n");
+ SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
+ impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None);
+}
+
+/// Apply a signature conversion to the entry block of the given region.
+Block *ConversionPatternRewriter::applySignatureConversion(
+ Region *region, TypeConverter::SignatureConversion &conversion) {
+ return impl->applySignatureConversion(region, conversion);
+}
+
+void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
+ Value to) {
+ for (auto &u : from->getUses()) {
+ if (u.getOwner() == to->getDefiningOp())
+ continue;
+ u.getOwner()->replaceUsesOfWith(from, to);
+ }
+ impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+}
+
+/// Return the converted value that replaces 'key'. Return 'key' if there is
+/// no such a converted value.
+Value ConversionPatternRewriter::getRemappedValue(Value key) {
+ return impl->mapping.lookupOrDefault(key);
+}
+
+/// PatternRewriter hook for splitting a block into two parts.
+Block *ConversionPatternRewriter::splitBlock(Block *block,
+ Block::iterator before) {
+ auto *continuation = PatternRewriter::splitBlock(block, before);
+ impl->notifySplitBlock(block, continuation);
+ return continuation;
+}
+
+/// PatternRewriter hook for merging a block into another.
+void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
+ ValueRange argValues) {
+ // TODO(riverriddle) This requires fixing the implementation of
+ // 'replaceUsesOfBlockArgument', which currently isn't undoable.
+ llvm_unreachable("block merging updates are currently not supported");
+}
+
+/// PatternRewriter hook for moving blocks out of a region.
+void ConversionPatternRewriter::inlineRegionBefore(Region &region,
+ Region &parent,
+ Region::iterator before) {
+ impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
+ PatternRewriter::inlineRegionBefore(region, parent, before);
+}
+
+/// PatternRewriter hook for cloning blocks of one region into another.
+void ConversionPatternRewriter::cloneRegionBefore(
+ Region &region, Region &parent, Region::iterator before,
+ BlockAndValueMapping &mapping) {
+ if (region.empty())
+ return;
+ PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
+
+ // Collect the range of the cloned blocks.
+ auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
+ auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
+ impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
+}
+
+/// PatternRewriter hook for creating a new operation.
+Operation *ConversionPatternRewriter::insert(Operation *op) {
+ LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName()
+ << "\n");
+ impl->createdOps.push_back(op);
+ return OpBuilder::insert(op);
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ impl->pendingRootUpdates.insert(op);
+#endif
+ impl->rootUpdates.emplace_back(op);
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
+ // There is nothing to do here, we only need to track the operation at the
+ // start of the update.
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+ // Erase the last update for this operation.
+ auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
+ auto &rootUpdates = impl->rootUpdates;
+ auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
+ rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
+}
+
+/// Return a reference to the internal implementation.
+detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
+ return *impl;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+/// Attempt to match and rewrite the IR root at the specified operation.
+PatternMatchResult
+ConversionPattern::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ SmallVector<Value, 4> operands;
+ auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
+ dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
+
+ // If this operation has no successors, invoke the rewrite directly.
+ if (op->getNumSuccessors() == 0)
+ return matchAndRewrite(op, operands, dialectRewriter);
+
+ // Otherwise, we need to remap the successors.
+ SmallVector<Block *, 2> destinations;
+ destinations.reserve(op->getNumSuccessors());
+
+ SmallVector<ArrayRef<Value>, 2> operandsPerDestination;
+ unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
+ for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
+ destinations.push_back(op->getSuccessor(i));
+
+ // Lookup the successors operands.
+ unsigned n = op->getNumSuccessorOperands(i);
+ operandsPerDestination.push_back(
+ llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
+ seen += n;
+ }
+
+ // Rewrite the operation.
+ return matchAndRewrite(
+ op,
+ llvm::makeArrayRef(operands.data(),
+ operands.data() + firstSuccessorOperand),
+ destinations, operandsPerDestination, dialectRewriter);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationLegalizer
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A set of rewrite patterns that can be used to legalize a given operation.
+using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
+
+/// This class defines a recursive operation legalizer.
+class OperationLegalizer {
+public:
+ using LegalizationAction = ConversionTarget::LegalizationAction;
+
+ OperationLegalizer(ConversionTarget &targetInfo,
+ const OwningRewritePatternList &patterns)
+ : target(targetInfo) {
+ buildLegalizationGraph(patterns);
+ computeLegalizationGraphBenefit();
+ }
+
+ /// Returns if the given operation is known to be illegal on the target.
+ bool isIllegal(Operation *op) const;
+
+ /// Attempt to legalize the given operation. Returns success if the operation
+ /// was legalized, failure otherwise.
+ LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+
+ /// Returns the conversion target in use by the legalizer.
+ ConversionTarget &getTarget() { return target; }
+
+private:
+ /// Attempt to legalize the given operation by folding it.
+ LogicalResult legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter);
+
+ /// Attempt to legalize the given operation by applying the provided pattern.
+ /// Returns success if the operation was legalized, failure otherwise.
+ LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
+ ConversionPatternRewriter &rewriter);
+
+ /// Build an optimistic legalization graph given the provided patterns. This
+ /// function populates 'legalizerPatterns' with the operations that are not
+ /// directly legal, but may be transitively legal for the current target given
+ /// the provided patterns.
+ void buildLegalizationGraph(const OwningRewritePatternList &patterns);
+
+ /// Compute the benefit of each node within the computed legalization graph.
+ /// This orders the patterns within 'legalizerPatterns' based upon two
+ /// criteria:
+ /// 1) Prefer patterns that have the lowest legalization depth, i.e.
+ /// represent the more direct mapping to the target.
+ /// 2) When comparing patterns with the same legalization depth, prefer the
+ /// pattern with the highest PatternBenefit. This allows for users to
+ /// prefer specific legalizations over others.
+ void computeLegalizationGraphBenefit();
+
+ /// The current set of patterns that have been applied.
+ SmallPtrSet<RewritePattern *, 8> appliedPatterns;
+
+ /// The set of legality information for operations transitively supported by
+ /// the target.
+ DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
+
+ /// The legalization information provided by the target.
+ ConversionTarget &target;
+};
+} // namespace
+
+bool OperationLegalizer::isIllegal(Operation *op) const {
+ // Check if the target explicitly marked this operation as illegal.
+ return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
+}
+
+LogicalResult
+OperationLegalizer::legalize(Operation *op,
+ ConversionPatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
+ << "\n");
+
+ // Check if this operation is legal on the target.
+ if (auto legalityInfo = target.isLegal(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "-- Success : Operation marked legal by the target\n");
+ // If this operation is recursively legal, mark its children as ignored so
+ // that we don't consider them for legalization.
+ if (legalityInfo->isRecursivelyLegal) {
+ LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation is recursively legal; "
+ "Skipping internals\n");
+ rewriter.getImpl().markNestedOpsIgnored(op);
+ }
+ return success();
+ }
+
+ // Check to see if the operation is ignored and doesn't need to be converted.
+ if (rewriter.getImpl().isOpIgnored(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "-- Success : Operation marked ignored during conversion\n");
+ return success();
+ }
+
+ // If the operation isn't legal, try to fold it in-place.
+ // TODO(riverriddle) Should we always try to do this, even if the op is
+ // already legal?
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n");
+ return success();
+ }
+
+ // Otherwise, we need to apply a legalization pattern to this operation.
+ auto it = legalizerPatterns.find(op->getName());
+ if (it == legalizerPatterns.end()) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
+ return failure();
+ }
+
+ // The patterns are sorted by expected benefit, so try to apply each in-order.
+ for (auto *pattern : it->second)
+ if (succeeded(legalizePattern(op, pattern, rewriter)))
+ return success();
+
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
+ return failure();
+}
+
+LogicalResult
+OperationLegalizer::legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
+
+ // Try to fold the operation.
+ SmallVector<Value, 2> replacementValues;
+ rewriter.setInsertionPoint(op);
+ if (failed(rewriter.tryFold(op, replacementValues)))
+ return failure();
+
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
+ // Recursively legalize any new constant operations.
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+ i != e; ++i) {
+ Operation *cstOp = rewriterImpl.createdOps[i];
+ if (failed(legalize(cstOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '"
+ << cstOp->getName() << "' was illegal.\n");
+ rewriterImpl.resetState(curState);
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult
+OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
+ ConversionPatternRewriter &rewriter) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
+ interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
+ llvm::dbgs() << ")'.\n";
+ });
+
+ // Ensure that we don't cycle by not allowing the same pattern to be
+ // applied twice in the same recursion stack.
+ // TODO(riverriddle) We could eventually converge, but that requires more
+ // complicated analysis.
+ if (!appliedPatterns.insert(pattern).second) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
+ return failure();
+ }
+
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
+ auto cleanupFailure = [&] {
+ // Reset the rewriter state and pop this pattern.
+ rewriterImpl.resetState(curState);
+ appliedPatterns.erase(pattern);
+ return failure();
+ };
+
+ // Try to rewrite with the given pattern.
+ rewriter.setInsertionPoint(op);
+ auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+#ifndef NDEBUG
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#endif
+
+ if (!matchedPattern) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
+ return cleanupFailure();
+ }
+
+ // If the pattern moved or created any blocks, try to legalize their types.
+ // This ensures that the types of the block arguments are legal for the region
+ // they were moved into.
+ for (unsigned i = curState.numBlockActions,
+ e = rewriterImpl.blockActions.size();
+ i != e; ++i) {
+ auto &action = rewriterImpl.blockActions[i];
+ if (action.kind ==
+ ConversionPatternRewriterImpl::BlockActionKind::TypeConversion)
+ continue;
+
+ // Convert the block signature.
+ if (failed(rewriterImpl.convertBlockSignature(action.block))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "-- FAIL: failed to convert types of moved block.\n");
+ return cleanupFailure();
+ }
+ }
+
+ // Check all of the replacements to ensure that the pattern actually replaced
+ // the root operation. We also mark any other replaced ops as 'dead' so that
+ // we don't try to legalize them later.
+ bool replacedRoot = false;
+ for (unsigned i = curState.numReplacements,
+ e = rewriterImpl.replacements.size();
+ i != e; ++i) {
+ Operation *replacedOp = rewriterImpl.replacements[i].op;
+ if (replacedOp == op)
+ replacedRoot = true;
+ else
+ rewriterImpl.ignoredOps.insert(replacedOp);
+ }
+
+ // Check that the root was either updated or replace.
+ auto updatedRootInPlace = [&] {
+ return llvm::any_of(
+ llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+ [op](auto &state) { return state.getOperation() == op; });
+ };
+ (void)replacedRoot;
+ (void)updatedRootInPlace;
+ assert((replacedRoot || updatedRootInPlace()) &&
+ "expected pattern to replace the root operation");
+
+ // Recursively legalize each of the operations updated in place.
+ for (unsigned i = curState.numRootUpdates,
+ e = rewriterImpl.rootUpdates.size();
+ i != e; ++i) {
+ auto &state = rewriterImpl.rootUpdates[i];
+ if (failed(legalize(state.getOperation(), rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
+ << op->getName() << "' was illegal.\n");
+ return cleanupFailure();
+ }
+ }
+
+ // Recursively legalize each of the new operations.
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+ i != e; ++i) {
+ Operation *op = rewriterImpl.createdOps[i];
+ if (failed(legalize(op, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation '"
+ << op->getName() << "' was illegal.\n");
+ return cleanupFailure();
+ }
+ }
+
+ appliedPatterns.erase(pattern);
+ return success();
+}
+
+void OperationLegalizer::buildLegalizationGraph(
+ const OwningRewritePatternList &patterns) {
+ // A mapping between an operation and a set of operations that can be used to
+ // generate it.
+ DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
+ // A mapping between an operation and any currently invalid patterns it has.
+ DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
+ // A worklist of patterns to consider for legality.
+ llvm::SetVector<RewritePattern *> patternWorklist;
+
+ // Build the mapping from operations to the parent ops that may generate them.
+ for (auto &pattern : patterns) {
+ auto root = pattern->getRootKind();
+
+ // Skip operations that are always known to be legal.
+ if (target.getOpAction(root) == LegalizationAction::Legal)
+ continue;
+
+ // Add this pattern to the invalid set for the root op and record this root
+ // as a parent for any generated operations.
+ invalidPatterns[root].insert(pattern.get());
+ for (auto op : pattern->getGeneratedOps())
+ parentOps[op].insert(root);
+
+ // Add this pattern to the worklist.
+ patternWorklist.insert(pattern.get());
+ }
+
+ while (!patternWorklist.empty()) {
+ auto *pattern = patternWorklist.pop_back_val();
+
+ // Check to see if any of the generated operations are invalid.
+ if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
+ Optional<LegalizationAction> action = target.getOpAction(op);
+ return !legalizerPatterns.count(op) &&
+ (!action || action == LegalizationAction::Illegal);
+ }))
+ continue;
+
+ // Otherwise, if all of the generated operation are valid, this op is now
+ // legal so add all of the child patterns to the worklist.
+ legalizerPatterns[pattern->getRootKind()].push_back(pattern);
+ invalidPatterns[pattern->getRootKind()].erase(pattern);
+
+ // Add any invalid patterns of the parent operations to see if they have now
+ // become legal.
+ for (auto op : parentOps[pattern->getRootKind()])
+ patternWorklist.set_union(invalidPatterns[op]);
+ }
+}
+
+void OperationLegalizer::computeLegalizationGraphBenefit() {
+ // The smallest pattern depth, when legalizing an operation.
+ DenseMap<OperationName, unsigned> minPatternDepth;
+
+ // Compute the minimum legalization depth for a given operation.
+ std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) {
+ // Check for existing depth.
+ auto depthIt = minPatternDepth.find(op);
+ if (depthIt != minPatternDepth.end())
+ return depthIt->second;
+
+ // If a mapping for this operation does not exist, then this operation
+ // is always legal. Return 0 as the depth for a directly legal operation.
+ auto opPatternsIt = legalizerPatterns.find(op);
+ if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
+ return 0u;
+
+ // Initialize the depth to the maximum value.
+ unsigned minDepth = std::numeric_limits<unsigned>::max();
+
+ // Record this initial depth in case we encounter this op again when
+ // recursively computing the depth.
+ minPatternDepth.try_emplace(op, minDepth);
+
+ // Compute the depth for each pattern used to legalize this operation.
+ SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth;
+ patternsByDepth.reserve(opPatternsIt->second.size());
+ for (RewritePattern *pattern : opPatternsIt->second) {
+ unsigned depth = 0;
+ for (auto generatedOp : pattern->getGeneratedOps())
+ depth = std::max(depth, computeDepth(generatedOp) + 1);
+ patternsByDepth.emplace_back(pattern, depth);
+
+ // Update the min depth for this operation.
+ minDepth = std::min(minDepth, depth);
+ }
+
+ // Update the pattern depth.
+ minPatternDepth[op] = minDepth;
+
+ // If the operation only has one legalization pattern, there is no need to
+ // sort them.
+ if (patternsByDepth.size() == 1)
+ return minDepth;
+
+ // Sort the patterns by those likely to be the most beneficial.
+ llvm::array_pod_sort(
+ patternsByDepth.begin(), patternsByDepth.end(),
+ [](const std::pair<RewritePattern *, unsigned> *lhs,
+ const std::pair<RewritePattern *, unsigned> *rhs) {
+ // First sort by the smaller pattern legalization depth.
+ if (lhs->second != rhs->second)
+ return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
+ &rhs->second);
+
+ // Then sort by the larger pattern benefit.
+ auto lhsBenefit = lhs->first->getBenefit();
+ auto rhsBenefit = rhs->first->getBenefit();
+ return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
+ &lhsBenefit);
+ });
+
+ // Update the legalization pattern to use the new sorted list.
+ opPatternsIt->second.clear();
+ for (auto &patternIt : patternsByDepth)
+ opPatternsIt->second.push_back(patternIt.first);
+
+ return minDepth;
+ };
+
+ // For each operation that is transitively legal, compute a cost for it.
+ for (auto &opIt : legalizerPatterns)
+ if (!minPatternDepth.count(opIt.first))
+ computeDepth(opIt.first);
+}
+
+//===----------------------------------------------------------------------===//
+// OperationConverter
+//===----------------------------------------------------------------------===//
+namespace {
+enum OpConversionMode {
+ // In this mode, the conversion will ignore failed conversions to allow
+ // illegal operations to co-exist in the IR.
+ Partial,
+
+ // In this mode, all operations must be legal for the given target for the
+ // conversion to succeed.
+ Full,
+
+ // In this mode, operations are analyzed for legality. No actual rewrites are
+ // applied to the operations on success.
+ Analysis,
+};
+
+// This class converts operations to a given conversion target via a set of
+// rewrite patterns. The conversion behaves differently depending on the
+// conversion mode.
+struct OperationConverter {
+ explicit OperationConverter(ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ OpConversionMode mode,
+ DenseSet<Operation *> *legalizableOps = nullptr)
+ : opLegalizer(target, patterns), mode(mode),
+ legalizableOps(legalizableOps) {}
+
+ /// Converts the given operations to the conversion target.
+ LogicalResult convertOperations(ArrayRef<Operation *> ops,
+ TypeConverter *typeConverter);
+
+private:
+ /// Converts an operation with the given rewriter.
+ LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+
+ /// Converts the type signatures of the blocks nested within 'op'.
+ LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
+ Operation *op);
+
+ /// The legalizer to use when converting operations.
+ OperationLegalizer opLegalizer;
+
+ /// The conversion mode to use when legalizing operations.
+ OpConversionMode mode;
+
+ /// A set of pre-existing operations that were found to be legalizable to the
+ /// target. This field is only used when mode == OpConversionMode::Analysis.
+ DenseSet<Operation *> *legalizableOps;
+};
+} // end anonymous namespace
+
+LogicalResult
+OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
+ Operation *op) {
+ // Check to see if type signatures need to be converted.
+ if (!rewriter.getImpl().argConverter.typeConverter)
+ return success();
+
+ for (auto &region : op->getRegions()) {
+ for (auto &block : llvm::make_early_inc_range(region))
+ if (failed(rewriter.getImpl().convertBlockSignature(&block)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
+ Operation *op) {
+ // Legalize the given operation.
+ if (failed(opLegalizer.legalize(op, rewriter))) {
+ // Handle the case of a failed conversion for each of the different modes.
+ /// Full conversions expect all operations to be converted.
+ if (mode == OpConversionMode::Full)
+ return op->emitError()
+ << "failed to legalize operation '" << op->getName() << "'";
+ /// Partial conversions allow conversions to fail iff the operation was not
+ /// explicitly marked as illegal.
+ if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
+ return op->emitError()
+ << "failed to legalize operation '" << op->getName()
+ << "' that was explicitly marked illegal";
+ } else {
+ /// Analysis conversions don't fail if any operations fail to legalize,
+ /// they are only interested in the operations that were successfully
+ /// legalized.
+ if (mode == OpConversionMode::Analysis)
+ legalizableOps->insert(op);
+
+ // If legalization succeeded, convert the types any of the blocks within
+ // this operation.
+ if (failed(convertBlockSignatures(rewriter, op)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult
+OperationConverter::convertOperations(ArrayRef<Operation *> ops,
+ TypeConverter *typeConverter) {
+ if (ops.empty())
+ return success();
+ ConversionTarget &target = opLegalizer.getTarget();
+
+ /// Compute the set of operations and blocks to convert.
+ std::vector<Operation *> toConvert;
+ for (auto *op : ops) {
+ toConvert.emplace_back(op);
+ for (auto &region : op->getRegions())
+ if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
+ toConvert, &target)))
+ return failure();
+ }
+
+ // Convert each operation and discard rewrites on failure.
+ ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
+ for (auto *op : toConvert)
+ if (failed(convert(rewriter, op)))
+ return rewriter.getImpl().discardRewrites(), failure();
+
+ // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
+ // analysis conversion.
+ if (mode == OpConversionMode::Analysis)
+ rewriter.getImpl().discardRewrites();
+ else
+ rewriter.getImpl().applyRewrites();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+/// Remap an input of the original signature with a new set of types. The
+/// new types are appended to the new signature conversion.
+void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
+ ArrayRef<Type> types) {
+ assert(!types.empty() && "expected valid types");
+ remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
+ addInputs(types);
+}
+
+/// Append new input types to the signature conversion, this should only be
+/// used if the new types are not intended to remap an existing input.
+void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
+ assert(!types.empty() &&
+ "1->0 type remappings don't need to be added explicitly");
+ argTypes.append(types.begin(), types.end());
+}
+
+/// Remap an input of the original signature with a range of types in the
+/// new signature.
+void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
+ unsigned newInputNo,
+ unsigned newInputCount) {
+ assert(!remappedInputs[origInputNo] && "input has already been remapped");
+ assert(newInputCount != 0 && "expected valid input count");
+ remappedInputs[origInputNo] =
+ InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
+}
+
+/// Remap an input of the original signature to another `replacementValue`
+/// value. This would make the signature converter drop this argument.
+void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
+ Value replacementValue) {
+ assert(!remappedInputs[origInputNo] && "input has already been remapped");
+ remappedInputs[origInputNo] =
+ InputMapping{origInputNo, /*size=*/0, replacementValue};
+}
+
+/// This hooks allows for converting a type.
+LogicalResult TypeConverter::convertType(Type t,
+ SmallVectorImpl<Type> &results) {
+ if (auto newT = convertType(t)) {
+ results.push_back(newT);
+ return success();
+ }
+ return failure();
+}
+
+/// Convert the given set of types, filling 'results' as necessary. This
+/// returns failure if the conversion of any of the types fails, success
+/// otherwise.
+LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
+ SmallVectorImpl<Type> &results) {
+ for (auto type : types)
+ if (failed(convertType(type, results)))
+ return failure();
+ return success();
+}
+
+/// Return true if the given type is legal for this type converter, i.e. the
+/// type converts to itself.
+bool TypeConverter::isLegal(Type type) {
+ SmallVector<Type, 1> results;
+ return succeeded(convertType(type, results)) && results.size() == 1 &&
+ results.front() == type;
+}
+
+/// Return true if the inputs and outputs of the given function type are
+/// legal.
+bool TypeConverter::isSignatureLegal(FunctionType funcType) {
+ return llvm::all_of(
+ llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
+ [this](Type type) { return isLegal(type); });
+}
+
+/// This hook allows for converting a specific argument of a signature.
+LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+ SignatureConversion &result) {
+ // Try to convert the given input type.
+ SmallVector<Type, 1> convertedTypes;
+ if (failed(convertType(type, convertedTypes)))
+ return failure();
+
+ // If this argument is being dropped, there is nothing left to do.
+ if (convertedTypes.empty())
+ return success();
+
+ // Otherwise, add the new inputs.
+ result.addInputs(inputNo, convertedTypes);
+ return success();
+}
+
+/// Create a default conversion pattern that rewrites the type signature of a
+/// FuncOp.
+namespace {
+struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
+ FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+ : OpConversionPattern(ctx), converter(converter) {}
+
+ /// Hook for derived classes to implement combined matching and rewriting.
+ PatternMatchResult
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FunctionType type = funcOp.getType();
+
+ // Convert the original function arguments.
+ TypeConverter::SignatureConversion result(type.getNumInputs());
+ for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+ if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
+ return matchFailure();
+
+ // Convert the original function results.
+ SmallVector<Type, 1> convertedResults;
+ if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+ return matchFailure();
+
+ // Update the function signature in-place.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(FunctionType::get(result.getConvertedTypes(),
+ convertedResults, funcOp.getContext()));
+ rewriter.applySignatureConversion(&funcOp.getBody(), result);
+ });
+ return matchSuccess();
+ }
+
+ /// The type converter to use when rewriting the signature.
+ TypeConverter &converter;
+};
+} // end anonymous namespace
+
+void mlir::populateFuncOpTypeConversionPattern(
+ OwningRewritePatternList &patterns, MLIRContext *ctx,
+ TypeConverter &converter) {
+ patterns.insert<FuncOpSignatureConversion>(ctx, converter);
+}
+
+/// This function converts the type signature of the given block, by invoking
+/// 'convertSignatureArg' for each argument. This function should return a valid
+/// conversion for the signature on success, None otherwise.
+auto TypeConverter::convertBlockSignature(Block *block)
+ -> Optional<SignatureConversion> {
+ SignatureConversion conversion(block->getNumArguments());
+ for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
+ if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
+ conversion)))
+ return llvm::None;
+ return conversion;
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionTarget
+//===----------------------------------------------------------------------===//
+
+/// Register a legality action for the given operation.
+void ConversionTarget::setOpAction(OperationName op,
+ LegalizationAction action) {
+ legalOperations[op] = {action, /*isRecursivelyLegal=*/false};
+}
+
+/// Register a legality action for the given dialects.
+void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
+ LegalizationAction action) {
+ for (StringRef dialect : dialectNames)
+ legalDialects[dialect] = action;
+}
+
+/// Get the legality action for the given operation.
+auto ConversionTarget::getOpAction(OperationName op) const
+ -> Optional<LegalizationAction> {
+ Optional<LegalizationInfo> info = getOpInfo(op);
+ return info ? info->action : Optional<LegalizationAction>();
+}
+
+/// If the given operation instance is legal on this target, a structure
+/// containing legality information is returned. If the operation is not legal,
+/// None is returned.
+auto ConversionTarget::isLegal(Operation *op) const
+ -> Optional<LegalOpDetails> {
+ Optional<LegalizationInfo> info = getOpInfo(op->getName());
+ if (!info)
+ return llvm::None;
+
+ // Returns true if this operation instance is known to be legal.
+ auto isOpLegal = [&] {
+ // Handle dynamic legality.
+ if (info->action == LegalizationAction::Dynamic) {
+ // Check for callbacks on the operation or dialect.
+ auto opFn = opLegalityFns.find(op->getName());
+ if (opFn != opLegalityFns.end())
+ return opFn->second(op);
+ auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
+ if (dialectFn != dialectLegalityFns.end())
+ return dialectFn->second(op);
+
+ // Otherwise, invoke the hook on the derived instance.
+ return isDynamicallyLegal(op);
+ }
+
+ // Otherwise, the operation is only legal if it was marked 'Legal'.
+ return info->action == LegalizationAction::Legal;
+ };
+ if (!isOpLegal())
+ return llvm::None;
+
+ // This operation is legal, compute any additional legality information.
+ LegalOpDetails legalityDetails;
+
+ if (info->isRecursivelyLegal) {
+ auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
+ if (legalityFnIt != opRecursiveLegalityFns.end())
+ legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
+ else
+ legalityDetails.isRecursivelyLegal = true;
+ }
+ return legalityDetails;
+}
+
+/// Set the dynamic legality callback for the given operation.
+void ConversionTarget::setLegalityCallback(
+ OperationName name, const DynamicLegalityCallbackFn &callback) {
+ assert(callback && "expected valid legality callback");
+ opLegalityFns[name] = callback;
+}
+
+/// Set the recursive legality callback for the given operation and mark the
+/// operation as recursively legal.
+void ConversionTarget::markOpRecursivelyLegal(
+ OperationName name, const DynamicLegalityCallbackFn &callback) {
+ auto infoIt = legalOperations.find(name);
+ assert(infoIt != legalOperations.end() &&
+ infoIt->second.action != LegalizationAction::Illegal &&
+ "expected operation to already be marked as legal");
+ infoIt->second.isRecursivelyLegal = true;
+ if (callback)
+ opRecursiveLegalityFns[name] = callback;
+ else
+ opRecursiveLegalityFns.erase(name);
+}
+
+/// Set the dynamic legality callback for the given dialects.
+void ConversionTarget::setLegalityCallback(
+ ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
+ assert(callback && "expected valid legality callback");
+ for (StringRef dialect : dialects)
+ dialectLegalityFns[dialect] = callback;
+}
+
+/// Get the legalization information for the given operation.
+auto ConversionTarget::getOpInfo(OperationName op) const
+ -> Optional<LegalizationInfo> {
+ // Check for info for this specific operation.
+ auto it = legalOperations.find(op);
+ if (it != legalOperations.end())
+ return it->second;
+ // Otherwise, default to checking on the parent dialect.
+ auto dialectIt = legalDialects.find(op.getDialect());
+ if (dialectIt != legalDialects.end())
+ return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false};
+ return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Conversion Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Apply a partial conversion on the given operations, and all nested
+/// operations. This method converts as many operations to the target as
+/// possible, ignoring operations that failed to legalize.
+LogicalResult mlir::applyPartialConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const OwningRewritePatternList &patterns, TypeConverter *converter) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
+ return opConverter.convertOperations(ops, converter);
+}
+LogicalResult
+mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ TypeConverter *converter) {
+ return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
+ converter);
+}
+
+/// Apply a complete conversion on the given operations, and all nested
+/// operations. This method will return failure if the conversion of any
+/// operation fails.
+LogicalResult
+mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ TypeConverter *converter) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+ return opConverter.convertOperations(ops, converter);
+}
+LogicalResult
+mlir::applyFullConversion(Operation *op, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ TypeConverter *converter) {
+ return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
+ converter);
+}
+
+/// Apply an analysis conversion on the given operations, and all nested
+/// operations. This method analyzes which operations would be successfully
+/// converted to the target if a conversion was applied. All operations that
+/// were found to be legalizable to the given 'target' are placed within the
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+LogicalResult mlir::applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> &convertedOps, TypeConverter *converter) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+ &convertedOps);
+ return opConverter.convertOperations(ops, converter);
+}
+LogicalResult
+mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
+ const OwningRewritePatternList &patterns,
+ DenseSet<Operation *> &convertedOps,
+ TypeConverter *converter) {
+ return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
+ convertedOps, converter);
+}
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
new file mode 100644
index 00000000000..b2cee7da083
--- /dev/null
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -0,0 +1,296 @@
+//===- Inliner.cpp - Pass to inline function calls ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a basic inlining algorithm that operates bottom up over
+// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
+// incremental propagation of inlining decisions from the leafs to the roots of
+// the callgraph.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/CallGraph.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Parallel.h"
+
+#define DEBUG_TYPE "inlining"
+
+using namespace mlir;
+
+static llvm::cl::opt<bool> disableCanonicalization(
+ "mlir-disable-inline-simplify",
+ llvm::cl::desc("Disable running simplifications during inlining"),
+ llvm::cl::ReallyHidden, llvm::cl::init(false));
+
+static llvm::cl::opt<unsigned> maxInliningIterations(
+ "mlir-max-inline-iterations",
+ llvm::cl::desc("Maximum number of iterations when inlining within an SCC"),
+ llvm::cl::ReallyHidden, llvm::cl::init(4));
+
+//===----------------------------------------------------------------------===//
+// CallGraph traversal
+//===----------------------------------------------------------------------===//
+
+/// Run a given transformation over the SCCs of the callgraph in a bottom up
+/// traversal.
+static void runTransformOnCGSCCs(
+ const CallGraph &cg,
+ function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
+ std::vector<CallGraphNode *> currentSCCVec;
+ auto cgi = llvm::scc_begin(&cg);
+ while (!cgi.isAtEnd()) {
+ // Copy the current SCC and increment so that the transformer can modify the
+ // SCC without invalidating our iterator.
+ currentSCCVec = *cgi;
+ ++cgi;
+ sccTransformer(currentSCCVec);
+ }
+}
+
+namespace {
+/// This struct represents a resolved call to a given callgraph node. Given that
+/// the call does not actually contain a direct reference to the
+/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
+/// explicitly.
+struct ResolvedCall {
+ ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
+ : call(call), targetNode(targetNode) {}
+ CallOpInterface call;
+ CallGraphNode *targetNode;
+};
+} // end anonymous namespace
+
+/// Collect all of the callable operations within the given range of blocks. If
+/// `traverseNestedCGNodes` is true, this will also collect call operations
+/// inside of nested callgraph nodes.
+static void collectCallOps(iterator_range<Region::iterator> blocks,
+ CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
+ bool traverseNestedCGNodes) {
+ SmallVector<Block *, 8> worklist;
+ auto addToWorklist = [&](iterator_range<Region::iterator> blocks) {
+ for (Block &block : blocks)
+ worklist.push_back(&block);
+ };
+
+ addToWorklist(blocks);
+ while (!worklist.empty()) {
+ for (Operation &op : *worklist.pop_back_val()) {
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ CallGraphNode *node =
+ cg.resolveCallable(call.getCallableForCallee(), &op);
+ if (!node->isExternal())
+ calls.emplace_back(call, node);
+ continue;
+ }
+
+ // If this is not a call, traverse the nested regions. If
+ // `traverseNestedCGNodes` is false, then don't traverse nested call graph
+ // regions.
+ for (auto &nestedRegion : op.getRegions())
+ if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
+ addToWorklist(nestedRegion);
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Inliner
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides a specialization of the main inlining interface.
+struct Inliner : public InlinerInterface {
+ Inliner(MLIRContext *context, CallGraph &cg)
+ : InlinerInterface(context), cg(cg) {}
+
+ /// Process a set of blocks that have been inlined. This callback is invoked
+ /// *before* inlined terminator operations have been processed.
+ void
+ processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
+ collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
+ }
+
+ /// The current set of call instructions to consider for inlining.
+ SmallVector<ResolvedCall, 8> calls;
+
+ /// The callgraph being operated on.
+ CallGraph &cg;
+};
+} // namespace
+
+/// Returns true if the given call should be inlined.
+static bool shouldInline(ResolvedCall &resolvedCall) {
+ // Don't allow inlining terminator calls. We currently don't support this
+ // case.
+ if (resolvedCall.call.getOperation()->isKnownTerminator())
+ return false;
+
+ // Don't allow inlining if the target is an ancestor of the call. This
+ // prevents inlining recursively.
+ if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
+ resolvedCall.call.getParentRegion()))
+ return false;
+
+ // Otherwise, inline.
+ return true;
+}
+
+/// Attempt to inline calls within the given scc. This function returns
+/// success if any calls were inlined, failure otherwise.
+static LogicalResult inlineCallsInSCC(Inliner &inliner,
+ ArrayRef<CallGraphNode *> currentSCC) {
+ CallGraph &cg = inliner.cg;
+ auto &calls = inliner.calls;
+
+ // Collect all of the direct calls within the nodes of the current SCC. We
+ // don't traverse nested callgraph nodes, because they are handled separately
+ // likely within a different SCC.
+ for (auto *node : currentSCC) {
+ if (!node->isExternal())
+ collectCallOps(*node->getCallableRegion(), cg, calls,
+ /*traverseNestedCGNodes=*/false);
+ }
+ if (calls.empty())
+ return failure();
+
+ // Try to inline each of the call operations. Don't cache the end iterator
+ // here as more calls may be added during inlining.
+ bool inlinedAnyCalls = false;
+ for (unsigned i = 0; i != calls.size(); ++i) {
+ ResolvedCall &it = calls[i];
+ LLVM_DEBUG({
+ llvm::dbgs() << "* Considering inlining call: ";
+ it.call.dump();
+ });
+ if (!shouldInline(it))
+ continue;
+
+ CallOpInterface call = it.call;
+ Region *targetRegion = it.targetNode->getCallableRegion();
+ LogicalResult inlineResult = inlineCall(
+ inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
+ targetRegion);
+ if (failed(inlineResult))
+ continue;
+
+ // If the inlining was successful, then erase the call.
+ call.erase();
+ inlinedAnyCalls = true;
+ }
+ calls.clear();
+ return success(inlinedAnyCalls);
+}
+
+/// Canonicalize the nodes within the given SCC with the given set of
+/// canonicalization patterns.
+static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC,
+ MLIRContext *context,
+ const OwningRewritePatternList &canonPatterns) {
+ // Collect the sets of nodes to canonicalize.
+ SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
+ for (auto *node : currentSCC) {
+ // Don't canonicalize the external node, it has no valid callable region.
+ if (node->isExternal())
+ continue;
+
+ // Don't canonicalize nodes with children. Nodes with children
+ // require special handling as we may remove the node during
+ // canonicalization. In the future, we should be able to handle this
+ // case with proper node deletion tracking.
+ if (node->hasChildren())
+ continue;
+
+ // We also won't apply canonicalizations for nodes that are not
+ // isolated. This avoids potentially mutating the regions of nodes defined
+ // above, this is also a stipulation of the 'applyPatternsGreedily' driver.
+ auto *region = node->getCallableRegion();
+ if (!region->getParentOp()->isKnownIsolatedFromAbove())
+ continue;
+ nodesToCanonicalize.push_back(node);
+ }
+ if (nodesToCanonicalize.empty())
+ return;
+
+ // Canonicalize each of the nodes within the SCC in parallel.
+ // NOTE: This is simple now, because we don't enable canonicalizing nodes
+ // within children. When we remove this restriction, this logic will need to
+ // be reworked.
+ ParallelDiagnosticHandler canonicalizationHandler(context);
+ llvm::parallel::for_each_n(
+ llvm::parallel::par, /*Begin=*/size_t(0),
+ /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
+ // Set the order for this thread so that diagnostics will be properly
+ // ordered.
+ canonicalizationHandler.setOrderIDForThread(index);
+
+ // Apply the canonicalization patterns to this region.
+ auto *node = nodesToCanonicalize[index];
+ applyPatternsGreedily(*node->getCallableRegion(), canonPatterns);
+
+ // Make sure to reset the order ID for the diagnostic handler, as this
+ // thread may be used in a different context.
+ canonicalizationHandler.eraseOrderIDForThread();
+ });
+}
+
+/// Attempt to inline calls within the given scc, and run canonicalizations with
+/// the given patterns, until a fixed point is reached. This allows for the
+/// inlining of newly devirtualized calls.
+static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC,
+ MLIRContext *context,
+ const OwningRewritePatternList &canonPatterns) {
+ // If we successfully inlined any calls, run some simplifications on the
+ // nodes of the scc. Continue attempting to inline until we reach a fixed
+ // point, or a maximum iteration count. We canonicalize here as it may
+ // devirtualize new calls, as well as give us a better cost model.
+ unsigned iterationCount = 0;
+ while (succeeded(inlineCallsInSCC(inliner, currentSCC))) {
+ // If we aren't allowing simplifications or the max iteration count was
+ // reached, then bail out early.
+ if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
+ break;
+ canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// InlinerPass
+//===----------------------------------------------------------------------===//
+
+// TODO(riverriddle) This pass should currently only be used for basic testing
+// of inlining functionality.
+namespace {
+struct InlinerPass : public OperationPass<InlinerPass> {
+ void runOnOperation() override {
+ CallGraph &cg = getAnalysis<CallGraph>();
+ auto *context = &getContext();
+
+ // Collect a set of canonicalization patterns to use when simplifying
+ // callable regions within an SCC.
+ OwningRewritePatternList canonPatterns;
+ for (auto *op : context->getRegisteredOperations())
+ op->getCanonicalizationPatterns(canonPatterns, context);
+
+ // Run the inline transform in post-order over the SCCs in the callgraph.
+ Inliner inliner(context, cg);
+ runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
+ inlineSCC(inliner, scc, context, canonPatterns);
+ });
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createInlinerPass() {
+ return std::make_unique<InlinerPass>();
+}
+
+static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp
new file mode 100644
index 00000000000..2aee688c6c1
--- /dev/null
+++ b/mlir/lib/Transforms/LoopCoalescing.cpp
@@ -0,0 +1,96 @@
+//===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/Debug.h"
+
+#define PASS_NAME "loop-coalescing"
+#define DEBUG_TYPE PASS_NAME
+
+using namespace mlir;
+
+namespace {
+class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> {
+public:
+ void runOnFunction() override {
+ FuncOp func = getFunction();
+
+ func.walk([](loop::ForOp op) {
+ // Ignore nested loops.
+ if (op.getParentOfType<loop::ForOp>())
+ return;
+
+ SmallVector<loop::ForOp, 4> loops;
+ getPerfectlyNestedLoops(loops, op);
+ LLVM_DEBUG(llvm::dbgs()
+ << "found a perfect nest of depth " << loops.size() << '\n');
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+ // 1. For each loop, find above which parent loop its operands are
+ // defined.
+ SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ if (areValuesDefinedAbove(loops[i].getOperands(),
+ loops[j].region())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << " bounds of loop " << i << " are known above depth "
+ << operandsDefinedAbove[i] << '\n');
+ }
+
+ // 2. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+
+ assert(maxPos == start &&
+ "expected loop bounds to be known at the start of the band");
+ LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start
+ << " to " << end << '\n');
+
+ auto band =
+ llvm::makeMutableArrayRef(loops.data() + start, end - start);
+ coalesceLoops(band);
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopCoalescingPass() {
+ return std::make_unique<LoopCoalescingPass>();
+}
+
+static PassRegistration<LoopCoalescingPass>
+ reg(PASS_NAME,
+ "coalesce nested loops with independent bounds into a single loop");
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
new file mode 100644
index 00000000000..fcfc1d7ae52
--- /dev/null
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -0,0 +1,1979 @@
+//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop fusion.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopFusionUtils.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <iomanip>
+#include <sstream>
+#define DEBUG_TYPE "affine-loop-fusion"
+
+using llvm::SetVector;
+
+using namespace mlir;
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+/// Disables fusion profitability check and fuses if valid. Ignore any
+/// additional (redundant) computation tolerance threshold
+/// that would have prevented fusion.
+static llvm::cl::opt<bool>
+ clMaximalLoopFusion("fusion-maximal",
+ llvm::cl::desc("Enables maximal loop fusion"),
+ llvm::cl::cat(clOptionsCategory));
+
+/// A threshold in percent of additional computation allowed when fusing.
+static llvm::cl::opt<double> clFusionAddlComputeTolerance(
+ "fusion-compute-tolerance",
+ llvm::cl::desc("Fractional increase in additional "
+ "computation tolerated while fusing"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
+ "fusion-fast-mem-space",
+ llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
+ llvm::cl::cat(clOptionsCategory));
+
+// A local buffer of size less than or equal to this size is automatically
+// promoted to fast memory after producer-consumer fusion.
+static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold(
+ "fusion-local-buf-threshold",
+ llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast "
+ "memory space"),
+ llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// Loop fusion pass. This pass currently supports a greedy fusion policy,
+/// which fuses loop nests with single-writer/single-reader memref dependences
+/// with the goal of improving locality.
+
+// TODO(andydavis) Support fusion of source loop nests which write to multiple
+// memrefs, where each memref can have multiple users (if profitable).
+// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
+// and add support for more general loop fusion algorithms.
+
+struct LoopFusion : public FunctionPass<LoopFusion> {
+ LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
+ bool maximalFusion = false)
+ : localBufSizeThreshold(localBufSizeThreshold),
+ fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
+
+ void runOnFunction() override;
+
+ // Any local buffers smaller than this size (in bytes) will be created in
+ // `fastMemorySpace` if provided.
+ uint64_t localBufSizeThreshold;
+ Optional<unsigned> fastMemorySpace = None;
+ // If true, ignore any additional (redundant) computation tolerance threshold
+ // that would have prevented fusion.
+ bool maximalFusion;
+
+ // The amount of additional computation that is tolerated while fusing
+ // pair-wise as a fraction of the total computation.
+ constexpr static double kComputeToleranceThreshold = 0.30f;
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createLoopFusionPass(unsigned fastMemorySpace,
+ uint64_t localBufSizeThreshold, bool maximalFusion) {
+ return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
+ maximalFusion);
+}
+
+// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
+static bool isMemRefDereferencingOp(Operation &op) {
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
+ return true;
+ return false;
+}
+
+namespace {
+
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not an IfInst was encountered in the loop nest.
+struct LoopNestStateCollector {
+ SmallVector<AffineForOp, 4> forOps;
+ SmallVector<Operation *, 4> loadOpInsts;
+ SmallVector<Operation *, 4> storeOpInsts;
+ bool hasNonForRegion = false;
+
+ void collect(Operation *opToWalk) {
+ opToWalk->walk([&](Operation *op) {
+ if (isa<AffineForOp>(op))
+ forOps.push_back(cast<AffineForOp>(op));
+ else if (op->getNumRegions() != 0)
+ hasNonForRegion = true;
+ else if (isa<AffineLoadOp>(op))
+ loadOpInsts.push_back(op);
+ else if (isa<AffineStoreOp>(op))
+ storeOpInsts.push_back(op);
+ });
+ }
+};
+
+// MemRefDependenceGraph is a graph data structure where graph nodes are
+// top-level operations in a FuncOp which contain load/store ops, and edges
+// are memref dependences between the nodes.
+// TODO(andydavis) Add a more flexible dependence graph representation.
+// TODO(andydavis) Add a depth parameter to dependence graph construction.
+struct MemRefDependenceGraph {
+public:
+ // Node represents a node in the graph. A Node is either an entire loop nest
+ // rooted at the top level which contains loads/stores, or a top level
+ // load/store.
+ struct Node {
+ // The unique identifier of this node in the graph.
+ unsigned id;
+ // The top-level statement which is (or contains) a load/store.
+ Operation *op;
+ // List of load operations.
+ SmallVector<Operation *, 4> loads;
+ // List of store op insts.
+ SmallVector<Operation *, 4> stores;
+ Node(unsigned id, Operation *op) : id(id), op(op) {}
+
+ // Returns the load op count for 'memref'.
+ unsigned getLoadOpCount(Value memref) {
+ unsigned loadOpCount = 0;
+ for (auto *loadOpInst : loads) {
+ if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ ++loadOpCount;
+ }
+ return loadOpCount;
+ }
+
+ // Returns the store op count for 'memref'.
+ unsigned getStoreOpCount(Value memref) {
+ unsigned storeOpCount = 0;
+ for (auto *storeOpInst : stores) {
+ if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ ++storeOpCount;
+ }
+ return storeOpCount;
+ }
+
+ // Returns all store ops in 'storeOps' which access 'memref'.
+ void getStoreOpsForMemref(Value memref,
+ SmallVectorImpl<Operation *> *storeOps) {
+ for (auto *storeOpInst : stores) {
+ if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ storeOps->push_back(storeOpInst);
+ }
+ }
+
+ // Returns all load ops in 'loadOps' which access 'memref'.
+ void getLoadOpsForMemref(Value memref,
+ SmallVectorImpl<Operation *> *loadOps) {
+ for (auto *loadOpInst : loads) {
+ if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ loadOps->push_back(loadOpInst);
+ }
+ }
+
+ // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+ // has at least one load and store operation.
+ void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
+ llvm::SmallDenseSet<Value, 2> loadMemrefs;
+ for (auto *loadOpInst : loads) {
+ loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+ }
+ for (auto *storeOpInst : stores) {
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ if (loadMemrefs.count(memref) > 0)
+ loadAndStoreMemrefSet->insert(memref);
+ }
+ }
+ };
+
+ // Edge represents a data dependence between nodes in the graph.
+ struct Edge {
+ // The id of the node at the other end of the edge.
+ // If this edge is stored in Edge = Node.inEdges[i], then
+ // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
+ // If this edge is stored in Edge = Node.outEdges[i], then
+ // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
+ unsigned id;
+ // The SSA value on which this edge represents a dependence.
+ // If the value is a memref, then the dependence is between graph nodes
+ // which contain accesses to the same memref 'value'. If the value is a
+ // non-memref value, then the dependence is between a graph node which
+ // defines an SSA value and another graph node which uses the SSA value
+ // (e.g. a constant operation defining a value which is used inside a loop
+ // nest).
+ Value value;
+ };
+
+ // Map from node id to Node.
+ DenseMap<unsigned, Node> nodes;
+ // Map from node id to list of input edges.
+ DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
+ // Map from node id to list of output edges.
+ DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
+ // Map from memref to a count on the dependence edges associated with that
+ // memref.
+ DenseMap<Value, unsigned> memrefEdgeCount;
+ // The next unique identifier to use for newly created graph nodes.
+ unsigned nextNodeId = 0;
+
+ MemRefDependenceGraph() {}
+
+ // Initializes the dependence graph based on operations in 'f'.
+ // Returns true on success, false otherwise.
+ bool init(FuncOp f);
+
+ // Returns the graph node for 'id'.
+ Node *getNode(unsigned id) {
+ auto it = nodes.find(id);
+ assert(it != nodes.end());
+ return &it->second;
+ }
+
+ // Returns the graph node for 'forOp'.
+ Node *getForOpNode(AffineForOp forOp) {
+ for (auto &idAndNode : nodes)
+ if (idAndNode.second.op == forOp.getOperation())
+ return &idAndNode.second;
+ return nullptr;
+ }
+
+ // Adds a node with 'op' to the graph and returns its unique identifier.
+ unsigned addNode(Operation *op) {
+ Node node(nextNodeId++, op);
+ nodes.insert({node.id, node});
+ return node.id;
+ }
+
+ // Remove node 'id' (and its associated edges) from graph.
+ void removeNode(unsigned id) {
+ // Remove each edge in 'inEdges[id]'.
+ if (inEdges.count(id) > 0) {
+ SmallVector<Edge, 2> oldInEdges = inEdges[id];
+ for (auto &inEdge : oldInEdges) {
+ removeEdge(inEdge.id, id, inEdge.value);
+ }
+ }
+ // Remove each edge in 'outEdges[id]'.
+ if (outEdges.count(id) > 0) {
+ SmallVector<Edge, 2> oldOutEdges = outEdges[id];
+ for (auto &outEdge : oldOutEdges) {
+ removeEdge(id, outEdge.id, outEdge.value);
+ }
+ }
+ // Erase remaining node state.
+ inEdges.erase(id);
+ outEdges.erase(id);
+ nodes.erase(id);
+ }
+
+ // Returns true if node 'id' writes to any memref which escapes (or is an
+ // argument to) the function/block. Returns false otherwise.
+ bool writesToLiveInOrEscapingMemrefs(unsigned id) {
+ Node *node = getNode(id);
+ for (auto *storeOpInst : node->stores) {
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto *op = memref->getDefiningOp();
+ // Return true if 'memref' is a block argument.
+ if (!op)
+ return true;
+ // Return true if any use of 'memref' escapes the function.
+ for (auto *user : memref->getUsers())
+ if (!isMemRefDereferencingOp(*user))
+ return true;
+ }
+ return false;
+ }
+
+ // Returns the unique AffineStoreOp in `node` that meets all the following:
+ // *) store is the only one that writes to a function-local memref live out
+ // of `node`,
+ // *) store is not the source of a self-dependence on `node`.
+ // Otherwise, returns a null AffineStoreOp.
+ AffineStoreOp getUniqueOutgoingStore(Node *node) {
+ AffineStoreOp uniqueStore;
+
+ // Return null if `node` doesn't have any outgoing edges.
+ auto outEdgeIt = outEdges.find(node->id);
+ if (outEdgeIt == outEdges.end())
+ return nullptr;
+
+ const auto &nodeOutEdges = outEdgeIt->second;
+ for (auto *op : node->stores) {
+ auto storeOp = cast<AffineStoreOp>(op);
+ auto memref = storeOp.getMemRef();
+ // Skip this store if there are no dependences on its memref. This means
+ // that store either:
+ // *) writes to a memref that is only read within the same loop nest
+ // (self-dependence edges are not represented in graph at the moment),
+ // *) writes to a function live out memref (function parameter), or
+ // *) is dead.
+ if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+ return (edge.value != memref);
+ }))
+ continue;
+
+ if (uniqueStore)
+ // Found multiple stores to function-local live-out memrefs.
+ return nullptr;
+ // Found first store to function-local live-out memref.
+ uniqueStore = storeOp;
+ }
+
+ return uniqueStore;
+ }
+
+ // Returns true if node 'id' can be removed from the graph. Returns false
+ // otherwise. A node can be removed from the graph iff the following
+ // conditions are met:
+ // *) The node does not write to any memref which escapes (or is a
+ // function/block argument).
+ // *) The node has no successors in the dependence graph.
+ bool canRemoveNode(unsigned id) {
+ if (writesToLiveInOrEscapingMemrefs(id))
+ return false;
+ Node *node = getNode(id);
+ for (auto *storeOpInst : node->stores) {
+ // Return false if there exist out edges from 'id' on 'memref'.
+ if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+ return false;
+ }
+ return true;
+ }
+
+ // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+ // is for 'value' if non-null, or for any value otherwise. Returns false
+ // otherwise.
+ bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) {
+ if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+ return false;
+ }
+ bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+ return edge.id == dstId && (!value || edge.value == value);
+ });
+ bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+ return edge.id == srcId && (!value || edge.value == value);
+ });
+ return hasOutEdge && hasInEdge;
+ }
+
+ // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+ void addEdge(unsigned srcId, unsigned dstId, Value value) {
+ if (!hasEdge(srcId, dstId, value)) {
+ outEdges[srcId].push_back({dstId, value});
+ inEdges[dstId].push_back({srcId, value});
+ if (value->getType().isa<MemRefType>())
+ memrefEdgeCount[value]++;
+ }
+ }
+
+ // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+ void removeEdge(unsigned srcId, unsigned dstId, Value value) {
+ assert(inEdges.count(dstId) > 0);
+ assert(outEdges.count(srcId) > 0);
+ if (value->getType().isa<MemRefType>()) {
+ assert(memrefEdgeCount.count(value) > 0);
+ memrefEdgeCount[value]--;
+ }
+ // Remove 'srcId' from 'inEdges[dstId]'.
+ for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
+ if ((*it).id == srcId && (*it).value == value) {
+ inEdges[dstId].erase(it);
+ break;
+ }
+ }
+ // Remove 'dstId' from 'outEdges[srcId]'.
+ for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
+ if ((*it).id == dstId && (*it).value == value) {
+ outEdges[srcId].erase(it);
+ break;
+ }
+ }
+ }
+
+ // Returns true if there is a path in the dependence graph from node 'srcId'
+ // to node 'dstId'. Returns false otherwise.
+ bool hasDependencePath(unsigned srcId, unsigned dstId) {
+ // Worklist state is: <node-id, next-output-edge-index-to-visit>
+ SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
+ worklist.push_back({srcId, 0});
+ // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
+ while (!worklist.empty()) {
+ auto &idAndIndex = worklist.back();
+ // Return true if we have reached 'dstId'.
+ if (idAndIndex.first == dstId)
+ return true;
+ // Pop and continue if node has no out edges, or if all out edges have
+ // already been visited.
+ if (outEdges.count(idAndIndex.first) == 0 ||
+ idAndIndex.second == outEdges[idAndIndex.first].size()) {
+ worklist.pop_back();
+ continue;
+ }
+ // Get graph edge to traverse.
+ Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+ // Increment next output edge index for 'idAndIndex'.
+ ++idAndIndex.second;
+ // Add node at 'edge.id' to worklist.
+ worklist.push_back({edge.id, 0});
+ }
+ return false;
+ }
+
+ // Returns the input edge count for node 'id' and 'memref' from src nodes
+ // which access 'memref' with a store operation.
+ unsigned getIncomingMemRefAccesses(unsigned id, Value memref) {
+ unsigned inEdgeCount = 0;
+ if (inEdges.count(id) > 0)
+ for (auto &inEdge : inEdges[id])
+ if (inEdge.value == memref) {
+ Node *srcNode = getNode(inEdge.id);
+ // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+ if (srcNode->getStoreOpCount(memref) > 0)
+ ++inEdgeCount;
+ }
+ return inEdgeCount;
+ }
+
+ // Returns the output edge count for node 'id' and 'memref' (if non-null),
+ // otherwise returns the total output edge count from node 'id'.
+ unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) {
+ unsigned outEdgeCount = 0;
+ if (outEdges.count(id) > 0)
+ for (auto &outEdge : outEdges[id])
+ if (!memref || outEdge.value == memref)
+ ++outEdgeCount;
+ return outEdgeCount;
+ }
+
+ // Computes and returns an insertion point operation, before which the
+ // the fused <srcId, dstId> loop nest can be inserted while preserving
+ // dependences. Returns nullptr if no such insertion point is found.
+ Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
+ if (outEdges.count(srcId) == 0)
+ return getNode(dstId)->op;
+
+ // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
+ SmallPtrSet<Operation *, 2> srcDepInsts;
+ for (auto &outEdge : outEdges[srcId])
+ if (outEdge.id != dstId)
+ srcDepInsts.insert(getNode(outEdge.id)->op);
+
+ // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
+ SmallPtrSet<Operation *, 2> dstDepInsts;
+ for (auto &inEdge : inEdges[dstId])
+ if (inEdge.id != srcId)
+ dstDepInsts.insert(getNode(inEdge.id)->op);
+
+ Operation *srcNodeInst = getNode(srcId)->op;
+ Operation *dstNodeInst = getNode(dstId)->op;
+
+ // Computing insertion point:
+ // *) Walk all operation positions in Block operation list in the
+ // range (src, dst). For each operation 'op' visited in this search:
+ // *) Store in 'firstSrcDepPos' the first position where 'op' has a
+ // dependence edge from 'srcNode'.
+ // *) Store in 'lastDstDepPost' the last position where 'op' has a
+ // dependence edge to 'dstNode'.
+ // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
+ // operation insertion point (or return null pointer if no such
+ // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
+ SmallVector<Operation *, 2> depInsts;
+ Optional<unsigned> firstSrcDepPos;
+ Optional<unsigned> lastDstDepPos;
+ unsigned pos = 0;
+ for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
+ it != Block::iterator(dstNodeInst); ++it) {
+ Operation *op = &(*it);
+ if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None)
+ firstSrcDepPos = pos;
+ if (dstDepInsts.count(op) > 0)
+ lastDstDepPos = pos;
+ depInsts.push_back(op);
+ ++pos;
+ }
+
+ if (firstSrcDepPos.hasValue()) {
+ if (lastDstDepPos.hasValue()) {
+ if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
+ // No valid insertion point exists which preserves dependences.
+ return nullptr;
+ }
+ }
+ // Return the insertion point at 'firstSrcDepPos'.
+ return depInsts[firstSrcDepPos.getValue()];
+ }
+ // No dependence targets in range (or only dst deps in range), return
+ // 'dstNodInst' insertion point.
+ return dstNodeInst;
+ }
+
+ // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
+ // has been replaced in node at 'dstId' by a private memref depending
+ // on the value of 'createPrivateMemRef'.
+ void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
+ bool createPrivateMemRef) {
+ // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
+ if (inEdges.count(srcId) > 0) {
+ SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
+ for (auto &inEdge : oldInEdges) {
+ // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
+ if (inEdge.value != oldMemRef)
+ addEdge(inEdge.id, dstId, inEdge.value);
+ }
+ }
+ // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
+ if (outEdges.count(srcId) > 0) {
+ SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
+ for (auto &outEdge : oldOutEdges) {
+ // Remove any out edges from 'srcId' to 'dstId' across memrefs.
+ if (outEdge.id == dstId)
+ removeEdge(srcId, outEdge.id, outEdge.value);
+ }
+ }
+ // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
+ // replaced by a private memref). These edges could come from nodes
+ // other than 'srcId' which were removed in the previous step.
+ if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
+ SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
+ for (auto &inEdge : oldInEdges)
+ if (inEdge.value == oldMemRef)
+ removeEdge(inEdge.id, dstId, inEdge.value);
+ }
+ }
+
+ // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+ // of sibling node 'sidId' into node 'dstId'.
+ void updateEdges(unsigned sibId, unsigned dstId) {
+ // For each edge in 'inEdges[sibId]':
+ // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
+ // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
+ if (inEdges.count(sibId) > 0) {
+ SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
+ for (auto &inEdge : oldInEdges) {
+ addEdge(inEdge.id, dstId, inEdge.value);
+ removeEdge(inEdge.id, sibId, inEdge.value);
+ }
+ }
+
+ // For each edge in 'outEdges[sibId]' to node 'id'
+ // *) Add new edge from 'dstId' to 'outEdge.id'.
+ // *) Remove edge from 'sibId' to 'outEdge.id'.
+ if (outEdges.count(sibId) > 0) {
+ SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
+ for (auto &outEdge : oldOutEdges) {
+ addEdge(dstId, outEdge.id, outEdge.value);
+ removeEdge(sibId, outEdge.id, outEdge.value);
+ }
+ }
+ }
+
+ // Adds ops in 'loads' and 'stores' to node at 'id'.
+ void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
+ const SmallVectorImpl<Operation *> &stores) {
+ Node *node = getNode(id);
+ for (auto *loadOpInst : loads)
+ node->loads.push_back(loadOpInst);
+ for (auto *storeOpInst : stores)
+ node->stores.push_back(storeOpInst);
+ }
+
+ void clearNodeLoadAndStores(unsigned id) {
+ Node *node = getNode(id);
+ node->loads.clear();
+ node->stores.clear();
+ }
+
+ // Calls 'callback' for each input edge incident to node 'id' which carries a
+ // memref dependence.
+ void forEachMemRefInputEdge(unsigned id,
+ const std::function<void(Edge)> &callback) {
+ if (inEdges.count(id) > 0)
+ forEachMemRefEdge(inEdges[id], callback);
+ }
+
+ // Calls 'callback' for each output edge from node 'id' which carries a
+ // memref dependence.
+ void forEachMemRefOutputEdge(unsigned id,
+ const std::function<void(Edge)> &callback) {
+ if (outEdges.count(id) > 0)
+ forEachMemRefEdge(outEdges[id], callback);
+ }
+
+ // Calls 'callback' for each edge in 'edges' which carries a memref
+ // dependence.
+ void forEachMemRefEdge(ArrayRef<Edge> edges,
+ const std::function<void(Edge)> &callback) {
+ for (auto &edge : edges) {
+ // Skip if 'edge' is not a memref dependence edge.
+ if (!edge.value->getType().isa<MemRefType>())
+ continue;
+ assert(nodes.count(edge.id) > 0);
+ // Skip if 'edge.id' is not a loop nest.
+ if (!isa<AffineForOp>(getNode(edge.id)->op))
+ continue;
+ // Visit current input edge 'edge'.
+ callback(edge);
+ }
+ }
+
+ void print(raw_ostream &os) const {
+ os << "\nMemRefDependenceGraph\n";
+ os << "\nNodes:\n";
+ for (auto &idAndNode : nodes) {
+ os << "Node: " << idAndNode.first << "\n";
+ auto it = inEdges.find(idAndNode.first);
+ if (it != inEdges.end()) {
+ for (const auto &e : it->second)
+ os << " InEdge: " << e.id << " " << e.value << "\n";
+ }
+ it = outEdges.find(idAndNode.first);
+ if (it != outEdges.end()) {
+ for (const auto &e : it->second)
+ os << " OutEdge: " << e.id << " " << e.value << "\n";
+ }
+ }
+ }
+ void dump() const { print(llvm::errs()); }
+};
+
+} // end anonymous namespace
+
+// Initializes the data dependence graph by walking operations in 'f'.
+// Assigns each node in the graph a node id based on program order in 'f'.
+// TODO(andydavis) Add support for taking a Block arg to construct the
+// dependence graph at a different depth.
+bool MemRefDependenceGraph::init(FuncOp f) {
+ DenseMap<Value, SetVector<unsigned>> memrefAccesses;
+
+ // TODO: support multi-block functions.
+ if (f.getBlocks().size() != 1)
+ return false;
+
+ DenseMap<Operation *, unsigned> forToNodeMap;
+ for (auto &op : f.front()) {
+ if (auto forOp = dyn_cast<AffineForOp>(op)) {
+ // Create graph node 'id' to represent top-level 'forOp' and record
+ // all loads and store accesses it contains.
+ LoopNestStateCollector collector;
+ collector.collect(&op);
+ // Return false if a non 'affine.for' region was found (not currently
+ // supported).
+ if (collector.hasNonForRegion)
+ return false;
+ Node node(nextNodeId++, &op);
+ for (auto *opInst : collector.loadOpInsts) {
+ node.loads.push_back(opInst);
+ auto memref = cast<AffineLoadOp>(opInst).getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ }
+ for (auto *opInst : collector.storeOpInsts) {
+ node.stores.push_back(opInst);
+ auto memref = cast<AffineStoreOp>(opInst).getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ }
+ forToNodeMap[&op] = node.id;
+ nodes.insert({node.id, node});
+ } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+ // Create graph node for top-level load op.
+ Node node(nextNodeId++, &op);
+ node.loads.push_back(&op);
+ auto memref = cast<AffineLoadOp>(op).getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ nodes.insert({node.id, node});
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ // Create graph node for top-level store op.
+ Node node(nextNodeId++, &op);
+ node.stores.push_back(&op);
+ auto memref = cast<AffineStoreOp>(op).getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ nodes.insert({node.id, node});
+ } else if (op.getNumRegions() != 0) {
+ // Return false if another region is found (not currently supported).
+ return false;
+ } else if (op.getNumResults() > 0 && !op.use_empty()) {
+ // Create graph node for top-level producer of SSA values, which
+ // could be used by loop nest nodes.
+ Node node(nextNodeId++, &op);
+ nodes.insert({node.id, node});
+ }
+ }
+
+ // Add dependence edges between nodes which produce SSA values and their
+ // users.
+ for (auto &idAndNode : nodes) {
+ const Node &node = idAndNode.second;
+ if (!node.loads.empty() || !node.stores.empty())
+ continue;
+ auto *opInst = node.op;
+ for (auto value : opInst->getResults()) {
+ for (auto *user : value->getUsers()) {
+ SmallVector<AffineForOp, 4> loops;
+ getLoopIVs(*user, &loops);
+ if (loops.empty())
+ continue;
+ assert(forToNodeMap.count(loops[0].getOperation()) > 0);
+ unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
+ addEdge(node.id, userLoopNestId, value);
+ }
+ }
+ }
+
+ // Walk memref access lists and add graph edges between dependent nodes.
+ for (auto &memrefAndList : memrefAccesses) {
+ unsigned n = memrefAndList.second.size();
+ for (unsigned i = 0; i < n; ++i) {
+ unsigned srcId = memrefAndList.second[i];
+ bool srcHasStore =
+ getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
+ for (unsigned j = i + 1; j < n; ++j) {
+ unsigned dstId = memrefAndList.second[j];
+ bool dstHasStore =
+ getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
+ if (srcHasStore || dstHasStore)
+ addEdge(srcId, dstId, memrefAndList.first);
+ }
+ }
+ }
+ return true;
+}
+
+// Removes load operations from 'srcLoads' which operate on 'memref', and
+// adds them to 'dstLoads'.
+static void moveLoadsAccessingMemrefTo(Value memref,
+ SmallVectorImpl<Operation *> *srcLoads,
+ SmallVectorImpl<Operation *> *dstLoads) {
+ dstLoads->clear();
+ SmallVector<Operation *, 4> srcLoadsToKeep;
+ for (auto *load : *srcLoads) {
+ if (cast<AffineLoadOp>(load).getMemRef() == memref)
+ dstLoads->push_back(load);
+ else
+ srcLoadsToKeep.push_back(load);
+ }
+ srcLoads->swap(srcLoadsToKeep);
+}
+
+// Returns the innermost common loop depth for the set of operations in 'ops'.
+static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
+ unsigned numOps = ops.size();
+ assert(numOps > 0);
+
+ std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
+ unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
+ for (unsigned i = 0; i < numOps; ++i) {
+ getLoopIVs(*ops[i], &loops[i]);
+ loopDepthLimit =
+ std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
+ }
+
+ unsigned loopDepth = 0;
+ for (unsigned d = 0; d < loopDepthLimit; ++d) {
+ unsigned i;
+ for (i = 1; i < numOps; ++i) {
+ if (loops[i - 1][d] != loops[i][d])
+ break;
+ }
+ if (i != numOps)
+ break;
+ ++loopDepth;
+ }
+ return loopDepth;
+}
+
+// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
+// and 'storeOpInsts' are satisfied.
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
+ ArrayRef<Operation *> storeOpInsts) {
+ // Merge loads and stores into the same array.
+ SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
+ ops.append(storeOpInsts.begin(), storeOpInsts.end());
+
+ // Compute the innermost common loop depth for loads and stores.
+ unsigned loopDepth = getInnermostCommonLoopDepth(ops);
+
+ // Return common loop depth for loads if there are no store ops.
+ if (storeOpInsts.empty())
+ return loopDepth;
+
+ // Check dependences on all pairs of ops in 'ops' and store the minimum
+ // loop depth at which a dependence is satisfied.
+ for (unsigned i = 0, e = ops.size(); i < e; ++i) {
+ auto *srcOpInst = ops[i];
+ MemRefAccess srcAccess(srcOpInst);
+ for (unsigned j = 0; j < e; ++j) {
+ auto *dstOpInst = ops[j];
+ MemRefAccess dstAccess(dstOpInst);
+
+ unsigned numCommonLoops =
+ getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
+ for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+ FlatAffineConstraints dependenceConstraints;
+ // TODO(andydavis) Cache dependence analysis results, check cache here.
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, d, &dependenceConstraints,
+ /*dependenceComponents=*/nullptr);
+ if (hasDependence(result)) {
+ // Store minimum loop depth and break because we want the min 'd' at
+ // which there is a dependence.
+ loopDepth = std::min(loopDepth, d - 1);
+ break;
+ }
+ }
+ }
+ }
+ return loopDepth;
+}
+
+// Sinks all sequential loops to the innermost levels (while preserving
+// relative order among them) and moves all parallel loops to the
+// outermost (while again preserving relative order among them).
+// This can increase the loop depth at which we can fuse a slice, since we are
+// pushing loop carried dependence to a greater depth in the loop nest.
+static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
+ assert(isa<AffineForOp>(node->op));
+ AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
+ node->op = newRootForOp.getOperation();
+}
+
+// TODO(mlir-team): improve/complete this when we have target data.
+static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+ auto elementType = memRefType.getElementType();
+
+ unsigned sizeInBits;
+ if (elementType.isIntOrFloat()) {
+ sizeInBits = elementType.getIntOrFloatBitWidth();
+ } else {
+ auto vectorType = elementType.cast<VectorType>();
+ sizeInBits =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ }
+ return llvm::divideCeil(sizeInBits, 8);
+}
+
+// Creates and returns a private (single-user) memref for fused loop rooted
+// at 'forOp', with (potentially reduced) memref size based on the
+// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+// TODO(bondhugula): consider refactoring the common code from generateDma and
+// this one.
+static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+ unsigned dstLoopDepth,
+ Optional<unsigned> fastMemorySpace,
+ uint64_t localBufSizeThreshold) {
+ auto *forInst = forOp.getOperation();
+
+ // Create builder to insert alloc op just before 'forOp'.
+ OpBuilder b(forInst);
+ // Builder to create constants at the top level.
+ OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
+ // Create new memref type based on slice bounds.
+ auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+ auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+ unsigned rank = oldMemRefType.getRank();
+
+ // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
+ MemRefRegion region(srcStoreOpInst->getLoc());
+ bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+ (void)validRegion;
+ assert(validRegion && "unexpected memref region failure");
+ SmallVector<int64_t, 4> newShape;
+ std::vector<SmallVector<int64_t, 4>> lbs;
+ SmallVector<int64_t, 8> lbDivisors;
+ lbs.reserve(rank);
+ // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
+ // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+ Optional<int64_t> numElements =
+ region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
+ assert(numElements.hasValue() &&
+ "non-constant number of elts in local buffer");
+
+ const FlatAffineConstraints *cst = region.getConstraints();
+ // 'outerIVs' holds the values that this memory region is symbolic/parametric
+ // on; this would correspond to loop IVs surrounding the level at which the
+ // slice is being materialized.
+ SmallVector<Value, 8> outerIVs;
+ cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
+
+ // Build 'rank' AffineExprs from MemRefRegion 'lbs'
+ SmallVector<AffineExpr, 4> offsets;
+ offsets.reserve(rank);
+ for (unsigned d = 0; d < rank; ++d) {
+ assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
+
+ AffineExpr offset = top.getAffineConstantExpr(0);
+ for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
+ offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
+ }
+ assert(lbDivisors[d] > 0);
+ offset =
+ (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
+ offsets.push_back(offset);
+ }
+
+ // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
+ // by 'srcStoreOpInst'.
+ uint64_t bufSize =
+ getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
+ unsigned newMemSpace;
+ if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
+ newMemSpace = fastMemorySpace.getValue();
+ } else {
+ newMemSpace = oldMemRefType.getMemorySpace();
+ }
+ auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
+ {}, newMemSpace);
+ // Gather alloc operands for the dynamic dimensions of the memref.
+ SmallVector<Value, 4> allocOperands;
+ unsigned dynamicDimCount = 0;
+ for (auto dimSize : oldMemRefType.getShape()) {
+ if (dimSize == -1)
+ allocOperands.push_back(
+ top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
+ }
+
+ // Create new private memref for fused loop 'forOp'.
+ // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
+ // consumer loop nests to reduce their live range. Currently they are added
+ // at the beginning of the function, because loop nests can be reordered
+ // during the fusion pass.
+ Value newMemRef =
+ top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
+
+ // Build an AffineMap to remap access functions based on lower bound offsets.
+ SmallVector<AffineExpr, 4> remapExprs;
+ remapExprs.reserve(rank);
+ unsigned zeroOffsetCount = 0;
+ for (unsigned i = 0; i < rank; i++) {
+ if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
+ if (constExpr.getValue() == 0)
+ ++zeroOffsetCount;
+ auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
+
+ auto remapExpr =
+ simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
+ remapExprs.push_back(remapExpr);
+ }
+ auto indexRemap = zeroOffsetCount == rank
+ ? AffineMap()
+ : AffineMap::get(outerIVs.size() + rank, 0, remapExprs);
+ // Replace all users of 'oldMemRef' with 'newMemRef'.
+ LogicalResult res =
+ replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
+ /*extraOperands=*/outerIVs,
+ /*symbolOperands=*/{},
+ /*domInstFilter=*/&*forOp.getBody()->begin());
+ assert(succeeded(res) &&
+ "replaceAllMemrefUsesWith should always succeed here");
+ (void)res;
+ return newMemRef;
+}
+
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', has output edges.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
+// TODO(andydavis) Generalize this to handle more live in/out cases.
+static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+ AffineStoreOp srcLiveOutStoreOp,
+ MemRefDependenceGraph *mdg) {
+ assert(srcLiveOutStoreOp && "Expected a valid store op");
+ auto *dstNode = mdg->getNode(dstId);
+ Value memref = srcLiveOutStoreOp.getMemRef();
+ // Return false if 'srcNode' has more than one output edge on 'memref'.
+ if (mdg->getOutEdgeCount(srcId, memref) > 1)
+ return false;
+
+ // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+ MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+ if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for source operation\n.");
+ return false;
+ }
+ SmallVector<int64_t, 4> srcShape;
+ // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+ // by 'srcStoreOp' at depth 'dstLoopDepth'.
+ Optional<int64_t> srcNumElements =
+ srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+ if (!srcNumElements.hasValue())
+ return false;
+
+ // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
+ // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
+ // each store op in 'dstStoreOps').
+ SmallVector<Operation *, 2> dstStoreOps;
+ dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+ SmallVector<Operation *, 2> dstLoadOps;
+ dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
+
+ auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
+ MemRefRegion dstRegion(dstOpInst->getLoc());
+ if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for dest operation\n.");
+ return false;
+ }
+ SmallVector<int64_t, 4> dstShape;
+ // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
+ // by 'dstOpInst' at depth 'dstLoopDepth'.
+ Optional<int64_t> dstNumElements =
+ dstRegion.getConstantBoundingSizeAndShape(&dstShape);
+ if (!dstNumElements.hasValue())
+ return false;
+
+ // Return false if write region is not a superset of 'srcNodes' write
+ // region to 'memref'.
+ // TODO(andydavis) Check the shape and lower bounds here too.
+ if (srcNumElements != dstNumElements)
+ return false;
+ return true;
+}
+
+// Checks the profitability of fusing a backwards slice of the loop nest
+// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
+// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
+// the memref being produced and consumed, which is an input to the cost model.
+// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
+// 'srcOpInst', as we are slicing w.r.t to that producer.
+// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
+// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
+// will be the unique store op in the src node, which will be used to check
+// that the write region is the same after input-reuse fusion.
+// Returns true if it is profitable to fuse the candidate loop nests. Returns
+// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
+// to materialize the source loop nest slice.
+// The profitability model executes the following steps:
+// *) Computes the backward computation slice at 'srcOpInst'. This
+// computation slice of the loop nest surrounding 'srcOpInst' is
+// represented by modified src loop bounds in 'sliceState', which are
+// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
+// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
+// loop nest is the total number of dynamic operation instances in the loop
+// nest).
+// *) Computes the cost of fusing a slice of the src loop nest into the dst
+// loop nest at various values of dst loop depth, attempting to fuse
+// the largest computation slice at the maximal dst loop depth (closest to
+// the load) to minimize reuse distance and potentially enable subsequent
+// load/store forwarding.
+// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
+// the same memref as is written by 'srcOpInst', then the union of slice
+// loop bounds is used to compute the slice and associated slice cost.
+// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
+// nest, at which the src computation slice is inserted/fused.
+// NOTE: We attempt to maximize the dst loop depth, but there are cases
+// where a particular setting for 'dstLoopNest' might fuse an unsliced
+// loop (within the src computation slice) at a depth which results in
+// excessive recomputation (see unit tests for examples).
+// *) Compares the total cost of the unfused loop nests to the min cost fused
+// loop nest computed in the previous step, and returns true if the latter
+// is lower.
+static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+ ArrayRef<Operation *> dstLoadOpInsts,
+ ArrayRef<Operation *> dstStoreOpInsts,
+ ComputationSliceState *sliceState,
+ unsigned *dstLoopDepth, bool maximalFusion) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Checking whether fusion is profitable between:\n";
+ llvm::dbgs() << " " << *srcOpInst << " and \n";
+ for (auto dstOpInst : dstLoadOpInsts) {
+ llvm::dbgs() << " " << *dstOpInst << "\n";
+ };
+ });
+
+ // Compute cost of sliced and unsliced src loop nest.
+ SmallVector<AffineForOp, 4> srcLoopIVs;
+ getLoopIVs(*srcOpInst, &srcLoopIVs);
+ unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+ // Walk src loop nest and collect stats.
+ LoopNestStats srcLoopNestStats;
+ if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+ return false;
+
+ // Compute cost of dst loop nest.
+ SmallVector<AffineForOp, 4> dstLoopIVs;
+ getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+
+ LoopNestStats dstLoopNestStats;
+ if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
+ return false;
+
+ // Compute the maximum loop depth at which we can can insert the src slice
+ // and still satisfy dest loop nest dependences, for producer-consumer fusion.
+ unsigned maxDstLoopDepth =
+ (srcOpInst == srcStoreOpInst)
+ ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
+ : dstLoopIVs.size();
+ if (maxDstLoopDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
+ return false;
+ }
+
+ // Search for min cost value for 'dstLoopDepth'. At each value of
+ // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
+ // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
+ // of these bounds). Next the union slice bounds are used to calculate
+ // the cost of the slice and the cost of the slice inserted into the dst
+ // loop nest at 'dstLoopDepth'.
+ uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
+ double maxStorageReduction = 0.0;
+ Optional<uint64_t> sliceMemEstimate = None;
+
+ SmallVector<ComputationSliceState, 4> sliceStates;
+ sliceStates.resize(maxDstLoopDepth);
+ // The best loop depth at which to materialize the slice.
+ Optional<unsigned> bestDstLoopDepth = None;
+
+ // Compute op instance count for the src loop nest without iteration slicing.
+ uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+
+ // Compute src loop nest write region size.
+ MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
+ if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for source operation\n.");
+ return false;
+ }
+
+ Optional<int64_t> maybeSrcWriteRegionSizeBytes =
+ srcWriteRegion.getRegionSize();
+ if (!maybeSrcWriteRegionSizeBytes.hasValue())
+ return false;
+ int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
+
+ // Compute op instance count for the src loop nest.
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
+
+ // Evaluate all depth choices for materializing the slice in the destination
+ // loop nest.
+ for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
+ // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
+ if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
+ /*loopDepth=*/i,
+ /*numCommonLoops=*/0,
+ /*isBackwardSlice=*/true,
+ &sliceStates[i - 1]))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "computeSliceUnion failed for loopDepth: " << i << "\n");
+ continue;
+ }
+
+ int64_t fusedLoopNestComputeCost;
+ if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+ dstLoopNestStats, &sliceStates[i - 1],
+ &fusedLoopNestComputeCost)) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
+ continue;
+ }
+
+ double additionalComputeFraction =
+ fusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1;
+
+ // Determine what the slice write MemRefRegion would be, if the src loop
+ // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
+ // nest at loop depth 'i'
+ MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
+ if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
+ &sliceStates[i - 1]))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Failed to compute slice write region at loopDepth: " << i
+ << "\n");
+ continue;
+ }
+
+ Optional<int64_t> maybeSliceWriteRegionSizeBytes =
+ sliceWriteRegion.getRegionSize();
+ if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
+ maybeSliceWriteRegionSizeBytes.getValue() == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Failed to get slice write region size at loopDepth: " << i
+ << "\n");
+ continue;
+ }
+ int64_t sliceWriteRegionSizeBytes =
+ maybeSliceWriteRegionSizeBytes.getValue();
+
+ // If we are fusing for reuse, check that write regions remain the same.
+ // TODO(andydavis) Write region check should check sizes and offsets in
+ // each dimension, so that we are sure they are covering the same memref
+ // region. Also, move this out to a isMemRefRegionSuperSet helper function.
+ if (srcOpInst != srcStoreOpInst &&
+ sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
+ continue;
+
+ double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
+ static_cast<double>(sliceWriteRegionSizeBytes);
+
+ LLVM_DEBUG({
+ std::stringstream msg;
+ msg << " evaluating fusion profitability at depth : " << i << "\n"
+ << std::fixed << std::setprecision(2)
+ << " additional compute fraction: "
+ << 100.0 * additionalComputeFraction << "%\n"
+ << " storage reduction factor: " << storageReduction << "x\n"
+ << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
+ << " src write region size: " << srcWriteRegionSizeBytes << "\n"
+ << " slice write region size: " << sliceWriteRegionSizeBytes
+ << "\n";
+ llvm::dbgs() << msg.str();
+ });
+
+ double computeToleranceThreshold =
+ clFusionAddlComputeTolerance.getNumOccurrences() > 0
+ ? clFusionAddlComputeTolerance
+ : LoopFusion::kComputeToleranceThreshold;
+
+ // TODO(b/123247369): This is a placeholder cost model.
+ // Among all choices that add an acceptable amount of redundant computation
+ // (as per computeToleranceThreshold), we will simply pick the one that
+ // reduces the intermediary size the most.
+ if ((storageReduction > maxStorageReduction) &&
+ (maximalFusion ||
+ (additionalComputeFraction < computeToleranceThreshold))) {
+ maxStorageReduction = storageReduction;
+ bestDstLoopDepth = i;
+ minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+ sliceMemEstimate = sliceWriteRegionSizeBytes;
+ }
+ }
+
+ // A simple cost model: fuse if it reduces the memory footprint. If
+ // -maximal-fusion is set, fuse nevertheless.
+
+ if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "All fusion choices involve more than the threshold amount of "
+ "redundant computation; NOT fusing.\n");
+ return false;
+ }
+
+ if (!bestDstLoopDepth.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+ return false;
+ }
+
+ // Set dstLoopDepth based on best values from search.
+ *dstLoopDepth = bestDstLoopDepth.getValue();
+
+ LLVM_DEBUG(
+ llvm::dbgs() << " LoopFusion fusion stats:"
+ << "\n best loop depth: " << bestDstLoopDepth
+ << "\n src loop nest compute cost: " << srcLoopNestCost
+ << "\n dst loop nest compute cost: " << dstLoopNestCost
+ << "\n fused loop nest compute cost: "
+ << minFusedLoopNestComputeCost << "\n");
+
+ auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+
+ Optional<double> storageReduction = None;
+
+ if (!maximalFusion) {
+ if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ return false;
+ }
+
+ auto srcMemSizeVal = srcMemSize.getValue();
+ auto dstMemSizeVal = dstMemSize.getValue();
+
+ assert(sliceMemEstimate.hasValue() && "expected value");
+ auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+
+ LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
+ << " dst mem: " << dstMemSizeVal << "\n"
+ << " fused mem: " << fusedMem << "\n"
+ << " slice mem: " << sliceMemEstimate << "\n");
+
+ if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ return false;
+ }
+ storageReduction =
+ 100.0 *
+ (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+ }
+
+ double additionalComputeFraction =
+ 100.0 * (minFusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1);
+ (void)additionalComputeFraction;
+ LLVM_DEBUG({
+ std::stringstream msg;
+ msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
+ << std::setprecision(2) << additionalComputeFraction
+ << "% redundant computation and a ";
+ msg << (storageReduction.hasValue()
+ ? std::to_string(storageReduction.getValue())
+ : "<unknown>");
+ msg << "% storage reduction.\n";
+ llvm::dbgs() << msg.str();
+ });
+
+ // Update return parameter 'sliceState' with 'bestSliceState'.
+ ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
+ sliceState->lbs = bestSliceState->lbs;
+ sliceState->ubs = bestSliceState->ubs;
+ sliceState->lbOperands = bestSliceState->lbOperands;
+ sliceState->ubOperands = bestSliceState->ubOperands;
+
+ // Canonicalize slice bound affine maps.
+ for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+ if (sliceState->lbs[i] != AffineMap()) {
+ canonicalizeMapAndOperands(&sliceState->lbs[i],
+ &sliceState->lbOperands[i]);
+ }
+ if (sliceState->ubs[i] != AffineMap()) {
+ canonicalizeMapAndOperands(&sliceState->ubs[i],
+ &sliceState->ubOperands[i]);
+ }
+ }
+ return true;
+}
+
+namespace {
+
+// GreedyFusion greedily fuses loop nests which have a producer/consumer or
+// input-reuse relationship on a memref, with the goal of improving locality.
+//
+// The steps of the producer-consumer fusion algorithm are as follows:
+//
+// *) A worklist is initialized with node ids from the dependence graph.
+// *) For each node id in the worklist:
+// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
+// candidate destination AffineForOp into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
+// *) For each LoadOp in 'dstLoadOps' do:
+// *) Look up dependent loop nests which have a single store op to the same
+// memref.
+// *) Check if dependences would be violated by the fusion.
+// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
+// bounds to be functions of 'dstLoopNest' IVs and symbols.
+// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
+// at a loop depth determined by the cost model in 'isFusionProfitable'.
+// *) Add the newly fused load/store operations to the state,
+// and also add newly fused load ops to 'dstLoopOps' to be considered
+// as fusion dst load ops in another iteration.
+// *) Remove old src loop nest and its associated state.
+//
+// The steps of the input-reuse fusion algorithm are as follows:
+//
+// *) Initialize 'worklist' with node ids from the dependence graph.
+// *) For each 'dstNode' in the worklist:
+// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
+// loads from the same memref, but which has no dependence paths to/from.
+// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
+// bounds to be functions of 'dstLoopNest' IVs and symbols.
+// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
+// at a loop depth determined by the cost model in 'isFusionProfitable'.
+// This function also checks that the memref write region of 'sibLoopNest',
+// is preserved in the fused loop nest.
+// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
+//
+// Given a graph where top-level operations are vertices in the set 'V' and
+// edges in the set 'E' are dependences between vertices, this algorithm
+// takes O(V) time for initialization, and has runtime O(V + E).
+//
+// This greedy algorithm is not 'maximal' due to the current restriction of
+// fusing along single producer consumer edges, but there is a TODO to fix this.
+//
+// TODO(andydavis) Experiment with other fusion policies.
+struct GreedyFusion {
+public:
+ // The data dependence graph to traverse during fusion.
+ MemRefDependenceGraph *mdg;
+ // Worklist of graph nodes visited during the fusion pass.
+ SmallVector<unsigned, 8> worklist;
+ // Set of graph nodes which are present on the worklist.
+ llvm::SmallDenseSet<unsigned, 16> worklistSet;
+ // Parameter for local buffer size threshold.
+ unsigned localBufSizeThreshold;
+ // Parameter for fast memory space.
+ Optional<unsigned> fastMemorySpace;
+ // If true, ignore any additional (redundant) computation tolerance threshold
+ // that would have prevented fusion.
+ bool maximalFusion;
+
+ using Node = MemRefDependenceGraph::Node;
+
+ GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
+ Optional<unsigned> fastMemorySpace, bool maximalFusion)
+ : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
+ fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
+
+ // Initializes 'worklist' with nodes from 'mdg'
+ void init() {
+ // TODO(andydavis) Add a priority queue for prioritizing nodes by different
+ // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
+ worklist.clear();
+ worklistSet.clear();
+ for (auto &idAndNode : mdg->nodes) {
+ const Node &node = idAndNode.second;
+ worklist.push_back(node.id);
+ worklistSet.insert(node.id);
+ }
+ }
+
+ // Run the GreedyFusion pass.
+ // *) First pass through the nodes fuses single-use producer nodes into their
+ // unique consumer.
+ // *) Second pass fuses sibling nodes which share no dependence edges.
+ // *) Third pass fuses any remaining producer nodes into their users.
+ void run() {
+ // TODO(andydavis) Run this repeatedly until a fixed-point is reached.
+ fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
+ fuseSiblingNodes();
+ fuseProducerConsumerNodes(
+ /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
+ eraseUnusedMemRefAllocations();
+ }
+
+ void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+ init();
+ while (!worklist.empty()) {
+ unsigned dstId = worklist.back();
+ worklist.pop_back();
+ worklistSet.erase(dstId);
+
+ // Skip if this node was removed (fused into another node).
+ if (mdg->nodes.count(dstId) == 0)
+ continue;
+ // Get 'dstNode' into which to attempt fusion.
+ auto *dstNode = mdg->getNode(dstId);
+ // Skip if 'dstNode' is not a loop nest.
+ if (!isa<AffineForOp>(dstNode->op))
+ continue;
+ // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
+ // while preserving relative order. This can increase the maximum loop
+ // depth at which we can fuse a slice of a producer loop nest into a
+ // consumer loop nest.
+ sinkSequentialLoops(dstNode);
+
+ SmallVector<Operation *, 4> loads = dstNode->loads;
+ SmallVector<Operation *, 4> dstLoadOpInsts;
+ DenseSet<Value> visitedMemrefs;
+ while (!loads.empty()) {
+ // Get memref of load on top of the stack.
+ auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+ if (visitedMemrefs.count(memref) > 0)
+ continue;
+ visitedMemrefs.insert(memref);
+ // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
+ moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
+ // Skip if no input edges along which to fuse.
+ if (mdg->inEdges.count(dstId) == 0)
+ continue;
+ // Iterate through in-edges for 'dstId' and src node id for any
+ // edges on 'memref'.
+ SmallVector<unsigned, 2> srcNodeIds;
+ for (auto &srcEdge : mdg->inEdges[dstId]) {
+ // Skip 'srcEdge' if not for 'memref'.
+ if (srcEdge.value != memref)
+ continue;
+ srcNodeIds.push_back(srcEdge.id);
+ }
+ for (unsigned srcId : srcNodeIds) {
+ // Skip if this node was removed (fused into another node).
+ if (mdg->nodes.count(srcId) == 0)
+ continue;
+ // Get 'srcNode' from which to attempt fusion into 'dstNode'.
+ auto *srcNode = mdg->getNode(srcId);
+ // Skip if 'srcNode' is not a loop nest.
+ if (!isa<AffineForOp>(srcNode->op))
+ continue;
+ // Skip if 'srcNode' has more than one live-out store to a
+ // function-local memref.
+ // TODO(andydavis) Support more generic multi-output src loop nests
+ // fusion.
+ auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+ if (!srcStoreOp) {
+ // Get the src store op at the deepest loop depth.
+ // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+ // feasibility for loops with multiple stores.
+ unsigned maxLoopDepth = 0;
+ for (auto *op : srcNode->stores) {
+ auto storeOp = cast<AffineStoreOp>(op);
+ if (storeOp.getMemRef() != memref) {
+ srcStoreOp = nullptr;
+ break;
+ }
+ unsigned loopDepth = getNestingDepth(*storeOp);
+ if (loopDepth > maxLoopDepth) {
+ maxLoopDepth = loopDepth;
+ srcStoreOp = storeOp;
+ }
+ }
+ if (!srcStoreOp)
+ continue;
+ }
+
+ // Unique outgoing store found must write to 'memref' since 'memref'
+ // is the one that established the producer-consumer relationship
+ // between 'srcNode' and 'dstNode'.
+ assert(srcStoreOp.getMemRef() == memref &&
+ "Found store to unexpected memref");
+
+ // Skip if 'srcNode' writes to any live in or escaping memrefs,
+ // and cannot be fused.
+ bool writesToLiveInOrOut =
+ mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+ if (writesToLiveInOrOut &&
+ !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
+ continue;
+
+ // Don't create a private memref if 'writesToLiveInOrOut'.
+ bool createPrivateMemref = !writesToLiveInOrOut;
+ // Don't create a private memref if 'srcNode' has in edges on
+ // 'memref', or if 'dstNode' has out edges on 'memref'.
+ if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+ mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+ createPrivateMemref = false;
+ }
+
+ // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
+ if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
+ continue;
+
+ // Compute an operation list insertion point for the fused loop
+ // nest which preserves dependences.
+ Operation *insertPointInst =
+ mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
+ if (insertPointInst == nullptr)
+ continue;
+
+ // Compute the innermost common loop depth for dstNode loads/stores.
+ SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
+ dstNode->loads.end());
+ dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
+ unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
+ // Check the feasibility of fusing src loop nest into dst loop nest
+ // at loop depths in range [1, dstLoopDepthTest].
+ // TODO(andydavis) Use slice union computation and union of memref
+ // read/write regions to cost model and fusion.
+ bool canFuse = false;
+ for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+ ComputationSliceState sliceUnion;
+ FusionResult result = mlir::canFuseLoops(
+ cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
+ /*dstLoopDepth=*/i, &sliceUnion);
+ if (result.value == FusionResult::Success)
+ canFuse = true;
+ }
+
+ // Skip if fusion is not feasible at all loop depths.
+ if (!canFuse)
+ continue;
+
+ // Gather 'dstNode' store ops to 'memref'.
+ SmallVector<Operation *, 2> dstStoreOpInsts;
+ for (auto *storeOpInst : dstNode->stores)
+ if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ dstStoreOpInsts.push_back(storeOpInst);
+
+ unsigned bestDstLoopDepth;
+ mlir::ComputationSliceState sliceState;
+ // Check if fusion would be profitable.
+ if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+ dstStoreOpInsts, &sliceState,
+ &bestDstLoopDepth, maximalFusion))
+ continue;
+
+ // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
+ auto sliceLoopNest = mlir::insertBackwardComputationSlice(
+ srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+ if (sliceLoopNest) {
+ LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
+ << *sliceLoopNest.getOperation() << "\n");
+ // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+ if (insertPointInst != dstAffineForOp.getOperation()) {
+ dstAffineForOp.getOperation()->moveBefore(insertPointInst);
+ }
+ // Update edges between 'srcNode' and 'dstNode'.
+ mdg->updateEdges(srcNode->id, dstNode->id, memref,
+ createPrivateMemref);
+
+ // Collect slice loop stats.
+ LoopNestStateCollector sliceCollector;
+ sliceCollector.collect(sliceLoopNest.getOperation());
+ // Promote single iteration slice loops to single IV value.
+ for (auto forOp : sliceCollector.forOps) {
+ promoteIfSingleIteration(forOp);
+ }
+ if (createPrivateMemref) {
+ // Create private memref for 'memref' in 'dstAffineForOp'.
+ SmallVector<Operation *, 4> storesForMemref;
+ for (auto *storeOpInst : sliceCollector.storeOpInsts) {
+ if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ storesForMemref.push_back(storeOpInst);
+ }
+ // TODO(andydavis) Use union of memref write regions to compute
+ // private memref footprint.
+ auto newMemRef = createPrivateMemRef(
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+ fastMemorySpace, localBufSizeThreshold);
+ visitedMemrefs.insert(newMemRef);
+ // Create new node in dependence graph for 'newMemRef' alloc op.
+ unsigned newMemRefNodeId =
+ mdg->addNode(newMemRef->getDefiningOp());
+ // Add edge from 'newMemRef' node to dstNode.
+ mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+ }
+
+ // Collect dst loop stats after memref privatization transformation.
+ LoopNestStateCollector dstLoopCollector;
+ dstLoopCollector.collect(dstAffineForOp.getOperation());
+
+ // Add new load ops to current Node load op list 'loads' to
+ // continue fusing based on new operands.
+ for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+ auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+ if (visitedMemrefs.count(loadMemRef) == 0)
+ loads.push_back(loadOpInst);
+ }
+
+ // Clear and add back loads and stores.
+ mdg->clearNodeLoadAndStores(dstNode->id);
+ mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+ dstLoopCollector.storeOpInsts);
+ // Remove old src loop nest if it no longer has outgoing dependence
+ // edges, and if it does not write to a memref which escapes the
+ // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
+ // been fused into 'dstNode' and write region of 'dstNode' covers
+ // the write region of 'srcNode', and 'srcNode' has no other users
+ // so it is safe to remove.
+ if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+ mdg->removeNode(srcNode->id);
+ srcNode->op->erase();
+ } else {
+ // Add remaining users of 'oldMemRef' back on the worklist (if not
+ // already there), as its replacement with a local/private memref
+ // has reduced dependences on 'oldMemRef' which may have created
+ // new fusion opportunities.
+ if (mdg->outEdges.count(srcNode->id) > 0) {
+ SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+ mdg->outEdges[srcNode->id];
+ for (auto &outEdge : oldOutEdges) {
+ if (outEdge.value == memref &&
+ worklistSet.count(outEdge.id) == 0) {
+ worklist.push_back(outEdge.id);
+ worklistSet.insert(outEdge.id);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Visits each node in the graph, and for each node, attempts to fuse it with
+ // its sibling nodes (nodes which share a parent, but no dependence edges).
+ void fuseSiblingNodes() {
+ init();
+ while (!worklist.empty()) {
+ unsigned dstId = worklist.back();
+ worklist.pop_back();
+ worklistSet.erase(dstId);
+
+ // Skip if this node was removed (fused into another node).
+ if (mdg->nodes.count(dstId) == 0)
+ continue;
+ // Get 'dstNode' into which to attempt fusion.
+ auto *dstNode = mdg->getNode(dstId);
+ // Skip if 'dstNode' is not a loop nest.
+ if (!isa<AffineForOp>(dstNode->op))
+ continue;
+ // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
+ fuseWithSiblingNodes(dstNode);
+ }
+ }
+
+ // Attempt to fuse 'dstNode' with sibling nodes in the graph.
+ void fuseWithSiblingNodes(Node *dstNode) {
+ DenseSet<unsigned> visitedSibNodeIds;
+ std::pair<unsigned, Value> idAndMemref;
+ while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
+ unsigned sibId = idAndMemref.first;
+ Value memref = idAndMemref.second;
+ // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
+ // stores to the same memref in 'sibNode' loop nest.
+ auto *sibNode = mdg->getNode(sibId);
+ // Compute an operation list insertion point for the fused loop
+ // nest which preserves dependences.
+ assert(sibNode->op->getBlock() == dstNode->op->getBlock());
+ Operation *insertPointInst =
+ sibNode->op->isBeforeInBlock(dstNode->op)
+ ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
+ : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
+ if (insertPointInst == nullptr)
+ continue;
+
+ // Check if fusion would be profitable and at what depth.
+
+ // Get unique 'sibNode' load op to 'memref'.
+ SmallVector<Operation *, 2> sibLoadOpInsts;
+ sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
+ // Currently findSiblingNodeToFuse searches for siblings with one load.
+ assert(sibLoadOpInsts.size() == 1);
+ Operation *sibLoadOpInst = sibLoadOpInsts[0];
+ assert(!sibNode->stores.empty());
+ // TODO(andydavis) Choose the store which postdominates all other stores.
+ auto *sibStoreOpInst = sibNode->stores.back();
+
+ // Gather 'dstNode' load ops to 'memref'.
+ SmallVector<Operation *, 2> dstLoadOpInsts;
+ dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
+
+ // Gather 'dstNode' store ops to 'memref'.
+ SmallVector<Operation *, 2> dstStoreOpInsts;
+ dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
+
+ unsigned bestDstLoopDepth;
+ mlir::ComputationSliceState sliceState;
+
+ // Check if fusion would be profitable.
+ if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
+ dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
+ maximalFusion))
+ continue;
+
+ // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
+ auto sliceLoopNest = mlir::insertBackwardComputationSlice(
+ sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
+ if (sliceLoopNest != nullptr) {
+ auto dstForInst = cast<AffineForOp>(dstNode->op);
+ // Update operation position of fused loop nest (if needed).
+ if (insertPointInst != dstForInst.getOperation()) {
+ dstForInst.getOperation()->moveBefore(insertPointInst);
+ }
+ // Update data dependence graph state post fusion.
+ updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
+ }
+ }
+ }
+
+ // Searches function argument uses and the graph from 'dstNode' looking for a
+ // fusion candidate sibling node which shares no dependences with 'dstNode'
+ // but which loads from the same memref. Returns true and sets
+ // 'idAndMemrefToFuse' on success. Returns false otherwise.
+ bool findSiblingNodeToFuse(Node *dstNode,
+ DenseSet<unsigned> *visitedSibNodeIds,
+ std::pair<unsigned, Value> *idAndMemrefToFuse) {
+ // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
+ // on 'memref'.
+ auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
+ // Skip if 'outEdge' is not a read-after-write dependence.
+ // TODO(andydavis) Remove restrict to single load op restriction.
+ if (sibNode->getLoadOpCount(memref) != 1)
+ return false;
+ // Skip if there exists a path of dependent edges between
+ // 'sibNode' and 'dstNode'.
+ if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
+ mdg->hasDependencePath(dstNode->id, sibNode->id))
+ return false;
+ // Skip sib node if it loads to (and stores from) the same memref on
+ // which it also has an input dependence edge.
+ DenseSet<Value> loadAndStoreMemrefSet;
+ sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
+ if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
+ return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
+ }))
+ return false;
+
+ // Check that all stores are to the same memref.
+ DenseSet<Value> storeMemrefs;
+ for (auto *storeOpInst : sibNode->stores) {
+ storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+ }
+ if (storeMemrefs.size() != 1)
+ return false;
+ return true;
+ };
+
+ // Search for siblings which load the same memref function argument.
+ auto fn = dstNode->op->getParentOfType<FuncOp>();
+ for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+ for (auto *user : fn.getArgument(i)->getUsers()) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+ // Gather loops surrounding 'use'.
+ SmallVector<AffineForOp, 4> loops;
+ getLoopIVs(*user, &loops);
+ // Skip 'use' if it is not within a loop nest.
+ if (loops.empty())
+ continue;
+ Node *sibNode = mdg->getForOpNode(loops[0]);
+ assert(sibNode != nullptr);
+ // Skip 'use' if it not a sibling to 'dstNode'.
+ if (sibNode->id == dstNode->id)
+ continue;
+ // Skip 'use' if it has been visited.
+ if (visitedSibNodeIds->count(sibNode->id) > 0)
+ continue;
+ // Skip 'use' if it does not load from the same memref as 'dstNode'.
+ auto memref = loadOp.getMemRef();
+ if (dstNode->getLoadOpCount(memref) == 0)
+ continue;
+ // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
+ if (canFuseWithSibNode(sibNode, memref)) {
+ visitedSibNodeIds->insert(sibNode->id);
+ idAndMemrefToFuse->first = sibNode->id;
+ idAndMemrefToFuse->second = memref;
+ return true;
+ }
+ }
+ }
+ }
+
+ // Search for siblings by following edges through an intermediate src node.
+ // Collect candidate 'dstNode' input edges in 'inEdges'.
+ SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
+ mdg->forEachMemRefInputEdge(
+ dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
+ // Add 'inEdge' if it is a read-after-write dependence.
+ if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
+ mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
+ inEdges.push_back(inEdge);
+ });
+
+ // Search for sibling nodes to fuse by visiting output edges from each input
+ // edge in 'inEdges'.
+ for (auto &inEdge : inEdges) {
+ // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
+ SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
+ mdg->forEachMemRefOutputEdge(
+ inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
+ unsigned sibNodeId = outEdge.id;
+ if (visitedSibNodeIds->count(sibNodeId) > 0)
+ return;
+ // Skip output edge if not a sibling using the same memref.
+ if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
+ return;
+ auto *sibNode = mdg->getNode(sibNodeId);
+ if (!isa<AffineForOp>(sibNode->op))
+ return;
+ // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
+ if (canFuseWithSibNode(sibNode, outEdge.value)) {
+ // Add candidate 'outEdge' to sibling node.
+ outEdges.push_back(outEdge);
+ }
+ });
+
+ // Add first candidate if any were returned.
+ if (!outEdges.empty()) {
+ visitedSibNodeIds->insert(outEdges[0].id);
+ idAndMemrefToFuse->first = outEdges[0].id;
+ idAndMemrefToFuse->second = outEdges[0].value;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
+ Node *dstNode) {
+ // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
+ mdg->updateEdges(sibNode->id, dstNode->id);
+
+ // Collect slice loop stats.
+ LoopNestStateCollector sliceCollector;
+ sliceCollector.collect(sliceLoopNest.getOperation());
+ // Promote single iteration slice loops to single IV value.
+ for (auto forOp : sliceCollector.forOps) {
+ promoteIfSingleIteration(forOp);
+ }
+
+ // Collect dst loop stats after memref privatization transformation.
+ auto dstForInst = cast<AffineForOp>(dstNode->op);
+ LoopNestStateCollector dstLoopCollector;
+ dstLoopCollector.collect(dstForInst.getOperation());
+ // Clear and add back loads and stores
+ mdg->clearNodeLoadAndStores(dstNode->id);
+ mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
+ dstLoopCollector.storeOpInsts);
+ // Remove old sibling loop nest if it no longer has outgoing dependence
+ // edges, and it does not write to a memref which escapes the
+ // function.
+ if (mdg->getOutEdgeCount(sibNode->id) == 0) {
+ mdg->removeNode(sibNode->id);
+ sibNode->op->erase();
+ }
+ }
+
+ // Clean up any allocs with no users.
+ void eraseUnusedMemRefAllocations() {
+ for (auto &pair : mdg->memrefEdgeCount) {
+ if (pair.second > 0)
+ continue;
+ auto memref = pair.first;
+ // Skip if there exist other uses (return operation or function calls).
+ if (!memref->use_empty())
+ continue;
+ // Use list expected to match the dep graph info.
+ auto *op = memref->getDefiningOp();
+ if (isa_and_nonnull<AllocOp>(op))
+ op->erase();
+ }
+ }
+};
+
+} // end anonymous namespace
+
+void LoopFusion::runOnFunction() {
+ // Override if a command line argument was provided.
+ if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
+ fastMemorySpace = clFusionFastMemorySpace.getValue();
+ }
+
+ // Override if a command line argument was provided.
+ if (clFusionLocalBufThreshold.getNumOccurrences() > 0) {
+ localBufSizeThreshold = clFusionLocalBufThreshold * 1024;
+ }
+
+ if (clMaximalLoopFusion.getNumOccurrences() > 0)
+ maximalFusion = clMaximalLoopFusion;
+
+ MemRefDependenceGraph g;
+ if (g.init(getFunction()))
+ GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion)
+ .run();
+}
+
+static PassRegistration<LoopFusion> pass("affine-loop-fusion",
+ "Fuse loop nests");
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
new file mode 100644
index 00000000000..fb3d0c0b45c
--- /dev/null
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -0,0 +1,140 @@
+//===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop invariant code motion.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopLikeInterface.h"
+#include "mlir/Transforms/SideEffectsInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "licm"
+
+using namespace mlir;
+
+namespace {
+
+using SideEffecting = SideEffectsInterface::SideEffecting;
+
+/// Loop invariant code motion (LICM) pass.
+struct LoopInvariantCodeMotion : public OperationPass<LoopInvariantCodeMotion> {
+public:
+ void runOnOperation() override;
+};
+
+// Checks whether the given op can be hoisted by checking that
+// - the op and any of its contained operations do not depend on SSA values
+// defined inside of the loop (by means of calling definedOutside).
+// - the op has no side-effects. If sideEffecting is Never, sideeffects of this
+// op and its nested ops are ignored.
+static bool canBeHoisted(Operation *op,
+ function_ref<bool(Value)> definedOutside,
+ SideEffecting sideEffecting,
+ SideEffectsInterface &interface) {
+ // Check that dependencies are defined outside of loop.
+ if (!llvm::all_of(op->getOperands(), definedOutside))
+ return false;
+ // Check whether this op is side-effect free. If we already know that there
+ // can be no side-effects because the surrounding op has claimed so, we can
+ // (and have to) skip this step.
+ auto thisOpIsSideEffecting = sideEffecting;
+ if (thisOpIsSideEffecting != SideEffecting::Never) {
+ thisOpIsSideEffecting = interface.isSideEffecting(op);
+ // If the op always has sideeffects, we cannot hoist.
+ if (thisOpIsSideEffecting == SideEffecting::Always)
+ return false;
+ }
+ // Recurse into the regions for this op and check whether the contained ops
+ // can be hoisted.
+ for (auto &region : op->getRegions()) {
+ for (auto &block : region.getBlocks()) {
+ for (auto &innerOp : block) {
+ if (innerOp.isKnownTerminator())
+ continue;
+ if (!canBeHoisted(&innerOp, definedOutside, thisOpIsSideEffecting,
+ interface))
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike,
+ SideEffectsInterface &interface) {
+ auto &loopBody = looplike.getLoopBody();
+
+ // We use two collections here as we need to preserve the order for insertion
+ // and this is easiest.
+ SmallPtrSet<Operation *, 8> willBeMovedSet;
+ SmallVector<Operation *, 8> opsToMove;
+
+ // Helper to check whether an operation is loop invariant wrt. SSA properties.
+ auto isDefinedOutsideOfBody = [&](Value value) {
+ auto definingOp = value->getDefiningOp();
+ return (definingOp && !!willBeMovedSet.count(definingOp)) ||
+ looplike.isDefinedOutsideOfLoop(value);
+ };
+
+ // Do not use walk here, as we do not want to go into nested regions and hoist
+ // operations from there. These regions might have semantics unknown to this
+ // rewriting. If the nested regions are loops, they will have been processed.
+ for (auto &block : loopBody) {
+ for (auto &op : block.without_terminator()) {
+ if (canBeHoisted(&op, isDefinedOutsideOfBody,
+ mlir::SideEffectsDialectInterface::Recursive,
+ interface)) {
+ opsToMove.push_back(&op);
+ willBeMovedSet.insert(&op);
+ }
+ }
+ }
+
+ // For all instructions that we found to be invariant, move outside of the
+ // loop.
+ auto result = looplike.moveOutOfLoop(opsToMove);
+ LLVM_DEBUG(looplike.print(llvm::dbgs() << "Modified loop\n"));
+ return result;
+}
+
+} // end anonymous namespace
+
+void LoopInvariantCodeMotion::runOnOperation() {
+ SideEffectsInterface interface(&getContext());
+ // Walk through all loops in a function in innermost-loop-first order. This
+ // way, we first LICM from the inner loop, and place the ops in
+ // the outer loop, which in turn can be further LICM'ed.
+ getOperation()->walk([&](Operation *op) {
+ if (auto looplike = dyn_cast<LoopLikeOpInterface>(op)) {
+ LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
+ if (failed(moveLoopInvariantCode(looplike, interface)))
+ signalPassFailure();
+ }
+ });
+}
+
+// Include the generated code for the loop-like interface here, as it otherwise
+// has no compilation unit. This works as loop-invariant code motion is the
+// only user of that interface.
+#include "mlir/Transforms/LoopLikeInterface.cpp.inc"
+
+std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {
+ return std::make_unique<LoopInvariantCodeMotion>();
+}
+
+static PassRegistration<LoopInvariantCodeMotion>
+ pass("loop-invariant-code-motion",
+ "Hoist loop invariant instructions outside of the loop");
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
new file mode 100644
index 00000000000..d3dc81760fc
--- /dev/null
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -0,0 +1,402 @@
+//===- LoopTiling.cpp --- Loop tiling pass ------------------------------*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to tile loop nests.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-tile"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+static llvm::cl::opt<unsigned long long>
+ clCacheSizeKiB("tile-cache-size",
+ llvm::cl::desc("Set size of cache to tile for in KiB"),
+ llvm::cl::cat(clOptionsCategory));
+
+// Tile size to use for all loops (overrides -tile-sizes if provided).
+static llvm::cl::opt<unsigned>
+ clTileSize("tile-size", llvm::cl::desc("Use this tile size for all loops"),
+ llvm::cl::cat(clOptionsCategory));
+
+// List of tile sizes. If any of them aren't provided, they are filled with
+// clTileSize / kDefaultTileSize.
+static llvm::cl::list<unsigned> clTileSizes(
+ "tile-sizes",
+ llvm::cl::desc(
+ "List of tile sizes for each perfect nest (overridden by -tile-size)"),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+namespace {
+
+/// A pass to perform loop tiling on all suitable loop nests of a Function.
+struct LoopTiling : public FunctionPass<LoopTiling> {
+ explicit LoopTiling(uint64_t cacheSizeBytes = kDefaultCacheMemCapacity,
+ bool avoidMaxMinBounds = true)
+ : cacheSizeBytes(cacheSizeBytes), avoidMaxMinBounds(avoidMaxMinBounds) {}
+
+ void runOnFunction() override;
+ void getTileSizes(ArrayRef<AffineForOp> band,
+ SmallVectorImpl<unsigned> *tileSizes);
+
+ // Default tile size if nothing is provided.
+ constexpr static unsigned kDefaultTileSize = 4;
+ constexpr static uint64_t kDefaultCacheMemCapacity = 512 * 1024UL;
+
+ // Capacity of the cache to tile for.
+ uint64_t cacheSizeBytes;
+ // If true, tile sizes are set to avoid max/min in bounds if possible.
+ bool avoidMaxMinBounds;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to perform loop tiling on all suitable loop nests of a
+/// Function.
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createLoopTilingPass(uint64_t cacheSizeBytes) {
+ return std::make_unique<LoopTiling>(cacheSizeBytes);
+}
+
+// Move the loop body of AffineForOp 'src' from 'src' into the specified
+// location in destination's body, ignoring the terminator.
+static inline void moveLoopBody(AffineForOp src, AffineForOp dest,
+ Block::iterator loc) {
+ auto &insts = src.getBody()->getOperations();
+ dest.getBody()->getOperations().splice(loc, insts, insts.begin(),
+ std::prev(insts.end()));
+}
+
+// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's
+// body.
+static inline void moveLoopBody(AffineForOp src, AffineForOp dest) {
+ moveLoopBody(src, dest, dest.getBody()->begin());
+}
+
+/// Constructs and sets new loop bounds after tiling for the case of
+/// hyper-rectangular index sets, where the bounds of one dimension do not
+/// depend on other dimensions. Bounds of each dimension can thus be treated
+/// independently, and deriving the new bounds is much simpler and faster
+/// than for the case of tiling arbitrary polyhedral shapes.
+static void
+constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
+ MutableArrayRef<AffineForOp> newLoops,
+ ArrayRef<unsigned> tileSizes) {
+ assert(!origLoops.empty());
+ assert(origLoops.size() == tileSizes.size());
+
+ OpBuilder b(origLoops[0].getOperation());
+ unsigned width = origLoops.size();
+
+ // Bounds for tile space loops.
+ for (unsigned i = 0; i < width; i++) {
+ auto lbOperands = origLoops[i].getLowerBoundOperands();
+ auto ubOperands = origLoops[i].getUpperBoundOperands();
+ SmallVector<Value, 4> newLbOperands(lbOperands);
+ SmallVector<Value, 4> newUbOperands(ubOperands);
+ newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
+ newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
+ newLoops[i].setStep(tileSizes[i]);
+ }
+ // Bounds for intra-tile loops.
+ for (unsigned i = 0; i < width; i++) {
+ int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
+ auto mayBeConstantCount = getConstantTripCount(origLoops[i]);
+ // The lower bound is just the tile-space loop.
+ AffineMap lbMap = b.getDimIdentityMap();
+ newLoops[width + i].setLowerBound(
+ /*operands=*/newLoops[i].getInductionVar(), lbMap);
+
+ // Set the upper bound.
+ if (mayBeConstantCount.hasValue() &&
+ mayBeConstantCount.getValue() < tileSizes[i]) {
+ // Trip count is less than tile size; upper bound is the trip count.
+ auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
+ newLoops[width + i].setUpperBoundMap(ubMap);
+ } else if (largestDiv % tileSizes[i] != 0) {
+ // Intra-tile loop ii goes from i to min(i + tileSize, ub_i).
+ // Construct the upper bound map; the operands are the original operands
+ // with 'i' (tile-space loop) appended to it. The new upper bound map is
+ // the original one with an additional expression i + tileSize appended.
+ auto ub = origLoops[i].getUpperBound();
+ SmallVector<Value, 4> ubOperands;
+ ubOperands.reserve(ub.getNumOperands() + 1);
+ auto origUbMap = ub.getMap();
+ // Add dim operands from original upper bound.
+ for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) {
+ ubOperands.push_back(ub.getOperand(j));
+ }
+ // Add dim operand for new loop upper bound.
+ ubOperands.push_back(newLoops[i].getInductionVar());
+ // Add symbol operands from original upper bound.
+ for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) {
+ ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
+ }
+ SmallVector<AffineExpr, 4> boundExprs;
+ boundExprs.reserve(1 + origUbMap.getNumResults());
+ auto dim = b.getAffineDimExpr(origUbMap.getNumDims());
+ // The new upper bound map is the original one with an additional
+ // expression i + tileSize appended.
+ boundExprs.push_back(dim + tileSizes[i]);
+ boundExprs.append(origUbMap.getResults().begin(),
+ origUbMap.getResults().end());
+ auto ubMap = AffineMap::get(origUbMap.getNumDims() + 1,
+ origUbMap.getNumSymbols(), boundExprs);
+ newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
+ } else {
+ // No need of the min expression.
+ auto dim = b.getAffineDimExpr(0);
+ auto ubMap = AffineMap::get(1, 0, dim + tileSizes[i]);
+ newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
+ }
+ }
+}
+
+/// Tiles the specified band of perfectly nested loops creating tile-space loops
+/// and intra-tile loops. A band is a contiguous set of loops.
+// TODO(bondhugula): handle non hyper-rectangular spaces.
+LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
+ ArrayRef<unsigned> tileSizes) {
+ assert(!band.empty());
+ assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes");
+
+ // Check if the supplied for op's are all successively nested.
+ for (unsigned i = 1, e = band.size(); i < e; i++) {
+ assert(band[i].getParentOp() == band[i - 1].getOperation());
+ }
+
+ auto origLoops = band;
+
+ AffineForOp rootAffineForOp = origLoops[0];
+ auto loc = rootAffineForOp.getLoc();
+ // Note that width is at least one since band isn't empty.
+ unsigned width = band.size();
+
+ SmallVector<AffineForOp, 12> newLoops(2 * width);
+ AffineForOp innermostPointLoop;
+
+ // The outermost among the loops as we add more..
+ auto *topLoop = rootAffineForOp.getOperation();
+
+ // Add intra-tile (or point) loops.
+ for (unsigned i = 0; i < width; i++) {
+ OpBuilder b(topLoop);
+ // Loop bounds will be set later.
+ auto pointLoop = b.create<AffineForOp>(loc, 0, 0);
+ pointLoop.getBody()->getOperations().splice(
+ pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
+ topLoop);
+ newLoops[2 * width - 1 - i] = pointLoop;
+ topLoop = pointLoop.getOperation();
+ if (i == 0)
+ innermostPointLoop = pointLoop;
+ }
+
+ // Add tile space loops;
+ for (unsigned i = width; i < 2 * width; i++) {
+ OpBuilder b(topLoop);
+ // Loop bounds will be set later.
+ auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
+ tileSpaceLoop.getBody()->getOperations().splice(
+ tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
+ topLoop);
+ newLoops[2 * width - i - 1] = tileSpaceLoop;
+ topLoop = tileSpaceLoop.getOperation();
+ }
+
+ // Move the loop body of the original nest to the new one.
+ moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
+
+ SmallVector<Value, 8> origLoopIVs;
+ extractForInductionVars(band, &origLoopIVs);
+ SmallVector<Optional<Value>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
+ FlatAffineConstraints cst;
+ getIndexSet(band, &cst);
+
+ if (!cst.isHyperRectangular(0, width)) {
+ rootAffineForOp.emitError("tiled code generation unimplemented for the "
+ "non-hyperrectangular case");
+ return failure();
+ }
+
+ constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
+ // In this case, the point loop IVs just replace the original ones.
+ for (unsigned i = 0; i < width; i++) {
+ origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width].getInductionVar());
+ }
+
+ // Erase the old loop nest.
+ rootAffineForOp.erase();
+
+ return success();
+}
+
+// Identify valid and profitable bands of loops to tile. This is currently just
+// a temporary placeholder to test the mechanics of tiled code generation.
+// Returns all maximal outermost perfect loop nests to tile.
+static void getTileableBands(FuncOp f,
+ std::vector<SmallVector<AffineForOp, 6>> *bands) {
+ // Get maximal perfect nest of 'affine.for' insts starting from root
+ // (inclusive).
+ auto getMaximalPerfectLoopNest = [&](AffineForOp root) {
+ SmallVector<AffineForOp, 6> band;
+ getPerfectlyNestedLoops(band, root);
+ bands->push_back(band);
+ };
+
+ for (auto &block : f)
+ for (auto &op : block)
+ if (auto forOp = dyn_cast<AffineForOp>(op))
+ getMaximalPerfectLoopNest(forOp);
+}
+
+// Reduce each tile size to the largest divisor of the corresponding trip count
+// (if the trip count is known).
+static void adjustToDivisorsOfTripCounts(ArrayRef<AffineForOp> band,
+ SmallVectorImpl<unsigned> *tileSizes) {
+ assert(band.size() == tileSizes->size() && "invalid tile size count");
+ for (unsigned i = 0, e = band.size(); i < e; i++) {
+ unsigned &tSizeAdjusted = (*tileSizes)[i];
+ auto mayConst = getConstantTripCount(band[i]);
+ if (!mayConst.hasValue())
+ continue;
+ // Adjust the tile size to largest factor of the trip count less than
+ // tSize.
+ uint64_t constTripCount = mayConst.getValue();
+ if (constTripCount > 1 && tSizeAdjusted > constTripCount / 2)
+ tSizeAdjusted = constTripCount / 2;
+ while (constTripCount % tSizeAdjusted != 0)
+ tSizeAdjusted--;
+ }
+}
+
+// Returns tile sizes to use. Checks CL options; if none are specified, sets it
+// based on a simple model that looks at the memory footprint and determines
+// tile sizes assuming identity accesses / 1:1 tile size proportional footprint
+// along each of the dimensions being tiled.
+// TODO(mlir-team): evolve this model. Tile size determination is a large area
+// to play with in general.
+void LoopTiling::getTileSizes(ArrayRef<AffineForOp> band,
+ SmallVectorImpl<unsigned> *tileSizes) {
+ if (band.empty())
+ return;
+
+ tileSizes->resize(band.size());
+
+ // Use clTileSize for all loops if specified.
+ if (clTileSize.getNumOccurrences() > 0) {
+ std::fill(tileSizes->begin(), tileSizes->end(), clTileSize);
+ return;
+ }
+
+ // Use clTileSizes and fill them with default tile size if it's short.
+ if (!clTileSizes.empty()) {
+ std::fill(tileSizes->begin(), tileSizes->end(),
+ LoopTiling::kDefaultTileSize);
+ std::copy(clTileSizes.begin(),
+ clTileSizes.begin() + std::min(clTileSizes.size(), band.size()),
+ tileSizes->begin());
+ return;
+ }
+
+ // The first loop in the band.
+ auto rootForOp = band[0];
+ (void)rootForOp;
+
+ // Obtain memory footprint and set tile sizes so that a tile fits in
+ // the cache size. This is an approximation with the assumption that the
+ // footprint increases with the tile size linearly in that dimension (i.e.,
+ // assumes one-to-one access function).
+ auto fp = getMemoryFootprintBytes(band[0], 0);
+ if (!fp.hasValue()) {
+ // Fill with default tile sizes if footprint is unknown.
+ std::fill(tileSizes->begin(), tileSizes->end(),
+ LoopTiling::kDefaultTileSize);
+ if (avoidMaxMinBounds)
+ adjustToDivisorsOfTripCounts(band, tileSizes);
+ LLVM_DEBUG(
+ rootForOp.emitWarning("memory footprint unknown: using default tile "
+ "sizes adjusted to trip count divisors"));
+ return;
+ }
+
+ // Check how many times larger the cache size is when compared to footprint.
+ uint64_t excessFactor = llvm::divideCeil(fp.getValue(), cacheSizeBytes);
+ if (excessFactor <= 1) {
+ // No need of any tiling - set tile size to 1.
+ std::fill(tileSizes->begin(), tileSizes->end(), 1);
+ return;
+ }
+
+ // Divide all loops equally in an attempt to reduce footprint.
+ // TODO(bondhugula): this is approximate. Ideally, obtain reuse factor /
+ // profitability along each dimension and weight tile sizes based on that as
+ // one possible approach. Or compute a polynomial in tile sizes and solve for
+ // it.
+
+ // For an n-d tileable band, compute n^th root of the excess.
+ unsigned tSize =
+ static_cast<unsigned>(floorl(std::pow(excessFactor, 1.0 / band.size())));
+ // We'll keep a running product to determine the last tile size better.
+ unsigned cumulProductOfTileSizes = 1;
+ for (unsigned i = 0, e = band.size(); i < e; i++) {
+ if (i < e - 1)
+ (*tileSizes)[i] = tSize;
+ else
+ // Set last tile size to cover the balance.
+ (*tileSizes)[i] = std::max(
+ 1U, static_cast<unsigned>(excessFactor / cumulProductOfTileSizes));
+ cumulProductOfTileSizes *= (*tileSizes)[i];
+ }
+ if (avoidMaxMinBounds)
+ adjustToDivisorsOfTripCounts(band, tileSizes);
+}
+
+void LoopTiling::runOnFunction() {
+ // Override cache size if provided on command line.
+ if (clCacheSizeKiB.getNumOccurrences() > 0)
+ cacheSizeBytes = clCacheSizeKiB * 1024;
+
+ // Bands of loops to tile.
+ std::vector<SmallVector<AffineForOp, 6>> bands;
+ getTileableBands(getFunction(), &bands);
+
+ for (auto &band : bands) {
+ // Set up tile sizes; fill missing tile sizes at the end with default tile
+ // size or clTileSize if one was provided.
+ SmallVector<unsigned, 6> tileSizes;
+ getTileSizes(band, &tileSizes);
+ if (llvm::DebugFlag) {
+ auto diag = band[0].emitRemark("using tile sizes [");
+ for (auto tSize : tileSizes)
+ diag << tSize << " ";
+ diag << "]\n";
+ }
+ if (failed(tileCodeGen(band, tileSizes)))
+ return signalPassFailure();
+ }
+}
+
+constexpr unsigned LoopTiling::kDefaultTileSize;
+constexpr uint64_t LoopTiling::kDefaultCacheMemCapacity;
+
+static PassRegistration<LoopTiling> pass("affine-loop-tile", "Tile loop nests");
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
new file mode 100644
index 00000000000..e94c6c8b0bb
--- /dev/null
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -0,0 +1,182 @@
+//===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop unrolling.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-unroll"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned> clUnrollFactor(
+ "unroll-factor",
+ llvm::cl::desc("Use this unroll factor for all loops being unrolled"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full",
+ llvm::cl::desc("Fully unroll loops"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
+ "unroll-num-reps",
+ llvm::cl::desc("Unroll innermost loops repeatedly this many times"),
+ llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+ "unroll-full-threshold", llvm::cl::Hidden,
+ llvm::cl::desc(
+ "Unroll all loops with trip count less than or equal to this"),
+ llvm::cl::cat(clOptionsCategory));
+
+namespace {
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
+struct LoopUnroll : public FunctionPass<LoopUnroll> {
+ const Optional<unsigned> unrollFactor;
+ const Optional<bool> unrollFull;
+ // Callback to obtain unroll factors; if this has a callable target, takes
+ // precedence over command-line argument or passed argument.
+ const std::function<unsigned(AffineForOp)> getUnrollFactor;
+
+ explicit LoopUnroll(
+ Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
+ const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
+ : unrollFactor(unrollFactor), unrollFull(unrollFull),
+ getUnrollFactor(getUnrollFactor) {}
+
+ void runOnFunction() override;
+
+ /// Unroll this for op. Returns failure if nothing was done.
+ LogicalResult runOnAffineForOp(AffineForOp forOp);
+
+ static const unsigned kDefaultUnrollFactor = 4;
+};
+} // end anonymous namespace
+
+void LoopUnroll::runOnFunction() {
+ // Gathers all innermost loops through a post order pruned walk.
+ struct InnermostLoopGatherer {
+ // Store innermost loops as we walk.
+ std::vector<AffineForOp> loops;
+
+ void walkPostOrder(FuncOp f) {
+ for (auto &b : f)
+ walkPostOrder(b.begin(), b.end());
+ }
+
+ bool walkPostOrder(Block::iterator Start, Block::iterator End) {
+ bool hasInnerLoops = false;
+ // We need to walk all elements since all innermost loops need to be
+ // gathered as opposed to determining whether this list has any inner
+ // loops or not.
+ while (Start != End)
+ hasInnerLoops |= walkPostOrder(&(*Start++));
+ return hasInnerLoops;
+ }
+ bool walkPostOrder(Operation *opInst) {
+ bool hasInnerLoops = false;
+ for (auto &region : opInst->getRegions())
+ for (auto &block : region)
+ hasInnerLoops |= walkPostOrder(block.begin(), block.end());
+ if (isa<AffineForOp>(opInst)) {
+ if (!hasInnerLoops)
+ loops.push_back(cast<AffineForOp>(opInst));
+ return true;
+ }
+ return hasInnerLoops;
+ }
+ };
+
+ if (clUnrollFull.getNumOccurrences() > 0 &&
+ clUnrollFullThreshold.getNumOccurrences() > 0) {
+ // Store short loops as we walk.
+ std::vector<AffineForOp> loops;
+
+ // Gathers all loops with trip count <= minTripCount. Do a post order walk
+ // so that loops are gathered from innermost to outermost (or else unrolling
+ // an outer one may delete gathered inner ones).
+ getFunction().walk([&](AffineForOp forOp) {
+ Optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
+ loops.push_back(forOp);
+ });
+ for (auto forOp : loops)
+ loopUnrollFull(forOp);
+ return;
+ }
+
+ unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
+ ? clUnrollNumRepetitions
+ : 1;
+ // If the call back is provided, we will recurse until no loops are found.
+ FuncOp func = getFunction();
+ for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
+ InnermostLoopGatherer ilg;
+ ilg.walkPostOrder(func);
+ auto &loops = ilg.loops;
+ if (loops.empty())
+ break;
+ bool unrolled = false;
+ for (auto forOp : loops)
+ unrolled |= succeeded(runOnAffineForOp(forOp));
+ if (!unrolled)
+ // Break out if nothing was unrolled.
+ break;
+ }
+}
+
+/// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
+/// failure otherwise. The default unroll factor is 4.
+LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
+ // Use the function callback if one was provided.
+ if (getUnrollFactor) {
+ return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
+ }
+ // Unroll by the factor passed, if any.
+ if (unrollFactor.hasValue())
+ return loopUnrollByFactor(forOp, unrollFactor.getValue());
+ // Unroll by the command line factor if one was specified.
+ if (clUnrollFactor.getNumOccurrences() > 0)
+ return loopUnrollByFactor(forOp, clUnrollFactor);
+ // Unroll completely if full loop unroll was specified.
+ if (clUnrollFull.getNumOccurrences() > 0 ||
+ (unrollFull.hasValue() && unrollFull.getValue()))
+ return loopUnrollFull(forOp);
+
+ // Unroll by four otherwise.
+ return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
+}
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass(
+ int unrollFactor, int unrollFull,
+ const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
+ return std::make_unique<LoopUnroll>(
+ unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
+ unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
+}
+
+static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops");
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
new file mode 100644
index 00000000000..6c74d545497
--- /dev/null
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -0,0 +1,235 @@
+//===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop unroll and jam. Unroll and jam is a transformation
+// that improves locality, in particular, register reuse, while also improving
+// operation level parallelism. The example below shows what it does in nearly
+// the general case. Loop unroll and jam currently works if the bounds of the
+// loops inner to the loop being unroll-jammed do not depend on the latter.
+//
+// Before After unroll and jam of i by factor 2:
+//
+// for i, step = 2
+// for i S1(i);
+// S1; S2(i);
+// S2; S1(i+1);
+// for j S2(i+1);
+// S3; for j
+// S4; S3(i, j);
+// S5; S4(i, j);
+// S6; S3(i+1, j)
+// S4(i+1, j)
+// S5(i);
+// S6(i);
+// S5(i+1);
+// S6(i+1);
+//
+// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
+// op's, bodies of those loops will not be jammed.
+//===----------------------------------------------------------------------===//
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "affine-loop-unroll-jam"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+
+// Loop unroll and jam factor.
+static llvm::cl::opt<unsigned>
+ clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden,
+ llvm::cl::desc("Use this unroll jam factor for all loops"
+ " (default 4)"),
+ llvm::cl::cat(clOptionsCategory));
+
+namespace {
+/// Loop unroll jam pass. Currently, this just unroll jams the first
+/// outer loop in a Function.
+struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> {
+ Optional<unsigned> unrollJamFactor;
+ static const unsigned kDefaultUnrollJamFactor = 4;
+
+ explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
+ : unrollJamFactor(unrollJamFactor) {}
+
+ void runOnFunction() override;
+ LogicalResult runOnAffineForOp(AffineForOp forOp);
+};
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
+ return std::make_unique<LoopUnrollAndJam>(
+ unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
+}
+
+void LoopUnrollAndJam::runOnFunction() {
+ // Currently, just the outermost loop from the first loop nest is
+ // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
+ // any for operation.
+ auto &entryBlock = getFunction().front();
+ if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front()))
+ runOnAffineForOp(forOp);
+}
+
+/// Unroll and jam a 'affine.for' op. Default unroll jam factor is
+/// kDefaultUnrollJamFactor. Return failure if nothing was done.
+LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) {
+ // Unroll and jam by the factor that was passed if any.
+ if (unrollJamFactor.hasValue())
+ return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
+ // Otherwise, unroll jam by the command-line factor if one was specified.
+ if (clUnrollJamFactor.getNumOccurrences() > 0)
+ return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
+
+ // Unroll and jam by four otherwise.
+ return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
+}
+
+LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
+ uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollJamFactor)
+ return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forOp, unrollJamFactor);
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
+ uint64_t unrollJamFactor) {
+ // Gathers all maximal sub-blocks of operations that do not themselves
+ // include a for op (a operation could have a descendant for op though
+ // in its tree). Ignore the block terminators.
+ struct JamBlockGatherer {
+ // Store iterators to the first and last op of each sub-block found.
+ std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+ // This is a linear time walk.
+ void walk(Operation *op) {
+ for (auto &region : op->getRegions())
+ for (auto &block : region)
+ walk(block);
+ }
+ void walk(Block &block) {
+ for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+ auto subBlockStart = it;
+ while (it != e && !isa<AffineForOp>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.push_back({subBlockStart, std::prev(it)});
+ // Process all for insts that appear next.
+ while (it != e && isa<AffineForOp>(&*it))
+ walk(&*it++);
+ }
+ }
+ };
+
+ assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
+
+ if (unrollJamFactor == 1)
+ return promoteIfSingleIteration(forOp);
+
+ if (forOp.getBody()->empty() ||
+ forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+ return failure();
+
+ // Loops where both lower and upper bounds are multi-result maps won't be
+ // unrolled (since the trip can't be expressed as an affine function in
+ // general).
+ // TODO(mlir-team): this may not be common, but we could support the case
+ // where the lower bound is a multi-result map and the ub is a single result
+ // one.
+ if (forOp.getLowerBoundMap().getNumResults() != 1)
+ return failure();
+
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ // If the trip count is lower than the unroll jam factor, no unroll jam.
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollJamFactor)
+ return failure();
+
+ auto *forInst = forOp.getOperation();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer jbg;
+ jbg.walk(forInst);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Generate the cleanup loop if trip count isn't a multiple of
+ // unrollJamFactor.
+ if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
+ // Insert the cleanup loop right after 'forOp'.
+ OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst)));
+ auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst));
+ // Adjust the lower bound of the cleanup loop; its upper bound is the same
+ // as the original loop's upper bound.
+ AffineMap cleanupMap;
+ SmallVector<Value, 4> cleanupOperands;
+ getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap,
+ &cleanupOperands, builder);
+ cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
+
+ // Promote the cleanup loop if it has turned into a single iteration loop.
+ promoteIfSingleIteration(cleanupAffineForOp);
+
+ // Adjust the upper bound of the original loop - it will be the same as the
+ // cleanup loop's lower bound. Its lower bound remains unchanged.
+ forOp.setUpperBound(cleanupOperands, cleanupMap);
+ }
+
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ int64_t step = forOp.getStep();
+ forOp.setStep(step * unrollJamFactor);
+
+ auto forOpIV = forOp.getInductionVar();
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ // Operand map persists across all sub-blocks.
+ BlockAndValueMapping operandMapping;
+ for (auto &subBlock : subBlocks) {
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
+ // sub-block.
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV->use_empty()) {
+ // iv' = iv + i, i = 1 to unrollJamFactor-1.
+ auto d0 = builder.getAffineDimExpr(0);
+ auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
+ auto ivUnroll =
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
+ operandMapping.map(forOpIV, ivUnroll);
+ }
+ // Clone the sub-block being unroll-jammed.
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
+ builder.clone(*it, operandMapping);
+ }
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(forOp);
+ return success();
+}
+
+static PassRegistration<LoopUnrollAndJam> pass("affine-loop-unroll-jam",
+ "Unroll and jam loops");
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
new file mode 100644
index 00000000000..e2514e12cc7
--- /dev/null
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -0,0 +1,227 @@
+//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to forward memref stores to loads, thereby
+// potentially getting rid of intermediate memref's entirely.
+// TODO(mlir-team): In the future, similar techniques could be used to eliminate
+// dead memref store's and perform more complex forwarding when support for
+// SSA scalars live out of 'affine.for'/'affine.if' statements is available.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "memref-dataflow-opt"
+
+using namespace mlir;
+
+namespace {
+
+// The store to load forwarding relies on three conditions:
+//
+// 1) they need to have mathematically equivalent affine access functions
+// (checked after full composition of load/store operands); this implies that
+// they access the same single memref element for all iterations of the common
+// surrounding loop,
+//
+// 2) the store op should dominate the load op,
+//
+// 3) among all op's that satisfy both (1) and (2), the one that postdominates
+// all store op's that have a dependence into the load, is provably the last
+// writer to the particular memref location being loaded at the load op, and its
+// store value can be forwarded to the load. Note that the only dependences
+// that are to be considered are those that are satisfied at the block* of the
+// innermost common surrounding loop of the <store, load> being considered.
+//
+// (* A dependence being satisfied at a block: a dependence that is satisfied by
+// virtue of the destination operation appearing textually / lexically after
+// the source operation within the body of a 'affine.for' operation; thus, a
+// dependence is always either satisfied by a loop or by a block).
+//
+// The above conditions are simple to check, sufficient, and powerful for most
+// cases in practice - they are sufficient, but not necessary --- since they
+// don't reason about loops that are guaranteed to execute at least once or
+// multiple sources to forward from.
+//
+// TODO(mlir-team): more forwarding can be done when support for
+// loop/conditional live-out SSA values is available.
+// TODO(mlir-team): do general dead store elimination for memref's. This pass
+// currently only eliminates the stores only if no other loads/uses (other
+// than dealloc) remain.
+//
+struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> {
+ void runOnFunction() override;
+
+ void forwardStoreToLoad(AffineLoadOp loadOp);
+
+ // A list of memref's that are potentially dead / could be eliminated.
+ SmallPtrSet<Value, 4> memrefsToErase;
+ // Load op's whose results were replaced by those forwarded from stores.
+ SmallVector<Operation *, 8> loadOpsToErase;
+
+ DominanceInfo *domInfo = nullptr;
+ PostDominanceInfo *postDomInfo = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to perform optimizations relying on memref dataflow such as
+/// store to load forwarding, elimination of dead stores, and dead allocs.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefDataFlowOptPass() {
+ return std::make_unique<MemRefDataFlowOpt>();
+}
+
+// This is a straightforward implementation not optimized for speed. Optimize
+// if needed.
+void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
+ Operation *loadOpInst = loadOp.getOperation();
+
+ // First pass over the use list to get minimum number of surrounding
+ // loops common between the load op and the store op, with min taken across
+ // all store ops.
+ SmallVector<Operation *, 8> storeOps;
+ unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
+ for (auto *user : loadOp.getMemRef()->getUsers()) {
+ auto storeOp = dyn_cast<AffineStoreOp>(user);
+ if (!storeOp)
+ continue;
+ auto *storeOpInst = storeOp.getOperation();
+ unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
+ minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
+ storeOps.push_back(storeOpInst);
+ }
+
+ // The list of store op candidates for forwarding that satisfy conditions
+ // (1) and (2) above - they will be filtered later when checking (3).
+ SmallVector<Operation *, 8> fwdingCandidates;
+
+ // Store ops that have a dependence into the load (even if they aren't
+ // forwarding candidates). Each forwarding candidate will be checked for a
+ // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
+ SmallVector<Operation *, 8> depSrcStores;
+
+ for (auto *storeOpInst : storeOps) {
+ MemRefAccess srcAccess(storeOpInst);
+ MemRefAccess destAccess(loadOpInst);
+ // Find stores that may be reaching the load.
+ FlatAffineConstraints dependenceConstraints;
+ unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
+ unsigned d;
+ // Dependences at loop depth <= minSurroundingLoops do NOT matter.
+ for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, destAccess, d, &dependenceConstraints,
+ /*dependenceComponents=*/nullptr);
+ if (hasDependence(result))
+ break;
+ }
+ if (d == minSurroundingLoops)
+ continue;
+
+ // Stores that *may* be reaching the load.
+ depSrcStores.push_back(storeOpInst);
+
+ // 1. Check if the store and the load have mathematically equivalent
+ // affine access functions; this implies that they statically refer to the
+ // same single memref element. As an example this filters out cases like:
+ // store %A[%i0 + 1]
+ // load %A[%i0]
+ // store %A[%M]
+ // load %A[%N]
+ // Use the AffineValueMap difference based memref access equality checking.
+ if (srcAccess != destAccess)
+ continue;
+
+ // 2. The store has to dominate the load op to be candidate.
+ if (!domInfo->dominates(storeOpInst, loadOpInst))
+ continue;
+
+ // We now have a candidate for forwarding.
+ fwdingCandidates.push_back(storeOpInst);
+ }
+
+ // 3. Of all the store op's that meet the above criteria, the store that
+ // postdominates all 'depSrcStores' (if one exists) is the unique store
+ // providing the value to the load, i.e., provably the last writer to that
+ // memref loc.
+ // Note: this can be implemented in a cleaner way with postdominator tree
+ // traversals. Consider this for the future if needed.
+ Operation *lastWriteStoreOp = nullptr;
+ for (auto *storeOpInst : fwdingCandidates) {
+ if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
+ return postDomInfo->postDominates(storeOpInst, depStore);
+ })) {
+ lastWriteStoreOp = storeOpInst;
+ break;
+ }
+ }
+ if (!lastWriteStoreOp)
+ return;
+
+ // Perform the actual store to load forwarding.
+ Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
+ loadOp.replaceAllUsesWith(storeVal);
+ // Record the memref for a later sweep to optimize away.
+ memrefsToErase.insert(loadOp.getMemRef());
+ // Record this to erase later.
+ loadOpsToErase.push_back(loadOpInst);
+}
+
+void MemRefDataFlowOpt::runOnFunction() {
+ // Only supports single block functions at the moment.
+ FuncOp f = getFunction();
+ if (f.getBlocks().size() != 1) {
+ markAllAnalysesPreserved();
+ return;
+ }
+
+ domInfo = &getAnalysis<DominanceInfo>();
+ postDomInfo = &getAnalysis<PostDominanceInfo>();
+
+ loadOpsToErase.clear();
+ memrefsToErase.clear();
+
+ // Walk all load's and perform load/store forwarding.
+ f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
+
+ // Erase all load op's whose results were replaced with store fwd'ed ones.
+ for (auto *loadOp : loadOpsToErase) {
+ loadOp->erase();
+ }
+
+ // Check if the store fwd'ed memrefs are now left with only stores and can
+ // thus be completely deleted. Note: the canonicalize pass should be able
+ // to do this as well, but we'll do it here since we collected these anyway.
+ for (auto memref : memrefsToErase) {
+ // If the memref hasn't been alloc'ed in this function, skip.
+ Operation *defInst = memref->getDefiningOp();
+ if (!defInst || !isa<AllocOp>(defInst))
+ // TODO(mlir-team): if the memref was returned by a 'call' operation, we
+ // could still erase it if the call had no side-effects.
+ continue;
+ if (llvm::any_of(memref->getUsers(), [&](Operation *ownerInst) {
+ return (!isa<AffineStoreOp>(ownerInst) && !isa<DeallocOp>(ownerInst));
+ }))
+ continue;
+
+ // Erase all stores, the dealloc, and the alloc on the memref.
+ for (auto *user : llvm::make_early_inc_range(memref->getUsers()))
+ user->erase();
+ defInst->erase();
+ }
+}
+
+static PassRegistration<MemRefDataFlowOpt>
+ pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs");
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
new file mode 100644
index 00000000000..dce02737064
--- /dev/null
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -0,0 +1,379 @@
+//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to pipeline data transfers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "affine-pipeline-data-transfer"
+
+using namespace mlir;
+
+namespace {
+
+struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> {
+ void runOnFunction() override;
+ void runOnAffineForOp(AffineForOp forOp);
+
+ std::vector<AffineForOp> forOps;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to pipeline explicit movement of data across levels of the
+/// memory hierarchy.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() {
+ return std::make_unique<PipelineDataTransfer>();
+}
+
+// Returns the position of the tag memref operand given a DMA operation.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+static unsigned getTagMemRefPos(Operation &dmaInst) {
+ assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst));
+ if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) {
+ return dmaStartOp.getTagMemRefOperandIndex();
+ }
+ // First operand for a dma finish operation.
+ return 0;
+}
+
+/// Doubles the buffer of the supplied memref on the specified 'affine.for'
+/// operation by adding a leading dimension of size two to the memref.
+/// Replaces all uses of the old memref by the new one while indexing the newly
+/// added dimension by the loop IV of the specified 'affine.for' operation
+/// modulo 2. Returns false if such a replacement cannot be performed.
+static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
+ auto *forBody = forOp.getBody();
+ OpBuilder bInner(forBody, forBody->begin());
+
+ // Doubles the shape with a leading dimension extent of 2.
+ auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
+ // Add the leading dimension in the shape for the double buffer.
+ ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
+ SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
+ newShape[0] = 2;
+ std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
+ auto newMemRefType =
+ MemRefType::get(newShape, oldMemRefType.getElementType(), {},
+ oldMemRefType.getMemorySpace());
+ return newMemRefType;
+ };
+
+ auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+ auto newMemRefType = doubleShape(oldMemRefType);
+
+ // The double buffer is allocated right before 'forInst'.
+ auto *forInst = forOp.getOperation();
+ OpBuilder bOuter(forInst);
+ // Put together alloc operands for any dynamic dimensions of the memref.
+ SmallVector<Value, 4> allocOperands;
+ unsigned dynamicDimCount = 0;
+ for (auto dimSize : oldMemRefType.getShape()) {
+ if (dimSize == -1)
+ allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
+ dynamicDimCount++));
+ }
+
+ // Create and place the alloc right before the 'affine.for' operation.
+ Value newMemRef =
+ bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
+
+ // Create 'iv mod 2' value to index the leading dimension.
+ auto d0 = bInner.getAffineDimExpr(0);
+ int64_t step = forOp.getStep();
+ auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+ {d0.floorDiv(step) % 2});
+ auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
+ forOp.getInductionVar());
+
+ // replaceAllMemRefUsesWith will succeed unless the forOp body has
+ // non-dereferencing uses of the memref (dealloc's are fine though).
+ if (failed(replaceAllMemRefUsesWith(
+ oldMemRef, newMemRef,
+ /*extraIndices=*/{ivModTwoOp},
+ /*indexRemap=*/AffineMap(),
+ /*extraOperands=*/{},
+ /*symbolOperands=*/{},
+ /*domInstFilter=*/&*forOp.getBody()->begin()))) {
+ LLVM_DEBUG(
+ forOp.emitError("memref replacement for double buffering failed"));
+ ivModTwoOp.erase();
+ return false;
+ }
+ // Insert the dealloc op right after the for loop.
+ bOuter.setInsertionPointAfter(forInst);
+ bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef);
+
+ return true;
+}
+
+/// Returns success if the IR is in a valid state.
+void PipelineDataTransfer::runOnFunction() {
+ // Do a post order walk so that inner loop DMAs are processed first. This is
+ // necessary since 'affine.for' operations nested within would otherwise
+ // become invalid (erased) when the outer loop is pipelined (the pipelined one
+ // gets deleted and replaced by a prologue, a new steady-state loop and an
+ // epilogue).
+ forOps.clear();
+ getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
+ for (auto forOp : forOps)
+ runOnAffineForOp(forOp);
+}
+
+// Check if tags of the dma start op and dma wait op match.
+static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
+ if (startOp.getTagMemRef() != waitOp.getTagMemRef())
+ return false;
+ auto startIndices = startOp.getTagIndices();
+ auto waitIndices = waitOp.getTagIndices();
+ // Both of these have the same number of indices since they correspond to the
+ // same tag memref.
+ for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
+ e = startIndices.end();
+ it != e; ++it, ++wIt) {
+ // Keep it simple for now, just checking if indices match.
+ // TODO(mlir-team): this would in general need to check if there is no
+ // intervening write writing to the same tag location, i.e., memory last
+ // write/data flow analysis. This is however sufficient/powerful enough for
+ // now since the DMA generation pass or the input for it will always have
+ // start/wait with matching tags (same SSA operand indices).
+ if (*it != *wIt)
+ return false;
+ }
+ return true;
+}
+
+// Identify matching DMA start/finish operations to overlap computation with.
+static void findMatchingStartFinishInsts(
+ AffineForOp forOp,
+ SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
+
+ // Collect outgoing DMA operations - needed to check for dependences below.
+ SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
+ for (auto &op : *forOp.getBody()) {
+ auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+ if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
+ outgoingDmaOps.push_back(dmaStartOp);
+ }
+
+ SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
+ for (auto &op : *forOp.getBody()) {
+ // Collect DMA finish operations.
+ if (isa<AffineDmaWaitOp>(op)) {
+ dmaFinishInsts.push_back(&op);
+ continue;
+ }
+ auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+ if (!dmaStartOp)
+ continue;
+
+ // Only DMAs incoming into higher memory spaces are pipelined for now.
+ // TODO(bondhugula): handle outgoing DMA pipelining.
+ if (!dmaStartOp.isDestMemorySpaceFaster())
+ continue;
+
+ // Check for dependence with outgoing DMAs. Doing this conservatively.
+ // TODO(andydavis,bondhugula): use the dependence analysis to check for
+ // dependences between an incoming and outgoing DMA in the same iteration.
+ auto it = outgoingDmaOps.begin();
+ for (; it != outgoingDmaOps.end(); ++it) {
+ if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
+ break;
+ }
+ if (it != outgoingDmaOps.end())
+ continue;
+
+ // We only double buffer if the buffer is not live out of loop.
+ auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
+ bool escapingUses = false;
+ for (auto *user : memref->getUsers()) {
+ // We can double buffer regardless of dealloc's outside the loop.
+ if (isa<DeallocOp>(user))
+ continue;
+ if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't pipeline: buffer is live out of loop\n";);
+ escapingUses = true;
+ break;
+ }
+ }
+ if (!escapingUses)
+ dmaStartInsts.push_back(&op);
+ }
+
+ // For each start operation, we look for a matching finish operation.
+ for (auto *dmaStartInst : dmaStartInsts) {
+ for (auto *dmaFinishInst : dmaFinishInsts) {
+ if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst),
+ cast<AffineDmaWaitOp>(dmaFinishInst))) {
+ startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
+ break;
+ }
+ }
+ }
+}
+
+/// Overlap DMA transfers with computation in this loop. If successful,
+/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// inserted right before where it was.
+void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
+ auto mayBeConstTripCount = getConstantTripCount(forOp);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(
+ forOp.emitRemark("won't pipeline due to unknown trip count loop"));
+ return;
+ }
+
+ SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
+ findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+ if (startWaitPairs.empty()) {
+ LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
+ return;
+ }
+
+ // Double the buffers for the higher memory space memref's.
+ // Identify memref's to replace by scanning through all DMA start
+ // operations. A DMA start operation has two memref's - the one from the
+ // higher level of memory hierarchy is the one to double buffer.
+ // TODO(bondhugula): check whether double-buffering is even necessary.
+ // TODO(bondhugula): make this work with different layouts: assuming here that
+ // the dimension we are adding here for the double buffering is the outermost
+ // dimension.
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartInst = pair.first;
+ Value oldMemRef = dmaStartInst->getOperand(
+ cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos());
+ if (!doubleBuffer(oldMemRef, forOp)) {
+ // Normally, double buffering should not fail because we already checked
+ // that there are no uses outside.
+ LLVM_DEBUG(llvm::dbgs()
+ << "double buffering failed for" << dmaStartInst << "\n";);
+ // IR still valid and semantically correct.
+ return;
+ }
+ // If the old memref has no more uses, remove its 'dead' alloc if it was
+ // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
+ // operation could have been used on it if it was dynamically shaped in
+ // order to create the double buffer above.)
+ // '-canonicalize' does this in a more general way, but we'll anyway do the
+ // simple/common case so that the output / test cases looks clear.
+ if (auto *allocInst = oldMemRef->getDefiningOp()) {
+ if (oldMemRef->use_empty()) {
+ allocInst->erase();
+ } else if (oldMemRef->hasOneUse()) {
+ if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) {
+ dealloc.erase();
+ allocInst->erase();
+ }
+ }
+ }
+ }
+
+ // Double the buffers for tag memrefs.
+ for (auto &pair : startWaitPairs) {
+ auto *dmaFinishInst = pair.second;
+ Value oldTagMemRef =
+ dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
+ if (!doubleBuffer(oldTagMemRef, forOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
+ return;
+ }
+ // If the old tag has no uses or a single dealloc use, remove it.
+ // (canonicalization handles more complex cases).
+ if (auto *tagAllocInst = oldTagMemRef->getDefiningOp()) {
+ if (oldTagMemRef->use_empty()) {
+ tagAllocInst->erase();
+ } else if (oldTagMemRef->hasOneUse()) {
+ if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef->user_begin())) {
+ dealloc.erase();
+ tagAllocInst->erase();
+ }
+ }
+ }
+ }
+
+ // Double buffering would have invalidated all the old DMA start/wait insts.
+ startWaitPairs.clear();
+ findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+ // Store shift for operation for later lookup for AffineApplyOp's.
+ DenseMap<Operation *, unsigned> instShiftMap;
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartInst = pair.first;
+ assert(isa<AffineDmaStartOp>(dmaStartInst));
+ instShiftMap[dmaStartInst] = 0;
+ // Set shifts for DMA start op's affine operand computation slices to 0.
+ SmallVector<AffineApplyOp, 4> sliceOps;
+ mlir::createAffineComputationSlice(dmaStartInst, &sliceOps);
+ if (!sliceOps.empty()) {
+ for (auto sliceOp : sliceOps) {
+ instShiftMap[sliceOp.getOperation()] = 0;
+ }
+ } else {
+ // If a slice wasn't created, the reachable affine.apply op's from its
+ // operands are the ones that go with it.
+ SmallVector<Operation *, 4> affineApplyInsts;
+ SmallVector<Value, 4> operands(dmaStartInst->getOperands());
+ getReachableAffineApplyOps(operands, affineApplyInsts);
+ for (auto *op : affineApplyInsts) {
+ instShiftMap[op] = 0;
+ }
+ }
+ }
+ // Everything else (including compute ops and dma finish) are shifted by one.
+ for (auto &op : *forOp.getBody()) {
+ if (instShiftMap.find(&op) == instShiftMap.end()) {
+ instShiftMap[&op] = 1;
+ }
+ }
+
+ // Get shifts stored in map.
+ std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size());
+ unsigned s = 0;
+ for (auto &op : *forOp.getBody()) {
+ assert(instShiftMap.find(&op) != instShiftMap.end());
+ shifts[s++] = instShiftMap[&op];
+
+ // Tagging operations with shifts for debugging purposes.
+ LLVM_DEBUG({
+ OpBuilder b(&op);
+ op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
+ });
+ }
+
+ if (!isInstwiseShiftValid(forOp, shifts)) {
+ // Violates dependences.
+ LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
+ return;
+ }
+
+ if (failed(instBodySkew(forOp, shifts))) {
+ LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
+ return;
+ }
+}
+
+static PassRegistration<PipelineDataTransfer> pass(
+ "affine-pipeline-data-transfer",
+ "Pipeline non-blocking data transfers between explicitly managed levels of "
+ "the memory hierarchy");
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
new file mode 100644
index 00000000000..217e06bc877
--- /dev/null
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -0,0 +1,108 @@
+//===- SimplifyAffineStructures.cpp ---------------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to simplify affine structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+#define DEBUG_TYPE "simplify-affine-structure"
+
+using namespace mlir;
+
+namespace {
+
+/// Simplifies affine maps and sets appearing in the operations of the Function.
+/// This part is mainly to test the simplifyAffineExpr method. In addition,
+/// all memrefs with non-trivial layout maps are converted to ones with trivial
+/// identity layout ones.
+struct SimplifyAffineStructures
+ : public FunctionPass<SimplifyAffineStructures> {
+ void runOnFunction() override;
+
+ /// Utility to simplify an affine attribute and update its entry in the parent
+ /// operation if necessary.
+ template <typename AttributeT>
+ void simplifyAndUpdateAttribute(Operation *op, Identifier name,
+ AttributeT attr) {
+ auto &simplified = simplifiedAttributes[attr];
+ if (simplified == attr)
+ return;
+
+ // This is a newly encountered attribute.
+ if (!simplified) {
+ // Try to simplify the value of the attribute.
+ auto value = attr.getValue();
+ auto simplifiedValue = simplify(value);
+ if (simplifiedValue == value) {
+ simplified = attr;
+ return;
+ }
+ simplified = AttributeT::get(simplifiedValue);
+ }
+
+ // Simplification was successful, so update the attribute.
+ op->setAttr(name, simplified);
+ }
+
+ /// Performs basic integer set simplifications. Checks if it's empty, and
+ /// replaces it with the canonical empty set if it is.
+ IntegerSet simplify(IntegerSet set) {
+ FlatAffineConstraints fac(set);
+ if (fac.isEmpty())
+ return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
+ &getContext());
+ return set;
+ }
+
+ /// Performs basic affine map simplifications.
+ AffineMap simplify(AffineMap map) {
+ MutableAffineMap mMap(map);
+ mMap.simplify();
+ return mMap.getAffineMap();
+ }
+
+ DenseMap<Attribute, Attribute> simplifiedAttributes;
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimplifyAffineStructuresPass() {
+ return std::make_unique<SimplifyAffineStructures>();
+}
+
+void SimplifyAffineStructures::runOnFunction() {
+ auto func = getFunction();
+ simplifiedAttributes.clear();
+ func.walk([&](Operation *opInst) {
+ for (auto attr : opInst->getAttrs()) {
+ if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
+ simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
+ else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
+ simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
+ }
+ });
+
+ // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+ // alloc ops first and then process since normalizeMemRef replaces/erases ops
+ // during memref rewriting.
+ SmallVector<AllocOp, 4> allocOps;
+ func.walk([&](AllocOp op) { allocOps.push_back(op); });
+ for (auto allocOp : allocOps) {
+ normalizeMemRef(allocOp);
+ }
+}
+
+static PassRegistration<SimplifyAffineStructures>
+ pass("simplify-affine-structures",
+ "Simplify affine expressions in maps/sets and normalize memrefs");
diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp
new file mode 100644
index 00000000000..cdfc7fd7e41
--- /dev/null
+++ b/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -0,0 +1,37 @@
+//===- StripDebugInfo.cpp - Pass to strip debug information ---------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
+ void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void StripDebugInfo::runOnFunction() {
+ FuncOp func = getFunction();
+ auto unknownLoc = UnknownLoc::get(&getContext());
+
+ // Strip the debug info from the function and its operations.
+ func.setLoc(unknownLoc);
+ func.walk([&](Operation *op) { op->setLoc(unknownLoc); });
+}
+
+/// Creates a pass to strip debug information from a function.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createStripDebugInfoPass() {
+ return std::make_unique<StripDebugInfo>();
+}
+
+static PassRegistration<StripDebugInfo>
+ pass("strip-debuginfo", "Strip debug info from functions and operations");
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
new file mode 100644
index 00000000000..4e1dc5e4b4e
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_llvm_library(MLIRTransformUtils
+ FoldUtils.cpp
+ GreedyPatternRewriteDriver.cpp
+ InliningUtils.cpp
+ LoopFusionUtils.cpp
+ LoopUtils.cpp
+ RegionUtils.cpp
+ Utils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+ )
+
+add_dependencies(MLIRTransformUtils MLIRStandardOpsIncGen)
+target_link_libraries(MLIRTransformUtils
+ MLIRAffineOps
+ MLIRAnalysis
+ MLIRLoopOps
+ MLIRPass
+ MLIRStandardOps
+ )
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
new file mode 100644
index 00000000000..719c6fac731
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -0,0 +1,246 @@
+//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines various operation fold utilities. These utilities are
+// intended to be used by passes to unify and simply their logic.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+
+/// Given an operation, find the parent region that folded constants should be
+/// inserted into.
+static Region *getInsertionRegion(
+ DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
+ Operation *op) {
+ while (Region *region = op->getParentRegion()) {
+ // Insert in this region for any of the following scenarios:
+ // * The parent is unregistered, or is known to be isolated from above.
+ // * The parent is a top-level operation.
+ auto *parentOp = region->getParentOp();
+ if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
+ !parentOp->getBlock())
+ return region;
+
+ // Otherwise, check if this region is a desired insertion region.
+ auto *interface = interfaces.getInterfaceFor(parentOp);
+ if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
+ return region;
+
+ // Traverse up the parent looking for an insertion region.
+ op = parentOp;
+ }
+ llvm_unreachable("expected valid insertion region");
+}
+
+/// A utility function used to materialize a constant for a given attribute and
+/// type. On success, a valid constant value is returned. Otherwise, null is
+/// returned
+static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ auto insertPt = builder.getInsertionPoint();
+ (void)insertPt;
+
+ // Ask the dialect to materialize a constant operation for this value.
+ if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
+ assert(insertPt == builder.getInsertionPoint());
+ assert(matchPattern(constOp, m_Constant(&value)));
+ return constOp;
+ }
+
+ // If the dialect is unable to materialize a constant, check to see if the
+ // standard constant can be used.
+ if (ConstantOp::isBuildableWith(value, type))
+ return builder.create<ConstantOp>(loc, type, value);
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// OperationFolder
+//===----------------------------------------------------------------------===//
+
+LogicalResult OperationFolder::tryToFold(
+ Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
+ function_ref<void(Operation *)> preReplaceAction) {
+ // If this is a unique'd constant, return failure as we know that it has
+ // already been folded.
+ if (referencedDialects.count(op))
+ return failure();
+
+ // Try to fold the operation.
+ SmallVector<Value, 8> results;
+ if (failed(tryToFold(op, results, processGeneratedConstants)))
+ return failure();
+
+ // Constant folding succeeded. We will start replacing this op's uses and
+ // eventually erase this op. Invoke the callback provided by the caller to
+ // perform any pre-replacement action.
+ if (preReplaceAction)
+ preReplaceAction(op);
+
+ // Check to see if the operation was just updated in place.
+ if (results.empty())
+ return success();
+
+ // Otherwise, replace all of the result values and erase the operation.
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ op->getResult(i)->replaceAllUsesWith(results[i]);
+ op->erase();
+ return success();
+}
+
+/// Notifies that the given constant `op` should be remove from this
+/// OperationFolder's internal bookkeeping.
+void OperationFolder::notifyRemoval(Operation *op) {
+ // Check to see if this operation is uniqued within the folder.
+ auto it = referencedDialects.find(op);
+ if (it == referencedDialects.end())
+ return;
+
+ // Get the constant value for this operation, this is the value that was used
+ // to unique the operation internally.
+ Attribute constValue;
+ matchPattern(op, m_Constant(&constValue));
+ assert(constValue);
+
+ // Get the constant map that this operation was uniqued in.
+ auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)];
+
+ // Erase all of the references to this operation.
+ auto type = op->getResult(0)->getType();
+ for (auto *dialect : it->second)
+ uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
+ referencedDialects.erase(it);
+}
+
+/// Tries to perform folding on the given `op`. If successful, populates
+/// `results` with the results of the folding.
+LogicalResult OperationFolder::tryToFold(
+ Operation *op, SmallVectorImpl<Value> &results,
+ function_ref<void(Operation *)> processGeneratedConstants) {
+ SmallVector<Attribute, 8> operandConstants;
+ SmallVector<OpFoldResult, 8> foldResults;
+
+ // Check to see if any operands to the operation is constant and whether
+ // the operation knows how to constant fold itself.
+ operandConstants.assign(op->getNumOperands(), Attribute());
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
+
+ // If this is a commutative binary operation with a constant on the left
+ // side move it to the right side.
+ if (operandConstants.size() == 2 && operandConstants[0] &&
+ !operandConstants[1] && op->isCommutative()) {
+ std::swap(op->getOpOperand(0), op->getOpOperand(1));
+ std::swap(operandConstants[0], operandConstants[1]);
+ }
+
+ // Attempt to constant fold the operation.
+ if (failed(op->fold(operandConstants, foldResults)))
+ return failure();
+
+ // Check to see if the operation was just updated in place.
+ if (foldResults.empty())
+ return success();
+ assert(foldResults.size() == op->getNumResults());
+
+ // Create a builder to insert new operations into the entry block of the
+ // insertion region.
+ auto *insertRegion = getInsertionRegion(interfaces, op);
+ auto &entry = insertRegion->front();
+ OpBuilder builder(&entry, entry.begin());
+
+ // Get the constant map for the insertion region of this operation.
+ auto &uniquedConstants = foldScopes[insertRegion];
+
+ // Create the result constants and replace the results.
+ auto *dialect = op->getDialect();
+ for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+ assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
+
+ // Check if the result was an SSA value.
+ if (auto repl = foldResults[i].dyn_cast<Value>()) {
+ results.emplace_back(repl);
+ continue;
+ }
+
+ // Check to see if there is a canonicalized version of this constant.
+ auto res = op->getResult(i);
+ Attribute attrRepl = foldResults[i].get<Attribute>();
+ if (auto *constOp =
+ tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
+ res->getType(), op->getLoc())) {
+ results.push_back(constOp->getResult(0));
+ continue;
+ }
+ // If materialization fails, cleanup any operations generated for the
+ // previous results and return failure.
+ for (Operation &op : llvm::make_early_inc_range(
+ llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
+ notifyRemoval(&op);
+ op.erase();
+ }
+ return failure();
+ }
+
+ // Process any newly generated operations.
+ if (processGeneratedConstants) {
+ for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
+ processGeneratedConstants(&*i);
+ }
+
+ return success();
+}
+
+/// Try to get or create a new constant entry. On success this returns the
+/// constant operation value, nullptr otherwise.
+Operation *OperationFolder::tryGetOrCreateConstant(
+ ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
+ Attribute value, Type type, Location loc) {
+ // Check if an existing mapping already exists.
+ auto constKey = std::make_tuple(dialect, value, type);
+ auto *&constInst = uniquedConstants[constKey];
+ if (constInst)
+ return constInst;
+
+ // If one doesn't exist, try to materialize one.
+ if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+ return nullptr;
+
+ // Check to see if the generated constant is in the expected dialect.
+ auto *newDialect = constInst->getDialect();
+ if (newDialect == dialect) {
+ referencedDialects[constInst].push_back(dialect);
+ return constInst;
+ }
+
+ // If it isn't, then we also need to make sure that the mapping for the new
+ // dialect is valid.
+ auto newKey = std::make_tuple(newDialect, value, type);
+
+ // If an existing operation in the new dialect already exists, delete the
+ // materialized operation in favor of the existing one.
+ if (auto *existingOp = uniquedConstants.lookup(newKey)) {
+ constInst->erase();
+ referencedDialects[existingOp].push_back(dialect);
+ return constInst = existingOp;
+ }
+
+ // Otherwise, update the new dialect to the materialized operation.
+ referencedDialects[constInst].assign({dialect, newDialect});
+ auto newIt = uniquedConstants.insert({newKey, constInst});
+ return newIt.first->second;
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
new file mode 100644
index 00000000000..1eb9c57639a
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -0,0 +1,247 @@
+//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements mlir::applyPatternsGreedily.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "pattern-matcher"
+
+static llvm::cl::opt<unsigned> maxPatternMatchIterations(
+ "mlir-max-pattern-match-iterations",
+ llvm::cl::desc("Max number of iterations scanning for pattern match"),
+ llvm::cl::init(10));
+
+namespace {
+
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver : public PatternRewriter {
+public:
+ explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+ const OwningRewritePatternList &patterns)
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
+ worklist.reserve(64);
+ }
+
+ /// Perform the rewrites. Return true if the rewrite converges in
+ /// `maxIterations`.
+ bool simplify(MutableArrayRef<Region> regions, int maxIterations);
+
+ void addToWorklist(Operation *op) {
+ // Check to see if the worklist already contains this op.
+ if (worklistMap.count(op))
+ return;
+
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+ }
+
+ Operation *popFromWorklist() {
+ auto *op = worklist.back();
+ worklist.pop_back();
+
+ // This operation is no longer in the worklist, keep worklistMap up to date.
+ if (op)
+ worklistMap.erase(op);
+ return op;
+ }
+
+ /// If the specified operation is in the worklist, remove it. If not, this is
+ /// a no-op.
+ void removeFromWorklist(Operation *op) {
+ auto it = worklistMap.find(op);
+ if (it != worklistMap.end()) {
+ assert(worklist[it->second] == op && "malformed worklist data structure");
+ worklist[it->second] = nullptr;
+ worklistMap.erase(it);
+ }
+ }
+
+ // These are hooks implemented for PatternRewriter.
+protected:
+ // Implement the hook for inserting operations, and make sure that newly
+ // inserted ops are added to the worklist for processing.
+ Operation *insert(Operation *op) override {
+ addToWorklist(op);
+ return OpBuilder::insert(op);
+ }
+
+ // If an operation is about to be removed, make sure it is not in our
+ // worklist anymore because we'd get dangling references to it.
+ void notifyOperationRemoved(Operation *op) override {
+ addToWorklist(op->getOperands());
+ op->walk([this](Operation *operation) {
+ removeFromWorklist(operation);
+ folder.notifyRemoval(operation);
+ });
+ }
+
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyRootReplaced(Operation *op) override {
+ for (auto result : op->getResults())
+ for (auto *user : result->getUsers())
+ addToWorklist(user);
+ }
+
+private:
+ // Look over the provided operands for any defining operations that should
+ // be re-added to the worklist. This function should be called when an
+ // operation is modified or removed, as it may trigger further
+ // simplifications.
+ template <typename Operands> void addToWorklist(Operands &&operands) {
+ for (Value operand : operands) {
+ // If the use count of this operand is now < 2, we re-add the defining
+ // operation to the worklist.
+ // TODO(riverriddle) This is based on the fact that zero use operations
+ // may be deleted, and that single use values often have more
+ // canonicalization opportunities.
+ if (!operand->use_empty() && !operand->hasOneUse())
+ continue;
+ if (auto *defInst = operand->getDefiningOp())
+ addToWorklist(defInst);
+ }
+ }
+
+ /// The low-level pattern matcher.
+ RewritePatternMatcher matcher;
+
+ /// The worklist for this transformation keeps track of the operations that
+ /// need to be revisited, plus their index in the worklist. This allows us to
+ /// efficiently remove operations from the worklist when they are erased, even
+ /// if they aren't the root of a pattern.
+ std::vector<Operation *> worklist;
+ DenseMap<Operation *, unsigned> worklistMap;
+
+ /// Non-pattern based folder for operations.
+ OperationFolder folder;
+};
+} // end anonymous namespace
+
+/// Perform the rewrites.
+bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
+ int maxIterations) {
+ // Add the given operation to the worklist.
+ auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
+ bool changed = false;
+ int i = 0;
+ do {
+ // Add all nested operations to the worklist.
+ for (auto &region : regions)
+ region.walk(collectOps);
+
+ // These are scratch vectors used in the folding loop below.
+ SmallVector<Value, 8> originalOperands, resultValues;
+
+ changed = false;
+ while (!worklist.empty()) {
+ auto *op = popFromWorklist();
+
+ // Nulls get added to the worklist when operations are removed, ignore
+ // them.
+ if (op == nullptr)
+ continue;
+
+ // If the operation has no side effects, and no users, then it is
+ // trivially dead - remove it.
+ if (op->hasNoSideEffect() && op->use_empty()) {
+ // Be careful to update bookkeeping.
+ notifyOperationRemoved(op);
+ op->erase();
+ continue;
+ }
+
+ // Collects all the operands and result uses of the given `op` into work
+ // list. Also remove `op` and nested ops from worklist.
+ originalOperands.assign(op->operand_begin(), op->operand_end());
+ auto preReplaceAction = [&](Operation *op) {
+ // Add the operands to the worklist for visitation.
+ addToWorklist(originalOperands);
+
+ // Add all the users of the result to the worklist so we make sure
+ // to revisit them.
+ for (auto result : op->getResults())
+ for (auto *operand : result->getUsers())
+ addToWorklist(operand);
+
+ notifyOperationRemoved(op);
+ };
+
+ // Try to fold this op.
+ if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
+ changed |= true;
+ continue;
+ }
+
+ // Make sure that any new operations are inserted at this point.
+ setInsertionPoint(op);
+
+ // Try to match one of the patterns. The rewriter is automatically
+ // notified of any necessary changes, so there is nothing else to do here.
+ changed |= matcher.matchAndRewrite(op, *this);
+ }
+
+ // After applying patterns, make sure that the CFG of each of the regions is
+ // kept up to date.
+ changed |= succeeded(simplifyRegions(regions));
+ } while (changed && ++i < maxIterations);
+ // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+ return !changed;
+}
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return true if no more patterns can be matched in
+/// the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself.
+///
+bool mlir::applyPatternsGreedily(Operation *op,
+ const OwningRewritePatternList &patterns) {
+ return applyPatternsGreedily(op->getRegions(), patterns);
+}
+
+/// Rewrite the given regions, which must be isolated from above.
+bool mlir::applyPatternsGreedily(MutableArrayRef<Region> regions,
+ const OwningRewritePatternList &patterns) {
+ if (regions.empty())
+ return true;
+
+ // The top-level operation must be known to be isolated from above to
+ // prevent performing canonicalizations on operations defined at or above
+ // the region containing 'op'.
+ auto regionIsIsolated = [](Region &region) {
+ return region.getParentOp()->isKnownIsolatedFromAbove();
+ };
+ (void)regionIsIsolated;
+ assert(llvm::all_of(regions, regionIsIsolated) &&
+ "patterns can only be applied to operations IsolatedFromAbove");
+
+ // Start the pattern driver.
+ GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
+ bool converged = driver.simplify(regions, maxPatternMatchIterations);
+ LLVM_DEBUG(if (!converged) {
+ llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+ << maxPatternMatchIterations << " times";
+ });
+ return converged;
+}
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
new file mode 100644
index 00000000000..1ac286c67fb
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -0,0 +1,356 @@
+//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous inlining utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/InliningUtils.h"
+
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "inlining"
+
+using namespace mlir;
+
+/// Remap locations from the inlined blocks with CallSiteLoc locations with the
+/// provided caller location.
+static void
+remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
+ Location callerLoc) {
+ DenseMap<Location, Location> mappedLocations;
+ auto remapOpLoc = [&](Operation *op) {
+ auto it = mappedLocations.find(op->getLoc());
+ if (it == mappedLocations.end()) {
+ auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
+ it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
+ }
+ op->setLoc(it->second);
+ };
+ for (auto &block : inlinedBlocks)
+ block.walk(remapOpLoc);
+}
+
+static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
+ BlockAndValueMapping &mapper) {
+ auto remapOperands = [&](Operation *op) {
+ for (auto &operand : op->getOpOperands())
+ if (auto mappedOp = mapper.lookupOrNull(operand.get()))
+ operand.set(mappedOp);
+ };
+ for (auto &block : inlinedBlocks)
+ block.walk(remapOperands);
+}
+
+//===----------------------------------------------------------------------===//
+// InlinerInterface
+//===----------------------------------------------------------------------===//
+
+bool InlinerInterface::isLegalToInline(
+ Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
+ // Regions can always be inlined into functions.
+ if (isa<FuncOp>(dest->getParentOp()))
+ return true;
+
+ auto *handler = getInterfaceFor(dest->getParentOp());
+ return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
+}
+
+bool InlinerInterface::isLegalToInline(
+ Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
+ auto *handler = getInterfaceFor(op);
+ return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
+}
+
+bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
+ auto *handler = getInterfaceFor(op);
+ return handler ? handler->shouldAnalyzeRecursively(op) : true;
+}
+
+/// Handle the given inlined terminator by replacing it with a new operation
+/// as necessary.
+void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
+ auto *handler = getInterfaceFor(op);
+ assert(handler && "expected valid dialect handler");
+ handler->handleTerminator(op, newDest);
+}
+
+/// Handle the given inlined terminator by replacing it with a new operation
+/// as necessary.
+void InlinerInterface::handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const {
+ auto *handler = getInterfaceFor(op);
+ assert(handler && "expected valid dialect handler");
+ handler->handleTerminator(op, valuesToRepl);
+}
+
+/// Utility to check that all of the operations within 'src' can be inlined.
+static bool isLegalToInline(InlinerInterface &interface, Region *src,
+ Region *insertRegion,
+ BlockAndValueMapping &valueMapping) {
+ for (auto &block : *src) {
+ for (auto &op : block) {
+ // Check this operation.
+ if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "* Illegal to inline because of op: ";
+ op.dump();
+ });
+ return false;
+ }
+ // Check any nested regions.
+ if (interface.shouldAnalyzeRecursively(&op) &&
+ llvm::any_of(op.getRegions(), [&](Region &region) {
+ return !isLegalToInline(interface, &region, insertRegion,
+ valueMapping);
+ }))
+ return false;
+ }
+ }
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Inline Methods
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint,
+ BlockAndValueMapping &mapper,
+ ArrayRef<Value> resultsToReplace,
+ Optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ // We expect the region to have at least one block.
+ if (src->empty())
+ return failure();
+
+ // Check that all of the region arguments have been mapped.
+ auto *srcEntryBlock = &src->front();
+ if (llvm::any_of(srcEntryBlock->getArguments(),
+ [&](BlockArgument arg) { return !mapper.contains(arg); }))
+ return failure();
+
+ // The insertion point must be within a block.
+ Block *insertBlock = inlinePoint->getBlock();
+ if (!insertBlock)
+ return failure();
+ Region *insertRegion = insertBlock->getParent();
+
+ // Check that the operations within the source region are valid to inline.
+ if (!interface.isLegalToInline(insertRegion, src, mapper) ||
+ !isLegalToInline(interface, src, insertRegion, mapper))
+ return failure();
+
+ // Split the insertion block.
+ Block *postInsertBlock =
+ insertBlock->splitBlock(++inlinePoint->getIterator());
+
+ // Check to see if the region is being cloned, or moved inline. In either
+ // case, move the new blocks after the 'insertBlock' to improve IR
+ // readability.
+ if (shouldCloneInlinedRegion)
+ src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
+ else
+ insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
+ src->getBlocks(), src->begin(),
+ src->end());
+
+ // Get the range of newly inserted blocks.
+ auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
+ postInsertBlock->getIterator());
+ Block *firstNewBlock = &*newBlocks.begin();
+
+ // Remap the locations of the inlined operations if a valid source location
+ // was provided.
+ if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
+ remapInlinedLocations(newBlocks, *inlineLoc);
+
+ // If the blocks were moved in-place, make sure to remap any necessary
+ // operands.
+ if (!shouldCloneInlinedRegion)
+ remapInlinedOperands(newBlocks, mapper);
+
+ // Process the newly inlined blocks.
+ interface.processInlinedBlocks(newBlocks);
+
+ // Handle the case where only a single block was inlined.
+ if (std::next(newBlocks.begin()) == newBlocks.end()) {
+ // Have the interface handle the terminator of this block.
+ auto *firstBlockTerminator = firstNewBlock->getTerminator();
+ interface.handleTerminator(firstBlockTerminator, resultsToReplace);
+ firstBlockTerminator->erase();
+
+ // Merge the post insert block into the cloned entry block.
+ firstNewBlock->getOperations().splice(firstNewBlock->end(),
+ postInsertBlock->getOperations());
+ postInsertBlock->erase();
+ } else {
+ // Otherwise, there were multiple blocks inlined. Add arguments to the post
+ // insertion block to represent the results to replace.
+ for (Value resultToRepl : resultsToReplace) {
+ resultToRepl->replaceAllUsesWith(
+ postInsertBlock->addArgument(resultToRepl->getType()));
+ }
+
+ /// Handle the terminators for each of the new blocks.
+ for (auto &newBlock : newBlocks)
+ interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
+ }
+
+ // Splice the instructions of the inlined entry block into the insert block.
+ insertBlock->getOperations().splice(insertBlock->end(),
+ firstNewBlock->getOperations());
+ firstNewBlock->erase();
+ return success();
+}
+
+/// This function is an overload of the above 'inlineRegion' that allows for
+/// providing the set of operands ('inlinedOperands') that should be used
+/// in-favor of the region arguments when inlining.
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint,
+ ArrayRef<Value> inlinedOperands,
+ ArrayRef<Value> resultsToReplace,
+ Optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ // We expect the region to have at least one block.
+ if (src->empty())
+ return failure();
+
+ auto *entryBlock = &src->front();
+ if (inlinedOperands.size() != entryBlock->getNumArguments())
+ return failure();
+
+ // Map the provided call operands to the arguments of the region.
+ BlockAndValueMapping mapper;
+ for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
+ // Verify that the types of the provided values match the function argument
+ // types.
+ BlockArgument regionArg = entryBlock->getArgument(i);
+ if (inlinedOperands[i]->getType() != regionArg->getType())
+ return failure();
+ mapper.map(regionArg, inlinedOperands[i]);
+ }
+
+ // Call into the main region inliner function.
+ return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
+ inlineLoc, shouldCloneInlinedRegion);
+}
+
+/// Utility function used to generate a cast operation from the given interface,
+/// or return nullptr if a cast could not be generated.
+static Value materializeConversion(const DialectInlinerInterface *interface,
+ SmallVectorImpl<Operation *> &castOps,
+ OpBuilder &castBuilder, Value arg, Type type,
+ Location conversionLoc) {
+ if (!interface)
+ return nullptr;
+
+ // Check to see if the interface for the call can materialize a conversion.
+ Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
+ type, conversionLoc);
+ if (!castOp)
+ return nullptr;
+ castOps.push_back(castOp);
+
+ // Ensure that the generated cast is correct.
+ assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
+ castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
+ return castOp->getResult(0);
+}
+
+/// This function inlines a given region, 'src', of a callable operation,
+/// 'callable', into the location defined by the given call operation. This
+/// function returns failure if inlining is not possible, success otherwise. On
+/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
+/// corresponds to whether the source region should be cloned into the 'call' or
+/// spliced directly.
+LogicalResult mlir::inlineCall(InlinerInterface &interface,
+ CallOpInterface call,
+ CallableOpInterface callable, Region *src,
+ bool shouldCloneInlinedRegion) {
+ // We expect the region to have at least one block.
+ if (src->empty())
+ return failure();
+ auto *entryBlock = &src->front();
+ ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
+
+ // Make sure that the number of arguments and results matchup between the call
+ // and the region.
+ SmallVector<Value, 8> callOperands(call.getArgOperands());
+ SmallVector<Value, 8> callResults(call.getOperation()->getResults());
+ if (callOperands.size() != entryBlock->getNumArguments() ||
+ callResults.size() != callableResultTypes.size())
+ return failure();
+
+ // A set of cast operations generated to matchup the signature of the region
+ // with the signature of the call.
+ SmallVector<Operation *, 4> castOps;
+ castOps.reserve(callOperands.size() + callResults.size());
+
+ // Functor used to cleanup generated state on failure.
+ auto cleanupState = [&] {
+ for (auto *op : castOps) {
+ op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
+ op->erase();
+ }
+ return failure();
+ };
+
+ // Builder used for any conversion operations that need to be materialized.
+ OpBuilder castBuilder(call);
+ Location castLoc = call.getLoc();
+ auto *callInterface = interface.getInterfaceFor(call.getDialect());
+
+ // Map the provided call operands to the arguments of the region.
+ BlockAndValueMapping mapper;
+ for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
+ BlockArgument regionArg = entryBlock->getArgument(i);
+ Value operand = callOperands[i];
+
+ // If the call operand doesn't match the expected region argument, try to
+ // generate a cast.
+ Type regionArgType = regionArg->getType();
+ if (operand->getType() != regionArgType) {
+ if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
+ operand, regionArgType, castLoc)))
+ return cleanupState();
+ }
+ mapper.map(regionArg, operand);
+ }
+
+ // Ensure that the resultant values of the call, match the callable.
+ castBuilder.setInsertionPointAfter(call);
+ for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
+ Value callResult = callResults[i];
+ if (callResult->getType() == callableResultTypes[i])
+ continue;
+
+ // Generate a conversion that will produce the original type, so that the IR
+ // is still valid after the original call gets replaced.
+ Value castResult =
+ materializeConversion(callInterface, castOps, castBuilder, callResult,
+ callResult->getType(), castLoc);
+ if (!castResult)
+ return cleanupState();
+ callResult->replaceAllUsesWith(castResult);
+ castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
+ }
+
+ // Attempt to inline the call.
+ if (failed(inlineRegion(interface, src, call, mapper, callResults,
+ call.getLoc(), shouldCloneInlinedRegion)))
+ return cleanupState();
+ return success();
+}
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
new file mode 100644
index 00000000000..b0d9fdf5fd8
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -0,0 +1,480 @@
+//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop fusion transformation utility functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopFusionUtils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "loop-fusion-utils"
+
+using namespace mlir;
+
+// Gathers all load and store memref accesses in 'opA' into 'values', where
+// 'values[memref] == true' for each store operation.
+static void getLoadAndStoreMemRefAccesses(Operation *opA,
+ DenseMap<Value, bool> &values) {
+ opA->walk([&](Operation *op) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+ if (values.count(loadOp.getMemRef()) == 0)
+ values[loadOp.getMemRef()] = false;
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ values[storeOp.getMemRef()] = true;
+ }
+ });
+}
+
+// Returns true if 'op' is a load or store operation which access an memref
+// accessed 'values' and at least one of the access is a store operation.
+// Returns false otherwise.
+static bool isDependentLoadOrStoreOp(Operation *op,
+ DenseMap<Value, bool> &values) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+ return values.count(loadOp.getMemRef()) > 0 &&
+ values[loadOp.getMemRef()] == true;
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ return values.count(storeOp.getMemRef()) > 0;
+ }
+ return false;
+}
+
+// Returns the first operation in range ('opA', 'opB') which has a data
+// dependence on 'opA'. Returns 'nullptr' of no dependence exists.
+static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
+ // Record memref values from all loads/store in loop nest rooted at 'opA'.
+ // Map from memref value to bool which is true if store, false otherwise.
+ DenseMap<Value, bool> values;
+ getLoadAndStoreMemRefAccesses(opA, values);
+
+ // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
+ // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
+ // and at least one of the accesses is a store).
+ Operation *firstDepOp = nullptr;
+ for (Block::iterator it = std::next(Block::iterator(opA));
+ it != Block::iterator(opB); ++it) {
+ Operation *opX = &(*it);
+ opX->walk([&](Operation *op) {
+ if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
+ firstDepOp = opX;
+ });
+ if (firstDepOp)
+ break;
+ }
+ return firstDepOp;
+}
+
+// Returns the last operation 'opX' in range ('opA', 'opB'), for which there
+// exists a data dependence from 'opX' to 'opB'.
+// Returns 'nullptr' of no dependence exists.
+static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
+ // Record memref values from all loads/store in loop nest rooted at 'opB'.
+ // Map from memref value to bool which is true if store, false otherwise.
+ DenseMap<Value, bool> values;
+ getLoadAndStoreMemRefAccesses(opB, values);
+
+ // For each 'opX' in block in range ('opA', 'opB') in reverse order,
+ // check if there is a data dependence from 'opX' to 'opB':
+ // *) 'opX' and 'opB' access the same memref and at least one of the accesses
+ // is a store.
+ // *) 'opX' produces an SSA Value which is used by 'opB'.
+ Operation *lastDepOp = nullptr;
+ for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
+ it != Block::reverse_iterator(opA); ++it) {
+ Operation *opX = &(*it);
+ opX->walk([&](Operation *op) {
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
+ if (isDependentLoadOrStoreOp(op, values)) {
+ lastDepOp = opX;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ }
+ for (auto value : op->getResults()) {
+ for (auto user : value->getUsers()) {
+ SmallVector<AffineForOp, 4> loops;
+ // Check if any loop in loop nest surrounding 'user' is 'opB'.
+ getLoopIVs(*user, &loops);
+ if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
+ lastDepOp = opX;
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (lastDepOp)
+ break;
+ }
+ return lastDepOp;
+}
+
+// Computes and returns an insertion point operation, before which the
+// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
+// dependences. Returns nullptr if no such insertion point is found.
+static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
+ AffineForOp dstForOp) {
+ bool isSrcForOpBeforeDstForOp =
+ srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
+ auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
+ auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
+
+ auto *firstDepOpA =
+ getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+ auto *lastDepOpB =
+ getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+ // Block:
+ // ...
+ // |-- opA
+ // | ...
+ // | lastDepOpB --|
+ // | ... |
+ // |-> firstDepOpA |
+ // ... |
+ // opB <---------
+ //
+ // Valid insertion point range: (lastDepOpB, firstDepOpA)
+ //
+ if (firstDepOpA != nullptr) {
+ if (lastDepOpB != nullptr) {
+ if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
+ // No valid insertion point exists which preserves dependences.
+ return nullptr;
+ }
+ // Return insertion point in valid range closest to 'opB'.
+ // TODO(andydavis) Consider other insertion points in valid range.
+ return firstDepOpA;
+ }
+ // No dependences from 'opA' to operation in range ('opA', 'opB'), return
+ // 'opB' insertion point.
+ return forOpB.getOperation();
+}
+
+// Gathers all load and store ops in loop nest rooted at 'forOp' into
+// 'loadAndStoreOps'.
+static bool
+gatherLoadsAndStores(AffineForOp forOp,
+ SmallVectorImpl<Operation *> &loadAndStoreOps) {
+ bool hasIfOp = false;
+ forOp.walk([&](Operation *op) {
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ loadAndStoreOps.push_back(op);
+ else if (isa<AffineIfOp>(op))
+ hasIfOp = true;
+ });
+ return !hasIfOp;
+}
+
+// TODO(andydavis) Prevent fusion of loop nests with side-effecting operations.
+FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
+ unsigned dstLoopDepth,
+ ComputationSliceState *srcSlice) {
+ // Return 'failure' if 'dstLoopDepth == 0'.
+ if (dstLoopDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+ return FusionResult::FailPrecondition;
+ }
+ // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
+ auto *block = srcForOp.getOperation()->getBlock();
+ if (block != dstForOp.getOperation()->getBlock()) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Return 'failure' if no valid insertion point for fused loop nest in 'block'
+ // exists which would preserve dependences.
+ if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
+ return FusionResult::FailBlockDependence;
+ }
+
+ // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
+ bool isSrcForOpBeforeDstForOp =
+ srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
+ // 'forOpA' executes before 'forOpB' in 'block'.
+ auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
+ auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
+
+ // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
+ SmallVector<Operation *, 4> opsA;
+ if (!gatherLoadsAndStores(forOpA, opsA)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
+ SmallVector<Operation *, 4> opsB;
+ if (!gatherLoadsAndStores(forOpB, opsB)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
+ unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
+ *srcForOp.getOperation(), *dstForOp.getOperation());
+
+ // Compute union of computation slices computed between all pairs of ops
+ // from 'forOpA' and 'forOpB'.
+ if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
+ isSrcForOpBeforeDstForOp, srcSlice))) {
+ LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ return FusionResult::FailPrecondition;
+ }
+
+ return FusionResult::Success;
+}
+
+/// Collect loop nest statistics (eg. loop trip count and operation count)
+/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
+/// returns false otherwise.
+bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
+ auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
+ auto *childForOp = forOp.getOperation();
+ auto *parentForOp = forOp.getParentOp();
+ if (!llvm::isa<FuncOp>(parentForOp)) {
+ if (!isa<AffineForOp>(parentForOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
+ return WalkResult::interrupt();
+ }
+ // Add mapping to 'forOp' from its parent AffineForOp.
+ stats->loopMap[parentForOp].push_back(forOp);
+ }
+
+ // Record the number of op operations in the body of 'forOp'.
+ unsigned count = 0;
+ stats->opCountMap[childForOp] = 0;
+ for (auto &op : *forOp.getBody()) {
+ if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+ ++count;
+ }
+ stats->opCountMap[childForOp] = count;
+
+ // Record trip count for 'forOp'. Set flag if trip count is not
+ // constant.
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+ if (!maybeConstTripCount.hasValue()) {
+ // Currently only constant trip count loop nests are supported.
+ LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
+ return WalkResult::interrupt();
+ }
+
+ stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+// Computes the total cost of the loop nest rooted at 'forOp'.
+// Currently, the total cost is computed by counting the total operation
+// instance count (i.e. total number of operations in the loop bodyloop
+// operation count * loop trip count) for the entire loop nest.
+// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
+// specified in the map when computing the total op instance count.
+// NOTEs: 1) This is used to compute the cost of computation slices, which are
+// sliced along the iteration dimension, and thus reduce the trip count.
+// If 'computeCostMap' is non-null, the total op count for forOps specified
+// in the map is increased (not overridden) by adding the op count from the
+// map to the existing op count for the for loop. This is done before
+// multiplying by the loop's trip count, and is used to model the cost of
+// inserting a sliced loop nest of known cost into the loop's body.
+// 2) This is also used to compute the cost of fusing a slice of some loop nest
+// within another loop.
+static int64_t getComputeCostHelper(
+ Operation *forOp, LoopNestStats &stats,
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
+ DenseMap<Operation *, int64_t> *computeCostMap) {
+ // 'opCount' is the total number operations in one iteration of 'forOp' body,
+ // minus terminator op which is a no-op.
+ int64_t opCount = stats.opCountMap[forOp] - 1;
+ if (stats.loopMap.count(forOp) > 0) {
+ for (auto childForOp : stats.loopMap[forOp]) {
+ opCount += getComputeCostHelper(childForOp.getOperation(), stats,
+ tripCountOverrideMap, computeCostMap);
+ }
+ }
+ // Add in additional op instances from slice (if specified in map).
+ if (computeCostMap != nullptr) {
+ auto it = computeCostMap->find(forOp);
+ if (it != computeCostMap->end()) {
+ opCount += it->second;
+ }
+ }
+ // Override trip count (if specified in map).
+ int64_t tripCount = stats.tripCountMap[forOp];
+ if (tripCountOverrideMap != nullptr) {
+ auto it = tripCountOverrideMap->find(forOp);
+ if (it != tripCountOverrideMap->end()) {
+ tripCount = it->second;
+ }
+ }
+ // Returns the total number of dynamic instances of operations in loop body.
+ return tripCount * opCount;
+}
+
+// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
+static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
+ assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+ assert(ubMap.getNumResults() == 1 && "expected single result bound map");
+ assert(lbMap.getNumDims() == ubMap.getNumDims());
+ assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
+ AffineExpr lbExpr(lbMap.getResult(0));
+ AffineExpr ubExpr(ubMap.getResult(0));
+ auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
+ lbMap.getNumSymbols());
+ auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+ if (!cExpr)
+ return None;
+ return cExpr.getValue();
+}
+
+// Return the number of iterations in the given slice.
+static uint64_t getSliceIterationCount(
+ const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
+ uint64_t iterCount = 1;
+ for (const auto &count : sliceTripCountMap) {
+ iterCount *= count.second;
+ }
+ return iterCount;
+}
+
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
+// nest surrounding represented by slice loop bounds in 'slice'.
+// Returns true on success, false otherwise (if a non-constant trip count
+// was encountered).
+// TODO(andydavis) Make this work with non-unit step loops.
+static bool buildSliceTripCountMap(
+ ComputationSliceState *slice,
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
+ unsigned numSrcLoopIVs = slice->ivs.size();
+ // Populate map from AffineForOp -> trip count
+ for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+ AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
+ auto *op = forOp.getOperation();
+ AffineMap lbMap = slice->lbs[i];
+ AffineMap ubMap = slice->ubs[i];
+ if (lbMap == AffineMap() || ubMap == AffineMap()) {
+ // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
+ if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
+ (*tripCountMap)[op] =
+ forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
+ continue;
+ }
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+ if (maybeConstTripCount.hasValue()) {
+ (*tripCountMap)[op] = maybeConstTripCount.getValue();
+ continue;
+ }
+ return false;
+ }
+ Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
+ // Slice bounds are created with a constant ub - lb difference.
+ if (!tripCount.hasValue())
+ return false;
+ (*tripCountMap)[op] = tripCount.getValue();
+ }
+ return true;
+}
+
+/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
+/// Currently, the total cost is computed by counting the total operation
+/// instance count (i.e. total number of operations in the loop body * loop
+/// trip count) for the entire loop nest.
+int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
+ return getComputeCostHelper(forOp.getOperation(), stats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
+}
+
+/// Computes and returns in 'computeCost', the total compute cost of fusing the
+/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
+/// the total cost is computed by counting the total operation instance count
+/// (i.e. total number of operations in the loop body * loop trip count) for
+/// the entire loop nest.
+bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
+ AffineForOp dstForOp, LoopNestStats &dstStats,
+ ComputationSliceState *slice,
+ int64_t *computeCost) {
+ llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+ DenseMap<Operation *, int64_t> computeCostMap;
+
+ // Build trip count map for computation slice.
+ if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
+ return false;
+ // Checks whether a store to load forwarding will happen.
+ int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+ assert(sliceIterationCount > 0);
+ bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+ auto *insertPointParent = slice->insertPoint->getParentOp();
+
+ // The store and loads to this memref will disappear.
+ // TODO(andydavis) Add load coalescing to memref data flow opt pass.
+ if (storeLoadFwdGuaranteed) {
+ // Subtract from operation count the loads/store we expect load/store
+ // forwarding to remove.
+ unsigned storeCount = 0;
+ llvm::SmallDenseSet<Value, 4> storeMemrefs;
+ srcForOp.walk([&](Operation *op) {
+ if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ storeMemrefs.insert(storeOp.getMemRef());
+ ++storeCount;
+ }
+ });
+ // Subtract out any store ops in single-iteration src slice loop nest.
+ if (storeCount > 0)
+ computeCostMap[insertPointParent] = -storeCount;
+ // Subtract out any load users of 'storeMemrefs' nested below
+ // 'insertPointParent'.
+ for (auto value : storeMemrefs) {
+ for (auto *user : value->getUsers()) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+ SmallVector<AffineForOp, 4> loops;
+ // Check if any loop in loop nest surrounding 'user' is
+ // 'insertPointParent'.
+ getLoopIVs(*user, &loops);
+ if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
+ if (auto forOp =
+ dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
+ if (computeCostMap.count(forOp) == 0)
+ computeCostMap[forOp] = 0;
+ computeCostMap[forOp] -= 1;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Compute op instance count for the src loop nest with iteration slicing.
+ int64_t sliceComputeCost = getComputeCostHelper(
+ srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+
+ // Compute cost of fusion for this depth.
+ computeCostMap[insertPointParent] = sliceComputeCost;
+
+ *computeCost =
+ getComputeCostHelper(dstForOp.getOperation(), dstStats,
+ /*tripCountOverrideMap=*/nullptr, &computeCostMap);
+ return true;
+}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
new file mode 100644
index 00000000000..0fece54132a
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -0,0 +1,1779 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopUtils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "LoopUtils"
+
+using namespace mlir;
+using llvm::SetVector;
+using llvm::SmallMapVector;
+
+/// Computes the cleanup loop lower bound of the loop being unrolled with
+/// the specified unroll factor; this bound will also be upper bound of the main
+/// part of the unrolled loop. Computes the bound as an AffineMap with its
+/// operands or a null map when the trip count can't be expressed as an affine
+/// expression.
+void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
+ AffineMap *map,
+ SmallVectorImpl<Value> *operands,
+ OpBuilder &b) {
+ auto lbMap = forOp.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap.getNumResults() != 1) {
+ *map = AffineMap();
+ return;
+ }
+
+ AffineMap tripCountMap;
+ SmallVector<Value, 4> tripCountOperands;
+ buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
+
+ // Sometimes the trip count cannot be expressed as an affine expression.
+ if (!tripCountMap) {
+ *map = AffineMap();
+ return;
+ }
+
+ unsigned step = forOp.getStep();
+ auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
+ forOp.getLowerBoundOperands());
+
+ // For each upper bound expr, get the range.
+ // Eg: affine.for %i = lb to min (ub1, ub2),
+ // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
+ // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
+ // these affine.apply's make up the cleanup loop lower bound.
+ SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
+ SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
+ for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
+ auto tripCountExpr = tripCountMap.getResult(i);
+ bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
+ auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
+ tripCountMap.getNumSymbols(), bumpExprs[i]);
+ bumpValues[i] =
+ b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
+ }
+
+ SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
+ for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
+ newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
+
+ operands->clear();
+ operands->push_back(lb);
+ operands->append(bumpValues.begin(), bumpValues.end());
+ *map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs);
+ // Simplify the map + operands.
+ fullyComposeAffineMapAndOperands(map, operands);
+ *map = simplifyAffineMap(*map);
+ canonicalizeMapAndOperands(map, operands);
+ // Remove any affine.apply's that became dead from the simplification above.
+ for (auto v : bumpValues) {
+ if (v->use_empty()) {
+ v->getDefiningOp()->erase();
+ }
+ }
+ if (lb.use_empty())
+ lb.erase();
+}
+
+/// Promotes the loop body of a forOp to its containing block if the forOp
+/// was known to have a single iteration.
+// TODO(bondhugula): extend this for arbitrary affine bounds.
+LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
+ Optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount.hasValue() || tripCount.getValue() != 1)
+ return failure();
+
+ // TODO(mlir-team): there is no builder for a max.
+ if (forOp.getLowerBoundMap().getNumResults() != 1)
+ return failure();
+
+ // Replaces all IV uses to its single iteration value.
+ auto iv = forOp.getInductionVar();
+ Operation *op = forOp.getOperation();
+ if (!iv->use_empty()) {
+ if (forOp.hasConstantLowerBound()) {
+ OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
+ auto constOp = topBuilder.create<ConstantIndexOp>(
+ forOp.getLoc(), forOp.getConstantLowerBound());
+ iv->replaceAllUsesWith(constOp);
+ } else {
+ AffineBound lb = forOp.getLowerBound();
+ SmallVector<Value, 4> lbOperands(lb.operand_begin(), lb.operand_end());
+ OpBuilder builder(op->getBlock(), Block::iterator(op));
+ if (lb.getMap() == builder.getDimIdentityMap()) {
+ // No need of generating an affine.apply.
+ iv->replaceAllUsesWith(lbOperands[0]);
+ } else {
+ auto affineApplyOp = builder.create<AffineApplyOp>(
+ op->getLoc(), lb.getMap(), lbOperands);
+ iv->replaceAllUsesWith(affineApplyOp);
+ }
+ }
+ }
+ // Move the loop body operations, except for terminator, to the loop's
+ // containing block.
+ auto *block = op->getBlock();
+ forOp.getBody()->getOperations().back().erase();
+ block->getOperations().splice(Block::iterator(op),
+ forOp.getBody()->getOperations());
+ forOp.erase();
+ return success();
+}
+
+/// Promotes all single iteration for op's in the FuncOp, i.e., moves
+/// their body into the containing Block.
+void mlir::promoteSingleIterationLoops(FuncOp f) {
+ // Gathers all innermost loops through a post order pruned walk.
+ f.walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
+}
+
+/// Generates a 'affine.for' op with the specified lower and upper bounds
+/// while generating the right IV remappings for the shifted operations. The
+/// operation blocks that go into the loop are specified in instGroupQueue
+/// starting from the specified offset, and in that order; the first element of
+/// the pair specifies the shift applied to that group of operations; note
+/// that the shift is multiplied by the loop step before being applied. Returns
+/// nullptr if the generated loop simplifies to a single iteration one.
+static AffineForOp
+generateLoop(AffineMap lbMap, AffineMap ubMap,
+ const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>>
+ &instGroupQueue,
+ unsigned offset, AffineForOp srcForInst, OpBuilder b) {
+ SmallVector<Value, 4> lbOperands(srcForInst.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(srcForInst.getUpperBoundOperands());
+
+ assert(lbMap.getNumInputs() == lbOperands.size());
+ assert(ubMap.getNumInputs() == ubOperands.size());
+
+ auto loopChunk =
+ b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands,
+ ubMap, srcForInst.getStep());
+ auto loopChunkIV = loopChunk.getInductionVar();
+ auto srcIV = srcForInst.getInductionVar();
+
+ BlockAndValueMapping operandMap;
+
+ OpBuilder bodyBuilder = loopChunk.getBodyBuilder();
+ for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
+ it != e; ++it) {
+ uint64_t shift = it->first;
+ auto insts = it->second;
+ // All 'same shift' operations get added with their operands being
+ // remapped to results of cloned operations, and their IV used remapped.
+ // Generate the remapping if the shift is not zero: remappedIV = newIV -
+ // shift.
+ if (!srcIV->use_empty() && shift != 0) {
+ auto ivRemap = bodyBuilder.create<AffineApplyOp>(
+ srcForInst.getLoc(),
+ bodyBuilder.getSingleDimShiftAffineMap(
+ -static_cast<int64_t>(srcForInst.getStep() * shift)),
+ loopChunkIV);
+ operandMap.map(srcIV, ivRemap);
+ } else {
+ operandMap.map(srcIV, loopChunkIV);
+ }
+ for (auto *op : insts) {
+ if (!isa<AffineTerminatorOp>(op))
+ bodyBuilder.clone(*op, operandMap);
+ }
+ };
+ if (succeeded(promoteIfSingleIteration(loopChunk)))
+ return AffineForOp();
+ return loopChunk;
+}
+
+/// Skew the operations in the body of a 'affine.for' operation with the
+/// specified operation-wise shifts. The shifts are with respect to the
+/// original execution order, and are multiplied by the loop 'step' before being
+/// applied. A shift of zero for each operation will lead to no change.
+// The skewing of operations with respect to one another can be used for
+// example to allow overlap of asynchronous operations (such as DMA
+// communication) with computation, or just relative shifting of operations
+// for better register reuse, locality or parallelism. As such, the shifts are
+// typically expected to be at most of the order of the number of operations.
+// This method should not be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of operations
+// in the body of the for loop - (using the 'sweep line' paradigm). This method
+// asserts preservation of SSA dominance. A check for that as well as that for
+// memory-based dependence preservation check rests with the users of this
+// method.
+LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
+ bool unrollPrologueEpilogue) {
+ if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+ return success();
+
+ // If the trip counts aren't constant, we would need versioning and
+ // conditional guards (or context information to prevent such versioning). The
+ // better way to pipeline for such loops is to first tile them and extract
+ // constant trip count "full tiles" before applying this.
+ auto mayBeConstTripCount = getConstantTripCount(forOp);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
+ return success();
+ }
+ uint64_t tripCount = mayBeConstTripCount.getValue();
+
+ assert(isInstwiseShiftValid(forOp, shifts) &&
+ "shifts will lead to an invalid transformation\n");
+
+ int64_t step = forOp.getStep();
+
+ unsigned numChildInsts = forOp.getBody()->getOperations().size();
+
+ // Do a linear time (counting) sort for the shifts.
+ uint64_t maxShift = 0;
+ for (unsigned i = 0; i < numChildInsts; i++) {
+ maxShift = std::max(maxShift, shifts[i]);
+ }
+ // Such large shifts are not the typical use case.
+ if (maxShift >= numChildInsts) {
+ forOp.emitWarning("not shifting because shifts are unrealistically large");
+ return success();
+ }
+
+ // An array of operation groups sorted by shift amount; each group has all
+ // operations with the same shift in the order in which they appear in the
+ // body of the 'affine.for' op.
+ std::vector<std::vector<Operation *>> sortedInstGroups(maxShift + 1);
+ unsigned pos = 0;
+ for (auto &op : *forOp.getBody()) {
+ auto shift = shifts[pos++];
+ sortedInstGroups[shift].push_back(&op);
+ }
+
+ // Unless the shifts have a specific pattern (which actually would be the
+ // common use case), prologue and epilogue are not meaningfully defined.
+ // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
+ // loop generated as the prologue and the last as epilogue and unroll these
+ // fully.
+ AffineForOp prologue;
+ AffineForOp epilogue;
+
+ // Do a sweep over the sorted shifts while storing open groups in a
+ // vector, and generating loop portions as necessary during the sweep. A block
+ // of operations is paired with its shift.
+ std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> instGroupQueue;
+
+ auto origLbMap = forOp.getLowerBoundMap();
+ uint64_t lbShift = 0;
+ OpBuilder b(forOp.getOperation());
+ for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
+ // If nothing is shifted by d, continue.
+ if (sortedInstGroups[d].empty())
+ continue;
+ if (!instGroupQueue.empty()) {
+ assert(d >= 1 &&
+ "Queue expected to be empty when the first block is found");
+ // The interval for which the loop needs to be generated here is:
+ // [lbShift, min(lbShift + tripCount, d)) and the body of the
+ // loop needs to have all operations in instQueue in that order.
+ AffineForOp res;
+ if (lbShift + tripCount * step < d * step) {
+ res = generateLoop(
+ b.getShiftedAffineMap(origLbMap, lbShift),
+ b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
+ instGroupQueue, 0, forOp, b);
+ // Entire loop for the queued op groups generated, empty it.
+ instGroupQueue.clear();
+ lbShift += tripCount * step;
+ } else {
+ res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+ b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
+ 0, forOp, b);
+ lbShift = d * step;
+ }
+ if (!prologue && res)
+ prologue = res;
+ epilogue = res;
+ } else {
+ // Start of first interval.
+ lbShift = d * step;
+ }
+ // Augment the list of operations that get into the current open interval.
+ instGroupQueue.push_back({d, sortedInstGroups[d]});
+ }
+
+ // Those operations groups left in the queue now need to be processed (FIFO)
+ // and their loops completed.
+ for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
+ uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
+ epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+ b.getShiftedAffineMap(origLbMap, ubShift),
+ instGroupQueue, i, forOp, b);
+ lbShift = ubShift;
+ if (!prologue)
+ prologue = epilogue;
+ }
+
+ // Erase the original for op.
+ forOp.erase();
+
+ if (unrollPrologueEpilogue && prologue)
+ loopUnrollFull(prologue);
+ if (unrollPrologueEpilogue && !epilogue &&
+ epilogue.getOperation() != prologue.getOperation())
+ loopUnrollFull(epilogue);
+
+ return success();
+}
+
+// Collect perfectly nested loops starting from `rootForOps`. Loops are
+// perfectly nested if each loop is the first and only non-terminator operation
+// in the parent loop. Collect at most `maxLoops` loops and append them to
+// `forOps`.
+template <typename T>
+void getPerfectlyNestedLoopsImpl(
+ SmallVectorImpl<T> &forOps, T rootForOp,
+ unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
+ for (unsigned i = 0; i < maxLoops; ++i) {
+ forOps.push_back(rootForOp);
+ Block &body = rootForOp.region().front();
+ if (body.begin() != std::prev(body.end(), 2))
+ return;
+
+ rootForOp = dyn_cast<T>(&body.front());
+ if (!rootForOp)
+ return;
+ }
+}
+
+/// Get perfectly nested sequence of loops starting at root of loop nest
+/// (the first op being another AffineFor, and the second op - a terminator).
+/// A loop is perfectly nested iff: the first op in the loop's body is another
+/// AffineForOp, and the second op is a terminator).
+void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
+ AffineForOp root) {
+ getPerfectlyNestedLoopsImpl(nestedLoops, root);
+}
+
+void mlir::getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops,
+ loop::ForOp root) {
+ getPerfectlyNestedLoopsImpl(nestedLoops, root);
+}
+
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (mayBeConstantTripCount.hasValue()) {
+ uint64_t tripCount = mayBeConstantTripCount.getValue();
+ if (tripCount == 1) {
+ return promoteIfSingleIteration(forOp);
+ }
+ return loopUnrollByFactor(forOp, tripCount);
+ }
+ return failure();
+}
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant) whichever is lower.
+LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
+ uint64_t unrollFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollFactor)
+ return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
+ return loopUnrollByFactor(forOp, unrollFactor);
+}
+
+/// Unrolls this loop by the specified factor. Returns success if the loop
+/// is successfully unrolled.
+LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
+ uint64_t unrollFactor) {
+ assert(unrollFactor >= 1 && "unroll factor should be >= 1");
+
+ if (unrollFactor == 1)
+ return promoteIfSingleIteration(forOp);
+
+ if (forOp.getBody()->empty() ||
+ forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
+ return failure();
+
+ // Loops where the lower bound is a max expression isn't supported for
+ // unrolling since the trip count can be expressed as an affine function when
+ // both the lower bound and the upper bound are multi-result maps. However,
+ // one meaningful way to do such unrolling would be to specialize the loop for
+ // the 'hotspot' case and unroll that hotspot.
+ if (forOp.getLowerBoundMap().getNumResults() != 1)
+ return failure();
+
+ // If the trip count is lower than the unroll factor, no unrolled body.
+ // TODO(bondhugula): option to specify cleanup loop unrolling.
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollFactor)
+ return failure();
+
+ // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+ Operation *op = forOp.getOperation();
+ if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
+ OpBuilder builder(op->getBlock(), ++Block::iterator(op));
+ auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
+ AffineMap cleanupMap;
+ SmallVector<Value, 4> cleanupOperands;
+ getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
+ builder);
+ assert(cleanupMap &&
+ "cleanup loop lower bound map for single result lower bound maps "
+ "can always be determined");
+ cleanupForInst.setLowerBound(cleanupOperands, cleanupMap);
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(cleanupForInst);
+
+ // Adjust upper bound of the original loop; this is the same as the lower
+ // bound of the cleanup loop.
+ forOp.setUpperBound(cleanupOperands, cleanupMap);
+ }
+
+ // Scale the step of loop being unrolled by unroll factor.
+ int64_t step = forOp.getStep();
+ forOp.setStep(step * unrollFactor);
+
+ // Builder to insert unrolled bodies just before the terminator of the body of
+ // 'forOp'.
+ OpBuilder builder = forOp.getBodyBuilder();
+
+ // Keep a pointer to the last non-terminator operation in the original block
+ // so that we know what to clone (since we are doing this in-place).
+ Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2);
+
+ // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
+ auto forOpIV = forOp.getInductionVar();
+ for (unsigned i = 1; i < unrollFactor; i++) {
+ BlockAndValueMapping operandMap;
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV->use_empty()) {
+ // iv' = iv + 1/2/3...unrollFactor-1;
+ auto d0 = builder.getAffineDimExpr(0);
+ auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
+ auto ivUnroll =
+ builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
+ operandMap.map(forOpIV, ivUnroll);
+ }
+
+ // Clone the original body of 'forOp'.
+ for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd);
+ it++) {
+ builder.clone(*it, operandMap);
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(forOp);
+ return success();
+}
+
+/// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
+/// nested within 'forOpA' as the only non-terminator operation in its block.
+void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
+ auto *forOpAInst = forOpA.getOperation();
+
+ assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
+ auto &forOpABody = forOpA.getBody()->getOperations();
+ auto &forOpBBody = forOpB.getBody()->getOperations();
+
+ // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
+ // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
+ // body containing only the terminator.
+ forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst),
+ forOpABody, forOpABody.begin(),
+ std::prev(forOpABody.end()));
+ // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
+ // body (this leaves forOpB's body containing only the terminator).
+ forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
+ std::prev(forOpBBody.end()));
+ // 3) Splice forOpA into the beginning of forOpB's body.
+ forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(),
+ Block::iterator(forOpAInst));
+}
+
+// Checks each dependence component against the permutation to see if the
+// desired loop interchange would violate dependences by making the
+// dependence component lexicographically negative.
+static bool checkLoopInterchangeDependences(
+ const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
+ ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
+ // Invert permutation map.
+ unsigned maxLoopDepth = loops.size();
+ SmallVector<unsigned, 4> loopPermMapInv;
+ loopPermMapInv.resize(maxLoopDepth);
+ for (unsigned i = 0; i < maxLoopDepth; ++i)
+ loopPermMapInv[loopPermMap[i]] = i;
+
+ // Check each dependence component against the permutation to see if the
+ // desired loop interchange permutation would make the dependence vectors
+ // lexicographically negative.
+ // Example 1: [-1, 1][0, 0]
+ // Example 2: [0, 0][-1, 1]
+ for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
+ const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
+ assert(depComps.size() >= maxLoopDepth);
+ // Check if the first non-zero dependence component is positive.
+ // This iterates through loops in the desired order.
+ for (unsigned j = 0; j < maxLoopDepth; ++j) {
+ unsigned permIndex = loopPermMapInv[j];
+ assert(depComps[permIndex].lb.hasValue());
+ int64_t depCompLb = depComps[permIndex].lb.getValue();
+ if (depCompLb > 0)
+ break;
+ if (depCompLb < 0)
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
+/// nested sequence of loops in 'loops' would violate dependences.
+bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
+ ArrayRef<unsigned> loopPermMap) {
+ // Gather dependence components for dependences between all ops in loop nest
+ // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
+ assert(loopPermMap.size() == loops.size());
+ unsigned maxLoopDepth = loops.size();
+ std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
+ getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
+ return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
+}
+
+/// Performs a sequence of loop interchanges of loops in perfectly nested
+/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'.
+unsigned mlir::interchangeLoops(ArrayRef<AffineForOp> loops,
+ ArrayRef<unsigned> loopPermMap) {
+ Optional<unsigned> loopNestRootIndex;
+ for (int i = loops.size() - 1; i >= 0; --i) {
+ int permIndex = static_cast<int>(loopPermMap[i]);
+ // Store the index of the for loop which will be the new loop nest root.
+ if (permIndex == 0)
+ loopNestRootIndex = i;
+ if (permIndex > i) {
+ // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
+ sinkLoop(loops[i], permIndex - i);
+ }
+ }
+ assert(loopNestRootIndex.hasValue());
+ return loopNestRootIndex.getValue();
+}
+
+// Sinks all sequential loops to the innermost levels (while preserving
+// relative order among them) and moves all parallel loops to the
+// outermost (while again preserving relative order among them).
+AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
+ SmallVector<AffineForOp, 4> loops;
+ getPerfectlyNestedLoops(loops, forOp);
+ if (loops.size() < 2)
+ return forOp;
+
+ // Gather dependence components for dependences between all ops in loop nest
+ // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
+ unsigned maxLoopDepth = loops.size();
+ std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
+ getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
+
+ // Mark loops as either parallel or sequential.
+ SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
+ for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
+ SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
+ assert(depComps.size() >= maxLoopDepth);
+ for (unsigned j = 0; j < maxLoopDepth; ++j) {
+ DependenceComponent &depComp = depComps[j];
+ assert(depComp.lb.hasValue() && depComp.ub.hasValue());
+ if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
+ isParallelLoop[j] = false;
+ }
+ }
+
+ // Count the number of parallel loops.
+ unsigned numParallelLoops = 0;
+ for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
+ if (isParallelLoop[i])
+ ++numParallelLoops;
+
+ // Compute permutation of loops that sinks sequential loops (and thus raises
+ // parallel loops) while preserving relative order.
+ SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
+ unsigned nextSequentialLoop = numParallelLoops;
+ unsigned nextParallelLoop = 0;
+ for (unsigned i = 0; i < maxLoopDepth; ++i) {
+ if (isParallelLoop[i]) {
+ loopPermMap[i] = nextParallelLoop++;
+ } else {
+ loopPermMap[i] = nextSequentialLoop++;
+ }
+ }
+
+ // Check if permutation 'loopPermMap' would violate dependences.
+ if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
+ return forOp;
+ // Perform loop interchange according to permutation 'loopPermMap'.
+ unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap);
+ return loops[loopNestRootIndex];
+}
+
+/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels
+/// deeper in the loop nest.
+void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
+ for (unsigned i = 0; i < loopDepth; ++i) {
+ AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
+ interchangeLoops(forOp, nextForOp);
+ }
+}
+
+// Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
+// lower (resp. upper) loop bound. When called for both the lower and upper
+// bounds, the resulting IR resembles:
+//
+// ```mlir
+// affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
+// ...
+// }
+// ```
+static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
+ SmallVector<Value, 4> *operands,
+ int64_t offset = 0) {
+ auto bounds = llvm::to_vector<4>(map->getResults());
+ bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
+ operands->insert(operands->begin() + map->getNumDims(), iv);
+ *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds);
+ canonicalizeMapAndOperands(map, operands);
+}
+
+// Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
+// Stripmine-sink is a primitive building block for generalized tiling of
+// imperfectly nested loops.
+// This transformation is purely mechanical and does not check legality,
+// profitability or even structural correctness. It is the user's
+// responsibility to specify `targets` that are dominated by `forOp`.
+// Returns the new AffineForOps, one per `targets`, nested immediately under
+// each of the `targets`.
+static SmallVector<AffineForOp, 8>
+stripmineSink(AffineForOp forOp, uint64_t factor,
+ ArrayRef<AffineForOp> targets) {
+ auto originalStep = forOp.getStep();
+ auto scaledStep = originalStep * factor;
+ forOp.setStep(scaledStep);
+
+ auto *op = forOp.getOperation();
+ OpBuilder b(op->getBlock(), ++Block::iterator(op));
+
+ // Lower-bound map creation.
+ auto lbMap = forOp.getLowerBoundMap();
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
+ augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
+
+ // Upper-bound map creation.
+ auto ubMap = forOp.getUpperBoundMap();
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
+ augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
+ /*offset=*/scaledStep);
+
+ auto iv = forOp.getInductionVar();
+ SmallVector<AffineForOp, 8> innerLoops;
+ for (auto t : targets) {
+ // Insert newForOp before the terminator of `t`.
+ OpBuilder b = t.getBodyBuilder();
+ auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, originalStep);
+ auto begin = t.getBody()->begin();
+ // Skip terminator and `newForOp` which is just before the terminator.
+ auto nOps = t.getBody()->getOperations().size() - 2;
+ newForOp.getBody()->getOperations().splice(
+ newForOp.getBody()->getOperations().begin(),
+ t.getBody()->getOperations(), begin, std::next(begin, nOps));
+ replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
+ newForOp.region());
+ innerLoops.push_back(newForOp);
+ }
+
+ return innerLoops;
+}
+
+static Loops stripmineSink(loop::ForOp forOp, Value factor,
+ ArrayRef<loop::ForOp> targets) {
+ auto originalStep = forOp.step();
+ auto iv = forOp.getInductionVar();
+
+ OpBuilder b(forOp);
+ forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
+
+ Loops innerLoops;
+ for (auto t : targets) {
+ // Save information for splicing ops out of t when done
+ auto begin = t.getBody()->begin();
+ auto nOps = t.getBody()->getOperations().size();
+
+ // Insert newForOp before the terminator of `t`.
+ OpBuilder b(t.getBodyBuilder());
+ Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
+ Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
+ forOp.upperBound(), stepped);
+ Value ub =
+ b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
+
+ // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
+ auto newForOp = b.create<loop::ForOp>(t.getLoc(), iv, ub, originalStep);
+ newForOp.getBody()->getOperations().splice(
+ newForOp.getBody()->getOperations().begin(),
+ t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
+ replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
+ newForOp.region());
+
+ innerLoops.push_back(newForOp);
+ }
+
+ return innerLoops;
+}
+
+// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
+// Returns the new AffineForOps, nested immediately under `target`.
+template <typename ForType, typename SizeType>
+static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) {
+ // TODO(ntv): Use cheap structural assertions that targets are nested under
+ // forOp and that targets are not nested under each other when DominanceInfo
+ // exposes the capability. It seems overkill to construct a whole function
+ // dominance tree at this point.
+ auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target});
+ assert(res.size() == 1 && "Expected 1 inner forOp");
+ return res[0];
+}
+
+template <typename ForType, typename SizeType>
+static SmallVector<SmallVector<ForType, 8>, 8>
+tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes,
+ ArrayRef<ForType> targets) {
+ SmallVector<SmallVector<ForType, 8>, 8> res;
+ SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end());
+ for (auto it : llvm::zip(forOps, sizes)) {
+ auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
+ res.push_back(step);
+ currentTargets = step;
+ }
+ return res;
+}
+
+SmallVector<SmallVector<AffineForOp, 8>, 8>
+mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
+ ArrayRef<AffineForOp> targets) {
+ return tileImpl(forOps, sizes, targets);
+}
+
+SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps,
+ ArrayRef<Value> sizes,
+ ArrayRef<loop::ForOp> targets) {
+ return tileImpl(forOps, sizes, targets);
+}
+
+template <typename ForType, typename SizeType>
+static SmallVector<ForType, 8>
+tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) {
+ SmallVector<ForType, 8> res;
+ for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) {
+ assert(loops.size() == 1);
+ res.push_back(loops[0]);
+ }
+ return res;
+}
+
+SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
+ ArrayRef<uint64_t> sizes,
+ AffineForOp target) {
+ return tileImpl(forOps, sizes, target);
+}
+
+Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
+ loop::ForOp target) {
+ return tileImpl(forOps, sizes, target);
+}
+
+Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes) {
+ // Collect perfectly nested loops. If more size values provided than nested
+ // loops available, truncate `sizes`.
+ SmallVector<loop::ForOp, 4> forOps;
+ forOps.reserve(sizes.size());
+ getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
+ if (forOps.size() < sizes.size())
+ sizes = sizes.take_front(forOps.size());
+
+ return ::tile(forOps, sizes, forOps.back());
+}
+
+// Build the IR that performs ceil division of a positive value by a constant:
+// ceildiv(a, B) = divis(a + (B-1), B)
+// where divis is rounding-to-zero division.
+static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
+ int64_t divisor) {
+ assert(divisor > 0 && "expected positive divisor");
+ assert(dividend->getType().isIndex() && "expected index-typed value");
+
+ Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
+ Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
+ Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
+ return builder.create<SignedDivIOp>(loc, sum, divisorCst);
+}
+
+// Build the IR that performs ceil division of a positive value by another
+// positive value:
+// ceildiv(a, b) = divis(a + (b - 1), b)
+// where divis is rounding-to-zero division.
+static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
+ Value divisor) {
+ assert(dividend->getType().isIndex() && "expected index-typed value");
+
+ Value cstOne = builder.create<ConstantIndexOp>(loc, 1);
+ Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
+ Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
+ return builder.create<SignedDivIOp>(loc, sum, divisor);
+}
+
+// Hoist the ops within `outer` that appear before `inner`.
+// Such ops include the ops that have been introduced by parametric tiling.
+// Ops that come from triangular loops (i.e. that belong to the program slice
+// rooted at `outer`) and ops that have side effects cannot be hoisted.
+// Return failure when any op fails to hoist.
+static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) {
+ SetVector<Operation *> forwardSlice;
+ getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
+ return op != inner.getOperation();
+ });
+ LogicalResult status = success();
+ SmallVector<Operation *, 8> toHoist;
+ for (auto &op : outer.getBody()->getOperations()) {
+ // Stop when encountering the inner loop.
+ if (&op == inner.getOperation())
+ break;
+ // Skip over non-hoistable ops.
+ if (forwardSlice.count(&op) > 0) {
+ status = failure();
+ continue;
+ }
+ // Skip loop::ForOp, these are not considered a failure.
+ if (op.getNumRegions() > 0)
+ continue;
+ // Skip other ops with regions.
+ if (op.getNumRegions() > 0) {
+ status = failure();
+ continue;
+ }
+ // Skip if op has side effects.
+ // TODO(ntv): loads to immutable memory regions are ok.
+ if (!op.hasNoSideEffect()) {
+ status = failure();
+ continue;
+ }
+ toHoist.push_back(&op);
+ }
+ auto *outerForOp = outer.getOperation();
+ for (auto *op : toHoist)
+ op->moveBefore(outerForOp);
+ return status;
+}
+
+// Traverse the interTile and intraTile loops and try to hoist ops such that
+// bands of perfectly nested loops are isolated.
+// Return failure if either perfect interTile or perfect intraTile bands cannot
+// be formed.
+static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
+ LogicalResult status = success();
+ auto &interTile = tileLoops.first;
+ auto &intraTile = tileLoops.second;
+ auto size = interTile.size();
+ assert(size == intraTile.size());
+ if (size <= 1)
+ return success();
+ for (unsigned s = 1; s < size; ++s)
+ status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
+ : failure();
+ for (unsigned s = 1; s < size; ++s)
+ status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
+ : failure();
+ return status;
+}
+
+TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp,
+ ArrayRef<int64_t> sizes) {
+ // Collect perfectly nested loops. If more size values provided than nested
+ // loops available, truncate `sizes`.
+ SmallVector<loop::ForOp, 4> forOps;
+ forOps.reserve(sizes.size());
+ getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
+ if (forOps.size() < sizes.size())
+ sizes = sizes.take_front(forOps.size());
+
+ // Compute the tile sizes such that i-th outer loop executes size[i]
+ // iterations. Given that the loop current executes
+ // numIterations = ceildiv((upperBound - lowerBound), step)
+ // iterations, we need to tile with size ceildiv(numIterations, size[i]).
+ SmallVector<Value, 4> tileSizes;
+ tileSizes.reserve(sizes.size());
+ for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
+ assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
+
+ auto forOp = forOps[i];
+ OpBuilder builder(forOp);
+ auto loc = forOp.getLoc();
+ Value diff =
+ builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
+ Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
+ Value iterationsPerBlock =
+ ceilDivPositive(builder, loc, numIterations, sizes[i]);
+ tileSizes.push_back(iterationsPerBlock);
+ }
+
+ // Call parametric tiling with the given sizes.
+ auto intraTile = tile(forOps, tileSizes, forOps.back());
+ TileLoops tileLoops = std::make_pair(forOps, intraTile);
+
+ // TODO(ntv, zinenko) for now we just ignore the result of band isolation.
+ // In the future, mapping decisions may be impacted by the ability to
+ // isolate perfectly nested bands.
+ tryIsolateBands(tileLoops);
+
+ return tileLoops;
+}
+
+// Replaces all uses of `orig` with `replacement` except if the user is listed
+// in `exceptions`.
+static void
+replaceAllUsesExcept(Value orig, Value replacement,
+ const SmallPtrSetImpl<Operation *> &exceptions) {
+ for (auto &use : llvm::make_early_inc_range(orig->getUses())) {
+ if (exceptions.count(use.getOwner()) == 0)
+ use.set(replacement);
+ }
+}
+
+// Transform a loop with a strictly positive step
+// for %i = %lb to %ub step %s
+// into a 0-based loop with step 1
+// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+// %i = %ii * %s + %lb
+// Insert the induction variable remapping in the body of `inner`, which is
+// expected to be either `loop` or another loop perfectly nested under `loop`.
+// Insert the definition of new bounds immediate before `outer`, which is
+// expected to be either `loop` or its parent in the loop nest.
+static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
+ loop::ForOp inner) {
+ OpBuilder builder(outer);
+ Location loc = loop.getLoc();
+
+ // Check if the loop is already known to have a constant zero lower bound or
+ // a constant one step.
+ bool isZeroBased = false;
+ if (auto ubCst =
+ dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound()->getDefiningOp()))
+ isZeroBased = ubCst.getValue() == 0;
+
+ bool isStepOne = false;
+ if (auto stepCst =
+ dyn_cast_or_null<ConstantIndexOp>(loop.step()->getDefiningOp()))
+ isStepOne = stepCst.getValue() == 1;
+
+ if (isZeroBased && isStepOne)
+ return;
+
+ // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
+ // assuming the step is strictly positive. Update the bounds and the step
+ // of the loop to go from 0 to the number of iterations, if necessary.
+ // TODO(zinenko): introduce support for negative steps or emit dynamic asserts
+ // on step positivity, whatever gets implemented first.
+ Value diff =
+ builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
+ Value numIterations = ceilDivPositive(builder, loc, diff, loop.step());
+ loop.setUpperBound(numIterations);
+
+ Value lb = loop.lowerBound();
+ if (!isZeroBased) {
+ Value cst0 = builder.create<ConstantIndexOp>(loc, 0);
+ loop.setLowerBound(cst0);
+ }
+
+ Value step = loop.step();
+ if (!isStepOne) {
+ Value cst1 = builder.create<ConstantIndexOp>(loc, 1);
+ loop.setStep(cst1);
+ }
+
+ // Insert code computing the value of the original loop induction variable
+ // from the "normalized" one.
+ builder.setInsertionPointToStart(inner.getBody());
+ Value scaled =
+ isStepOne ? loop.getInductionVar()
+ : builder.create<MulIOp>(loc, loop.getInductionVar(), step);
+ Value shifted =
+ isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
+
+ SmallPtrSet<Operation *, 2> preserve{scaled->getDefiningOp(),
+ shifted->getDefiningOp()};
+ replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve);
+}
+
+void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
+ if (loops.size() < 2)
+ return;
+
+ loop::ForOp innermost = loops.back();
+ loop::ForOp outermost = loops.front();
+
+ // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
+ // allows the following code to assume upperBound is the number of iterations.
+ for (auto loop : loops)
+ normalizeLoop(loop, outermost, innermost);
+
+ // 2. Emit code computing the upper bound of the coalesced loop as product
+ // of the number of iterations of all loops.
+ OpBuilder builder(outermost);
+ Location loc = outermost.getLoc();
+ Value upperBound = outermost.upperBound();
+ for (auto loop : loops.drop_front())
+ upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
+ outermost.setUpperBound(upperBound);
+
+ builder.setInsertionPointToStart(outermost.getBody());
+
+ // 3. Remap induction variables. For each original loop, the value of the
+ // induction variable can be obtained by dividing the induction variable of
+ // the linearized loop by the total number of iterations of the loops nested
+ // in it modulo the number of iterations in this loop (remove the values
+ // related to the outer loops):
+ // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+ // Compute these iteratively from the innermost loop by creating a "running
+ // quotient" of division by the range.
+ Value previous = outermost.getInductionVar();
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ unsigned idx = loops.size() - i - 1;
+ if (i != 0)
+ previous = builder.create<SignedDivIOp>(loc, previous,
+ loops[idx + 1].upperBound());
+
+ Value iv = (i == e - 1) ? previous
+ : builder.create<SignedRemIOp>(
+ loc, previous, loops[idx].upperBound());
+ replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
+ loops.back().region());
+ }
+
+ // 4. Move the operations from the innermost just above the second-outermost
+ // loop, delete the extra terminator and the second-outermost loop.
+ loop::ForOp second = loops[1];
+ innermost.getBody()->back().erase();
+ outermost.getBody()->getOperations().splice(
+ Block::iterator(second.getOperation()),
+ innermost.getBody()->getOperations());
+ second.erase();
+}
+
+void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
+ ArrayRef<Value> numProcessors) {
+ assert(processorId.size() == numProcessors.size());
+ if (processorId.empty())
+ return;
+
+ OpBuilder b(forOp);
+ Location loc(forOp.getLoc());
+ Value mul = processorId.front();
+ for (unsigned i = 1, e = processorId.size(); i < e; ++i)
+ mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
+ processorId[i]);
+ Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
+ b.create<MulIOp>(loc, forOp.step(), mul));
+ forOp.setLowerBound(lb);
+
+ Value step = forOp.step();
+ for (auto numProcs : numProcessors)
+ step = b.create<MulIOp>(loc, step, numProcs);
+ forOp.setStep(step);
+}
+
+/// Given a memref region, determine the lowest depth at which transfers can be
+/// placed for it, and return the corresponding block, start and end positions
+/// in the block for placing incoming (read) and outgoing (write) copies
+/// respectively. The lowest depth depends on whether the region being accessed
+/// is hoistable with respect to one or more immediately surrounding loops.
+static void
+findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
+ Block::iterator &begin, Block::iterator &end,
+ Block **copyPlacementBlock,
+ Block::iterator *copyInPlacementStart,
+ Block::iterator *copyOutPlacementStart) {
+ const auto *cst = region.getConstraints();
+ SmallVector<Value, 4> symbols;
+ cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
+
+ SmallVector<AffineForOp, 4> enclosingFors;
+ getLoopIVs(*block.begin(), &enclosingFors);
+ // Walk up loop parents till we find an IV on which this region is
+ // symbolic/variant.
+ auto it = enclosingFors.rbegin();
+ for (auto e = enclosingFors.rend(); it != e; ++it) {
+ // TODO(bondhugula): also need to be checking this for regions symbols that
+ // aren't loop IVs, whether we are within their resp. defs' dominance scope.
+ if (llvm::is_contained(symbols, it->getInductionVar()))
+ break;
+ }
+
+ if (it != enclosingFors.rbegin()) {
+ auto lastInvariantIV = *std::prev(it);
+ *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
+ *copyOutPlacementStart = std::next(*copyInPlacementStart);
+ *copyPlacementBlock = lastInvariantIV.getOperation()->getBlock();
+ } else {
+ *copyInPlacementStart = begin;
+ *copyOutPlacementStart = end;
+ *copyPlacementBlock = &block;
+ }
+}
+
+// Info comprising stride and number of elements transferred every stride.
+struct StrideInfo {
+ int64_t stride;
+ int64_t numEltPerStride;
+};
+
+/// Returns striding information for a copy/transfer of this region with
+/// potentially multiple striding levels from outermost to innermost. For an
+/// n-dimensional region, there can be at most n-1 levels of striding
+/// successively nested.
+// TODO(bondhugula): make this work with non-identity layout maps.
+static void getMultiLevelStrides(const MemRefRegion &region,
+ ArrayRef<int64_t> bufferShape,
+ SmallVectorImpl<StrideInfo> *strideInfos) {
+ if (bufferShape.size() <= 1)
+ return;
+
+ int64_t numEltPerStride = 1;
+ int64_t stride = 1;
+ for (int d = bufferShape.size() - 1; d >= 1; d--) {
+ int64_t dimSize = region.memref->getType().cast<MemRefType>().getDimSize(d);
+ stride *= dimSize;
+ numEltPerStride *= bufferShape[d];
+ // A stride is needed only if the region has a shorter extent than the
+ // memref along the dimension *and* has an extent greater than one along the
+ // next major dimension.
+ if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
+ strideInfos->push_back({stride, numEltPerStride});
+ }
+ }
+}
+
+/// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
+/// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart'
+/// holds the lower coordinates of the region in the original memref to copy
+/// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in.
+static AffineForOp generatePointWiseCopy(Location loc, Value memref,
+ Value fastMemRef,
+ AffineMap memAffineMap,
+ ArrayRef<Value> memIndicesStart,
+ ArrayRef<int64_t> fastBufferShape,
+ bool isCopyOut, OpBuilder b) {
+ assert(!memIndicesStart.empty() && "only 1-d or more memrefs");
+
+ // The copy-in nest is generated as follows as an example for a 2-d region:
+ // for x = ...
+ // for y = ...
+ // fast_buf[x][y] = buf[mem_x + x][mem_y + y]
+
+ SmallVector<Value, 4> fastBufIndices, memIndices;
+ AffineForOp copyNestRoot;
+ for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) {
+ auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]);
+ if (d == 0)
+ copyNestRoot = forOp;
+ b = forOp.getBodyBuilder();
+ fastBufIndices.push_back(forOp.getInductionVar());
+
+ Value memBase =
+ (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims()))
+ ? memIndicesStart[d]
+ : b.create<AffineApplyOp>(
+ loc,
+ AffineMap::get(memAffineMap.getNumDims(),
+ memAffineMap.getNumSymbols(),
+ memAffineMap.getResult(d)),
+ memIndicesStart);
+
+ // Construct the subscript for the slow memref being copied.
+ auto memIndex = b.create<AffineApplyOp>(
+ loc,
+ AffineMap::get(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)),
+ ValueRange({memBase, forOp.getInductionVar()}));
+ memIndices.push_back(memIndex);
+ }
+
+ if (!isCopyOut) {
+ // Copy in.
+ auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
+ b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufIndices);
+ return copyNestRoot;
+ }
+
+ // Copy out.
+ auto load = b.create<AffineLoadOp>(loc, fastMemRef, fastBufIndices);
+ b.create<AffineStoreOp>(loc, load, memref, memIndices);
+ return copyNestRoot;
+}
+
+static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
+emitRemarkForBlock(Block &block) {
+ return block.getParentOp()->emitRemark();
+}
+
+/// Creates a buffer in the faster memory space for the specified memref region;
+/// generates a copy from the lower memory space to this one, and replaces all
+/// loads/stores in the block range [`begin', `end') of `block' to load/store
+/// from that buffer. Returns failure if copies could not be generated due to
+/// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
+/// in copyPlacementBlock specify the insertion points where the incoming copies
+/// and outgoing copies, respectively, should be inserted (the insertion happens
+/// right before the insertion point). Since `begin` can itself be invalidated
+/// due to the memref rewriting done from this method, the output argument
+/// `nBegin` is set to its replacement (set to `begin` if no invalidation
+/// happens). Since outgoing copies could have been inserted at `end`, the
+/// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
+/// size of the fast buffer allocated.
+static LogicalResult generateCopy(
+ const MemRefRegion &region, Block *block, Block::iterator begin,
+ Block::iterator end, Block *copyPlacementBlock,
+ Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
+ AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
+ DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
+ Block::iterator *nBegin, Block::iterator *nEnd) {
+ *nBegin = begin;
+ *nEnd = end;
+
+ FuncOp f = begin->getParentOfType<FuncOp>();
+ OpBuilder topBuilder(f.getBody());
+ Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
+
+ if (begin == end)
+ return success();
+
+ // Is the copy out point at the end of the block where we are doing
+ // explicit copying.
+ bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
+
+ // Copies for read regions are going to be inserted at 'begin'.
+ OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
+ // Copies for write regions are going to be inserted at 'end'.
+ OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
+ OpBuilder &b = region.isWrite() ? epilogue : prologue;
+
+ // Builder to create constants at the top level.
+ auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
+ OpBuilder top(func.getBody());
+
+ auto loc = region.loc;
+ auto memref = region.memref;
+ auto memRefType = memref->getType().cast<MemRefType>();
+
+ auto layoutMaps = memRefType.getAffineMaps();
+ if (layoutMaps.size() > 1 ||
+ (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ return failure();
+ }
+
+ // Indices to use for the copying.
+ // Indices for the original memref being copied from/to.
+ SmallVector<Value, 4> memIndices;
+ // Indices for the faster buffer being copied into/from.
+ SmallVector<Value, 4> bufIndices;
+
+ unsigned rank = memRefType.getRank();
+ SmallVector<int64_t, 4> fastBufferShape;
+
+ // Compute the extents of the buffer.
+ std::vector<SmallVector<int64_t, 4>> lbs;
+ SmallVector<int64_t, 8> lbDivisors;
+ lbs.reserve(rank);
+ Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
+ &fastBufferShape, &lbs, &lbDivisors);
+ if (!numElements.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
+ return failure();
+ }
+
+ if (numElements.getValue() == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
+ *sizeInBytes = 0;
+ return success();
+ }
+
+ const FlatAffineConstraints *cst = region.getConstraints();
+ // 'regionSymbols' hold values that this memory region is symbolic/parametric
+ // on; these typically include loop IVs surrounding the level at which the
+ // copy generation is being done or other valid symbols in MLIR.
+ SmallVector<Value, 8> regionSymbols;
+ cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
+
+ // Construct the index expressions for the fast memory buffer. The index
+ // expression for a particular dimension of the fast buffer is obtained by
+ // subtracting out the lower bound on the original memref's data region
+ // along the corresponding dimension.
+
+ // Index start offsets for faster memory buffer relative to the original.
+ SmallVector<AffineExpr, 4> offsets;
+ offsets.reserve(rank);
+ for (unsigned d = 0; d < rank; d++) {
+ assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
+
+ AffineExpr offset = top.getAffineConstantExpr(0);
+ for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
+ offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
+ }
+ assert(lbDivisors[d] > 0);
+ offset =
+ (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
+
+ // Set copy start location for this dimension in the lower memory space
+ // memref.
+ if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
+ auto indexVal = caf.getValue();
+ if (indexVal == 0) {
+ memIndices.push_back(zeroIndex);
+ } else {
+ memIndices.push_back(
+ top.create<ConstantIndexOp>(loc, indexVal).getResult());
+ }
+ } else {
+ // The coordinate for the start location is just the lower bound along the
+ // corresponding dimension on the memory region (stored in 'offset').
+ auto map = AffineMap::get(
+ cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
+ memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
+ }
+ // The fast buffer is copied into at location zero; addressing is relative.
+ bufIndices.push_back(zeroIndex);
+
+ // Record the offsets since they are needed to remap the memory accesses of
+ // the original memref further below.
+ offsets.push_back(offset);
+ }
+
+ // The faster memory space buffer.
+ Value fastMemRef;
+
+ // Check if a buffer was already created.
+ bool existingBuf = fastBufferMap.count(memref) > 0;
+ if (!existingBuf) {
+ AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
+ auto fastMemRefType =
+ MemRefType::get(fastBufferShape, memRefType.getElementType(),
+ fastBufferLayout, copyOptions.fastMemorySpace);
+
+ // Create the fast memory space buffer just before the 'affine.for'
+ // operation.
+ fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult();
+ // Record it.
+ fastBufferMap[memref] = fastMemRef;
+ // fastMemRefType is a constant shaped memref.
+ *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
+ LLVM_DEBUG(emitRemarkForBlock(*block)
+ << "Creating fast buffer of type " << fastMemRefType
+ << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
+ << " KiB\n");
+ } else {
+ // Reuse the one already created.
+ fastMemRef = fastBufferMap[memref];
+ *sizeInBytes = 0;
+ }
+
+ auto numElementsSSA =
+ top.create<ConstantIndexOp>(loc, numElements.getValue());
+
+ SmallVector<StrideInfo, 4> strideInfos;
+ getMultiLevelStrides(region, fastBufferShape, &strideInfos);
+
+ // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
+ // multi-level strides.
+ if (strideInfos.size() > 1) {
+ LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+ return failure();
+ }
+
+ Value stride = nullptr;
+ Value numEltPerStride = nullptr;
+ if (!strideInfos.empty()) {
+ stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
+ numEltPerStride =
+ top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
+ }
+
+ // Record the last operation where we want the memref replacement to end. We
+ // later do the memref replacement only in [begin, postDomFilter] so
+ // that the original memref's used in the data movement code themselves don't
+ // get replaced.
+ auto postDomFilter = std::prev(end);
+
+ // Create fully composed affine maps for each memref.
+ auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
+ fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
+ auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
+ fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
+
+ if (!copyOptions.generateDma) {
+ // Point-wise copy generation.
+ auto copyNest = generatePointWiseCopy(loc, memref, fastMemRef, memAffineMap,
+ memIndices, fastBufferShape,
+ /*isCopyOut=*/region.isWrite(), b);
+
+ // Record this so that we can skip it from yet another copy.
+ copyNests.insert(copyNest);
+
+ // Since new ops are being appended (for copy out's), adjust the end to
+ // mark end of block range being processed if necessary.
+ if (region.isWrite() && isCopyOutAtEndOfBlock)
+ *nEnd = Block::iterator(copyNest.getOperation());
+ } else {
+ // DMA generation.
+ // Create a tag (single element 1-d memref) for the DMA.
+ auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
+ copyOptions.tagMemorySpace);
+ auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
+
+ SmallVector<Value, 4> tagIndices({zeroIndex});
+ auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
+ fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
+ if (!region.isWrite()) {
+ // DMA non-blocking read from original buffer to fast buffer.
+ b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
+ fastMemRef, bufAffineMap, bufIndices,
+ tagMemRef, tagAffineMap, tagIndices,
+ numElementsSSA, stride, numEltPerStride);
+ } else {
+ // DMA non-blocking write from fast buffer to the original memref.
+ auto op = b.create<AffineDmaStartOp>(
+ loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
+ memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
+ stride, numEltPerStride);
+ // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
+ // end to mark end of block range being processed.
+ if (isCopyOutAtEndOfBlock)
+ *nEnd = Block::iterator(op.getOperation());
+ }
+
+ // Matching DMA wait to block on completion; tag always has a 0 index.
+ b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
+ numElementsSSA);
+
+ // Generate dealloc for the tag.
+ auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
+ if (*nEnd == end && isCopyOutAtEndOfBlock)
+ // Since new ops are being appended (for outgoing DMAs), adjust the end to
+ // mark end of range of the original.
+ *nEnd = Block::iterator(tagDeallocOp.getOperation());
+ }
+
+ // Generate dealloc for the buffer.
+ if (!existingBuf) {
+ auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef);
+ // When generating pointwise copies, `nEnd' has to be set to deallocOp on
+ // the fast buffer (since it marks the new end insertion point).
+ if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
+ *nEnd = Block::iterator(bufDeallocOp.getOperation());
+ }
+
+ // Replace all uses of the old memref with the faster one while remapping
+ // access indices (subtracting out lower bound offsets for each dimension).
+ // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
+ // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
+ // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
+ // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
+ // d2, d3 correspond to the original indices (%i, %j).
+ SmallVector<AffineExpr, 4> remapExprs;
+ remapExprs.reserve(rank);
+ for (unsigned i = 0; i < rank; i++) {
+ // The starting operands of indexRemap will be regionSymbols (the symbols on
+ // which the memref region is parametric); then those corresponding to
+ // the memref's original indices follow.
+ auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
+ remapExprs.push_back(dimExpr - offsets[i]);
+ }
+ auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs);
+
+ // Record the begin since it may be invalidated by memref replacement.
+ Block::iterator prevOfBegin;
+ bool isBeginAtStartOfBlock = (begin == block->begin());
+ if (!isBeginAtStartOfBlock)
+ prevOfBegin = std::prev(begin);
+
+ // *Only* those uses within the range [begin, end) of 'block' are replaced.
+ replaceAllMemRefUsesWith(memref, fastMemRef,
+ /*extraIndices=*/{}, indexRemap,
+ /*extraOperands=*/regionSymbols,
+ /*symbolOperands=*/{},
+ /*domInstFilter=*/&*begin,
+ /*postDomInstFilter=*/&*postDomFilter);
+
+ *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
+
+ return success();
+}
+
+/// Construct the memref region to just include the entire memref. Returns false
+/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
+/// enclosing loop IVs of opInst (starting from the outermost) that the region
+/// is parametric on.
+static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
+ MemRefRegion *region) {
+ unsigned rank;
+ if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
+ rank = loadOp.getMemRefType().getRank();
+ region->memref = loadOp.getMemRef();
+ region->setWrite(false);
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
+ rank = storeOp.getMemRefType().getRank();
+ region->memref = storeOp.getMemRef();
+ region->setWrite(true);
+ } else {
+ assert(false && "expected load or store op");
+ return false;
+ }
+ auto memRefType = region->memref->getType().cast<MemRefType>();
+ if (!memRefType.hasStaticShape())
+ return false;
+
+ auto *regionCst = region->getConstraints();
+
+ // Just get the first numSymbols IVs, which the memref region is parametric
+ // on.
+ SmallVector<AffineForOp, 4> ivs;
+ getLoopIVs(*opInst, &ivs);
+ ivs.resize(numParamLoopIVs);
+ SmallVector<Value, 4> symbols;
+ extractForInductionVars(ivs, &symbols);
+ regionCst->reset(rank, numParamLoopIVs, 0);
+ regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
+
+ // Memref dim sizes provide the bounds.
+ for (unsigned d = 0; d < rank; d++) {
+ auto dimSize = memRefType.getDimSize(d);
+ assert(dimSize > 0 && "filtered dynamic shapes above");
+ regionCst->addConstantLowerBound(d, 0);
+ regionCst->addConstantUpperBound(d, dimSize - 1);
+ }
+ return true;
+}
+
+/// Generates copies for a contiguous sequence of operations in `block` in the
+/// iterator range [`begin', `end'), where `end' can't be past the terminator of
+/// the block (since additional operations are potentially inserted right before
+/// `end'. Returns the total size of the fast buffers used.
+// Since we generate alloc's and dealloc's for all fast buffers (before and
+// after the range of operations resp.), all of the fast memory capacity is
+// assumed to be available for processing this block range.
+uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
+ Block::iterator end,
+ const AffineCopyOptions &copyOptions,
+ DenseSet<Operation *> &copyNests) {
+ if (begin == end)
+ return 0;
+
+ assert(begin->getBlock() == std::prev(end)->getBlock() &&
+ "Inconsistent block begin/end args");
+ assert(end != end->getBlock()->end() && "end can't be the block terminator");
+
+ Block *block = begin->getBlock();
+
+ // Copies will be generated for this depth, i.e., symbolic in all loops
+ // surrounding the this block range.
+ unsigned copyDepth = getNestingDepth(*begin);
+
+ LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
+ << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
+
+ // List of memory regions to copy for. We need a map vector to have a
+ // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
+ // since the alloc's for example are identical except for the SSA id.
+ SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
+ SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
+
+ // Map from original memref's to the fast buffers that their accesses are
+ // replaced with.
+ DenseMap<Value, Value> fastBufferMap;
+
+ // To check for errors when walking the block.
+ bool error = false;
+
+ // Walk this range of operations to gather all memory regions.
+ block->walk(begin, end, [&](Operation *opInst) {
+ // Gather regions to allocate to buffers in faster memory space.
+ if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
+ if ((loadOp.getMemRefType().getMemorySpace() !=
+ copyOptions.slowMemorySpace))
+ return;
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
+ if (storeOp.getMemRefType().getMemorySpace() !=
+ copyOptions.slowMemorySpace)
+ return;
+ } else {
+ // Neither load nor a store op.
+ return;
+ }
+
+ // Compute the MemRefRegion accessed.
+ auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
+ if (failed(region->compute(opInst, copyDepth))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Error obtaining memory region: semi-affine maps?\n");
+ LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
+ if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
+ LLVM_DEBUG(
+ opInst->emitError("non-constant memref sizes not yet supported"));
+ error = true;
+ return;
+ }
+ }
+
+ // Each memref has a single buffer associated with it irrespective of how
+ // many load's and store's happen on it.
+ // TODO(bondhugula): in the future, when regions don't intersect and satisfy
+ // other properties (based on load/store regions), we could consider
+ // multiple buffers per memref.
+
+ // Add to the appropriate region if it's not already in it, or take a
+ // bounding box union with the existing one if it's already in there.
+ // Note that a memref may have both read and write regions - so update the
+ // region in the other list if one exists (write in case of read and vice
+ // versa) since there is a single bounding box for a memref across all reads
+ // and writes that happen on it.
+
+ // Attempts to update; returns true if 'region' exists in targetRegions.
+ auto updateRegion =
+ [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
+ &targetRegions) {
+ auto it = targetRegions.find(region->memref);
+ if (it == targetRegions.end())
+ return false;
+
+ // Perform a union with the existing region.
+ if (failed(it->second->unionBoundingBox(*region))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Memory region bounding box failed; "
+ "over-approximating to the entire memref\n");
+ // If the union fails, we will overapproximate.
+ if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
+ LLVM_DEBUG(opInst->emitError(
+ "non-constant memref sizes not yet supported"));
+ error = true;
+ return true;
+ }
+ it->second->getConstraints()->clearAndCopyFrom(
+ *region->getConstraints());
+ } else {
+ // Union was computed and stored in 'it->second': copy to 'region'.
+ region->getConstraints()->clearAndCopyFrom(
+ *it->second->getConstraints());
+ }
+ return true;
+ };
+
+ bool existsInRead = updateRegion(readRegions);
+ if (error)
+ return;
+ bool existsInWrite = updateRegion(writeRegions);
+ if (error)
+ return;
+
+ // Finally add it to the region list.
+ if (region->isWrite() && !existsInWrite) {
+ writeRegions[region->memref] = std::move(region);
+ } else if (!region->isWrite() && !existsInRead) {
+ readRegions[region->memref] = std::move(region);
+ }
+ });
+
+ if (error) {
+ begin->emitError(
+ "copy generation failed for one or more memref's in this block\n");
+ return 0;
+ }
+
+ uint64_t totalCopyBuffersSizeInBytes = 0;
+ bool ret = true;
+ auto processRegions =
+ [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
+ &regions) {
+ for (const auto &regionEntry : regions) {
+ // For each region, hoist copy in/out past all hoistable
+ // 'affine.for's.
+ Block::iterator copyInPlacementStart, copyOutPlacementStart;
+ Block *copyPlacementBlock;
+ findHighestBlockForPlacement(
+ *regionEntry.second, *block, begin, end, &copyPlacementBlock,
+ &copyInPlacementStart, &copyOutPlacementStart);
+
+ uint64_t sizeInBytes;
+ Block::iterator nBegin, nEnd;
+ LogicalResult iRet = generateCopy(
+ *regionEntry.second, block, begin, end, copyPlacementBlock,
+ copyInPlacementStart, copyOutPlacementStart, copyOptions,
+ fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
+ if (succeeded(iRet)) {
+ // begin/end could have been invalidated, and need update.
+ begin = nBegin;
+ end = nEnd;
+ totalCopyBuffersSizeInBytes += sizeInBytes;
+ }
+ ret = ret & succeeded(iRet);
+ }
+ };
+ processRegions(readRegions);
+ processRegions(writeRegions);
+
+ if (!ret) {
+ begin->emitError(
+ "copy generation failed for one or more memref's in this block\n");
+ return totalCopyBuffersSizeInBytes;
+ }
+
+ // For a range of operations, a note will be emitted at the caller.
+ AffineForOp forOp;
+ uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024);
+ if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
+ forOp.emitRemark()
+ << sizeInKib
+ << " KiB of copy buffers in fast memory space for this block\n";
+ }
+
+ if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
+ StringRef str = "Total size of all copy buffers' for this block "
+ "exceeds fast memory capacity\n";
+ block->getParentOp()->emitError(str);
+ }
+
+ return totalCopyBuffersSizeInBytes;
+}
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
new file mode 100644
index 00000000000..ca26074f288
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -0,0 +1,348 @@
+//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+
+void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
+ Region &region) {
+ for (auto &use : llvm::make_early_inc_range(orig->getUses())) {
+ if (region.isAncestor(use.getOwner()->getParentRegion()))
+ use.set(replacement);
+ }
+}
+
+void mlir::visitUsedValuesDefinedAbove(
+ Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
+ assert(limit.isAncestor(&region) &&
+ "expected isolation limit to be an ancestor of the given region");
+
+ // Collect proper ancestors of `limit` upfront to avoid traversing the region
+ // tree for every value.
+ SmallPtrSet<Region *, 4> properAncestors;
+ for (auto *reg = limit.getParentRegion(); reg != nullptr;
+ reg = reg->getParentRegion()) {
+ properAncestors.insert(reg);
+ }
+
+ region.walk([callback, &properAncestors](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands())
+ // Callback on values defined in a proper ancestor of region.
+ if (properAncestors.count(operand.get()->getParentRegion()))
+ callback(&operand);
+ });
+}
+
+void mlir::visitUsedValuesDefinedAbove(
+ MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
+ for (Region &region : regions)
+ visitUsedValuesDefinedAbove(region, region, callback);
+}
+
+void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
+ llvm::SetVector<Value> &values) {
+ visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
+ values.insert(operand->get());
+ });
+}
+
+void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
+ llvm::SetVector<Value> &values) {
+ for (Region &region : regions)
+ getUsedValuesDefinedAbove(region, region, values);
+}
+
+//===----------------------------------------------------------------------===//
+// Unreachable Block Elimination
+//===----------------------------------------------------------------------===//
+
+/// Erase the unreachable blocks within the provided regions. Returns success
+/// if any blocks were erased, failure otherwise.
+// TODO: We could likely merge this with the DCE algorithm below.
+static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
+ // Set of blocks found to be reachable within a given region.
+ llvm::df_iterator_default_set<Block *, 16> reachable;
+ // If any blocks were found to be dead.
+ bool erasedDeadBlocks = false;
+
+ SmallVector<Region *, 1> worklist;
+ worklist.reserve(regions.size());
+ for (Region &region : regions)
+ worklist.push_back(&region);
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+ if (region->empty())
+ continue;
+
+ // If this is a single block region, just collect the nested regions.
+ if (std::next(region->begin()) == region->end()) {
+ for (Operation &op : region->front())
+ for (Region &region : op.getRegions())
+ worklist.push_back(&region);
+ continue;
+ }
+
+ // Mark all reachable blocks.
+ reachable.clear();
+ for (Block *block : depth_first_ext(&region->front(), reachable))
+ (void)block /* Mark all reachable blocks */;
+
+ // Collect all of the dead blocks and push the live regions onto the
+ // worklist.
+ for (Block &block : llvm::make_early_inc_range(*region)) {
+ if (!reachable.count(&block)) {
+ block.dropAllDefinedValueUses();
+ block.erase();
+ erasedDeadBlocks = true;
+ continue;
+ }
+
+ // Walk any regions within this block.
+ for (Operation &op : block)
+ for (Region &region : op.getRegions())
+ worklist.push_back(&region);
+ }
+ }
+
+ return success(erasedDeadBlocks);
+}
+
+//===----------------------------------------------------------------------===//
+// Dead Code Elimination
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Data structure used to track which values have already been proved live.
+///
+/// Because Operation's can have multiple results, this data structure tracks
+/// liveness for both Value's and Operation's to avoid having to look through
+/// all Operation results when analyzing a use.
+///
+/// This data structure essentially tracks the dataflow lattice.
+/// The set of values/ops proved live increases monotonically to a fixed-point.
+class LiveMap {
+public:
+ /// Value methods.
+ bool wasProvenLive(Value value) { return liveValues.count(value); }
+ void setProvedLive(Value value) {
+ changed |= liveValues.insert(value).second;
+ }
+
+ /// Operation methods.
+ bool wasProvenLive(Operation *op) { return liveOps.count(op); }
+ void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
+
+ /// Methods for tracking if we have reached a fixed-point.
+ void resetChanged() { changed = false; }
+ bool hasChanged() { return changed; }
+
+private:
+ bool changed = false;
+ DenseSet<Value> liveValues;
+ DenseSet<Operation *> liveOps;
+};
+} // namespace
+
+static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
+ Operation *owner = use.getOwner();
+ unsigned operandIndex = use.getOperandNumber();
+ // This pass generally treats all uses of an op as live if the op itself is
+ // considered live. However, for successor operands to terminators we need a
+ // finer-grained notion where we deduce liveness for operands individually.
+ // The reason for this is easiest to think about in terms of a classical phi
+ // node based SSA IR, where each successor operand is really an operand to a
+ // *separate* phi node, rather than all operands to the branch itself as with
+ // the block argument representation that MLIR uses.
+ //
+ // And similarly, because each successor operand is really an operand to a phi
+ // node, rather than to the terminator op itself, a terminator op can't e.g.
+ // "print" the value of a successor operand.
+ if (owner->isKnownTerminator()) {
+ if (auto arg = owner->getSuccessorBlockArgument(operandIndex))
+ return !liveMap.wasProvenLive(*arg);
+ return false;
+ }
+ return false;
+}
+
+static void processValue(Value value, LiveMap &liveMap) {
+ bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) {
+ if (isUseSpeciallyKnownDead(use, liveMap))
+ return false;
+ return liveMap.wasProvenLive(use.getOwner());
+ });
+ if (provedLive)
+ liveMap.setProvedLive(value);
+}
+
+static bool isOpIntrinsicallyLive(Operation *op) {
+ // This pass doesn't modify the CFG, so terminators are never deleted.
+ if (!op->isKnownNonTerminator())
+ return true;
+ // If the op has a side effect, we treat it as live.
+ if (!op->hasNoSideEffect())
+ return true;
+ return false;
+}
+
+static void propagateLiveness(Region &region, LiveMap &liveMap);
+static void propagateLiveness(Operation *op, LiveMap &liveMap) {
+ // All Value's are either a block argument or an op result.
+ // We call processValue on those cases.
+
+ // Recurse on any regions the op has.
+ for (Region &region : op->getRegions())
+ propagateLiveness(region, liveMap);
+
+ // Process the op itself.
+ if (isOpIntrinsicallyLive(op)) {
+ liveMap.setProvedLive(op);
+ return;
+ }
+ for (Value value : op->getResults())
+ processValue(value, liveMap);
+ bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
+ return liveMap.wasProvenLive(value);
+ });
+ if (provedLive)
+ liveMap.setProvedLive(op);
+}
+
+static void propagateLiveness(Region &region, LiveMap &liveMap) {
+ if (region.empty())
+ return;
+
+ for (Block *block : llvm::post_order(&region.front())) {
+ // We process block arguments after the ops in the block, to promote
+ // faster convergence to a fixed point (we try to visit uses before defs).
+ for (Operation &op : llvm::reverse(block->getOperations()))
+ propagateLiveness(&op, liveMap);
+ for (Value value : block->getArguments())
+ processValue(value, liveMap);
+ }
+}
+
+static void eraseTerminatorSuccessorOperands(Operation *terminator,
+ LiveMap &liveMap) {
+ for (unsigned succI = 0, succE = terminator->getNumSuccessors();
+ succI < succE; succI++) {
+ // Iterating successors in reverse is not strictly needed, since we
+ // aren't erasing any successors. But it is slightly more efficient
+ // since it will promote later operands of the terminator being erased
+ // first, reducing the quadratic-ness.
+ unsigned succ = succE - succI - 1;
+ for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ);
+ argI < argE; argI++) {
+ // Iterating args in reverse is needed for correctness, to avoid
+ // shifting later args when earlier args are erased.
+ unsigned arg = argE - argI - 1;
+ Value value = terminator->getSuccessor(succ)->getArgument(arg);
+ if (!liveMap.wasProvenLive(value)) {
+ terminator->eraseSuccessorOperand(succ, arg);
+ }
+ }
+ }
+}
+
+static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
+ LiveMap &liveMap) {
+ bool erasedAnything = false;
+ for (Region &region : regions) {
+ if (region.empty())
+ continue;
+
+ // We do the deletion in an order that deletes all uses before deleting
+ // defs.
+ // MLIR's SSA structural invariants guarantee that except for block
+ // arguments, the use-def graph is acyclic, so this is possible with a
+ // single walk of ops and then a final pass to clean up block arguments.
+ //
+ // To do this, we visit ops in an order that visits domtree children
+ // before domtree parents. A CFG post-order (with reverse iteration with a
+ // block) satisfies that without needing an explicit domtree calculation.
+ for (Block *block : llvm::post_order(&region.front())) {
+ eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
+ for (Operation &childOp :
+ llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
+ erasedAnything |=
+ succeeded(deleteDeadness(childOp.getRegions(), liveMap));
+ if (!liveMap.wasProvenLive(&childOp)) {
+ erasedAnything = true;
+ childOp.erase();
+ }
+ }
+ }
+ // Delete block arguments.
+ // The entry block has an unknown contract with their enclosing block, so
+ // skip it.
+ for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
+ // Iterate in reverse to avoid shifting later arguments when deleting
+ // earlier arguments.
+ for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
+ if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
+ block.eraseArgument(e - i - 1, /*updatePredTerms=*/false);
+ erasedAnything = true;
+ }
+ }
+ }
+ return success(erasedAnything);
+}
+
+// This function performs a simple dead code elimination algorithm over the
+// given regions.
+//
+// The overall goal is to prove that Values are dead, which allows deleting ops
+// and block arguments.
+//
+// This uses an optimistic algorithm that assumes everything is dead until
+// proved otherwise, allowing it to delete recursively dead cycles.
+//
+// This is a simple fixed-point dataflow analysis algorithm on a lattice
+// {Dead,Alive}. Because liveness flows backward, we generally try to
+// iterate everything backward to speed up convergence to the fixed-point. This
+// allows for being able to delete recursively dead cycles of the use-def graph,
+// including block arguments.
+//
+// This function returns success if any operations or arguments were deleted,
+// failure otherwise.
+static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
+ assert(regions.size() == 1);
+
+ LiveMap liveMap;
+ do {
+ liveMap.resetChanged();
+
+ for (Region &region : regions)
+ propagateLiveness(region, liveMap);
+ } while (liveMap.hasChanged());
+
+ return deleteDeadness(regions, liveMap);
+}
+
+//===----------------------------------------------------------------------===//
+// Region Simplification
+//===----------------------------------------------------------------------===//
+
+/// Run a set of structural simplifications over the given regions. This
+/// includes transformations like unreachable block elimination, dead argument
+/// elimination, as well as some other DCE. This function returns success if any
+/// of the regions were simplified, failure otherwise.
+LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
+ LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions);
+ LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions);
+ return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs));
+}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
new file mode 100644
index 00000000000..a6629183dee
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -0,0 +1,469 @@
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/ADT/TypeSwitch.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseMap.h"
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(Operation &op) {
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
+ return true;
+ return false;
+}
+
+/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
+static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
+ return TypeSwitch<Operation *, NamedAttribute>(op)
+ .Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp,
+ AffineDmaWaitOp>(
+ [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
+}
+
+// Perform the replacement in `op`.
+LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
+ Operation *op,
+ ArrayRef<Value> extraIndices,
+ AffineMap indexRemap,
+ ArrayRef<Value> extraOperands,
+ ArrayRef<Value> symbolOperands) {
+ unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
+ (void)newMemRefRank; // unused in opt mode
+ unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
+ (void)oldMemRefRank; // unused in opt mode
+ if (indexRemap) {
+ assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
+ "symbolic operand count mismatch");
+ assert(indexRemap.getNumInputs() ==
+ extraOperands.size() + oldMemRefRank + symbolOperands.size());
+ assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+ newMemRef->getType().cast<MemRefType>().getElementType());
+
+ if (!isMemRefDereferencingOp(*op))
+ // Failure: memref used in a non-dereferencing context (potentially
+ // escapes); no replacement in these cases.
+ return failure();
+
+ SmallVector<unsigned, 2> usePositions;
+ for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
+ if (opEntry.value() == oldMemRef)
+ usePositions.push_back(opEntry.index());
+ }
+
+ // If memref doesn't appear, nothing to do.
+ if (usePositions.empty())
+ return success();
+
+ if (usePositions.size() > 1) {
+ // TODO(mlir-team): extend it for this case when needed (rare).
+ assert(false && "multiple dereferencing uses in a single op not supported");
+ return failure();
+ }
+
+ unsigned memRefOperandPos = usePositions.front();
+
+ OpBuilder builder(op);
+ NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+ AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
+ unsigned oldMapNumInputs = oldMap.getNumInputs();
+ SmallVector<Value, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+
+ // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
+ SmallVector<Value, 4> oldMemRefOperands;
+ SmallVector<Value, 4> affineApplyOps;
+ oldMemRefOperands.reserve(oldMemRefRank);
+ if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+ for (auto resultExpr : oldMap.getResults()) {
+ auto singleResMap = AffineMap::get(oldMap.getNumDims(),
+ oldMap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ oldMapOperands);
+ oldMemRefOperands.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
+ }
+
+ // Construct new indices as a remap of the old ones if a remapping has been
+ // provided. The indices of a memref come right after it, i.e.,
+ // at position memRefOperandPos + 1.
+ SmallVector<Value, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+ SmallVector<Value, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = AffineMap::get(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.append(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (auto extraIndex : extraIndices) {
+ assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
+ "single result op's expected to generate these indices");
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+ auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
+ // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
+ fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
+ newMap = simplifyAffineMap(newMap);
+ canonicalizeMapAndOperands(&newMap, &newMapOperands);
+ // Remove any affine.apply's that became dead as a result of composition.
+ for (auto value : affineApplyOps)
+ if (value->use_empty())
+ value->getDefiningOp()->erase();
+
+ // Construct the new operation using this memref.
+ OperationState state(op->getLoc(), op->getName());
+ state.setOperandListToResizable(op->hasResizableOperandsList());
+ state.operands.reserve(op->getNumOperands() + extraIndices.size());
+ // Insert the non-memref operands.
+ state.operands.append(op->operand_begin(),
+ op->operand_begin() + memRefOperandPos);
+ // Insert the new memref value.
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ // Insert the remaining operands unmodified.
+ state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
+ oldMapNumInputs,
+ op->operand_end());
+
+ // Result types don't change. Both memref's are of the same elemental type.
+ state.types.reserve(op->getNumResults());
+ for (auto result : op->getResults())
+ state.types.push_back(result->getType());
+
+ // Add attribute for 'newMap', other Attributes do not change.
+ auto newMapAttr = AffineMapAttr::get(newMap);
+ for (auto namedAttr : op->getAttrs()) {
+ if (namedAttr.first == oldMapAttrPair.first) {
+ state.attributes.push_back({namedAttr.first, newMapAttr});
+ } else {
+ state.attributes.push_back(namedAttr);
+ }
+ }
+
+ // Create the new operation.
+ auto *repOp = builder.createOperation(state);
+ op->replaceAllUsesWith(repOp);
+ op->erase();
+
+ return success();
+}
+
+LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
+ ArrayRef<Value> extraIndices,
+ AffineMap indexRemap,
+ ArrayRef<Value> extraOperands,
+ ArrayRef<Value> symbolOperands,
+ Operation *domInstFilter,
+ Operation *postDomInstFilter) {
+ unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
+ (void)newMemRefRank; // unused in opt mode
+ unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
+ (void)oldMemRefRank;
+ if (indexRemap) {
+ assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
+ "symbol operand count mismatch");
+ assert(indexRemap.getNumInputs() ==
+ extraOperands.size() + oldMemRefRank + symbolOperands.size());
+ assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+ newMemRef->getType().cast<MemRefType>().getElementType());
+
+ std::unique_ptr<DominanceInfo> domInfo;
+ std::unique_ptr<PostDominanceInfo> postDomInfo;
+ if (domInstFilter)
+ domInfo = std::make_unique<DominanceInfo>(
+ domInstFilter->getParentOfType<FuncOp>());
+
+ if (postDomInstFilter)
+ postDomInfo = std::make_unique<PostDominanceInfo>(
+ postDomInstFilter->getParentOfType<FuncOp>());
+
+ // Walk all uses of old memref; collect ops to perform replacement. We use a
+ // DenseSet since an operation could potentially have multiple uses of a
+ // memref (although rare), and the replacement later is going to erase ops.
+ DenseSet<Operation *> opsToReplace;
+ for (auto *op : oldMemRef->getUsers()) {
+ // Skip this use if it's not dominated by domInstFilter.
+ if (domInstFilter && !domInfo->dominates(domInstFilter, op))
+ continue;
+
+ // Skip this use if it's not post-dominated by postDomInstFilter.
+ if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
+ continue;
+
+ // Skip dealloc's - no replacement is necessary, and a memref replacement
+ // at other uses doesn't hurt these dealloc's.
+ if (isa<DeallocOp>(op))
+ continue;
+
+ // Check if the memref was used in a non-dereferencing context. It is fine
+ // for the memref to be used in a non-dereferencing way outside of the
+ // region where this replacement is happening.
+ if (!isMemRefDereferencingOp(*op))
+ // Failure: memref used in a non-dereferencing op (potentially escapes);
+ // no replacement in these cases.
+ return failure();
+
+ // We'll first collect and then replace --- since replacement erases the op
+ // that has the use, and that op could be postDomFilter or domFilter itself!
+ opsToReplace.insert(op);
+ }
+
+ for (auto *op : opsToReplace) {
+ if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
+ indexRemap, extraOperands,
+ symbolOperands)))
+ llvm_unreachable("memref replacement guaranteed to succeed here");
+ }
+
+ return success();
+}
+
+/// Given an operation, inserts one or more single result affine
+/// apply operations, results of which are exclusively used by this operation
+/// operation. The operands of these newly created affine apply ops are
+/// guaranteed to be loop iterators or terminal symbols of a function.
+///
+/// Before
+///
+/// affine.for %i = 0 to #map(%N)
+/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// "compute"(%idx)
+///
+/// After
+///
+/// affine.for %i = 0 to #map(%N)
+/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
+/// "compute"(%idx_)
+///
+/// This allows applying different transformations on send and compute (for eg.
+/// different shifts/delays).
+///
+/// Returns nullptr either if none of opInst's operands were the result of an
+/// affine.apply and thus there was no affine computation slice to create, or if
+/// all the affine.apply op's supplying operands to this opInst did not have any
+/// uses besides this opInst; otherwise returns the list of affine.apply
+/// operations created in output argument `sliceOps`.
+void mlir::createAffineComputationSlice(
+ Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
+ // Collect all operands that are results of affine apply ops.
+ SmallVector<Value, 4> subOperands;
+ subOperands.reserve(opInst->getNumOperands());
+ for (auto operand : opInst->getOperands())
+ if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp()))
+ subOperands.push_back(operand);
+
+ // Gather sequence of AffineApplyOps reachable from 'subOperands'.
+ SmallVector<Operation *, 4> affineApplyOps;
+ getReachableAffineApplyOps(subOperands, affineApplyOps);
+ // Skip transforming if there are no affine maps to compose.
+ if (affineApplyOps.empty())
+ return;
+
+ // Check if all uses of the affine apply op's lie only in this op op, in
+ // which case there would be nothing to do.
+ bool localized = true;
+ for (auto *op : affineApplyOps) {
+ for (auto result : op->getResults()) {
+ for (auto *user : result->getUsers()) {
+ if (user != opInst) {
+ localized = false;
+ break;
+ }
+ }
+ }
+ }
+ if (localized)
+ return;
+
+ OpBuilder builder(opInst);
+ SmallVector<Value, 4> composedOpOperands(subOperands);
+ auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
+ fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
+
+ // Create an affine.apply for each of the map results.
+ sliceOps->reserve(composedMap.getNumResults());
+ for (auto resultExpr : composedMap.getResults()) {
+ auto singleResMap = AffineMap::get(composedMap.getNumDims(),
+ composedMap.getNumSymbols(), resultExpr);
+ sliceOps->push_back(builder.create<AffineApplyOp>(
+ opInst->getLoc(), singleResMap, composedOpOperands));
+ }
+
+ // Construct the new operands that include the results from the composed
+ // affine apply op above instead of existing ones (subOperands). So, they
+ // differ from opInst's operands only for those operands in 'subOperands', for
+ // which they will be replaced by the corresponding one from 'sliceOps'.
+ SmallVector<Value, 4> newOperands(opInst->getOperands());
+ for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
+ // Replace the subOperands from among the new operands.
+ unsigned j, f;
+ for (j = 0, f = subOperands.size(); j < f; j++) {
+ if (newOperands[i] == subOperands[j])
+ break;
+ }
+ if (j < subOperands.size()) {
+ newOperands[i] = (*sliceOps)[j];
+ }
+ }
+ for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
+ opInst->setOperand(idx, newOperands[idx]);
+ }
+}
+
+// TODO: Currently works for static memrefs with a single layout map.
+LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
+ MemRefType memrefType = allocOp.getType();
+ unsigned rank = memrefType.getRank();
+ if (rank == 0)
+ return success();
+
+ auto layoutMaps = memrefType.getAffineMaps();
+ OpBuilder b(allocOp);
+ if (layoutMaps.size() != 1)
+ return failure();
+
+ AffineMap layoutMap = layoutMaps.front();
+
+ // Nothing to do for identity layout maps.
+ if (layoutMap == b.getMultiDimIdentityMap(rank))
+ return success();
+
+ // We don't do any checks for one-to-one'ness; we assume that it is
+ // one-to-one.
+
+ // TODO: Only for static memref's for now.
+ if (memrefType.getNumDynamicDims() > 0)
+ return failure();
+
+ // We have a single map that is not an identity map. Create a new memref with
+ // the right shape and an identity layout map.
+ auto shape = memrefType.getShape();
+ FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands());
+ for (unsigned d = 0; d < rank; ++d) {
+ fac.addConstantLowerBound(d, 0);
+ fac.addConstantUpperBound(d, shape[d] - 1);
+ }
+
+ // We compose this map with the original index (logical) space to derive the
+ // upper bounds for the new index space.
+ unsigned newRank = layoutMap.getNumResults();
+ if (failed(fac.composeMatchingMap(layoutMap)))
+ // TODO: semi-affine maps.
+ return failure();
+
+ // Project out the old data dimensions.
+ fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
+ SmallVector<int64_t, 4> newShape(newRank);
+ for (unsigned d = 0; d < newRank; ++d) {
+ // The lower bound for the shape is always zero.
+ auto ubConst = fac.getConstantUpperBound(d);
+ // For a static memref and an affine map with no symbols, this is always
+ // bounded.
+ assert(ubConst.hasValue() && "should always have an upper bound");
+ if (ubConst.getValue() < 0)
+ // This is due to an invalid map that maps to a negative space.
+ return failure();
+ newShape[d] = ubConst.getValue() + 1;
+ }
+
+ auto oldMemRef = allocOp.getResult();
+ SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
+
+ auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
+ b.getMultiDimIdentityMap(newRank));
+ auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
+
+ // Replace all uses of the old memref.
+ if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
+ /*extraIndices=*/{},
+ /*indexRemap=*/layoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/symbolOperands))) {
+ // If it failed (due to escapes for example), bail out.
+ newAlloc.erase();
+ return failure();
+ }
+ // Replace any uses of the original alloc op and erase it. All remaining uses
+ // have to be dealloc's; RAMUW above would've failed otherwise.
+ assert(std::all_of(oldMemRef->user_begin(), oldMemRef->user_end(),
+ [](Operation *op) { return isa<DeallocOp>(op); }));
+ oldMemRef->replaceAllUsesWith(newAlloc);
+ allocOp.erase();
+ return success();
+}
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
new file mode 100644
index 00000000000..6b2b3e1ee7e
--- /dev/null
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -0,0 +1,1292 @@
+//===- Vectorize.cpp - Vectorize Pass Impl --------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements vectorization of loops, operations and data types to
+// a target-independent, n-D super-vector abstraction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+///
+/// Implements a high-level vectorization strategy on a Function.
+/// The abstraction used is that of super-vectors, which provide a single,
+/// compact, representation in the vector types, information that is expected
+/// to reduce the impact of the phase ordering problem
+///
+/// Vector granularity:
+/// ===================
+/// This pass is designed to perform vectorization at a super-vector
+/// granularity. A super-vector is loosely defined as a vector type that is a
+/// multiple of a "good" vector size so the HW can efficiently implement a set
+/// of high-level primitives. Multiple is understood along any dimension; e.g.
+/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a
+/// vector<8xf32> HW vector. Note that a "good vector size so the HW can
+/// efficiently implement a set of high-level primitives" is not necessarily an
+/// integer multiple of actual hardware registers. We leave details of this
+/// distinction unspecified for now.
+///
+/// Some may prefer the terminology a "tile of HW vectors". In this case, one
+/// should note that super-vectors implement an "always full tile" abstraction.
+/// They guarantee no partial-tile separation is necessary by relying on a
+/// high-level copy-reshape abstraction that we call vector.transfer. This
+/// copy-reshape operations is also responsible for performing layout
+/// transposition if necessary. In the general case this will require a scoped
+/// allocation in some notional local memory.
+///
+/// Whatever the mental model one prefers to use for this abstraction, the key
+/// point is that we burn into a single, compact, representation in the vector
+/// types, information that is expected to reduce the impact of the phase
+/// ordering problem. Indeed, a vector type conveys information that:
+/// 1. the associated loops have dependency semantics that do not prevent
+/// vectorization;
+/// 2. the associate loops have been sliced in chunks of static sizes that are
+/// compatible with vector sizes (i.e. similar to unroll-and-jam);
+/// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by
+/// the
+/// vector type and no vectorization hampering transformations can be
+/// applied to them anymore;
+/// 4. the underlying memrefs are accessed in some notional contiguous way
+/// that allows loading into vectors with some amount of spatial locality;
+/// In other words, super-vectorization provides a level of separation of
+/// concern by way of opacity to subsequent passes. This has the effect of
+/// encapsulating and propagating vectorization constraints down the list of
+/// passes until we are ready to lower further.
+///
+/// For a particular target, a notion of minimal n-d vector size will be
+/// specified and vectorization targets a multiple of those. In the following
+/// paragraph, let "k ." represent "a multiple of", to be understood as a
+/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes
+/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).
+///
+/// Some non-exhaustive notable super-vector sizes of interest include:
+/// - CPU: vector<k . HW_vector_size>,
+/// vector<k' . core_count x k . HW_vector_size>,
+/// vector<socket_count x k' . core_count x k . HW_vector_size>;
+/// - GPU: vector<k . warp_size>,
+/// vector<k . warp_size x float2>,
+/// vector<k . warp_size x float4>,
+/// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).
+///
+/// Loops and operations are emitted that operate on those super-vector shapes.
+/// Subsequent lowering passes will materialize to actual HW vector sizes. These
+/// passes are expected to be (gradually) more target-specific.
+///
+/// At a high level, a vectorized load in a loop will resemble:
+/// ```mlir
+/// affine.for %i = ? to ? step ? {
+/// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
+/// }
+/// ```
+/// It is the responsibility of the implementation of vector.transfer_read to
+/// materialize vector registers from the original scalar memrefs. A later (more
+/// target-dependent) lowering pass will materialize to actual HW vector sizes.
+/// This lowering may be occur at different times:
+/// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp +
+/// DmaWaitOp + vectorized operations for data transformations and shuffle;
+/// thus opening opportunities for unrolling and pipelining. This is an
+/// instance of library call "whiteboxing"; or
+/// 2. later in the a target-specific lowering pass or hand-written library
+/// call; achieving full separation of concerns. This is an instance of
+/// library call; or
+/// 3. a mix of both, e.g. based on a model.
+/// In the future, these operations will expose a contract to constrain the
+/// search on vectorization patterns and sizes.
+///
+/// Occurrence of super-vectorization in the compiler flow:
+/// =======================================================
+/// This is an active area of investigation. We start with 2 remarks to position
+/// super-vectorization in the context of existing ongoing work: LLVM VPLAN
+/// and LLVM SLP Vectorizer.
+///
+/// LLVM VPLAN:
+/// -----------
+/// The astute reader may have noticed that in the limit, super-vectorization
+/// can be applied at a similar time and with similar objectives than VPLAN.
+/// For instance, in the case of a traditional, polyhedral compilation-flow (for
+/// instance, the PPCG project uses ISL to provide dependence analysis,
+/// multi-level(scheduling + tiling), lifting footprint to fast memory,
+/// communication synthesis, mapping, register optimizations) and before
+/// unrolling. When vectorization is applied at this *late* level in a typical
+/// polyhedral flow, and is instantiated with actual hardware vector sizes,
+/// super-vectorization is expected to match (or subsume) the type of patterns
+/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR
+/// is higher level and our implementation should be significantly simpler. Also
+/// note that in this mode, recursive patterns are probably a bit of an overkill
+/// although it is reasonable to expect that mixing a bit of outer loop and
+/// inner loop vectorization + unrolling will provide interesting choices to
+/// MLIR.
+///
+/// LLVM SLP Vectorizer:
+/// --------------------
+/// Super-vectorization however is not meant to be usable in a similar fashion
+/// to the SLP vectorizer. The main difference lies in the information that
+/// both vectorizers use: super-vectorization examines contiguity of memory
+/// references along fastest varying dimensions and loops with recursive nested
+/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on
+/// the other hand, performs flat pattern matching inside a single unrolled loop
+/// body and stitches together pieces of load and store operations into full
+/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture
+/// innermost loop, control-flow dependent patterns that super-vectorization may
+/// not be able to capture easily. In other words, super-vectorization does not
+/// aim at replacing the SLP vectorizer and the two solutions are complementary.
+///
+/// Ongoing investigations:
+/// -----------------------
+/// We discuss the following *early* places where super-vectorization is
+/// applicable and touch on the expected benefits and risks . We list the
+/// opportunities in the context of the traditional polyhedral compiler flow
+/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline
+/// we expect to experiment with super-vectorization:
+/// 1. Right after language lowering to MLIR: this is the earliest time where
+/// super-vectorization is expected to be applied. At this level, all the
+/// language/user/library-level annotations are available and can be fully
+/// exploited. Examples include loop-type annotations (such as parallel,
+/// reduction, scan, dependence distance vector, vectorizable) as well as
+/// memory access annotations (such as non-aliasing writes guaranteed,
+/// indirect accesses that are permutations by construction) accesses or
+/// that a particular operation is prescribed atomic by the user. At this
+/// level, anything that enriches what dependence analysis can do should be
+/// aggressively exploited. At this level we are close to having explicit
+/// vector types in the language, except we do not impose that burden on the
+/// programmer/library: we derive information from scalar code + annotations.
+/// 2. After dependence analysis and before polyhedral scheduling: the
+/// information that supports vectorization does not need to be supplied by a
+/// higher level of abstraction. Traditional dependence analysis is available
+/// in MLIR and will be used to drive vectorization and cost models.
+///
+/// Let's pause here and remark that applying super-vectorization as described
+/// in 1. and 2. presents clear opportunities and risks:
+/// - the opportunity is that vectorization is burned in the type system and
+/// is protected from the adverse effect of loop scheduling, tiling, loop
+/// interchange and all passes downstream. Provided that subsequent passes are
+/// able to operate on vector types; the vector shapes, associated loop
+/// iterator properties, alignment, and contiguity of fastest varying
+/// dimensions are preserved until we lower the super-vector types. We expect
+/// this to significantly rein in on the adverse effects of phase ordering.
+/// - the risks are that a. all passes after super-vectorization have to work
+/// on elemental vector types (not that this is always true, wherever
+/// vectorization is applied) and b. that imposing vectorization constraints
+/// too early may be overall detrimental to loop fusion, tiling and other
+/// transformations because the dependence distances are coarsened when
+/// operating on elemental vector types. For this reason, the pattern
+/// profitability analysis should include a component that also captures the
+/// maximal amount of fusion available under a particular pattern. This is
+/// still at the stage of rough ideas but in this context, search is our
+/// friend as the Tensor Comprehensions and auto-TVM contributions
+/// demonstrated previously.
+/// Bottom-line is we do not yet have good answers for the above but aim at
+/// making it easy to answer such questions.
+///
+/// Back to our listing, the last places where early super-vectorization makes
+/// sense are:
+/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known
+/// to improve locality, parallelism and be configurable (e.g. max-fuse,
+/// smart-fuse etc). They can also have adverse effects on contiguity
+/// properties that are required for vectorization but the vector.transfer
+/// copy-reshape-pad-transpose abstraction is expected to help recapture
+/// these properties.
+/// 4. right after polyhedral-style scheduling+tiling;
+/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent
+/// probably the most promising places because applying tiling achieves a
+/// separation of concerns that allows rescheduling to worry less about
+/// locality and more about parallelism and distribution (e.g. min-fuse).
+///
+/// At these levels the risk-reward looks different: on one hand we probably
+/// lost a good deal of language/user/library-level annotation; on the other
+/// hand we gained parallelism and locality through scheduling and tiling.
+/// However we probably want to ensure tiling is compatible with the
+/// full-tile-only abstraction used in super-vectorization or suffer the
+/// consequences. It is too early to place bets on what will win but we expect
+/// super-vectorization to be the right abstraction to allow exploring at all
+/// these levels. And again, search is our friend.
+///
+/// Lastly, we mention it again here:
+/// 6. as a MLIR-based alternative to VPLAN.
+///
+/// Lowering, unrolling, pipelining:
+/// ================================
+/// TODO(ntv): point to the proper places.
+///
+/// Algorithm:
+/// ==========
+/// The algorithm proceeds in a few steps:
+/// 1. defining super-vectorization patterns and matching them on the tree of
+/// AffineForOp. A super-vectorization pattern is defined as a recursive
+/// data structures that matches and captures nested, imperfectly-nested
+/// loops that have a. conformable loop annotations attached (e.g. parallel,
+/// reduction, vectorizable, ...) as well as b. all contiguous load/store
+/// operations along a specified minor dimension (not necessarily the
+/// fastest varying) ;
+/// 2. analyzing those patterns for profitability (TODO(ntv): and
+/// interference);
+/// 3. Then, for each pattern in order:
+/// a. applying iterative rewriting of the loop and the load operations in
+/// DFS postorder. Rewriting is implemented by coarsening the loops and
+/// turning load operations into opaque vector.transfer_read ops;
+/// b. keeping track of the load operations encountered as "roots" and the
+/// store operations as "terminals";
+/// c. traversing the use-def chains starting from the roots and iteratively
+/// propagating vectorized values. Scalar values that are encountered
+/// during this process must come from outside the scope of the current
+/// pattern (TODO(ntv): enforce this and generalize). Such a scalar value
+/// is vectorized only if it is a constant (into a vector splat). The
+/// non-constant case is not supported for now and results in the pattern
+/// failing to vectorize;
+/// d. performing a second traversal on the terminals (store ops) to
+/// rewriting the scalar value they write to memory into vector form.
+/// If the scalar value has been vectorized previously, we simply replace
+/// it by its vector form. Otherwise, if the scalar value is a constant,
+/// it is vectorized into a splat. In all other cases, vectorization for
+/// the pattern currently fails.
+/// e. if everything under the root AffineForOp in the current pattern
+/// vectorizes properly, we commit that loop to the IR. Otherwise we
+/// discard it and restore a previously cloned version of the loop. Thanks
+/// to the recursive scoping nature of matchers and captured patterns,
+/// this is transparently achieved by a simple RAII implementation.
+/// f. vectorization is applied on the next pattern in the list. Because
+/// pattern interference avoidance is not yet implemented and that we do
+/// not support further vectorizing an already vector load we need to
+/// re-verify that the pattern is still vectorizable. This is expected to
+/// make cost models more difficult to write and is subject to improvement
+/// in the future.
+///
+/// Points c. and d. above are worth additional comment. In most passes that
+/// do not change the type of operands, it is usually preferred to eagerly
+/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization
+/// because during the use-def chain traversal, all the operands of an operation
+/// must be available in vector form. Trying to propagate eagerly makes the IR
+/// temporarily invalid and results in errors such as:
+/// `vectorize.mlir:308:13: error: 'addf' op requires the same type for all
+/// operands and results
+/// %s5 = addf %a5, %b5 : f32`
+///
+/// Lastly, we show a minimal example for which use-def chains rooted in load /
+/// vector.transfer_read are not enough. This is what motivated splitting
+/// terminal processing out of the use-def chains starting from loads. In the
+/// following snippet, there is simply no load::
+/// ```mlir
+/// func @fill(%A : memref<128xf32>) -> () {
+/// %f1 = constant 1.0 : f32
+/// affine.for %i0 = 0 to 32 {
+/// affine.store %f1, %A[%i0] : memref<128xf32, 0>
+/// }
+/// return
+/// }
+/// ```
+///
+/// Choice of loop transformation to support the algorithm:
+/// =======================================================
+/// The choice of loop transformation to apply for coarsening vectorized loops
+/// is still subject to exploratory tradeoffs. In particular, say we want to
+/// vectorize by a factor 128, we want to transform the following input:
+/// ```mlir
+/// affine.for %i = %M to %N {
+/// %a = affine.load %A[%i] : memref<?xf32>
+/// }
+/// ```
+///
+/// Traditionally, one would vectorize late (after scheduling, tiling,
+/// memory promotion etc) say after stripmining (and potentially unrolling in
+/// the case of LLVM's SLP vectorizer):
+/// ```mlir
+/// affine.for %i = floor(%M, 128) to ceil(%N, 128) {
+/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
+/// %a = affine.load %A[%ii] : memref<?xf32>
+/// }
+/// }
+/// ```
+///
+/// Instead, we seek to vectorize early and freeze vector types before
+/// scheduling, so we want to generate a pattern that resembles:
+/// ```mlir
+/// affine.for %i = ? to ? step ? {
+/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
+/// }
+/// ```
+///
+/// i. simply dividing the lower / upper bounds by 128 creates issues
+/// when representing expressions such as ii + 1 because now we only
+/// have access to original values that have been divided. Additional
+/// information is needed to specify accesses at below-128 granularity;
+/// ii. another alternative is to coarsen the loop step but this may have
+/// consequences on dependence analysis and fusability of loops: fusable
+/// loops probably need to have the same step (because we don't want to
+/// stripmine/unroll to enable fusion).
+/// As a consequence, we choose to represent the coarsening using the loop
+/// step for now and reevaluate in the future. Note that we can renormalize
+/// loop steps later if/when we have evidence that they are problematic.
+///
+/// For the simple strawman example above, vectorizing for a 1-D vector
+/// abstraction of size 128 returns code similar to:
+/// ```mlir
+/// affine.for %i = %M to %N step 128 {
+/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
+/// }
+/// ```
+///
+/// Unsupported cases, extensions, and work in progress (help welcome :-) ):
+/// ========================================================================
+/// 1. lowering to concrete vector types for various HW;
+/// 2. reduction support;
+/// 3. non-effecting padding during vector.transfer_read and filter during
+/// vector.transfer_write;
+/// 4. misalignment support vector.transfer_read / vector.transfer_write
+/// (hopefully without read-modify-writes);
+/// 5. control-flow support;
+/// 6. cost-models, heuristics and search;
+/// 7. Op implementation, extensions and implication on memref views;
+/// 8. many TODOs left around.
+///
+/// Examples:
+/// =========
+/// Consider the following Function:
+/// ```mlir
+/// func @vector_add_2d(%M : index, %N : index) -> f32 {
+/// %A = alloc (%M, %N) : memref<?x?xf32, 0>
+/// %B = alloc (%M, %N) : memref<?x?xf32, 0>
+/// %C = alloc (%M, %N) : memref<?x?xf32, 0>
+/// %f1 = constant 1.0 : f32
+/// %f2 = constant 2.0 : f32
+/// affine.for %i0 = 0 to %M {
+/// affine.for %i1 = 0 to %N {
+/// // non-scoped %f1
+/// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
+/// }
+/// }
+/// affine.for %i2 = 0 to %M {
+/// affine.for %i3 = 0 to %N {
+/// // non-scoped %f2
+/// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
+/// }
+/// }
+/// affine.for %i4 = 0 to %M {
+/// affine.for %i5 = 0 to %N {
+/// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0>
+/// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0>
+/// %s5 = addf %a5, %b5 : f32
+/// // non-scoped %f1
+/// %s6 = addf %s5, %f1 : f32
+/// // non-scoped %f2
+/// %s7 = addf %s5, %f2 : f32
+/// // diamond dependency.
+/// %s8 = addf %s7, %s6 : f32
+/// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
+/// }
+/// }
+/// %c7 = constant 7 : index
+/// %c42 = constant 42 : index
+/// %res = load %C[%c7, %c42] : memref<?x?xf32, 0>
+/// return %res : f32
+/// }
+/// ```
+///
+/// The -affine-vectorize pass with the following arguments:
+/// ```
+/// -affine-vectorize -virtual-vector-size 256 --test-fastest-varying=0
+/// ```
+///
+/// produces this standard innermost-loop vectorized code:
+/// ```mlir
+/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
+/// %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %cst = constant 1.0 : f32
+/// %cst_0 = constant 2.0 : f32
+/// affine.for %i0 = 0 to %arg0 {
+/// affine.for %i1 = 0 to %arg1 step 256 {
+/// %cst_1 = constant dense<vector<256xf32>, 1.0> :
+/// vector<256xf32>
+/// vector.transfer_write %cst_1, %0[%i0, %i1] :
+/// vector<256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// affine.for %i2 = 0 to %arg0 {
+/// affine.for %i3 = 0 to %arg1 step 256 {
+/// %cst_2 = constant dense<vector<256xf32>, 2.0> :
+/// vector<256xf32>
+/// vector.transfer_write %cst_2, %1[%i2, %i3] :
+/// vector<256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// affine.for %i4 = 0 to %arg0 {
+/// affine.for %i5 = 0 to %arg1 step 256 {
+/// %3 = vector.transfer_read %0[%i4, %i5] :
+/// memref<?x?xf32>, vector<256xf32>
+/// %4 = vector.transfer_read %1[%i4, %i5] :
+/// memref<?x?xf32>, vector<256xf32>
+/// %5 = addf %3, %4 : vector<256xf32>
+/// %cst_3 = constant dense<vector<256xf32>, 1.0> :
+/// vector<256xf32>
+/// %6 = addf %5, %cst_3 : vector<256xf32>
+/// %cst_4 = constant dense<vector<256xf32>, 2.0> :
+/// vector<256xf32>
+/// %7 = addf %5, %cst_4 : vector<256xf32>
+/// %8 = addf %7, %6 : vector<256xf32>
+/// vector.transfer_write %8, %2[%i4, %i5] :
+/// vector<256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// %c7 = constant 7 : index
+/// %c42 = constant 42 : index
+/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
+/// return %9 : f32
+/// }
+/// ```
+///
+/// The -affine-vectorize pass with the following arguments:
+/// ```
+/// -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256
+/// --test-fastest-varying=1 --test-fastest-varying=0
+/// ```
+///
+/// produces this more interesting mixed outer-innermost-loop vectorized code:
+/// ```mlir
+/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
+/// %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
+/// %cst = constant 1.0 : f32
+/// %cst_0 = constant 2.0 : f32
+/// affine.for %i0 = 0 to %arg0 step 32 {
+/// affine.for %i1 = 0 to %arg1 step 256 {
+/// %cst_1 = constant dense<vector<32x256xf32>, 1.0> :
+/// vector<32x256xf32>
+/// vector.transfer_write %cst_1, %0[%i0, %i1] :
+/// vector<32x256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// affine.for %i2 = 0 to %arg0 step 32 {
+/// affine.for %i3 = 0 to %arg1 step 256 {
+/// %cst_2 = constant dense<vector<32x256xf32>, 2.0> :
+/// vector<32x256xf32>
+/// vector.transfer_write %cst_2, %1[%i2, %i3] :
+/// vector<32x256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// affine.for %i4 = 0 to %arg0 step 32 {
+/// affine.for %i5 = 0 to %arg1 step 256 {
+/// %3 = vector.transfer_read %0[%i4, %i5] :
+/// memref<?x?xf32> vector<32x256xf32>
+/// %4 = vector.transfer_read %1[%i4, %i5] :
+/// memref<?x?xf32>, vector<32x256xf32>
+/// %5 = addf %3, %4 : vector<32x256xf32>
+/// %cst_3 = constant dense<vector<32x256xf32>, 1.0> :
+/// vector<32x256xf32>
+/// %6 = addf %5, %cst_3 : vector<32x256xf32>
+/// %cst_4 = constant dense<vector<32x256xf32>, 2.0> :
+/// vector<32x256xf32>
+/// %7 = addf %5, %cst_4 : vector<32x256xf32>
+/// %8 = addf %7, %6 : vector<32x256xf32>
+/// vector.transfer_write %8, %2[%i4, %i5] :
+/// vector<32x256xf32>, memref<?x?xf32>
+/// }
+/// }
+/// %c7 = constant 7 : index
+/// %c42 = constant 42 : index
+/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
+/// return %9 : f32
+/// }
+/// ```
+///
+/// Of course, much more intricate n-D imperfectly-nested patterns can be
+/// vectorized too and specified in a fully declarative fashion.
+
+#define DEBUG_TYPE "early-vect"
+
+using functional::makePtrDynCaster;
+using functional::map;
+using llvm::dbgs;
+using llvm::SetVector;
+
+static llvm::cl::OptionCategory clOptionsCategory("vectorize options");
+
+static llvm::cl::list<int> clVirtualVectorSize(
+ "virtual-vector-size",
+ llvm::cl::desc("Specify an n-D virtual vector size for vectorization"),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+static llvm::cl::list<int> clFastestVaryingPattern(
+ "test-fastest-varying",
+ llvm::cl::desc(
+ "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory"
+ " dimensions to match. See defaultPatterns in Vectorize.cpp for a"
+ " description and examples. This is used for testing purposes"),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory));
+
+/// Forward declaration.
+static FilterFunctionType
+isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
+ int fastestVaryingMemRefDimension);
+
+/// Creates a vectorization pattern from the command line arguments.
+/// Up to 3-D patterns are supported.
+/// If the command line argument requests a pattern of higher order, returns an
+/// empty pattern list which will conservatively result in no vectorization.
+static std::vector<NestedPattern>
+makePatterns(const DenseSet<Operation *> &parallelLoops, int vectorRank,
+ ArrayRef<int64_t> fastestVaryingPattern) {
+ using matcher::For;
+ int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0];
+ int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1];
+ int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2];
+ switch (vectorRank) {
+ case 1:
+ return {For(isVectorizableLoopPtrFactory(parallelLoops, d0))};
+ case 2:
+ return {For(isVectorizableLoopPtrFactory(parallelLoops, d0),
+ For(isVectorizableLoopPtrFactory(parallelLoops, d1)))};
+ case 3:
+ return {For(isVectorizableLoopPtrFactory(parallelLoops, d0),
+ For(isVectorizableLoopPtrFactory(parallelLoops, d1),
+ For(isVectorizableLoopPtrFactory(parallelLoops, d2))))};
+ default: {
+ return std::vector<NestedPattern>();
+ }
+ }
+}
+
+static NestedPattern &vectorTransferPattern() {
+ static auto pattern = matcher::Op([](Operation &op) {
+ return isa<vector::TransferReadOp>(op) || isa<vector::TransferWriteOp>(op);
+ });
+ return pattern;
+}
+
+namespace {
+
+/// Base state for the vectorize pass.
+/// Command line arguments are preempted by non-empty pass arguments.
+struct Vectorize : public FunctionPass<Vectorize> {
+ Vectorize();
+ Vectorize(ArrayRef<int64_t> virtualVectorSize);
+ void runOnFunction() override;
+
+ // The virtual vector size that we vectorize to.
+ SmallVector<int64_t, 4> vectorSizes;
+ // Optionally, the fixed mapping from loop to fastest varying MemRef dimension
+ // for all the MemRefs within a loop pattern:
+ // the index represents the loop depth, the value represents the k^th
+ // fastest varying memory dimension.
+ // This is voluntarily restrictive and is meant to precisely target a
+ // particular loop/op pair, for testing purposes.
+ SmallVector<int64_t, 4> fastestVaryingPattern;
+};
+
+} // end anonymous namespace
+
+Vectorize::Vectorize()
+ : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()),
+ fastestVaryingPattern(clFastestVaryingPattern.begin(),
+ clFastestVaryingPattern.end()) {}
+
+Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() {
+ if (!virtualVectorSize.empty()) {
+ this->vectorSizes.assign(virtualVectorSize.begin(),
+ virtualVectorSize.end());
+ }
+}
+
+/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate.
+/////////
+namespace {
+
+struct VectorizationStrategy {
+ SmallVector<int64_t, 8> vectorSizes;
+ DenseMap<Operation *, unsigned> loopToVectorDim;
+};
+
+} // end anonymous namespace
+
+static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern,
+ unsigned patternDepth,
+ VectorizationStrategy *strategy) {
+ assert(patternDepth > depthInPattern &&
+ "patternDepth is greater than depthInPattern");
+ if (patternDepth - depthInPattern > strategy->vectorSizes.size()) {
+ // Don't vectorize this loop
+ return;
+ }
+ strategy->loopToVectorDim[loop] =
+ strategy->vectorSizes.size() - (patternDepth - depthInPattern);
+}
+
+/// Implements a simple strawman strategy for vectorization.
+/// Given a matched pattern `matches` of depth `patternDepth`, this strategy
+/// greedily assigns the fastest varying dimension ** of the vector ** to the
+/// innermost loop in the pattern.
+/// When coupled with a pattern that looks for the fastest varying dimension in
+/// load/store MemRefs, this creates a generic vectorization strategy that works
+/// for any loop in a hierarchy (outermost, innermost or intermediate).
+///
+/// TODO(ntv): In the future we should additionally increase the power of the
+/// profitability analysis along 3 directions:
+/// 1. account for loop extents (both static and parametric + annotations);
+/// 2. account for data layout permutations;
+/// 3. account for impact of vectorization on maximal loop fusion.
+/// Then we can quantify the above to build a cost model and search over
+/// strategies.
+static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
+ unsigned depthInPattern,
+ unsigned patternDepth,
+ VectorizationStrategy *strategy) {
+ for (auto m : matches) {
+ if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
+ patternDepth, strategy))) {
+ return failure();
+ }
+ vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern,
+ patternDepth, strategy);
+ }
+ return success();
+}
+
+///// end TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate /////
+
+namespace {
+
+struct VectorizationState {
+ /// Adds an entry of pre/post vectorization operations in the state.
+ void registerReplacement(Operation *key, Operation *value);
+ /// When the current vectorization pattern is successful, this erases the
+ /// operations that were marked for erasure in the proper order and resets
+ /// the internal state for the next pattern.
+ void finishVectorizationPattern();
+
+ // In-order tracking of original Operation that have been vectorized.
+ // Erase in reverse order.
+ SmallVector<Operation *, 16> toErase;
+ // Set of Operation that have been vectorized (the values in the
+ // vectorizationMap for hashed access). The vectorizedSet is used in
+ // particular to filter the operations that have already been vectorized by
+ // this pattern, when iterating over nested loops in this pattern.
+ DenseSet<Operation *> vectorizedSet;
+ // Map of old scalar Operation to new vectorized Operation.
+ DenseMap<Operation *, Operation *> vectorizationMap;
+ // Map of old scalar Value to new vectorized Value.
+ DenseMap<Value, Value> replacementMap;
+ // The strategy drives which loop to vectorize by which amount.
+ const VectorizationStrategy *strategy;
+ // Use-def roots. These represent the starting points for the worklist in the
+ // vectorizeNonTerminals function. They consist of the subset of load
+ // operations that have been vectorized. They can be retrieved from
+ // `vectorizationMap` but it is convenient to keep track of them in a separate
+ // data structure.
+ DenseSet<Operation *> roots;
+ // Terminal operations for the worklist in the vectorizeNonTerminals
+ // function. They consist of the subset of store operations that have been
+ // vectorized. They can be retrieved from `vectorizationMap` but it is
+ // convenient to keep track of them in a separate data structure. Since they
+ // do not necessarily belong to use-def chains starting from loads (e.g
+ // storing a constant), we need to handle them in a post-pass.
+ DenseSet<Operation *> terminals;
+ // Checks that the type of `op` is AffineStoreOp and adds it to the terminals
+ // set.
+ void registerTerminal(Operation *op);
+ // Folder used to factor out constant creation.
+ OperationFolder *folder;
+
+private:
+ void registerReplacement(Value key, Value value);
+};
+
+} // end namespace
+
+void VectorizationState::registerReplacement(Operation *key, Operation *value) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
+ LLVM_DEBUG(key->print(dbgs()));
+ LLVM_DEBUG(dbgs() << " into ");
+ LLVM_DEBUG(value->print(dbgs()));
+ assert(key->getNumResults() == 1 && "already registered");
+ assert(value->getNumResults() == 1 && "already registered");
+ assert(vectorizedSet.count(value) == 0 && "already registered");
+ assert(vectorizationMap.count(key) == 0 && "already registered");
+ toErase.push_back(key);
+ vectorizedSet.insert(value);
+ vectorizationMap.insert(std::make_pair(key, value));
+ registerReplacement(key->getResult(0), value->getResult(0));
+ if (isa<AffineLoadOp>(key)) {
+ assert(roots.count(key) == 0 && "root was already inserted previously");
+ roots.insert(key);
+ }
+}
+
+void VectorizationState::registerTerminal(Operation *op) {
+ assert(isa<AffineStoreOp>(op) && "terminal must be a AffineStoreOp");
+ assert(terminals.count(op) == 0 &&
+ "terminal was already inserted previously");
+ terminals.insert(op);
+}
+
+void VectorizationState::finishVectorizationPattern() {
+ while (!toErase.empty()) {
+ auto *op = toErase.pop_back_val();
+ LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
+ LLVM_DEBUG(op->print(dbgs()));
+ op->erase();
+ }
+}
+
+void VectorizationState::registerReplacement(Value key, Value value) {
+ assert(replacementMap.count(key) == 0 && "replacement already registered");
+ replacementMap.insert(std::make_pair(key, value));
+}
+
+// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
+static void computeMemoryOpIndices(Operation *op, AffineMap map,
+ ValueRange mapOperands,
+ SmallVectorImpl<Value> &results) {
+ OpBuilder builder(op);
+ for (auto resultExpr : map.getResults()) {
+ auto singleResMap =
+ AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
+ auto afOp =
+ builder.create<AffineApplyOp>(op->getLoc(), singleResMap, mapOperands);
+ results.push_back(afOp);
+ }
+}
+
+////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ////
+
+/// Handles the vectorization of load and store MLIR operations.
+///
+/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call.
+/// They are vectorized immediately. The resulting vector.transfer_read is
+/// immediately registered to replace all uses of the AffineLoadOp in this
+/// pattern's scope.
+///
+/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They
+/// need to be vectorized late once all the use-def chains have been traversed.
+/// Additionally, they may have ssa-values operands which come from outside the
+/// scope of the current pattern.
+/// Such special cases force us to delay the vectorization of the stores until
+/// the last step. Here we merely register the store operation.
+template <typename LoadOrStoreOpPointer>
+static LogicalResult vectorizeRootOrTerminal(Value iv,
+ LoadOrStoreOpPointer memoryOp,
+ VectorizationState *state) {
+ auto memRefType = memoryOp.getMemRef()->getType().template cast<MemRefType>();
+
+ auto elementType = memRefType.getElementType();
+ // TODO(ntv): ponder whether we want to further vectorize a vector value.
+ assert(VectorType::isValidElementType(elementType) &&
+ "Not a valid vector element type");
+ auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
+
+ // Materialize a MemRef with 1 vector.
+ auto *opInst = memoryOp.getOperation();
+ // For now, vector.transfers must be aligned, operate only on indices with an
+ // identity subset of AffineMap and do not change layout.
+ // TODO(ntv): increase the expressiveness power of vector.transfer operations
+ // as needed by various targets.
+ if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
+ OpBuilder b(opInst);
+ ValueRange mapOperands = load.getMapOperands();
+ SmallVector<Value, 8> indices;
+ indices.reserve(load.getMemRefType().getRank());
+ if (load.getAffineMap() !=
+ b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
+ computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
+ } else {
+ indices.append(mapOperands.begin(), mapOperands.end());
+ }
+ auto permutationMap =
+ makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+ if (!permutationMap)
+ return LogicalResult::Failure;
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+ LLVM_DEBUG(permutationMap.print(dbgs()));
+ auto transfer = b.create<vector::TransferReadOp>(
+ opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices,
+ AffineMapAttr::get(permutationMap),
+ // TODO(b/144455320) add a proper padding value, not just 0.0 : f32
+ state->folder->create<ConstantFloatOp>(b, opInst->getLoc(),
+ APFloat(0.0f), b.getF32Type()));
+ state->registerReplacement(opInst, transfer.getOperation());
+ } else {
+ state->registerTerminal(opInst);
+ }
+ return success();
+}
+/// end TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ///
+
+/// Coarsens the loops bounds and transforms all remaining load and store
+/// operations into the appropriate vector.transfer.
+static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
+ VectorizationState *state) {
+ using namespace functional;
+ loop.setStep(step);
+
+ FilterFunctionType notVectorizedThisPattern = [state](Operation &op) {
+ if (!matcher::isLoadOrStore(op)) {
+ return false;
+ }
+ return state->vectorizationMap.count(&op) == 0 &&
+ state->vectorizedSet.count(&op) == 0 &&
+ state->roots.count(&op) == 0 && state->terminals.count(&op) == 0;
+ };
+ auto loadAndStores = matcher::Op(notVectorizedThisPattern);
+ SmallVector<NestedMatch, 8> loadAndStoresMatches;
+ loadAndStores.match(loop.getOperation(), &loadAndStoresMatches);
+ for (auto ls : loadAndStoresMatches) {
+ auto *opInst = ls.getMatchedOperation();
+ auto load = dyn_cast<AffineLoadOp>(opInst);
+ auto store = dyn_cast<AffineStoreOp>(opInst);
+ LLVM_DEBUG(opInst->print(dbgs()));
+ LogicalResult result =
+ load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state)
+ : vectorizeRootOrTerminal(loop.getInductionVar(), store, state);
+ if (failed(result)) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+/// Returns a FilterFunctionType that can be used in NestedPattern to match a
+/// loop whose underlying load/store accesses are either invariant or all
+// varying along the `fastestVaryingMemRefDimension`.
+static FilterFunctionType
+isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
+ int fastestVaryingMemRefDimension) {
+ return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
+ auto loop = cast<AffineForOp>(forOp);
+ auto parallelIt = parallelLoops.find(loop);
+ if (parallelIt == parallelLoops.end())
+ return false;
+ int memRefDim = -1;
+ auto vectorizableBody =
+ isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern());
+ if (!vectorizableBody)
+ return false;
+ return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
+ memRefDim == fastestVaryingMemRefDimension;
+ };
+}
+
+/// Apply vectorization of `loop` according to `state`. This is only triggered
+/// if all vectorizations in `childrenMatches` have already succeeded
+/// recursively in DFS post-order.
+static LogicalResult
+vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch,
+ VectorizationState *state) {
+ auto *loopInst = oneMatch.getMatchedOperation();
+ auto loop = cast<AffineForOp>(loopInst);
+ auto childrenMatches = oneMatch.getMatchedChildren();
+
+ // 1. DFS postorder recursion, if any of my children fails, I fail too.
+ for (auto m : childrenMatches) {
+ if (failed(vectorizeLoopsAndLoadsRecursively(m, state))) {
+ return failure();
+ }
+ }
+
+ // 2. This loop may have been omitted from vectorization for various reasons
+ // (e.g. due to the performance model or pattern depth > vector size).
+ auto it = state->strategy->loopToVectorDim.find(loopInst);
+ if (it == state->strategy->loopToVectorDim.end()) {
+ return success();
+ }
+
+ // 3. Actual post-order transformation.
+ auto vectorDim = it->second;
+ assert(vectorDim < state->strategy->vectorSizes.size() &&
+ "vector dim overflow");
+ // a. get actual vector size
+ auto vectorSize = state->strategy->vectorSizes[vectorDim];
+ // b. loop transformation for early vectorization is still subject to
+ // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
+ // | ub -> ub
+ // | step -> step * vectorSize
+ LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
+ << " : ");
+ LLVM_DEBUG(loopInst->print(dbgs()));
+ return vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state);
+}
+
+/// Tries to transform a scalar constant into a vector splat of that constant.
+/// Returns the vectorized splat operation if the constant is a valid vector
+/// element type.
+/// If `type` is not a valid vector type or if the scalar constant is not a
+/// valid vector element type, returns nullptr.
+static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
+ if (!type || !type.isa<VectorType>() ||
+ !VectorType::isValidElementType(constant.getType())) {
+ return nullptr;
+ }
+ OpBuilder b(op);
+ Location loc = op->getLoc();
+ auto vectorType = type.cast<VectorType>();
+ auto attr = DenseElementsAttr::get(vectorType, constant.getValue());
+ auto *constantOpInst = constant.getOperation();
+
+ OperationState state(loc, constantOpInst->getName().getStringRef(), {},
+ {vectorType}, {b.getNamedAttr("value", attr)});
+
+ return b.createOperation(state)->getResult(0);
+}
+
+/// Tries to vectorize a given operand `op` of Operation `op` during
+/// def-chain propagation or during terminal vectorization, by applying the
+/// following logic:
+/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized
+/// useby -def propagation), `op` is already in the proper vector form;
+/// 2. otherwise, the `op` may be in some other vector form that fails to
+/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate
+/// failure;
+/// 3. if the `op` is a constant, returns the vectorized form of the constant;
+/// 4. non-constant scalars are currently non-vectorizable, in particular to
+/// guard against vectorizing an index which may be loop-variant and needs
+/// special handling.
+///
+/// In particular this logic captures some of the use cases where definitions
+/// that are not scoped under the current pattern are needed to vectorize.
+/// One such example is top level function constants that need to be splatted.
+///
+/// Returns an operand that has been vectorized to match `state`'s strategy if
+/// vectorization is possible with the above logic. Returns nullptr otherwise.
+///
+/// TODO(ntv): handle more complex cases.
+static Value vectorizeOperand(Value operand, Operation *op,
+ VectorizationState *state) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
+ LLVM_DEBUG(operand->print(dbgs()));
+ // 1. If this value has already been vectorized this round, we are done.
+ if (state->vectorizedSet.count(operand->getDefiningOp()) > 0) {
+ LLVM_DEBUG(dbgs() << " -> already vector operand");
+ return operand;
+ }
+ // 1.b. Delayed on-demand replacement of a use.
+ // Note that we cannot just call replaceAllUsesWith because it may result
+ // in ops with mixed types, for ops whose operands have not all yet
+ // been vectorized. This would be invalid IR.
+ auto it = state->replacementMap.find(operand);
+ if (it != state->replacementMap.end()) {
+ auto res = it->second;
+ LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
+ LLVM_DEBUG(res->print(dbgs()));
+ return res;
+ }
+ // 2. TODO(ntv): broadcast needed.
+ if (operand->getType().isa<VectorType>()) {
+ LLVM_DEBUG(dbgs() << "-> non-vectorizable");
+ return nullptr;
+ }
+ // 3. vectorize constant.
+ if (auto constant = dyn_cast<ConstantOp>(operand->getDefiningOp())) {
+ return vectorizeConstant(
+ op, constant,
+ VectorType::get(state->strategy->vectorSizes, operand->getType()));
+ }
+ // 4. currently non-vectorizable.
+ LLVM_DEBUG(dbgs() << "-> non-vectorizable");
+ LLVM_DEBUG(operand->print(dbgs()));
+ return nullptr;
+}
+
+/// Encodes Operation-specific behavior for vectorization. In general we assume
+/// that all operands of an op must be vectorized but this is not always true.
+/// In the future, it would be nice to have a trait that describes how a
+/// particular operation vectorizes. For now we implement the case distinction
+/// here.
+/// Returns a vectorized form of an operation or nullptr if vectorization fails.
+// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
+// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
+// do one-off logic here; ideally it would be TableGen'd.
+static Operation *vectorizeOneOperation(Operation *opInst,
+ VectorizationState *state) {
+ // Sanity checks.
+ assert(!isa<AffineLoadOp>(opInst) &&
+ "all loads must have already been fully vectorized independently");
+ assert(!isa<vector::TransferReadOp>(opInst) &&
+ "vector.transfer_read cannot be further vectorized");
+ assert(!isa<vector::TransferWriteOp>(opInst) &&
+ "vector.transfer_write cannot be further vectorized");
+
+ if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
+ OpBuilder b(opInst);
+ auto memRef = store.getMemRef();
+ auto value = store.getValueToStore();
+ auto vectorValue = vectorizeOperand(value, opInst, state);
+
+ ValueRange mapOperands = store.getMapOperands();
+ SmallVector<Value, 8> indices;
+ indices.reserve(store.getMemRefType().getRank());
+ if (store.getAffineMap() !=
+ b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
+ computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
+ indices);
+ } else {
+ indices.append(mapOperands.begin(), mapOperands.end());
+ }
+
+ auto permutationMap =
+ makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
+ if (!permutationMap)
+ return nullptr;
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
+ LLVM_DEBUG(permutationMap.print(dbgs()));
+ auto transfer = b.create<vector::TransferWriteOp>(
+ opInst->getLoc(), vectorValue, memRef, indices,
+ AffineMapAttr::get(permutationMap));
+ auto *res = transfer.getOperation();
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
+ // "Terminals" (i.e. AffineStoreOps) are erased on the spot.
+ opInst->erase();
+ return res;
+ }
+ if (opInst->getNumRegions() != 0)
+ return nullptr;
+
+ SmallVector<Type, 8> vectorTypes;
+ for (auto v : opInst->getResults()) {
+ vectorTypes.push_back(
+ VectorType::get(state->strategy->vectorSizes, v->getType()));
+ }
+ SmallVector<Value, 8> vectorOperands;
+ for (auto v : opInst->getOperands()) {
+ vectorOperands.push_back(vectorizeOperand(v, opInst, state));
+ }
+ // Check whether a single operand is null. If so, vectorization failed.
+ bool success = llvm::all_of(vectorOperands, [](Value op) { return op; });
+ if (!success) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
+ return nullptr;
+ }
+
+ // Create a clone of the op with the proper operands and return types.
+ // TODO(ntv): The following assumes there is always an op with a fixed
+ // name that works both in scalar mode and vector mode.
+ // TODO(ntv): Is it worth considering an Operation.clone operation which
+ // changes the type so we can promote an Operation with less boilerplate?
+ OpBuilder b(opInst);
+ OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(),
+ vectorOperands, vectorTypes, opInst->getAttrs(),
+ /*successors=*/{},
+ /*regions=*/{}, opInst->hasResizableOperandsList());
+ return b.createOperation(newOp);
+}
+
+/// Iterates over the forward slice from the loads in the vectorization pattern
+/// and rewrites them using their vectorized counterpart by:
+/// 1. Create the forward slice starting from the loads in the vectorization
+/// pattern.
+/// 2. Topologically sorts the forward slice.
+/// 3. For each operation in the slice, create the vector form of this
+/// operation, replacing each operand by a replacement operands retrieved from
+/// replacementMap. If any such replacement is missing, vectorization fails.
+static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
+ // 1. create initial worklist with the uses of the roots.
+ SetVector<Operation *> worklist;
+ // Note: state->roots have already been vectorized and must not be vectorized
+ // again. This fits `getForwardSlice` which does not insert `op` in the
+ // result.
+ // Note: we have to exclude terminals because some of their defs may not be
+ // nested under the vectorization pattern (e.g. constants defined in an
+ // encompassing scope).
+ // TODO(ntv): Use a backward slice for terminals, avoid special casing and
+ // merge implementations.
+ for (auto *op : state->roots) {
+ getForwardSlice(op, &worklist, [state](Operation *op) {
+ return state->terminals.count(op) == 0; // propagate if not terminal
+ });
+ }
+ // We merged multiple slices, topological order may not hold anymore.
+ worklist = topologicalSort(worklist);
+
+ for (unsigned i = 0; i < worklist.size(); ++i) {
+ auto *op = worklist[i];
+ LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
+ LLVM_DEBUG(op->print(dbgs()));
+
+ // Create vector form of the operation.
+ // Insert it just before op, on success register op as replaced.
+ auto *vectorizedInst = vectorizeOneOperation(op, state);
+ if (!vectorizedInst) {
+ return failure();
+ }
+
+ // 3. Register replacement for future uses in the scope.
+ // Note that we cannot just call replaceAllUsesWith because it may
+ // result in ops with mixed types, for ops whose operands have not all
+ // yet been vectorized. This would be invalid IR.
+ state->registerReplacement(op, vectorizedInst);
+ }
+ return success();
+}
+
+/// Vectorization is a recursive procedure where anything below can fail.
+/// The root match thus needs to maintain a clone for handling failure.
+/// Each root may succeed independently but will otherwise clean after itself if
+/// anything below it fails.
+static LogicalResult vectorizeRootMatch(NestedMatch m,
+ VectorizationStrategy *strategy) {
+ auto loop = cast<AffineForOp>(m.getMatchedOperation());
+ OperationFolder folder(loop.getContext());
+ VectorizationState state;
+ state.strategy = strategy;
+ state.folder = &folder;
+
+ // Since patterns are recursive, they can very well intersect.
+ // Since we do not want a fully greedy strategy in general, we decouple
+ // pattern matching, from profitability analysis, from application.
+ // As a consequence we must check that each root pattern is still
+ // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
+ // TODO(ntv): implement a non-greedy profitability analysis that keeps only
+ // non-intersecting patterns.
+ if (!isVectorizableLoopBody(loop, vectorTransferPattern())) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
+ return failure();
+ }
+
+ /// Sets up error handling for this root loop. This is how the root match
+ /// maintains a clone for handling failure and restores the proper state via
+ /// RAII.
+ auto *loopInst = loop.getOperation();
+ OpBuilder builder(loopInst);
+ auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
+ struct Guard {
+ LogicalResult failure() {
+ loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar());
+ loop.erase();
+ return mlir::failure();
+ }
+ LogicalResult success() {
+ clonedLoop.erase();
+ return mlir::success();
+ }
+ AffineForOp loop;
+ AffineForOp clonedLoop;
+ } guard{loop, clonedLoop};
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Start vectorizing.
+ // From now on, any error triggers the scope guard above.
+ //////////////////////////////////////////////////////////////////////////////
+ // 1. Vectorize all the loops matched by the pattern, recursively.
+ // This also vectorizes the roots (AffineLoadOp) as well as registers the
+ // terminals (AffineStoreOp) for post-processing vectorization (we need to
+ // wait for all use-def chains into them to be vectorized first).
+ if (failed(vectorizeLoopsAndLoadsRecursively(m, &state))) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop");
+ return guard.failure();
+ }
+
+ // 2. Vectorize operations reached by use-def chains from root except the
+ // terminals (store operations) that need to be post-processed separately.
+ // TODO(ntv): add more as we expand.
+ if (failed(vectorizeNonTerminals(&state))) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals");
+ return guard.failure();
+ }
+
+ // 3. Post-process terminals.
+ // Note: we have to post-process terminals because some of their defs may not
+ // be nested under the vectorization pattern (e.g. constants defined in an
+ // encompassing scope).
+ // TODO(ntv): Use a backward slice for terminals, avoid special casing and
+ // merge implementations.
+ for (auto *op : state.terminals) {
+ if (!vectorizeOneOperation(op, &state)) { // nullptr == failure
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals");
+ return guard.failure();
+ }
+ }
+
+ // 4. Finish this vectorization pattern.
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
+ state.finishVectorizationPattern();
+ return guard.success();
+}
+
+/// Applies vectorization to the current Function by searching over a bunch of
+/// predetermined patterns.
+void Vectorize::runOnFunction() {
+ FuncOp f = getFunction();
+ if (!fastestVaryingPattern.empty() &&
+ fastestVaryingPattern.size() != vectorSizes.size()) {
+ f.emitRemark("Fastest varying pattern specified with different size than "
+ "the vector size.");
+ return signalPassFailure();
+ }
+
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
+ DenseSet<Operation *> parallelLoops;
+ f.walk([&parallelLoops](AffineForOp loop) {
+ if (isLoopParallel(loop))
+ parallelLoops.insert(loop);
+ });
+
+ for (auto &pat :
+ makePatterns(parallelLoops, vectorSizes.size(), fastestVaryingPattern)) {
+ LLVM_DEBUG(dbgs() << "\n******************************************");
+ LLVM_DEBUG(dbgs() << "\n******************************************");
+ LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");
+ LLVM_DEBUG(f.print(dbgs()));
+ unsigned patternDepth = pat.getDepth();
+
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
+ // Iterate over all the top-level matches and vectorize eagerly.
+ // This automatically prunes intersecting matches.
+ for (auto m : matches) {
+ VectorizationStrategy strategy;
+ // TODO(ntv): depending on profitability, elect to reduce the vector size.
+ strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end());
+ if (failed(analyzeProfitability(m.getMatchedChildren(), 1, patternDepth,
+ &strategy))) {
+ continue;
+ }
+ vectorizeLoopIfProfitable(m.getMatchedOperation(), 0, patternDepth,
+ &strategy);
+ // TODO(ntv): if pattern does not apply, report it; alter the
+ // cost/benefit.
+ vectorizeRootMatch(m, &strategy);
+ // TODO(ntv): some diagnostics if failure to vectorize occurs.
+ }
+ }
+ LLVM_DEBUG(dbgs() << "\n");
+}
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createVectorizePass(ArrayRef<int64_t> virtualVectorSize) {
+ return std::make_unique<Vectorize>(virtualVectorSize);
+}
+
+static PassRegistration<Vectorize>
+ pass("affine-vectorize",
+ "Vectorize to a target independent n-D vector abstraction");
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
new file mode 100644
index 00000000000..508c547a52b
--- /dev/null
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -0,0 +1,170 @@
+//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ViewOpGraph.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<int> elideIfLarger(
+ "print-op-graph-elide-if-larger",
+ llvm::cl::desc("Upper limit to emit elements attribute rather than elide"),
+ llvm::cl::init(16));
+
+using namespace mlir;
+
+namespace llvm {
+
+// Specialize GraphTraits to treat Block as a graph of Operations as nodes and
+// uses as edges.
+template <> struct GraphTraits<Block *> {
+ using GraphType = Block *;
+ using NodeRef = Operation *;
+
+ using ChildIteratorType = UseIterator;
+ static ChildIteratorType child_begin(NodeRef n) {
+ return ChildIteratorType(n);
+ }
+ static ChildIteratorType child_end(NodeRef n) {
+ return ChildIteratorType(n, /*end=*/true);
+ }
+
+ // Operation's destructor is private so use Operation* instead and use
+ // mapped iterator.
+ static Operation *AddressOf(Operation &op) { return &op; }
+ using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
+ static nodes_iterator nodes_begin(Block *b) {
+ return nodes_iterator(b->begin(), &AddressOf);
+ }
+ static nodes_iterator nodes_end(Block *b) {
+ return nodes_iterator(b->end(), &AddressOf);
+ }
+};
+
+// Specialize DOTGraphTraits to produce more readable output.
+template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
+ using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
+ static std::string getNodeLabel(Operation *op, Block *);
+};
+
+std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
+ // Reuse the print output for the node labels.
+ std::string ostr;
+ raw_string_ostream os(ostr);
+ os << op->getName() << "\n";
+
+ if (!op->getLoc().isa<UnknownLoc>()) {
+ os << op->getLoc() << "\n";
+ }
+
+ // Print resultant types
+ interleaveComma(op->getResultTypes(), os);
+ os << "\n";
+
+ for (auto attr : op->getAttrs()) {
+ os << '\n' << attr.first << ": ";
+ // Always emit splat attributes.
+ if (attr.second.isa<SplatElementsAttr>()) {
+ attr.second.print(os);
+ continue;
+ }
+
+ // Elide "big" elements attributes.
+ auto elements = attr.second.dyn_cast<ElementsAttr>();
+ if (elements && elements.getNumElements() > elideIfLarger) {
+ os << std::string(elements.getType().getRank(), '[') << "..."
+ << std::string(elements.getType().getRank(), ']') << " : "
+ << elements.getType();
+ continue;
+ }
+
+ auto array = attr.second.dyn_cast<ArrayAttr>();
+ if (array && static_cast<int64_t>(array.size()) > elideIfLarger) {
+ os << "[...]";
+ continue;
+ }
+
+ // Print all other attributes.
+ attr.second.print(os);
+ }
+ return os.str();
+}
+
+} // end namespace llvm
+
+namespace {
+// PrintOpPass is simple pass to write graph per function.
+// Note: this is a module pass only to avoid interleaving on the same ostream
+// due to multi-threading over functions.
+struct PrintOpPass : public ModulePass<PrintOpPass> {
+ explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false,
+ const Twine &title = "")
+ : os(os), title(title.str()), short_names(short_names) {}
+
+ std::string getOpName(Operation &op) {
+ auto symbolAttr =
+ op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
+ if (symbolAttr)
+ return symbolAttr.getValue();
+ ++unnamedOpCtr;
+ return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
+ }
+
+ // Print all the ops in a module.
+ void processModule(ModuleOp module) {
+ for (Operation &op : module) {
+ // Modules may actually be nested, recurse on nesting.
+ if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
+ processModule(nestedModule);
+ continue;
+ }
+ auto opName = getOpName(op);
+ for (Region &region : op.getRegions()) {
+ for (auto indexed_block : llvm::enumerate(region)) {
+ // Suffix block number if there are more than 1 block.
+ auto blockName = region.getBlocks().size() == 1
+ ? ""
+ : ("__" + llvm::utostr(indexed_block.index()));
+ llvm::WriteGraph(os, &indexed_block.value(), short_names,
+ Twine(title) + opName + blockName);
+ }
+ }
+ }
+ }
+
+ void runOnModule() override { processModule(getModule()); }
+
+private:
+ raw_ostream &os;
+ std::string title;
+ int unnamedOpCtr = 0;
+ bool short_names;
+};
+} // namespace
+
+void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
+ const Twine &title, llvm::GraphProgram::Name program) {
+ llvm::ViewGraph(&block, name, shortNames, title, program);
+}
+
+raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
+ const Twine &title) {
+ return llvm::WriteGraph(os, &block, shortNames, title);
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
+ const Twine &title) {
+ return std::make_unique<PrintOpPass>(os, shortNames, title);
+}
+
+static PassRegistration<PrintOpPass> pass("print-op-graph",
+ "Print op graph per region");
diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp
new file mode 100644
index 00000000000..77111087d07
--- /dev/null
+++ b/mlir/lib/Transforms/ViewRegionGraph.cpp
@@ -0,0 +1,85 @@
+//===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ViewRegionGraph.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace llvm {
+
+// Specialize DOTGraphTraits to produce more readable output.
+template <> struct DOTGraphTraits<Region *> : public DefaultDOTGraphTraits {
+ using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
+
+ static std::string getNodeLabel(Block *Block, Region *);
+};
+
+std::string DOTGraphTraits<Region *>::getNodeLabel(Block *Block, Region *) {
+ // Reuse the print output for the node labels.
+ std::string outStreamStr;
+ raw_string_ostream os(outStreamStr);
+ Block->print(os);
+ std::string &outStr = os.str();
+
+ if (outStr[0] == '\n')
+ outStr.erase(outStr.begin());
+
+ // Process string output to left justify the block.
+ for (unsigned i = 0; i != outStr.length(); ++i) {
+ if (outStr[i] == '\n') {
+ outStr[i] = '\\';
+ outStr.insert(outStr.begin() + i + 1, 'l');
+ }
+ }
+
+ return outStr;
+}
+
+} // end namespace llvm
+
+void mlir::viewGraph(Region &region, const Twine &name, bool shortNames,
+ const Twine &title, llvm::GraphProgram::Name program) {
+ llvm::ViewGraph(&region, name, shortNames, title, program);
+}
+
+raw_ostream &mlir::writeGraph(raw_ostream &os, Region &region, bool shortNames,
+ const Twine &title) {
+ return llvm::WriteGraph(os, &region, shortNames, title);
+}
+
+void mlir::Region::viewGraph(const Twine &regionName) {
+ ::mlir::viewGraph(*this, regionName);
+}
+void mlir::Region::viewGraph() { viewGraph("region"); }
+
+namespace {
+struct PrintCFGPass : public FunctionPass<PrintCFGPass> {
+ PrintCFGPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
+ const Twine &title = "")
+ : os(os), shortNames(shortNames), title(title.str()) {}
+ void runOnFunction() override {
+ mlir::writeGraph(os, getFunction().getBody(), shortNames, title);
+ }
+
+private:
+ raw_ostream &os;
+ bool shortNames;
+ std::string title;
+};
+} // namespace
+
+std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
+mlir::createPrintCFGGraphPass(raw_ostream &os, bool shortNames,
+ const Twine &title) {
+ return std::make_unique<PrintCFGPass>(os, shortNames, title);
+}
+
+static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
+ "Print CFG graph per Function");
OpenPOWER on IntegriCloud