//===- Transforms.cpp - Implementation of the linalg Transformations ------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements analyses and transformations for the linalg dialect. // //===----------------------------------------------------------------------===// #include "linalg4/Transforms.h" #include "linalg3/Intrinsics.h" #include "linalg3/TensorOps.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/LoopUtils.h" using llvm::ArrayRef; using llvm::SmallVector; using namespace mlir; using namespace mlir::edsc; using namespace linalg; using namespace linalg::intrinsics; llvm::Optional> linalg::writeAsTiledLoops(Operation *op, ArrayRef tileSizes) { auto loops = writeAsLoops(op); if (loops.hasValue()) return mlir::tile(*loops, tileSizes, loops->back()); return llvm::None; } void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef tileSizes) { f.walk([tileSizes](Operation *op) { if (writeAsTiledLoops(op, tileSizes).hasValue()) op->erase(); }); } static bool isZeroIndex(Value *v) { return isa_and_nonnull(v->getDefiningOp()) && cast(v->getDefiningOp()).getValue() == 0; } template static llvm::SmallVector makeTiledRanges(TensorContractionBase &contraction, ArrayRef allRanges, llvm::ArrayRef ivs, llvm::ArrayRef tileSizes) { assert(ivs.size() == tileSizes.size()); if (ivs.empty()) return RangeParts(allRanges).makeRanges(); auto *op = static_cast(&contraction); RangeParts result(allRanges.size()); RangeParts rangeParts(allRanges); for (auto map : op->loopsToOperandRangeMaps()) { // 1. Take the first ivs results of the map, the other ones are not composed // but merely copied over. assert(map.getNumSymbols() == 0); MLIRContext *context = ScopedContext::getContext(); unsigned numParallel = op->getNumParallelDims(); unsigned numReduction = op->getNumReductionDims(); if (ivs.size() < numParallel + numReduction) { // Inject zeros in positions that are not tiled. SmallVector dimReplacements(numParallel + numReduction); for (unsigned i = 0, e = numParallel + numReduction; i < e; ++i) { dimReplacements[i] = (i < ivs.size()) ? getAffineDimExpr(i, context) : getAffineConstantExpr(0, context); } map = map.replaceDimsAndSymbols(dimReplacements, {}, ivs.size(), 0); } // 2. Apply the rewritten map to the ranges. unsigned numDims = map.getNumDims(); for (auto en : llvm::enumerate(map.getResults())) { auto index = en.index(); auto expr = en.value(); AffineMap exprMap = AffineMap::get(numDims, 0, expr); ValueHandle offset(makeFoldedComposedAffineApply(exprMap, ivs)); // Offset is normally a function of loop induction variables. // If it is 0, it must come from a dimension that was not tiled. if (isZeroIndex(offset)) { result.mins.push_back(rangeParts.mins[index]); result.maxes.push_back(rangeParts.maxes[index]); continue; } ValueHandle step(makeFoldedComposedAffineApply(exprMap, tileSizes)); ValueHandle min(rangeParts.mins[index]); using edsc::op::operator+; result.mins.push_back(min + offset); // Ideally this should be: // `min(rangeParts.max, rangeParts.min + offset + step)` // but that breaks the current limitations of the affine dialect. result.maxes.push_back(min + offset + step); } } // Note that for the purpose of tiled ranges and views, the steps do not // change in our representation. result.steps = rangeParts.steps; return result.makeRanges(); } template static SmallVector makeTiledViews(linalg::TensorContractionBase &contraction, ArrayRef ivs, ArrayRef tileSizes) { auto tiledRanges = makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes); SmallVector res; unsigned currentRange = 0; for (auto *in : contraction.getInputsAndOutputs()) { unsigned runningSliceDim = 0; auto *runningSlice = in; assert(runningSlice->getType().template isa()); for (unsigned d = 0, e = getViewRank(runningSlice); d < e; ++d) { auto *r = tiledRanges[currentRange++]; runningSlice = slice(runningSlice, r, runningSliceDim++).getValue(); } res.push_back(runningSlice); } return res; } template static SmallVector writeContractionAsTiledViews(TensorContractionBase &contraction, ArrayRef tileSizes) { assert(tileSizes.size() <= contraction.getNumParallelDims() + contraction.getNumReductionDims()); auto *op = static_cast(&contraction); mlir::OpBuilder builder(op->getOperation()); ScopedContext scope(builder, op->getLoc()); SmallVector ivs(tileSizes.size()); auto pivs = IndexHandle::makeIndexHandlePointers(ivs); // clang-format off using linalg::common::LoopNestRangeBuilder; auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction), getRanges(contraction), tileSizes); linalg::common::LoopNestRangeBuilder(pivs, ranges)( [&contraction, &tileSizes, &ivs]() { SmallVector ivValues(ivs.begin(), ivs.end()); auto views = makeTiledViews(contraction, ivValues, tileSizes); ScopedContext::getBuilder().create( ScopedContext::getLocation(), views); }); // clang-format on SmallVector res; res.reserve(ivs.size()); for (auto iv : ivs) res.push_back(getForInductionVarOwner(iv.getValue())); return res; } llvm::Optional> linalg::writeAsTiledViews(Operation *op, ArrayRef tileSizes) { if (auto matmulOp = dyn_cast(op)) { return writeContractionAsTiledViews(matmulOp, tileSizes); } else if (auto matvecOp = dyn_cast(op)) { return writeContractionAsTiledViews(matvecOp, tileSizes); } else if (auto dotOp = dyn_cast(op)) { return writeContractionAsTiledViews(dotOp, tileSizes); } return llvm::None; } void linalg::lowerToTiledViews(mlir::Function f, ArrayRef tileSizes) { f.walk([tileSizes](Operation *op) { if (auto matmulOp = dyn_cast(op)) { writeAsTiledViews(matmulOp, tileSizes); } else if (auto matvecOp = dyn_cast(op)) { writeAsTiledViews(matvecOp, tileSizes); } else if (auto dotOp = dyn_cast(op)) { writeAsTiledViews(dotOp, tileSizes); } else { return; } op->erase(); }); }