diff options
author | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
---|---|---|
committer | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
commit | 0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch) | |
tree | 31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/lib/Transforms | |
parent | 6f635f90929da9545dd696071a829a1a42f84b30 (diff) | |
parent | 5b4a01d4a63cb66ab981e52548f940813393bf42 (diff) | |
download | bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip |
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/lib/Transforms')
28 files changed, 12273 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp new file mode 100644 index 00000000000..902f5c3adcb --- /dev/null +++ b/mlir/lib/Transforms/AffineDataCopyGeneration.cpp @@ -0,0 +1,268 @@ +//===- AffineDataCopyGeneration.cpp - Explicit memref copying pass ------*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to automatically promote accessed memref regions +// to buffers in a faster memory space that is explicitly managed, with the +// necessary data movement operations performed through either regular +// point-wise load/store's or DMAs. Such explicit copying (also referred to as +// array packing/unpacking in the literature), when done on arrays that exhibit +// reuse, results in near elimination of conflict misses, TLB misses, reduced +// use of hardware prefetch streams, and reduced false sharing. It is also +// necessary for hardware that explicitly managed levels in the memory +// hierarchy, and where DMAs may have to be used. This optimization is often +// performed on already tiled code. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include <algorithm> + +#define DEBUG_TYPE "affine-data-copy-generate" + +using namespace mlir; + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +static llvm::cl::opt<unsigned long long> clFastMemoryCapacity( + "affine-data-copy-generate-fast-mem-capacity", + llvm::cl::desc( + "Set fast memory space capacity in KiB (default: unlimited)"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<bool> + clDma("affine-data-copy-generate-dma", + llvm::cl::desc("Generate DMA instead of point-wise copy"), + llvm::cl::cat(clOptionsCategory), llvm::cl::init(true)); + +static llvm::cl::opt<unsigned> clFastMemorySpace( + "affine-data-copy-generate-fast-mem-space", llvm::cl::init(1), + llvm::cl::desc( + "Fast memory space identifier for copy generation (default: 1)"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<bool> clSkipNonUnitStrideLoop( + "affine-data-copy-generate-skip-non-unit-stride-loops", llvm::cl::Hidden, + llvm::cl::init(false), + llvm::cl::desc("Testing purposes: avoid non-unit stride loop choice depths " + "for copy placement"), + llvm::cl::cat(clOptionsCategory)); + +namespace { + +/// Replaces all loads and stores on memref's living in 'slowMemorySpace' by +/// introducing copy operations to transfer data into `fastMemorySpace` and +/// rewriting the original load's/store's to instead load/store from the +/// allocated fast memory buffers. Additional options specify the identifier +/// corresponding to the fast memory space and the amount of fast memory space +/// available. The pass traverses through the nesting structure, recursing to +/// inner levels if necessary to determine at what depth copies need to be +/// placed so that the allocated buffers fit within the memory capacity +/// provided. +// TODO(bondhugula): We currently can't generate copies correctly when stores +// are strided. Check for strided stores. +struct AffineDataCopyGeneration + : public FunctionPass<AffineDataCopyGeneration> { + explicit AffineDataCopyGeneration( + unsigned slowMemorySpace = 0, + unsigned fastMemorySpace = clFastMemorySpace, unsigned tagMemorySpace = 0, + int minDmaTransferSize = 1024, + uint64_t fastMemCapacityBytes = + (clFastMemoryCapacity.getNumOccurrences() > 0 + ? clFastMemoryCapacity * 1024 // cl-provided size is in KiB + : std::numeric_limits<uint64_t>::max()), + bool generateDma = clDma, + bool skipNonUnitStrideLoops = clSkipNonUnitStrideLoop) + : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace), + tagMemorySpace(tagMemorySpace), minDmaTransferSize(minDmaTransferSize), + fastMemCapacityBytes(fastMemCapacityBytes), generateDma(generateDma), + skipNonUnitStrideLoops(skipNonUnitStrideLoops) {} + + explicit AffineDataCopyGeneration(const AffineDataCopyGeneration &other) + : slowMemorySpace(other.slowMemorySpace), + fastMemorySpace(other.fastMemorySpace), + tagMemorySpace(other.tagMemorySpace), + minDmaTransferSize(other.minDmaTransferSize), + fastMemCapacityBytes(other.fastMemCapacityBytes), + generateDma(other.generateDma), + skipNonUnitStrideLoops(other.skipNonUnitStrideLoops) {} + + void runOnFunction() override; + LogicalResult runOnBlock(Block *block, DenseSet<Operation *> ©Nests); + + // Slow memory space associated with copies. + const unsigned slowMemorySpace; + // Fast memory space associated with copies. + unsigned fastMemorySpace; + // Memory space associated with DMA tags. + unsigned tagMemorySpace; + // Minimum DMA transfer size supported by the target in bytes. + const int minDmaTransferSize; + // Capacity of the faster memory space. + uint64_t fastMemCapacityBytes; + + // If set, generate DMA operations instead of read/write. + bool generateDma; + + // If set, ignore loops with steps other than 1. + bool skipNonUnitStrideLoops; + + // Constant zero index to avoid too many duplicates. + Value zeroIndex = nullptr; +}; + +} // end anonymous namespace + +/// Generates copies for memref's living in 'slowMemorySpace' into newly created +/// buffers in 'fastMemorySpace', and replaces memory operations to the former +/// by the latter. Only load op's handled for now. +/// TODO(bondhugula): extend this to store op's. +std::unique_ptr<OpPassBase<FuncOp>> mlir::createAffineDataCopyGenerationPass( + unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace, + int minDmaTransferSize, uint64_t fastMemCapacityBytes) { + return std::make_unique<AffineDataCopyGeneration>( + slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize, + fastMemCapacityBytes); +} + +/// Generate copies for this block. The block is partitioned into separate +/// ranges: each range is either a sequence of one or more operations starting +/// and ending with an affine load or store op, or just an affine.forop (which +/// could have other affine for op's nested within). +LogicalResult +AffineDataCopyGeneration::runOnBlock(Block *block, + DenseSet<Operation *> ©Nests) { + if (block->empty()) + return success(); + + AffineCopyOptions copyOptions = {generateDma, slowMemorySpace, + fastMemorySpace, tagMemorySpace, + fastMemCapacityBytes}; + + // Every affine.forop in the block starts and ends a block range for copying; + // in addition, a contiguous sequence of operations starting with a + // load/store op but not including any copy nests themselves is also + // identified as a copy block range. Straightline code (a contiguous chunk of + // operations excluding AffineForOp's) are always assumed to not exhaust + // memory. As a result, this approach is conservative in some cases at the + // moment; we do a check later and report an error with location info. + // TODO(bondhugula): An 'affine.if' operation is being treated similar to an + // operation. 'affine.if''s could have 'affine.for's in them; + // treat them separately. + + // Get to the first load, store, or for op (that is not a copy nest itself). + auto curBegin = + std::find_if(block->begin(), block->end(), [&](Operation &op) { + return (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || + isa<AffineForOp>(op)) && + copyNests.count(&op) == 0; + }); + + // Create [begin, end) ranges. + auto it = curBegin; + while (it != block->end()) { + AffineForOp forOp; + // If you hit a non-copy for loop, we will split there. + if ((forOp = dyn_cast<AffineForOp>(&*it)) && copyNests.count(forOp) == 0) { + // Perform the copying up unti this 'for' op first. + affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/it, copyOptions, + copyNests); + + // Returns true if the footprint is known to exceed capacity. + auto exceedsCapacity = [&](AffineForOp forOp) { + Optional<int64_t> footprint = + getMemoryFootprintBytes(forOp, + /*memorySpace=*/0); + return (footprint.hasValue() && + static_cast<uint64_t>(footprint.getValue()) > + fastMemCapacityBytes); + }; + + // If the memory footprint of the 'affine.for' loop is higher than fast + // memory capacity (when provided), we recurse to copy at an inner level + // until we find a depth at which footprint fits in fast mem capacity. If + // the footprint can't be calculated, we assume for now it fits. Recurse + // inside if footprint for 'forOp' exceeds capacity, or when + // skipNonUnitStrideLoops is set and the step size is not one. + bool recurseInner = skipNonUnitStrideLoops ? forOp.getStep() != 1 + : exceedsCapacity(forOp); + if (recurseInner) { + // We'll recurse and do the copies at an inner level for 'forInst'. + // Recurse onto the body of this loop. + runOnBlock(forOp.getBody(), copyNests); + } else { + // We have enough capacity, i.e., copies will be computed for the + // portion of the block until 'it', and for 'it', which is 'forOp'. Note + // that for the latter, the copies are placed just before this loop (for + // incoming copies) and right after (for outgoing ones). + + // Inner loop copies have their own scope - we don't thus update + // consumed capacity. The footprint check above guarantees this inner + // loop's footprint fits. + affineDataCopyGenerate(/*begin=*/it, /*end=*/std::next(it), copyOptions, + copyNests); + } + // Get to the next load or store op after 'forOp'. + curBegin = std::find_if(std::next(it), block->end(), [&](Operation &op) { + return (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || + isa<AffineForOp>(op)) && + copyNests.count(&op) == 0; + }); + it = curBegin; + } else { + assert(copyNests.count(&*it) == 0 && + "all copy nests generated should have been skipped above"); + // We simply include this op in the current range and continue for more. + ++it; + } + } + + // Generate the copy for the final block range. + if (curBegin != block->end()) { + // Can't be a terminator because it would have been skipped above. + assert(!curBegin->isKnownTerminator() && "can't be a terminator"); + // Exclude the affine terminator - hence, the std::prev. + affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/std::prev(block->end()), + copyOptions, copyNests); + } + + return success(); +} + +void AffineDataCopyGeneration::runOnFunction() { + FuncOp f = getFunction(); + OpBuilder topBuilder(f.getBody()); + zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0); + + // Nests that are copy-in's or copy-out's; the root AffineForOps of those + // nests are stored herein. + DenseSet<Operation *> copyNests; + + // Clear recorded copy nests. + copyNests.clear(); + + for (auto &block : f) + runOnBlock(&block, copyNests); + + // Promote any single iteration loops in the copy nests. + for (auto nest : copyNests) { + nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); + } +} + +static PassRegistration<AffineDataCopyGeneration> + pass("affine-data-copy-generate", + "Generate explicit copying for memory operations"); diff --git a/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp new file mode 100644 index 00000000000..24ec2d7c70b --- /dev/null +++ b/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -0,0 +1,239 @@ +//===- AffineLoopInvariantCodeMotion.cpp - Code to perform loop fusion-----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop invariant code motion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "licm" + +using namespace mlir; + +namespace { + +/// Loop invariant code motion (LICM) pass. +/// TODO(asabne) : The pass is missing zero-trip tests. +/// TODO(asabne) : Check for the presence of side effects before hoisting. +/// TODO: This code should be removed once the new LICM pass can handle its +/// uses. +struct LoopInvariantCodeMotion : public FunctionPass<LoopInvariantCodeMotion> { + void runOnFunction() override; + void runOnAffineForOp(AffineForOp forOp); +}; +} // end anonymous namespace + +static bool +checkInvarianceOfNestedIfOps(Operation *op, Value indVar, + SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist); +static bool isOpLoopInvariant(Operation &op, Value indVar, + SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist); + +static bool +areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar, + SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist); + +static bool isMemRefDereferencingOp(Operation &op) { + // TODO(asabne): Support DMA Ops. + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { + return true; + } + return false; +} + +// Returns true if the individual op is loop invariant. +bool isOpLoopInvariant(Operation &op, Value indVar, + SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist) { + LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); + + if (isa<AffineIfOp>(op)) { + if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) { + return false; + } + } else if (isa<AffineForOp>(op)) { + // If the body of a predicated region has a for loop, we don't hoist the + // 'affine.if'. + return false; + } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) { + // TODO(asabne): Support DMA ops. + return false; + } else if (!isa<ConstantOp>(op)) { + if (isMemRefDereferencingOp(op)) { + Value memref = isa<AffineLoadOp>(op) + ? cast<AffineLoadOp>(op).getMemRef() + : cast<AffineStoreOp>(op).getMemRef(); + for (auto *user : memref->getUsers()) { + // If this memref has a user that is a DMA, give up because these + // operations write to this memref. + if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) { + return false; + } + // If the memref used by the load/store is used in a store elsewhere in + // the loop nest, we do not hoist. Similarly, if the memref used in a + // load is also being stored too, we do not hoist the load. + if (isa<AffineStoreOp>(user) || + (isa<AffineLoadOp>(user) && isa<AffineStoreOp>(op))) { + if (&op != user) { + SmallVector<AffineForOp, 8> userIVs; + getLoopIVs(*user, &userIVs); + // Check that userIVs don't contain the for loop around the op. + if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) { + return false; + } + } + } + } + } + + // Insert this op in the defined ops list. + definedOps.insert(&op); + + if (op.getNumOperands() == 0 && !isa<AffineTerminatorOp>(op)) { + LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); + return false; + } + for (unsigned int i = 0; i < op.getNumOperands(); ++i) { + auto *operandSrc = op.getOperand(i)->getDefiningOp(); + + LLVM_DEBUG( + op.getOperand(i)->print(llvm::dbgs() << "\nIterating on operand\n")); + + // If the loop IV is the operand, this op isn't loop invariant. + if (indVar == op.getOperand(i)) { + LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n"); + return false; + } + + if (operandSrc != nullptr) { + LLVM_DEBUG(llvm::dbgs() + << *operandSrc << "\nIterating on operand src\n"); + + // If the value was defined in the loop (outside of the + // if/else region), and that operation itself wasn't meant to + // be hoisted, then mark this operation loop dependent. + if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { + return false; + } + } + } + } + + // If no operand was loop variant, mark this op for motion. + opsToHoist.insert(&op); + return true; +} + +// Checks if all ops in a region (i.e. list of blocks) are loop invariant. +bool areAllOpsInTheBlockListInvariant( + Region &blockList, Value indVar, SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist) { + + for (auto &b : blockList) { + for (auto &op : b) { + if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + return false; + } + } + } + + return true; +} + +// Returns true if the affine.if op can be hoisted. +bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, + SmallPtrSetImpl<Operation *> &definedOps, + SmallPtrSetImpl<Operation *> &opsToHoist) { + assert(isa<AffineIfOp>(op)); + auto ifOp = cast<AffineIfOp>(op); + + if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, definedOps, + opsToHoist)) { + return false; + } + + if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, definedOps, + opsToHoist)) { + return false; + } + + return true; +} + +void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { + auto *loopBody = forOp.getBody(); + auto indVar = forOp.getInductionVar(); + + SmallPtrSet<Operation *, 8> definedOps; + // This is the place where hoisted instructions would reside. + OpBuilder b(forOp.getOperation()); + + SmallPtrSet<Operation *, 8> opsToHoist; + SmallVector<Operation *, 8> opsToMove; + + for (auto &op : *loopBody) { + // We don't hoist for loops. + if (!isa<AffineForOp>(op)) { + if (!isa<AffineTerminatorOp>(op)) { + if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + opsToMove.push_back(&op); + } + } + } + } + + // For all instructions that we found to be invariant, place sequentially + // right before the for loop. + for (auto *op : opsToMove) { + op->moveBefore(forOp); + } + + LLVM_DEBUG(forOp.getOperation()->print(llvm::dbgs() << "Modified loop\n")); +} + +void LoopInvariantCodeMotion::runOnFunction() { + // Walk through all loops in a function in innermost-loop-first order. This + // way, we first LICM from the inner loop, and place the ops in + // the outer loop, which in turn can be further LICM'ed. + getFunction().walk([&](AffineForOp op) { + LLVM_DEBUG(op.getOperation()->print(llvm::dbgs() << "\nOriginal loop\n")); + runOnAffineForOp(op); + }); +} + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createAffineLoopInvariantCodeMotionPass() { + return std::make_unique<LoopInvariantCodeMotion>(); +} + +static PassRegistration<LoopInvariantCodeMotion> + pass("affine-loop-invariant-code-motion", + "Hoist loop invariant instructions outside of the loop"); diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt new file mode 100644 index 00000000000..d6c5bd88f7f --- /dev/null +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -0,0 +1,38 @@ +add_subdirectory(Utils) + +add_llvm_library(MLIRTransforms + AffineDataCopyGeneration.cpp + AffineLoopInvariantCodeMotion.cpp + Canonicalizer.cpp + CSE.cpp + DialectConversion.cpp + Inliner.cpp + LoopCoalescing.cpp + LoopFusion.cpp + LoopInvariantCodeMotion.cpp + LoopTiling.cpp + LoopUnrollAndJam.cpp + LoopUnroll.cpp + MemRefDataFlowOpt.cpp + PipelineDataTransfer.cpp + SimplifyAffineStructures.cpp + StripDebugInfo.cpp + Vectorize.cpp + ViewOpGraph.cpp + ViewRegionGraph.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms + ) + +add_dependencies(MLIRTransforms + MLIRLoopLikeInterfaceIncGen + MLIRStandardOpsIncGen) +target_link_libraries(MLIRTransforms + MLIRAffineOps + MLIRAnalysis + MLIRLoopOps + MLIRPass + MLIRTransformUtils + MLIRVectorOps + ) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp new file mode 100644 index 00000000000..714fb1d0109 --- /dev/null +++ b/mlir/lib/Transforms/CSE.cpp @@ -0,0 +1,263 @@ +//===- CSE.cpp - Common Sub-expression Elimination ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a simple common sub-expression elimination +// algorithm on operations within a function. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Dominance.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/RecyclingAllocator.h" +#include <deque> +using namespace mlir; + +namespace { +// TODO(riverriddle) Handle commutative operations. +struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { + static unsigned getHashValue(const Operation *opC) { + auto *op = const_cast<Operation *>(opC); + // Hash the operations based upon their: + // - Operation Name + // - Attributes + // - Result Types + // - Operands + return hash_combine( + op->getName(), op->getAttrList().getDictionary(), + hash_combine_range(op->result_type_begin(), op->result_type_end()), + hash_combine_range(op->operand_begin(), op->operand_end())); + } + static bool isEqual(const Operation *lhsC, const Operation *rhsC) { + auto *lhs = const_cast<Operation *>(lhsC); + auto *rhs = const_cast<Operation *>(rhsC); + if (lhs == rhs) + return true; + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + + // Compare the operation name. + if (lhs->getName() != rhs->getName()) + return false; + // Check operand and result type counts. + if (lhs->getNumOperands() != rhs->getNumOperands() || + lhs->getNumResults() != rhs->getNumResults()) + return false; + // Compare attributes. + if (lhs->getAttrList() != rhs->getAttrList()) + return false; + // Compare operands. + if (!std::equal(lhs->operand_begin(), lhs->operand_end(), + rhs->operand_begin())) + return false; + // Compare result types. + return std::equal(lhs->result_type_begin(), lhs->result_type_end(), + rhs->result_type_begin()); + } +}; +} // end anonymous namespace + +namespace { +/// Simple common sub-expression elimination. +struct CSE : public OperationPass<CSE> { + CSE() = default; + CSE(const CSE &) {} + + /// Shared implementation of operation elimination and scoped map definitions. + using AllocatorTy = llvm::RecyclingAllocator< + llvm::BumpPtrAllocator, + llvm::ScopedHashTableVal<Operation *, Operation *>>; + using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *, + SimpleOperationInfo, AllocatorTy>; + + /// Represents a single entry in the depth first traversal of a CFG. + struct CFGStackNode { + CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) + : scope(knownValues), node(node), childIterator(node->begin()), + processed(false) {} + + /// Scope for the known values. + ScopedMapTy::ScopeTy scope; + + DominanceInfoNode *node; + DominanceInfoNode::iterator childIterator; + + /// If this node has been fully processed yet or not. + bool processed; + }; + + /// Attempt to eliminate a redundant operation. Returns success if the + /// operation was marked for removal, failure otherwise. + LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op); + + void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Block *bb); + void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Region ®ion); + + void runOnOperation() override; + +private: + /// Operations marked as dead and to be erased. + std::vector<Operation *> opsToErase; + + /// Statistics for CSE. + Statistic numCSE{this, "num-cse'd", "Number of operations CSE'd"}; + Statistic numDCE{this, "num-dce'd", "Number of operations trivially DCE'd"}; +}; +} // end anonymous namespace + +/// Attempt to eliminate a redundant operation. +LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) { + // Don't simplify operations with nested blocks. We don't currently model + // equality comparisons correctly among other things. It is also unclear + // whether we would want to CSE such operations. + if (op->getNumRegions() != 0) + return failure(); + + // TODO(riverriddle) We currently only eliminate non side-effecting + // operations. + if (!op->hasNoSideEffect()) + return failure(); + + // If the operation is already trivially dead just add it to the erase list. + if (op->use_empty()) { + opsToErase.push_back(op); + ++numDCE; + return success(); + } + + // Look for an existing definition for the operation. + if (auto *existing = knownValues.lookup(op)) { + // If we find one then replace all uses of the current operation with the + // existing one and mark it for deletion. + op->replaceAllUsesWith(existing); + opsToErase.push_back(op); + + // If the existing operation has an unknown location and the current + // operation doesn't, then set the existing op's location to that of the + // current op. + if (existing->getLoc().isa<UnknownLoc>() && + !op->getLoc().isa<UnknownLoc>()) { + existing->setLoc(op->getLoc()); + } + + ++numCSE; + return success(); + } + + // Otherwise, we add this operation to the known values map. + knownValues.insert(op, op); + return failure(); +} + +void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Block *bb) { + for (auto &inst : *bb) { + // If the operation is simplified, we don't process any held regions. + if (succeeded(simplifyOperation(knownValues, &inst))) + continue; + + // If this operation is isolated above, we can't process nested regions with + // the given 'knownValues' map. This would cause the insertion of implicit + // captures in explicit capture only regions. + if (!inst.isRegistered() || inst.isKnownIsolatedFromAbove()) { + ScopedMapTy nestedKnownValues; + for (auto ®ion : inst.getRegions()) + simplifyRegion(nestedKnownValues, domInfo, region); + continue; + } + + // Otherwise, process nested regions normally. + for (auto ®ion : inst.getRegions()) + simplifyRegion(knownValues, domInfo, region); + } +} + +void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Region ®ion) { + // If the region is empty there is nothing to do. + if (region.empty()) + return; + + // If the region only contains one block, then simplify it directly. + if (std::next(region.begin()) == region.end()) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(knownValues, domInfo, ®ion.front()); + return; + } + + // Note, deque is being used here because there was significant performance + // gains over vector when the container becomes very large due to the + // specific access patterns. If/when these performance issues are no + // longer a problem we can change this to vector. For more information see + // the llvm mailing list discussion on this: + // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html + std::deque<std::unique_ptr<CFGStackNode>> stack; + + // Process the nodes of the dom tree for this region. + stack.emplace_back(std::make_unique<CFGStackNode>( + knownValues, domInfo.getRootNode(®ion))); + + while (!stack.empty()) { + auto ¤tNode = stack.back(); + + // Check to see if we need to process this node. + if (!currentNode->processed) { + currentNode->processed = true; + simplifyBlock(knownValues, domInfo, currentNode->node->getBlock()); + } + + // Otherwise, check to see if we need to process a child node. + if (currentNode->childIterator != currentNode->node->end()) { + auto *childNode = *(currentNode->childIterator++); + stack.emplace_back( + std::make_unique<CFGStackNode>(knownValues, childNode)); + } else { + // Finally, if the node and all of its children have been processed + // then we delete the node. + stack.pop_back(); + } + } +} + +void CSE::runOnOperation() { + /// A scoped hash table of defining operations within a region. + ScopedMapTy knownValues; + + DominanceInfo &domInfo = getAnalysis<DominanceInfo>(); + for (Region ®ion : getOperation()->getRegions()) + simplifyRegion(knownValues, domInfo, region); + + // If no operations were erased, then we mark all analyses as preserved. + if (opsToErase.empty()) + return markAllAnalysesPreserved(); + + /// Erase any operations that were marked as dead during simplification. + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + + // We currently don't remove region operations, so mark dominance as + // preserved. + markAnalysesPreserved<DominanceInfo, PostDominanceInfo>(); +} + +std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); } + +static PassRegistration<CSE> pass("cse", "Eliminate common sub-expressions"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp new file mode 100644 index 00000000000..5b3a1eb1cf3 --- /dev/null +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -0,0 +1,45 @@ +//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass converts operations into their canonical forms by +// folding constants, applying operation identity transformations etc. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +using namespace mlir; + +namespace { +/// Canonicalize operations in nested regions. +struct Canonicalizer : public OperationPass<Canonicalizer> { + void runOnOperation() override { + OwningRewritePatternList patterns; + + // TODO: Instead of adding all known patterns from the whole system lazily + // add and cache the canonicalization patterns for ops we see in practice + // when building the worklist. For now, we just grab everything. + auto *context = &getContext(); + for (auto *op : context->getRegisteredOperations()) + op->getCanonicalizationPatterns(patterns, context); + + Operation *op = getOperation(); + applyPatternsGreedily(op->getRegions(), patterns); + } +}; +} // end anonymous namespace + +/// Create a Canonicalizer pass. +std::unique_ptr<Pass> mlir::createCanonicalizerPass() { + return std::make_unique<Canonicalizer>(); +} + +static PassRegistration<Canonicalizer> pass("canonicalize", + "Canonicalize operations"); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp new file mode 100644 index 00000000000..5f7fb7a68c9 --- /dev/null +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -0,0 +1,1846 @@ +//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::detail; + +#define DEBUG_TYPE "dialect-conversion" + +/// Recursively collect all of the operations to convert from within 'region'. +/// If 'target' is nonnull, operations that are recursively legal have their +/// regions pre-filtered to avoid considering them for legalization. +static LogicalResult +computeConversionSet(iterator_range<Region::iterator> region, + Location regionLoc, std::vector<Operation *> &toConvert, + ConversionTarget *target = nullptr) { + if (llvm::empty(region)) + return success(); + + // Traverse starting from the entry block. + SmallVector<Block *, 16> worklist(1, &*region.begin()); + DenseSet<Block *> visitedBlocks; + visitedBlocks.insert(worklist.front()); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + // Compute the conversion set of each of the nested operations. + for (Operation &op : *block) { + toConvert.emplace_back(&op); + + // Don't check this operation's children for conversion if the operation + // is recursively legal. + auto legalityInfo = target ? target->isLegal(&op) + : Optional<ConversionTarget::LegalOpDetails>(); + if (legalityInfo && legalityInfo->isRecursivelyLegal) + continue; + for (auto ®ion : op.getRegions()) + computeConversionSet(region.getBlocks(), region.getLoc(), toConvert, + target); + } + + // Recurse to children that haven't been visited. + for (Block *succ : block->getSuccessors()) + if (visitedBlocks.insert(succ).second) + worklist.push_back(succ); + } + + // Check that all blocks in the region were visited. + if (llvm::any_of(llvm::drop_begin(region, 1), + [&](Block &block) { return !visitedBlocks.count(&block); })) + return emitError(regionLoc, "unreachable blocks were not converted"); + return success(); +} + +//===----------------------------------------------------------------------===// +// Multi-Level Value Mapper +//===----------------------------------------------------------------------===// + +namespace { +/// This class wraps a BlockAndValueMapping to provide recursive lookup +/// functionality, i.e. we will traverse if the mapped value also has a mapping. +struct ConversionValueMapping { + /// Lookup a mapped value within the map. If a mapping for the provided value + /// does not exist then return the provided value. + Value lookupOrDefault(Value from) const; + + /// Map a value to the one provided. + void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } + + /// Drop the last mapping for the given value. + void erase(Value value) { mapping.erase(value); } + +private: + /// Current value mappings. + BlockAndValueMapping mapping; +}; +} // end anonymous namespace + +/// Lookup a mapped value within the map. If a mapping for the provided value +/// does not exist then return the provided value. +Value ConversionValueMapping::lookupOrDefault(Value from) const { + // If this value had a valid mapping, unmap that value as well in the case + // that it was also replaced. + while (auto mappedValue = mapping.lookupOrNull(from)) + from = mappedValue; + return from; +} + +//===----------------------------------------------------------------------===// +// ArgConverter +//===----------------------------------------------------------------------===// +namespace { +/// This class provides a simple interface for converting the types of block +/// arguments. This is done by creating a new block that contains the new legal +/// types and extracting the block that contains the old illegal types to allow +/// for undoing pending rewrites in the case of failure. +struct ArgConverter { + ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) + : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), + rewriter(rewriter) {} + + /// This structure contains the information pertaining to an argument that has + /// been converted. + struct ConvertedArgInfo { + ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, + Value castValue = nullptr) + : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} + + /// The start index of in the new argument list that contains arguments that + /// replace the original. + unsigned newArgIdx; + + /// The number of arguments that replaced the original argument. + unsigned newArgSize; + + /// The cast value that was created to cast from the new arguments to the + /// old. This only used if 'newArgSize' > 1. + Value castValue; + }; + + /// This structure contains information pertaining to a block that has had its + /// signature converted. + struct ConvertedBlockInfo { + ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {} + + /// The original block that was requested to have its signature converted. + Block *origBlock; + + /// The conversion information for each of the arguments. The information is + /// None if the argument was dropped during conversion. + SmallVector<Optional<ConvertedArgInfo>, 1> argInfo; + }; + + /// Return if the signature of the given block has already been converted. + bool hasBeenConverted(Block *block) const { + return conversionInfo.count(block); + } + + //===--------------------------------------------------------------------===// + // Rewrite Application + //===--------------------------------------------------------------------===// + + /// Erase any rewrites registered for the blocks within the given operation + /// which is about to be removed. This merely drops the rewrites without + /// undoing them. + void notifyOpRemoved(Operation *op); + + /// Cleanup and undo any generated conversions for the arguments of block. + /// This method replaces the new block with the original, reverting the IR to + /// its original state. + void discardRewrites(Block *block); + + /// Fully replace uses of the old arguments with the new, materializing cast + /// operations as necessary. + // FIXME(riverriddle) The 'mapping' parameter is only necessary because the + // implementation of replaceUsesOfBlockArgument is buggy. + void applyRewrites(ConversionValueMapping &mapping); + + //===--------------------------------------------------------------------===// + // Conversion + //===--------------------------------------------------------------------===// + + /// Attempt to convert the signature of the given block, if successful a new + /// block is returned containing the new arguments. On failure, nullptr is + /// returned. + Block *convertSignature(Block *block, ConversionValueMapping &mapping); + + /// Apply the given signature conversion on the given block. The new block + /// containing the updated signature is returned. + Block *applySignatureConversion( + Block *block, TypeConverter::SignatureConversion &signatureConversion, + ConversionValueMapping &mapping); + + /// Insert a new conversion into the cache. + void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); + + /// A collection of blocks that have had their arguments converted. + llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; + + /// A mapping from valid regions, to those containing the original blocks of a + /// conversion. + DenseMap<Region *, std::unique_ptr<Region>> regionMapping; + + /// An instance of the unknown location that is used when materializing + /// conversions. + Location loc; + + /// The type converter to use when changing types. + TypeConverter *typeConverter; + + /// The pattern rewriter to use when materializing conversions. + PatternRewriter &rewriter; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Rewrite Application + +void ArgConverter::notifyOpRemoved(Operation *op) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // Drop any rewrites from within. + for (Operation &nestedOp : block) + if (nestedOp.getNumRegions()) + notifyOpRemoved(&nestedOp); + + // Check if this block was converted. + auto it = conversionInfo.find(&block); + if (it == conversionInfo.end()) + return; + + // Drop all uses of the original arguments and delete the original block. + Block *origBlock = it->second.origBlock; + for (BlockArgument arg : origBlock->getArguments()) + arg->dropAllUses(); + conversionInfo.erase(it); + } + } +} + +void ArgConverter::discardRewrites(Block *block) { + auto it = conversionInfo.find(block); + if (it == conversionInfo.end()) + return; + Block *origBlock = it->second.origBlock; + + // Drop all uses of the new block arguments and replace uses of the new block. + for (int i = block->getNumArguments() - 1; i >= 0; --i) + block->getArgument(i)->dropAllUses(); + block->replaceAllUsesWith(origBlock); + + // Move the operations back the original block and the delete the new block. + origBlock->getOperations().splice(origBlock->end(), block->getOperations()); + origBlock->moveBefore(block); + block->erase(); + + conversionInfo.erase(it); +} + +void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { + for (auto &info : conversionInfo) { + Block *newBlock = info.first; + ConvertedBlockInfo &blockInfo = info.second; + Block *origBlock = blockInfo.origBlock; + + // Process the remapping for each of the original arguments. + for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { + Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i]; + BlockArgument origArg = origBlock->getArgument(i); + + // Handle the case of a 1->0 value mapping. + if (!argInfo) { + // If a replacement value was given for this argument, use that to + // replace all uses. + auto argReplacementValue = mapping.lookupOrDefault(origArg); + if (argReplacementValue != origArg) { + origArg->replaceAllUsesWith(argReplacementValue); + continue; + } + // If there are any dangling uses then replace the argument with one + // generated by the type converter. This is necessary as the cast must + // persist in the IR after conversion. + if (!origArg->use_empty()) { + rewriter.setInsertionPointToStart(newBlock); + auto *newOp = typeConverter->materializeConversion( + rewriter, origArg->getType(), llvm::None, loc); + origArg->replaceAllUsesWith(newOp->getResult(0)); + } + continue; + } + + // If mapping is 1-1, replace the remaining uses and drop the cast + // operation. + // FIXME(riverriddle) This should check that the result type and operand + // type are the same, otherwise it should force a conversion to be + // materialized. + if (argInfo->newArgSize == 1) { + origArg->replaceAllUsesWith( + mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx))); + continue; + } + + // Otherwise this is a 1->N value mapping. + Value castValue = argInfo->castValue; + assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); + + // If the argument is still used, replace it with the generated cast. + if (!origArg->use_empty()) + origArg->replaceAllUsesWith(mapping.lookupOrDefault(castValue)); + + // If all users of the cast were removed, we can drop it. Otherwise, keep + // the operation alive and let the user handle any remaining usages. + if (castValue->use_empty()) + castValue->getDefiningOp()->erase(); + } + } +} + +//===----------------------------------------------------------------------===// +// Conversion + +Block *ArgConverter::convertSignature(Block *block, + ConversionValueMapping &mapping) { + if (auto conversion = typeConverter->convertBlockSignature(block)) + return applySignatureConversion(block, *conversion, mapping); + return nullptr; +} + +Block *ArgConverter::applySignatureConversion( + Block *block, TypeConverter::SignatureConversion &signatureConversion, + ConversionValueMapping &mapping) { + // If no arguments are being changed or added, there is nothing to do. + unsigned origArgCount = block->getNumArguments(); + auto convertedTypes = signatureConversion.getConvertedTypes(); + if (origArgCount == 0 && convertedTypes.empty()) + return block; + + // Split the block at the beginning to get a new block to use for the updated + // signature. + Block *newBlock = block->splitBlock(block->begin()); + block->replaceAllUsesWith(newBlock); + + SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes)); + ArrayRef<Value> newArgs(newArgRange); + + // Remap each of the original arguments as determined by the signature + // conversion. + ConvertedBlockInfo info(block); + info.argInfo.resize(origArgCount); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newBlock); + for (unsigned i = 0; i != origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap) + continue; + BlockArgument origArg = block->getArgument(i); + + // If inputMap->replacementValue is not nullptr, then the argument is + // dropped and a replacement value is provided to be the remappedValue. + if (inputMap->replacementValue) { + assert(inputMap->size == 0 && + "invalid to provide a replacement value when the argument isn't " + "dropped"); + mapping.map(origArg, inputMap->replacementValue); + continue; + } + + // If this is a 1->1 mapping, then map the argument directly. + if (inputMap->size == 1) { + mapping.map(origArg, newArgs[inputMap->inputNo]); + info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size); + continue; + } + + // Otherwise, this is a 1->N mapping. Call into the provided type converter + // to pack the new values. + auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); + Operation *cast = typeConverter->materializeConversion( + rewriter, origArg->getType(), replArgs, loc); + assert(cast->getNumResults() == 1 && + cast->getNumOperands() == replArgs.size()); + mapping.map(origArg, cast->getResult(0)); + info.argInfo[i] = + ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); + } + + // Remove the original block from the region and return the new one. + insertConversion(newBlock, std::move(info)); + return newBlock; +} + +void ArgConverter::insertConversion(Block *newBlock, + ConvertedBlockInfo &&info) { + // Get a region to insert the old block. + Region *region = newBlock->getParent(); + std::unique_ptr<Region> &mappedRegion = regionMapping[region]; + if (!mappedRegion) + mappedRegion = std::make_unique<Region>(region->getParentOp()); + + // Move the original block to the mapped region and emplace the conversion. + mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), + info.origBlock->getIterator()); + conversionInfo.insert({newBlock, std::move(info)}); +} + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriterImpl +//===----------------------------------------------------------------------===// +namespace { +/// This class contains a snapshot of the current conversion rewriter state. +/// This is useful when saving and undoing a set of rewrites. +struct RewriterState { + RewriterState(unsigned numCreatedOps, unsigned numReplacements, + unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRootUpdates) + : numCreatedOps(numCreatedOps), numReplacements(numReplacements), + numBlockActions(numBlockActions), + numIgnoredOperations(numIgnoredOperations), + numRootUpdates(numRootUpdates) {} + + /// The current number of created operations. + unsigned numCreatedOps; + + /// The current number of replacements queued. + unsigned numReplacements; + + /// The current number of block actions performed. + unsigned numBlockActions; + + /// The current number of ignored operations. + unsigned numIgnoredOperations; + + /// The current number of operations that were updated in place. + unsigned numRootUpdates; +}; + +/// The state of an operation that was updated by a pattern in-place. This +/// contains all of the necessary information to reconstruct an operation that +/// was updated in place. +class OperationTransactionState { +public: + OperationTransactionState() = default; + OperationTransactionState(Operation *op) + : op(op), loc(op->getLoc()), attrs(op->getAttrList()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + /// Discard the transaction state and reset the state of the original + /// operation. + void resetOperation() const { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (auto it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + + /// Return the original operation of this state. + Operation *getOperation() const { return op; } + +private: + Operation *op; + LocationAttr loc; + NamedAttributeList attrs; + SmallVector<Value, 8> operands; + SmallVector<Block *, 2> successors; +}; +} // end anonymous namespace + +namespace mlir { +namespace detail { +struct ConversionPatternRewriterImpl { + /// This class represents one requested operation replacement via 'replaceOp'. + struct OpReplacement { + OpReplacement() = default; + OpReplacement(Operation *op, ValueRange newValues) + : op(op), newValues(newValues.begin(), newValues.end()) {} + + Operation *op; + SmallVector<Value, 2> newValues; + }; + + /// The kind of the block action performed during the rewrite. Actions can be + /// undone if the conversion fails. + enum class BlockActionKind { Create, Move, Split, TypeConversion }; + + /// Original position of the given block in its parent region. We cannot use + /// a region iterator because it could have been invalidated by other region + /// operations since the position was stored. + struct BlockPosition { + Region *region; + Region::iterator::difference_type position; + }; + + /// The storage class for an undoable block action (one of BlockActionKind), + /// contains the information necessary to undo this action. + struct BlockAction { + static BlockAction getCreate(Block *block) { + return {BlockActionKind::Create, block, {}}; + } + static BlockAction getMove(Block *block, BlockPosition originalPos) { + return {BlockActionKind::Move, block, {originalPos}}; + } + static BlockAction getSplit(Block *block, Block *originalBlock) { + BlockAction action{BlockActionKind::Split, block, {}}; + action.originalBlock = originalBlock; + return action; + } + static BlockAction getTypeConversion(Block *block) { + return BlockAction{BlockActionKind::TypeConversion, block, {}}; + } + + // The action kind. + BlockActionKind kind; + + // A pointer to the block that was created by the action. + Block *block; + + union { + // In use if kind == BlockActionKind::Move and contains a pointer to the + // region that originally contained the block as well as the position of + // the block in that region. + BlockPosition originalPosition; + // In use if kind == BlockActionKind::Split and contains a pointer to the + // block that was split into two parts. + Block *originalBlock; + }; + }; + + ConversionPatternRewriterImpl(PatternRewriter &rewriter, + TypeConverter *converter) + : argConverter(converter, rewriter) {} + + /// Return the current state of the rewriter. + RewriterState getCurrentState(); + + /// Reset the state of the rewriter to a previously saved point. + void resetState(RewriterState state); + + /// Undo the block actions (motions, splits) one by one in reverse order until + /// "numActionsToKeep" actions remains. + void undoBlockActions(unsigned numActionsToKeep = 0); + + /// Cleanup and destroy any generated rewrite operations. This method is + /// invoked when the conversion process fails. + void discardRewrites(); + + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites(); + + /// Convert the signature of the given block. + LogicalResult convertBlockSignature(Block *block); + + /// Apply a signature conversion on the given region. + Block * + applySignatureConversion(Region *region, + TypeConverter::SignatureConversion &conversion); + + /// PatternRewriter hook for replacing the results of an operation. + void replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead); + + /// Notifies that a block was split. + void notifySplitBlock(Block *block, Block *continuation); + + /// Notifies that the blocks of a region are about to be moved. + void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, + Region::iterator before); + + /// Notifies that the blocks of a region were cloned into another. + void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks, + Location origRegionLoc); + + /// Remap the given operands to those with potentially different types. + void remapValues(Operation::operand_range operands, + SmallVectorImpl<Value> &remapped); + + /// Returns true if the given operation is ignored, and does not need to be + /// converted. + bool isOpIgnored(Operation *op) const; + + /// Recursively marks the nested operations under 'op' as ignored. This + /// removes them from being considered for legalization. + void markNestedOpsIgnored(Operation *op); + + // Mapping between replaced values that differ in type. This happens when + // replacing a value with one of a different type. + ConversionValueMapping mapping; + + /// Utility used to convert block arguments. + ArgConverter argConverter; + + /// Ordered vector of all of the newly created operations during conversion. + std::vector<Operation *> createdOps; + + /// Ordered vector of any requested operation replacements. + SmallVector<OpReplacement, 4> replacements; + + /// Ordered list of block operations (creations, splits, motions). + SmallVector<BlockAction, 4> blockActions; + + /// A set of operations that have been erased/replaced/etc that should no + /// longer be considered for legalization. This is not meant to be an + /// exhaustive list of all operations, but the minimal set that can be used to + /// detect if a given operation should be `ignored`. For example, we may add + /// the operations that define non-empty regions to the set, but not any of + /// the others. This simplifies the amount of memory needed as we can query if + /// the parent operation was ignored. + llvm::SetVector<Operation *> ignoredOps; + + /// A transaction state for each of operations that were updated in-place. + SmallVector<OperationTransactionState, 4> rootUpdates; + +#ifndef NDEBUG + /// A set of operations that have pending updates. This tracking isn't + /// strictly necessary, and is thus only active during debug builds for extra + /// verification. + SmallPtrSet<Operation *, 1> pendingRootUpdates; +#endif +}; +} // end namespace detail +} // end namespace mlir + +RewriterState ConversionPatternRewriterImpl::getCurrentState() { + return RewriterState(createdOps.size(), replacements.size(), + blockActions.size(), ignoredOps.size(), + rootUpdates.size()); +} + +void ConversionPatternRewriterImpl::resetState(RewriterState state) { + // Reset any operations that were updated in place. + for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) + rootUpdates[i].resetOperation(); + rootUpdates.resize(state.numRootUpdates); + + // Undo any block actions. + undoBlockActions(state.numBlockActions); + + // Reset any replaced operations and undo any saved mappings. + for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) + for (auto result : repl.op->getResults()) + mapping.erase(result); + replacements.resize(state.numReplacements); + + // Pop all of the newly created operations. + while (createdOps.size() != state.numCreatedOps) { + createdOps.back()->erase(); + createdOps.pop_back(); + } + + // Pop all of the recorded ignored operations that are no longer valid. + while (ignoredOps.size() != state.numIgnoredOperations) + ignoredOps.pop_back(); +} + +void ConversionPatternRewriterImpl::undoBlockActions( + unsigned numActionsToKeep) { + for (auto &action : + llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { + switch (action.kind) { + // Delete the created block. + case BlockActionKind::Create: { + // Unlink all of the operations within this block, they will be deleted + // separately. + auto &blockOps = action.block->getOperations(); + while (!blockOps.empty()) + blockOps.remove(blockOps.begin()); + action.block->dropAllDefinedValueUses(); + action.block->erase(); + break; + } + // Move the block back to its original position. + case BlockActionKind::Move: { + Region *originalRegion = action.originalPosition.region; + originalRegion->getBlocks().splice( + std::next(originalRegion->begin(), action.originalPosition.position), + action.block->getParent()->getBlocks(), action.block); + break; + } + // Merge back the block that was split out. + case BlockActionKind::Split: { + action.originalBlock->getOperations().splice( + action.originalBlock->end(), action.block->getOperations()); + action.block->dropAllDefinedValueUses(); + action.block->erase(); + break; + } + // Undo the type conversion. + case BlockActionKind::TypeConversion: { + argConverter.discardRewrites(action.block); + break; + } + } + } + blockActions.resize(numActionsToKeep); +} + +void ConversionPatternRewriterImpl::discardRewrites() { + // Reset any operations that were updated in place. + for (auto &state : rootUpdates) + state.resetOperation(); + + undoBlockActions(); + + // Remove any newly created ops. + for (auto *op : llvm::reverse(createdOps)) + op->erase(); +} + +void ConversionPatternRewriterImpl::applyRewrites() { + // Apply all of the rewrites replacements requested during conversion. + for (auto &repl : replacements) { + for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { + if (auto newValue = repl.newValues[i]) + repl.op->getResult(i)->replaceAllUsesWith( + mapping.lookupOrDefault(newValue)); + } + + // If this operation defines any regions, drop any pending argument + // rewrites. + if (argConverter.typeConverter && repl.op->getNumRegions()) + argConverter.notifyOpRemoved(repl.op); + } + + // In a second pass, erase all of the replaced operations in reverse. This + // allows processing nested operations before their parent region is + // destroyed. + for (auto &repl : llvm::reverse(replacements)) + repl.op->erase(); + + argConverter.applyRewrites(mapping); +} + +LogicalResult +ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { + // Check to see if this block should not be converted: + // * There is no type converter. + // * The block has already been converted. + // * This is an entry block, these are converted explicitly via patterns. + if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || + !block->getParent() || block->isEntryBlock()) + return success(); + + // Otherwise, try to convert the block signature. + Block *newBlock = argConverter.convertSignature(block, mapping); + if (newBlock) + blockActions.push_back(BlockAction::getTypeConversion(newBlock)); + return success(newBlock); +} + +Block *ConversionPatternRewriterImpl::applySignatureConversion( + Region *region, TypeConverter::SignatureConversion &conversion) { + if (!region->empty()) { + Block *newEntry = argConverter.applySignatureConversion( + ®ion->front(), conversion, mapping); + blockActions.push_back(BlockAction::getTypeConversion(newEntry)); + return newEntry; + } + return nullptr; +} + +void ConversionPatternRewriterImpl::replaceOp(Operation *op, + ValueRange newValues, + ValueRange valuesToRemoveIfDead) { + assert(newValues.size() == op->getNumResults()); + + // Create mappings for each of the new result values. + for (unsigned i = 0, e = newValues.size(); i < e; ++i) + if (auto repl = newValues[i]) + mapping.map(op->getResult(i), repl); + + // Record the requested operation replacement. + replacements.emplace_back(op, newValues); + + /// Mark this operation as recursively ignored so that we don't need to + /// convert any nested operations. + markNestedOpsIgnored(op); +} + +void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, + Block *continuation) { + blockActions.push_back(BlockAction::getSplit(continuation, block)); +} + +void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( + Region ®ion, Region &parent, Region::iterator before) { + for (auto &pair : llvm::enumerate(region)) { + Block &block = pair.value(); + unsigned position = pair.index(); + blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); + } +} + +void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( + iterator_range<Region::iterator> &blocks, Location origRegionLoc) { + for (Block &block : blocks) + blockActions.push_back(BlockAction::getCreate(&block)); + + // Compute the conversion set for the inlined region. + auto result = computeConversionSet(blocks, origRegionLoc, createdOps); + + // This original region has already had its conversion set computed, so there + // shouldn't be any new failures. + (void)result; + assert(succeeded(result) && "expected region to have no unreachable blocks"); +} + +void ConversionPatternRewriterImpl::remapValues( + Operation::operand_range operands, SmallVectorImpl<Value> &remapped) { + remapped.reserve(llvm::size(operands)); + for (Value operand : operands) + remapped.push_back(mapping.lookupOrDefault(operand)); +} + +bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { + // Check to see if this operation or its parent were ignored. + return ignoredOps.count(op) || ignoredOps.count(op->getParentOp()); +} + +void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { + // Walk this operation and collect nested operations that define non-empty + // regions. We mark such operations as 'ignored' so that we know we don't have + // to convert them, or their nested ops. + if (op->getNumRegions() == 0) + return; + op->walk([&](Operation *op) { + if (llvm::any_of(op->getRegions(), + [](Region ®ion) { return !region.empty(); })) + ignoredOps.insert(op); + }); +} + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter +//===----------------------------------------------------------------------===// + +ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, + TypeConverter *converter) + : PatternRewriter(ctx), + impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} +ConversionPatternRewriter::~ConversionPatternRewriter() {} + +/// PatternRewriter hook for replacing the results of an operation. +void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead) { + LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName() + << "\n"); + impl->replaceOp(op, newValues, valuesToRemoveIfDead); +} + +/// PatternRewriter hook for erasing a dead operation. The uses of this +/// operation *must* be made dead by the end of the conversion process, +/// otherwise an assert will be issued. +void ConversionPatternRewriter::eraseOp(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() + << "\n"); + SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); + impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); +} + +/// Apply a signature conversion to the entry block of the given region. +Block *ConversionPatternRewriter::applySignatureConversion( + Region *region, TypeConverter::SignatureConversion &conversion) { + return impl->applySignatureConversion(region, conversion); +} + +void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, + Value to) { + for (auto &u : from->getUses()) { + if (u.getOwner() == to->getDefiningOp()) + continue; + u.getOwner()->replaceUsesOfWith(from, to); + } + impl->mapping.map(impl->mapping.lookupOrDefault(from), to); +} + +/// Return the converted value that replaces 'key'. Return 'key' if there is +/// no such a converted value. +Value ConversionPatternRewriter::getRemappedValue(Value key) { + return impl->mapping.lookupOrDefault(key); +} + +/// PatternRewriter hook for splitting a block into two parts. +Block *ConversionPatternRewriter::splitBlock(Block *block, + Block::iterator before) { + auto *continuation = PatternRewriter::splitBlock(block, before); + impl->notifySplitBlock(block, continuation); + return continuation; +} + +/// PatternRewriter hook for merging a block into another. +void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, + ValueRange argValues) { + // TODO(riverriddle) This requires fixing the implementation of + // 'replaceUsesOfBlockArgument', which currently isn't undoable. + llvm_unreachable("block merging updates are currently not supported"); +} + +/// PatternRewriter hook for moving blocks out of a region. +void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, + Region &parent, + Region::iterator before) { + impl->notifyRegionIsBeingInlinedBefore(region, parent, before); + PatternRewriter::inlineRegionBefore(region, parent, before); +} + +/// PatternRewriter hook for cloning blocks of one region into another. +void ConversionPatternRewriter::cloneRegionBefore( + Region ®ion, Region &parent, Region::iterator before, + BlockAndValueMapping &mapping) { + if (region.empty()) + return; + PatternRewriter::cloneRegionBefore(region, parent, before, mapping); + + // Collect the range of the cloned blocks. + auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); + auto clonedBlocks = llvm::make_range(clonedBeginIt, before); + impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); +} + +/// PatternRewriter hook for creating a new operation. +Operation *ConversionPatternRewriter::insert(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName() + << "\n"); + impl->createdOps.push_back(op); + return OpBuilder::insert(op); +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::startRootUpdate(Operation *op) { +#ifndef NDEBUG + impl->pendingRootUpdates.insert(op); +#endif + impl->rootUpdates.emplace_back(op); +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { + // There is nothing to do here, we only need to track the operation at the + // start of the update. +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif + // Erase the last update for this operation. + auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; + auto &rootUpdates = impl->rootUpdates; + auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); + rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); +} + +/// Return a reference to the internal implementation. +detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { + return *impl; +} + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Attempt to match and rewrite the IR root at the specified operation. +PatternMatchResult +ConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + SmallVector<Value, 4> operands; + auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); + dialectRewriter.getImpl().remapValues(op->getOperands(), operands); + + // If this operation has no successors, invoke the rewrite directly. + if (op->getNumSuccessors() == 0) + return matchAndRewrite(op, operands, dialectRewriter); + + // Otherwise, we need to remap the successors. + SmallVector<Block *, 2> destinations; + destinations.reserve(op->getNumSuccessors()); + + SmallVector<ArrayRef<Value>, 2> operandsPerDestination; + unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); + for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { + destinations.push_back(op->getSuccessor(i)); + + // Lookup the successors operands. + unsigned n = op->getNumSuccessorOperands(i); + operandsPerDestination.push_back( + llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n)); + seen += n; + } + + // Rewrite the operation. + return matchAndRewrite( + op, + llvm::makeArrayRef(operands.data(), + operands.data() + firstSuccessorOperand), + destinations, operandsPerDestination, dialectRewriter); +} + +//===----------------------------------------------------------------------===// +// OperationLegalizer +//===----------------------------------------------------------------------===// + +namespace { +/// A set of rewrite patterns that can be used to legalize a given operation. +using LegalizationPatterns = SmallVector<RewritePattern *, 1>; + +/// This class defines a recursive operation legalizer. +class OperationLegalizer { +public: + using LegalizationAction = ConversionTarget::LegalizationAction; + + OperationLegalizer(ConversionTarget &targetInfo, + const OwningRewritePatternList &patterns) + : target(targetInfo) { + buildLegalizationGraph(patterns); + computeLegalizationGraphBenefit(); + } + + /// Returns if the given operation is known to be illegal on the target. + bool isIllegal(Operation *op) const; + + /// Attempt to legalize the given operation. Returns success if the operation + /// was legalized, failure otherwise. + LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); + + /// Returns the conversion target in use by the legalizer. + ConversionTarget &getTarget() { return target; } + +private: + /// Attempt to legalize the given operation by folding it. + LogicalResult legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter); + + /// Attempt to legalize the given operation by applying the provided pattern. + /// Returns success if the operation was legalized, failure otherwise. + LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, + ConversionPatternRewriter &rewriter); + + /// Build an optimistic legalization graph given the provided patterns. This + /// function populates 'legalizerPatterns' with the operations that are not + /// directly legal, but may be transitively legal for the current target given + /// the provided patterns. + void buildLegalizationGraph(const OwningRewritePatternList &patterns); + + /// Compute the benefit of each node within the computed legalization graph. + /// This orders the patterns within 'legalizerPatterns' based upon two + /// criteria: + /// 1) Prefer patterns that have the lowest legalization depth, i.e. + /// represent the more direct mapping to the target. + /// 2) When comparing patterns with the same legalization depth, prefer the + /// pattern with the highest PatternBenefit. This allows for users to + /// prefer specific legalizations over others. + void computeLegalizationGraphBenefit(); + + /// The current set of patterns that have been applied. + SmallPtrSet<RewritePattern *, 8> appliedPatterns; + + /// The set of legality information for operations transitively supported by + /// the target. + DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; + + /// The legalization information provided by the target. + ConversionTarget ⌖ +}; +} // namespace + +bool OperationLegalizer::isIllegal(Operation *op) const { + // Check if the target explicitly marked this operation as illegal. + return target.getOpAction(op->getName()) == LegalizationAction::Illegal; +} + +LogicalResult +OperationLegalizer::legalize(Operation *op, + ConversionPatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() + << "\n"); + + // Check if this operation is legal on the target. + if (auto legalityInfo = target.isLegal(op)) { + LLVM_DEBUG(llvm::dbgs() + << "-- Success : Operation marked legal by the target\n"); + // If this operation is recursively legal, mark its children as ignored so + // that we don't consider them for legalization. + if (legalityInfo->isRecursivelyLegal) { + LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation is recursively legal; " + "Skipping internals\n"); + rewriter.getImpl().markNestedOpsIgnored(op); + } + return success(); + } + + // Check to see if the operation is ignored and doesn't need to be converted. + if (rewriter.getImpl().isOpIgnored(op)) { + LLVM_DEBUG(llvm::dbgs() + << "-- Success : Operation marked ignored during conversion\n"); + return success(); + } + + // If the operation isn't legal, try to fold it in-place. + // TODO(riverriddle) Should we always try to do this, even if the op is + // already legal? + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n"); + return success(); + } + + // Otherwise, we need to apply a legalization pattern to this operation. + auto it = legalizerPatterns.find(op->getName()); + if (it == legalizerPatterns.end()) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); + return failure(); + } + + // The patterns are sorted by expected benefit, so try to apply each in-order. + for (auto *pattern : it->second) + if (succeeded(legalizePattern(op, pattern, rewriter))) + return success(); + + LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n"); + return failure(); +} + +LogicalResult +OperationLegalizer::legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); + + // Try to fold the operation. + SmallVector<Value, 2> replacementValues; + rewriter.setInsertionPoint(op); + if (failed(rewriter.tryFold(op, replacementValues))) + return failure(); + + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + + // Recursively legalize any new constant operations. + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); + i != e; ++i) { + Operation *cstOp = rewriterImpl.createdOps[i]; + if (failed(legalize(cstOp, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '" + << cstOp->getName() << "' was illegal.\n"); + rewriterImpl.resetState(curState); + return failure(); + } + } + return success(); +} + +LogicalResult +OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, + ConversionPatternRewriter &rewriter) { + LLVM_DEBUG({ + llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; + interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); + llvm::dbgs() << ")'.\n"; + }); + + // Ensure that we don't cycle by not allowing the same pattern to be + // applied twice in the same recursion stack. + // TODO(riverriddle) We could eventually converge, but that requires more + // complicated analysis. + if (!appliedPatterns.insert(pattern).second) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n"); + return failure(); + } + + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); + auto cleanupFailure = [&] { + // Reset the rewriter state and pop this pattern. + rewriterImpl.resetState(curState); + appliedPatterns.erase(pattern); + return failure(); + }; + + // Try to rewrite with the given pattern. + rewriter.setInsertionPoint(op); + auto matchedPattern = pattern->matchAndRewrite(op, rewriter); +#ifndef NDEBUG + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#endif + + if (!matchedPattern) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); + return cleanupFailure(); + } + + // If the pattern moved or created any blocks, try to legalize their types. + // This ensures that the types of the block arguments are legal for the region + // they were moved into. + for (unsigned i = curState.numBlockActions, + e = rewriterImpl.blockActions.size(); + i != e; ++i) { + auto &action = rewriterImpl.blockActions[i]; + if (action.kind == + ConversionPatternRewriterImpl::BlockActionKind::TypeConversion) + continue; + + // Convert the block signature. + if (failed(rewriterImpl.convertBlockSignature(action.block))) { + LLVM_DEBUG(llvm::dbgs() + << "-- FAIL: failed to convert types of moved block.\n"); + return cleanupFailure(); + } + } + + // Check all of the replacements to ensure that the pattern actually replaced + // the root operation. We also mark any other replaced ops as 'dead' so that + // we don't try to legalize them later. + bool replacedRoot = false; + for (unsigned i = curState.numReplacements, + e = rewriterImpl.replacements.size(); + i != e; ++i) { + Operation *replacedOp = rewriterImpl.replacements[i].op; + if (replacedOp == op) + replacedRoot = true; + else + rewriterImpl.ignoredOps.insert(replacedOp); + } + + // Check that the root was either updated or replace. + auto updatedRootInPlace = [&] { + return llvm::any_of( + llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), + [op](auto &state) { return state.getOperation() == op; }); + }; + (void)replacedRoot; + (void)updatedRootInPlace; + assert((replacedRoot || updatedRootInPlace()) && + "expected pattern to replace the root operation"); + + // Recursively legalize each of the operations updated in place. + for (unsigned i = curState.numRootUpdates, + e = rewriterImpl.rootUpdates.size(); + i != e; ++i) { + auto &state = rewriterImpl.rootUpdates[i]; + if (failed(legalize(state.getOperation(), rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '" + << op->getName() << "' was illegal.\n"); + return cleanupFailure(); + } + } + + // Recursively legalize each of the new operations. + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); + i != e; ++i) { + Operation *op = rewriterImpl.createdOps[i]; + if (failed(legalize(op, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation '" + << op->getName() << "' was illegal.\n"); + return cleanupFailure(); + } + } + + appliedPatterns.erase(pattern); + return success(); +} + +void OperationLegalizer::buildLegalizationGraph( + const OwningRewritePatternList &patterns) { + // A mapping between an operation and a set of operations that can be used to + // generate it. + DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; + // A mapping between an operation and any currently invalid patterns it has. + DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns; + // A worklist of patterns to consider for legality. + llvm::SetVector<RewritePattern *> patternWorklist; + + // Build the mapping from operations to the parent ops that may generate them. + for (auto &pattern : patterns) { + auto root = pattern->getRootKind(); + + // Skip operations that are always known to be legal. + if (target.getOpAction(root) == LegalizationAction::Legal) + continue; + + // Add this pattern to the invalid set for the root op and record this root + // as a parent for any generated operations. + invalidPatterns[root].insert(pattern.get()); + for (auto op : pattern->getGeneratedOps()) + parentOps[op].insert(root); + + // Add this pattern to the worklist. + patternWorklist.insert(pattern.get()); + } + + while (!patternWorklist.empty()) { + auto *pattern = patternWorklist.pop_back_val(); + + // Check to see if any of the generated operations are invalid. + if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { + Optional<LegalizationAction> action = target.getOpAction(op); + return !legalizerPatterns.count(op) && + (!action || action == LegalizationAction::Illegal); + })) + continue; + + // Otherwise, if all of the generated operation are valid, this op is now + // legal so add all of the child patterns to the worklist. + legalizerPatterns[pattern->getRootKind()].push_back(pattern); + invalidPatterns[pattern->getRootKind()].erase(pattern); + + // Add any invalid patterns of the parent operations to see if they have now + // become legal. + for (auto op : parentOps[pattern->getRootKind()]) + patternWorklist.set_union(invalidPatterns[op]); + } +} + +void OperationLegalizer::computeLegalizationGraphBenefit() { + // The smallest pattern depth, when legalizing an operation. + DenseMap<OperationName, unsigned> minPatternDepth; + + // Compute the minimum legalization depth for a given operation. + std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) { + // Check for existing depth. + auto depthIt = minPatternDepth.find(op); + if (depthIt != minPatternDepth.end()) + return depthIt->second; + + // If a mapping for this operation does not exist, then this operation + // is always legal. Return 0 as the depth for a directly legal operation. + auto opPatternsIt = legalizerPatterns.find(op); + if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) + return 0u; + + // Initialize the depth to the maximum value. + unsigned minDepth = std::numeric_limits<unsigned>::max(); + + // Record this initial depth in case we encounter this op again when + // recursively computing the depth. + minPatternDepth.try_emplace(op, minDepth); + + // Compute the depth for each pattern used to legalize this operation. + SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth; + patternsByDepth.reserve(opPatternsIt->second.size()); + for (RewritePattern *pattern : opPatternsIt->second) { + unsigned depth = 0; + for (auto generatedOp : pattern->getGeneratedOps()) + depth = std::max(depth, computeDepth(generatedOp) + 1); + patternsByDepth.emplace_back(pattern, depth); + + // Update the min depth for this operation. + minDepth = std::min(minDepth, depth); + } + + // Update the pattern depth. + minPatternDepth[op] = minDepth; + + // If the operation only has one legalization pattern, there is no need to + // sort them. + if (patternsByDepth.size() == 1) + return minDepth; + + // Sort the patterns by those likely to be the most beneficial. + llvm::array_pod_sort( + patternsByDepth.begin(), patternsByDepth.end(), + [](const std::pair<RewritePattern *, unsigned> *lhs, + const std::pair<RewritePattern *, unsigned> *rhs) { + // First sort by the smaller pattern legalization depth. + if (lhs->second != rhs->second) + return llvm::array_pod_sort_comparator<unsigned>(&lhs->second, + &rhs->second); + + // Then sort by the larger pattern benefit. + auto lhsBenefit = lhs->first->getBenefit(); + auto rhsBenefit = rhs->first->getBenefit(); + return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit, + &lhsBenefit); + }); + + // Update the legalization pattern to use the new sorted list. + opPatternsIt->second.clear(); + for (auto &patternIt : patternsByDepth) + opPatternsIt->second.push_back(patternIt.first); + + return minDepth; + }; + + // For each operation that is transitively legal, compute a cost for it. + for (auto &opIt : legalizerPatterns) + if (!minPatternDepth.count(opIt.first)) + computeDepth(opIt.first); +} + +//===----------------------------------------------------------------------===// +// OperationConverter +//===----------------------------------------------------------------------===// +namespace { +enum OpConversionMode { + // In this mode, the conversion will ignore failed conversions to allow + // illegal operations to co-exist in the IR. + Partial, + + // In this mode, all operations must be legal for the given target for the + // conversion to succeed. + Full, + + // In this mode, operations are analyzed for legality. No actual rewrites are + // applied to the operations on success. + Analysis, +}; + +// This class converts operations to a given conversion target via a set of +// rewrite patterns. The conversion behaves differently depending on the +// conversion mode. +struct OperationConverter { + explicit OperationConverter(ConversionTarget &target, + const OwningRewritePatternList &patterns, + OpConversionMode mode, + DenseSet<Operation *> *legalizableOps = nullptr) + : opLegalizer(target, patterns), mode(mode), + legalizableOps(legalizableOps) {} + + /// Converts the given operations to the conversion target. + LogicalResult convertOperations(ArrayRef<Operation *> ops, + TypeConverter *typeConverter); + +private: + /// Converts an operation with the given rewriter. + LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); + + /// Converts the type signatures of the blocks nested within 'op'. + LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, + Operation *op); + + /// The legalizer to use when converting operations. + OperationLegalizer opLegalizer; + + /// The conversion mode to use when legalizing operations. + OpConversionMode mode; + + /// A set of pre-existing operations that were found to be legalizable to the + /// target. This field is only used when mode == OpConversionMode::Analysis. + DenseSet<Operation *> *legalizableOps; +}; +} // end anonymous namespace + +LogicalResult +OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, + Operation *op) { + // Check to see if type signatures need to be converted. + if (!rewriter.getImpl().argConverter.typeConverter) + return success(); + + for (auto ®ion : op->getRegions()) { + for (auto &block : llvm::make_early_inc_range(region)) + if (failed(rewriter.getImpl().convertBlockSignature(&block))) + return failure(); + } + return success(); +} + +LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, + Operation *op) { + // Legalize the given operation. + if (failed(opLegalizer.legalize(op, rewriter))) { + // Handle the case of a failed conversion for each of the different modes. + /// Full conversions expect all operations to be converted. + if (mode == OpConversionMode::Full) + return op->emitError() + << "failed to legalize operation '" << op->getName() << "'"; + /// Partial conversions allow conversions to fail iff the operation was not + /// explicitly marked as illegal. + if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op)) + return op->emitError() + << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + } else { + /// Analysis conversions don't fail if any operations fail to legalize, + /// they are only interested in the operations that were successfully + /// legalized. + if (mode == OpConversionMode::Analysis) + legalizableOps->insert(op); + + // If legalization succeeded, convert the types any of the blocks within + // this operation. + if (failed(convertBlockSignatures(rewriter, op))) + return failure(); + } + return success(); +} + +LogicalResult +OperationConverter::convertOperations(ArrayRef<Operation *> ops, + TypeConverter *typeConverter) { + if (ops.empty()) + return success(); + ConversionTarget &target = opLegalizer.getTarget(); + + /// Compute the set of operations and blocks to convert. + std::vector<Operation *> toConvert; + for (auto *op : ops) { + toConvert.emplace_back(op); + for (auto ®ion : op->getRegions()) + if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), + toConvert, &target))) + return failure(); + } + + // Convert each operation and discard rewrites on failure. + ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); + for (auto *op : toConvert) + if (failed(convert(rewriter, op))) + return rewriter.getImpl().discardRewrites(), failure(); + + // Otherwise, the body conversion succeeded. Apply rewrites if this is not an + // analysis conversion. + if (mode == OpConversionMode::Analysis) + rewriter.getImpl().discardRewrites(); + else + rewriter.getImpl().applyRewrites(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Type Conversion +//===----------------------------------------------------------------------===// + +/// Remap an input of the original signature with a new set of types. The +/// new types are appended to the new signature conversion. +void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, + ArrayRef<Type> types) { + assert(!types.empty() && "expected valid types"); + remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); + addInputs(types); +} + +/// Append new input types to the signature conversion, this should only be +/// used if the new types are not intended to remap an existing input. +void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { + assert(!types.empty() && + "1->0 type remappings don't need to be added explicitly"); + argTypes.append(types.begin(), types.end()); +} + +/// Remap an input of the original signature with a range of types in the +/// new signature. +void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, + unsigned newInputNo, + unsigned newInputCount) { + assert(!remappedInputs[origInputNo] && "input has already been remapped"); + assert(newInputCount != 0 && "expected valid input count"); + remappedInputs[origInputNo] = + InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; +} + +/// Remap an input of the original signature to another `replacementValue` +/// value. This would make the signature converter drop this argument. +void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, + Value replacementValue) { + assert(!remappedInputs[origInputNo] && "input has already been remapped"); + remappedInputs[origInputNo] = + InputMapping{origInputNo, /*size=*/0, replacementValue}; +} + +/// This hooks allows for converting a type. +LogicalResult TypeConverter::convertType(Type t, + SmallVectorImpl<Type> &results) { + if (auto newT = convertType(t)) { + results.push_back(newT); + return success(); + } + return failure(); +} + +/// Convert the given set of types, filling 'results' as necessary. This +/// returns failure if the conversion of any of the types fails, success +/// otherwise. +LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types, + SmallVectorImpl<Type> &results) { + for (auto type : types) + if (failed(convertType(type, results))) + return failure(); + return success(); +} + +/// Return true if the given type is legal for this type converter, i.e. the +/// type converts to itself. +bool TypeConverter::isLegal(Type type) { + SmallVector<Type, 1> results; + return succeeded(convertType(type, results)) && results.size() == 1 && + results.front() == type; +} + +/// Return true if the inputs and outputs of the given function type are +/// legal. +bool TypeConverter::isSignatureLegal(FunctionType funcType) { + return llvm::all_of( + llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()), + [this](Type type) { return isLegal(type); }); +} + +/// This hook allows for converting a specific argument of a signature. +LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, + SignatureConversion &result) { + // Try to convert the given input type. + SmallVector<Type, 1> convertedTypes; + if (failed(convertType(type, convertedTypes))) + return failure(); + + // If this argument is being dropped, there is nothing left to do. + if (convertedTypes.empty()) + return success(); + + // Otherwise, add the new inputs. + result.addInputs(inputNo, convertedTypes); + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FuncOp. +namespace { +struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> { + FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : OpConversionPattern(ctx), converter(converter) {} + + /// Hook for derived classes to implement combined matching and rewriting. + PatternMatchResult + matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType type = funcOp.getType(); + + // Convert the original function arguments. + TypeConverter::SignatureConversion result(type.getNumInputs()); + for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) + return matchFailure(); + + // Convert the original function results. + SmallVector<Type, 1> convertedResults; + if (failed(converter.convertTypes(type.getResults(), convertedResults))) + return matchFailure(); + + // Update the function signature in-place. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(FunctionType::get(result.getConvertedTypes(), + convertedResults, funcOp.getContext())); + rewriter.applySignatureConversion(&funcOp.getBody(), result); + }); + return matchSuccess(); + } + + /// The type converter to use when rewriting the signature. + TypeConverter &converter; +}; +} // end anonymous namespace + +void mlir::populateFuncOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + patterns.insert<FuncOpSignatureConversion>(ctx, converter); +} + +/// This function converts the type signature of the given block, by invoking +/// 'convertSignatureArg' for each argument. This function should return a valid +/// conversion for the signature on success, None otherwise. +auto TypeConverter::convertBlockSignature(Block *block) + -> Optional<SignatureConversion> { + SignatureConversion conversion(block->getNumArguments()); + for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) + if (failed(convertSignatureArg(i, block->getArgument(i)->getType(), + conversion))) + return llvm::None; + return conversion; +} + +//===----------------------------------------------------------------------===// +// ConversionTarget +//===----------------------------------------------------------------------===// + +/// Register a legality action for the given operation. +void ConversionTarget::setOpAction(OperationName op, + LegalizationAction action) { + legalOperations[op] = {action, /*isRecursivelyLegal=*/false}; +} + +/// Register a legality action for the given dialects. +void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, + LegalizationAction action) { + for (StringRef dialect : dialectNames) + legalDialects[dialect] = action; +} + +/// Get the legality action for the given operation. +auto ConversionTarget::getOpAction(OperationName op) const + -> Optional<LegalizationAction> { + Optional<LegalizationInfo> info = getOpInfo(op); + return info ? info->action : Optional<LegalizationAction>(); +} + +/// If the given operation instance is legal on this target, a structure +/// containing legality information is returned. If the operation is not legal, +/// None is returned. +auto ConversionTarget::isLegal(Operation *op) const + -> Optional<LegalOpDetails> { + Optional<LegalizationInfo> info = getOpInfo(op->getName()); + if (!info) + return llvm::None; + + // Returns true if this operation instance is known to be legal. + auto isOpLegal = [&] { + // Handle dynamic legality. + if (info->action == LegalizationAction::Dynamic) { + // Check for callbacks on the operation or dialect. + auto opFn = opLegalityFns.find(op->getName()); + if (opFn != opLegalityFns.end()) + return opFn->second(op); + auto dialectFn = dialectLegalityFns.find(op->getName().getDialect()); + if (dialectFn != dialectLegalityFns.end()) + return dialectFn->second(op); + + // Otherwise, invoke the hook on the derived instance. + return isDynamicallyLegal(op); + } + + // Otherwise, the operation is only legal if it was marked 'Legal'. + return info->action == LegalizationAction::Legal; + }; + if (!isOpLegal()) + return llvm::None; + + // This operation is legal, compute any additional legality information. + LegalOpDetails legalityDetails; + + if (info->isRecursivelyLegal) { + auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); + if (legalityFnIt != opRecursiveLegalityFns.end()) + legalityDetails.isRecursivelyLegal = legalityFnIt->second(op); + else + legalityDetails.isRecursivelyLegal = true; + } + return legalityDetails; +} + +/// Set the dynamic legality callback for the given operation. +void ConversionTarget::setLegalityCallback( + OperationName name, const DynamicLegalityCallbackFn &callback) { + assert(callback && "expected valid legality callback"); + opLegalityFns[name] = callback; +} + +/// Set the recursive legality callback for the given operation and mark the +/// operation as recursively legal. +void ConversionTarget::markOpRecursivelyLegal( + OperationName name, const DynamicLegalityCallbackFn &callback) { + auto infoIt = legalOperations.find(name); + assert(infoIt != legalOperations.end() && + infoIt->second.action != LegalizationAction::Illegal && + "expected operation to already be marked as legal"); + infoIt->second.isRecursivelyLegal = true; + if (callback) + opRecursiveLegalityFns[name] = callback; + else + opRecursiveLegalityFns.erase(name); +} + +/// Set the dynamic legality callback for the given dialects. +void ConversionTarget::setLegalityCallback( + ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { + assert(callback && "expected valid legality callback"); + for (StringRef dialect : dialects) + dialectLegalityFns[dialect] = callback; +} + +/// Get the legalization information for the given operation. +auto ConversionTarget::getOpInfo(OperationName op) const + -> Optional<LegalizationInfo> { + // Check for info for this specific operation. + auto it = legalOperations.find(op); + if (it != legalOperations.end()) + return it->second; + // Otherwise, default to checking on the parent dialect. + auto dialectIt = legalDialects.find(op.getDialect()); + if (dialectIt != legalDialects.end()) + return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false}; + return llvm::None; +} + +//===----------------------------------------------------------------------===// +// Op Conversion Entry Points +//===----------------------------------------------------------------------===// + +/// Apply a partial conversion on the given operations, and all nested +/// operations. This method converts as many operations to the target as +/// possible, ignoring operations that failed to legalize. +LogicalResult mlir::applyPartialConversion( + ArrayRef<Operation *> ops, ConversionTarget &target, + const OwningRewritePatternList &patterns, TypeConverter *converter) { + OperationConverter opConverter(target, patterns, OpConversionMode::Partial); + return opConverter.convertOperations(ops, converter); +} +LogicalResult +mlir::applyPartialConversion(Operation *op, ConversionTarget &target, + const OwningRewritePatternList &patterns, + TypeConverter *converter) { + return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, + converter); +} + +/// Apply a complete conversion on the given operations, and all nested +/// operations. This method will return failure if the conversion of any +/// operation fails. +LogicalResult +mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target, + const OwningRewritePatternList &patterns, + TypeConverter *converter) { + OperationConverter opConverter(target, patterns, OpConversionMode::Full); + return opConverter.convertOperations(ops, converter); +} +LogicalResult +mlir::applyFullConversion(Operation *op, ConversionTarget &target, + const OwningRewritePatternList &patterns, + TypeConverter *converter) { + return applyFullConversion(llvm::makeArrayRef(op), target, patterns, + converter); +} + +/// Apply an analysis conversion on the given operations, and all nested +/// operations. This method analyzes which operations would be successfully +/// converted to the target if a conversion was applied. All operations that +/// were found to be legalizable to the given 'target' are placed within the +/// provided 'convertedOps' set; note that no actual rewrites are applied to the +/// operations on success and only pre-existing operations are added to the set. +LogicalResult mlir::applyAnalysisConversion( + ArrayRef<Operation *> ops, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet<Operation *> &convertedOps, TypeConverter *converter) { + OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, + &convertedOps); + return opConverter.convertOperations(ops, converter); +} +LogicalResult +mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet<Operation *> &convertedOps, + TypeConverter *converter) { + return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, + convertedOps, converter); +} diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp new file mode 100644 index 00000000000..b2cee7da083 --- /dev/null +++ b/mlir/lib/Transforms/Inliner.cpp @@ -0,0 +1,296 @@ +//===- Inliner.cpp - Pass to inline function calls ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a basic inlining algorithm that operates bottom up over +// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more +// incremental propagation of inlining decisions from the leafs to the roots of +// the callgraph. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/CallGraph.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Parallel.h" + +#define DEBUG_TYPE "inlining" + +using namespace mlir; + +static llvm::cl::opt<bool> disableCanonicalization( + "mlir-disable-inline-simplify", + llvm::cl::desc("Disable running simplifications during inlining"), + llvm::cl::ReallyHidden, llvm::cl::init(false)); + +static llvm::cl::opt<unsigned> maxInliningIterations( + "mlir-max-inline-iterations", + llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), + llvm::cl::ReallyHidden, llvm::cl::init(4)); + +//===----------------------------------------------------------------------===// +// CallGraph traversal +//===----------------------------------------------------------------------===// + +/// Run a given transformation over the SCCs of the callgraph in a bottom up +/// traversal. +static void runTransformOnCGSCCs( + const CallGraph &cg, + function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { + std::vector<CallGraphNode *> currentSCCVec; + auto cgi = llvm::scc_begin(&cg); + while (!cgi.isAtEnd()) { + // Copy the current SCC and increment so that the transformer can modify the + // SCC without invalidating our iterator. + currentSCCVec = *cgi; + ++cgi; + sccTransformer(currentSCCVec); + } +} + +namespace { +/// This struct represents a resolved call to a given callgraph node. Given that +/// the call does not actually contain a direct reference to the +/// Region(CallGraphNode) that it is dispatching to, we need to resolve them +/// explicitly. +struct ResolvedCall { + ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) + : call(call), targetNode(targetNode) {} + CallOpInterface call; + CallGraphNode *targetNode; +}; +} // end anonymous namespace + +/// Collect all of the callable operations within the given range of blocks. If +/// `traverseNestedCGNodes` is true, this will also collect call operations +/// inside of nested callgraph nodes. +static void collectCallOps(iterator_range<Region::iterator> blocks, + CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, + bool traverseNestedCGNodes) { + SmallVector<Block *, 8> worklist; + auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { + for (Block &block : blocks) + worklist.push_back(&block); + }; + + addToWorklist(blocks); + while (!worklist.empty()) { + for (Operation &op : *worklist.pop_back_val()) { + if (auto call = dyn_cast<CallOpInterface>(op)) { + CallGraphNode *node = + cg.resolveCallable(call.getCallableForCallee(), &op); + if (!node->isExternal()) + calls.emplace_back(call, node); + continue; + } + + // If this is not a call, traverse the nested regions. If + // `traverseNestedCGNodes` is false, then don't traverse nested call graph + // regions. + for (auto &nestedRegion : op.getRegions()) + if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) + addToWorklist(nestedRegion); + } + } +} + +//===----------------------------------------------------------------------===// +// Inliner +//===----------------------------------------------------------------------===// +namespace { +/// This class provides a specialization of the main inlining interface. +struct Inliner : public InlinerInterface { + Inliner(MLIRContext *context, CallGraph &cg) + : InlinerInterface(context), cg(cg) {} + + /// Process a set of blocks that have been inlined. This callback is invoked + /// *before* inlined terminator operations have been processed. + void + processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { + collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); + } + + /// The current set of call instructions to consider for inlining. + SmallVector<ResolvedCall, 8> calls; + + /// The callgraph being operated on. + CallGraph &cg; +}; +} // namespace + +/// Returns true if the given call should be inlined. +static bool shouldInline(ResolvedCall &resolvedCall) { + // Don't allow inlining terminator calls. We currently don't support this + // case. + if (resolvedCall.call.getOperation()->isKnownTerminator()) + return false; + + // Don't allow inlining if the target is an ancestor of the call. This + // prevents inlining recursively. + if (resolvedCall.targetNode->getCallableRegion()->isAncestor( + resolvedCall.call.getParentRegion())) + return false; + + // Otherwise, inline. + return true; +} + +/// Attempt to inline calls within the given scc. This function returns +/// success if any calls were inlined, failure otherwise. +static LogicalResult inlineCallsInSCC(Inliner &inliner, + ArrayRef<CallGraphNode *> currentSCC) { + CallGraph &cg = inliner.cg; + auto &calls = inliner.calls; + + // Collect all of the direct calls within the nodes of the current SCC. We + // don't traverse nested callgraph nodes, because they are handled separately + // likely within a different SCC. + for (auto *node : currentSCC) { + if (!node->isExternal()) + collectCallOps(*node->getCallableRegion(), cg, calls, + /*traverseNestedCGNodes=*/false); + } + if (calls.empty()) + return failure(); + + // Try to inline each of the call operations. Don't cache the end iterator + // here as more calls may be added during inlining. + bool inlinedAnyCalls = false; + for (unsigned i = 0; i != calls.size(); ++i) { + ResolvedCall &it = calls[i]; + LLVM_DEBUG({ + llvm::dbgs() << "* Considering inlining call: "; + it.call.dump(); + }); + if (!shouldInline(it)) + continue; + + CallOpInterface call = it.call; + Region *targetRegion = it.targetNode->getCallableRegion(); + LogicalResult inlineResult = inlineCall( + inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), + targetRegion); + if (failed(inlineResult)) + continue; + + // If the inlining was successful, then erase the call. + call.erase(); + inlinedAnyCalls = true; + } + calls.clear(); + return success(inlinedAnyCalls); +} + +/// Canonicalize the nodes within the given SCC with the given set of +/// canonicalization patterns. +static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, + MLIRContext *context, + const OwningRewritePatternList &canonPatterns) { + // Collect the sets of nodes to canonicalize. + SmallVector<CallGraphNode *, 4> nodesToCanonicalize; + for (auto *node : currentSCC) { + // Don't canonicalize the external node, it has no valid callable region. + if (node->isExternal()) + continue; + + // Don't canonicalize nodes with children. Nodes with children + // require special handling as we may remove the node during + // canonicalization. In the future, we should be able to handle this + // case with proper node deletion tracking. + if (node->hasChildren()) + continue; + + // We also won't apply canonicalizations for nodes that are not + // isolated. This avoids potentially mutating the regions of nodes defined + // above, this is also a stipulation of the 'applyPatternsGreedily' driver. + auto *region = node->getCallableRegion(); + if (!region->getParentOp()->isKnownIsolatedFromAbove()) + continue; + nodesToCanonicalize.push_back(node); + } + if (nodesToCanonicalize.empty()) + return; + + // Canonicalize each of the nodes within the SCC in parallel. + // NOTE: This is simple now, because we don't enable canonicalizing nodes + // within children. When we remove this restriction, this logic will need to + // be reworked. + ParallelDiagnosticHandler canonicalizationHandler(context); + llvm::parallel::for_each_n( + llvm::parallel::par, /*Begin=*/size_t(0), + /*End=*/nodesToCanonicalize.size(), [&](size_t index) { + // Set the order for this thread so that diagnostics will be properly + // ordered. + canonicalizationHandler.setOrderIDForThread(index); + + // Apply the canonicalization patterns to this region. + auto *node = nodesToCanonicalize[index]; + applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); + + // Make sure to reset the order ID for the diagnostic handler, as this + // thread may be used in a different context. + canonicalizationHandler.eraseOrderIDForThread(); + }); +} + +/// Attempt to inline calls within the given scc, and run canonicalizations with +/// the given patterns, until a fixed point is reached. This allows for the +/// inlining of newly devirtualized calls. +static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, + MLIRContext *context, + const OwningRewritePatternList &canonPatterns) { + // If we successfully inlined any calls, run some simplifications on the + // nodes of the scc. Continue attempting to inline until we reach a fixed + // point, or a maximum iteration count. We canonicalize here as it may + // devirtualize new calls, as well as give us a better cost model. + unsigned iterationCount = 0; + while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { + // If we aren't allowing simplifications or the max iteration count was + // reached, then bail out early. + if (disableCanonicalization || ++iterationCount >= maxInliningIterations) + break; + canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); + } +} + +//===----------------------------------------------------------------------===// +// InlinerPass +//===----------------------------------------------------------------------===// + +// TODO(riverriddle) This pass should currently only be used for basic testing +// of inlining functionality. +namespace { +struct InlinerPass : public OperationPass<InlinerPass> { + void runOnOperation() override { + CallGraph &cg = getAnalysis<CallGraph>(); + auto *context = &getContext(); + + // Collect a set of canonicalization patterns to use when simplifying + // callable regions within an SCC. + OwningRewritePatternList canonPatterns; + for (auto *op : context->getRegisteredOperations()) + op->getCanonicalizationPatterns(canonPatterns, context); + + // Run the inline transform in post-order over the SCCs in the callgraph. + Inliner inliner(context, cg); + runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { + inlineSCC(inliner, scc, context, canonPatterns); + }); + } +}; +} // end anonymous namespace + +std::unique_ptr<Pass> mlir::createInlinerPass() { + return std::make_unique<InlinerPass>(); +} + +static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp new file mode 100644 index 00000000000..2aee688c6c1 --- /dev/null +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -0,0 +1,96 @@ +//===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/Debug.h" + +#define PASS_NAME "loop-coalescing" +#define DEBUG_TYPE PASS_NAME + +using namespace mlir; + +namespace { +class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> { +public: + void runOnFunction() override { + FuncOp func = getFunction(); + + func.walk([](loop::ForOp op) { + // Ignore nested loops. + if (op.getParentOfType<loop::ForOp>()) + return; + + SmallVector<loop::ForOp, 4> loops; + getPerfectlyNestedLoops(loops, op); + LLVM_DEBUG(llvm::dbgs() + << "found a perfect nest of depth " << loops.size() << '\n'); + + // Look for a band of loops that can be coalesced, i.e. perfectly nested + // loops with bounds defined above some loop. + // 1. For each loop, find above which parent loop its operands are + // defined. + SmallVector<unsigned, 4> operandsDefinedAbove(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + operandsDefinedAbove[i] = i; + for (unsigned j = 0; j < i; ++j) { + if (areValuesDefinedAbove(loops[i].getOperands(), + loops[j].region())) { + operandsDefinedAbove[i] = j; + break; + } + } + LLVM_DEBUG(llvm::dbgs() + << " bounds of loop " << i << " are known above depth " + << operandsDefinedAbove[i] << '\n'); + } + + // 2. Identify bands of loops such that the operands of all of them are + // defined above the first loop in the band. Traverse the nest bottom-up + // so that modifications don't invalidate the inner loops. + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = + *std::max_element(std::next(operandsDefinedAbove.begin(), start), + std::next(operandsDefinedAbove.begin(), end)); + if (maxPos > start) + continue; + + assert(maxPos == start && + "expected loop bounds to be known at the start of the band"); + LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start + << " to " << end << '\n'); + + auto band = + llvm::makeMutableArrayRef(loops.data() + start, end - start); + coalesceLoops(band); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } + }); + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopCoalescingPass() { + return std::make_unique<LoopCoalescingPass>(); +} + +static PassRegistration<LoopCoalescingPass> + reg(PASS_NAME, + "coalesce nested loops with independent bounds into a single loop"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp new file mode 100644 index 00000000000..fcfc1d7ae52 --- /dev/null +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -0,0 +1,1979 @@ +//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop fusion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopFusionUtils.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <iomanip> +#include <sstream> +#define DEBUG_TYPE "affine-loop-fusion" + +using llvm::SetVector; + +using namespace mlir; + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +/// Disables fusion profitability check and fuses if valid. Ignore any +/// additional (redundant) computation tolerance threshold +/// that would have prevented fusion. +static llvm::cl::opt<bool> + clMaximalLoopFusion("fusion-maximal", + llvm::cl::desc("Enables maximal loop fusion"), + llvm::cl::cat(clOptionsCategory)); + +/// A threshold in percent of additional computation allowed when fusing. +static llvm::cl::opt<double> clFusionAddlComputeTolerance( + "fusion-compute-tolerance", + llvm::cl::desc("Fractional increase in additional " + "computation tolerated while fusing"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<unsigned> clFusionFastMemorySpace( + "fusion-fast-mem-space", + llvm::cl::desc("Faster memory space number to promote fusion buffers to"), + llvm::cl::cat(clOptionsCategory)); + +// A local buffer of size less than or equal to this size is automatically +// promoted to fast memory after producer-consumer fusion. +static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold( + "fusion-local-buf-threshold", + llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast " + "memory space"), + llvm::cl::cat(clOptionsCategory)); + +namespace { + +/// Loop fusion pass. This pass currently supports a greedy fusion policy, +/// which fuses loop nests with single-writer/single-reader memref dependences +/// with the goal of improving locality. + +// TODO(andydavis) Support fusion of source loop nests which write to multiple +// memrefs, where each memref can have multiple users (if profitable). +// TODO(andydavis) Extend this pass to check for fusion preventing dependences, +// and add support for more general loop fusion algorithms. + +struct LoopFusion : public FunctionPass<LoopFusion> { + LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, + bool maximalFusion = false) + : localBufSizeThreshold(localBufSizeThreshold), + fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} + + void runOnFunction() override; + + // Any local buffers smaller than this size (in bytes) will be created in + // `fastMemorySpace` if provided. + uint64_t localBufSizeThreshold; + Optional<unsigned> fastMemorySpace = None; + // If true, ignore any additional (redundant) computation tolerance threshold + // that would have prevented fusion. + bool maximalFusion; + + // The amount of additional computation that is tolerated while fusing + // pair-wise as a fraction of the total computation. + constexpr static double kComputeToleranceThreshold = 0.30f; +}; + +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createLoopFusionPass(unsigned fastMemorySpace, + uint64_t localBufSizeThreshold, bool maximalFusion) { + return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, + maximalFusion); +} + +// TODO(b/117228571) Replace when this is modeled through side-effects/op traits +static bool isMemRefDereferencingOp(Operation &op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || + isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) + return true; + return false; +} + +namespace { + +// LoopNestStateCollector walks loop nests and collects load and store +// operations, and whether or not an IfInst was encountered in the loop nest. +struct LoopNestStateCollector { + SmallVector<AffineForOp, 4> forOps; + SmallVector<Operation *, 4> loadOpInsts; + SmallVector<Operation *, 4> storeOpInsts; + bool hasNonForRegion = false; + + void collect(Operation *opToWalk) { + opToWalk->walk([&](Operation *op) { + if (isa<AffineForOp>(op)) + forOps.push_back(cast<AffineForOp>(op)); + else if (op->getNumRegions() != 0) + hasNonForRegion = true; + else if (isa<AffineLoadOp>(op)) + loadOpInsts.push_back(op); + else if (isa<AffineStoreOp>(op)) + storeOpInsts.push_back(op); + }); + } +}; + +// MemRefDependenceGraph is a graph data structure where graph nodes are +// top-level operations in a FuncOp which contain load/store ops, and edges +// are memref dependences between the nodes. +// TODO(andydavis) Add a more flexible dependence graph representation. +// TODO(andydavis) Add a depth parameter to dependence graph construction. +struct MemRefDependenceGraph { +public: + // Node represents a node in the graph. A Node is either an entire loop nest + // rooted at the top level which contains loads/stores, or a top level + // load/store. + struct Node { + // The unique identifier of this node in the graph. + unsigned id; + // The top-level statement which is (or contains) a load/store. + Operation *op; + // List of load operations. + SmallVector<Operation *, 4> loads; + // List of store op insts. + SmallVector<Operation *, 4> stores; + Node(unsigned id, Operation *op) : id(id), op(op) {} + + // Returns the load op count for 'memref'. + unsigned getLoadOpCount(Value memref) { + unsigned loadOpCount = 0; + for (auto *loadOpInst : loads) { + if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) + ++loadOpCount; + } + return loadOpCount; + } + + // Returns the store op count for 'memref'. + unsigned getStoreOpCount(Value memref) { + unsigned storeOpCount = 0; + for (auto *storeOpInst : stores) { + if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) + ++storeOpCount; + } + return storeOpCount; + } + + // Returns all store ops in 'storeOps' which access 'memref'. + void getStoreOpsForMemref(Value memref, + SmallVectorImpl<Operation *> *storeOps) { + for (auto *storeOpInst : stores) { + if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) + storeOps->push_back(storeOpInst); + } + } + + // Returns all load ops in 'loadOps' which access 'memref'. + void getLoadOpsForMemref(Value memref, + SmallVectorImpl<Operation *> *loadOps) { + for (auto *loadOpInst : loads) { + if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) + loadOps->push_back(loadOpInst); + } + } + + // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node + // has at least one load and store operation. + void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) { + llvm::SmallDenseSet<Value, 2> loadMemrefs; + for (auto *loadOpInst : loads) { + loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef()); + } + for (auto *storeOpInst : stores) { + auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); + if (loadMemrefs.count(memref) > 0) + loadAndStoreMemrefSet->insert(memref); + } + } + }; + + // Edge represents a data dependence between nodes in the graph. + struct Edge { + // The id of the node at the other end of the edge. + // If this edge is stored in Edge = Node.inEdges[i], then + // 'Node.inEdges[i].id' is the identifier of the source node of the edge. + // If this edge is stored in Edge = Node.outEdges[i], then + // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. + unsigned id; + // The SSA value on which this edge represents a dependence. + // If the value is a memref, then the dependence is between graph nodes + // which contain accesses to the same memref 'value'. If the value is a + // non-memref value, then the dependence is between a graph node which + // defines an SSA value and another graph node which uses the SSA value + // (e.g. a constant operation defining a value which is used inside a loop + // nest). + Value value; + }; + + // Map from node id to Node. + DenseMap<unsigned, Node> nodes; + // Map from node id to list of input edges. + DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; + // Map from node id to list of output edges. + DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; + // Map from memref to a count on the dependence edges associated with that + // memref. + DenseMap<Value, unsigned> memrefEdgeCount; + // The next unique identifier to use for newly created graph nodes. + unsigned nextNodeId = 0; + + MemRefDependenceGraph() {} + + // Initializes the dependence graph based on operations in 'f'. + // Returns true on success, false otherwise. + bool init(FuncOp f); + + // Returns the graph node for 'id'. + Node *getNode(unsigned id) { + auto it = nodes.find(id); + assert(it != nodes.end()); + return &it->second; + } + + // Returns the graph node for 'forOp'. + Node *getForOpNode(AffineForOp forOp) { + for (auto &idAndNode : nodes) + if (idAndNode.second.op == forOp.getOperation()) + return &idAndNode.second; + return nullptr; + } + + // Adds a node with 'op' to the graph and returns its unique identifier. + unsigned addNode(Operation *op) { + Node node(nextNodeId++, op); + nodes.insert({node.id, node}); + return node.id; + } + + // Remove node 'id' (and its associated edges) from graph. + void removeNode(unsigned id) { + // Remove each edge in 'inEdges[id]'. + if (inEdges.count(id) > 0) { + SmallVector<Edge, 2> oldInEdges = inEdges[id]; + for (auto &inEdge : oldInEdges) { + removeEdge(inEdge.id, id, inEdge.value); + } + } + // Remove each edge in 'outEdges[id]'. + if (outEdges.count(id) > 0) { + SmallVector<Edge, 2> oldOutEdges = outEdges[id]; + for (auto &outEdge : oldOutEdges) { + removeEdge(id, outEdge.id, outEdge.value); + } + } + // Erase remaining node state. + inEdges.erase(id); + outEdges.erase(id); + nodes.erase(id); + } + + // Returns true if node 'id' writes to any memref which escapes (or is an + // argument to) the function/block. Returns false otherwise. + bool writesToLiveInOrEscapingMemrefs(unsigned id) { + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { + auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); + auto *op = memref->getDefiningOp(); + // Return true if 'memref' is a block argument. + if (!op) + return true; + // Return true if any use of 'memref' escapes the function. + for (auto *user : memref->getUsers()) + if (!isMemRefDereferencingOp(*user)) + return true; + } + return false; + } + + // Returns the unique AffineStoreOp in `node` that meets all the following: + // *) store is the only one that writes to a function-local memref live out + // of `node`, + // *) store is not the source of a self-dependence on `node`. + // Otherwise, returns a null AffineStoreOp. + AffineStoreOp getUniqueOutgoingStore(Node *node) { + AffineStoreOp uniqueStore; + + // Return null if `node` doesn't have any outgoing edges. + auto outEdgeIt = outEdges.find(node->id); + if (outEdgeIt == outEdges.end()) + return nullptr; + + const auto &nodeOutEdges = outEdgeIt->second; + for (auto *op : node->stores) { + auto storeOp = cast<AffineStoreOp>(op); + auto memref = storeOp.getMemRef(); + // Skip this store if there are no dependences on its memref. This means + // that store either: + // *) writes to a memref that is only read within the same loop nest + // (self-dependence edges are not represented in graph at the moment), + // *) writes to a function live out memref (function parameter), or + // *) is dead. + if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { + return (edge.value != memref); + })) + continue; + + if (uniqueStore) + // Found multiple stores to function-local live-out memrefs. + return nullptr; + // Found first store to function-local live-out memref. + uniqueStore = storeOp; + } + + return uniqueStore; + } + + // Returns true if node 'id' can be removed from the graph. Returns false + // otherwise. A node can be removed from the graph iff the following + // conditions are met: + // *) The node does not write to any memref which escapes (or is a + // function/block argument). + // *) The node has no successors in the dependence graph. + bool canRemoveNode(unsigned id) { + if (writesToLiveInOrEscapingMemrefs(id)) + return false; + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { + // Return false if there exist out edges from 'id' on 'memref'. + if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0) + return false; + } + return true; + } + + // Returns true iff there is an edge from node 'srcId' to node 'dstId' which + // is for 'value' if non-null, or for any value otherwise. Returns false + // otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { + if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { + return false; + } + bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { + return edge.id == dstId && (!value || edge.value == value); + }); + bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { + return edge.id == srcId && (!value || edge.value == value); + }); + return hasOutEdge && hasInEdge; + } + + // Adds an edge from node 'srcId' to node 'dstId' for 'value'. + void addEdge(unsigned srcId, unsigned dstId, Value value) { + if (!hasEdge(srcId, dstId, value)) { + outEdges[srcId].push_back({dstId, value}); + inEdges[dstId].push_back({srcId, value}); + if (value->getType().isa<MemRefType>()) + memrefEdgeCount[value]++; + } + } + + // Removes an edge from node 'srcId' to node 'dstId' for 'value'. + void removeEdge(unsigned srcId, unsigned dstId, Value value) { + assert(inEdges.count(dstId) > 0); + assert(outEdges.count(srcId) > 0); + if (value->getType().isa<MemRefType>()) { + assert(memrefEdgeCount.count(value) > 0); + memrefEdgeCount[value]--; + } + // Remove 'srcId' from 'inEdges[dstId]'. + for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { + if ((*it).id == srcId && (*it).value == value) { + inEdges[dstId].erase(it); + break; + } + } + // Remove 'dstId' from 'outEdges[srcId]'. + for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { + if ((*it).id == dstId && (*it).value == value) { + outEdges[srcId].erase(it); + break; + } + } + } + + // Returns true if there is a path in the dependence graph from node 'srcId' + // to node 'dstId'. Returns false otherwise. + bool hasDependencePath(unsigned srcId, unsigned dstId) { + // Worklist state is: <node-id, next-output-edge-index-to-visit> + SmallVector<std::pair<unsigned, unsigned>, 4> worklist; + worklist.push_back({srcId, 0}); + // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. + while (!worklist.empty()) { + auto &idAndIndex = worklist.back(); + // Return true if we have reached 'dstId'. + if (idAndIndex.first == dstId) + return true; + // Pop and continue if node has no out edges, or if all out edges have + // already been visited. + if (outEdges.count(idAndIndex.first) == 0 || + idAndIndex.second == outEdges[idAndIndex.first].size()) { + worklist.pop_back(); + continue; + } + // Get graph edge to traverse. + Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; + // Increment next output edge index for 'idAndIndex'. + ++idAndIndex.second; + // Add node at 'edge.id' to worklist. + worklist.push_back({edge.id, 0}); + } + return false; + } + + // Returns the input edge count for node 'id' and 'memref' from src nodes + // which access 'memref' with a store operation. + unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { + unsigned inEdgeCount = 0; + if (inEdges.count(id) > 0) + for (auto &inEdge : inEdges[id]) + if (inEdge.value == memref) { + Node *srcNode = getNode(inEdge.id); + // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' + if (srcNode->getStoreOpCount(memref) > 0) + ++inEdgeCount; + } + return inEdgeCount; + } + + // Returns the output edge count for node 'id' and 'memref' (if non-null), + // otherwise returns the total output edge count from node 'id'. + unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { + unsigned outEdgeCount = 0; + if (outEdges.count(id) > 0) + for (auto &outEdge : outEdges[id]) + if (!memref || outEdge.value == memref) + ++outEdgeCount; + return outEdgeCount; + } + + // Computes and returns an insertion point operation, before which the + // the fused <srcId, dstId> loop nest can be inserted while preserving + // dependences. Returns nullptr if no such insertion point is found. + Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { + if (outEdges.count(srcId) == 0) + return getNode(dstId)->op; + + // Build set of insts in range (srcId, dstId) which depend on 'srcId'. + SmallPtrSet<Operation *, 2> srcDepInsts; + for (auto &outEdge : outEdges[srcId]) + if (outEdge.id != dstId) + srcDepInsts.insert(getNode(outEdge.id)->op); + + // Build set of insts in range (srcId, dstId) on which 'dstId' depends. + SmallPtrSet<Operation *, 2> dstDepInsts; + for (auto &inEdge : inEdges[dstId]) + if (inEdge.id != srcId) + dstDepInsts.insert(getNode(inEdge.id)->op); + + Operation *srcNodeInst = getNode(srcId)->op; + Operation *dstNodeInst = getNode(dstId)->op; + + // Computing insertion point: + // *) Walk all operation positions in Block operation list in the + // range (src, dst). For each operation 'op' visited in this search: + // *) Store in 'firstSrcDepPos' the first position where 'op' has a + // dependence edge from 'srcNode'. + // *) Store in 'lastDstDepPost' the last position where 'op' has a + // dependence edge to 'dstNode'. + // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the + // operation insertion point (or return null pointer if no such + // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). + SmallVector<Operation *, 2> depInsts; + Optional<unsigned> firstSrcDepPos; + Optional<unsigned> lastDstDepPos; + unsigned pos = 0; + for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); + it != Block::iterator(dstNodeInst); ++it) { + Operation *op = &(*it); + if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None) + firstSrcDepPos = pos; + if (dstDepInsts.count(op) > 0) + lastDstDepPos = pos; + depInsts.push_back(op); + ++pos; + } + + if (firstSrcDepPos.hasValue()) { + if (lastDstDepPos.hasValue()) { + if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) { + // No valid insertion point exists which preserves dependences. + return nullptr; + } + } + // Return the insertion point at 'firstSrcDepPos'. + return depInsts[firstSrcDepPos.getValue()]; + } + // No dependence targets in range (or only dst deps in range), return + // 'dstNodInst' insertion point. + return dstNodeInst; + } + + // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' + // has been replaced in node at 'dstId' by a private memref depending + // on the value of 'createPrivateMemRef'. + void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef, + bool createPrivateMemRef) { + // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. + if (inEdges.count(srcId) > 0) { + SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; + for (auto &inEdge : oldInEdges) { + // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. + if (inEdge.value != oldMemRef) + addEdge(inEdge.id, dstId, inEdge.value); + } + } + // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. + if (outEdges.count(srcId) > 0) { + SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; + for (auto &outEdge : oldOutEdges) { + // Remove any out edges from 'srcId' to 'dstId' across memrefs. + if (outEdge.id == dstId) + removeEdge(srcId, outEdge.id, outEdge.value); + } + } + // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being + // replaced by a private memref). These edges could come from nodes + // other than 'srcId' which were removed in the previous step. + if (inEdges.count(dstId) > 0 && createPrivateMemRef) { + SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; + for (auto &inEdge : oldInEdges) + if (inEdge.value == oldMemRef) + removeEdge(inEdge.id, dstId, inEdge.value); + } + } + + // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion + // of sibling node 'sidId' into node 'dstId'. + void updateEdges(unsigned sibId, unsigned dstId) { + // For each edge in 'inEdges[sibId]': + // *) Add new edge from source node 'inEdge.id' to 'dstNode'. + // *) Remove edge from source node 'inEdge.id' to 'sibNode'. + if (inEdges.count(sibId) > 0) { + SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; + for (auto &inEdge : oldInEdges) { + addEdge(inEdge.id, dstId, inEdge.value); + removeEdge(inEdge.id, sibId, inEdge.value); + } + } + + // For each edge in 'outEdges[sibId]' to node 'id' + // *) Add new edge from 'dstId' to 'outEdge.id'. + // *) Remove edge from 'sibId' to 'outEdge.id'. + if (outEdges.count(sibId) > 0) { + SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; + for (auto &outEdge : oldOutEdges) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(sibId, outEdge.id, outEdge.value); + } + } + } + + // Adds ops in 'loads' and 'stores' to node at 'id'. + void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, + const SmallVectorImpl<Operation *> &stores) { + Node *node = getNode(id); + for (auto *loadOpInst : loads) + node->loads.push_back(loadOpInst); + for (auto *storeOpInst : stores) + node->stores.push_back(storeOpInst); + } + + void clearNodeLoadAndStores(unsigned id) { + Node *node = getNode(id); + node->loads.clear(); + node->stores.clear(); + } + + // Calls 'callback' for each input edge incident to node 'id' which carries a + // memref dependence. + void forEachMemRefInputEdge(unsigned id, + const std::function<void(Edge)> &callback) { + if (inEdges.count(id) > 0) + forEachMemRefEdge(inEdges[id], callback); + } + + // Calls 'callback' for each output edge from node 'id' which carries a + // memref dependence. + void forEachMemRefOutputEdge(unsigned id, + const std::function<void(Edge)> &callback) { + if (outEdges.count(id) > 0) + forEachMemRefEdge(outEdges[id], callback); + } + + // Calls 'callback' for each edge in 'edges' which carries a memref + // dependence. + void forEachMemRefEdge(ArrayRef<Edge> edges, + const std::function<void(Edge)> &callback) { + for (auto &edge : edges) { + // Skip if 'edge' is not a memref dependence edge. + if (!edge.value->getType().isa<MemRefType>()) + continue; + assert(nodes.count(edge.id) > 0); + // Skip if 'edge.id' is not a loop nest. + if (!isa<AffineForOp>(getNode(edge.id)->op)) + continue; + // Visit current input edge 'edge'. + callback(edge); + } + } + + void print(raw_ostream &os) const { + os << "\nMemRefDependenceGraph\n"; + os << "\nNodes:\n"; + for (auto &idAndNode : nodes) { + os << "Node: " << idAndNode.first << "\n"; + auto it = inEdges.find(idAndNode.first); + if (it != inEdges.end()) { + for (const auto &e : it->second) + os << " InEdge: " << e.id << " " << e.value << "\n"; + } + it = outEdges.find(idAndNode.first); + if (it != outEdges.end()) { + for (const auto &e : it->second) + os << " OutEdge: " << e.id << " " << e.value << "\n"; + } + } + } + void dump() const { print(llvm::errs()); } +}; + +} // end anonymous namespace + +// Initializes the data dependence graph by walking operations in 'f'. +// Assigns each node in the graph a node id based on program order in 'f'. +// TODO(andydavis) Add support for taking a Block arg to construct the +// dependence graph at a different depth. +bool MemRefDependenceGraph::init(FuncOp f) { + DenseMap<Value, SetVector<unsigned>> memrefAccesses; + + // TODO: support multi-block functions. + if (f.getBlocks().size() != 1) + return false; + + DenseMap<Operation *, unsigned> forToNodeMap; + for (auto &op : f.front()) { + if (auto forOp = dyn_cast<AffineForOp>(op)) { + // Create graph node 'id' to represent top-level 'forOp' and record + // all loads and store accesses it contains. + LoopNestStateCollector collector; + collector.collect(&op); + // Return false if a non 'affine.for' region was found (not currently + // supported). + if (collector.hasNonForRegion) + return false; + Node node(nextNodeId++, &op); + for (auto *opInst : collector.loadOpInsts) { + node.loads.push_back(opInst); + auto memref = cast<AffineLoadOp>(opInst).getMemRef(); + memrefAccesses[memref].insert(node.id); + } + for (auto *opInst : collector.storeOpInsts) { + node.stores.push_back(opInst); + auto memref = cast<AffineStoreOp>(opInst).getMemRef(); + memrefAccesses[memref].insert(node.id); + } + forToNodeMap[&op] = node.id; + nodes.insert({node.id, node}); + } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { + // Create graph node for top-level load op. + Node node(nextNodeId++, &op); + node.loads.push_back(&op); + auto memref = cast<AffineLoadOp>(op).getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + // Create graph node for top-level store op. + Node node(nextNodeId++, &op); + node.stores.push_back(&op); + auto memref = cast<AffineStoreOp>(op).getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (op.getNumRegions() != 0) { + // Return false if another region is found (not currently supported). + return false; + } else if (op.getNumResults() > 0 && !op.use_empty()) { + // Create graph node for top-level producer of SSA values, which + // could be used by loop nest nodes. + Node node(nextNodeId++, &op); + nodes.insert({node.id, node}); + } + } + + // Add dependence edges between nodes which produce SSA values and their + // users. + for (auto &idAndNode : nodes) { + const Node &node = idAndNode.second; + if (!node.loads.empty() || !node.stores.empty()) + continue; + auto *opInst = node.op; + for (auto value : opInst->getResults()) { + for (auto *user : value->getUsers()) { + SmallVector<AffineForOp, 4> loops; + getLoopIVs(*user, &loops); + if (loops.empty()) + continue; + assert(forToNodeMap.count(loops[0].getOperation()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; + addEdge(node.id, userLoopNestId, value); + } + } + } + + // Walk memref access lists and add graph edges between dependent nodes. + for (auto &memrefAndList : memrefAccesses) { + unsigned n = memrefAndList.second.size(); + for (unsigned i = 0; i < n; ++i) { + unsigned srcId = memrefAndList.second[i]; + bool srcHasStore = + getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; + for (unsigned j = i + 1; j < n; ++j) { + unsigned dstId = memrefAndList.second[j]; + bool dstHasStore = + getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; + if (srcHasStore || dstHasStore) + addEdge(srcId, dstId, memrefAndList.first); + } + } + } + return true; +} + +// Removes load operations from 'srcLoads' which operate on 'memref', and +// adds them to 'dstLoads'. +static void moveLoadsAccessingMemrefTo(Value memref, + SmallVectorImpl<Operation *> *srcLoads, + SmallVectorImpl<Operation *> *dstLoads) { + dstLoads->clear(); + SmallVector<Operation *, 4> srcLoadsToKeep; + for (auto *load : *srcLoads) { + if (cast<AffineLoadOp>(load).getMemRef() == memref) + dstLoads->push_back(load); + else + srcLoadsToKeep.push_back(load); + } + srcLoads->swap(srcLoadsToKeep); +} + +// Returns the innermost common loop depth for the set of operations in 'ops'. +static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) { + unsigned numOps = ops.size(); + assert(numOps > 0); + + std::vector<SmallVector<AffineForOp, 4>> loops(numOps); + unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); + for (unsigned i = 0; i < numOps; ++i) { + getLoopIVs(*ops[i], &loops[i]); + loopDepthLimit = + std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); + } + + unsigned loopDepth = 0; + for (unsigned d = 0; d < loopDepthLimit; ++d) { + unsigned i; + for (i = 1; i < numOps; ++i) { + if (loops[i - 1][d] != loops[i][d]) + break; + } + if (i != numOps) + break; + ++loopDepth; + } + return loopDepth; +} + +// Returns the maximum loop depth at which no dependences between 'loadOpInsts' +// and 'storeOpInsts' are satisfied. +static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts, + ArrayRef<Operation *> storeOpInsts) { + // Merge loads and stores into the same array. + SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); + ops.append(storeOpInsts.begin(), storeOpInsts.end()); + + // Compute the innermost common loop depth for loads and stores. + unsigned loopDepth = getInnermostCommonLoopDepth(ops); + + // Return common loop depth for loads if there are no store ops. + if (storeOpInsts.empty()) + return loopDepth; + + // Check dependences on all pairs of ops in 'ops' and store the minimum + // loop depth at which a dependence is satisfied. + for (unsigned i = 0, e = ops.size(); i < e; ++i) { + auto *srcOpInst = ops[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < e; ++j) { + auto *dstOpInst = ops[j]; + MemRefAccess dstAccess(dstOpInst); + + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + FlatAffineConstraints dependenceConstraints; + // TODO(andydavis) Cache dependence analysis results, check cache here. + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) { + // Store minimum loop depth and break because we want the min 'd' at + // which there is a dependence. + loopDepth = std::min(loopDepth, d - 1); + break; + } + } + } + } + return loopDepth; +} + +// Sinks all sequential loops to the innermost levels (while preserving +// relative order among them) and moves all parallel loops to the +// outermost (while again preserving relative order among them). +// This can increase the loop depth at which we can fuse a slice, since we are +// pushing loop carried dependence to a greater depth in the loop nest. +static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { + assert(isa<AffineForOp>(node->op)); + AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); + node->op = newRootForOp.getOperation(); +} + +// TODO(mlir-team): improve/complete this when we have target data. +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast<VectorType>(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Creates and returns a private (single-user) memref for fused loop rooted +// at 'forOp', with (potentially reduced) memref size based on the +// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. +// TODO(bondhugula): consider refactoring the common code from generateDma and +// this one. +static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, + unsigned dstLoopDepth, + Optional<unsigned> fastMemorySpace, + uint64_t localBufSizeThreshold) { + auto *forInst = forOp.getOperation(); + + // Create builder to insert alloc op just before 'forOp'. + OpBuilder b(forInst); + // Builder to create constants at the top level. + OpBuilder top(forInst->getParentOfType<FuncOp>().getBody()); + // Create new memref type based on slice bounds. + auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef(); + auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); + unsigned rank = oldMemRefType.getRank(); + + // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. + MemRefRegion region(srcStoreOpInst->getLoc()); + bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); + (void)validRegion; + assert(validRegion && "unexpected memref region failure"); + SmallVector<int64_t, 4> newShape; + std::vector<SmallVector<int64_t, 4>> lbs; + SmallVector<int64_t, 8> lbDivisors; + lbs.reserve(rank); + // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed + // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + Optional<int64_t> numElements = + region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); + assert(numElements.hasValue() && + "non-constant number of elts in local buffer"); + + const FlatAffineConstraints *cst = region.getConstraints(); + // 'outerIVs' holds the values that this memory region is symbolic/parametric + // on; this would correspond to loop IVs surrounding the level at which the + // slice is being materialized. + SmallVector<Value, 8> outerIVs; + cst->getIdValues(rank, cst->getNumIds(), &outerIVs); + + // Build 'rank' AffineExprs from MemRefRegion 'lbs' + SmallVector<AffineExpr, 4> offsets; + offsets.reserve(rank); + for (unsigned d = 0; d < rank; ++d) { + assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); + + AffineExpr offset = top.getAffineConstantExpr(0); + for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { + offset = offset + lbs[d][j] * top.getAffineDimExpr(j); + } + assert(lbDivisors[d] > 0); + offset = + (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); + offsets.push_back(offset); + } + + // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed + // by 'srcStoreOpInst'. + uint64_t bufSize = + getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); + unsigned newMemSpace; + if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { + newMemSpace = fastMemorySpace.getValue(); + } else { + newMemSpace = oldMemRefType.getMemorySpace(); + } + auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), + {}, newMemSpace); + // Gather alloc operands for the dynamic dimensions of the memref. + SmallVector<Value, 4> allocOperands; + unsigned dynamicDimCount = 0; + for (auto dimSize : oldMemRefType.getShape()) { + if (dimSize == -1) + allocOperands.push_back( + top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++)); + } + + // Create new private memref for fused loop 'forOp'. + // TODO(andydavis) Create/move alloc ops for private memrefs closer to their + // consumer loop nests to reduce their live range. Currently they are added + // at the beginning of the function, because loop nests can be reordered + // during the fusion pass. + Value newMemRef = + top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands); + + // Build an AffineMap to remap access functions based on lower bound offsets. + SmallVector<AffineExpr, 4> remapExprs; + remapExprs.reserve(rank); + unsigned zeroOffsetCount = 0; + for (unsigned i = 0; i < rank; i++) { + if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>()) + if (constExpr.getValue() == 0) + ++zeroOffsetCount; + auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); + + auto remapExpr = + simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); + remapExprs.push_back(remapExpr); + } + auto indexRemap = zeroOffsetCount == rank + ? AffineMap() + : AffineMap::get(outerIVs.size() + rank, 0, remapExprs); + // Replace all users of 'oldMemRef' with 'newMemRef'. + LogicalResult res = + replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, + /*extraOperands=*/outerIVs, + /*symbolOperands=*/{}, + /*domInstFilter=*/&*forOp.getBody()->begin()); + assert(succeeded(res) && + "replaceAllMemrefUsesWith should always succeed here"); + (void)res; + return newMemRef; +} + +// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' +// may write to multiple memrefs but it is required that only one of them, +// 'srcLiveOutStoreOp', has output edges. +// Returns true if 'dstNode's read/write region to 'memref' is a super set of +// 'srcNode's write region to 'memref' and 'srcId' has only one output edge. +// TODO(andydavis) Generalize this to handle more live in/out cases. +static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, + AffineStoreOp srcLiveOutStoreOp, + MemRefDependenceGraph *mdg) { + assert(srcLiveOutStoreOp && "Expected a valid store op"); + auto *dstNode = mdg->getNode(dstId); + Value memref = srcLiveOutStoreOp.getMemRef(); + // Return false if 'srcNode' has more than one output edge on 'memref'. + if (mdg->getOutEdgeCount(srcId, memref) > 1) + return false; + + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. + MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); + if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for source operation\n."); + return false; + } + SmallVector<int64_t, 4> srcShape; + // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. + // by 'srcStoreOp' at depth 'dstLoopDepth'. + Optional<int64_t> srcNumElements = + srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); + if (!srcNumElements.hasValue()) + return false; + + // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'. + // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for + // each store op in 'dstStoreOps'). + SmallVector<Operation *, 2> dstStoreOps; + dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + SmallVector<Operation *, 2> dstLoadOps; + dstNode->getLoadOpsForMemref(memref, &dstLoadOps); + + auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0]; + MemRefRegion dstRegion(dstOpInst->getLoc()); + if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for dest operation\n."); + return false; + } + SmallVector<int64_t, 4> dstShape; + // Query 'dstRegion' for 'dstShape' and 'dstNumElements'. + // by 'dstOpInst' at depth 'dstLoopDepth'. + Optional<int64_t> dstNumElements = + dstRegion.getConstantBoundingSizeAndShape(&dstShape); + if (!dstNumElements.hasValue()) + return false; + + // Return false if write region is not a superset of 'srcNodes' write + // region to 'memref'. + // TODO(andydavis) Check the shape and lower bounds here too. + if (srcNumElements != dstNumElements) + return false; + return true; +} + +// Checks the profitability of fusing a backwards slice of the loop nest +// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. +// The argument 'srcStoreOpInst' is used to calculate the storage reduction on +// the memref being produced and consumed, which is an input to the cost model. +// For producer-consumer fusion, 'srcStoreOpInst' will be the same as +// 'srcOpInst', as we are slicing w.r.t to that producer. +// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which +// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' +// will be the unique store op in the src node, which will be used to check +// that the write region is the same after input-reuse fusion. +// Returns true if it is profitable to fuse the candidate loop nests. Returns +// false otherwise. `dstLoopDepth` is set to the most profitable depth at which +// to materialize the source loop nest slice. +// The profitability model executes the following steps: +// *) Computes the backward computation slice at 'srcOpInst'. This +// computation slice of the loop nest surrounding 'srcOpInst' is +// represented by modified src loop bounds in 'sliceState', which are +// functions of loop IVs in the loop nest surrounding 'srcOpInst'. +// *) Computes the cost of unfused src/dst loop nests (currently the cost of a +// loop nest is the total number of dynamic operation instances in the loop +// nest). +// *) Computes the cost of fusing a slice of the src loop nest into the dst +// loop nest at various values of dst loop depth, attempting to fuse +// the largest computation slice at the maximal dst loop depth (closest to +// the load) to minimize reuse distance and potentially enable subsequent +// load/store forwarding. +// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for +// the same memref as is written by 'srcOpInst', then the union of slice +// loop bounds is used to compute the slice and associated slice cost. +// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop +// nest, at which the src computation slice is inserted/fused. +// NOTE: We attempt to maximize the dst loop depth, but there are cases +// where a particular setting for 'dstLoopNest' might fuse an unsliced +// loop (within the src computation slice) at a depth which results in +// excessive recomputation (see unit tests for examples). +// *) Compares the total cost of the unfused loop nests to the min cost fused +// loop nest computed in the previous step, and returns true if the latter +// is lower. +static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, + ArrayRef<Operation *> dstLoadOpInsts, + ArrayRef<Operation *> dstStoreOpInsts, + ComputationSliceState *sliceState, + unsigned *dstLoopDepth, bool maximalFusion) { + LLVM_DEBUG({ + llvm::dbgs() << "Checking whether fusion is profitable between:\n"; + llvm::dbgs() << " " << *srcOpInst << " and \n"; + for (auto dstOpInst : dstLoadOpInsts) { + llvm::dbgs() << " " << *dstOpInst << "\n"; + }; + }); + + // Compute cost of sliced and unsliced src loop nest. + SmallVector<AffineForOp, 4> srcLoopIVs; + getLoopIVs(*srcOpInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + + // Walk src loop nest and collect stats. + LoopNestStats srcLoopNestStats; + if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) + return false; + + // Compute cost of dst loop nest. + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); + + LoopNestStats dstLoopNestStats; + if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) + return false; + + // Compute the maximum loop depth at which we can can insert the src slice + // and still satisfy dest loop nest dependences, for producer-consumer fusion. + unsigned maxDstLoopDepth = + (srcOpInst == srcStoreOpInst) + ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) + : dstLoopIVs.size(); + if (maxDstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); + return false; + } + + // Search for min cost value for 'dstLoopDepth'. At each value of + // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union + // of these bounds). Next the union slice bounds are used to calculate + // the cost of the slice and the cost of the slice inserted into the dst + // loop nest at 'dstLoopDepth'. + uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); + double maxStorageReduction = 0.0; + Optional<uint64_t> sliceMemEstimate = None; + + SmallVector<ComputationSliceState, 4> sliceStates; + sliceStates.resize(maxDstLoopDepth); + // The best loop depth at which to materialize the slice. + Optional<unsigned> bestDstLoopDepth = None; + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); + + // Compute src loop nest write region size. + MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); + if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for source operation\n."); + return false; + } + + Optional<int64_t> maybeSrcWriteRegionSizeBytes = + srcWriteRegion.getRegionSize(); + if (!maybeSrcWriteRegionSizeBytes.hasValue()) + return false; + int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); + + // Compute op instance count for the src loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats); + + // Evaluate all depth choices for materializing the slice in the destination + // loop nest. + for (unsigned i = maxDstLoopDepth; i >= 1; --i) { + // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. + if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, + /*loopDepth=*/i, + /*numCommonLoops=*/0, + /*isBackwardSlice=*/true, + &sliceStates[i - 1]))) { + LLVM_DEBUG(llvm::dbgs() + << "computeSliceUnion failed for loopDepth: " << i << "\n"); + continue; + } + + int64_t fusedLoopNestComputeCost; + if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], + dstLoopNestStats, &sliceStates[i - 1], + &fusedLoopNestComputeCost)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); + continue; + } + + double additionalComputeFraction = + fusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1; + + // Determine what the slice write MemRefRegion would be, if the src loop + // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop + // nest at loop depth 'i' + MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); + if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, + &sliceStates[i - 1]))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to compute slice write region at loopDepth: " << i + << "\n"); + continue; + } + + Optional<int64_t> maybeSliceWriteRegionSizeBytes = + sliceWriteRegion.getRegionSize(); + if (!maybeSliceWriteRegionSizeBytes.hasValue() || + maybeSliceWriteRegionSizeBytes.getValue() == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to get slice write region size at loopDepth: " << i + << "\n"); + continue; + } + int64_t sliceWriteRegionSizeBytes = + maybeSliceWriteRegionSizeBytes.getValue(); + + // If we are fusing for reuse, check that write regions remain the same. + // TODO(andydavis) Write region check should check sizes and offsets in + // each dimension, so that we are sure they are covering the same memref + // region. Also, move this out to a isMemRefRegionSuperSet helper function. + if (srcOpInst != srcStoreOpInst && + sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) + continue; + + double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / + static_cast<double>(sliceWriteRegionSizeBytes); + + LLVM_DEBUG({ + std::stringstream msg; + msg << " evaluating fusion profitability at depth : " << i << "\n" + << std::fixed << std::setprecision(2) + << " additional compute fraction: " + << 100.0 * additionalComputeFraction << "%\n" + << " storage reduction factor: " << storageReduction << "x\n" + << " fused nest cost: " << fusedLoopNestComputeCost << "\n" + << " src write region size: " << srcWriteRegionSizeBytes << "\n" + << " slice write region size: " << sliceWriteRegionSizeBytes + << "\n"; + llvm::dbgs() << msg.str(); + }); + + double computeToleranceThreshold = + clFusionAddlComputeTolerance.getNumOccurrences() > 0 + ? clFusionAddlComputeTolerance + : LoopFusion::kComputeToleranceThreshold; + + // TODO(b/123247369): This is a placeholder cost model. + // Among all choices that add an acceptable amount of redundant computation + // (as per computeToleranceThreshold), we will simply pick the one that + // reduces the intermediary size the most. + if ((storageReduction > maxStorageReduction) && + (maximalFusion || + (additionalComputeFraction < computeToleranceThreshold))) { + maxStorageReduction = storageReduction; + bestDstLoopDepth = i; + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + sliceMemEstimate = sliceWriteRegionSizeBytes; + } + } + + // A simple cost model: fuse if it reduces the memory footprint. If + // -maximal-fusion is set, fuse nevertheless. + + if (!maximalFusion && !bestDstLoopDepth.hasValue()) { + LLVM_DEBUG( + llvm::dbgs() + << "All fusion choices involve more than the threshold amount of " + "redundant computation; NOT fusing.\n"); + return false; + } + + if (!bestDstLoopDepth.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); + return false; + } + + // Set dstLoopDepth based on best values from search. + *dstLoopDepth = bestDstLoopDepth.getValue(); + + LLVM_DEBUG( + llvm::dbgs() << " LoopFusion fusion stats:" + << "\n best loop depth: " << bestDstLoopDepth + << "\n src loop nest compute cost: " << srcLoopNestCost + << "\n dst loop nest compute cost: " << dstLoopNestCost + << "\n fused loop nest compute cost: " + << minFusedLoopNestComputeCost << "\n"); + + auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); + + Optional<double> storageReduction = None; + + if (!maximalFusion) { + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG( + llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } + + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); + + assert(sliceMemEstimate.hasValue() && "expected value"); + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); + + if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; + } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); + } + + double additionalComputeFraction = + 100.0 * (minFusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1); + (void)additionalComputeFraction; + LLVM_DEBUG({ + std::stringstream msg; + msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " + << std::setprecision(2) << additionalComputeFraction + << "% redundant computation and a "; + msg << (storageReduction.hasValue() + ? std::to_string(storageReduction.getValue()) + : "<unknown>"); + msg << "% storage reduction.\n"; + llvm::dbgs() << msg.str(); + }); + + // Update return parameter 'sliceState' with 'bestSliceState'. + ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; + sliceState->lbs = bestSliceState->lbs; + sliceState->ubs = bestSliceState->ubs; + sliceState->lbOperands = bestSliceState->lbOperands; + sliceState->ubOperands = bestSliceState->ubOperands; + + // Canonicalize slice bound affine maps. + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + if (sliceState->lbs[i] != AffineMap()) { + canonicalizeMapAndOperands(&sliceState->lbs[i], + &sliceState->lbOperands[i]); + } + if (sliceState->ubs[i] != AffineMap()) { + canonicalizeMapAndOperands(&sliceState->ubs[i], + &sliceState->ubOperands[i]); + } + } + return true; +} + +namespace { + +// GreedyFusion greedily fuses loop nests which have a producer/consumer or +// input-reuse relationship on a memref, with the goal of improving locality. +// +// The steps of the producer-consumer fusion algorithm are as follows: +// +// *) A worklist is initialized with node ids from the dependence graph. +// *) For each node id in the worklist: +// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a +// candidate destination AffineForOp into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. +// *) For each LoadOp in 'dstLoadOps' do: +// *) Look up dependent loop nests which have a single store op to the same +// memref. +// *) Check if dependences would be violated by the fusion. +// *) Get a computation slice of 'srcLoopNest', which adjusts its loop +// bounds to be functions of 'dstLoopNest' IVs and symbols. +// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', +// at a loop depth determined by the cost model in 'isFusionProfitable'. +// *) Add the newly fused load/store operations to the state, +// and also add newly fused load ops to 'dstLoopOps' to be considered +// as fusion dst load ops in another iteration. +// *) Remove old src loop nest and its associated state. +// +// The steps of the input-reuse fusion algorithm are as follows: +// +// *) Initialize 'worklist' with node ids from the dependence graph. +// *) For each 'dstNode' in the worklist: +// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which +// loads from the same memref, but which has no dependence paths to/from. +// *) Get a computation slice of 'sibLoopNest', which adjusts its loop +// bounds to be functions of 'dstLoopNest' IVs and symbols. +// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', +// at a loop depth determined by the cost model in 'isFusionProfitable'. +// This function also checks that the memref write region of 'sibLoopNest', +// is preserved in the fused loop nest. +// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. +// +// Given a graph where top-level operations are vertices in the set 'V' and +// edges in the set 'E' are dependences between vertices, this algorithm +// takes O(V) time for initialization, and has runtime O(V + E). +// +// This greedy algorithm is not 'maximal' due to the current restriction of +// fusing along single producer consumer edges, but there is a TODO to fix this. +// +// TODO(andydavis) Experiment with other fusion policies. +struct GreedyFusion { +public: + // The data dependence graph to traverse during fusion. + MemRefDependenceGraph *mdg; + // Worklist of graph nodes visited during the fusion pass. + SmallVector<unsigned, 8> worklist; + // Set of graph nodes which are present on the worklist. + llvm::SmallDenseSet<unsigned, 16> worklistSet; + // Parameter for local buffer size threshold. + unsigned localBufSizeThreshold; + // Parameter for fast memory space. + Optional<unsigned> fastMemorySpace; + // If true, ignore any additional (redundant) computation tolerance threshold + // that would have prevented fusion. + bool maximalFusion; + + using Node = MemRefDependenceGraph::Node; + + GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, + Optional<unsigned> fastMemorySpace, bool maximalFusion) + : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), + fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} + + // Initializes 'worklist' with nodes from 'mdg' + void init() { + // TODO(andydavis) Add a priority queue for prioritizing nodes by different + // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). + worklist.clear(); + worklistSet.clear(); + for (auto &idAndNode : mdg->nodes) { + const Node &node = idAndNode.second; + worklist.push_back(node.id); + worklistSet.insert(node.id); + } + } + + // Run the GreedyFusion pass. + // *) First pass through the nodes fuses single-use producer nodes into their + // unique consumer. + // *) Second pass fuses sibling nodes which share no dependence edges. + // *) Third pass fuses any remaining producer nodes into their users. + void run() { + // TODO(andydavis) Run this repeatedly until a fixed-point is reached. + fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); + fuseSiblingNodes(); + fuseProducerConsumerNodes( + /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); + eraseUnusedMemRefAllocations(); + } + + void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { + init(); + while (!worklist.empty()) { + unsigned dstId = worklist.back(); + worklist.pop_back(); + worklistSet.erase(dstId); + + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(dstId) == 0) + continue; + // Get 'dstNode' into which to attempt fusion. + auto *dstNode = mdg->getNode(dstId); + // Skip if 'dstNode' is not a loop nest. + if (!isa<AffineForOp>(dstNode->op)) + continue; + // Sink sequential loops in 'dstNode' (and thus raise parallel loops) + // while preserving relative order. This can increase the maximum loop + // depth at which we can fuse a slice of a producer loop nest into a + // consumer loop nest. + sinkSequentialLoops(dstNode); + + SmallVector<Operation *, 4> loads = dstNode->loads; + SmallVector<Operation *, 4> dstLoadOpInsts; + DenseSet<Value> visitedMemrefs; + while (!loads.empty()) { + // Get memref of load on top of the stack. + auto memref = cast<AffineLoadOp>(loads.back()).getMemRef(); + if (visitedMemrefs.count(memref) > 0) + continue; + visitedMemrefs.insert(memref); + // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. + moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); + // Skip if no input edges along which to fuse. + if (mdg->inEdges.count(dstId) == 0) + continue; + // Iterate through in-edges for 'dstId' and src node id for any + // edges on 'memref'. + SmallVector<unsigned, 2> srcNodeIds; + for (auto &srcEdge : mdg->inEdges[dstId]) { + // Skip 'srcEdge' if not for 'memref'. + if (srcEdge.value != memref) + continue; + srcNodeIds.push_back(srcEdge.id); + } + for (unsigned srcId : srcNodeIds) { + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(srcId) == 0) + continue; + // Get 'srcNode' from which to attempt fusion into 'dstNode'. + auto *srcNode = mdg->getNode(srcId); + // Skip if 'srcNode' is not a loop nest. + if (!isa<AffineForOp>(srcNode->op)) + continue; + // Skip if 'srcNode' has more than one live-out store to a + // function-local memref. + // TODO(andydavis) Support more generic multi-output src loop nests + // fusion. + auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); + if (!srcStoreOp) { + // Get the src store op at the deepest loop depth. + // We will use 'LoopFusionUtils::canFuseLoops' to check fusion + // feasibility for loops with multiple stores. + unsigned maxLoopDepth = 0; + for (auto *op : srcNode->stores) { + auto storeOp = cast<AffineStoreOp>(op); + if (storeOp.getMemRef() != memref) { + srcStoreOp = nullptr; + break; + } + unsigned loopDepth = getNestingDepth(*storeOp); + if (loopDepth > maxLoopDepth) { + maxLoopDepth = loopDepth; + srcStoreOp = storeOp; + } + } + if (!srcStoreOp) + continue; + } + + // Unique outgoing store found must write to 'memref' since 'memref' + // is the one that established the producer-consumer relationship + // between 'srcNode' and 'dstNode'. + assert(srcStoreOp.getMemRef() == memref && + "Found store to unexpected memref"); + + // Skip if 'srcNode' writes to any live in or escaping memrefs, + // and cannot be fused. + bool writesToLiveInOrOut = + mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); + if (writesToLiveInOrOut && + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) + continue; + + // Don't create a private memref if 'writesToLiveInOrOut'. + bool createPrivateMemref = !writesToLiveInOrOut; + // Don't create a private memref if 'srcNode' has in edges on + // 'memref', or if 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || + mdg->getOutEdgeCount(dstNode->id, memref) > 0) { + createPrivateMemref = false; + } + + // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. + if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) + continue; + + // Compute an operation list insertion point for the fused loop + // nest which preserves dependences. + Operation *insertPointInst = + mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); + if (insertPointInst == nullptr) + continue; + + // Compute the innermost common loop depth for dstNode loads/stores. + SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(), + dstNode->loads.end()); + dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + // Check the feasibility of fusing src loop nest into dst loop nest + // at loop depths in range [1, dstLoopDepthTest]. + // TODO(andydavis) Use slice union computation and union of memref + // read/write regions to cost model and fusion. + bool canFuse = false; + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + ComputationSliceState sliceUnion; + FusionResult result = mlir::canFuseLoops( + cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op), + /*dstLoopDepth=*/i, &sliceUnion); + if (result.value == FusionResult::Success) + canFuse = true; + } + + // Skip if fusion is not feasible at all loop depths. + if (!canFuse) + continue; + + // Gather 'dstNode' store ops to 'memref'. + SmallVector<Operation *, 2> dstStoreOpInsts; + for (auto *storeOpInst : dstNode->stores) + if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref) + dstStoreOpInsts.push_back(storeOpInst); + + unsigned bestDstLoopDepth; + mlir::ComputationSliceState sliceState; + // Check if fusion would be profitable. + if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, + &bestDstLoopDepth, maximalFusion)) + continue; + + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. + auto sliceLoopNest = mlir::insertBackwardComputationSlice( + srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); + if (sliceLoopNest) { + LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" + << *sliceLoopNest.getOperation() << "\n"); + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + if (insertPointInst != dstAffineForOp.getOperation()) { + dstAffineForOp.getOperation()->moveBefore(insertPointInst); + } + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); + + // Collect slice loop stats. + LoopNestStateCollector sliceCollector; + sliceCollector.collect(sliceLoopNest.getOperation()); + // Promote single iteration slice loops to single IV value. + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); + } + if (createPrivateMemref) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector<Operation *, 4> storesForMemref; + for (auto *storeOpInst : sliceCollector.storeOpInsts) { + if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref) + storesForMemref.push_back(storeOpInst); + } + // TODO(andydavis) Use union of memref write regions to compute + // private memref footprint. + auto newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = + mdg->addNode(newMemRef->getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + } + + // Collect dst loop stats after memref privatization transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstAffineForOp.getOperation()); + + // Add new load ops to current Node load op list 'loads' to + // continue fusing based on new operands. + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef(); + if (visitedMemrefs.count(loadMemRef) == 0) + loads.push_back(loadOpInst); + } + + // Clear and add back loads and stores. + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and if it does not write to a memref which escapes the + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has + // been fused into 'dstNode' and write region of 'dstNode' covers + // the write region of 'srcNode', and 'srcNode' has no other users + // so it is safe to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { + mdg->removeNode(srcNode->id); + srcNode->op->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created + // new fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); + } + } + } + } + } + } + } + } + } + + // Visits each node in the graph, and for each node, attempts to fuse it with + // its sibling nodes (nodes which share a parent, but no dependence edges). + void fuseSiblingNodes() { + init(); + while (!worklist.empty()) { + unsigned dstId = worklist.back(); + worklist.pop_back(); + worklistSet.erase(dstId); + + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(dstId) == 0) + continue; + // Get 'dstNode' into which to attempt fusion. + auto *dstNode = mdg->getNode(dstId); + // Skip if 'dstNode' is not a loop nest. + if (!isa<AffineForOp>(dstNode->op)) + continue; + // Attempt to fuse 'dstNode' with its sibling nodes in the graph. + fuseWithSiblingNodes(dstNode); + } + } + + // Attempt to fuse 'dstNode' with sibling nodes in the graph. + void fuseWithSiblingNodes(Node *dstNode) { + DenseSet<unsigned> visitedSibNodeIds; + std::pair<unsigned, Value> idAndMemref; + while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { + unsigned sibId = idAndMemref.first; + Value memref = idAndMemref.second; + // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other + // stores to the same memref in 'sibNode' loop nest. + auto *sibNode = mdg->getNode(sibId); + // Compute an operation list insertion point for the fused loop + // nest which preserves dependences. + assert(sibNode->op->getBlock() == dstNode->op->getBlock()); + Operation *insertPointInst = + sibNode->op->isBeforeInBlock(dstNode->op) + ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) + : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); + if (insertPointInst == nullptr) + continue; + + // Check if fusion would be profitable and at what depth. + + // Get unique 'sibNode' load op to 'memref'. + SmallVector<Operation *, 2> sibLoadOpInsts; + sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); + // Currently findSiblingNodeToFuse searches for siblings with one load. + assert(sibLoadOpInsts.size() == 1); + Operation *sibLoadOpInst = sibLoadOpInsts[0]; + assert(!sibNode->stores.empty()); + // TODO(andydavis) Choose the store which postdominates all other stores. + auto *sibStoreOpInst = sibNode->stores.back(); + + // Gather 'dstNode' load ops to 'memref'. + SmallVector<Operation *, 2> dstLoadOpInsts; + dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); + + // Gather 'dstNode' store ops to 'memref'. + SmallVector<Operation *, 2> dstStoreOpInsts; + dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); + + unsigned bestDstLoopDepth; + mlir::ComputationSliceState sliceState; + + // Check if fusion would be profitable. + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + dstStoreOpInsts, &sliceState, &bestDstLoopDepth, + maximalFusion)) + continue; + + // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. + auto sliceLoopNest = mlir::insertBackwardComputationSlice( + sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); + if (sliceLoopNest != nullptr) { + auto dstForInst = cast<AffineForOp>(dstNode->op); + // Update operation position of fused loop nest (if needed). + if (insertPointInst != dstForInst.getOperation()) { + dstForInst.getOperation()->moveBefore(insertPointInst); + } + // Update data dependence graph state post fusion. + updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); + } + } + } + + // Searches function argument uses and the graph from 'dstNode' looking for a + // fusion candidate sibling node which shares no dependences with 'dstNode' + // but which loads from the same memref. Returns true and sets + // 'idAndMemrefToFuse' on success. Returns false otherwise. + bool findSiblingNodeToFuse(Node *dstNode, + DenseSet<unsigned> *visitedSibNodeIds, + std::pair<unsigned, Value> *idAndMemrefToFuse) { + // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse + // on 'memref'. + auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { + // Skip if 'outEdge' is not a read-after-write dependence. + // TODO(andydavis) Remove restrict to single load op restriction. + if (sibNode->getLoadOpCount(memref) != 1) + return false; + // Skip if there exists a path of dependent edges between + // 'sibNode' and 'dstNode'. + if (mdg->hasDependencePath(sibNode->id, dstNode->id) || + mdg->hasDependencePath(dstNode->id, sibNode->id)) + return false; + // Skip sib node if it loads to (and stores from) the same memref on + // which it also has an input dependence edge. + DenseSet<Value> loadAndStoreMemrefSet; + sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); + if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { + return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; + })) + return false; + + // Check that all stores are to the same memref. + DenseSet<Value> storeMemrefs; + for (auto *storeOpInst : sibNode->stores) { + storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef()); + } + if (storeMemrefs.size() != 1) + return false; + return true; + }; + + // Search for siblings which load the same memref function argument. + auto fn = dstNode->op->getParentOfType<FuncOp>(); + for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { + for (auto *user : fn.getArgument(i)->getUsers()) { + if (auto loadOp = dyn_cast<AffineLoadOp>(user)) { + // Gather loops surrounding 'use'. + SmallVector<AffineForOp, 4> loops; + getLoopIVs(*user, &loops); + // Skip 'use' if it is not within a loop nest. + if (loops.empty()) + continue; + Node *sibNode = mdg->getForOpNode(loops[0]); + assert(sibNode != nullptr); + // Skip 'use' if it not a sibling to 'dstNode'. + if (sibNode->id == dstNode->id) + continue; + // Skip 'use' if it has been visited. + if (visitedSibNodeIds->count(sibNode->id) > 0) + continue; + // Skip 'use' if it does not load from the same memref as 'dstNode'. + auto memref = loadOp.getMemRef(); + if (dstNode->getLoadOpCount(memref) == 0) + continue; + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, memref)) { + visitedSibNodeIds->insert(sibNode->id); + idAndMemrefToFuse->first = sibNode->id; + idAndMemrefToFuse->second = memref; + return true; + } + } + } + } + + // Search for siblings by following edges through an intermediate src node. + // Collect candidate 'dstNode' input edges in 'inEdges'. + SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; + mdg->forEachMemRefInputEdge( + dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { + // Add 'inEdge' if it is a read-after-write dependence. + if (dstNode->getLoadOpCount(inEdge.value) > 0 && + mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) + inEdges.push_back(inEdge); + }); + + // Search for sibling nodes to fuse by visiting output edges from each input + // edge in 'inEdges'. + for (auto &inEdge : inEdges) { + // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. + SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; + mdg->forEachMemRefOutputEdge( + inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { + unsigned sibNodeId = outEdge.id; + if (visitedSibNodeIds->count(sibNodeId) > 0) + return; + // Skip output edge if not a sibling using the same memref. + if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) + return; + auto *sibNode = mdg->getNode(sibNodeId); + if (!isa<AffineForOp>(sibNode->op)) + return; + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, outEdge.value)) { + // Add candidate 'outEdge' to sibling node. + outEdges.push_back(outEdge); + } + }); + + // Add first candidate if any were returned. + if (!outEdges.empty()) { + visitedSibNodeIds->insert(outEdges[0].id); + idAndMemrefToFuse->first = outEdges[0].id; + idAndMemrefToFuse->second = outEdges[0].value; + return true; + } + } + return false; + } + + void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, + Node *dstNode) { + // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. + mdg->updateEdges(sibNode->id, dstNode->id); + + // Collect slice loop stats. + LoopNestStateCollector sliceCollector; + sliceCollector.collect(sliceLoopNest.getOperation()); + // Promote single iteration slice loops to single IV value. + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); + } + + // Collect dst loop stats after memref privatization transformation. + auto dstForInst = cast<AffineForOp>(dstNode->op); + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstForInst.getOperation()); + // Clear and add back loads and stores + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old sibling loop nest if it no longer has outgoing dependence + // edges, and it does not write to a memref which escapes the + // function. + if (mdg->getOutEdgeCount(sibNode->id) == 0) { + mdg->removeNode(sibNode->id); + sibNode->op->erase(); + } + } + + // Clean up any allocs with no users. + void eraseUnusedMemRefAllocations() { + for (auto &pair : mdg->memrefEdgeCount) { + if (pair.second > 0) + continue; + auto memref = pair.first; + // Skip if there exist other uses (return operation or function calls). + if (!memref->use_empty()) + continue; + // Use list expected to match the dep graph info. + auto *op = memref->getDefiningOp(); + if (isa_and_nonnull<AllocOp>(op)) + op->erase(); + } + } +}; + +} // end anonymous namespace + +void LoopFusion::runOnFunction() { + // Override if a command line argument was provided. + if (clFusionFastMemorySpace.getNumOccurrences() > 0) { + fastMemorySpace = clFusionFastMemorySpace.getValue(); + } + + // Override if a command line argument was provided. + if (clFusionLocalBufThreshold.getNumOccurrences() > 0) { + localBufSizeThreshold = clFusionLocalBufThreshold * 1024; + } + + if (clMaximalLoopFusion.getNumOccurrences() > 0) + maximalFusion = clMaximalLoopFusion; + + MemRefDependenceGraph g; + if (g.init(getFunction())) + GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion) + .run(); +} + +static PassRegistration<LoopFusion> pass("affine-loop-fusion", + "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp new file mode 100644 index 00000000000..fb3d0c0b45c --- /dev/null +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -0,0 +1,140 @@ +//===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop invariant code motion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopLikeInterface.h" +#include "mlir/Transforms/SideEffectsInterface.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "licm" + +using namespace mlir; + +namespace { + +using SideEffecting = SideEffectsInterface::SideEffecting; + +/// Loop invariant code motion (LICM) pass. +struct LoopInvariantCodeMotion : public OperationPass<LoopInvariantCodeMotion> { +public: + void runOnOperation() override; +}; + +// Checks whether the given op can be hoisted by checking that +// - the op and any of its contained operations do not depend on SSA values +// defined inside of the loop (by means of calling definedOutside). +// - the op has no side-effects. If sideEffecting is Never, sideeffects of this +// op and its nested ops are ignored. +static bool canBeHoisted(Operation *op, + function_ref<bool(Value)> definedOutside, + SideEffecting sideEffecting, + SideEffectsInterface &interface) { + // Check that dependencies are defined outside of loop. + if (!llvm::all_of(op->getOperands(), definedOutside)) + return false; + // Check whether this op is side-effect free. If we already know that there + // can be no side-effects because the surrounding op has claimed so, we can + // (and have to) skip this step. + auto thisOpIsSideEffecting = sideEffecting; + if (thisOpIsSideEffecting != SideEffecting::Never) { + thisOpIsSideEffecting = interface.isSideEffecting(op); + // If the op always has sideeffects, we cannot hoist. + if (thisOpIsSideEffecting == SideEffecting::Always) + return false; + } + // Recurse into the regions for this op and check whether the contained ops + // can be hoisted. + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + for (auto &innerOp : block) { + if (innerOp.isKnownTerminator()) + continue; + if (!canBeHoisted(&innerOp, definedOutside, thisOpIsSideEffecting, + interface)) + return false; + } + } + } + return true; +} + +static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike, + SideEffectsInterface &interface) { + auto &loopBody = looplike.getLoopBody(); + + // We use two collections here as we need to preserve the order for insertion + // and this is easiest. + SmallPtrSet<Operation *, 8> willBeMovedSet; + SmallVector<Operation *, 8> opsToMove; + + // Helper to check whether an operation is loop invariant wrt. SSA properties. + auto isDefinedOutsideOfBody = [&](Value value) { + auto definingOp = value->getDefiningOp(); + return (definingOp && !!willBeMovedSet.count(definingOp)) || + looplike.isDefinedOutsideOfLoop(value); + }; + + // Do not use walk here, as we do not want to go into nested regions and hoist + // operations from there. These regions might have semantics unknown to this + // rewriting. If the nested regions are loops, they will have been processed. + for (auto &block : loopBody) { + for (auto &op : block.without_terminator()) { + if (canBeHoisted(&op, isDefinedOutsideOfBody, + mlir::SideEffectsDialectInterface::Recursive, + interface)) { + opsToMove.push_back(&op); + willBeMovedSet.insert(&op); + } + } + } + + // For all instructions that we found to be invariant, move outside of the + // loop. + auto result = looplike.moveOutOfLoop(opsToMove); + LLVM_DEBUG(looplike.print(llvm::dbgs() << "Modified loop\n")); + return result; +} + +} // end anonymous namespace + +void LoopInvariantCodeMotion::runOnOperation() { + SideEffectsInterface interface(&getContext()); + // Walk through all loops in a function in innermost-loop-first order. This + // way, we first LICM from the inner loop, and place the ops in + // the outer loop, which in turn can be further LICM'ed. + getOperation()->walk([&](Operation *op) { + if (auto looplike = dyn_cast<LoopLikeOpInterface>(op)) { + LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n")); + if (failed(moveLoopInvariantCode(looplike, interface))) + signalPassFailure(); + } + }); +} + +// Include the generated code for the loop-like interface here, as it otherwise +// has no compilation unit. This works as loop-invariant code motion is the +// only user of that interface. +#include "mlir/Transforms/LoopLikeInterface.cpp.inc" + +std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() { + return std::make_unique<LoopInvariantCodeMotion>(); +} + +static PassRegistration<LoopInvariantCodeMotion> + pass("loop-invariant-code-motion", + "Hoist loop invariant instructions outside of the loop"); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp new file mode 100644 index 00000000000..d3dc81760fc --- /dev/null +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -0,0 +1,402 @@ +//===- LoopTiling.cpp --- Loop tiling pass ------------------------------*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to tile loop nests. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +using namespace mlir; + +#define DEBUG_TYPE "affine-loop-tile" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +static llvm::cl::opt<unsigned long long> + clCacheSizeKiB("tile-cache-size", + llvm::cl::desc("Set size of cache to tile for in KiB"), + llvm::cl::cat(clOptionsCategory)); + +// Tile size to use for all loops (overrides -tile-sizes if provided). +static llvm::cl::opt<unsigned> + clTileSize("tile-size", llvm::cl::desc("Use this tile size for all loops"), + llvm::cl::cat(clOptionsCategory)); + +// List of tile sizes. If any of them aren't provided, they are filled with +// clTileSize / kDefaultTileSize. +static llvm::cl::list<unsigned> clTileSizes( + "tile-sizes", + llvm::cl::desc( + "List of tile sizes for each perfect nest (overridden by -tile-size)"), + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); + +namespace { + +/// A pass to perform loop tiling on all suitable loop nests of a Function. +struct LoopTiling : public FunctionPass<LoopTiling> { + explicit LoopTiling(uint64_t cacheSizeBytes = kDefaultCacheMemCapacity, + bool avoidMaxMinBounds = true) + : cacheSizeBytes(cacheSizeBytes), avoidMaxMinBounds(avoidMaxMinBounds) {} + + void runOnFunction() override; + void getTileSizes(ArrayRef<AffineForOp> band, + SmallVectorImpl<unsigned> *tileSizes); + + // Default tile size if nothing is provided. + constexpr static unsigned kDefaultTileSize = 4; + constexpr static uint64_t kDefaultCacheMemCapacity = 512 * 1024UL; + + // Capacity of the cache to tile for. + uint64_t cacheSizeBytes; + // If true, tile sizes are set to avoid max/min in bounds if possible. + bool avoidMaxMinBounds; +}; + +} // end anonymous namespace + +/// Creates a pass to perform loop tiling on all suitable loop nests of a +/// Function. +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createLoopTilingPass(uint64_t cacheSizeBytes) { + return std::make_unique<LoopTiling>(cacheSizeBytes); +} + +// Move the loop body of AffineForOp 'src' from 'src' into the specified +// location in destination's body, ignoring the terminator. +static inline void moveLoopBody(AffineForOp src, AffineForOp dest, + Block::iterator loc) { + auto &insts = src.getBody()->getOperations(); + dest.getBody()->getOperations().splice(loc, insts, insts.begin(), + std::prev(insts.end())); +} + +// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's +// body. +static inline void moveLoopBody(AffineForOp src, AffineForOp dest) { + moveLoopBody(src, dest, dest.getBody()->begin()); +} + +/// Constructs and sets new loop bounds after tiling for the case of +/// hyper-rectangular index sets, where the bounds of one dimension do not +/// depend on other dimensions. Bounds of each dimension can thus be treated +/// independently, and deriving the new bounds is much simpler and faster +/// than for the case of tiling arbitrary polyhedral shapes. +static void +constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops, + MutableArrayRef<AffineForOp> newLoops, + ArrayRef<unsigned> tileSizes) { + assert(!origLoops.empty()); + assert(origLoops.size() == tileSizes.size()); + + OpBuilder b(origLoops[0].getOperation()); + unsigned width = origLoops.size(); + + // Bounds for tile space loops. + for (unsigned i = 0; i < width; i++) { + auto lbOperands = origLoops[i].getLowerBoundOperands(); + auto ubOperands = origLoops[i].getUpperBoundOperands(); + SmallVector<Value, 4> newLbOperands(lbOperands); + SmallVector<Value, 4> newUbOperands(ubOperands); + newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap()); + newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap()); + newLoops[i].setStep(tileSizes[i]); + } + // Bounds for intra-tile loops. + for (unsigned i = 0; i < width; i++) { + int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]); + auto mayBeConstantCount = getConstantTripCount(origLoops[i]); + // The lower bound is just the tile-space loop. + AffineMap lbMap = b.getDimIdentityMap(); + newLoops[width + i].setLowerBound( + /*operands=*/newLoops[i].getInductionVar(), lbMap); + + // Set the upper bound. + if (mayBeConstantCount.hasValue() && + mayBeConstantCount.getValue() < tileSizes[i]) { + // Trip count is less than tile size; upper bound is the trip count. + auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue()); + newLoops[width + i].setUpperBoundMap(ubMap); + } else if (largestDiv % tileSizes[i] != 0) { + // Intra-tile loop ii goes from i to min(i + tileSize, ub_i). + // Construct the upper bound map; the operands are the original operands + // with 'i' (tile-space loop) appended to it. The new upper bound map is + // the original one with an additional expression i + tileSize appended. + auto ub = origLoops[i].getUpperBound(); + SmallVector<Value, 4> ubOperands; + ubOperands.reserve(ub.getNumOperands() + 1); + auto origUbMap = ub.getMap(); + // Add dim operands from original upper bound. + for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) { + ubOperands.push_back(ub.getOperand(j)); + } + // Add dim operand for new loop upper bound. + ubOperands.push_back(newLoops[i].getInductionVar()); + // Add symbol operands from original upper bound. + for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) { + ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); + } + SmallVector<AffineExpr, 4> boundExprs; + boundExprs.reserve(1 + origUbMap.getNumResults()); + auto dim = b.getAffineDimExpr(origUbMap.getNumDims()); + // The new upper bound map is the original one with an additional + // expression i + tileSize appended. + boundExprs.push_back(dim + tileSizes[i]); + boundExprs.append(origUbMap.getResults().begin(), + origUbMap.getResults().end()); + auto ubMap = AffineMap::get(origUbMap.getNumDims() + 1, + origUbMap.getNumSymbols(), boundExprs); + newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap); + } else { + // No need of the min expression. + auto dim = b.getAffineDimExpr(0); + auto ubMap = AffineMap::get(1, 0, dim + tileSizes[i]); + newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap); + } + } +} + +/// Tiles the specified band of perfectly nested loops creating tile-space loops +/// and intra-tile loops. A band is a contiguous set of loops. +// TODO(bondhugula): handle non hyper-rectangular spaces. +LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band, + ArrayRef<unsigned> tileSizes) { + assert(!band.empty()); + assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes"); + + // Check if the supplied for op's are all successively nested. + for (unsigned i = 1, e = band.size(); i < e; i++) { + assert(band[i].getParentOp() == band[i - 1].getOperation()); + } + + auto origLoops = band; + + AffineForOp rootAffineForOp = origLoops[0]; + auto loc = rootAffineForOp.getLoc(); + // Note that width is at least one since band isn't empty. + unsigned width = band.size(); + + SmallVector<AffineForOp, 12> newLoops(2 * width); + AffineForOp innermostPointLoop; + + // The outermost among the loops as we add more.. + auto *topLoop = rootAffineForOp.getOperation(); + + // Add intra-tile (or point) loops. + for (unsigned i = 0; i < width; i++) { + OpBuilder b(topLoop); + // Loop bounds will be set later. + auto pointLoop = b.create<AffineForOp>(loc, 0, 0); + pointLoop.getBody()->getOperations().splice( + pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), + topLoop); + newLoops[2 * width - 1 - i] = pointLoop; + topLoop = pointLoop.getOperation(); + if (i == 0) + innermostPointLoop = pointLoop; + } + + // Add tile space loops; + for (unsigned i = width; i < 2 * width; i++) { + OpBuilder b(topLoop); + // Loop bounds will be set later. + auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0); + tileSpaceLoop.getBody()->getOperations().splice( + tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), + topLoop); + newLoops[2 * width - i - 1] = tileSpaceLoop; + topLoop = tileSpaceLoop.getOperation(); + } + + // Move the loop body of the original nest to the new one. + moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); + + SmallVector<Value, 8> origLoopIVs; + extractForInductionVars(band, &origLoopIVs); + SmallVector<Optional<Value>, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); + FlatAffineConstraints cst; + getIndexSet(band, &cst); + + if (!cst.isHyperRectangular(0, width)) { + rootAffineForOp.emitError("tiled code generation unimplemented for the " + "non-hyperrectangular case"); + return failure(); + } + + constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes); + // In this case, the point loop IVs just replace the original ones. + for (unsigned i = 0; i < width; i++) { + origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width].getInductionVar()); + } + + // Erase the old loop nest. + rootAffineForOp.erase(); + + return success(); +} + +// Identify valid and profitable bands of loops to tile. This is currently just +// a temporary placeholder to test the mechanics of tiled code generation. +// Returns all maximal outermost perfect loop nests to tile. +static void getTileableBands(FuncOp f, + std::vector<SmallVector<AffineForOp, 6>> *bands) { + // Get maximal perfect nest of 'affine.for' insts starting from root + // (inclusive). + auto getMaximalPerfectLoopNest = [&](AffineForOp root) { + SmallVector<AffineForOp, 6> band; + getPerfectlyNestedLoops(band, root); + bands->push_back(band); + }; + + for (auto &block : f) + for (auto &op : block) + if (auto forOp = dyn_cast<AffineForOp>(op)) + getMaximalPerfectLoopNest(forOp); +} + +// Reduce each tile size to the largest divisor of the corresponding trip count +// (if the trip count is known). +static void adjustToDivisorsOfTripCounts(ArrayRef<AffineForOp> band, + SmallVectorImpl<unsigned> *tileSizes) { + assert(band.size() == tileSizes->size() && "invalid tile size count"); + for (unsigned i = 0, e = band.size(); i < e; i++) { + unsigned &tSizeAdjusted = (*tileSizes)[i]; + auto mayConst = getConstantTripCount(band[i]); + if (!mayConst.hasValue()) + continue; + // Adjust the tile size to largest factor of the trip count less than + // tSize. + uint64_t constTripCount = mayConst.getValue(); + if (constTripCount > 1 && tSizeAdjusted > constTripCount / 2) + tSizeAdjusted = constTripCount / 2; + while (constTripCount % tSizeAdjusted != 0) + tSizeAdjusted--; + } +} + +// Returns tile sizes to use. Checks CL options; if none are specified, sets it +// based on a simple model that looks at the memory footprint and determines +// tile sizes assuming identity accesses / 1:1 tile size proportional footprint +// along each of the dimensions being tiled. +// TODO(mlir-team): evolve this model. Tile size determination is a large area +// to play with in general. +void LoopTiling::getTileSizes(ArrayRef<AffineForOp> band, + SmallVectorImpl<unsigned> *tileSizes) { + if (band.empty()) + return; + + tileSizes->resize(band.size()); + + // Use clTileSize for all loops if specified. + if (clTileSize.getNumOccurrences() > 0) { + std::fill(tileSizes->begin(), tileSizes->end(), clTileSize); + return; + } + + // Use clTileSizes and fill them with default tile size if it's short. + if (!clTileSizes.empty()) { + std::fill(tileSizes->begin(), tileSizes->end(), + LoopTiling::kDefaultTileSize); + std::copy(clTileSizes.begin(), + clTileSizes.begin() + std::min(clTileSizes.size(), band.size()), + tileSizes->begin()); + return; + } + + // The first loop in the band. + auto rootForOp = band[0]; + (void)rootForOp; + + // Obtain memory footprint and set tile sizes so that a tile fits in + // the cache size. This is an approximation with the assumption that the + // footprint increases with the tile size linearly in that dimension (i.e., + // assumes one-to-one access function). + auto fp = getMemoryFootprintBytes(band[0], 0); + if (!fp.hasValue()) { + // Fill with default tile sizes if footprint is unknown. + std::fill(tileSizes->begin(), tileSizes->end(), + LoopTiling::kDefaultTileSize); + if (avoidMaxMinBounds) + adjustToDivisorsOfTripCounts(band, tileSizes); + LLVM_DEBUG( + rootForOp.emitWarning("memory footprint unknown: using default tile " + "sizes adjusted to trip count divisors")); + return; + } + + // Check how many times larger the cache size is when compared to footprint. + uint64_t excessFactor = llvm::divideCeil(fp.getValue(), cacheSizeBytes); + if (excessFactor <= 1) { + // No need of any tiling - set tile size to 1. + std::fill(tileSizes->begin(), tileSizes->end(), 1); + return; + } + + // Divide all loops equally in an attempt to reduce footprint. + // TODO(bondhugula): this is approximate. Ideally, obtain reuse factor / + // profitability along each dimension and weight tile sizes based on that as + // one possible approach. Or compute a polynomial in tile sizes and solve for + // it. + + // For an n-d tileable band, compute n^th root of the excess. + unsigned tSize = + static_cast<unsigned>(floorl(std::pow(excessFactor, 1.0 / band.size()))); + // We'll keep a running product to determine the last tile size better. + unsigned cumulProductOfTileSizes = 1; + for (unsigned i = 0, e = band.size(); i < e; i++) { + if (i < e - 1) + (*tileSizes)[i] = tSize; + else + // Set last tile size to cover the balance. + (*tileSizes)[i] = std::max( + 1U, static_cast<unsigned>(excessFactor / cumulProductOfTileSizes)); + cumulProductOfTileSizes *= (*tileSizes)[i]; + } + if (avoidMaxMinBounds) + adjustToDivisorsOfTripCounts(band, tileSizes); +} + +void LoopTiling::runOnFunction() { + // Override cache size if provided on command line. + if (clCacheSizeKiB.getNumOccurrences() > 0) + cacheSizeBytes = clCacheSizeKiB * 1024; + + // Bands of loops to tile. + std::vector<SmallVector<AffineForOp, 6>> bands; + getTileableBands(getFunction(), &bands); + + for (auto &band : bands) { + // Set up tile sizes; fill missing tile sizes at the end with default tile + // size or clTileSize if one was provided. + SmallVector<unsigned, 6> tileSizes; + getTileSizes(band, &tileSizes); + if (llvm::DebugFlag) { + auto diag = band[0].emitRemark("using tile sizes ["); + for (auto tSize : tileSizes) + diag << tSize << " "; + diag << "]\n"; + } + if (failed(tileCodeGen(band, tileSizes))) + return signalPassFailure(); + } +} + +constexpr unsigned LoopTiling::kDefaultTileSize; +constexpr uint64_t LoopTiling::kDefaultCacheMemCapacity; + +static PassRegistration<LoopTiling> pass("affine-loop-tile", "Tile loop nests"); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp new file mode 100644 index 00000000000..e94c6c8b0bb --- /dev/null +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -0,0 +1,182 @@ +//===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop unrolling. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "affine-loop-unroll" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +// Loop unrolling factor. +static llvm::cl::opt<unsigned> clUnrollFactor( + "unroll-factor", + llvm::cl::desc("Use this unroll factor for all loops being unrolled"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<bool> clUnrollFull("unroll-full", + llvm::cl::desc("Fully unroll loops"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<unsigned> clUnrollNumRepetitions( + "unroll-num-reps", + llvm::cl::desc("Unroll innermost loops repeatedly this many times"), + llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::opt<unsigned> clUnrollFullThreshold( + "unroll-full-threshold", llvm::cl::Hidden, + llvm::cl::desc( + "Unroll all loops with trip count less than or equal to this"), + llvm::cl::cat(clOptionsCategory)); + +namespace { +/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a +/// full unroll threshold was specified, in which case, fully unrolls all loops +/// with trip count less than the specified threshold. The latter is for testing +/// purposes, especially for testing outer loop unrolling. +struct LoopUnroll : public FunctionPass<LoopUnroll> { + const Optional<unsigned> unrollFactor; + const Optional<bool> unrollFull; + // Callback to obtain unroll factors; if this has a callable target, takes + // precedence over command-line argument or passed argument. + const std::function<unsigned(AffineForOp)> getUnrollFactor; + + explicit LoopUnroll( + Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, + const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) + : unrollFactor(unrollFactor), unrollFull(unrollFull), + getUnrollFactor(getUnrollFactor) {} + + void runOnFunction() override; + + /// Unroll this for op. Returns failure if nothing was done. + LogicalResult runOnAffineForOp(AffineForOp forOp); + + static const unsigned kDefaultUnrollFactor = 4; +}; +} // end anonymous namespace + +void LoopUnroll::runOnFunction() { + // Gathers all innermost loops through a post order pruned walk. + struct InnermostLoopGatherer { + // Store innermost loops as we walk. + std::vector<AffineForOp> loops; + + void walkPostOrder(FuncOp f) { + for (auto &b : f) + walkPostOrder(b.begin(), b.end()); + } + + bool walkPostOrder(Block::iterator Start, Block::iterator End) { + bool hasInnerLoops = false; + // We need to walk all elements since all innermost loops need to be + // gathered as opposed to determining whether this list has any inner + // loops or not. + while (Start != End) + hasInnerLoops |= walkPostOrder(&(*Start++)); + return hasInnerLoops; + } + bool walkPostOrder(Operation *opInst) { + bool hasInnerLoops = false; + for (auto ®ion : opInst->getRegions()) + for (auto &block : region) + hasInnerLoops |= walkPostOrder(block.begin(), block.end()); + if (isa<AffineForOp>(opInst)) { + if (!hasInnerLoops) + loops.push_back(cast<AffineForOp>(opInst)); + return true; + } + return hasInnerLoops; + } + }; + + if (clUnrollFull.getNumOccurrences() > 0 && + clUnrollFullThreshold.getNumOccurrences() > 0) { + // Store short loops as we walk. + std::vector<AffineForOp> loops; + + // Gathers all loops with trip count <= minTripCount. Do a post order walk + // so that loops are gathered from innermost to outermost (or else unrolling + // an outer one may delete gathered inner ones). + getFunction().walk([&](AffineForOp forOp) { + Optional<uint64_t> tripCount = getConstantTripCount(forOp); + if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) + loops.push_back(forOp); + }); + for (auto forOp : loops) + loopUnrollFull(forOp); + return; + } + + unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 + ? clUnrollNumRepetitions + : 1; + // If the call back is provided, we will recurse until no loops are found. + FuncOp func = getFunction(); + for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { + InnermostLoopGatherer ilg; + ilg.walkPostOrder(func); + auto &loops = ilg.loops; + if (loops.empty()) + break; + bool unrolled = false; + for (auto forOp : loops) + unrolled |= succeeded(runOnAffineForOp(forOp)); + if (!unrolled) + // Break out if nothing was unrolled. + break; + } +} + +/// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, +/// failure otherwise. The default unroll factor is 4. +LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { + // Use the function callback if one was provided. + if (getUnrollFactor) { + return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); + } + // Unroll by the factor passed, if any. + if (unrollFactor.hasValue()) + return loopUnrollByFactor(forOp, unrollFactor.getValue()); + // Unroll by the command line factor if one was specified. + if (clUnrollFactor.getNumOccurrences() > 0) + return loopUnrollByFactor(forOp, clUnrollFactor); + // Unroll completely if full loop unroll was specified. + if (clUnrollFull.getNumOccurrences() > 0 || + (unrollFull.hasValue() && unrollFull.getValue())) + return loopUnrollFull(forOp); + + // Unroll by four otherwise. + return loopUnrollByFactor(forOp, kDefaultUnrollFactor); +} + +std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass( + int unrollFactor, int unrollFull, + const std::function<unsigned(AffineForOp)> &getUnrollFactor) { + return std::make_unique<LoopUnroll>( + unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), + unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); +} + +static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops"); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp new file mode 100644 index 00000000000..6c74d545497 --- /dev/null +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -0,0 +1,235 @@ +//===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop unroll and jam. Unroll and jam is a transformation +// that improves locality, in particular, register reuse, while also improving +// operation level parallelism. The example below shows what it does in nearly +// the general case. Loop unroll and jam currently works if the bounds of the +// loops inner to the loop being unroll-jammed do not depend on the latter. +// +// Before After unroll and jam of i by factor 2: +// +// for i, step = 2 +// for i S1(i); +// S1; S2(i); +// S2; S1(i+1); +// for j S2(i+1); +// S3; for j +// S4; S3(i, j); +// S5; S4(i, j); +// S6; S3(i+1, j) +// S4(i+1, j) +// S5(i); +// S6(i); +// S5(i+1); +// S6(i+1); +// +// Note: 'if/else' blocks are not jammed. So, if there are loops inside if +// op's, bodies of those loops will not be jammed. +//===----------------------------------------------------------------------===// +#include "mlir/Transforms/Passes.h" + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; + +#define DEBUG_TYPE "affine-loop-unroll-jam" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +// Loop unroll and jam factor. +static llvm::cl::opt<unsigned> + clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden, + llvm::cl::desc("Use this unroll jam factor for all loops" + " (default 4)"), + llvm::cl::cat(clOptionsCategory)); + +namespace { +/// Loop unroll jam pass. Currently, this just unroll jams the first +/// outer loop in a Function. +struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> { + Optional<unsigned> unrollJamFactor; + static const unsigned kDefaultUnrollJamFactor = 4; + + explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None) + : unrollJamFactor(unrollJamFactor) {} + + void runOnFunction() override; + LogicalResult runOnAffineForOp(AffineForOp forOp); +}; +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { + return std::make_unique<LoopUnrollAndJam>( + unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor)); +} + +void LoopUnrollAndJam::runOnFunction() { + // Currently, just the outermost loop from the first loop nest is + // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on + // any for operation. + auto &entryBlock = getFunction().front(); + if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front())) + runOnAffineForOp(forOp); +} + +/// Unroll and jam a 'affine.for' op. Default unroll jam factor is +/// kDefaultUnrollJamFactor. Return failure if nothing was done. +LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) { + // Unroll and jam by the factor that was passed if any. + if (unrollJamFactor.hasValue()) + return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); + // Otherwise, unroll jam by the command-line factor if one was specified. + if (clUnrollJamFactor.getNumOccurrences() > 0) + return loopUnrollJamByFactor(forOp, clUnrollJamFactor); + + // Unroll and jam by four otherwise. + return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); +} + +LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp, + uint64_t unrollJamFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); + + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollJamFactor) + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); +} + +/// Unrolls and jams this loop by the specified factor. +LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, + uint64_t unrollJamFactor) { + // Gathers all maximal sub-blocks of operations that do not themselves + // include a for op (a operation could have a descendant for op though + // in its tree). Ignore the block terminators. + struct JamBlockGatherer { + // Store iterators to the first and last op of each sub-block found. + std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; + + // This is a linear time walk. + void walk(Operation *op) { + for (auto ®ion : op->getRegions()) + for (auto &block : region) + walk(block); + } + void walk(Block &block) { + for (auto it = block.begin(), e = std::prev(block.end()); it != e;) { + auto subBlockStart = it; + while (it != e && !isa<AffineForOp>(&*it)) + ++it; + if (it != subBlockStart) + subBlocks.push_back({subBlockStart, std::prev(it)}); + // Process all for insts that appear next. + while (it != e && isa<AffineForOp>(&*it)) + walk(&*it++); + } + } + }; + + assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); + + if (unrollJamFactor == 1) + return promoteIfSingleIteration(forOp); + + if (forOp.getBody()->empty() || + forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) + return failure(); + + // Loops where both lower and upper bounds are multi-result maps won't be + // unrolled (since the trip can't be expressed as an affine function in + // general). + // TODO(mlir-team): this may not be common, but we could support the case + // where the lower bound is a multi-result map and the ub is a single result + // one. + if (forOp.getLowerBoundMap().getNumResults() != 1) + return failure(); + + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); + // If the trip count is lower than the unroll jam factor, no unroll jam. + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollJamFactor) + return failure(); + + auto *forInst = forOp.getOperation(); + + // Gather all sub-blocks to jam upon the loop being unrolled. + JamBlockGatherer jbg; + jbg.walk(forInst); + auto &subBlocks = jbg.subBlocks; + + // Generate the cleanup loop if trip count isn't a multiple of + // unrollJamFactor. + if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) { + // Insert the cleanup loop right after 'forOp'. + OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); + auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst)); + // Adjust the lower bound of the cleanup loop; its upper bound is the same + // as the original loop's upper bound. + AffineMap cleanupMap; + SmallVector<Value, 4> cleanupOperands; + getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap, + &cleanupOperands, builder); + cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap); + + // Promote the cleanup loop if it has turned into a single iteration loop. + promoteIfSingleIteration(cleanupAffineForOp); + + // Adjust the upper bound of the original loop - it will be the same as the + // cleanup loop's lower bound. Its lower bound remains unchanged. + forOp.setUpperBound(cleanupOperands, cleanupMap); + } + + // Scale the step of loop being unroll-jammed by the unroll-jam factor. + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollJamFactor); + + auto forOpIV = forOp.getInductionVar(); + // Unroll and jam (appends unrollJamFactor - 1 additional copies). + for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { + // Operand map persists across all sub-blocks. + BlockAndValueMapping operandMapping; + for (auto &subBlock : subBlocks) { + // Builder to insert unroll-jammed bodies. Insert right at the end of + // sub-block. + OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); + + // If the induction variable is used, create a remapping to the value for + // this unrolled instance. + if (!forOpIV->use_empty()) { + // iv' = iv + i, i = 1 to unrollJamFactor-1. + auto d0 = builder.getAffineDimExpr(0); + auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); + auto ivUnroll = + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); + } + // Clone the sub-block being unroll-jammed. + for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { + builder.clone(*it, operandMapping); + } + } + } + + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(forOp); + return success(); +} + +static PassRegistration<LoopUnrollAndJam> pass("affine-loop-unroll-jam", + "Unroll and jam loops"); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp new file mode 100644 index 00000000000..e2514e12cc7 --- /dev/null +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -0,0 +1,227 @@ +//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to forward memref stores to loads, thereby +// potentially getting rid of intermediate memref's entirely. +// TODO(mlir-team): In the future, similar techniques could be used to eliminate +// dead memref store's and perform more complex forwarding when support for +// SSA scalars live out of 'affine.for'/'affine.if' statements is available. +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/Dominance.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallPtrSet.h" +#include <algorithm> + +#define DEBUG_TYPE "memref-dataflow-opt" + +using namespace mlir; + +namespace { + +// The store to load forwarding relies on three conditions: +// +// 1) they need to have mathematically equivalent affine access functions +// (checked after full composition of load/store operands); this implies that +// they access the same single memref element for all iterations of the common +// surrounding loop, +// +// 2) the store op should dominate the load op, +// +// 3) among all op's that satisfy both (1) and (2), the one that postdominates +// all store op's that have a dependence into the load, is provably the last +// writer to the particular memref location being loaded at the load op, and its +// store value can be forwarded to the load. Note that the only dependences +// that are to be considered are those that are satisfied at the block* of the +// innermost common surrounding loop of the <store, load> being considered. +// +// (* A dependence being satisfied at a block: a dependence that is satisfied by +// virtue of the destination operation appearing textually / lexically after +// the source operation within the body of a 'affine.for' operation; thus, a +// dependence is always either satisfied by a loop or by a block). +// +// The above conditions are simple to check, sufficient, and powerful for most +// cases in practice - they are sufficient, but not necessary --- since they +// don't reason about loops that are guaranteed to execute at least once or +// multiple sources to forward from. +// +// TODO(mlir-team): more forwarding can be done when support for +// loop/conditional live-out SSA values is available. +// TODO(mlir-team): do general dead store elimination for memref's. This pass +// currently only eliminates the stores only if no other loads/uses (other +// than dealloc) remain. +// +struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> { + void runOnFunction() override; + + void forwardStoreToLoad(AffineLoadOp loadOp); + + // A list of memref's that are potentially dead / could be eliminated. + SmallPtrSet<Value, 4> memrefsToErase; + // Load op's whose results were replaced by those forwarded from stores. + SmallVector<Operation *, 8> loadOpsToErase; + + DominanceInfo *domInfo = nullptr; + PostDominanceInfo *postDomInfo = nullptr; +}; + +} // end anonymous namespace + +/// Creates a pass to perform optimizations relying on memref dataflow such as +/// store to load forwarding, elimination of dead stores, and dead allocs. +std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefDataFlowOptPass() { + return std::make_unique<MemRefDataFlowOpt>(); +} + +// This is a straightforward implementation not optimized for speed. Optimize +// if needed. +void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { + Operation *loadOpInst = loadOp.getOperation(); + + // First pass over the use list to get minimum number of surrounding + // loops common between the load op and the store op, with min taken across + // all store ops. + SmallVector<Operation *, 8> storeOps; + unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); + for (auto *user : loadOp.getMemRef()->getUsers()) { + auto storeOp = dyn_cast<AffineStoreOp>(user); + if (!storeOp) + continue; + auto *storeOpInst = storeOp.getOperation(); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); + storeOps.push_back(storeOpInst); + } + + // The list of store op candidates for forwarding that satisfy conditions + // (1) and (2) above - they will be filtered later when checking (3). + SmallVector<Operation *, 8> fwdingCandidates; + + // Store ops that have a dependence into the load (even if they aren't + // forwarding candidates). Each forwarding candidate will be checked for a + // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. + SmallVector<Operation *, 8> depSrcStores; + + for (auto *storeOpInst : storeOps) { + MemRefAccess srcAccess(storeOpInst); + MemRefAccess destAccess(loadOpInst); + // Find stores that may be reaching the load. + FlatAffineConstraints dependenceConstraints; + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + unsigned d; + // Dependences at loop depth <= minSurroundingLoops do NOT matter. + for (d = nsLoops + 1; d > minSurroundingLoops; d--) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) + break; + } + if (d == minSurroundingLoops) + continue; + + // Stores that *may* be reaching the load. + depSrcStores.push_back(storeOpInst); + + // 1. Check if the store and the load have mathematically equivalent + // affine access functions; this implies that they statically refer to the + // same single memref element. As an example this filters out cases like: + // store %A[%i0 + 1] + // load %A[%i0] + // store %A[%M] + // load %A[%N] + // Use the AffineValueMap difference based memref access equality checking. + if (srcAccess != destAccess) + continue; + + // 2. The store has to dominate the load op to be candidate. + if (!domInfo->dominates(storeOpInst, loadOpInst)) + continue; + + // We now have a candidate for forwarding. + fwdingCandidates.push_back(storeOpInst); + } + + // 3. Of all the store op's that meet the above criteria, the store that + // postdominates all 'depSrcStores' (if one exists) is the unique store + // providing the value to the load, i.e., provably the last writer to that + // memref loc. + // Note: this can be implemented in a cleaner way with postdominator tree + // traversals. Consider this for the future if needed. + Operation *lastWriteStoreOp = nullptr; + for (auto *storeOpInst : fwdingCandidates) { + if (llvm::all_of(depSrcStores, [&](Operation *depStore) { + return postDomInfo->postDominates(storeOpInst, depStore); + })) { + lastWriteStoreOp = storeOpInst; + break; + } + } + if (!lastWriteStoreOp) + return; + + // Perform the actual store to load forwarding. + Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore(); + loadOp.replaceAllUsesWith(storeVal); + // Record the memref for a later sweep to optimize away. + memrefsToErase.insert(loadOp.getMemRef()); + // Record this to erase later. + loadOpsToErase.push_back(loadOpInst); +} + +void MemRefDataFlowOpt::runOnFunction() { + // Only supports single block functions at the moment. + FuncOp f = getFunction(); + if (f.getBlocks().size() != 1) { + markAllAnalysesPreserved(); + return; + } + + domInfo = &getAnalysis<DominanceInfo>(); + postDomInfo = &getAnalysis<PostDominanceInfo>(); + + loadOpsToErase.clear(); + memrefsToErase.clear(); + + // Walk all load's and perform load/store forwarding. + f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); + + // Erase all load op's whose results were replaced with store fwd'ed ones. + for (auto *loadOp : loadOpsToErase) { + loadOp->erase(); + } + + // Check if the store fwd'ed memrefs are now left with only stores and can + // thus be completely deleted. Note: the canonicalize pass should be able + // to do this as well, but we'll do it here since we collected these anyway. + for (auto memref : memrefsToErase) { + // If the memref hasn't been alloc'ed in this function, skip. + Operation *defInst = memref->getDefiningOp(); + if (!defInst || !isa<AllocOp>(defInst)) + // TODO(mlir-team): if the memref was returned by a 'call' operation, we + // could still erase it if the call had no side-effects. + continue; + if (llvm::any_of(memref->getUsers(), [&](Operation *ownerInst) { + return (!isa<AffineStoreOp>(ownerInst) && !isa<DeallocOp>(ownerInst)); + })) + continue; + + // Erase all stores, the dealloc, and the alloc on the memref. + for (auto *user : llvm::make_early_inc_range(memref->getUsers())) + user->erase(); + defInst->erase(); + } +} + +static PassRegistration<MemRefDataFlowOpt> + pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs"); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp new file mode 100644 index 00000000000..dce02737064 --- /dev/null +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -0,0 +1,379 @@ +//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to pipeline data transfers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "affine-pipeline-data-transfer" + +using namespace mlir; + +namespace { + +struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { + void runOnFunction() override; + void runOnAffineForOp(AffineForOp forOp); + + std::vector<AffineForOp> forOps; +}; + +} // end anonymous namespace + +/// Creates a pass to pipeline explicit movement of data across levels of the +/// memory hierarchy. +std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() { + return std::make_unique<PipelineDataTransfer>(); +} + +// Returns the position of the tag memref operand given a DMA operation. +// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are +// added. TODO(b/117228571) +static unsigned getTagMemRefPos(Operation &dmaInst) { + assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst)); + if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) { + return dmaStartOp.getTagMemRefOperandIndex(); + } + // First operand for a dma finish operation. + return 0; +} + +/// Doubles the buffer of the supplied memref on the specified 'affine.for' +/// operation by adding a leading dimension of size two to the memref. +/// Replaces all uses of the old memref by the new one while indexing the newly +/// added dimension by the loop IV of the specified 'affine.for' operation +/// modulo 2. Returns false if such a replacement cannot be performed. +static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { + auto *forBody = forOp.getBody(); + OpBuilder bInner(forBody, forBody->begin()); + + // Doubles the shape with a leading dimension extent of 2. + auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { + // Add the leading dimension in the shape for the double buffer. + ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); + SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); + newShape[0] = 2; + std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); + auto newMemRefType = + MemRefType::get(newShape, oldMemRefType.getElementType(), {}, + oldMemRefType.getMemorySpace()); + return newMemRefType; + }; + + auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); + auto newMemRefType = doubleShape(oldMemRefType); + + // The double buffer is allocated right before 'forInst'. + auto *forInst = forOp.getOperation(); + OpBuilder bOuter(forInst); + // Put together alloc operands for any dynamic dimensions of the memref. + SmallVector<Value, 4> allocOperands; + unsigned dynamicDimCount = 0; + for (auto dimSize : oldMemRefType.getShape()) { + if (dimSize == -1) + allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, + dynamicDimCount++)); + } + + // Create and place the alloc right before the 'affine.for' operation. + Value newMemRef = + bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); + + // Create 'iv mod 2' value to index the leading dimension. + auto d0 = bInner.getAffineDimExpr(0); + int64_t step = forOp.getStep(); + auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + {d0.floorDiv(step) % 2}); + auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, + forOp.getInductionVar()); + + // replaceAllMemRefUsesWith will succeed unless the forOp body has + // non-dereferencing uses of the memref (dealloc's are fine though). + if (failed(replaceAllMemRefUsesWith( + oldMemRef, newMemRef, + /*extraIndices=*/{ivModTwoOp}, + /*indexRemap=*/AffineMap(), + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/&*forOp.getBody()->begin()))) { + LLVM_DEBUG( + forOp.emitError("memref replacement for double buffering failed")); + ivModTwoOp.erase(); + return false; + } + // Insert the dealloc op right after the for loop. + bOuter.setInsertionPointAfter(forInst); + bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef); + + return true; +} + +/// Returns success if the IR is in a valid state. +void PipelineDataTransfer::runOnFunction() { + // Do a post order walk so that inner loop DMAs are processed first. This is + // necessary since 'affine.for' operations nested within would otherwise + // become invalid (erased) when the outer loop is pipelined (the pipelined one + // gets deleted and replaced by a prologue, a new steady-state loop and an + // epilogue). + forOps.clear(); + getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); + for (auto forOp : forOps) + runOnAffineForOp(forOp); +} + +// Check if tags of the dma start op and dma wait op match. +static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { + if (startOp.getTagMemRef() != waitOp.getTagMemRef()) + return false; + auto startIndices = startOp.getTagIndices(); + auto waitIndices = waitOp.getTagIndices(); + // Both of these have the same number of indices since they correspond to the + // same tag memref. + for (auto it = startIndices.begin(), wIt = waitIndices.begin(), + e = startIndices.end(); + it != e; ++it, ++wIt) { + // Keep it simple for now, just checking if indices match. + // TODO(mlir-team): this would in general need to check if there is no + // intervening write writing to the same tag location, i.e., memory last + // write/data flow analysis. This is however sufficient/powerful enough for + // now since the DMA generation pass or the input for it will always have + // start/wait with matching tags (same SSA operand indices). + if (*it != *wIt) + return false; + } + return true; +} + +// Identify matching DMA start/finish operations to overlap computation with. +static void findMatchingStartFinishInsts( + AffineForOp forOp, + SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) { + + // Collect outgoing DMA operations - needed to check for dependences below. + SmallVector<AffineDmaStartOp, 4> outgoingDmaOps; + for (auto &op : *forOp.getBody()) { + auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); + if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) + outgoingDmaOps.push_back(dmaStartOp); + } + + SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts; + for (auto &op : *forOp.getBody()) { + // Collect DMA finish operations. + if (isa<AffineDmaWaitOp>(op)) { + dmaFinishInsts.push_back(&op); + continue; + } + auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); + if (!dmaStartOp) + continue; + + // Only DMAs incoming into higher memory spaces are pipelined for now. + // TODO(bondhugula): handle outgoing DMA pipelining. + if (!dmaStartOp.isDestMemorySpaceFaster()) + continue; + + // Check for dependence with outgoing DMAs. Doing this conservatively. + // TODO(andydavis,bondhugula): use the dependence analysis to check for + // dependences between an incoming and outgoing DMA in the same iteration. + auto it = outgoingDmaOps.begin(); + for (; it != outgoingDmaOps.end(); ++it) { + if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) + break; + } + if (it != outgoingDmaOps.end()) + continue; + + // We only double buffer if the buffer is not live out of loop. + auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); + bool escapingUses = false; + for (auto *user : memref->getUsers()) { + // We can double buffer regardless of dealloc's outside the loop. + if (isa<DeallocOp>(user)) + continue; + if (!forOp.getBody()->findAncestorOpInBlock(*user)) { + LLVM_DEBUG(llvm::dbgs() + << "can't pipeline: buffer is live out of loop\n";); + escapingUses = true; + break; + } + } + if (!escapingUses) + dmaStartInsts.push_back(&op); + } + + // For each start operation, we look for a matching finish operation. + for (auto *dmaStartInst : dmaStartInsts) { + for (auto *dmaFinishInst : dmaFinishInsts) { + if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst), + cast<AffineDmaWaitOp>(dmaFinishInst))) { + startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); + break; + } + } + } +} + +/// Overlap DMA transfers with computation in this loop. If successful, +/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// inserted right before where it was. +void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { + auto mayBeConstTripCount = getConstantTripCount(forOp); + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG( + forOp.emitRemark("won't pipeline due to unknown trip count loop")); + return; + } + + SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; + findMatchingStartFinishInsts(forOp, startWaitPairs); + + if (startWaitPairs.empty()) { + LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n")); + return; + } + + // Double the buffers for the higher memory space memref's. + // Identify memref's to replace by scanning through all DMA start + // operations. A DMA start operation has two memref's - the one from the + // higher level of memory hierarchy is the one to double buffer. + // TODO(bondhugula): check whether double-buffering is even necessary. + // TODO(bondhugula): make this work with different layouts: assuming here that + // the dimension we are adding here for the double buffering is the outermost + // dimension. + for (auto &pair : startWaitPairs) { + auto *dmaStartInst = pair.first; + Value oldMemRef = dmaStartInst->getOperand( + cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos()); + if (!doubleBuffer(oldMemRef, forOp)) { + // Normally, double buffering should not fail because we already checked + // that there are no uses outside. + LLVM_DEBUG(llvm::dbgs() + << "double buffering failed for" << dmaStartInst << "\n";); + // IR still valid and semantically correct. + return; + } + // If the old memref has no more uses, remove its 'dead' alloc if it was + // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' + // operation could have been used on it if it was dynamically shaped in + // order to create the double buffer above.) + // '-canonicalize' does this in a more general way, but we'll anyway do the + // simple/common case so that the output / test cases looks clear. + if (auto *allocInst = oldMemRef->getDefiningOp()) { + if (oldMemRef->use_empty()) { + allocInst->erase(); + } else if (oldMemRef->hasOneUse()) { + if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) { + dealloc.erase(); + allocInst->erase(); + } + } + } + } + + // Double the buffers for tag memrefs. + for (auto &pair : startWaitPairs) { + auto *dmaFinishInst = pair.second; + Value oldTagMemRef = + dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); + if (!doubleBuffer(oldTagMemRef, forOp)) { + LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); + return; + } + // If the old tag has no uses or a single dealloc use, remove it. + // (canonicalization handles more complex cases). + if (auto *tagAllocInst = oldTagMemRef->getDefiningOp()) { + if (oldTagMemRef->use_empty()) { + tagAllocInst->erase(); + } else if (oldTagMemRef->hasOneUse()) { + if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef->user_begin())) { + dealloc.erase(); + tagAllocInst->erase(); + } + } + } + } + + // Double buffering would have invalidated all the old DMA start/wait insts. + startWaitPairs.clear(); + findMatchingStartFinishInsts(forOp, startWaitPairs); + + // Store shift for operation for later lookup for AffineApplyOp's. + DenseMap<Operation *, unsigned> instShiftMap; + for (auto &pair : startWaitPairs) { + auto *dmaStartInst = pair.first; + assert(isa<AffineDmaStartOp>(dmaStartInst)); + instShiftMap[dmaStartInst] = 0; + // Set shifts for DMA start op's affine operand computation slices to 0. + SmallVector<AffineApplyOp, 4> sliceOps; + mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); + if (!sliceOps.empty()) { + for (auto sliceOp : sliceOps) { + instShiftMap[sliceOp.getOperation()] = 0; + } + } else { + // If a slice wasn't created, the reachable affine.apply op's from its + // operands are the ones that go with it. + SmallVector<Operation *, 4> affineApplyInsts; + SmallVector<Value, 4> operands(dmaStartInst->getOperands()); + getReachableAffineApplyOps(operands, affineApplyInsts); + for (auto *op : affineApplyInsts) { + instShiftMap[op] = 0; + } + } + } + // Everything else (including compute ops and dma finish) are shifted by one. + for (auto &op : *forOp.getBody()) { + if (instShiftMap.find(&op) == instShiftMap.end()) { + instShiftMap[&op] = 1; + } + } + + // Get shifts stored in map. + std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size()); + unsigned s = 0; + for (auto &op : *forOp.getBody()) { + assert(instShiftMap.find(&op) != instShiftMap.end()); + shifts[s++] = instShiftMap[&op]; + + // Tagging operations with shifts for debugging purposes. + LLVM_DEBUG({ + OpBuilder b(&op); + op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); + }); + } + + if (!isInstwiseShiftValid(forOp, shifts)) { + // Violates dependences. + LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); + return; + } + + if (failed(instBodySkew(forOp, shifts))) { + LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); + return; + } +} + +static PassRegistration<PipelineDataTransfer> pass( + "affine-pipeline-data-transfer", + "Pipeline non-blocking data transfers between explicitly managed levels of " + "the memory hierarchy"); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp new file mode 100644 index 00000000000..217e06bc877 --- /dev/null +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -0,0 +1,108 @@ +//===- SimplifyAffineStructures.cpp ---------------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to simplify affine structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +#define DEBUG_TYPE "simplify-affine-structure" + +using namespace mlir; + +namespace { + +/// Simplifies affine maps and sets appearing in the operations of the Function. +/// This part is mainly to test the simplifyAffineExpr method. In addition, +/// all memrefs with non-trivial layout maps are converted to ones with trivial +/// identity layout ones. +struct SimplifyAffineStructures + : public FunctionPass<SimplifyAffineStructures> { + void runOnFunction() override; + + /// Utility to simplify an affine attribute and update its entry in the parent + /// operation if necessary. + template <typename AttributeT> + void simplifyAndUpdateAttribute(Operation *op, Identifier name, + AttributeT attr) { + auto &simplified = simplifiedAttributes[attr]; + if (simplified == attr) + return; + + // This is a newly encountered attribute. + if (!simplified) { + // Try to simplify the value of the attribute. + auto value = attr.getValue(); + auto simplifiedValue = simplify(value); + if (simplifiedValue == value) { + simplified = attr; + return; + } + simplified = AttributeT::get(simplifiedValue); + } + + // Simplification was successful, so update the attribute. + op->setAttr(name, simplified); + } + + /// Performs basic integer set simplifications. Checks if it's empty, and + /// replaces it with the canonical empty set if it is. + IntegerSet simplify(IntegerSet set) { + FlatAffineConstraints fac(set); + if (fac.isEmpty()) + return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), + &getContext()); + return set; + } + + /// Performs basic affine map simplifications. + AffineMap simplify(AffineMap map) { + MutableAffineMap mMap(map); + mMap.simplify(); + return mMap.getAffineMap(); + } + + DenseMap<Attribute, Attribute> simplifiedAttributes; +}; + +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimplifyAffineStructuresPass() { + return std::make_unique<SimplifyAffineStructures>(); +} + +void SimplifyAffineStructures::runOnFunction() { + auto func = getFunction(); + simplifiedAttributes.clear(); + func.walk([&](Operation *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) + simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); + else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) + simplifyAndUpdateAttribute(opInst, attr.first, setAttr); + } + }); + + // Turn memrefs' non-identity layouts maps into ones with identity. Collect + // alloc ops first and then process since normalizeMemRef replaces/erases ops + // during memref rewriting. + SmallVector<AllocOp, 4> allocOps; + func.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (auto allocOp : allocOps) { + normalizeMemRef(allocOp); + } +} + +static PassRegistration<SimplifyAffineStructures> + pass("simplify-affine-structures", + "Simplify affine expressions in maps/sets and normalize memrefs"); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp new file mode 100644 index 00000000000..cdfc7fd7e41 --- /dev/null +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -0,0 +1,37 @@ +//===- StripDebugInfo.cpp - Pass to strip debug information ---------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct StripDebugInfo : public FunctionPass<StripDebugInfo> { + void runOnFunction() override; +}; +} // end anonymous namespace + +void StripDebugInfo::runOnFunction() { + FuncOp func = getFunction(); + auto unknownLoc = UnknownLoc::get(&getContext()); + + // Strip the debug info from the function and its operations. + func.setLoc(unknownLoc); + func.walk([&](Operation *op) { op->setLoc(unknownLoc); }); +} + +/// Creates a pass to strip debug information from a function. +std::unique_ptr<OpPassBase<FuncOp>> mlir::createStripDebugInfoPass() { + return std::make_unique<StripDebugInfo>(); +} + +static PassRegistration<StripDebugInfo> + pass("strip-debuginfo", "Strip debug info from functions and operations"); diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt new file mode 100644 index 00000000000..4e1dc5e4b4e --- /dev/null +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -0,0 +1,21 @@ +add_llvm_library(MLIRTransformUtils + FoldUtils.cpp + GreedyPatternRewriteDriver.cpp + InliningUtils.cpp + LoopFusionUtils.cpp + LoopUtils.cpp + RegionUtils.cpp + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms + ) + +add_dependencies(MLIRTransformUtils MLIRStandardOpsIncGen) +target_link_libraries(MLIRTransformUtils + MLIRAffineOps + MLIRAnalysis + MLIRLoopOps + MLIRPass + MLIRStandardOps + ) diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp new file mode 100644 index 00000000000..719c6fac731 --- /dev/null +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -0,0 +1,246 @@ +//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various operation fold utilities. These utilities are +// intended to be used by passes to unify and simply their logic. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/FoldUtils.h" + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +/// Given an operation, find the parent region that folded constants should be +/// inserted into. +static Region *getInsertionRegion( + DialectInterfaceCollection<OpFolderDialectInterface> &interfaces, + Operation *op) { + while (Region *region = op->getParentRegion()) { + // Insert in this region for any of the following scenarios: + // * The parent is unregistered, or is known to be isolated from above. + // * The parent is a top-level operation. + auto *parentOp = region->getParentOp(); + if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() || + !parentOp->getBlock()) + return region; + + // Otherwise, check if this region is a desired insertion region. + auto *interface = interfaces.getInterfaceFor(parentOp); + if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) + return region; + + // Traverse up the parent looking for an insertion region. + op = parentOp; + } + llvm_unreachable("expected valid insertion region"); +} + +/// A utility function used to materialize a constant for a given attribute and +/// type. On success, a valid constant value is returned. Otherwise, null is +/// returned +static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, + Attribute value, Type type, + Location loc) { + auto insertPt = builder.getInsertionPoint(); + (void)insertPt; + + // Ask the dialect to materialize a constant operation for this value. + if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { + assert(insertPt == builder.getInsertionPoint()); + assert(matchPattern(constOp, m_Constant(&value))); + return constOp; + } + + // If the dialect is unable to materialize a constant, check to see if the + // standard constant can be used. + if (ConstantOp::isBuildableWith(value, type)) + return builder.create<ConstantOp>(loc, type, value); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// OperationFolder +//===----------------------------------------------------------------------===// + +LogicalResult OperationFolder::tryToFold( + Operation *op, function_ref<void(Operation *)> processGeneratedConstants, + function_ref<void(Operation *)> preReplaceAction) { + // If this is a unique'd constant, return failure as we know that it has + // already been folded. + if (referencedDialects.count(op)) + return failure(); + + // Try to fold the operation. + SmallVector<Value, 8> results; + if (failed(tryToFold(op, results, processGeneratedConstants))) + return failure(); + + // Constant folding succeeded. We will start replacing this op's uses and + // eventually erase this op. Invoke the callback provided by the caller to + // perform any pre-replacement action. + if (preReplaceAction) + preReplaceAction(op); + + // Check to see if the operation was just updated in place. + if (results.empty()) + return success(); + + // Otherwise, replace all of the result values and erase the operation. + for (unsigned i = 0, e = results.size(); i != e; ++i) + op->getResult(i)->replaceAllUsesWith(results[i]); + op->erase(); + return success(); +} + +/// Notifies that the given constant `op` should be remove from this +/// OperationFolder's internal bookkeeping. +void OperationFolder::notifyRemoval(Operation *op) { + // Check to see if this operation is uniqued within the folder. + auto it = referencedDialects.find(op); + if (it == referencedDialects.end()) + return; + + // Get the constant value for this operation, this is the value that was used + // to unique the operation internally. + Attribute constValue; + matchPattern(op, m_Constant(&constValue)); + assert(constValue); + + // Get the constant map that this operation was uniqued in. + auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)]; + + // Erase all of the references to this operation. + auto type = op->getResult(0)->getType(); + for (auto *dialect : it->second) + uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); + referencedDialects.erase(it); +} + +/// Tries to perform folding on the given `op`. If successful, populates +/// `results` with the results of the folding. +LogicalResult OperationFolder::tryToFold( + Operation *op, SmallVectorImpl<Value> &results, + function_ref<void(Operation *)> processGeneratedConstants) { + SmallVector<Attribute, 8> operandConstants; + SmallVector<OpFoldResult, 8> foldResults; + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. + operandConstants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); + + // If this is a commutative binary operation with a constant on the left + // side move it to the right side. + if (operandConstants.size() == 2 && operandConstants[0] && + !operandConstants[1] && op->isCommutative()) { + std::swap(op->getOpOperand(0), op->getOpOperand(1)); + std::swap(operandConstants[0], operandConstants[1]); + } + + // Attempt to constant fold the operation. + if (failed(op->fold(operandConstants, foldResults))) + return failure(); + + // Check to see if the operation was just updated in place. + if (foldResults.empty()) + return success(); + assert(foldResults.size() == op->getNumResults()); + + // Create a builder to insert new operations into the entry block of the + // insertion region. + auto *insertRegion = getInsertionRegion(interfaces, op); + auto &entry = insertRegion->front(); + OpBuilder builder(&entry, entry.begin()); + + // Get the constant map for the insertion region of this operation. + auto &uniquedConstants = foldScopes[insertRegion]; + + // Create the result constants and replace the results. + auto *dialect = op->getDialect(); + for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { + assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); + + // Check if the result was an SSA value. + if (auto repl = foldResults[i].dyn_cast<Value>()) { + results.emplace_back(repl); + continue; + } + + // Check to see if there is a canonicalized version of this constant. + auto res = op->getResult(i); + Attribute attrRepl = foldResults[i].get<Attribute>(); + if (auto *constOp = + tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, + res->getType(), op->getLoc())) { + results.push_back(constOp->getResult(0)); + continue; + } + // If materialization fails, cleanup any operations generated for the + // previous results and return failure. + for (Operation &op : llvm::make_early_inc_range( + llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { + notifyRemoval(&op); + op.erase(); + } + return failure(); + } + + // Process any newly generated operations. + if (processGeneratedConstants) { + for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) + processGeneratedConstants(&*i); + } + + return success(); +} + +/// Try to get or create a new constant entry. On success this returns the +/// constant operation value, nullptr otherwise. +Operation *OperationFolder::tryGetOrCreateConstant( + ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, + Attribute value, Type type, Location loc) { + // Check if an existing mapping already exists. + auto constKey = std::make_tuple(dialect, value, type); + auto *&constInst = uniquedConstants[constKey]; + if (constInst) + return constInst; + + // If one doesn't exist, try to materialize one. + if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) + return nullptr; + + // Check to see if the generated constant is in the expected dialect. + auto *newDialect = constInst->getDialect(); + if (newDialect == dialect) { + referencedDialects[constInst].push_back(dialect); + return constInst; + } + + // If it isn't, then we also need to make sure that the mapping for the new + // dialect is valid. + auto newKey = std::make_tuple(newDialect, value, type); + + // If an existing operation in the new dialect already exists, delete the + // materialized operation in favor of the existing one. + if (auto *existingOp = uniquedConstants.lookup(newKey)) { + constInst->erase(); + referencedDialects[existingOp].push_back(dialect); + return constInst = existingOp; + } + + // Otherwise, update the new dialect to the materialized operation. + referencedDialects[constInst].assign({dialect, newDialect}); + auto newIt = uniquedConstants.insert({newKey, constInst}); + return newIt.first->second; +} diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp new file mode 100644 index 00000000000..1eb9c57639a --- /dev/null +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -0,0 +1,247 @@ +//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements mlir::applyPatternsGreedily. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +#define DEBUG_TYPE "pattern-matcher" + +static llvm::cl::opt<unsigned> maxPatternMatchIterations( + "mlir-max-pattern-match-iterations", + llvm::cl::desc("Max number of iterations scanning for pattern match"), + llvm::cl::init(10)); + +namespace { + +/// This is a worklist-driven driver for the PatternMatcher, which repeatedly +/// applies the locally optimal patterns in a roughly "bottom up" way. +class GreedyPatternRewriteDriver : public PatternRewriter { +public: + explicit GreedyPatternRewriteDriver(MLIRContext *ctx, + const OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + worklist.reserve(64); + } + + /// Perform the rewrites. Return true if the rewrite converges in + /// `maxIterations`. + bool simplify(MutableArrayRef<Region> regions, int maxIterations); + + void addToWorklist(Operation *op) { + // Check to see if the worklist already contains this op. + if (worklistMap.count(op)) + return; + + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } + + Operation *popFromWorklist() { + auto *op = worklist.back(); + worklist.pop_back(); + + // This operation is no longer in the worklist, keep worklistMap up to date. + if (op) + worklistMap.erase(op); + return op; + } + + /// If the specified operation is in the worklist, remove it. If not, this is + /// a no-op. + void removeFromWorklist(Operation *op) { + auto it = worklistMap.find(op); + if (it != worklistMap.end()) { + assert(worklist[it->second] == op && "malformed worklist data structure"); + worklist[it->second] = nullptr; + worklistMap.erase(it); + } + } + + // These are hooks implemented for PatternRewriter. +protected: + // Implement the hook for inserting operations, and make sure that newly + // inserted ops are added to the worklist for processing. + Operation *insert(Operation *op) override { + addToWorklist(op); + return OpBuilder::insert(op); + } + + // If an operation is about to be removed, make sure it is not in our + // worklist anymore because we'd get dangling references to it. + void notifyOperationRemoved(Operation *op) override { + addToWorklist(op->getOperands()); + op->walk([this](Operation *operation) { + removeFromWorklist(operation); + folder.notifyRemoval(operation); + }); + } + + // When the root of a pattern is about to be replaced, it can trigger + // simplifications to its users - make sure to add them to the worklist + // before the root is changed. + void notifyRootReplaced(Operation *op) override { + for (auto result : op->getResults()) + for (auto *user : result->getUsers()) + addToWorklist(user); + } + +private: + // Look over the provided operands for any defining operations that should + // be re-added to the worklist. This function should be called when an + // operation is modified or removed, as it may trigger further + // simplifications. + template <typename Operands> void addToWorklist(Operands &&operands) { + for (Value operand : operands) { + // If the use count of this operand is now < 2, we re-add the defining + // operation to the worklist. + // TODO(riverriddle) This is based on the fact that zero use operations + // may be deleted, and that single use values often have more + // canonicalization opportunities. + if (!operand->use_empty() && !operand->hasOneUse()) + continue; + if (auto *defInst = operand->getDefiningOp()) + addToWorklist(defInst); + } + } + + /// The low-level pattern matcher. + RewritePatternMatcher matcher; + + /// The worklist for this transformation keeps track of the operations that + /// need to be revisited, plus their index in the worklist. This allows us to + /// efficiently remove operations from the worklist when they are erased, even + /// if they aren't the root of a pattern. + std::vector<Operation *> worklist; + DenseMap<Operation *, unsigned> worklistMap; + + /// Non-pattern based folder for operations. + OperationFolder folder; +}; +} // end anonymous namespace + +/// Perform the rewrites. +bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions, + int maxIterations) { + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; + + bool changed = false; + int i = 0; + do { + // Add all nested operations to the worklist. + for (auto ®ion : regions) + region.walk(collectOps); + + // These are scratch vectors used in the folding loop below. + SmallVector<Value, 8> originalOperands, resultValues; + + changed = false; + while (!worklist.empty()) { + auto *op = popFromWorklist(); + + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; + + // If the operation has no side effects, and no users, then it is + // trivially dead - remove it. + if (op->hasNoSideEffect() && op->use_empty()) { + // Be careful to update bookkeeping. + notifyOperationRemoved(op); + op->erase(); + continue; + } + + // Collects all the operands and result uses of the given `op` into work + // list. Also remove `op` and nested ops from worklist. + originalOperands.assign(op->operand_begin(), op->operand_end()); + auto preReplaceAction = [&](Operation *op) { + // Add the operands to the worklist for visitation. + addToWorklist(originalOperands); + + // Add all the users of the result to the worklist so we make sure + // to revisit them. + for (auto result : op->getResults()) + for (auto *operand : result->getUsers()) + addToWorklist(operand); + + notifyOperationRemoved(op); + }; + + // Try to fold this op. + if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) { + changed |= true; + continue; + } + + // Make sure that any new operations are inserted at this point. + setInsertionPoint(op); + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do here. + changed |= matcher.matchAndRewrite(op, *this); + } + + // After applying patterns, make sure that the CFG of each of the regions is + // kept up to date. + changed |= succeeded(simplifyRegions(regions)); + } while (changed && ++i < maxIterations); + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; +} + +/// Rewrite the regions of the specified operation, which must be isolated from +/// above, by repeatedly applying the highest benefit patterns in a greedy +/// work-list driven manner. Return true if no more patterns can be matched in +/// the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. +/// +bool mlir::applyPatternsGreedily(Operation *op, + const OwningRewritePatternList &patterns) { + return applyPatternsGreedily(op->getRegions(), patterns); +} + +/// Rewrite the given regions, which must be isolated from above. +bool mlir::applyPatternsGreedily(MutableArrayRef<Region> regions, + const OwningRewritePatternList &patterns) { + if (regions.empty()) + return true; + + // The top-level operation must be known to be isolated from above to + // prevent performing canonicalizations on operations defined at or above + // the region containing 'op'. + auto regionIsIsolated = [](Region ®ion) { + return region.getParentOp()->isKnownIsolatedFromAbove(); + }; + (void)regionIsIsolated; + assert(llvm::all_of(regions, regionIsIsolated) && + "patterns can only be applied to operations IsolatedFromAbove"); + + // Start the pattern driver. + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); + bool converged = driver.simplify(regions, maxPatternMatchIterations); + LLVM_DEBUG(if (!converged) { + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; + }); + return converged; +} diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp new file mode 100644 index 00000000000..1ac286c67fb --- /dev/null +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -0,0 +1,356 @@ +//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous inlining utilities. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/InliningUtils.h" + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "inlining" + +using namespace mlir; + +/// Remap locations from the inlined blocks with CallSiteLoc locations with the +/// provided caller location. +static void +remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, + Location callerLoc) { + DenseMap<Location, Location> mappedLocations; + auto remapOpLoc = [&](Operation *op) { + auto it = mappedLocations.find(op->getLoc()); + if (it == mappedLocations.end()) { + auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); + it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; + } + op->setLoc(it->second); + }; + for (auto &block : inlinedBlocks) + block.walk(remapOpLoc); +} + +static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, + BlockAndValueMapping &mapper) { + auto remapOperands = [&](Operation *op) { + for (auto &operand : op->getOpOperands()) + if (auto mappedOp = mapper.lookupOrNull(operand.get())) + operand.set(mappedOp); + }; + for (auto &block : inlinedBlocks) + block.walk(remapOperands); +} + +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +bool InlinerInterface::isLegalToInline( + Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { + // Regions can always be inlined into functions. + if (isa<FuncOp>(dest->getParentOp())) + return true; + + auto *handler = getInterfaceFor(dest->getParentOp()); + return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; +} + +bool InlinerInterface::isLegalToInline( + Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { + auto *handler = getInterfaceFor(op); + return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; +} + +bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { + auto *handler = getInterfaceFor(op); + return handler ? handler->shouldAnalyzeRecursively(op) : true; +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { + auto *handler = getInterfaceFor(op); + assert(handler && "expected valid dialect handler"); + handler->handleTerminator(op, newDest); +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void InlinerInterface::handleTerminator(Operation *op, + ArrayRef<Value> valuesToRepl) const { + auto *handler = getInterfaceFor(op); + assert(handler && "expected valid dialect handler"); + handler->handleTerminator(op, valuesToRepl); +} + +/// Utility to check that all of the operations within 'src' can be inlined. +static bool isLegalToInline(InlinerInterface &interface, Region *src, + Region *insertRegion, + BlockAndValueMapping &valueMapping) { + for (auto &block : *src) { + for (auto &op : block) { + // Check this operation. + if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { + LLVM_DEBUG({ + llvm::dbgs() << "* Illegal to inline because of op: "; + op.dump(); + }); + return false; + } + // Check any nested regions. + if (interface.shouldAnalyzeRecursively(&op) && + llvm::any_of(op.getRegions(), [&](Region ®ion) { + return !isLegalToInline(interface, ®ion, insertRegion, + valueMapping); + })) + return false; + } + } + return true; +} + +//===----------------------------------------------------------------------===// +// Inline Methods +//===----------------------------------------------------------------------===// + +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + BlockAndValueMapping &mapper, + ArrayRef<Value> resultsToReplace, + Optional<Location> inlineLoc, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + + // Check that all of the region arguments have been mapped. + auto *srcEntryBlock = &src->front(); + if (llvm::any_of(srcEntryBlock->getArguments(), + [&](BlockArgument arg) { return !mapper.contains(arg); })) + return failure(); + + // The insertion point must be within a block. + Block *insertBlock = inlinePoint->getBlock(); + if (!insertBlock) + return failure(); + Region *insertRegion = insertBlock->getParent(); + + // Check that the operations within the source region are valid to inline. + if (!interface.isLegalToInline(insertRegion, src, mapper) || + !isLegalToInline(interface, src, insertRegion, mapper)) + return failure(); + + // Split the insertion block. + Block *postInsertBlock = + insertBlock->splitBlock(++inlinePoint->getIterator()); + + // Check to see if the region is being cloned, or moved inline. In either + // case, move the new blocks after the 'insertBlock' to improve IR + // readability. + if (shouldCloneInlinedRegion) + src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); + else + insertRegion->getBlocks().splice(postInsertBlock->getIterator(), + src->getBlocks(), src->begin(), + src->end()); + + // Get the range of newly inserted blocks. + auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), + postInsertBlock->getIterator()); + Block *firstNewBlock = &*newBlocks.begin(); + + // Remap the locations of the inlined operations if a valid source location + // was provided. + if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) + remapInlinedLocations(newBlocks, *inlineLoc); + + // If the blocks were moved in-place, make sure to remap any necessary + // operands. + if (!shouldCloneInlinedRegion) + remapInlinedOperands(newBlocks, mapper); + + // Process the newly inlined blocks. + interface.processInlinedBlocks(newBlocks); + + // Handle the case where only a single block was inlined. + if (std::next(newBlocks.begin()) == newBlocks.end()) { + // Have the interface handle the terminator of this block. + auto *firstBlockTerminator = firstNewBlock->getTerminator(); + interface.handleTerminator(firstBlockTerminator, resultsToReplace); + firstBlockTerminator->erase(); + + // Merge the post insert block into the cloned entry block. + firstNewBlock->getOperations().splice(firstNewBlock->end(), + postInsertBlock->getOperations()); + postInsertBlock->erase(); + } else { + // Otherwise, there were multiple blocks inlined. Add arguments to the post + // insertion block to represent the results to replace. + for (Value resultToRepl : resultsToReplace) { + resultToRepl->replaceAllUsesWith( + postInsertBlock->addArgument(resultToRepl->getType())); + } + + /// Handle the terminators for each of the new blocks. + for (auto &newBlock : newBlocks) + interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); + } + + // Splice the instructions of the inlined entry block into the insert block. + insertBlock->getOperations().splice(insertBlock->end(), + firstNewBlock->getOperations()); + firstNewBlock->erase(); + return success(); +} + +/// This function is an overload of the above 'inlineRegion' that allows for +/// providing the set of operands ('inlinedOperands') that should be used +/// in-favor of the region arguments when inlining. +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + ArrayRef<Value> inlinedOperands, + ArrayRef<Value> resultsToReplace, + Optional<Location> inlineLoc, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + + auto *entryBlock = &src->front(); + if (inlinedOperands.size() != entryBlock->getNumArguments()) + return failure(); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { + // Verify that the types of the provided values match the function argument + // types. + BlockArgument regionArg = entryBlock->getArgument(i); + if (inlinedOperands[i]->getType() != regionArg->getType()) + return failure(); + mapper.map(regionArg, inlinedOperands[i]); + } + + // Call into the main region inliner function. + return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, + inlineLoc, shouldCloneInlinedRegion); +} + +/// Utility function used to generate a cast operation from the given interface, +/// or return nullptr if a cast could not be generated. +static Value materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl<Operation *> &castOps, + OpBuilder &castBuilder, Value arg, Type type, + Location conversionLoc) { + if (!interface) + return nullptr; + + // Check to see if the interface for the call can materialize a conversion. + Operation *castOp = interface->materializeCallConversion(castBuilder, arg, + type, conversionLoc); + if (!castOp) + return nullptr; + castOps.push_back(castOp); + + // Ensure that the generated cast is correct. + assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && + castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); + return castOp->getResult(0); +} + +/// This function inlines a given region, 'src', of a callable operation, +/// 'callable', into the location defined by the given call operation. This +/// function returns failure if inlining is not possible, success otherwise. On +/// failure, no changes are made to the module. 'shouldCloneInlinedRegion' +/// corresponds to whether the source region should be cloned into the 'call' or +/// spliced directly. +LogicalResult mlir::inlineCall(InlinerInterface &interface, + CallOpInterface call, + CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + auto *entryBlock = &src->front(); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + + // Make sure that the number of arguments and results matchup between the call + // and the region. + SmallVector<Value, 8> callOperands(call.getArgOperands()); + SmallVector<Value, 8> callResults(call.getOperation()->getResults()); + if (callOperands.size() != entryBlock->getNumArguments() || + callResults.size() != callableResultTypes.size()) + return failure(); + + // A set of cast operations generated to matchup the signature of the region + // with the signature of the call. + SmallVector<Operation *, 4> castOps; + castOps.reserve(callOperands.size() + callResults.size()); + + // Functor used to cleanup generated state on failure. + auto cleanupState = [&] { + for (auto *op : castOps) { + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + } + return failure(); + }; + + // Builder used for any conversion operations that need to be materialized. + OpBuilder castBuilder(call); + Location castLoc = call.getLoc(); + auto *callInterface = interface.getInterfaceFor(call.getDialect()); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { + BlockArgument regionArg = entryBlock->getArgument(i); + Value operand = callOperands[i]; + + // If the call operand doesn't match the expected region argument, try to + // generate a cast. + Type regionArgType = regionArg->getType(); + if (operand->getType() != regionArgType) { + if (!(operand = materializeConversion(callInterface, castOps, castBuilder, + operand, regionArgType, castLoc))) + return cleanupState(); + } + mapper.map(regionArg, operand); + } + + // Ensure that the resultant values of the call, match the callable. + castBuilder.setInsertionPointAfter(call); + for (unsigned i = 0, e = callResults.size(); i != e; ++i) { + Value callResult = callResults[i]; + if (callResult->getType() == callableResultTypes[i]) + continue; + + // Generate a conversion that will produce the original type, so that the IR + // is still valid after the original call gets replaced. + Value castResult = + materializeConversion(callInterface, castOps, castBuilder, callResult, + callResult->getType(), castLoc); + if (!castResult) + return cleanupState(); + callResult->replaceAllUsesWith(castResult); + castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); + } + + // Attempt to inline the call. + if (failed(inlineRegion(interface, src, call, mapper, callResults, + call.getLoc(), shouldCloneInlinedRegion))) + return cleanupState(); + return success(); +} diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp new file mode 100644 index 00000000000..b0d9fdf5fd8 --- /dev/null +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -0,0 +1,480 @@ +//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop fusion transformation utility functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/LoopFusionUtils.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "loop-fusion-utils" + +using namespace mlir; + +// Gathers all load and store memref accesses in 'opA' into 'values', where +// 'values[memref] == true' for each store operation. +static void getLoadAndStoreMemRefAccesses(Operation *opA, + DenseMap<Value, bool> &values) { + opA->walk([&](Operation *op) { + if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { + if (values.count(loadOp.getMemRef()) == 0) + values[loadOp.getMemRef()] = false; + } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + values[storeOp.getMemRef()] = true; + } + }); +} + +// Returns true if 'op' is a load or store operation which access an memref +// accessed 'values' and at least one of the access is a store operation. +// Returns false otherwise. +static bool isDependentLoadOrStoreOp(Operation *op, + DenseMap<Value, bool> &values) { + if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { + return values.count(loadOp.getMemRef()) > 0 && + values[loadOp.getMemRef()] == true; + } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + return values.count(storeOp.getMemRef()) > 0; + } + return false; +} + +// Returns the first operation in range ('opA', 'opB') which has a data +// dependence on 'opA'. Returns 'nullptr' of no dependence exists. +static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opA'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap<Value, bool> values; + getLoadAndStoreMemRefAccesses(opA, values); + + // For each 'opX' in block in range ('opA', 'opB'), check if there is a data + // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref + // and at least one of the accesses is a store). + Operation *firstDepOp = nullptr; + for (Block::iterator it = std::next(Block::iterator(opA)); + it != Block::iterator(opB); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) + firstDepOp = opX; + }); + if (firstDepOp) + break; + } + return firstDepOp; +} + +// Returns the last operation 'opX' in range ('opA', 'opB'), for which there +// exists a data dependence from 'opX' to 'opB'. +// Returns 'nullptr' of no dependence exists. +static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opB'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap<Value, bool> values; + getLoadAndStoreMemRefAccesses(opB, values); + + // For each 'opX' in block in range ('opA', 'opB') in reverse order, + // check if there is a data dependence from 'opX' to 'opB': + // *) 'opX' and 'opB' access the same memref and at least one of the accesses + // is a store. + // *) 'opX' produces an SSA Value which is used by 'opB'. + Operation *lastDepOp = nullptr; + for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); + it != Block::reverse_iterator(opA); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { + if (isDependentLoadOrStoreOp(op, values)) { + lastDepOp = opX; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + for (auto value : op->getResults()) { + for (auto user : value->getUsers()) { + SmallVector<AffineForOp, 4> loops; + // Check if any loop in loop nest surrounding 'user' is 'opB'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { + lastDepOp = opX; + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + if (lastDepOp) + break; + } + return lastDepOp; +} + +// Computes and returns an insertion point operation, before which the +// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving +// dependences. Returns nullptr if no such insertion point is found. +static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, + AffineForOp dstForOp) { + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + auto *firstDepOpA = + getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + auto *lastDepOpB = + getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + // Block: + // ... + // |-- opA + // | ... + // | lastDepOpB --| + // | ... | + // |-> firstDepOpA | + // ... | + // opB <--------- + // + // Valid insertion point range: (lastDepOpB, firstDepOpA) + // + if (firstDepOpA != nullptr) { + if (lastDepOpB != nullptr) { + if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) + // No valid insertion point exists which preserves dependences. + return nullptr; + } + // Return insertion point in valid range closest to 'opB'. + // TODO(andydavis) Consider other insertion points in valid range. + return firstDepOpA; + } + // No dependences from 'opA' to operation in range ('opA', 'opB'), return + // 'opB' insertion point. + return forOpB.getOperation(); +} + +// Gathers all load and store ops in loop nest rooted at 'forOp' into +// 'loadAndStoreOps'. +static bool +gatherLoadsAndStores(AffineForOp forOp, + SmallVectorImpl<Operation *> &loadAndStoreOps) { + bool hasIfOp = false; + forOp.walk([&](Operation *op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) + loadAndStoreOps.push_back(op); + else if (isa<AffineIfOp>(op)) + hasIfOp = true; + }); + return !hasIfOp; +} + +// TODO(andydavis) Prevent fusion of loop nests with side-effecting operations. +FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice) { + // Return 'failure' if 'dstLoopDepth == 0'. + if (dstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); + return FusionResult::FailPrecondition; + } + // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. + auto *block = srcForOp.getOperation()->getBlock(); + if (block != dstForOp.getOperation()->getBlock()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); + return FusionResult::FailPrecondition; + } + + // Return 'failure' if no valid insertion point for fused loop nest in 'block' + // exists which would preserve dependences. + if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); + return FusionResult::FailBlockDependence; + } + + // Check if 'srcForOp' precedes 'dstForOp' in 'block'. + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + // 'forOpA' executes before 'forOpB' in 'block'. + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. + SmallVector<Operation *, 4> opsA; + if (!gatherLoadsAndStores(forOpA, opsA)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. + SmallVector<Operation *, 4> opsB; + if (!gatherLoadsAndStores(forOpB, opsB)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. + unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( + *srcForOp.getOperation(), *dstForOp.getOperation()); + + // Compute union of computation slices computed between all pairs of ops + // from 'forOpA' and 'forOpB'. + if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops, + isSrcForOpBeforeDstForOp, srcSlice))) { + LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + return FusionResult::FailPrecondition; + } + + return FusionResult::Success; +} + +/// Collect loop nest statistics (eg. loop trip count and operation count) +/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, +/// returns false otherwise. +bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { + auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { + auto *childForOp = forOp.getOperation(); + auto *parentForOp = forOp.getParentOp(); + if (!llvm::isa<FuncOp>(parentForOp)) { + if (!isa<AffineForOp>(parentForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); + return WalkResult::interrupt(); + } + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentForOp].push_back(forOp); + } + + // Record the number of op operations in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[childForOp] = 0; + for (auto &op : *forOp.getBody()) { + if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op)) + ++count; + } + stats->opCountMap[childForOp] = count; + + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + // Currently only constant trip count loop nests are supported. + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); + return WalkResult::interrupt(); + } + + stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +// Computes the total cost of the loop nest rooted at 'forOp'. +// Currently, the total cost is computed by counting the total operation +// instance count (i.e. total number of operations in the loop bodyloop +// operation count * loop trip count) for the entire loop nest. +// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops +// specified in the map when computing the total op instance count. +// NOTEs: 1) This is used to compute the cost of computation slices, which are +// sliced along the iteration dimension, and thus reduce the trip count. +// If 'computeCostMap' is non-null, the total op count for forOps specified +// in the map is increased (not overridden) by adding the op count from the +// map to the existing op count for the for loop. This is done before +// multiplying by the loop's trip count, and is used to model the cost of +// inserting a sliced loop nest of known cost into the loop's body. +// 2) This is also used to compute the cost of fusing a slice of some loop nest +// within another loop. +static int64_t getComputeCostHelper( + Operation *forOp, LoopNestStats &stats, + llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, + DenseMap<Operation *, int64_t> *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body, + // minus terminator op which is a no-op. + int64_t opCount = stats.opCountMap[forOp] - 1; + if (stats.loopMap.count(forOp) > 0) { + for (auto childForOp : stats.loopMap[forOp]) { + opCount += getComputeCostHelper(childForOp.getOperation(), stats, + tripCountOverrideMap, computeCostMap); + } + } + // Add in additional op instances from slice (if specified in map). + if (computeCostMap != nullptr) { + auto it = computeCostMap->find(forOp); + if (it != computeCostMap->end()) { + opCount += it->second; + } + } + // Override trip count (if specified in map). + int64_t tripCount = stats.tripCountMap[forOp]; + if (tripCountOverrideMap != nullptr) { + auto it = tripCountOverrideMap->find(forOp); + if (it != tripCountOverrideMap->end()) { + tripCount = it->second; + } + } + // Returns the total number of dynamic instances of operations in loop body. + return tripCount * opCount; +} + +// TODO(andydavis,b/126426796): extend this to handle multiple result maps. +static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) { + assert(lbMap.getNumResults() == 1 && "expected single result bound map"); + assert(ubMap.getNumResults() == 1 && "expected single result bound map"); + assert(lbMap.getNumDims() == ubMap.getNumDims()); + assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), + lbMap.getNumSymbols()); + auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>(); + if (!cExpr) + return None; + return cExpr.getValue(); +} + +// Return the number of iterations in the given slice. +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop +// nest surrounding represented by slice loop bounds in 'slice'. +// Returns true on success, false otherwise (if a non-constant trip count +// was encountered). +// TODO(andydavis) Make this work with non-unit step loops. +static bool buildSliceTripCountMap( + ComputationSliceState *slice, + llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) { + unsigned numSrcLoopIVs = slice->ivs.size(); + // Populate map from AffineForOp -> trip count + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]); + auto *op = forOp.getOperation(); + AffineMap lbMap = slice->lbs[i]; + AffineMap ubMap = slice->ubs[i]; + if (lbMap == AffineMap() || ubMap == AffineMap()) { + // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. + if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { + (*tripCountMap)[op] = + forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); + continue; + } + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (maybeConstTripCount.hasValue()) { + (*tripCountMap)[op] = maybeConstTripCount.getValue(); + continue; + } + return false; + } + Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap); + // Slice bounds are created with a constant ub - lb difference. + if (!tripCount.hasValue()) + return false; + (*tripCountMap)[op] = tripCount.getValue(); + } + return true; +} + +/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. +/// Currently, the total cost is computed by counting the total operation +/// instance count (i.e. total number of operations in the loop body * loop +/// trip count) for the entire loop nest. +int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { + return getComputeCostHelper(forOp.getOperation(), stats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); +} + +/// Computes and returns in 'computeCost', the total compute cost of fusing the +/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, +/// the total cost is computed by counting the total operation instance count +/// (i.e. total number of operations in the loop body * loop trip count) for +/// the entire loop nest. +bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, + AffineForOp dstForOp, LoopNestStats &dstStats, + ComputationSliceState *slice, + int64_t *computeCost) { + llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; + DenseMap<Operation *, int64_t> computeCostMap; + + // Build trip count map for computation slice. + if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) + return false; + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + assert(sliceIterationCount > 0); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + auto *insertPointParent = slice->insertPoint->getParentOp(); + + // The store and loads to this memref will disappear. + // TODO(andydavis) Add load coalescing to memref data flow opt pass. + if (storeLoadFwdGuaranteed) { + // Subtract from operation count the loads/store we expect load/store + // forwarding to remove. + unsigned storeCount = 0; + llvm::SmallDenseSet<Value, 4> storeMemrefs; + srcForOp.walk([&](Operation *op) { + if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + storeMemrefs.insert(storeOp.getMemRef()); + ++storeCount; + } + }); + // Subtract out any store ops in single-iteration src slice loop nest. + if (storeCount > 0) + computeCostMap[insertPointParent] = -storeCount; + // Subtract out any load users of 'storeMemrefs' nested below + // 'insertPointParent'. + for (auto value : storeMemrefs) { + for (auto *user : value->getUsers()) { + if (auto loadOp = dyn_cast<AffineLoadOp>(user)) { + SmallVector<AffineForOp, 4> loops; + // Check if any loop in loop nest surrounding 'user' is + // 'insertPointParent'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { + if (auto forOp = + dyn_cast_or_null<AffineForOp>(user->getParentOp())) { + if (computeCostMap.count(forOp) == 0) + computeCostMap[forOp] = 0; + computeCostMap[forOp] -= 1; + } + } + } + } + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = getComputeCostHelper( + srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); + + // Compute cost of fusion for this depth. + computeCostMap[insertPointParent] = sliceComputeCost; + + *computeCost = + getComputeCostHelper(dstForOp.getOperation(), dstStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + return true; +} diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp new file mode 100644 index 00000000000..0fece54132a --- /dev/null +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -0,0 +1,1779 @@ +//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous loop transformation routines. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/LoopUtils.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Function.h" +#include "mlir/Transforms/RegionUtils.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "LoopUtils" + +using namespace mlir; +using llvm::SetVector; +using llvm::SmallMapVector; + +/// Computes the cleanup loop lower bound of the loop being unrolled with +/// the specified unroll factor; this bound will also be upper bound of the main +/// part of the unrolled loop. Computes the bound as an AffineMap with its +/// operands or a null map when the trip count can't be expressed as an affine +/// expression. +void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, + AffineMap *map, + SmallVectorImpl<Value> *operands, + OpBuilder &b) { + auto lbMap = forOp.getLowerBoundMap(); + + // Single result lower bound map only. + if (lbMap.getNumResults() != 1) { + *map = AffineMap(); + return; + } + + AffineMap tripCountMap; + SmallVector<Value, 4> tripCountOperands; + buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands); + + // Sometimes the trip count cannot be expressed as an affine expression. + if (!tripCountMap) { + *map = AffineMap(); + return; + } + + unsigned step = forOp.getStep(); + auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap, + forOp.getLowerBoundOperands()); + + // For each upper bound expr, get the range. + // Eg: affine.for %i = lb to min (ub1, ub2), + // where tripCountExprs yield (tr1, tr2), we create affine.apply's: + // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all + // these affine.apply's make up the cleanup loop lower bound. + SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults()); + SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults()); + for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) { + auto tripCountExpr = tripCountMap.getResult(i); + bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step; + auto bumpMap = AffineMap::get(tripCountMap.getNumDims(), + tripCountMap.getNumSymbols(), bumpExprs[i]); + bumpValues[i] = + b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands); + } + + SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults()); + for (unsigned i = 0, e = bumpExprs.size(); i < e; i++) + newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1); + + operands->clear(); + operands->push_back(lb); + operands->append(bumpValues.begin(), bumpValues.end()); + *map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs); + // Simplify the map + operands. + fullyComposeAffineMapAndOperands(map, operands); + *map = simplifyAffineMap(*map); + canonicalizeMapAndOperands(map, operands); + // Remove any affine.apply's that became dead from the simplification above. + for (auto v : bumpValues) { + if (v->use_empty()) { + v->getDefiningOp()->erase(); + } + } + if (lb.use_empty()) + lb.erase(); +} + +/// Promotes the loop body of a forOp to its containing block if the forOp +/// was known to have a single iteration. +// TODO(bondhugula): extend this for arbitrary affine bounds. +LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { + Optional<uint64_t> tripCount = getConstantTripCount(forOp); + if (!tripCount.hasValue() || tripCount.getValue() != 1) + return failure(); + + // TODO(mlir-team): there is no builder for a max. + if (forOp.getLowerBoundMap().getNumResults() != 1) + return failure(); + + // Replaces all IV uses to its single iteration value. + auto iv = forOp.getInductionVar(); + Operation *op = forOp.getOperation(); + if (!iv->use_empty()) { + if (forOp.hasConstantLowerBound()) { + OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody()); + auto constOp = topBuilder.create<ConstantIndexOp>( + forOp.getLoc(), forOp.getConstantLowerBound()); + iv->replaceAllUsesWith(constOp); + } else { + AffineBound lb = forOp.getLowerBound(); + SmallVector<Value, 4> lbOperands(lb.operand_begin(), lb.operand_end()); + OpBuilder builder(op->getBlock(), Block::iterator(op)); + if (lb.getMap() == builder.getDimIdentityMap()) { + // No need of generating an affine.apply. + iv->replaceAllUsesWith(lbOperands[0]); + } else { + auto affineApplyOp = builder.create<AffineApplyOp>( + op->getLoc(), lb.getMap(), lbOperands); + iv->replaceAllUsesWith(affineApplyOp); + } + } + } + // Move the loop body operations, except for terminator, to the loop's + // containing block. + auto *block = op->getBlock(); + forOp.getBody()->getOperations().back().erase(); + block->getOperations().splice(Block::iterator(op), + forOp.getBody()->getOperations()); + forOp.erase(); + return success(); +} + +/// Promotes all single iteration for op's in the FuncOp, i.e., moves +/// their body into the containing Block. +void mlir::promoteSingleIterationLoops(FuncOp f) { + // Gathers all innermost loops through a post order pruned walk. + f.walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); +} + +/// Generates a 'affine.for' op with the specified lower and upper bounds +/// while generating the right IV remappings for the shifted operations. The +/// operation blocks that go into the loop are specified in instGroupQueue +/// starting from the specified offset, and in that order; the first element of +/// the pair specifies the shift applied to that group of operations; note +/// that the shift is multiplied by the loop step before being applied. Returns +/// nullptr if the generated loop simplifies to a single iteration one. +static AffineForOp +generateLoop(AffineMap lbMap, AffineMap ubMap, + const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> + &instGroupQueue, + unsigned offset, AffineForOp srcForInst, OpBuilder b) { + SmallVector<Value, 4> lbOperands(srcForInst.getLowerBoundOperands()); + SmallVector<Value, 4> ubOperands(srcForInst.getUpperBoundOperands()); + + assert(lbMap.getNumInputs() == lbOperands.size()); + assert(ubMap.getNumInputs() == ubOperands.size()); + + auto loopChunk = + b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands, + ubMap, srcForInst.getStep()); + auto loopChunkIV = loopChunk.getInductionVar(); + auto srcIV = srcForInst.getInductionVar(); + + BlockAndValueMapping operandMap; + + OpBuilder bodyBuilder = loopChunk.getBodyBuilder(); + for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end(); + it != e; ++it) { + uint64_t shift = it->first; + auto insts = it->second; + // All 'same shift' operations get added with their operands being + // remapped to results of cloned operations, and their IV used remapped. + // Generate the remapping if the shift is not zero: remappedIV = newIV - + // shift. + if (!srcIV->use_empty() && shift != 0) { + auto ivRemap = bodyBuilder.create<AffineApplyOp>( + srcForInst.getLoc(), + bodyBuilder.getSingleDimShiftAffineMap( + -static_cast<int64_t>(srcForInst.getStep() * shift)), + loopChunkIV); + operandMap.map(srcIV, ivRemap); + } else { + operandMap.map(srcIV, loopChunkIV); + } + for (auto *op : insts) { + if (!isa<AffineTerminatorOp>(op)) + bodyBuilder.clone(*op, operandMap); + } + }; + if (succeeded(promoteIfSingleIteration(loopChunk))) + return AffineForOp(); + return loopChunk; +} + +/// Skew the operations in the body of a 'affine.for' operation with the +/// specified operation-wise shifts. The shifts are with respect to the +/// original execution order, and are multiplied by the loop 'step' before being +/// applied. A shift of zero for each operation will lead to no change. +// The skewing of operations with respect to one another can be used for +// example to allow overlap of asynchronous operations (such as DMA +// communication) with computation, or just relative shifting of operations +// for better register reuse, locality or parallelism. As such, the shifts are +// typically expected to be at most of the order of the number of operations. +// This method should not be used as a substitute for loop distribution/fission. +// This method uses an algorithm// in time linear in the number of operations +// in the body of the for loop - (using the 'sweep line' paradigm). This method +// asserts preservation of SSA dominance. A check for that as well as that for +// memory-based dependence preservation check rests with the users of this +// method. +LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts, + bool unrollPrologueEpilogue) { + if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) + return success(); + + // If the trip counts aren't constant, we would need versioning and + // conditional guards (or context information to prevent such versioning). The + // better way to pipeline for such loops is to first tile them and extract + // constant trip count "full tiles" before applying this. + auto mayBeConstTripCount = getConstantTripCount(forOp); + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled")); + return success(); + } + uint64_t tripCount = mayBeConstTripCount.getValue(); + + assert(isInstwiseShiftValid(forOp, shifts) && + "shifts will lead to an invalid transformation\n"); + + int64_t step = forOp.getStep(); + + unsigned numChildInsts = forOp.getBody()->getOperations().size(); + + // Do a linear time (counting) sort for the shifts. + uint64_t maxShift = 0; + for (unsigned i = 0; i < numChildInsts; i++) { + maxShift = std::max(maxShift, shifts[i]); + } + // Such large shifts are not the typical use case. + if (maxShift >= numChildInsts) { + forOp.emitWarning("not shifting because shifts are unrealistically large"); + return success(); + } + + // An array of operation groups sorted by shift amount; each group has all + // operations with the same shift in the order in which they appear in the + // body of the 'affine.for' op. + std::vector<std::vector<Operation *>> sortedInstGroups(maxShift + 1); + unsigned pos = 0; + for (auto &op : *forOp.getBody()) { + auto shift = shifts[pos++]; + sortedInstGroups[shift].push_back(&op); + } + + // Unless the shifts have a specific pattern (which actually would be the + // common use case), prologue and epilogue are not meaningfully defined. + // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first + // loop generated as the prologue and the last as epilogue and unroll these + // fully. + AffineForOp prologue; + AffineForOp epilogue; + + // Do a sweep over the sorted shifts while storing open groups in a + // vector, and generating loop portions as necessary during the sweep. A block + // of operations is paired with its shift. + std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> instGroupQueue; + + auto origLbMap = forOp.getLowerBoundMap(); + uint64_t lbShift = 0; + OpBuilder b(forOp.getOperation()); + for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { + // If nothing is shifted by d, continue. + if (sortedInstGroups[d].empty()) + continue; + if (!instGroupQueue.empty()) { + assert(d >= 1 && + "Queue expected to be empty when the first block is found"); + // The interval for which the loop needs to be generated here is: + // [lbShift, min(lbShift + tripCount, d)) and the body of the + // loop needs to have all operations in instQueue in that order. + AffineForOp res; + if (lbShift + tripCount * step < d * step) { + res = generateLoop( + b.getShiftedAffineMap(origLbMap, lbShift), + b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), + instGroupQueue, 0, forOp, b); + // Entire loop for the queued op groups generated, empty it. + instGroupQueue.clear(); + lbShift += tripCount * step; + } else { + res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), + b.getShiftedAffineMap(origLbMap, d), instGroupQueue, + 0, forOp, b); + lbShift = d * step; + } + if (!prologue && res) + prologue = res; + epilogue = res; + } else { + // Start of first interval. + lbShift = d * step; + } + // Augment the list of operations that get into the current open interval. + instGroupQueue.push_back({d, sortedInstGroups[d]}); + } + + // Those operations groups left in the queue now need to be processed (FIFO) + // and their loops completed. + for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) { + uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; + epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), + b.getShiftedAffineMap(origLbMap, ubShift), + instGroupQueue, i, forOp, b); + lbShift = ubShift; + if (!prologue) + prologue = epilogue; + } + + // Erase the original for op. + forOp.erase(); + + if (unrollPrologueEpilogue && prologue) + loopUnrollFull(prologue); + if (unrollPrologueEpilogue && !epilogue && + epilogue.getOperation() != prologue.getOperation()) + loopUnrollFull(epilogue); + + return success(); +} + +// Collect perfectly nested loops starting from `rootForOps`. Loops are +// perfectly nested if each loop is the first and only non-terminator operation +// in the parent loop. Collect at most `maxLoops` loops and append them to +// `forOps`. +template <typename T> +void getPerfectlyNestedLoopsImpl( + SmallVectorImpl<T> &forOps, T rootForOp, + unsigned maxLoops = std::numeric_limits<unsigned>::max()) { + for (unsigned i = 0; i < maxLoops; ++i) { + forOps.push_back(rootForOp); + Block &body = rootForOp.region().front(); + if (body.begin() != std::prev(body.end(), 2)) + return; + + rootForOp = dyn_cast<T>(&body.front()); + if (!rootForOp) + return; + } +} + +/// Get perfectly nested sequence of loops starting at root of loop nest +/// (the first op being another AffineFor, and the second op - a terminator). +/// A loop is perfectly nested iff: the first op in the loop's body is another +/// AffineForOp, and the second op is a terminator). +void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops, + AffineForOp root) { + getPerfectlyNestedLoopsImpl(nestedLoops, root); +} + +void mlir::getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops, + loop::ForOp root) { + getPerfectlyNestedLoopsImpl(nestedLoops, root); +} + +/// Unrolls this loop completely. +LogicalResult mlir::loopUnrollFull(AffineForOp forOp) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); + if (mayBeConstantTripCount.hasValue()) { + uint64_t tripCount = mayBeConstantTripCount.getValue(); + if (tripCount == 1) { + return promoteIfSingleIteration(forOp); + } + return loopUnrollByFactor(forOp, tripCount); + } + return failure(); +} + +/// Unrolls and jams this loop by the specified factor or by the trip count (if +/// constant) whichever is lower. +LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp, + uint64_t unrollFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); + + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollFactor) + return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forOp, unrollFactor); +} + +/// Unrolls this loop by the specified factor. Returns success if the loop +/// is successfully unrolled. +LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, + uint64_t unrollFactor) { + assert(unrollFactor >= 1 && "unroll factor should be >= 1"); + + if (unrollFactor == 1) + return promoteIfSingleIteration(forOp); + + if (forOp.getBody()->empty() || + forOp.getBody()->begin() == std::prev(forOp.getBody()->end())) + return failure(); + + // Loops where the lower bound is a max expression isn't supported for + // unrolling since the trip count can be expressed as an affine function when + // both the lower bound and the upper bound are multi-result maps. However, + // one meaningful way to do such unrolling would be to specialize the loop for + // the 'hotspot' case and unroll that hotspot. + if (forOp.getLowerBoundMap().getNumResults() != 1) + return failure(); + + // If the trip count is lower than the unroll factor, no unrolled body. + // TODO(bondhugula): option to specify cleanup loop unrolling. + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollFactor) + return failure(); + + // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. + Operation *op = forOp.getOperation(); + if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { + OpBuilder builder(op->getBlock(), ++Block::iterator(op)); + auto cleanupForInst = cast<AffineForOp>(builder.clone(*op)); + AffineMap cleanupMap; + SmallVector<Value, 4> cleanupOperands; + getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, + builder); + assert(cleanupMap && + "cleanup loop lower bound map for single result lower bound maps " + "can always be determined"); + cleanupForInst.setLowerBound(cleanupOperands, cleanupMap); + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(cleanupForInst); + + // Adjust upper bound of the original loop; this is the same as the lower + // bound of the cleanup loop. + forOp.setUpperBound(cleanupOperands, cleanupMap); + } + + // Scale the step of loop being unrolled by unroll factor. + int64_t step = forOp.getStep(); + forOp.setStep(step * unrollFactor); + + // Builder to insert unrolled bodies just before the terminator of the body of + // 'forOp'. + OpBuilder builder = forOp.getBodyBuilder(); + + // Keep a pointer to the last non-terminator operation in the original block + // so that we know what to clone (since we are doing this in-place). + Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2); + + // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). + auto forOpIV = forOp.getInductionVar(); + for (unsigned i = 1; i < unrollFactor; i++) { + BlockAndValueMapping operandMap; + + // If the induction variable is used, create a remapping to the value for + // this unrolled instance. + if (!forOpIV->use_empty()) { + // iv' = iv + 1/2/3...unrollFactor-1; + auto d0 = builder.getAffineDimExpr(0); + auto bumpMap = AffineMap::get(1, 0, {d0 + i * step}); + auto ivUnroll = + builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV); + operandMap.map(forOpIV, ivUnroll); + } + + // Clone the original body of 'forOp'. + for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd); + it++) { + builder.clone(*it, operandMap); + } + } + + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(forOp); + return success(); +} + +/// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is +/// nested within 'forOpA' as the only non-terminator operation in its block. +void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { + auto *forOpAInst = forOpA.getOperation(); + + assert(&*forOpA.getBody()->begin() == forOpB.getOperation()); + auto &forOpABody = forOpA.getBody()->getOperations(); + auto &forOpBBody = forOpB.getBody()->getOperations(); + + // 1) Splice forOpA's non-terminator operations (which is just forOpB) just + // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's + // body containing only the terminator. + forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst), + forOpABody, forOpABody.begin(), + std::prev(forOpABody.end())); + // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's + // body (this leaves forOpB's body containing only the terminator). + forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(), + std::prev(forOpBBody.end())); + // 3) Splice forOpA into the beginning of forOpB's body. + forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(), + Block::iterator(forOpAInst)); +} + +// Checks each dependence component against the permutation to see if the +// desired loop interchange would violate dependences by making the +// dependence component lexicographically negative. +static bool checkLoopInterchangeDependences( + const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec, + ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) { + // Invert permutation map. + unsigned maxLoopDepth = loops.size(); + SmallVector<unsigned, 4> loopPermMapInv; + loopPermMapInv.resize(maxLoopDepth); + for (unsigned i = 0; i < maxLoopDepth; ++i) + loopPermMapInv[loopPermMap[i]] = i; + + // Check each dependence component against the permutation to see if the + // desired loop interchange permutation would make the dependence vectors + // lexicographically negative. + // Example 1: [-1, 1][0, 0] + // Example 2: [0, 0][-1, 1] + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + // Check if the first non-zero dependence component is positive. + // This iterates through loops in the desired order. + for (unsigned j = 0; j < maxLoopDepth; ++j) { + unsigned permIndex = loopPermMapInv[j]; + assert(depComps[permIndex].lb.hasValue()); + int64_t depCompLb = depComps[permIndex].lb.getValue(); + if (depCompLb > 0) + break; + if (depCompLb < 0) + return false; + } + } + return true; +} + +/// Checks if the loop interchange permutation 'loopPermMap' of the perfectly +/// nested sequence of loops in 'loops' would violate dependences. +bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops, + ArrayRef<unsigned> loopPermMap) { + // Gather dependence components for dependences between all ops in loop nest + // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. + assert(loopPermMap.size() == loops.size()); + unsigned maxLoopDepth = loops.size(); + std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; + getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); + return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap); +} + +/// Performs a sequence of loop interchanges of loops in perfectly nested +/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'. +unsigned mlir::interchangeLoops(ArrayRef<AffineForOp> loops, + ArrayRef<unsigned> loopPermMap) { + Optional<unsigned> loopNestRootIndex; + for (int i = loops.size() - 1; i >= 0; --i) { + int permIndex = static_cast<int>(loopPermMap[i]); + // Store the index of the for loop which will be the new loop nest root. + if (permIndex == 0) + loopNestRootIndex = i; + if (permIndex > i) { + // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. + sinkLoop(loops[i], permIndex - i); + } + } + assert(loopNestRootIndex.hasValue()); + return loopNestRootIndex.getValue(); +} + +// Sinks all sequential loops to the innermost levels (while preserving +// relative order among them) and moves all parallel loops to the +// outermost (while again preserving relative order among them). +AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) { + SmallVector<AffineForOp, 4> loops; + getPerfectlyNestedLoops(loops, forOp); + if (loops.size() < 2) + return forOp; + + // Gather dependence components for dependences between all ops in loop nest + // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. + unsigned maxLoopDepth = loops.size(); + std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; + getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); + + // Mark loops as either parallel or sequential. + SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true); + for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { + SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; + assert(depComps.size() >= maxLoopDepth); + for (unsigned j = 0; j < maxLoopDepth; ++j) { + DependenceComponent &depComp = depComps[j]; + assert(depComp.lb.hasValue() && depComp.ub.hasValue()); + if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0) + isParallelLoop[j] = false; + } + } + + // Count the number of parallel loops. + unsigned numParallelLoops = 0; + for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i) + if (isParallelLoop[i]) + ++numParallelLoops; + + // Compute permutation of loops that sinks sequential loops (and thus raises + // parallel loops) while preserving relative order. + SmallVector<unsigned, 4> loopPermMap(maxLoopDepth); + unsigned nextSequentialLoop = numParallelLoops; + unsigned nextParallelLoop = 0; + for (unsigned i = 0; i < maxLoopDepth; ++i) { + if (isParallelLoop[i]) { + loopPermMap[i] = nextParallelLoop++; + } else { + loopPermMap[i] = nextSequentialLoop++; + } + } + + // Check if permutation 'loopPermMap' would violate dependences. + if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap)) + return forOp; + // Perform loop interchange according to permutation 'loopPermMap'. + unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap); + return loops[loopNestRootIndex]; +} + +/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels +/// deeper in the loop nest. +void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { + for (unsigned i = 0; i < loopDepth; ++i) { + AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front()); + interchangeLoops(forOp, nextForOp); + } +} + +// Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the +// lower (resp. upper) loop bound. When called for both the lower and upper +// bounds, the resulting IR resembles: +// +// ```mlir +// affine.for %i = max (`iv, ...) to min (`iv` + `offset`) { +// ... +// } +// ``` +static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map, + SmallVector<Value, 4> *operands, + int64_t offset = 0) { + auto bounds = llvm::to_vector<4>(map->getResults()); + bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset); + operands->insert(operands->begin() + map->getNumDims(), iv); + *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds); + canonicalizeMapAndOperands(map, operands); +} + +// Stripmines `forOp` by `factor` and sinks it under each of the `targets`. +// Stripmine-sink is a primitive building block for generalized tiling of +// imperfectly nested loops. +// This transformation is purely mechanical and does not check legality, +// profitability or even structural correctness. It is the user's +// responsibility to specify `targets` that are dominated by `forOp`. +// Returns the new AffineForOps, one per `targets`, nested immediately under +// each of the `targets`. +static SmallVector<AffineForOp, 8> +stripmineSink(AffineForOp forOp, uint64_t factor, + ArrayRef<AffineForOp> targets) { + auto originalStep = forOp.getStep(); + auto scaledStep = originalStep * factor; + forOp.setStep(scaledStep); + + auto *op = forOp.getOperation(); + OpBuilder b(op->getBlock(), ++Block::iterator(op)); + + // Lower-bound map creation. + auto lbMap = forOp.getLowerBoundMap(); + SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); + augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands); + + // Upper-bound map creation. + auto ubMap = forOp.getUpperBoundMap(); + SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); + augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands, + /*offset=*/scaledStep); + + auto iv = forOp.getInductionVar(); + SmallVector<AffineForOp, 8> innerLoops; + for (auto t : targets) { + // Insert newForOp before the terminator of `t`. + OpBuilder b = t.getBodyBuilder(); + auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap, + ubOperands, ubMap, originalStep); + auto begin = t.getBody()->begin(); + // Skip terminator and `newForOp` which is just before the terminator. + auto nOps = t.getBody()->getOperations().size() - 2; + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + t.getBody()->getOperations(), begin, std::next(begin, nOps)); + replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), + newForOp.region()); + innerLoops.push_back(newForOp); + } + + return innerLoops; +} + +static Loops stripmineSink(loop::ForOp forOp, Value factor, + ArrayRef<loop::ForOp> targets) { + auto originalStep = forOp.step(); + auto iv = forOp.getInductionVar(); + + OpBuilder b(forOp); + forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor)); + + Loops innerLoops; + for (auto t : targets) { + // Save information for splicing ops out of t when done + auto begin = t.getBody()->begin(); + auto nOps = t.getBody()->getOperations().size(); + + // Insert newForOp before the terminator of `t`. + OpBuilder b(t.getBodyBuilder()); + Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step()); + Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt, + forOp.upperBound(), stepped); + Value ub = + b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped); + + // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. + auto newForOp = b.create<loop::ForOp>(t.getLoc(), iv, ub, originalStep); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); + replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), + newForOp.region()); + + innerLoops.push_back(newForOp); + } + + return innerLoops; +} + +// Stripmines a `forOp` by `factor` and sinks it under a single `target`. +// Returns the new AffineForOps, nested immediately under `target`. +template <typename ForType, typename SizeType> +static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) { + // TODO(ntv): Use cheap structural assertions that targets are nested under + // forOp and that targets are not nested under each other when DominanceInfo + // exposes the capability. It seems overkill to construct a whole function + // dominance tree at this point. + auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target}); + assert(res.size() == 1 && "Expected 1 inner forOp"); + return res[0]; +} + +template <typename ForType, typename SizeType> +static SmallVector<SmallVector<ForType, 8>, 8> +tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, + ArrayRef<ForType> targets) { + SmallVector<SmallVector<ForType, 8>, 8> res; + SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end()); + for (auto it : llvm::zip(forOps, sizes)) { + auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); + res.push_back(step); + currentTargets = step; + } + return res; +} + +SmallVector<SmallVector<AffineForOp, 8>, 8> +mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes, + ArrayRef<AffineForOp> targets) { + return tileImpl(forOps, sizes, targets); +} + +SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps, + ArrayRef<Value> sizes, + ArrayRef<loop::ForOp> targets) { + return tileImpl(forOps, sizes, targets); +} + +template <typename ForType, typename SizeType> +static SmallVector<ForType, 8> +tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) { + SmallVector<ForType, 8> res; + for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) { + assert(loops.size() == 1); + res.push_back(loops[0]); + } + return res; +} + +SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps, + ArrayRef<uint64_t> sizes, + AffineForOp target) { + return tileImpl(forOps, sizes, target); +} + +Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes, + loop::ForOp target) { + return tileImpl(forOps, sizes, target); +} + +Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes) { + // Collect perfectly nested loops. If more size values provided than nested + // loops available, truncate `sizes`. + SmallVector<loop::ForOp, 4> forOps; + forOps.reserve(sizes.size()); + getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); + if (forOps.size() < sizes.size()) + sizes = sizes.take_front(forOps.size()); + + return ::tile(forOps, sizes, forOps.back()); +} + +// Build the IR that performs ceil division of a positive value by a constant: +// ceildiv(a, B) = divis(a + (B-1), B) +// where divis is rounding-to-zero division. +static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, + int64_t divisor) { + assert(divisor > 0 && "expected positive divisor"); + assert(dividend->getType().isIndex() && "expected index-typed value"); + + Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1); + Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor); + Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst); + return builder.create<SignedDivIOp>(loc, sum, divisorCst); +} + +// Build the IR that performs ceil division of a positive value by another +// positive value: +// ceildiv(a, b) = divis(a + (b - 1), b) +// where divis is rounding-to-zero division. +static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, + Value divisor) { + assert(dividend->getType().isIndex() && "expected index-typed value"); + + Value cstOne = builder.create<ConstantIndexOp>(loc, 1); + Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne); + Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne); + return builder.create<SignedDivIOp>(loc, sum, divisor); +} + +// Hoist the ops within `outer` that appear before `inner`. +// Such ops include the ops that have been introduced by parametric tiling. +// Ops that come from triangular loops (i.e. that belong to the program slice +// rooted at `outer`) and ops that have side effects cannot be hoisted. +// Return failure when any op fails to hoist. +static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) { + SetVector<Operation *> forwardSlice; + getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) { + return op != inner.getOperation(); + }); + LogicalResult status = success(); + SmallVector<Operation *, 8> toHoist; + for (auto &op : outer.getBody()->getOperations()) { + // Stop when encountering the inner loop. + if (&op == inner.getOperation()) + break; + // Skip over non-hoistable ops. + if (forwardSlice.count(&op) > 0) { + status = failure(); + continue; + } + // Skip loop::ForOp, these are not considered a failure. + if (op.getNumRegions() > 0) + continue; + // Skip other ops with regions. + if (op.getNumRegions() > 0) { + status = failure(); + continue; + } + // Skip if op has side effects. + // TODO(ntv): loads to immutable memory regions are ok. + if (!op.hasNoSideEffect()) { + status = failure(); + continue; + } + toHoist.push_back(&op); + } + auto *outerForOp = outer.getOperation(); + for (auto *op : toHoist) + op->moveBefore(outerForOp); + return status; +} + +// Traverse the interTile and intraTile loops and try to hoist ops such that +// bands of perfectly nested loops are isolated. +// Return failure if either perfect interTile or perfect intraTile bands cannot +// be formed. +static LogicalResult tryIsolateBands(const TileLoops &tileLoops) { + LogicalResult status = success(); + auto &interTile = tileLoops.first; + auto &intraTile = tileLoops.second; + auto size = interTile.size(); + assert(size == intraTile.size()); + if (size <= 1) + return success(); + for (unsigned s = 1; s < size; ++s) + status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s]) + : failure(); + for (unsigned s = 1; s < size; ++s) + status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s]) + : failure(); + return status; +} + +TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp, + ArrayRef<int64_t> sizes) { + // Collect perfectly nested loops. If more size values provided than nested + // loops available, truncate `sizes`. + SmallVector<loop::ForOp, 4> forOps; + forOps.reserve(sizes.size()); + getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); + if (forOps.size() < sizes.size()) + sizes = sizes.take_front(forOps.size()); + + // Compute the tile sizes such that i-th outer loop executes size[i] + // iterations. Given that the loop current executes + // numIterations = ceildiv((upperBound - lowerBound), step) + // iterations, we need to tile with size ceildiv(numIterations, size[i]). + SmallVector<Value, 4> tileSizes; + tileSizes.reserve(sizes.size()); + for (unsigned i = 0, e = sizes.size(); i < e; ++i) { + assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); + + auto forOp = forOps[i]; + OpBuilder builder(forOp); + auto loc = forOp.getLoc(); + Value diff = + builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound()); + Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step()); + Value iterationsPerBlock = + ceilDivPositive(builder, loc, numIterations, sizes[i]); + tileSizes.push_back(iterationsPerBlock); + } + + // Call parametric tiling with the given sizes. + auto intraTile = tile(forOps, tileSizes, forOps.back()); + TileLoops tileLoops = std::make_pair(forOps, intraTile); + + // TODO(ntv, zinenko) for now we just ignore the result of band isolation. + // In the future, mapping decisions may be impacted by the ability to + // isolate perfectly nested bands. + tryIsolateBands(tileLoops); + + return tileLoops; +} + +// Replaces all uses of `orig` with `replacement` except if the user is listed +// in `exceptions`. +static void +replaceAllUsesExcept(Value orig, Value replacement, + const SmallPtrSetImpl<Operation *> &exceptions) { + for (auto &use : llvm::make_early_inc_range(orig->getUses())) { + if (exceptions.count(use.getOwner()) == 0) + use.set(replacement); + } +} + +// Transform a loop with a strictly positive step +// for %i = %lb to %ub step %s +// into a 0-based loop with step 1 +// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 { +// %i = %ii * %s + %lb +// Insert the induction variable remapping in the body of `inner`, which is +// expected to be either `loop` or another loop perfectly nested under `loop`. +// Insert the definition of new bounds immediate before `outer`, which is +// expected to be either `loop` or its parent in the loop nest. +static void normalizeLoop(loop::ForOp loop, loop::ForOp outer, + loop::ForOp inner) { + OpBuilder builder(outer); + Location loc = loop.getLoc(); + + // Check if the loop is already known to have a constant zero lower bound or + // a constant one step. + bool isZeroBased = false; + if (auto ubCst = + dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound()->getDefiningOp())) + isZeroBased = ubCst.getValue() == 0; + + bool isStepOne = false; + if (auto stepCst = + dyn_cast_or_null<ConstantIndexOp>(loop.step()->getDefiningOp())) + isStepOne = stepCst.getValue() == 1; + + if (isZeroBased && isStepOne) + return; + + // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) + // assuming the step is strictly positive. Update the bounds and the step + // of the loop to go from 0 to the number of iterations, if necessary. + // TODO(zinenko): introduce support for negative steps or emit dynamic asserts + // on step positivity, whatever gets implemented first. + Value diff = + builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound()); + Value numIterations = ceilDivPositive(builder, loc, diff, loop.step()); + loop.setUpperBound(numIterations); + + Value lb = loop.lowerBound(); + if (!isZeroBased) { + Value cst0 = builder.create<ConstantIndexOp>(loc, 0); + loop.setLowerBound(cst0); + } + + Value step = loop.step(); + if (!isStepOne) { + Value cst1 = builder.create<ConstantIndexOp>(loc, 1); + loop.setStep(cst1); + } + + // Insert code computing the value of the original loop induction variable + // from the "normalized" one. + builder.setInsertionPointToStart(inner.getBody()); + Value scaled = + isStepOne ? loop.getInductionVar() + : builder.create<MulIOp>(loc, loop.getInductionVar(), step); + Value shifted = + isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb); + + SmallPtrSet<Operation *, 2> preserve{scaled->getDefiningOp(), + shifted->getDefiningOp()}; + replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve); +} + +void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) { + if (loops.size() < 2) + return; + + loop::ForOp innermost = loops.back(); + loop::ForOp outermost = loops.front(); + + // 1. Make sure all loops iterate from 0 to upperBound with step 1. This + // allows the following code to assume upperBound is the number of iterations. + for (auto loop : loops) + normalizeLoop(loop, outermost, innermost); + + // 2. Emit code computing the upper bound of the coalesced loop as product + // of the number of iterations of all loops. + OpBuilder builder(outermost); + Location loc = outermost.getLoc(); + Value upperBound = outermost.upperBound(); + for (auto loop : loops.drop_front()) + upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound()); + outermost.setUpperBound(upperBound); + + builder.setInsertionPointToStart(outermost.getBody()); + + // 3. Remap induction variables. For each original loop, the value of the + // induction variable can be obtained by dividing the induction variable of + // the linearized loop by the total number of iterations of the loops nested + // in it modulo the number of iterations in this loop (remove the values + // related to the outer loops): + // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. + // Compute these iteratively from the innermost loop by creating a "running + // quotient" of division by the range. + Value previous = outermost.getInductionVar(); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + unsigned idx = loops.size() - i - 1; + if (i != 0) + previous = builder.create<SignedDivIOp>(loc, previous, + loops[idx + 1].upperBound()); + + Value iv = (i == e - 1) ? previous + : builder.create<SignedRemIOp>( + loc, previous, loops[idx].upperBound()); + replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv, + loops.back().region()); + } + + // 4. Move the operations from the innermost just above the second-outermost + // loop, delete the extra terminator and the second-outermost loop. + loop::ForOp second = loops[1]; + innermost.getBody()->back().erase(); + outermost.getBody()->getOperations().splice( + Block::iterator(second.getOperation()), + innermost.getBody()->getOperations()); + second.erase(); +} + +void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId, + ArrayRef<Value> numProcessors) { + assert(processorId.size() == numProcessors.size()); + if (processorId.empty()) + return; + + OpBuilder b(forOp); + Location loc(forOp.getLoc()); + Value mul = processorId.front(); + for (unsigned i = 1, e = processorId.size(); i < e; ++i) + mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]), + processorId[i]); + Value lb = b.create<AddIOp>(loc, forOp.lowerBound(), + b.create<MulIOp>(loc, forOp.step(), mul)); + forOp.setLowerBound(lb); + + Value step = forOp.step(); + for (auto numProcs : numProcessors) + step = b.create<MulIOp>(loc, step, numProcs); + forOp.setStep(step); +} + +/// Given a memref region, determine the lowest depth at which transfers can be +/// placed for it, and return the corresponding block, start and end positions +/// in the block for placing incoming (read) and outgoing (write) copies +/// respectively. The lowest depth depends on whether the region being accessed +/// is hoistable with respect to one or more immediately surrounding loops. +static void +findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, + Block::iterator &begin, Block::iterator &end, + Block **copyPlacementBlock, + Block::iterator *copyInPlacementStart, + Block::iterator *copyOutPlacementStart) { + const auto *cst = region.getConstraints(); + SmallVector<Value, 4> symbols; + cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols); + + SmallVector<AffineForOp, 4> enclosingFors; + getLoopIVs(*block.begin(), &enclosingFors); + // Walk up loop parents till we find an IV on which this region is + // symbolic/variant. + auto it = enclosingFors.rbegin(); + for (auto e = enclosingFors.rend(); it != e; ++it) { + // TODO(bondhugula): also need to be checking this for regions symbols that + // aren't loop IVs, whether we are within their resp. defs' dominance scope. + if (llvm::is_contained(symbols, it->getInductionVar())) + break; + } + + if (it != enclosingFors.rbegin()) { + auto lastInvariantIV = *std::prev(it); + *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation()); + *copyOutPlacementStart = std::next(*copyInPlacementStart); + *copyPlacementBlock = lastInvariantIV.getOperation()->getBlock(); + } else { + *copyInPlacementStart = begin; + *copyOutPlacementStart = end; + *copyPlacementBlock = █ + } +} + +// Info comprising stride and number of elements transferred every stride. +struct StrideInfo { + int64_t stride; + int64_t numEltPerStride; +}; + +/// Returns striding information for a copy/transfer of this region with +/// potentially multiple striding levels from outermost to innermost. For an +/// n-dimensional region, there can be at most n-1 levels of striding +/// successively nested. +// TODO(bondhugula): make this work with non-identity layout maps. +static void getMultiLevelStrides(const MemRefRegion ®ion, + ArrayRef<int64_t> bufferShape, + SmallVectorImpl<StrideInfo> *strideInfos) { + if (bufferShape.size() <= 1) + return; + + int64_t numEltPerStride = 1; + int64_t stride = 1; + for (int d = bufferShape.size() - 1; d >= 1; d--) { + int64_t dimSize = region.memref->getType().cast<MemRefType>().getDimSize(d); + stride *= dimSize; + numEltPerStride *= bufferShape[d]; + // A stride is needed only if the region has a shorter extent than the + // memref along the dimension *and* has an extent greater than one along the + // next major dimension. + if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) { + strideInfos->push_back({stride, numEltPerStride}); + } + } +} + +/// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and +/// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart' +/// holds the lower coordinates of the region in the original memref to copy +/// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in. +static AffineForOp generatePointWiseCopy(Location loc, Value memref, + Value fastMemRef, + AffineMap memAffineMap, + ArrayRef<Value> memIndicesStart, + ArrayRef<int64_t> fastBufferShape, + bool isCopyOut, OpBuilder b) { + assert(!memIndicesStart.empty() && "only 1-d or more memrefs"); + + // The copy-in nest is generated as follows as an example for a 2-d region: + // for x = ... + // for y = ... + // fast_buf[x][y] = buf[mem_x + x][mem_y + y] + + SmallVector<Value, 4> fastBufIndices, memIndices; + AffineForOp copyNestRoot; + for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) { + auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]); + if (d == 0) + copyNestRoot = forOp; + b = forOp.getBodyBuilder(); + fastBufIndices.push_back(forOp.getInductionVar()); + + Value memBase = + (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims())) + ? memIndicesStart[d] + : b.create<AffineApplyOp>( + loc, + AffineMap::get(memAffineMap.getNumDims(), + memAffineMap.getNumSymbols(), + memAffineMap.getResult(d)), + memIndicesStart); + + // Construct the subscript for the slow memref being copied. + auto memIndex = b.create<AffineApplyOp>( + loc, + AffineMap::get(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)), + ValueRange({memBase, forOp.getInductionVar()})); + memIndices.push_back(memIndex); + } + + if (!isCopyOut) { + // Copy in. + auto load = b.create<AffineLoadOp>(loc, memref, memIndices); + b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufIndices); + return copyNestRoot; + } + + // Copy out. + auto load = b.create<AffineLoadOp>(loc, fastMemRef, fastBufIndices); + b.create<AffineStoreOp>(loc, load, memref, memIndices); + return copyNestRoot; +} + +static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED +emitRemarkForBlock(Block &block) { + return block.getParentOp()->emitRemark(); +} + +/// Creates a buffer in the faster memory space for the specified memref region; +/// generates a copy from the lower memory space to this one, and replaces all +/// loads/stores in the block range [`begin', `end') of `block' to load/store +/// from that buffer. Returns failure if copies could not be generated due to +/// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart` +/// in copyPlacementBlock specify the insertion points where the incoming copies +/// and outgoing copies, respectively, should be inserted (the insertion happens +/// right before the insertion point). Since `begin` can itself be invalidated +/// due to the memref rewriting done from this method, the output argument +/// `nBegin` is set to its replacement (set to `begin` if no invalidation +/// happens). Since outgoing copies could have been inserted at `end`, the +/// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the +/// size of the fast buffer allocated. +static LogicalResult generateCopy( + const MemRefRegion ®ion, Block *block, Block::iterator begin, + Block::iterator end, Block *copyPlacementBlock, + Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart, + AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap, + DenseSet<Operation *> ©Nests, uint64_t *sizeInBytes, + Block::iterator *nBegin, Block::iterator *nEnd) { + *nBegin = begin; + *nEnd = end; + + FuncOp f = begin->getParentOfType<FuncOp>(); + OpBuilder topBuilder(f.getBody()); + Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0); + + if (begin == end) + return success(); + + // Is the copy out point at the end of the block where we are doing + // explicit copying. + bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart); + + // Copies for read regions are going to be inserted at 'begin'. + OpBuilder prologue(copyPlacementBlock, copyInPlacementStart); + // Copies for write regions are going to be inserted at 'end'. + OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart); + OpBuilder &b = region.isWrite() ? epilogue : prologue; + + // Builder to create constants at the top level. + auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>(); + OpBuilder top(func.getBody()); + + auto loc = region.loc; + auto memref = region.memref; + auto memRefType = memref->getType().cast<MemRefType>(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return failure(); + } + + // Indices to use for the copying. + // Indices for the original memref being copied from/to. + SmallVector<Value, 4> memIndices; + // Indices for the faster buffer being copied into/from. + SmallVector<Value, 4> bufIndices; + + unsigned rank = memRefType.getRank(); + SmallVector<int64_t, 4> fastBufferShape; + + // Compute the extents of the buffer. + std::vector<SmallVector<int64_t, 4>> lbs; + SmallVector<int64_t, 8> lbDivisors; + lbs.reserve(rank); + Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape( + &fastBufferShape, &lbs, &lbDivisors); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); + return failure(); + } + + if (numElements.getValue() == 0) { + LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n"); + *sizeInBytes = 0; + return success(); + } + + const FlatAffineConstraints *cst = region.getConstraints(); + // 'regionSymbols' hold values that this memory region is symbolic/parametric + // on; these typically include loop IVs surrounding the level at which the + // copy generation is being done or other valid symbols in MLIR. + SmallVector<Value, 8> regionSymbols; + cst->getIdValues(rank, cst->getNumIds(), ®ionSymbols); + + // Construct the index expressions for the fast memory buffer. The index + // expression for a particular dimension of the fast buffer is obtained by + // subtracting out the lower bound on the original memref's data region + // along the corresponding dimension. + + // Index start offsets for faster memory buffer relative to the original. + SmallVector<AffineExpr, 4> offsets; + offsets.reserve(rank); + for (unsigned d = 0; d < rank; d++) { + assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); + + AffineExpr offset = top.getAffineConstantExpr(0); + for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { + offset = offset + lbs[d][j] * top.getAffineDimExpr(j); + } + assert(lbDivisors[d] > 0); + offset = + (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); + + // Set copy start location for this dimension in the lower memory space + // memref. + if (auto caf = offset.dyn_cast<AffineConstantExpr>()) { + auto indexVal = caf.getValue(); + if (indexVal == 0) { + memIndices.push_back(zeroIndex); + } else { + memIndices.push_back( + top.create<ConstantIndexOp>(loc, indexVal).getResult()); + } + } else { + // The coordinate for the start location is just the lower bound along the + // corresponding dimension on the memory region (stored in 'offset'). + auto map = AffineMap::get( + cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset); + memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols)); + } + // The fast buffer is copied into at location zero; addressing is relative. + bufIndices.push_back(zeroIndex); + + // Record the offsets since they are needed to remap the memory accesses of + // the original memref further below. + offsets.push_back(offset); + } + + // The faster memory space buffer. + Value fastMemRef; + + // Check if a buffer was already created. + bool existingBuf = fastBufferMap.count(memref) > 0; + if (!existingBuf) { + AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank); + auto fastMemRefType = + MemRefType::get(fastBufferShape, memRefType.getElementType(), + fastBufferLayout, copyOptions.fastMemorySpace); + + // Create the fast memory space buffer just before the 'affine.for' + // operation. + fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult(); + // Record it. + fastBufferMap[memref] = fastMemRef; + // fastMemRefType is a constant shaped memref. + *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue(); + LLVM_DEBUG(emitRemarkForBlock(*block) + << "Creating fast buffer of type " << fastMemRefType + << " and size " << llvm::divideCeil(*sizeInBytes, 1024) + << " KiB\n"); + } else { + // Reuse the one already created. + fastMemRef = fastBufferMap[memref]; + *sizeInBytes = 0; + } + + auto numElementsSSA = + top.create<ConstantIndexOp>(loc, numElements.getValue()); + + SmallVector<StrideInfo, 4> strideInfos; + getMultiLevelStrides(region, fastBufferShape, &strideInfos); + + // TODO(bondhugula): use all stride levels once DmaStartOp is extended for + // multi-level strides. + if (strideInfos.size() > 1) { + LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n"); + return failure(); + } + + Value stride = nullptr; + Value numEltPerStride = nullptr; + if (!strideInfos.empty()) { + stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride); + numEltPerStride = + top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride); + } + + // Record the last operation where we want the memref replacement to end. We + // later do the memref replacement only in [begin, postDomFilter] so + // that the original memref's used in the data movement code themselves don't + // get replaced. + auto postDomFilter = std::prev(end); + + // Create fully composed affine maps for each memref. + auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size()); + fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices); + auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size()); + fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices); + + if (!copyOptions.generateDma) { + // Point-wise copy generation. + auto copyNest = generatePointWiseCopy(loc, memref, fastMemRef, memAffineMap, + memIndices, fastBufferShape, + /*isCopyOut=*/region.isWrite(), b); + + // Record this so that we can skip it from yet another copy. + copyNests.insert(copyNest); + + // Since new ops are being appended (for copy out's), adjust the end to + // mark end of block range being processed if necessary. + if (region.isWrite() && isCopyOutAtEndOfBlock) + *nEnd = Block::iterator(copyNest.getOperation()); + } else { + // DMA generation. + // Create a tag (single element 1-d memref) for the DMA. + auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, + copyOptions.tagMemorySpace); + auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType); + + SmallVector<Value, 4> tagIndices({zeroIndex}); + auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); + fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); + if (!region.isWrite()) { + // DMA non-blocking read from original buffer to fast buffer. + b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices, + fastMemRef, bufAffineMap, bufIndices, + tagMemRef, tagAffineMap, tagIndices, + numElementsSSA, stride, numEltPerStride); + } else { + // DMA non-blocking write from fast buffer to the original memref. + auto op = b.create<AffineDmaStartOp>( + loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, + memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, + stride, numEltPerStride); + // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the + // end to mark end of block range being processed. + if (isCopyOutAtEndOfBlock) + *nEnd = Block::iterator(op.getOperation()); + } + + // Matching DMA wait to block on completion; tag always has a 0 index. + b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex, + numElementsSSA); + + // Generate dealloc for the tag. + auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef); + if (*nEnd == end && isCopyOutAtEndOfBlock) + // Since new ops are being appended (for outgoing DMAs), adjust the end to + // mark end of range of the original. + *nEnd = Block::iterator(tagDeallocOp.getOperation()); + } + + // Generate dealloc for the buffer. + if (!existingBuf) { + auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef); + // When generating pointwise copies, `nEnd' has to be set to deallocOp on + // the fast buffer (since it marks the new end insertion point). + if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock) + *nEnd = Block::iterator(bufDeallocOp.getOperation()); + } + + // Replace all uses of the old memref with the faster one while remapping + // access indices (subtracting out lower bound offsets for each dimension). + // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT], + // index remap will be (%i, %j) -> (%i - %iT, %j - %jT), + // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j), + // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'. + // d2, d3 correspond to the original indices (%i, %j). + SmallVector<AffineExpr, 4> remapExprs; + remapExprs.reserve(rank); + for (unsigned i = 0; i < rank; i++) { + // The starting operands of indexRemap will be regionSymbols (the symbols on + // which the memref region is parametric); then those corresponding to + // the memref's original indices follow. + auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i); + remapExprs.push_back(dimExpr - offsets[i]); + } + auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs); + + // Record the begin since it may be invalidated by memref replacement. + Block::iterator prevOfBegin; + bool isBeginAtStartOfBlock = (begin == block->begin()); + if (!isBeginAtStartOfBlock) + prevOfBegin = std::prev(begin); + + // *Only* those uses within the range [begin, end) of 'block' are replaced. + replaceAllMemRefUsesWith(memref, fastMemRef, + /*extraIndices=*/{}, indexRemap, + /*extraOperands=*/regionSymbols, + /*symbolOperands=*/{}, + /*domInstFilter=*/&*begin, + /*postDomInstFilter=*/&*postDomFilter); + + *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin); + + return success(); +} + +/// Construct the memref region to just include the entire memref. Returns false +/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of +/// enclosing loop IVs of opInst (starting from the outermost) that the region +/// is parametric on. +static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, + MemRefRegion *region) { + unsigned rank; + if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) { + rank = loadOp.getMemRefType().getRank(); + region->memref = loadOp.getMemRef(); + region->setWrite(false); + } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) { + rank = storeOp.getMemRefType().getRank(); + region->memref = storeOp.getMemRef(); + region->setWrite(true); + } else { + assert(false && "expected load or store op"); + return false; + } + auto memRefType = region->memref->getType().cast<MemRefType>(); + if (!memRefType.hasStaticShape()) + return false; + + auto *regionCst = region->getConstraints(); + + // Just get the first numSymbols IVs, which the memref region is parametric + // on. + SmallVector<AffineForOp, 4> ivs; + getLoopIVs(*opInst, &ivs); + ivs.resize(numParamLoopIVs); + SmallVector<Value, 4> symbols; + extractForInductionVars(ivs, &symbols); + regionCst->reset(rank, numParamLoopIVs, 0); + regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols); + + // Memref dim sizes provide the bounds. + for (unsigned d = 0; d < rank; d++) { + auto dimSize = memRefType.getDimSize(d); + assert(dimSize > 0 && "filtered dynamic shapes above"); + regionCst->addConstantLowerBound(d, 0); + regionCst->addConstantUpperBound(d, dimSize - 1); + } + return true; +} + +/// Generates copies for a contiguous sequence of operations in `block` in the +/// iterator range [`begin', `end'), where `end' can't be past the terminator of +/// the block (since additional operations are potentially inserted right before +/// `end'. Returns the total size of the fast buffers used. +// Since we generate alloc's and dealloc's for all fast buffers (before and +// after the range of operations resp.), all of the fast memory capacity is +// assumed to be available for processing this block range. +uint64_t mlir::affineDataCopyGenerate(Block::iterator begin, + Block::iterator end, + const AffineCopyOptions ©Options, + DenseSet<Operation *> ©Nests) { + if (begin == end) + return 0; + + assert(begin->getBlock() == std::prev(end)->getBlock() && + "Inconsistent block begin/end args"); + assert(end != end->getBlock()->end() && "end can't be the block terminator"); + + Block *block = begin->getBlock(); + + // Copies will be generated for this depth, i.e., symbolic in all loops + // surrounding the this block range. + unsigned copyDepth = getNestingDepth(*begin); + + LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n"); + LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n"); + + // List of memory regions to copy for. We need a map vector to have a + // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here + // since the alloc's for example are identical except for the SSA id. + SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions; + SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions; + + // Map from original memref's to the fast buffers that their accesses are + // replaced with. + DenseMap<Value, Value> fastBufferMap; + + // To check for errors when walking the block. + bool error = false; + + // Walk this range of operations to gather all memory regions. + block->walk(begin, end, [&](Operation *opInst) { + // Gather regions to allocate to buffers in faster memory space. + if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) { + if ((loadOp.getMemRefType().getMemorySpace() != + copyOptions.slowMemorySpace)) + return; + } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) { + if (storeOp.getMemRefType().getMemorySpace() != + copyOptions.slowMemorySpace) + return; + } else { + // Neither load nor a store op. + return; + } + + // Compute the MemRefRegion accessed. + auto region = std::make_unique<MemRefRegion>(opInst->getLoc()); + if (failed(region->compute(opInst, copyDepth))) { + LLVM_DEBUG(llvm::dbgs() + << "Error obtaining memory region: semi-affine maps?\n"); + LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); + if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { + LLVM_DEBUG( + opInst->emitError("non-constant memref sizes not yet supported")); + error = true; + return; + } + } + + // Each memref has a single buffer associated with it irrespective of how + // many load's and store's happen on it. + // TODO(bondhugula): in the future, when regions don't intersect and satisfy + // other properties (based on load/store regions), we could consider + // multiple buffers per memref. + + // Add to the appropriate region if it's not already in it, or take a + // bounding box union with the existing one if it's already in there. + // Note that a memref may have both read and write regions - so update the + // region in the other list if one exists (write in case of read and vice + // versa) since there is a single bounding box for a memref across all reads + // and writes that happen on it. + + // Attempts to update; returns true if 'region' exists in targetRegions. + auto updateRegion = + [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> + &targetRegions) { + auto it = targetRegions.find(region->memref); + if (it == targetRegions.end()) + return false; + + // Perform a union with the existing region. + if (failed(it->second->unionBoundingBox(*region))) { + LLVM_DEBUG(llvm::dbgs() + << "Memory region bounding box failed; " + "over-approximating to the entire memref\n"); + // If the union fails, we will overapproximate. + if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { + LLVM_DEBUG(opInst->emitError( + "non-constant memref sizes not yet supported")); + error = true; + return true; + } + it->second->getConstraints()->clearAndCopyFrom( + *region->getConstraints()); + } else { + // Union was computed and stored in 'it->second': copy to 'region'. + region->getConstraints()->clearAndCopyFrom( + *it->second->getConstraints()); + } + return true; + }; + + bool existsInRead = updateRegion(readRegions); + if (error) + return; + bool existsInWrite = updateRegion(writeRegions); + if (error) + return; + + // Finally add it to the region list. + if (region->isWrite() && !existsInWrite) { + writeRegions[region->memref] = std::move(region); + } else if (!region->isWrite() && !existsInRead) { + readRegions[region->memref] = std::move(region); + } + }); + + if (error) { + begin->emitError( + "copy generation failed for one or more memref's in this block\n"); + return 0; + } + + uint64_t totalCopyBuffersSizeInBytes = 0; + bool ret = true; + auto processRegions = + [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> + ®ions) { + for (const auto ®ionEntry : regions) { + // For each region, hoist copy in/out past all hoistable + // 'affine.for's. + Block::iterator copyInPlacementStart, copyOutPlacementStart; + Block *copyPlacementBlock; + findHighestBlockForPlacement( + *regionEntry.second, *block, begin, end, ©PlacementBlock, + ©InPlacementStart, ©OutPlacementStart); + + uint64_t sizeInBytes; + Block::iterator nBegin, nEnd; + LogicalResult iRet = generateCopy( + *regionEntry.second, block, begin, end, copyPlacementBlock, + copyInPlacementStart, copyOutPlacementStart, copyOptions, + fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd); + if (succeeded(iRet)) { + // begin/end could have been invalidated, and need update. + begin = nBegin; + end = nEnd; + totalCopyBuffersSizeInBytes += sizeInBytes; + } + ret = ret & succeeded(iRet); + } + }; + processRegions(readRegions); + processRegions(writeRegions); + + if (!ret) { + begin->emitError( + "copy generation failed for one or more memref's in this block\n"); + return totalCopyBuffersSizeInBytes; + } + + // For a range of operations, a note will be emitted at the caller. + AffineForOp forOp; + uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024); + if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) { + forOp.emitRemark() + << sizeInKib + << " KiB of copy buffers in fast memory space for this block\n"; + } + + if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) { + StringRef str = "Total size of all copy buffers' for this block " + "exceeds fast memory capacity\n"; + block->getParentOp()->emitError(str); + } + + return totalCopyBuffersSizeInBytes; +} diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp new file mode 100644 index 00000000000..ca26074f288 --- /dev/null +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -0,0 +1,348 @@ +//===- RegionUtils.cpp - Region-related transformation utilities ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/RegionUtils.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; + +void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, + Region ®ion) { + for (auto &use : llvm::make_early_inc_range(orig->getUses())) { + if (region.isAncestor(use.getOwner()->getParentRegion())) + use.set(replacement); + } +} + +void mlir::visitUsedValuesDefinedAbove( + Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { + assert(limit.isAncestor(®ion) && + "expected isolation limit to be an ancestor of the given region"); + + // Collect proper ancestors of `limit` upfront to avoid traversing the region + // tree for every value. + SmallPtrSet<Region *, 4> properAncestors; + for (auto *reg = limit.getParentRegion(); reg != nullptr; + reg = reg->getParentRegion()) { + properAncestors.insert(reg); + } + + region.walk([callback, &properAncestors](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) + // Callback on values defined in a proper ancestor of region. + if (properAncestors.count(operand.get()->getParentRegion())) + callback(&operand); + }); +} + +void mlir::visitUsedValuesDefinedAbove( + MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { + for (Region ®ion : regions) + visitUsedValuesDefinedAbove(region, region, callback); +} + +void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, + llvm::SetVector<Value> &values) { + visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { + values.insert(operand->get()); + }); +} + +void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, + llvm::SetVector<Value> &values) { + for (Region ®ion : regions) + getUsedValuesDefinedAbove(region, region, values); +} + +//===----------------------------------------------------------------------===// +// Unreachable Block Elimination +//===----------------------------------------------------------------------===// + +/// Erase the unreachable blocks within the provided regions. Returns success +/// if any blocks were erased, failure otherwise. +// TODO: We could likely merge this with the DCE algorithm below. +static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) { + // Set of blocks found to be reachable within a given region. + llvm::df_iterator_default_set<Block *, 16> reachable; + // If any blocks were found to be dead. + bool erasedDeadBlocks = false; + + SmallVector<Region *, 1> worklist; + worklist.reserve(regions.size()); + for (Region ®ion : regions) + worklist.push_back(®ion); + while (!worklist.empty()) { + Region *region = worklist.pop_back_val(); + if (region->empty()) + continue; + + // If this is a single block region, just collect the nested regions. + if (std::next(region->begin()) == region->end()) { + for (Operation &op : region->front()) + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); + continue; + } + + // Mark all reachable blocks. + reachable.clear(); + for (Block *block : depth_first_ext(®ion->front(), reachable)) + (void)block /* Mark all reachable blocks */; + + // Collect all of the dead blocks and push the live regions onto the + // worklist. + for (Block &block : llvm::make_early_inc_range(*region)) { + if (!reachable.count(&block)) { + block.dropAllDefinedValueUses(); + block.erase(); + erasedDeadBlocks = true; + continue; + } + + // Walk any regions within this block. + for (Operation &op : block) + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); + } + } + + return success(erasedDeadBlocks); +} + +//===----------------------------------------------------------------------===// +// Dead Code Elimination +//===----------------------------------------------------------------------===// + +namespace { +/// Data structure used to track which values have already been proved live. +/// +/// Because Operation's can have multiple results, this data structure tracks +/// liveness for both Value's and Operation's to avoid having to look through +/// all Operation results when analyzing a use. +/// +/// This data structure essentially tracks the dataflow lattice. +/// The set of values/ops proved live increases monotonically to a fixed-point. +class LiveMap { +public: + /// Value methods. + bool wasProvenLive(Value value) { return liveValues.count(value); } + void setProvedLive(Value value) { + changed |= liveValues.insert(value).second; + } + + /// Operation methods. + bool wasProvenLive(Operation *op) { return liveOps.count(op); } + void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } + + /// Methods for tracking if we have reached a fixed-point. + void resetChanged() { changed = false; } + bool hasChanged() { return changed; } + +private: + bool changed = false; + DenseSet<Value> liveValues; + DenseSet<Operation *> liveOps; +}; +} // namespace + +static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { + Operation *owner = use.getOwner(); + unsigned operandIndex = use.getOperandNumber(); + // This pass generally treats all uses of an op as live if the op itself is + // considered live. However, for successor operands to terminators we need a + // finer-grained notion where we deduce liveness for operands individually. + // The reason for this is easiest to think about in terms of a classical phi + // node based SSA IR, where each successor operand is really an operand to a + // *separate* phi node, rather than all operands to the branch itself as with + // the block argument representation that MLIR uses. + // + // And similarly, because each successor operand is really an operand to a phi + // node, rather than to the terminator op itself, a terminator op can't e.g. + // "print" the value of a successor operand. + if (owner->isKnownTerminator()) { + if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) + return !liveMap.wasProvenLive(*arg); + return false; + } + return false; +} + +static void processValue(Value value, LiveMap &liveMap) { + bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) { + if (isUseSpeciallyKnownDead(use, liveMap)) + return false; + return liveMap.wasProvenLive(use.getOwner()); + }); + if (provedLive) + liveMap.setProvedLive(value); +} + +static bool isOpIntrinsicallyLive(Operation *op) { + // This pass doesn't modify the CFG, so terminators are never deleted. + if (!op->isKnownNonTerminator()) + return true; + // If the op has a side effect, we treat it as live. + if (!op->hasNoSideEffect()) + return true; + return false; +} + +static void propagateLiveness(Region ®ion, LiveMap &liveMap); +static void propagateLiveness(Operation *op, LiveMap &liveMap) { + // All Value's are either a block argument or an op result. + // We call processValue on those cases. + + // Recurse on any regions the op has. + for (Region ®ion : op->getRegions()) + propagateLiveness(region, liveMap); + + // Process the op itself. + if (isOpIntrinsicallyLive(op)) { + liveMap.setProvedLive(op); + return; + } + for (Value value : op->getResults()) + processValue(value, liveMap); + bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { + return liveMap.wasProvenLive(value); + }); + if (provedLive) + liveMap.setProvedLive(op); +} + +static void propagateLiveness(Region ®ion, LiveMap &liveMap) { + if (region.empty()) + return; + + for (Block *block : llvm::post_order(®ion.front())) { + // We process block arguments after the ops in the block, to promote + // faster convergence to a fixed point (we try to visit uses before defs). + for (Operation &op : llvm::reverse(block->getOperations())) + propagateLiveness(&op, liveMap); + for (Value value : block->getArguments()) + processValue(value, liveMap); + } +} + +static void eraseTerminatorSuccessorOperands(Operation *terminator, + LiveMap &liveMap) { + for (unsigned succI = 0, succE = terminator->getNumSuccessors(); + succI < succE; succI++) { + // Iterating successors in reverse is not strictly needed, since we + // aren't erasing any successors. But it is slightly more efficient + // since it will promote later operands of the terminator being erased + // first, reducing the quadratic-ness. + unsigned succ = succE - succI - 1; + for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); + argI < argE; argI++) { + // Iterating args in reverse is needed for correctness, to avoid + // shifting later args when earlier args are erased. + unsigned arg = argE - argI - 1; + Value value = terminator->getSuccessor(succ)->getArgument(arg); + if (!liveMap.wasProvenLive(value)) { + terminator->eraseSuccessorOperand(succ, arg); + } + } + } +} + +static LogicalResult deleteDeadness(MutableArrayRef<Region> regions, + LiveMap &liveMap) { + bool erasedAnything = false; + for (Region ®ion : regions) { + if (region.empty()) + continue; + + // We do the deletion in an order that deletes all uses before deleting + // defs. + // MLIR's SSA structural invariants guarantee that except for block + // arguments, the use-def graph is acyclic, so this is possible with a + // single walk of ops and then a final pass to clean up block arguments. + // + // To do this, we visit ops in an order that visits domtree children + // before domtree parents. A CFG post-order (with reverse iteration with a + // block) satisfies that without needing an explicit domtree calculation. + for (Block *block : llvm::post_order(®ion.front())) { + eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); + for (Operation &childOp : + llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { + erasedAnything |= + succeeded(deleteDeadness(childOp.getRegions(), liveMap)); + if (!liveMap.wasProvenLive(&childOp)) { + erasedAnything = true; + childOp.erase(); + } + } + } + // Delete block arguments. + // The entry block has an unknown contract with their enclosing block, so + // skip it. + for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { + // Iterate in reverse to avoid shifting later arguments when deleting + // earlier arguments. + for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) + if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { + block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); + erasedAnything = true; + } + } + } + return success(erasedAnything); +} + +// This function performs a simple dead code elimination algorithm over the +// given regions. +// +// The overall goal is to prove that Values are dead, which allows deleting ops +// and block arguments. +// +// This uses an optimistic algorithm that assumes everything is dead until +// proved otherwise, allowing it to delete recursively dead cycles. +// +// This is a simple fixed-point dataflow analysis algorithm on a lattice +// {Dead,Alive}. Because liveness flows backward, we generally try to +// iterate everything backward to speed up convergence to the fixed-point. This +// allows for being able to delete recursively dead cycles of the use-def graph, +// including block arguments. +// +// This function returns success if any operations or arguments were deleted, +// failure otherwise. +static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { + assert(regions.size() == 1); + + LiveMap liveMap; + do { + liveMap.resetChanged(); + + for (Region ®ion : regions) + propagateLiveness(region, liveMap); + } while (liveMap.hasChanged()); + + return deleteDeadness(regions, liveMap); +} + +//===----------------------------------------------------------------------===// +// Region Simplification +//===----------------------------------------------------------------------===// + +/// Run a set of structural simplifications over the given regions. This +/// includes transformations like unreachable block elimination, dead argument +/// elimination, as well as some other DCE. This function returns success if any +/// of the regions were simplified, failure otherwise. +LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) { + LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); + LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); + return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); +} diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp new file mode 100644 index 00000000000..a6629183dee --- /dev/null +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -0,0 +1,469 @@ +//===- Utils.cpp ---- Misc utilities for code and data transformation -----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous transformation routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Utils.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Dominance.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/DenseMap.h" +using namespace mlir; + +/// Return true if this operation dereferences one or more memref's. +// Temporary utility: will be replaced when this is modeled through +// side-effects/op traits. TODO(b/117228571) +static bool isMemRefDereferencingOp(Operation &op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || + isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) + return true; + return false; +} + +/// Return the AffineMapAttr associated with memory 'op' on 'memref'. +static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { + return TypeSwitch<Operation *, NamedAttribute>(op) + .Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp, + AffineDmaWaitOp>( + [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); +} + +// Perform the replacement in `op`. +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + Operation *op, + ArrayRef<Value> extraIndices, + AffineMap indexRemap, + ArrayRef<Value> extraOperands, + ArrayRef<Value> symbolOperands) { + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); + (void)oldMemRefRank; // unused in opt mode + if (indexRemap) { + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbolic operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); + + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases. + return failure(); + + SmallVector<unsigned, 2> usePositions; + for (const auto &opEntry : llvm::enumerate(op->getOperands())) { + if (opEntry.value() == oldMemRef) + usePositions.push_back(opEntry.index()); + } + + // If memref doesn't appear, nothing to do. + if (usePositions.empty()) + return success(); + + if (usePositions.size() > 1) { + // TODO(mlir-team): extend it for this case when needed (rare). + assert(false && "multiple dereferencing uses in a single op not supported"); + return failure(); + } + + unsigned memRefOperandPos = usePositions.front(); + + OpBuilder builder(op); + NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); + AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue(); + unsigned oldMapNumInputs = oldMap.getNumInputs(); + SmallVector<Value, 4> oldMapOperands( + op->operand_begin() + memRefOperandPos + 1, + op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + + // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. + SmallVector<Value, 4> oldMemRefOperands; + SmallVector<Value, 4> affineApplyOps; + oldMemRefOperands.reserve(oldMemRefRank); + if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { + for (auto resultExpr : oldMap.getResults()) { + auto singleResMap = AffineMap::get(oldMap.getNumDims(), + oldMap.getNumSymbols(), resultExpr); + auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, + oldMapOperands); + oldMemRefOperands.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + } + + // Construct new indices as a remap of the old ones if a remapping has been + // provided. The indices of a memref come right after it, i.e., + // at position memRefOperandPos + 1. + SmallVector<Value, 4> remapOperands; + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); + remapOperands.append(extraOperands.begin(), extraOperands.end()); + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); + + SmallVector<Value, 4> remapOutputs; + remapOutputs.reserve(oldMemRefRank); + + if (indexRemap && + indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { + // Remapped indices. + for (auto resultExpr : indexRemap.getResults()) { + auto singleResMap = AffineMap::get( + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); + auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, + remapOperands); + remapOutputs.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + // No remapping specified. + remapOutputs.append(remapOperands.begin(), remapOperands.end()); + } + + SmallVector<Value, 4> newMapOperands; + newMapOperands.reserve(newMemRefRank); + + // Prepend 'extraIndices' in 'newMapOperands'. + for (auto extraIndex : extraIndices) { + assert(extraIndex->getDefiningOp()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && + "invalid memory op index"); + newMapOperands.push_back(extraIndex); + } + + // Append 'remapOutputs' to 'newMapOperands'. + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); + + // Create new fully composed AffineMap for new op to be created. + assert(newMapOperands.size() == newMemRefRank); + auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); + // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. + fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); + newMap = simplifyAffineMap(newMap); + canonicalizeMapAndOperands(&newMap, &newMapOperands); + // Remove any affine.apply's that became dead as a result of composition. + for (auto value : affineApplyOps) + if (value->use_empty()) + value->getDefiningOp()->erase(); + + // Construct the new operation using this memref. + OperationState state(op->getLoc(), op->getName()); + state.setOperandListToResizable(op->hasResizableOperandsList()); + state.operands.reserve(op->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + state.operands.append(op->operand_begin(), + op->operand_begin() + memRefOperandPos); + // Insert the new memref value. + state.operands.push_back(newMemRef); + + // Insert the new memref map operands. + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + + // Insert the remaining operands unmodified. + state.operands.append(op->operand_begin() + memRefOperandPos + 1 + + oldMapNumInputs, + op->operand_end()); + + // Result types don't change. Both memref's are of the same elemental type. + state.types.reserve(op->getNumResults()); + for (auto result : op->getResults()) + state.types.push_back(result->getType()); + + // Add attribute for 'newMap', other Attributes do not change. + auto newMapAttr = AffineMapAttr::get(newMap); + for (auto namedAttr : op->getAttrs()) { + if (namedAttr.first == oldMapAttrPair.first) { + state.attributes.push_back({namedAttr.first, newMapAttr}); + } else { + state.attributes.push_back(namedAttr); + } + } + + // Create the new operation. + auto *repOp = builder.createOperation(state); + op->replaceAllUsesWith(repOp); + op->erase(); + + return success(); +} + +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + ArrayRef<Value> extraIndices, + AffineMap indexRemap, + ArrayRef<Value> extraOperands, + ArrayRef<Value> symbolOperands, + Operation *domInstFilter, + Operation *postDomInstFilter) { + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); + (void)oldMemRefRank; + if (indexRemap) { + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbol operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); + + std::unique_ptr<DominanceInfo> domInfo; + std::unique_ptr<PostDominanceInfo> postDomInfo; + if (domInstFilter) + domInfo = std::make_unique<DominanceInfo>( + domInstFilter->getParentOfType<FuncOp>()); + + if (postDomInstFilter) + postDomInfo = std::make_unique<PostDominanceInfo>( + postDomInstFilter->getParentOfType<FuncOp>()); + + // Walk all uses of old memref; collect ops to perform replacement. We use a + // DenseSet since an operation could potentially have multiple uses of a + // memref (although rare), and the replacement later is going to erase ops. + DenseSet<Operation *> opsToReplace; + for (auto *op : oldMemRef->getUsers()) { + // Skip this use if it's not dominated by domInstFilter. + if (domInstFilter && !domInfo->dominates(domInstFilter, op)) + continue; + + // Skip this use if it's not post-dominated by postDomInstFilter. + if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) + continue; + + // Skip dealloc's - no replacement is necessary, and a memref replacement + // at other uses doesn't hurt these dealloc's. + if (isa<DeallocOp>(op)) + continue; + + // Check if the memref was used in a non-dereferencing context. It is fine + // for the memref to be used in a non-dereferencing way outside of the + // region where this replacement is happening. + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing op (potentially escapes); + // no replacement in these cases. + return failure(); + + // We'll first collect and then replace --- since replacement erases the op + // that has the use, and that op could be postDomFilter or domFilter itself! + opsToReplace.insert(op); + } + + for (auto *op : opsToReplace) { + if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, + indexRemap, extraOperands, + symbolOperands))) + llvm_unreachable("memref replacement guaranteed to succeed here"); + } + + return success(); +} + +/// Given an operation, inserts one or more single result affine +/// apply operations, results of which are exclusively used by this operation +/// operation. The operands of these newly created affine apply ops are +/// guaranteed to be loop iterators or terminal symbols of a function. +/// +/// Before +/// +/// affine.for %i = 0 to #map(%N) +/// %idx = affine.apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// "compute"(%idx) +/// +/// After +/// +/// affine.for %i = 0 to #map(%N) +/// %idx = affine.apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) +/// "compute"(%idx_) +/// +/// This allows applying different transformations on send and compute (for eg. +/// different shifts/delays). +/// +/// Returns nullptr either if none of opInst's operands were the result of an +/// affine.apply and thus there was no affine computation slice to create, or if +/// all the affine.apply op's supplying operands to this opInst did not have any +/// uses besides this opInst; otherwise returns the list of affine.apply +/// operations created in output argument `sliceOps`. +void mlir::createAffineComputationSlice( + Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { + // Collect all operands that are results of affine apply ops. + SmallVector<Value, 4> subOperands; + subOperands.reserve(opInst->getNumOperands()); + for (auto operand : opInst->getOperands()) + if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp())) + subOperands.push_back(operand); + + // Gather sequence of AffineApplyOps reachable from 'subOperands'. + SmallVector<Operation *, 4> affineApplyOps; + getReachableAffineApplyOps(subOperands, affineApplyOps); + // Skip transforming if there are no affine maps to compose. + if (affineApplyOps.empty()) + return; + + // Check if all uses of the affine apply op's lie only in this op op, in + // which case there would be nothing to do. + bool localized = true; + for (auto *op : affineApplyOps) { + for (auto result : op->getResults()) { + for (auto *user : result->getUsers()) { + if (user != opInst) { + localized = false; + break; + } + } + } + } + if (localized) + return; + + OpBuilder builder(opInst); + SmallVector<Value, 4> composedOpOperands(subOperands); + auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); + fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); + + // Create an affine.apply for each of the map results. + sliceOps->reserve(composedMap.getNumResults()); + for (auto resultExpr : composedMap.getResults()) { + auto singleResMap = AffineMap::get(composedMap.getNumDims(), + composedMap.getNumSymbols(), resultExpr); + sliceOps->push_back(builder.create<AffineApplyOp>( + opInst->getLoc(), singleResMap, composedOpOperands)); + } + + // Construct the new operands that include the results from the composed + // affine apply op above instead of existing ones (subOperands). So, they + // differ from opInst's operands only for those operands in 'subOperands', for + // which they will be replaced by the corresponding one from 'sliceOps'. + SmallVector<Value, 4> newOperands(opInst->getOperands()); + for (unsigned i = 0, e = newOperands.size(); i < e; i++) { + // Replace the subOperands from among the new operands. + unsigned j, f; + for (j = 0, f = subOperands.size(); j < f; j++) { + if (newOperands[i] == subOperands[j]) + break; + } + if (j < subOperands.size()) { + newOperands[i] = (*sliceOps)[j]; + } + } + for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { + opInst->setOperand(idx, newOperands[idx]); + } +} + +// TODO: Currently works for static memrefs with a single layout map. +LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { + MemRefType memrefType = allocOp.getType(); + unsigned rank = memrefType.getRank(); + if (rank == 0) + return success(); + + auto layoutMaps = memrefType.getAffineMaps(); + OpBuilder b(allocOp); + if (layoutMaps.size() != 1) + return failure(); + + AffineMap layoutMap = layoutMaps.front(); + + // Nothing to do for identity layout maps. + if (layoutMap == b.getMultiDimIdentityMap(rank)) + return success(); + + // We don't do any checks for one-to-one'ness; we assume that it is + // one-to-one. + + // TODO: Only for static memref's for now. + if (memrefType.getNumDynamicDims() > 0) + return failure(); + + // We have a single map that is not an identity map. Create a new memref with + // the right shape and an identity layout map. + auto shape = memrefType.getShape(); + FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands()); + for (unsigned d = 0; d < rank; ++d) { + fac.addConstantLowerBound(d, 0); + fac.addConstantUpperBound(d, shape[d] - 1); + } + + // We compose this map with the original index (logical) space to derive the + // upper bounds for the new index space. + unsigned newRank = layoutMap.getNumResults(); + if (failed(fac.composeMatchingMap(layoutMap))) + // TODO: semi-affine maps. + return failure(); + + // Project out the old data dimensions. + fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); + SmallVector<int64_t, 4> newShape(newRank); + for (unsigned d = 0; d < newRank; ++d) { + // The lower bound for the shape is always zero. + auto ubConst = fac.getConstantUpperBound(d); + // For a static memref and an affine map with no symbols, this is always + // bounded. + assert(ubConst.hasValue() && "should always have an upper bound"); + if (ubConst.getValue() < 0) + // This is due to an invalid map that maps to a negative space. + return failure(); + newShape[d] = ubConst.getValue() + 1; + } + + auto oldMemRef = allocOp.getResult(); + SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands()); + + auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), + b.getMultiDimIdentityMap(newRank)); + auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType); + + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands))) { + // If it failed (due to escapes for example), bail out. + newAlloc.erase(); + return failure(); + } + // Replace any uses of the original alloc op and erase it. All remaining uses + // have to be dealloc's; RAMUW above would've failed otherwise. + assert(std::all_of(oldMemRef->user_begin(), oldMemRef->user_end(), + [](Operation *op) { return isa<DeallocOp>(op); })); + oldMemRef->replaceAllUsesWith(newAlloc); + allocOp.erase(); + return success(); +} diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp new file mode 100644 index 00000000000..6b2b3e1ee7e --- /dev/null +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -0,0 +1,1292 @@ +//===- Vectorize.cpp - Vectorize Pass Impl --------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements vectorization of loops, operations and data types to +// a target-independent, n-D super-vector abstraction. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/NestedMatcher.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +/// +/// Implements a high-level vectorization strategy on a Function. +/// The abstraction used is that of super-vectors, which provide a single, +/// compact, representation in the vector types, information that is expected +/// to reduce the impact of the phase ordering problem +/// +/// Vector granularity: +/// =================== +/// This pass is designed to perform vectorization at a super-vector +/// granularity. A super-vector is loosely defined as a vector type that is a +/// multiple of a "good" vector size so the HW can efficiently implement a set +/// of high-level primitives. Multiple is understood along any dimension; e.g. +/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a +/// vector<8xf32> HW vector. Note that a "good vector size so the HW can +/// efficiently implement a set of high-level primitives" is not necessarily an +/// integer multiple of actual hardware registers. We leave details of this +/// distinction unspecified for now. +/// +/// Some may prefer the terminology a "tile of HW vectors". In this case, one +/// should note that super-vectors implement an "always full tile" abstraction. +/// They guarantee no partial-tile separation is necessary by relying on a +/// high-level copy-reshape abstraction that we call vector.transfer. This +/// copy-reshape operations is also responsible for performing layout +/// transposition if necessary. In the general case this will require a scoped +/// allocation in some notional local memory. +/// +/// Whatever the mental model one prefers to use for this abstraction, the key +/// point is that we burn into a single, compact, representation in the vector +/// types, information that is expected to reduce the impact of the phase +/// ordering problem. Indeed, a vector type conveys information that: +/// 1. the associated loops have dependency semantics that do not prevent +/// vectorization; +/// 2. the associate loops have been sliced in chunks of static sizes that are +/// compatible with vector sizes (i.e. similar to unroll-and-jam); +/// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by +/// the +/// vector type and no vectorization hampering transformations can be +/// applied to them anymore; +/// 4. the underlying memrefs are accessed in some notional contiguous way +/// that allows loading into vectors with some amount of spatial locality; +/// In other words, super-vectorization provides a level of separation of +/// concern by way of opacity to subsequent passes. This has the effect of +/// encapsulating and propagating vectorization constraints down the list of +/// passes until we are ready to lower further. +/// +/// For a particular target, a notion of minimal n-d vector size will be +/// specified and vectorization targets a multiple of those. In the following +/// paragraph, let "k ." represent "a multiple of", to be understood as a +/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes +/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc). +/// +/// Some non-exhaustive notable super-vector sizes of interest include: +/// - CPU: vector<k . HW_vector_size>, +/// vector<k' . core_count x k . HW_vector_size>, +/// vector<socket_count x k' . core_count x k . HW_vector_size>; +/// - GPU: vector<k . warp_size>, +/// vector<k . warp_size x float2>, +/// vector<k . warp_size x float4>, +/// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes). +/// +/// Loops and operations are emitted that operate on those super-vector shapes. +/// Subsequent lowering passes will materialize to actual HW vector sizes. These +/// passes are expected to be (gradually) more target-specific. +/// +/// At a high level, a vectorized load in a loop will resemble: +/// ```mlir +/// affine.for %i = ? to ? step ? { +/// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32> +/// } +/// ``` +/// It is the responsibility of the implementation of vector.transfer_read to +/// materialize vector registers from the original scalar memrefs. A later (more +/// target-dependent) lowering pass will materialize to actual HW vector sizes. +/// This lowering may be occur at different times: +/// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp + +/// DmaWaitOp + vectorized operations for data transformations and shuffle; +/// thus opening opportunities for unrolling and pipelining. This is an +/// instance of library call "whiteboxing"; or +/// 2. later in the a target-specific lowering pass or hand-written library +/// call; achieving full separation of concerns. This is an instance of +/// library call; or +/// 3. a mix of both, e.g. based on a model. +/// In the future, these operations will expose a contract to constrain the +/// search on vectorization patterns and sizes. +/// +/// Occurrence of super-vectorization in the compiler flow: +/// ======================================================= +/// This is an active area of investigation. We start with 2 remarks to position +/// super-vectorization in the context of existing ongoing work: LLVM VPLAN +/// and LLVM SLP Vectorizer. +/// +/// LLVM VPLAN: +/// ----------- +/// The astute reader may have noticed that in the limit, super-vectorization +/// can be applied at a similar time and with similar objectives than VPLAN. +/// For instance, in the case of a traditional, polyhedral compilation-flow (for +/// instance, the PPCG project uses ISL to provide dependence analysis, +/// multi-level(scheduling + tiling), lifting footprint to fast memory, +/// communication synthesis, mapping, register optimizations) and before +/// unrolling. When vectorization is applied at this *late* level in a typical +/// polyhedral flow, and is instantiated with actual hardware vector sizes, +/// super-vectorization is expected to match (or subsume) the type of patterns +/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR +/// is higher level and our implementation should be significantly simpler. Also +/// note that in this mode, recursive patterns are probably a bit of an overkill +/// although it is reasonable to expect that mixing a bit of outer loop and +/// inner loop vectorization + unrolling will provide interesting choices to +/// MLIR. +/// +/// LLVM SLP Vectorizer: +/// -------------------- +/// Super-vectorization however is not meant to be usable in a similar fashion +/// to the SLP vectorizer. The main difference lies in the information that +/// both vectorizers use: super-vectorization examines contiguity of memory +/// references along fastest varying dimensions and loops with recursive nested +/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on +/// the other hand, performs flat pattern matching inside a single unrolled loop +/// body and stitches together pieces of load and store operations into full +/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture +/// innermost loop, control-flow dependent patterns that super-vectorization may +/// not be able to capture easily. In other words, super-vectorization does not +/// aim at replacing the SLP vectorizer and the two solutions are complementary. +/// +/// Ongoing investigations: +/// ----------------------- +/// We discuss the following *early* places where super-vectorization is +/// applicable and touch on the expected benefits and risks . We list the +/// opportunities in the context of the traditional polyhedral compiler flow +/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline +/// we expect to experiment with super-vectorization: +/// 1. Right after language lowering to MLIR: this is the earliest time where +/// super-vectorization is expected to be applied. At this level, all the +/// language/user/library-level annotations are available and can be fully +/// exploited. Examples include loop-type annotations (such as parallel, +/// reduction, scan, dependence distance vector, vectorizable) as well as +/// memory access annotations (such as non-aliasing writes guaranteed, +/// indirect accesses that are permutations by construction) accesses or +/// that a particular operation is prescribed atomic by the user. At this +/// level, anything that enriches what dependence analysis can do should be +/// aggressively exploited. At this level we are close to having explicit +/// vector types in the language, except we do not impose that burden on the +/// programmer/library: we derive information from scalar code + annotations. +/// 2. After dependence analysis and before polyhedral scheduling: the +/// information that supports vectorization does not need to be supplied by a +/// higher level of abstraction. Traditional dependence analysis is available +/// in MLIR and will be used to drive vectorization and cost models. +/// +/// Let's pause here and remark that applying super-vectorization as described +/// in 1. and 2. presents clear opportunities and risks: +/// - the opportunity is that vectorization is burned in the type system and +/// is protected from the adverse effect of loop scheduling, tiling, loop +/// interchange and all passes downstream. Provided that subsequent passes are +/// able to operate on vector types; the vector shapes, associated loop +/// iterator properties, alignment, and contiguity of fastest varying +/// dimensions are preserved until we lower the super-vector types. We expect +/// this to significantly rein in on the adverse effects of phase ordering. +/// - the risks are that a. all passes after super-vectorization have to work +/// on elemental vector types (not that this is always true, wherever +/// vectorization is applied) and b. that imposing vectorization constraints +/// too early may be overall detrimental to loop fusion, tiling and other +/// transformations because the dependence distances are coarsened when +/// operating on elemental vector types. For this reason, the pattern +/// profitability analysis should include a component that also captures the +/// maximal amount of fusion available under a particular pattern. This is +/// still at the stage of rough ideas but in this context, search is our +/// friend as the Tensor Comprehensions and auto-TVM contributions +/// demonstrated previously. +/// Bottom-line is we do not yet have good answers for the above but aim at +/// making it easy to answer such questions. +/// +/// Back to our listing, the last places where early super-vectorization makes +/// sense are: +/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known +/// to improve locality, parallelism and be configurable (e.g. max-fuse, +/// smart-fuse etc). They can also have adverse effects on contiguity +/// properties that are required for vectorization but the vector.transfer +/// copy-reshape-pad-transpose abstraction is expected to help recapture +/// these properties. +/// 4. right after polyhedral-style scheduling+tiling; +/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent +/// probably the most promising places because applying tiling achieves a +/// separation of concerns that allows rescheduling to worry less about +/// locality and more about parallelism and distribution (e.g. min-fuse). +/// +/// At these levels the risk-reward looks different: on one hand we probably +/// lost a good deal of language/user/library-level annotation; on the other +/// hand we gained parallelism and locality through scheduling and tiling. +/// However we probably want to ensure tiling is compatible with the +/// full-tile-only abstraction used in super-vectorization or suffer the +/// consequences. It is too early to place bets on what will win but we expect +/// super-vectorization to be the right abstraction to allow exploring at all +/// these levels. And again, search is our friend. +/// +/// Lastly, we mention it again here: +/// 6. as a MLIR-based alternative to VPLAN. +/// +/// Lowering, unrolling, pipelining: +/// ================================ +/// TODO(ntv): point to the proper places. +/// +/// Algorithm: +/// ========== +/// The algorithm proceeds in a few steps: +/// 1. defining super-vectorization patterns and matching them on the tree of +/// AffineForOp. A super-vectorization pattern is defined as a recursive +/// data structures that matches and captures nested, imperfectly-nested +/// loops that have a. conformable loop annotations attached (e.g. parallel, +/// reduction, vectorizable, ...) as well as b. all contiguous load/store +/// operations along a specified minor dimension (not necessarily the +/// fastest varying) ; +/// 2. analyzing those patterns for profitability (TODO(ntv): and +/// interference); +/// 3. Then, for each pattern in order: +/// a. applying iterative rewriting of the loop and the load operations in +/// DFS postorder. Rewriting is implemented by coarsening the loops and +/// turning load operations into opaque vector.transfer_read ops; +/// b. keeping track of the load operations encountered as "roots" and the +/// store operations as "terminals"; +/// c. traversing the use-def chains starting from the roots and iteratively +/// propagating vectorized values. Scalar values that are encountered +/// during this process must come from outside the scope of the current +/// pattern (TODO(ntv): enforce this and generalize). Such a scalar value +/// is vectorized only if it is a constant (into a vector splat). The +/// non-constant case is not supported for now and results in the pattern +/// failing to vectorize; +/// d. performing a second traversal on the terminals (store ops) to +/// rewriting the scalar value they write to memory into vector form. +/// If the scalar value has been vectorized previously, we simply replace +/// it by its vector form. Otherwise, if the scalar value is a constant, +/// it is vectorized into a splat. In all other cases, vectorization for +/// the pattern currently fails. +/// e. if everything under the root AffineForOp in the current pattern +/// vectorizes properly, we commit that loop to the IR. Otherwise we +/// discard it and restore a previously cloned version of the loop. Thanks +/// to the recursive scoping nature of matchers and captured patterns, +/// this is transparently achieved by a simple RAII implementation. +/// f. vectorization is applied on the next pattern in the list. Because +/// pattern interference avoidance is not yet implemented and that we do +/// not support further vectorizing an already vector load we need to +/// re-verify that the pattern is still vectorizable. This is expected to +/// make cost models more difficult to write and is subject to improvement +/// in the future. +/// +/// Points c. and d. above are worth additional comment. In most passes that +/// do not change the type of operands, it is usually preferred to eagerly +/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization +/// because during the use-def chain traversal, all the operands of an operation +/// must be available in vector form. Trying to propagate eagerly makes the IR +/// temporarily invalid and results in errors such as: +/// `vectorize.mlir:308:13: error: 'addf' op requires the same type for all +/// operands and results +/// %s5 = addf %a5, %b5 : f32` +/// +/// Lastly, we show a minimal example for which use-def chains rooted in load / +/// vector.transfer_read are not enough. This is what motivated splitting +/// terminal processing out of the use-def chains starting from loads. In the +/// following snippet, there is simply no load:: +/// ```mlir +/// func @fill(%A : memref<128xf32>) -> () { +/// %f1 = constant 1.0 : f32 +/// affine.for %i0 = 0 to 32 { +/// affine.store %f1, %A[%i0] : memref<128xf32, 0> +/// } +/// return +/// } +/// ``` +/// +/// Choice of loop transformation to support the algorithm: +/// ======================================================= +/// The choice of loop transformation to apply for coarsening vectorized loops +/// is still subject to exploratory tradeoffs. In particular, say we want to +/// vectorize by a factor 128, we want to transform the following input: +/// ```mlir +/// affine.for %i = %M to %N { +/// %a = affine.load %A[%i] : memref<?xf32> +/// } +/// ``` +/// +/// Traditionally, one would vectorize late (after scheduling, tiling, +/// memory promotion etc) say after stripmining (and potentially unrolling in +/// the case of LLVM's SLP vectorizer): +/// ```mlir +/// affine.for %i = floor(%M, 128) to ceil(%N, 128) { +/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { +/// %a = affine.load %A[%ii] : memref<?xf32> +/// } +/// } +/// ``` +/// +/// Instead, we seek to vectorize early and freeze vector types before +/// scheduling, so we want to generate a pattern that resembles: +/// ```mlir +/// affine.for %i = ? to ? step ? { +/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> +/// } +/// ``` +/// +/// i. simply dividing the lower / upper bounds by 128 creates issues +/// when representing expressions such as ii + 1 because now we only +/// have access to original values that have been divided. Additional +/// information is needed to specify accesses at below-128 granularity; +/// ii. another alternative is to coarsen the loop step but this may have +/// consequences on dependence analysis and fusability of loops: fusable +/// loops probably need to have the same step (because we don't want to +/// stripmine/unroll to enable fusion). +/// As a consequence, we choose to represent the coarsening using the loop +/// step for now and reevaluate in the future. Note that we can renormalize +/// loop steps later if/when we have evidence that they are problematic. +/// +/// For the simple strawman example above, vectorizing for a 1-D vector +/// abstraction of size 128 returns code similar to: +/// ```mlir +/// affine.for %i = %M to %N step 128 { +/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> +/// } +/// ``` +/// +/// Unsupported cases, extensions, and work in progress (help welcome :-) ): +/// ======================================================================== +/// 1. lowering to concrete vector types for various HW; +/// 2. reduction support; +/// 3. non-effecting padding during vector.transfer_read and filter during +/// vector.transfer_write; +/// 4. misalignment support vector.transfer_read / vector.transfer_write +/// (hopefully without read-modify-writes); +/// 5. control-flow support; +/// 6. cost-models, heuristics and search; +/// 7. Op implementation, extensions and implication on memref views; +/// 8. many TODOs left around. +/// +/// Examples: +/// ========= +/// Consider the following Function: +/// ```mlir +/// func @vector_add_2d(%M : index, %N : index) -> f32 { +/// %A = alloc (%M, %N) : memref<?x?xf32, 0> +/// %B = alloc (%M, %N) : memref<?x?xf32, 0> +/// %C = alloc (%M, %N) : memref<?x?xf32, 0> +/// %f1 = constant 1.0 : f32 +/// %f2 = constant 2.0 : f32 +/// affine.for %i0 = 0 to %M { +/// affine.for %i1 = 0 to %N { +/// // non-scoped %f1 +/// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0> +/// } +/// } +/// affine.for %i2 = 0 to %M { +/// affine.for %i3 = 0 to %N { +/// // non-scoped %f2 +/// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0> +/// } +/// } +/// affine.for %i4 = 0 to %M { +/// affine.for %i5 = 0 to %N { +/// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0> +/// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0> +/// %s5 = addf %a5, %b5 : f32 +/// // non-scoped %f1 +/// %s6 = addf %s5, %f1 : f32 +/// // non-scoped %f2 +/// %s7 = addf %s5, %f2 : f32 +/// // diamond dependency. +/// %s8 = addf %s7, %s6 : f32 +/// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0> +/// } +/// } +/// %c7 = constant 7 : index +/// %c42 = constant 42 : index +/// %res = load %C[%c7, %c42] : memref<?x?xf32, 0> +/// return %res : f32 +/// } +/// ``` +/// +/// The -affine-vectorize pass with the following arguments: +/// ``` +/// -affine-vectorize -virtual-vector-size 256 --test-fastest-varying=0 +/// ``` +/// +/// produces this standard innermost-loop vectorized code: +/// ```mlir +/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { +/// %0 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %1 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %2 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %cst = constant 1.0 : f32 +/// %cst_0 = constant 2.0 : f32 +/// affine.for %i0 = 0 to %arg0 { +/// affine.for %i1 = 0 to %arg1 step 256 { +/// %cst_1 = constant dense<vector<256xf32>, 1.0> : +/// vector<256xf32> +/// vector.transfer_write %cst_1, %0[%i0, %i1] : +/// vector<256xf32>, memref<?x?xf32> +/// } +/// } +/// affine.for %i2 = 0 to %arg0 { +/// affine.for %i3 = 0 to %arg1 step 256 { +/// %cst_2 = constant dense<vector<256xf32>, 2.0> : +/// vector<256xf32> +/// vector.transfer_write %cst_2, %1[%i2, %i3] : +/// vector<256xf32>, memref<?x?xf32> +/// } +/// } +/// affine.for %i4 = 0 to %arg0 { +/// affine.for %i5 = 0 to %arg1 step 256 { +/// %3 = vector.transfer_read %0[%i4, %i5] : +/// memref<?x?xf32>, vector<256xf32> +/// %4 = vector.transfer_read %1[%i4, %i5] : +/// memref<?x?xf32>, vector<256xf32> +/// %5 = addf %3, %4 : vector<256xf32> +/// %cst_3 = constant dense<vector<256xf32>, 1.0> : +/// vector<256xf32> +/// %6 = addf %5, %cst_3 : vector<256xf32> +/// %cst_4 = constant dense<vector<256xf32>, 2.0> : +/// vector<256xf32> +/// %7 = addf %5, %cst_4 : vector<256xf32> +/// %8 = addf %7, %6 : vector<256xf32> +/// vector.transfer_write %8, %2[%i4, %i5] : +/// vector<256xf32>, memref<?x?xf32> +/// } +/// } +/// %c7 = constant 7 : index +/// %c42 = constant 42 : index +/// %9 = load %2[%c7, %c42] : memref<?x?xf32> +/// return %9 : f32 +/// } +/// ``` +/// +/// The -affine-vectorize pass with the following arguments: +/// ``` +/// -affine-vectorize -virtual-vector-size 32 -virtual-vector-size 256 +/// --test-fastest-varying=1 --test-fastest-varying=0 +/// ``` +/// +/// produces this more interesting mixed outer-innermost-loop vectorized code: +/// ```mlir +/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { +/// %0 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %1 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %2 = alloc(%arg0, %arg1) : memref<?x?xf32> +/// %cst = constant 1.0 : f32 +/// %cst_0 = constant 2.0 : f32 +/// affine.for %i0 = 0 to %arg0 step 32 { +/// affine.for %i1 = 0 to %arg1 step 256 { +/// %cst_1 = constant dense<vector<32x256xf32>, 1.0> : +/// vector<32x256xf32> +/// vector.transfer_write %cst_1, %0[%i0, %i1] : +/// vector<32x256xf32>, memref<?x?xf32> +/// } +/// } +/// affine.for %i2 = 0 to %arg0 step 32 { +/// affine.for %i3 = 0 to %arg1 step 256 { +/// %cst_2 = constant dense<vector<32x256xf32>, 2.0> : +/// vector<32x256xf32> +/// vector.transfer_write %cst_2, %1[%i2, %i3] : +/// vector<32x256xf32>, memref<?x?xf32> +/// } +/// } +/// affine.for %i4 = 0 to %arg0 step 32 { +/// affine.for %i5 = 0 to %arg1 step 256 { +/// %3 = vector.transfer_read %0[%i4, %i5] : +/// memref<?x?xf32> vector<32x256xf32> +/// %4 = vector.transfer_read %1[%i4, %i5] : +/// memref<?x?xf32>, vector<32x256xf32> +/// %5 = addf %3, %4 : vector<32x256xf32> +/// %cst_3 = constant dense<vector<32x256xf32>, 1.0> : +/// vector<32x256xf32> +/// %6 = addf %5, %cst_3 : vector<32x256xf32> +/// %cst_4 = constant dense<vector<32x256xf32>, 2.0> : +/// vector<32x256xf32> +/// %7 = addf %5, %cst_4 : vector<32x256xf32> +/// %8 = addf %7, %6 : vector<32x256xf32> +/// vector.transfer_write %8, %2[%i4, %i5] : +/// vector<32x256xf32>, memref<?x?xf32> +/// } +/// } +/// %c7 = constant 7 : index +/// %c42 = constant 42 : index +/// %9 = load %2[%c7, %c42] : memref<?x?xf32> +/// return %9 : f32 +/// } +/// ``` +/// +/// Of course, much more intricate n-D imperfectly-nested patterns can be +/// vectorized too and specified in a fully declarative fashion. + +#define DEBUG_TYPE "early-vect" + +using functional::makePtrDynCaster; +using functional::map; +using llvm::dbgs; +using llvm::SetVector; + +static llvm::cl::OptionCategory clOptionsCategory("vectorize options"); + +static llvm::cl::list<int> clVirtualVectorSize( + "virtual-vector-size", + llvm::cl::desc("Specify an n-D virtual vector size for vectorization"), + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); + +static llvm::cl::list<int> clFastestVaryingPattern( + "test-fastest-varying", + llvm::cl::desc( + "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory" + " dimensions to match. See defaultPatterns in Vectorize.cpp for a" + " description and examples. This is used for testing purposes"), + llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); + +/// Forward declaration. +static FilterFunctionType +isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, + int fastestVaryingMemRefDimension); + +/// Creates a vectorization pattern from the command line arguments. +/// Up to 3-D patterns are supported. +/// If the command line argument requests a pattern of higher order, returns an +/// empty pattern list which will conservatively result in no vectorization. +static std::vector<NestedPattern> +makePatterns(const DenseSet<Operation *> ¶llelLoops, int vectorRank, + ArrayRef<int64_t> fastestVaryingPattern) { + using matcher::For; + int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; + int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1]; + int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2]; + switch (vectorRank) { + case 1: + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0))}; + case 2: + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1)))}; + case 3: + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1), + For(isVectorizableLoopPtrFactory(parallelLoops, d2))))}; + default: { + return std::vector<NestedPattern>(); + } + } +} + +static NestedPattern &vectorTransferPattern() { + static auto pattern = matcher::Op([](Operation &op) { + return isa<vector::TransferReadOp>(op) || isa<vector::TransferWriteOp>(op); + }); + return pattern; +} + +namespace { + +/// Base state for the vectorize pass. +/// Command line arguments are preempted by non-empty pass arguments. +struct Vectorize : public FunctionPass<Vectorize> { + Vectorize(); + Vectorize(ArrayRef<int64_t> virtualVectorSize); + void runOnFunction() override; + + // The virtual vector size that we vectorize to. + SmallVector<int64_t, 4> vectorSizes; + // Optionally, the fixed mapping from loop to fastest varying MemRef dimension + // for all the MemRefs within a loop pattern: + // the index represents the loop depth, the value represents the k^th + // fastest varying memory dimension. + // This is voluntarily restrictive and is meant to precisely target a + // particular loop/op pair, for testing purposes. + SmallVector<int64_t, 4> fastestVaryingPattern; +}; + +} // end anonymous namespace + +Vectorize::Vectorize() + : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()), + fastestVaryingPattern(clFastestVaryingPattern.begin(), + clFastestVaryingPattern.end()) {} + +Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() { + if (!virtualVectorSize.empty()) { + this->vectorSizes.assign(virtualVectorSize.begin(), + virtualVectorSize.end()); + } +} + +/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. +///////// +namespace { + +struct VectorizationStrategy { + SmallVector<int64_t, 8> vectorSizes; + DenseMap<Operation *, unsigned> loopToVectorDim; +}; + +} // end anonymous namespace + +static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern, + unsigned patternDepth, + VectorizationStrategy *strategy) { + assert(patternDepth > depthInPattern && + "patternDepth is greater than depthInPattern"); + if (patternDepth - depthInPattern > strategy->vectorSizes.size()) { + // Don't vectorize this loop + return; + } + strategy->loopToVectorDim[loop] = + strategy->vectorSizes.size() - (patternDepth - depthInPattern); +} + +/// Implements a simple strawman strategy for vectorization. +/// Given a matched pattern `matches` of depth `patternDepth`, this strategy +/// greedily assigns the fastest varying dimension ** of the vector ** to the +/// innermost loop in the pattern. +/// When coupled with a pattern that looks for the fastest varying dimension in +/// load/store MemRefs, this creates a generic vectorization strategy that works +/// for any loop in a hierarchy (outermost, innermost or intermediate). +/// +/// TODO(ntv): In the future we should additionally increase the power of the +/// profitability analysis along 3 directions: +/// 1. account for loop extents (both static and parametric + annotations); +/// 2. account for data layout permutations; +/// 3. account for impact of vectorization on maximal loop fusion. +/// Then we can quantify the above to build a cost model and search over +/// strategies. +static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches, + unsigned depthInPattern, + unsigned patternDepth, + VectorizationStrategy *strategy) { + for (auto m : matches) { + if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, + patternDepth, strategy))) { + return failure(); + } + vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern, + patternDepth, strategy); + } + return success(); +} + +///// end TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate ///// + +namespace { + +struct VectorizationState { + /// Adds an entry of pre/post vectorization operations in the state. + void registerReplacement(Operation *key, Operation *value); + /// When the current vectorization pattern is successful, this erases the + /// operations that were marked for erasure in the proper order and resets + /// the internal state for the next pattern. + void finishVectorizationPattern(); + + // In-order tracking of original Operation that have been vectorized. + // Erase in reverse order. + SmallVector<Operation *, 16> toErase; + // Set of Operation that have been vectorized (the values in the + // vectorizationMap for hashed access). The vectorizedSet is used in + // particular to filter the operations that have already been vectorized by + // this pattern, when iterating over nested loops in this pattern. + DenseSet<Operation *> vectorizedSet; + // Map of old scalar Operation to new vectorized Operation. + DenseMap<Operation *, Operation *> vectorizationMap; + // Map of old scalar Value to new vectorized Value. + DenseMap<Value, Value> replacementMap; + // The strategy drives which loop to vectorize by which amount. + const VectorizationStrategy *strategy; + // Use-def roots. These represent the starting points for the worklist in the + // vectorizeNonTerminals function. They consist of the subset of load + // operations that have been vectorized. They can be retrieved from + // `vectorizationMap` but it is convenient to keep track of them in a separate + // data structure. + DenseSet<Operation *> roots; + // Terminal operations for the worklist in the vectorizeNonTerminals + // function. They consist of the subset of store operations that have been + // vectorized. They can be retrieved from `vectorizationMap` but it is + // convenient to keep track of them in a separate data structure. Since they + // do not necessarily belong to use-def chains starting from loads (e.g + // storing a constant), we need to handle them in a post-pass. + DenseSet<Operation *> terminals; + // Checks that the type of `op` is AffineStoreOp and adds it to the terminals + // set. + void registerTerminal(Operation *op); + // Folder used to factor out constant creation. + OperationFolder *folder; + +private: + void registerReplacement(Value key, Value value); +}; + +} // end namespace + +void VectorizationState::registerReplacement(Operation *key, Operation *value) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); + LLVM_DEBUG(key->print(dbgs())); + LLVM_DEBUG(dbgs() << " into "); + LLVM_DEBUG(value->print(dbgs())); + assert(key->getNumResults() == 1 && "already registered"); + assert(value->getNumResults() == 1 && "already registered"); + assert(vectorizedSet.count(value) == 0 && "already registered"); + assert(vectorizationMap.count(key) == 0 && "already registered"); + toErase.push_back(key); + vectorizedSet.insert(value); + vectorizationMap.insert(std::make_pair(key, value)); + registerReplacement(key->getResult(0), value->getResult(0)); + if (isa<AffineLoadOp>(key)) { + assert(roots.count(key) == 0 && "root was already inserted previously"); + roots.insert(key); + } +} + +void VectorizationState::registerTerminal(Operation *op) { + assert(isa<AffineStoreOp>(op) && "terminal must be a AffineStoreOp"); + assert(terminals.count(op) == 0 && + "terminal was already inserted previously"); + terminals.insert(op); +} + +void VectorizationState::finishVectorizationPattern() { + while (!toErase.empty()) { + auto *op = toErase.pop_back_val(); + LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); + LLVM_DEBUG(op->print(dbgs())); + op->erase(); + } +} + +void VectorizationState::registerReplacement(Value key, Value value) { + assert(replacementMap.count(key) == 0 && "replacement already registered"); + replacementMap.insert(std::make_pair(key, value)); +} + +// Apply 'map' with 'mapOperands' returning resulting values in 'results'. +static void computeMemoryOpIndices(Operation *op, AffineMap map, + ValueRange mapOperands, + SmallVectorImpl<Value> &results) { + OpBuilder builder(op); + for (auto resultExpr : map.getResults()) { + auto singleResMap = + AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); + auto afOp = + builder.create<AffineApplyOp>(op->getLoc(), singleResMap, mapOperands); + results.push_back(afOp); + } +} + +////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// + +/// Handles the vectorization of load and store MLIR operations. +/// +/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call. +/// They are vectorized immediately. The resulting vector.transfer_read is +/// immediately registered to replace all uses of the AffineLoadOp in this +/// pattern's scope. +/// +/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They +/// need to be vectorized late once all the use-def chains have been traversed. +/// Additionally, they may have ssa-values operands which come from outside the +/// scope of the current pattern. +/// Such special cases force us to delay the vectorization of the stores until +/// the last step. Here we merely register the store operation. +template <typename LoadOrStoreOpPointer> +static LogicalResult vectorizeRootOrTerminal(Value iv, + LoadOrStoreOpPointer memoryOp, + VectorizationState *state) { + auto memRefType = memoryOp.getMemRef()->getType().template cast<MemRefType>(); + + auto elementType = memRefType.getElementType(); + // TODO(ntv): ponder whether we want to further vectorize a vector value. + assert(VectorType::isValidElementType(elementType) && + "Not a valid vector element type"); + auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); + + // Materialize a MemRef with 1 vector. + auto *opInst = memoryOp.getOperation(); + // For now, vector.transfers must be aligned, operate only on indices with an + // identity subset of AffineMap and do not change layout. + // TODO(ntv): increase the expressiveness power of vector.transfer operations + // as needed by various targets. + if (auto load = dyn_cast<AffineLoadOp>(opInst)) { + OpBuilder b(opInst); + ValueRange mapOperands = load.getMapOperands(); + SmallVector<Value, 8> indices; + indices.reserve(load.getMemRefType().getRank()); + if (load.getAffineMap() != + b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { + computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices); + } else { + indices.append(mapOperands.begin(), mapOperands.end()); + } + auto permutationMap = + makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); + if (!permutationMap) + return LogicalResult::Failure; + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); + auto transfer = b.create<vector::TransferReadOp>( + opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices, + AffineMapAttr::get(permutationMap), + // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 + state->folder->create<ConstantFloatOp>(b, opInst->getLoc(), + APFloat(0.0f), b.getF32Type())); + state->registerReplacement(opInst, transfer.getOperation()); + } else { + state->registerTerminal(opInst); + } + return success(); +} +/// end TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. /// + +/// Coarsens the loops bounds and transforms all remaining load and store +/// operations into the appropriate vector.transfer. +static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, + VectorizationState *state) { + using namespace functional; + loop.setStep(step); + + FilterFunctionType notVectorizedThisPattern = [state](Operation &op) { + if (!matcher::isLoadOrStore(op)) { + return false; + } + return state->vectorizationMap.count(&op) == 0 && + state->vectorizedSet.count(&op) == 0 && + state->roots.count(&op) == 0 && state->terminals.count(&op) == 0; + }; + auto loadAndStores = matcher::Op(notVectorizedThisPattern); + SmallVector<NestedMatch, 8> loadAndStoresMatches; + loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); + for (auto ls : loadAndStoresMatches) { + auto *opInst = ls.getMatchedOperation(); + auto load = dyn_cast<AffineLoadOp>(opInst); + auto store = dyn_cast<AffineStoreOp>(opInst); + LLVM_DEBUG(opInst->print(dbgs())); + LogicalResult result = + load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) + : vectorizeRootOrTerminal(loop.getInductionVar(), store, state); + if (failed(result)) { + return failure(); + } + } + return success(); +} + +/// Returns a FilterFunctionType that can be used in NestedPattern to match a +/// loop whose underlying load/store accesses are either invariant or all +// varying along the `fastestVaryingMemRefDimension`. +static FilterFunctionType +isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, + int fastestVaryingMemRefDimension) { + return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { + auto loop = cast<AffineForOp>(forOp); + auto parallelIt = parallelLoops.find(loop); + if (parallelIt == parallelLoops.end()) + return false; + int memRefDim = -1; + auto vectorizableBody = + isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern()); + if (!vectorizableBody) + return false; + return memRefDim == -1 || fastestVaryingMemRefDimension == -1 || + memRefDim == fastestVaryingMemRefDimension; + }; +} + +/// Apply vectorization of `loop` according to `state`. This is only triggered +/// if all vectorizations in `childrenMatches` have already succeeded +/// recursively in DFS post-order. +static LogicalResult +vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, + VectorizationState *state) { + auto *loopInst = oneMatch.getMatchedOperation(); + auto loop = cast<AffineForOp>(loopInst); + auto childrenMatches = oneMatch.getMatchedChildren(); + + // 1. DFS postorder recursion, if any of my children fails, I fail too. + for (auto m : childrenMatches) { + if (failed(vectorizeLoopsAndLoadsRecursively(m, state))) { + return failure(); + } + } + + // 2. This loop may have been omitted from vectorization for various reasons + // (e.g. due to the performance model or pattern depth > vector size). + auto it = state->strategy->loopToVectorDim.find(loopInst); + if (it == state->strategy->loopToVectorDim.end()) { + return success(); + } + + // 3. Actual post-order transformation. + auto vectorDim = it->second; + assert(vectorDim < state->strategy->vectorSizes.size() && + "vector dim overflow"); + // a. get actual vector size + auto vectorSize = state->strategy->vectorSizes[vectorDim]; + // b. loop transformation for early vectorization is still subject to + // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: + // | ub -> ub + // | step -> step * vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize + << " : "); + LLVM_DEBUG(loopInst->print(dbgs())); + return vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state); +} + +/// Tries to transform a scalar constant into a vector splat of that constant. +/// Returns the vectorized splat operation if the constant is a valid vector +/// element type. +/// If `type` is not a valid vector type or if the scalar constant is not a +/// valid vector element type, returns nullptr. +static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) { + if (!type || !type.isa<VectorType>() || + !VectorType::isValidElementType(constant.getType())) { + return nullptr; + } + OpBuilder b(op); + Location loc = op->getLoc(); + auto vectorType = type.cast<VectorType>(); + auto attr = DenseElementsAttr::get(vectorType, constant.getValue()); + auto *constantOpInst = constant.getOperation(); + + OperationState state(loc, constantOpInst->getName().getStringRef(), {}, + {vectorType}, {b.getNamedAttr("value", attr)}); + + return b.createOperation(state)->getResult(0); +} + +/// Tries to vectorize a given operand `op` of Operation `op` during +/// def-chain propagation or during terminal vectorization, by applying the +/// following logic: +/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized +/// useby -def propagation), `op` is already in the proper vector form; +/// 2. otherwise, the `op` may be in some other vector form that fails to +/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate +/// failure; +/// 3. if the `op` is a constant, returns the vectorized form of the constant; +/// 4. non-constant scalars are currently non-vectorizable, in particular to +/// guard against vectorizing an index which may be loop-variant and needs +/// special handling. +/// +/// In particular this logic captures some of the use cases where definitions +/// that are not scoped under the current pattern are needed to vectorize. +/// One such example is top level function constants that need to be splatted. +/// +/// Returns an operand that has been vectorized to match `state`'s strategy if +/// vectorization is possible with the above logic. Returns nullptr otherwise. +/// +/// TODO(ntv): handle more complex cases. +static Value vectorizeOperand(Value operand, Operation *op, + VectorizationState *state) { + LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); + LLVM_DEBUG(operand->print(dbgs())); + // 1. If this value has already been vectorized this round, we are done. + if (state->vectorizedSet.count(operand->getDefiningOp()) > 0) { + LLVM_DEBUG(dbgs() << " -> already vector operand"); + return operand; + } + // 1.b. Delayed on-demand replacement of a use. + // Note that we cannot just call replaceAllUsesWith because it may result + // in ops with mixed types, for ops whose operands have not all yet + // been vectorized. This would be invalid IR. + auto it = state->replacementMap.find(operand); + if (it != state->replacementMap.end()) { + auto res = it->second; + LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); + LLVM_DEBUG(res->print(dbgs())); + return res; + } + // 2. TODO(ntv): broadcast needed. + if (operand->getType().isa<VectorType>()) { + LLVM_DEBUG(dbgs() << "-> non-vectorizable"); + return nullptr; + } + // 3. vectorize constant. + if (auto constant = dyn_cast<ConstantOp>(operand->getDefiningOp())) { + return vectorizeConstant( + op, constant, + VectorType::get(state->strategy->vectorSizes, operand->getType())); + } + // 4. currently non-vectorizable. + LLVM_DEBUG(dbgs() << "-> non-vectorizable"); + LLVM_DEBUG(operand->print(dbgs())); + return nullptr; +} + +/// Encodes Operation-specific behavior for vectorization. In general we assume +/// that all operands of an op must be vectorized but this is not always true. +/// In the future, it would be nice to have a trait that describes how a +/// particular operation vectorizes. For now we implement the case distinction +/// here. +/// Returns a vectorized form of an operation or nullptr if vectorization fails. +// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. +// Maybe some Ops are not vectorizable or require some tricky logic, we cannot +// do one-off logic here; ideally it would be TableGen'd. +static Operation *vectorizeOneOperation(Operation *opInst, + VectorizationState *state) { + // Sanity checks. + assert(!isa<AffineLoadOp>(opInst) && + "all loads must have already been fully vectorized independently"); + assert(!isa<vector::TransferReadOp>(opInst) && + "vector.transfer_read cannot be further vectorized"); + assert(!isa<vector::TransferWriteOp>(opInst) && + "vector.transfer_write cannot be further vectorized"); + + if (auto store = dyn_cast<AffineStoreOp>(opInst)) { + OpBuilder b(opInst); + auto memRef = store.getMemRef(); + auto value = store.getValueToStore(); + auto vectorValue = vectorizeOperand(value, opInst, state); + + ValueRange mapOperands = store.getMapOperands(); + SmallVector<Value, 8> indices; + indices.reserve(store.getMemRefType().getRank()); + if (store.getAffineMap() != + b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { + computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands, + indices); + } else { + indices.append(mapOperands.begin(), mapOperands.end()); + } + + auto permutationMap = + makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); + if (!permutationMap) + return nullptr; + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); + auto transfer = b.create<vector::TransferWriteOp>( + opInst->getLoc(), vectorValue, memRef, indices, + AffineMapAttr::get(permutationMap)); + auto *res = transfer.getOperation(); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); + // "Terminals" (i.e. AffineStoreOps) are erased on the spot. + opInst->erase(); + return res; + } + if (opInst->getNumRegions() != 0) + return nullptr; + + SmallVector<Type, 8> vectorTypes; + for (auto v : opInst->getResults()) { + vectorTypes.push_back( + VectorType::get(state->strategy->vectorSizes, v->getType())); + } + SmallVector<Value, 8> vectorOperands; + for (auto v : opInst->getOperands()) { + vectorOperands.push_back(vectorizeOperand(v, opInst, state)); + } + // Check whether a single operand is null. If so, vectorization failed. + bool success = llvm::all_of(vectorOperands, [](Value op) { return op; }); + if (!success) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); + return nullptr; + } + + // Create a clone of the op with the proper operands and return types. + // TODO(ntv): The following assumes there is always an op with a fixed + // name that works both in scalar mode and vector mode. + // TODO(ntv): Is it worth considering an Operation.clone operation which + // changes the type so we can promote an Operation with less boilerplate? + OpBuilder b(opInst); + OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(), + vectorOperands, vectorTypes, opInst->getAttrs(), + /*successors=*/{}, + /*regions=*/{}, opInst->hasResizableOperandsList()); + return b.createOperation(newOp); +} + +/// Iterates over the forward slice from the loads in the vectorization pattern +/// and rewrites them using their vectorized counterpart by: +/// 1. Create the forward slice starting from the loads in the vectorization +/// pattern. +/// 2. Topologically sorts the forward slice. +/// 3. For each operation in the slice, create the vector form of this +/// operation, replacing each operand by a replacement operands retrieved from +/// replacementMap. If any such replacement is missing, vectorization fails. +static LogicalResult vectorizeNonTerminals(VectorizationState *state) { + // 1. create initial worklist with the uses of the roots. + SetVector<Operation *> worklist; + // Note: state->roots have already been vectorized and must not be vectorized + // again. This fits `getForwardSlice` which does not insert `op` in the + // result. + // Note: we have to exclude terminals because some of their defs may not be + // nested under the vectorization pattern (e.g. constants defined in an + // encompassing scope). + // TODO(ntv): Use a backward slice for terminals, avoid special casing and + // merge implementations. + for (auto *op : state->roots) { + getForwardSlice(op, &worklist, [state](Operation *op) { + return state->terminals.count(op) == 0; // propagate if not terminal + }); + } + // We merged multiple slices, topological order may not hold anymore. + worklist = topologicalSort(worklist); + + for (unsigned i = 0; i < worklist.size(); ++i) { + auto *op = worklist[i]; + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); + LLVM_DEBUG(op->print(dbgs())); + + // Create vector form of the operation. + // Insert it just before op, on success register op as replaced. + auto *vectorizedInst = vectorizeOneOperation(op, state); + if (!vectorizedInst) { + return failure(); + } + + // 3. Register replacement for future uses in the scope. + // Note that we cannot just call replaceAllUsesWith because it may + // result in ops with mixed types, for ops whose operands have not all + // yet been vectorized. This would be invalid IR. + state->registerReplacement(op, vectorizedInst); + } + return success(); +} + +/// Vectorization is a recursive procedure where anything below can fail. +/// The root match thus needs to maintain a clone for handling failure. +/// Each root may succeed independently but will otherwise clean after itself if +/// anything below it fails. +static LogicalResult vectorizeRootMatch(NestedMatch m, + VectorizationStrategy *strategy) { + auto loop = cast<AffineForOp>(m.getMatchedOperation()); + OperationFolder folder(loop.getContext()); + VectorizationState state; + state.strategy = strategy; + state.folder = &folder; + + // Since patterns are recursive, they can very well intersect. + // Since we do not want a fully greedy strategy in general, we decouple + // pattern matching, from profitability analysis, from application. + // As a consequence we must check that each root pattern is still + // vectorizable. If a pattern is not vectorizable anymore, we just skip it. + // TODO(ntv): implement a non-greedy profitability analysis that keeps only + // non-intersecting patterns. + if (!isVectorizableLoopBody(loop, vectorTransferPattern())) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); + return failure(); + } + + /// Sets up error handling for this root loop. This is how the root match + /// maintains a clone for handling failure and restores the proper state via + /// RAII. + auto *loopInst = loop.getOperation(); + OpBuilder builder(loopInst); + auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst)); + struct Guard { + LogicalResult failure() { + loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar()); + loop.erase(); + return mlir::failure(); + } + LogicalResult success() { + clonedLoop.erase(); + return mlir::success(); + } + AffineForOp loop; + AffineForOp clonedLoop; + } guard{loop, clonedLoop}; + + ////////////////////////////////////////////////////////////////////////////// + // Start vectorizing. + // From now on, any error triggers the scope guard above. + ////////////////////////////////////////////////////////////////////////////// + // 1. Vectorize all the loops matched by the pattern, recursively. + // This also vectorizes the roots (AffineLoadOp) as well as registers the + // terminals (AffineStoreOp) for post-processing vectorization (we need to + // wait for all use-def chains into them to be vectorized first). + if (failed(vectorizeLoopsAndLoadsRecursively(m, &state))) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop"); + return guard.failure(); + } + + // 2. Vectorize operations reached by use-def chains from root except the + // terminals (store operations) that need to be post-processed separately. + // TODO(ntv): add more as we expand. + if (failed(vectorizeNonTerminals(&state))) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals"); + return guard.failure(); + } + + // 3. Post-process terminals. + // Note: we have to post-process terminals because some of their defs may not + // be nested under the vectorization pattern (e.g. constants defined in an + // encompassing scope). + // TODO(ntv): Use a backward slice for terminals, avoid special casing and + // merge implementations. + for (auto *op : state.terminals) { + if (!vectorizeOneOperation(op, &state)) { // nullptr == failure + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals"); + return guard.failure(); + } + } + + // 4. Finish this vectorization pattern. + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); + state.finishVectorizationPattern(); + return guard.success(); +} + +/// Applies vectorization to the current Function by searching over a bunch of +/// predetermined patterns. +void Vectorize::runOnFunction() { + FuncOp f = getFunction(); + if (!fastestVaryingPattern.empty() && + fastestVaryingPattern.size() != vectorSizes.size()) { + f.emitRemark("Fastest varying pattern specified with different size than " + "the vector size."); + return signalPassFailure(); + } + + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + + DenseSet<Operation *> parallelLoops; + f.walk([¶llelLoops](AffineForOp loop) { + if (isLoopParallel(loop)) + parallelLoops.insert(loop); + }); + + for (auto &pat : + makePatterns(parallelLoops, vectorSizes.size(), fastestVaryingPattern)) { + LLVM_DEBUG(dbgs() << "\n******************************************"); + LLVM_DEBUG(dbgs() << "\n******************************************"); + LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); + LLVM_DEBUG(f.print(dbgs())); + unsigned patternDepth = pat.getDepth(); + + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); + // Iterate over all the top-level matches and vectorize eagerly. + // This automatically prunes intersecting matches. + for (auto m : matches) { + VectorizationStrategy strategy; + // TODO(ntv): depending on profitability, elect to reduce the vector size. + strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end()); + if (failed(analyzeProfitability(m.getMatchedChildren(), 1, patternDepth, + &strategy))) { + continue; + } + vectorizeLoopIfProfitable(m.getMatchedOperation(), 0, patternDepth, + &strategy); + // TODO(ntv): if pattern does not apply, report it; alter the + // cost/benefit. + vectorizeRootMatch(m, &strategy); + // TODO(ntv): some diagnostics if failure to vectorize occurs. + } + } + LLVM_DEBUG(dbgs() << "\n"); +} + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createVectorizePass(ArrayRef<int64_t> virtualVectorSize) { + return std::make_unique<Vectorize>(virtualVectorSize); +} + +static PassRegistration<Vectorize> + pass("affine-vectorize", + "Vectorize to a target independent n-D vector abstraction"); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp new file mode 100644 index 00000000000..508c547a52b --- /dev/null +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -0,0 +1,170 @@ +//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/ViewOpGraph.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/Support/CommandLine.h" + +static llvm::cl::opt<int> elideIfLarger( + "print-op-graph-elide-if-larger", + llvm::cl::desc("Upper limit to emit elements attribute rather than elide"), + llvm::cl::init(16)); + +using namespace mlir; + +namespace llvm { + +// Specialize GraphTraits to treat Block as a graph of Operations as nodes and +// uses as edges. +template <> struct GraphTraits<Block *> { + using GraphType = Block *; + using NodeRef = Operation *; + + using ChildIteratorType = UseIterator; + static ChildIteratorType child_begin(NodeRef n) { + return ChildIteratorType(n); + } + static ChildIteratorType child_end(NodeRef n) { + return ChildIteratorType(n, /*end=*/true); + } + + // Operation's destructor is private so use Operation* instead and use + // mapped iterator. + static Operation *AddressOf(Operation &op) { return &op; } + using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>; + static nodes_iterator nodes_begin(Block *b) { + return nodes_iterator(b->begin(), &AddressOf); + } + static nodes_iterator nodes_end(Block *b) { + return nodes_iterator(b->end(), &AddressOf); + } +}; + +// Specialize DOTGraphTraits to produce more readable output. +template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits { + using DefaultDOTGraphTraits::DefaultDOTGraphTraits; + static std::string getNodeLabel(Operation *op, Block *); +}; + +std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) { + // Reuse the print output for the node labels. + std::string ostr; + raw_string_ostream os(ostr); + os << op->getName() << "\n"; + + if (!op->getLoc().isa<UnknownLoc>()) { + os << op->getLoc() << "\n"; + } + + // Print resultant types + interleaveComma(op->getResultTypes(), os); + os << "\n"; + + for (auto attr : op->getAttrs()) { + os << '\n' << attr.first << ": "; + // Always emit splat attributes. + if (attr.second.isa<SplatElementsAttr>()) { + attr.second.print(os); + continue; + } + + // Elide "big" elements attributes. + auto elements = attr.second.dyn_cast<ElementsAttr>(); + if (elements && elements.getNumElements() > elideIfLarger) { + os << std::string(elements.getType().getRank(), '[') << "..." + << std::string(elements.getType().getRank(), ']') << " : " + << elements.getType(); + continue; + } + + auto array = attr.second.dyn_cast<ArrayAttr>(); + if (array && static_cast<int64_t>(array.size()) > elideIfLarger) { + os << "[...]"; + continue; + } + + // Print all other attributes. + attr.second.print(os); + } + return os.str(); +} + +} // end namespace llvm + +namespace { +// PrintOpPass is simple pass to write graph per function. +// Note: this is a module pass only to avoid interleaving on the same ostream +// due to multi-threading over functions. +struct PrintOpPass : public ModulePass<PrintOpPass> { + explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false, + const Twine &title = "") + : os(os), title(title.str()), short_names(short_names) {} + + std::string getOpName(Operation &op) { + auto symbolAttr = + op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); + if (symbolAttr) + return symbolAttr.getValue(); + ++unnamedOpCtr; + return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str(); + } + + // Print all the ops in a module. + void processModule(ModuleOp module) { + for (Operation &op : module) { + // Modules may actually be nested, recurse on nesting. + if (auto nestedModule = dyn_cast<ModuleOp>(op)) { + processModule(nestedModule); + continue; + } + auto opName = getOpName(op); + for (Region ®ion : op.getRegions()) { + for (auto indexed_block : llvm::enumerate(region)) { + // Suffix block number if there are more than 1 block. + auto blockName = region.getBlocks().size() == 1 + ? "" + : ("__" + llvm::utostr(indexed_block.index())); + llvm::WriteGraph(os, &indexed_block.value(), short_names, + Twine(title) + opName + blockName); + } + } + } + } + + void runOnModule() override { processModule(getModule()); } + +private: + raw_ostream &os; + std::string title; + int unnamedOpCtr = 0; + bool short_names; +}; +} // namespace + +void mlir::viewGraph(Block &block, const Twine &name, bool shortNames, + const Twine &title, llvm::GraphProgram::Name program) { + llvm::ViewGraph(&block, name, shortNames, title, program); +} + +raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, + const Twine &title) { + return llvm::WriteGraph(os, &block, shortNames, title); +} + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, + const Twine &title) { + return std::make_unique<PrintOpPass>(os, shortNames, title); +} + +static PassRegistration<PrintOpPass> pass("print-op-graph", + "Print op graph per region"); diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp new file mode 100644 index 00000000000..77111087d07 --- /dev/null +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -0,0 +1,85 @@ +//===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/ViewRegionGraph.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace llvm { + +// Specialize DOTGraphTraits to produce more readable output. +template <> struct DOTGraphTraits<Region *> : public DefaultDOTGraphTraits { + using DefaultDOTGraphTraits::DefaultDOTGraphTraits; + + static std::string getNodeLabel(Block *Block, Region *); +}; + +std::string DOTGraphTraits<Region *>::getNodeLabel(Block *Block, Region *) { + // Reuse the print output for the node labels. + std::string outStreamStr; + raw_string_ostream os(outStreamStr); + Block->print(os); + std::string &outStr = os.str(); + + if (outStr[0] == '\n') + outStr.erase(outStr.begin()); + + // Process string output to left justify the block. + for (unsigned i = 0; i != outStr.length(); ++i) { + if (outStr[i] == '\n') { + outStr[i] = '\\'; + outStr.insert(outStr.begin() + i + 1, 'l'); + } + } + + return outStr; +} + +} // end namespace llvm + +void mlir::viewGraph(Region ®ion, const Twine &name, bool shortNames, + const Twine &title, llvm::GraphProgram::Name program) { + llvm::ViewGraph(®ion, name, shortNames, title, program); +} + +raw_ostream &mlir::writeGraph(raw_ostream &os, Region ®ion, bool shortNames, + const Twine &title) { + return llvm::WriteGraph(os, ®ion, shortNames, title); +} + +void mlir::Region::viewGraph(const Twine ®ionName) { + ::mlir::viewGraph(*this, regionName); +} +void mlir::Region::viewGraph() { viewGraph("region"); } + +namespace { +struct PrintCFGPass : public FunctionPass<PrintCFGPass> { + PrintCFGPass(raw_ostream &os = llvm::errs(), bool shortNames = false, + const Twine &title = "") + : os(os), shortNames(shortNames), title(title.str()) {} + void runOnFunction() override { + mlir::writeGraph(os, getFunction().getBody(), shortNames, title); + } + +private: + raw_ostream &os; + bool shortNames; + std::string title; +}; +} // namespace + +std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> +mlir::createPrintCFGGraphPass(raw_ostream &os, bool shortNames, + const Twine &title) { + return std::make_unique<PrintCFGPass>(os, shortNames, title); +} + +static PassRegistration<PrintCFGPass> pass("print-cfg-graph", + "Print CFG graph per Function"); |