//===- Transforms.cpp - Implementation of the linalg Transformations ------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements analyses and transformations for the linalg dialect. // //===----------------------------------------------------------------------===// #include "linalg3/Transforms.h" #include "linalg2/Intrinsics.h" #include "linalg3/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; void linalg::composeSliceOps(mlir::Function *f) { f->walk([](SliceOp sliceOp) { auto *sliceResult = sliceOp.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); sliceOp.erase(); }); } void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { f->walk([](Operation *op) { if (auto matmulOp = dyn_cast(op)) { matmulOp.writeAsFinerGrainTensorContraction(); } else if (auto matvecOp = dyn_cast(op)) { matvecOp.writeAsFinerGrainTensorContraction(); } else { return; } op->erase(); }); } // Folding eagerly is necessary to abide by affine.for static step requirement. // Returns nullptr if folding is not trivially feasible. static Value *tryFold(AffineMap map, SmallVector operands) { assert(map.getNumResults() == 1 && "single result map expected"); auto expr = map.getResult(0); if (auto dim = expr.dyn_cast()) return operands[dim.getPosition()]; if (auto sym = expr.dyn_cast()) return operands[map.getNumDims() + sym.getPosition()]; if (auto cst = expr.dyn_cast()) return constant_index(cst.getValue()); return nullptr; } Value *linalg::makeFoldedComposedAffineApply(AffineMap map, ArrayRef operandsRef) { SmallVector operands(operandsRef.begin(), operandsRef.end()); fullyComposeAffineMapAndOperands(&map, &operands); if (auto *v = tryFold(map, operands)) { return v; } auto *b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); return b->create(loc, map, operands).getResult(); } linalg::RangeParts::RangeParts(unsigned reserved) { mins.reserve(reserved); maxes.reserve(reserved); steps.reserve(reserved); } static SmallVector extractFromRanges(ArrayRef ranges, std::function extract) { SmallVector res; res.reserve(ranges.size()); for (auto *v : ranges) { auto r = cast(v->getDefiningOp()); res.push_back(extract(r)); } return res; } linalg::RangeParts::RangeParts(ArrayRef ranges) : mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })), maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })), steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {} SmallVector linalg::RangeParts::makeRanges() { SmallVector res; res.reserve(mins.size()); for (auto z : llvm::zip(mins, maxes, steps)) { res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z))); } return res; } static RangeParts makeGenericRangeParts(AffineMap map, ArrayRef ranges) { assert(map.getNumInputs() == ranges.size()); unsigned numDims = map.getNumDims(); assert(map.getNumSymbols() == 0); assert(map.getRangeSizes().empty()); RangeParts res(map.getNumResults()); RangeParts rangeParts(ranges); for (auto expr : map.getResults()) { AffineMap map = AffineMap::get(numDims, 0, expr, {}); res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins)); res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes)); res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps)); } return res; } SmallVector makeGenericRanges(AffineMap map, ArrayRef ranges) { return makeGenericRangeParts(map, ranges).makeRanges(); } SmallVector linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps, ArrayRef ranges, ArrayRef tileSizes) { RangeParts res = makeGenericRangeParts(operandRangesToLoopMaps, ranges); if (tileSizes.empty()) return res.makeRanges(); SmallVector tiledSteps; for (auto z : llvm::zip(res.steps, tileSizes)) { auto *step = std::get<0>(z); auto tileSize = std::get<1>(z); auto stepValue = cast(step->getDefiningOp()).getValue(); auto tileSizeValue = cast(tileSize->getDefiningOp()).getValue(); assert(stepValue > 0); tiledSteps.push_back(constant_index(stepValue * tileSizeValue)); } res.steps = tiledSteps; return res.makeRanges(); } template static SmallVector writeContractionAsLoops(ContractionOp contraction) { FuncBuilder builder(contraction.getOperation()); ScopedContext scope(builder, contraction.getLoc()); auto allRanges = getRanges(contraction); auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges); SmallVector parallelIvs(contraction.getNumParallelDims()); SmallVector reductionIvs(contraction.getNumReductionDims()); auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs); auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs); assert(loopRanges.size() == pivs.size() + rivs.size()); // clang-format off using linalg::common::LoopNestRangeBuilder; ArrayRef ranges(loopRanges); LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{ LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))( [&contraction, ¶llelIvs, &reductionIvs] { SmallVector parallel( parallelIvs.begin(), parallelIvs.end()); SmallVector reduction( reductionIvs.begin(), reductionIvs.end()); contraction.emitScalarImplementation(parallel, reduction); }); }); // clang-format on // Return the AffineForOp for better compositionality (e.g. tiling). SmallVector loops; loops.reserve(pivs.size() + rivs.size()); for (auto iv : parallelIvs) loops.push_back(getForInductionVarOwner(iv.getValue())); for (auto iv : reductionIvs) loops.push_back(getForInductionVarOwner(iv.getValue())); return loops; } llvm::Optional> linalg::writeAsLoops(Operation *op) { if (auto matmulOp = dyn_cast(op)) { return writeContractionAsLoops(matmulOp); } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsLoops(matvecOp); } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsLoops(dotOp); } return llvm::None; } void linalg::lowerToLoops(mlir::Function *f) { f->walk([](Operation *op) { if (writeAsLoops(op)) op->erase(); }); } /// Emits and returns the standard load and store ops from the view indexings. /// If the indexing is of index type, use it as an index to the load/store. /// If the indexing is a range, use range.min + indexing as an index to the /// load/store. template static SmallVector emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) { unsigned storeDim = 0; SmallVector operands; for (auto *indexing : viewOp.getIndexings()) { if (indexing->getType().isa()) { operands.push_back(indexing); continue; } RangeOp range = cast(indexing->getDefiningOp()); ValueHandle min(range.getMin()); Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++); using edsc::op::operator+; operands.push_back(min + ValueHandle(storeIndex)); } return operands; } namespace { /// Rewriting linalg::LoadOp and linalg::StoreOp to mlir::LoadOp and /// mlir::StoreOp requires finding the proper indexing in the supporting MemRef. /// This is most easily achieved by calling emitAndReturnFullyComposedView to /// fold away all the SliceOp. template struct Rewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; /// Performs the rewrite. PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op, PatternRewriter &rewriter) const override; }; struct LowerLinalgLoadStorePass : public FunctionPass { void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique>(context)); patterns.push_back(llvm::make_unique>(context)); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; template <> PatternMatchResult Rewriter::matchAndRewrite(linalg::LoadOp load, PatternRewriter &rewriter) const { SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(load.getView()->getDefiningOp()); FuncBuilder builder(load); ScopedContext scope(builder, load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); rewriter.replaceOpWithNewOp(load, memRef, operands); return matchSuccess(); } template <> PatternMatchResult Rewriter::matchAndRewrite(linalg::StoreOp store, PatternRewriter &rewriter) const { SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(store.getView()->getDefiningOp()); FuncBuilder builder(store); ScopedContext scope(builder, store.getLoc()); auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(store, view); rewriter.replaceOpWithNewOp(store, valueToStore, memRef, operands); return matchSuccess(); } } // namespace FunctionPassBase *linalg::createLowerLinalgLoadStorePass() { return new LowerLinalgLoadStorePass(); }