//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for transforming Linalg operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #define DEBUG_TYPE "linalg-transforms" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using llvm::dbgs; using llvm::SetVector; // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, StringRef linalgMarker, ArrayRef permutation) { assert(permutation.empty() || permutation.size() == sizes.size()); auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); return success(); } LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, ArrayRef operandIndicesToFuse, StringRef linalgMarker) { auto tileRes = tileLinalgOperation(rewriter, op, sizes); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); Aliases aliases; auto G = LinalgDependenceGraph::buildDependenceGraph( aliases, op->getParentOfType()); SmallVector originalProducers; for (auto operandIdx : operandIndicesToFuse) { auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G); if (!fusionRes) { // Linalg fusion requires tiled loops to even determine whether it is // possible to fuse. As a consequence, the pattern may fail even though a // tiled version of op has already been introduced. // So we need to remove the tiled version ourselves in case of failure. // Another possibility is to ensure the constraints on the pattern // guarantee that fusion will occur and just assert here. As we develop // more complex patterns we can choose what is best. rewriter.eraseOp(tileRes->loops[0]); return failure(); } fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); originalProducers.push_back(fusionRes->originalProducer); } // The originalProducers can now be safely erased. This is similar to // SSA-value use-def but in the world of buffer + structured ops. for (auto *originalProducer : originalProducers) rewriter.eraseOp(originalProducer); return success(); } bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); assert(consumer.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (!consumer) return false; auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); if (!maybeConsumerIndex) return false; Aliases aliases; auto G = LinalgDependenceGraph::buildDependenceGraph( aliases, consumer.getParentOfType()); for (auto dependence : G.getDependencesInto( consumer, LinalgDependenceGraph::DependenceType::RAW)) { auto producer = cast(dependence.dependentOpView.op); if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) continue; if (isaOpType(dependence.dependentOpView.op)) return true; } return false; } //============================================================================// // Precondition and transformation for vectorization of Linalg generic ops. //============================================================================// static bool hasMultiplyAddBody(linalg::GenericOp op) { auto &r = op.region(); if (r.empty()) return false; if (r.getBlocks().size() != 1) return false; auto &ops = r.front().getOperations(); if (ops.size() != 3) return false; using mlir::matchers::m_Val; auto a = m_Val(r.front().getArgument(0)); auto b = m_Val(r.front().getArgument(1)); auto c = m_Val(r.front().getArgument(2)); // TODO(ntv) Update this detection once we have matcher support for // specifying that any permutation of operands matches. auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || pattern3.match(&ops.back()) || pattern4.match(&ops.back()); } // TODO(ntv) should be Tablegen'd from a single source that generates the op // itself. static bool isMatmul(linalg::GenericOp genericOp) { auto *ctx = genericOp.getContext(); auto m = getAffineDimExpr(0, ctx); auto n = getAffineDimExpr(1, ctx); auto k = getAffineDimExpr(2, ctx); auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx); return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } LogicalResult mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { // TODO(ntv): This is in fact much more general than just vectorization for // matmul ops. auto genericOp = dyn_cast(op); if (!genericOp || !isMatmul(genericOp)) return failure(); // TODO(ntv): non-identity layout. auto isStaticMemRefWithIdentityLayout = [](Value v) { auto m = v.getType().dyn_cast(); if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) return false; return true; }; if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); } SmallVector mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, Operation *op) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Rewrite linalg op as vector.contract: " << *op << ":\n"); assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); auto genericOp = cast(op); assert(genericOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; using vector_contract = edsc::intrinsics::ValueBuilder; using vector_type_cast = edsc::intrinsics::ValueBuilder; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); std_store(vRes, vectorMemRefC); return {}; } //============================================================================// // Precondition and transformation for permutation of Linalg generic ops. //============================================================================// LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( Operation *op, ArrayRef permutation) { if (permutation.empty()) return failure(); // Transformation applies to generic ops only. if (!isa(op) && !isa(op)) return failure(); LinalgOp linOp = cast(op); // Transformation applies to buffers only. if (!linOp.hasBufferSemantics()) return failure(); return success(); } SmallVector mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, ArrayRef permutation, StringRef linalgMarker) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op << ":\n"); assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && "DRR failure case must be a precondition"); auto linOp = cast(op); auto permutationMap = inversePermutation( AffineMap::getPermutationMap(permutation, rewriter.getContext())); SmallVector newIndexingMap; auto indexingMaps = linOp.indexing_maps().getValue(); for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue().compose( permutationMap); newIndexingMap.push_back(m); } auto itTypes = linOp.iterator_types().getValue(); SmallVector itTypesVector; for (unsigned i = 0, e = itTypes.size(); i != e; ++i) itTypesVector.push_back(itTypes[i]); applyPermutationToVector(itTypesVector, permutation); op->setAttr(getIndexingMapsAttrName(), rewriter.getAffineMapArrayAttr(newIndexingMap)); op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector)); op->setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); return {}; } //============================================================================// // Precondition and transformation for Linalg subview promotion. //============================================================================// LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { LinalgOp linOp = dyn_cast(op); // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) return failure(); return success(); } SmallVector mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " << *op << ":\n"); assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); LinalgOp linOp = cast(op); assert(linOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SetVector subViews; for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { promoteSubViewOperands(rewriter, linOp, subViews); return {}; } llvm_unreachable("DRR failure case must be a precondition"); }