//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements the linalg dialect Fusion pass. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "linalg-fusion" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using llvm::dbgs; /// Implements a simple high-level fusion pass of linalg library operations. /// /// In each block, linalg ops are processed in reverse textual order. /// Given a linalg op, fusion occurs by: /// 1. tiling the op by a given multi-dimensional tile size; /// 2. inspecting the linalg ops that write into the views read by the op in /// step 1. This uses the SSA value of the views to determine producer- /// consumer dependences: only identical SSA views are considered for /// fusion at this point; /// 3. greedily fuse the producing linalg ops into the consuming loop tiles; /// 4. inspect the fused ops and determine whether they have other remaining /// LinalgOp uses. If not, then erase the original producing linalg op. /// /// More advanced use cases, analyses as well as profitability heuristics are /// left for future work. static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); static llvm::cl::list clTileSizes( "linalg-fusion-tile-sizes", llvm::cl::desc( "Tile sizes by which to tile linalg operations during linalg fusion"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)); // Return a cloned version of `op` that operates on `loopRanges`, assumed to be // a subset of the original loop ranges of `op`. // This is achieved by applying the `loopToOperandRangesMaps` permutation maps // to the `loopRanges` in order to obtain view ranges. static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges, OperationFolder &state) { ScopedContext scope(b, loc); auto maps = loopToOperandRangesMaps(op); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "map: " << map << "\n"); Value *view = en.value(); SmallVector viewRanges(map.getNumResults()); for (auto en2 : llvm::enumerate(map.getResults())) { unsigned d = en2.index(); // loopToOperandRangesMaps are permutations-only. unsigned loopPos = en2.value().cast().getPosition(); viewRanges[d] = loopRanges[loopPos]; LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() << "\t" << "loopPos: " << loopPos << "\t" << viewRanges[d]); } // TODO(ntv) opportunities for folding/CSE here rather than build new IR. clonedViews.push_back(b.create(loc, view, viewRanges)); } auto operands = getAssumedNonViewOperands(op); clonedViews.append(operands.begin(), operands.end()); return op.clone(b, loc, clonedViews); } struct ViewDimension { Value *view; unsigned dimension; }; static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { auto maps = loopToOperandRangesMaps(op); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. SmallVector ios(op.getInputsAndOutputs()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); Value *view = en.value(); SmallVector viewRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth << "\n"); LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view << "\n"); return ViewDimension{view, static_cast(en2.index())}; } } } llvm_unreachable("Expect to be able to extract a view defining loop range"); } static Optional fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, LinalgOp tiledConsumer, OperationFolder &state) { auto maybeConsumerIdx = consumer.getIndexOfInput(producedView); if (!maybeConsumerIdx.hasValue()) return llvm::None; unsigned consumerIdx = maybeConsumerIdx.getValue(); auto maybeProducerIdx = producer.getIndexOfOutput(producedView); if (!maybeProducerIdx.hasValue()) return llvm::None; unsigned producerIdx = maybeProducerIdx.getValue(); // If the view is the same between consumer and tiledConsumer, this means we // don't have loops and the producer cannot be fused at this level. if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx)) return llvm::None; auto tiledConsumerSubView = dyn_cast_or_null( tiledConsumer.getInput(consumerIdx)->getDefiningOp()); // If we don't have a slice, this also means we don't have loops and the // producer cannot be fused at this level. if (!tiledConsumerSubView) return llvm::None; // loopToOperandRangesMaps are permutations-only by construction: // we can always identify a data dimension with a (at least one) loop // dimension. AffineMap producerMap = loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: " << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n"); LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx << ", producer map: " << producerMap << "\n"); unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); // Iterate over dimensions identified by the producer map for `producerIdx`. // This defines a subset of the loop ranges that we need to complete later. for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index()); } OpBuilder b(tiledConsumer.getOperation()); auto loc = tiledConsumer.getLoc(); // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the view that // defines the loop ranges using the `producer`. for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { if (loopRanges[i].min) LLVM_DEBUG(llvm::dbgs() << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto viewDim = getViewDefiningLoopRange(producer, i); loopRanges[i] = SubViewOp::Range{ state.create(b, loc, 0), linalg::intrinsics::dim(viewDim.view, viewDim.dimension), state.create(b, loc, 1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } return cloneWithLoopRanges(b, loc, producer, loopRanges, state); } // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, LinalgOp consumer) { // If a producer has multiple outputs, the analysis needs to take the tiling // of other outputs into account. if (producer.getNumOutputs() != 1) return false; // Until subview analysis is available, same SSA value is required for fusion. if (producer.getOutput(0) != readView) return false; // No control-flow divergence supported. Only straightline op fusion allowed. // TODO(ntv) allow fusion when a dominance relation exists. if (producer.getOperation()->getBlock() != consumer.getOperation()->getBlock()) return false; return true; } static void fuseLinalgOps(FuncOp f, ArrayRef tileSizes) { OperationFolder state(f.getContext()); DenseSet eraseSet; LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); // 1. Record the linalg ops so we can traverse them in reverse order. SmallVector linalgOps; f.walk([&](LinalgOp op) { linalgOps.push_back(op.getOperation()); }); // 2. Setup the dependences graph, aliases are populated lazily. Aliases aliases; LinalgDependenceGraph G(aliases, linalgOps); // 2. For each original linalg op (in reverse order to allow chained // fusions). for (auto *op : llvm::reverse(linalgOps)) { auto consumer = cast(op); LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op); // 3. If marked for erasure, it has already been fused. Skip fusing op. if (eraseSet.count(op) > 0) { LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip."); continue; } // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op. auto tiledOp = tileLinalgOp(op, tileSizes, state); if (!tiledOp) { LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip."); continue; } // 5. For now, we only fuse RAW dependences. SmallVector fusedProducers; SmallVector fusedViews; for (auto dependence : G.getDependencesInto( consumer, LinalgDependenceGraph::DependenceType::RAW)) { auto producer = cast(dependence.dependentOpView.op); LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" << *producer.getOperation() << "\n"); // a. For now we require fusion on identical SSA values, this allows us to // not worry about partial writes etc. // TODO(ntv) support more elaborate fusion with non identical SSA values. auto *view = dependence.indexingView; if (view != dependence.dependentOpView.view) { LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip."); continue; } // b. Make some simple structural checks that alleviate the need for more // complex analyses. if (!isStructurallyFusableProducer(producer, view, op)) { LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); continue; } // c. Check for fusion-preventing write that would violate dependences. // `view` is a producer write that cannot bypass any other write or read. bool preventFusion = false; for (auto *op : G.findCoveringDependences(producer, consumer)) if (eraseSet.count(op) == 0) { preventFusion = true; LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op); break; } if (preventFusion) continue; // 6. Try to fuse `producer` just before `tiledOp`. LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n")); auto tOp = tiledOp->op; OpBuilder builder(tOp.getOperation()); ScopedContext scope(builder, tOp.getLoc()); LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n"); auto maybeFusedProducer = fuse(view, producer, op, tOp, state); if (!maybeFusedProducer) { LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip."); continue; } fusedProducers.push_back(producer.getOperation()); fusedViews.push_back(view); } // 7. If no fusion occurred, or a drop the outer tiled loop which undoes // everything we did. if (fusedProducers.empty()) { tiledOp->loops[0].erase(); continue; } eraseSet.insert(op); eraseSet.insert(fusedProducers.begin(), fusedProducers.end()); } for (auto *op : eraseSet) op->erase(); LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); } namespace { struct LinalgFusionPass : public FunctionPass { LinalgFusionPass() = default; LinalgFusionPass(ArrayRef sizes); void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); } SmallVector tileSizes; }; } // namespace LinalgFusionPass::LinalgFusionPass(ArrayRef sizes) : LinalgFusionPass() { if (!sizes.empty()) this->tileSizes.assign(sizes.begin(), sizes.end()); } std::unique_ptr> mlir::linalg::createLinalgFusionPass(ArrayRef tileSizes) { return std::make_unique(tileSizes); } static PassRegistration pass("linalg-fusion", "Fuse operations in the linalg dialect", [] { auto pass = std::make_unique(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); return pass; });