//===- LinalgOps.cpp - Implementation of the linalg operations ------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; ///////////////////// Operations defined with Tablegen ///////////////////////// // For such operations that do not correspond to library calls (i.e. defined in // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// template static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { auto attrNames = op.linalgTraitAttrNames(); llvm::StringSet<> linalgTraitAttrsSet; linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; for (auto attr : op.getAttrs()) if (linalgTraitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); p << op.getOperationName() << " " << dictAttr << " " << op.getOperands(); if (!op.region().empty()) p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) p << " -> " << outputTensorTypes; } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } static void print(OpAsmPrinter &p, IndexedGenericOp op) { printGenericOp(p, op); } static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { SmallVector operandsInfo, regionOperandsInfo; DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. // The core linalg traits must contain the information necessary to pass the // verifier. if (parser.parseAttribute(dictAttr, "_", result.attributes) || parser.parseOperandList(operandsInfo)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); Region ®ion = *result.addRegion(); SmallVector operandTypes, regionTypes; // Optional attributes may be added. // Either Optional getFunAttrName() attribute or region must be specified. if (!dictAttr.get(getFunAttrName()) && parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(operandTypes)) return failure(); // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. SmallVector tensorResultTypes; if (parser.parseOptionalArrowTypeList(tensorResultTypes)) return failure(); if (!tensorResultTypes.empty()) result.addTypes(tensorResultTypes); return parser.resolveOperands(operandsInfo, operandTypes, parser.getCurrentLocation(), result.operands); } template static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { auto nOperands = op.getNumOperands(); if (block.getNumArguments() != nOperands) return op.emitOpError("expected number of block arguments to match number " "of operands"); // Note: the number and type of yield values are checked in the YieldOp. auto nInputViews = op.getNumInputs(); for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") << "operand: " << viewType; } return success(); } template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nOperands = op.getNumOperands(); if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( "expected number of block arguments to match number of operands + " "number of loops"); // Note: the number and type of yield values are checked in the YieldOp. for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") << (i + 1) << " to be an index"; for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") << (memrefArgIndex + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") << "operand: " << viewType; } return success(); } template static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); template LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { auto res = verifyFuncArgs(op, funType); if (failed(res)) return res; auto nInputs = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); // linalg.generic output element types are exactly the function results. for (unsigned idx = 0; idx < nOutputs; ++idx) { ShapedType shapedType = op.getShapedType(nInputs + idx); if (funType.getResult(idx) != shapedType.getElementType()) return op.emitOpError("expected function result ") << (idx + 1) << " of the same type as elemental type " << shapedType.getElementType() << " of output " << (idx + 1); } return success(); } template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { auto nOperands = op.getNumOperands(); if (funType.getNumInputs() != nOperands) return op.emitOpError( "expected function arguments to match number of operands"); if (funType.getNumResults() != op.getNumOutputs()) return op.emitOpError("expected function results(") << funType.getNumResults() << ") to match number of outputs(" << op.getNumOutputs() << ")"; // linalg.generic operands element types are exactly the first function // arguments. for (unsigned idx = 0; idx < nOperands; ++idx) { ShapedType shapedType = op.getShapedType(idx); if (funType.getInput(idx) != shapedType.getElementType()) return op.emitOpError("expected function argument ") << (idx + 1) << " of the same type as elemental type " << shapedType.getElementType() << " of operand " << (idx + 1); } return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); auto nOutputs = op.getNumOutputs(); auto nOperands = op.getNumOperands(); if (funType.getNumInputs() != nOperands + nLoops) return op.emitOpError("expected function arguments to match number of " "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( "expected function results to match number of outputs"); for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) return op.emitOpError("expected function argument ") << (i + 1) << " to be an index"; // linalg.generic operands element types are exactly the first function // arguments. for (unsigned idx = 0; idx < nOperands; ++idx) { ShapedType shapedType = op.getShapedType(idx); if (funType.getInput(idx + nLoops) != shapedType.getElementType()) return op.emitOpError("expected function argument ") << (idx + nLoops + 1) << " of the same type as elemental type " << shapedType.getElementType() << " of input " << (idx + 1); } return success(); } template static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); if (nInputsAndOutputBuffers != llvm::size(op.views())) return op.emitOpError("expected exactly ") << nInputsAndOutputBuffers << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); auto funType = funOp ? funOp.getType() : FunctionType(); if (!region.empty()) { if (region.getBlocks().size() != 1) return op.emitOpError("expected region with 1 block"); if (failed(verifyBlockArgs(op, region.getBlocks().front()))) return failure(); } else { if (!funOp || !funOp.getType()) return op.emitOpError( "expected function attribute to refer to a defined symbol"); if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } SmallVector indexingMaps; indexingMaps.reserve(op.indexing_maps().size()); for (auto en : llvm::enumerate(op.indexing_maps())) { auto idx = en.index(); auto m = en.value().template cast().getValue(); indexingMaps.push_back(m); // Save reference to map for further checks. auto view = (idx < nInputViews) ? op.getInputShapedType(idx) : op.getOutputShapedType(idx - nInputViews); if (m.getNumSymbols() != 0) return op.emitOpError("expected indexing_map #") << idx << " to have no symbols"; if (m.getNumDims() != nLoops) return op.emitOpError("expected indexing_map #") << idx << " to have " << nLoops << " dim(s) to match the number of loops"; if (m.getNumResults() == 1 && view.getRank() == 0) { auto cst = m.getResult(0).template dyn_cast(); if (!cst || cst.getValue() != 0) return op.emitOpError("expected indexing_map #") << idx << " to be 0 to match 0-D view: " << view; } if (m.getNumResults() != view.getRank()) return op.emitOpError("expected indexing_map #") << idx << " results to match view rank: " << view; } auto concatMap = concatAffineMaps(indexingMaps); auto aggregateMap = inversePermutation(concatMap); if (!aggregateMap) return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); return success(); } static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// // RangeOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RangeOp op) { p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":" << op.step(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getResult().getType(); } static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) { SmallVector rangeInfo(3); RangeType type; auto indexTy = parser.getBuilder().getIndexType(); return failure(parser.parseOperand(rangeInfo[0]) || parser.parseColon() || parser.parseOperand(rangeInfo[1]) || parser.parseColon() || parser.parseOperand(rangeInfo[2]) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands(rangeInfo, indexTy, result.operands) || parser.addTypeToList(type, result.types)); } //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. static bool isReassociationValid(ArrayRef reassociation, int *invalidIndex = nullptr) { if (reassociation.empty()) return true; unsigned nDims = reassociation[0].getNumDims(); unsigned nextExpectedDim = 0; for (auto it : llvm::enumerate(reassociation)) { auto m = it.value(); if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { if (invalidIndex) *invalidIndex = it.index(); return false; } for (auto e : m.getResults()) { auto d = e.dyn_cast(); if (!d || d.getPosition() != nextExpectedDim++) { if (invalidIndex) *invalidIndex = it.index(); return false; } } } if (nextExpectedDim != nDims) { if (invalidIndex) *invalidIndex = reassociation.size() - 1; return false; } return true; } /// Detect whether memref dims [dim, dim + extent) can be reshaped without /// copies. static bool isReshapableDimBand(unsigned dim, unsigned extent, ArrayRef sizes, ArrayRef strides) { assert(sizes.size() == strides.size() && "mismatched ranks"); // off by 1 indexing to avoid out of bounds // V for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { // Only bands of static shapes are reshapable. This is due to the fact that // there is no relation between dynamic sizes and dynamic strides: we do not // have enough information to know whether a "-1" size corresponds to the // proper symbol in the AffineExpr of a stride. if (ShapedType::isDynamic(sizes[dim + 1])) return false; // TODO(ntv) Refine this by passing the proper nDims and nSymbols so we can // simplify on the fly and catch more reshapable cases. if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) return false; } return true; } /// Compute the MemRefType obtained by applying the `reassociation` (which is /// expected to be valid) to `type`. /// If `type` is Contiguous MemRefType, this always produce a contiguous /// MemRefType. static MemRefType computeReshapeCollapsedType(MemRefType type, ArrayRef reassociation) { auto sizes = type.getShape(); AffineExpr offset; SmallVector strides; auto status = getStridesAndOffset(type, strides, offset); (void)status; assert(succeeded(status) && "expected strided memref"); SmallVector newSizes; newSizes.reserve(reassociation.size()); SmallVector newStrides; newStrides.reserve(reassociation.size()); // Use the fact that reassociation is valid to simplify the logic: only use // each map's rank. assert(isReassociationValid(reassociation) && "invalid reassociation"); unsigned currentDim = 0; for (AffineMap m : reassociation) { unsigned dim = m.getNumResults(); int64_t size = 1; AffineExpr stride = strides[currentDim + dim - 1]; if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { size = ShapedType::kDynamicSize; stride = AffineExpr(); } else { for (unsigned d = 0; d < dim; ++d) size *= sizes[currentDim + d]; } newSizes.push_back(size); newStrides.push_back(stride); currentDim += dim; } // Early-exit: if `type` is contiguous, the result must be contiguous. if (canonicalizeStridedLayout(type).getAffineMaps().empty()) return MemRefType::get(newSizes, type.getElementType(), {}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where // copies may be necessary. int64_t intOffset = ShapedType::kDynamicStrideOrOffset; if (auto o = offset.dyn_cast()) intOffset = o.getValue(); SmallVector intStrides; intStrides.reserve(strides.size()); for (auto stride : newStrides) { if (auto cst = stride.dyn_cast_or_null()) intStrides.push_back(cst.getValue()); else intStrides.push_back(ShapedType::kDynamicStrideOrOffset); } auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( MemRefType::get(newSizes, type.getElementType(), {layout})); } /// Helper functions assert Attribute of the proper type in attr and returns the /// corresponding vector. /// TODO(rridle,ntv) this should be evolved into a generic /// `getRangeOfType(ArrayAttr attrs)` that does not copy. static SmallVector getAffineMaps(ArrayAttr attrs) { return functional::map( [](Attribute a) { return a.cast().getValue(); }, attrs); } template unsigned getMaxPosOfType(ArrayRef> exprArrays) { unsigned pos = 0; for (auto exprs : exprArrays) { for (auto expr : exprs) { expr.walk([&pos](AffineExpr e) { if (auto d = e.dyn_cast()) pos = std::max(pos, d.getPosition()); }); } } return pos; } static SmallVector getSymbolLessAffineMaps(ArrayRef> reassociation) { unsigned maxDim = getMaxPosOfType(reassociation); assert(getMaxPosOfType(reassociation) == 0 && "Expected symbol-less expressions"); SmallVector maps; maps.reserve(reassociation.size()); for (auto exprs : reassociation) maps.push_back(AffineMap::get(maxDim + 1, 0, exprs)); return maps; } void mlir::linalg::ReshapeOp::build( Builder *b, OperationState &result, Value view, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); auto memRefType = view.getType().cast(); auto resultType = computeReshapeCollapsedType(memRefType, maps); build(b, result, resultType, view, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } void mlir::linalg::ReshapeOp::build( Builder *b, OperationState &result, Type resultType, Value view, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); build(b, result, resultType, view, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } static void print(OpAsmPrinter &p, ReshapeOp op) { p << op.getOperationName() << " " << op.view() << " " << op.reassociation(); p.printOptionalAttrDict(op.getAttrs(), {ReshapeOp::getReassociationAttrName()}); p << " : " << op.getViewType() << " into " << op.getResult().getType(); } static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType view; ArrayAttr reassociation; MemRefType type, resultType; return failure(parser.parseOperand(view) || parser.parseAttribute(reassociation, ReshapeOp::getReassociationAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.parseKeywordType("into", resultType) || parser.resolveOperand(view, type, result.operands) || parser.addTypeToList(resultType, result.types)); } static LogicalResult verify(ReshapeOp op) { MemRefType expandedType = op.getViewType(); MemRefType collapsedType = op.getResult().getType().cast(); unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); bool isCollapse = expandedRank > collapsedRank; if (!isCollapse) { std::swap(expandedRank, collapsedRank); std::swap(expandedType, collapsedType); } if (expandedRank == 0 || collapsedRank == 0) return op.emitOpError("expected non-zero memref ranks"); if (expandedRank == collapsedRank) return op.emitOpError("expected to collapse or expand dims"); if (collapsedRank != op.reassociation().size()) return op.emitOpError("expected rank of the collapsed view(") << collapsedRank << ") to be the number of reassociation maps(" << op.reassociation().size() << ")"; auto maps = getAffineMaps(op.reassociation()); for (auto it : llvm::enumerate(maps)) if (it.value().getNumDims() != expandedRank) return op.emitOpError("expected reassociation map #") << it.index() << " of same rank as expanded memref(" << expandedRank << "), but got " << it.value().getNumDims(); int invalidIdx = 0; if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); if (collapsedType != expectedType) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; return success(); } //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, Value base, ValueRange indexings) { result.addOperands(base); result.addOperands(indexings); auto memRefType = base.getType().cast(); int64_t offset; SmallVector strides; auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && strides.size() == indexings.size()); (void)res; unsigned rank = memRefType.getRank(); // TODO(ntv): propagate static size and stride information when available. SmallVector sizes(rank, -1); // -1 encodes dynamic size. Type elementType = memRefType.getElementType(); result.addTypes({MemRefType::get( sizes, elementType, {makeStridedLinearLayoutMap(strides, offset, b->getContext())}, memRefType.getMemorySpace())}); } static void print(OpAsmPrinter &p, SliceOp op) { auto indexings = op.indexings(); p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings << "] "; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getBaseViewType(); if (!indexings.empty()) p << ", " << op.indexings().getTypes(); p << ", " << op.getType(); } static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType baseInfo; SmallVector operands; SmallVector types; if (parser.parseOperand(baseInfo) || parser.parseOperandList(operands, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types)) return failure(); if (types.size() < 2) return parser.emitError(parser.getCurrentLocation(), "expected at least input and result view types"); ArrayRef indexingTypes = ArrayRef(types).drop_front().drop_back(); return failure( parser.resolveOperand(baseInfo, types.front(), result.operands) || (!operands.empty() && parser.resolveOperands(operands, indexingTypes, operands.front().location, result.operands)) || parser.addTypeToList(types.back(), result.types)); } static LogicalResult verify(SliceOp op) { unsigned rank = op.getBaseViewRank(); if (rank != llvm::size(op.indexings())) return op.emitOpError("expected ") << rank << " indexings, got " << llvm::size(op.indexings()); unsigned index = 0; for (auto indexing : op.indexings()) { if (indexing.getType().isa()) --rank; ++index; } if (op.getRank() != rank) return op.emitOpError() << "expected rank of the view(" << op.getRank() << ") to be the number of ranges(" << rank << ")"; return success(); } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, Value view, AffineMapAttr permutation, ArrayRef attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); auto memRefType = view.getType().cast(); auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); // Compute permuted sizes. SmallVector sizes(rank, 0); for (auto en : llvm::enumerate(permutationMap.getResults())) sizes[en.index()] = originalSizes[en.value().cast().getPosition()]; // Compute permuted strides. int64_t offset; SmallVector strides; auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && strides.size() == static_cast(rank)); (void)res; auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext()); map = permutationMap ? map.compose(permutationMap) : map; // Compute result type. auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map, memRefType.getMemorySpace()); build(b, result, resultType, view, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); } static void print(OpAsmPrinter &p, TransposeOp op) { p << op.getOperationName() << " " << op.view() << " " << op.permutation(); p.printOptionalAttrDict(op.getAttrs(), {TransposeOp::getPermutationAttrName()}); p << " : " << op.view().getType(); } static ParseResult parseTransposeOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType view; AffineMap permutation; MemRefType type; if (parser.parseOperand(view) || parser.parseAffineMap(permutation) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(view, type, result.operands) || parser.addTypeToList(type, result.types)) return failure(); result.addAttribute(TransposeOp::getPermutationAttrName(), AffineMapAttr::get(permutation)); return success(); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, YieldOp op) { p << op.getOperationName(); if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); p.printOptionalAttrDict(op.getAttrs()); if (op.getNumOperands() > 0) p << " : " << op.getOperandTypes(); } static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; llvm::SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || parser.parseOptionalAttrDict(result.attributes) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } template static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. auto nOutputs = genericOp.getNumOutputs(); if (op.getNumOperands() != nOutputs) return op.emitOpError("expected number of yield values (") << nOutputs << ") to match the number of operands of the enclosing " << "linalg.generic op (" << op.getNumOperands() << ")"; for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) return op.emitOpError("type of yield operand ") << (i + 1) << " (" << op.getOperand(i).getType() << ") doesn't match " << "the element type of the enclosing linalg.generic op (" << elementType << ")"; } return success(); } static LogicalResult verify(YieldOp op) { auto *parentOp = op.getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); auto genericOp = dyn_cast(parentOp); if (genericOp) return verifyYield(op, genericOp); auto indexedGenericOp = dyn_cast(parentOp); if (indexedGenericOp) return verifyYield(op, indexedGenericOp); return op.emitOpError("expected '") << GenericOp::getOperationName() << "' or '" << IndexedGenericOp::getOperationName() << "' parent op"; } /////// Operations corresponding to library calls defined with Tablegen //////// // For such operations correspond to library calls (i.e. defined in // LinalgStructuredOps.td), we define an overloaded `print` function and a // parse`className` function. // A LinalgStructuredOp prints as: // // ```mlir // concrete_op_name (ssa-inputs, ssa-outputs) : view-types // ``` // // for example: // // ``` // linalg.matmul(%0, %1, %2) : // memref, // memref, // memref // ``` // // Where %0, %1 and %2 are ssa-values of type MemRefType with strides. static void printLinalgStructuredOp(OpAsmPrinter &p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); p << op->getName().getStringRef() << "(" << op->getOperands() << ")"; p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperandTypes(); } static ParseResult parseLinalgStructuredOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; SmallVector types; return failure( parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands)); } static LogicalResult verify(FillOp op) { auto viewType = op.getOutputShapedType(0); auto fillType = op.value().getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); return success(); } static LogicalResult verify(CopyOp op) { auto outputViewType = op.getOutputShapedType(0); auto inputViewType = op.getInputShapedType(0); if (inputViewType.getElementType() != outputViewType.getElementType()) return op.emitOpError("expects views of the same type"); if (inputViewType.getRank() != outputViewType.getRank()) return op.emitOpError("expects views of the same rank"); auto rank = op.getNumParallelLoops(); auto inputPermutationMap = op.inputPermutation(); if (inputPermutationMap) { if (inputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional input_permutation map of rank ") << rank; if (!inputPermutationMap->isPermutation()) return op.emitOpError( "expects optional input_permutation map to be a permutation"); } auto outputPermutationMap = op.outputPermutation(); if (outputPermutationMap) { if (outputPermutationMap->getNumInputs() != rank) return op.emitOpError("expects optional output_permutation map of rank ") << rank; if (!outputPermutationMap->isPermutation()) return op.emitOpError( "expects optional output_permutation map to be a permutation"); } if (rank == 0 && inputPermutationMap) return op.emitOpError("expected no input permutation when rank == 0"); if (rank == 0 && outputPermutationMap) return op.emitOpError("expected no output permutation when rank == 0"); return success(); } static LogicalResult verifyStrideOrDilation(ConvOp op, ArrayRef attrs, bool isStride) { auto strideOrDilation = isStride ? "stride" : "dilation"; if (attrs.size() != op.getNumWindowLoops()) return op.emitOpError("expects num ") << strideOrDilation << "s equal to number of window dimensions: " << attrs.size() << " vs " << op.getNumWindowLoops(); return success(); } static LogicalResult verify(ConvOp op) { auto oType = op.output().getType().cast(); auto fType = op.filter().getType().cast(); auto iType = op.input().getType().cast(); if (oType.getElementType() != iType.getElementType() || oType.getElementType() != fType.getElementType()) return op.emitOpError("expects memref elemental types to match"); if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) return op.emitOpError("expects memref ranks to match"); if (auto strides = op.strides()) { if (failed( verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) return failure(); } if (auto dilations = op.dilations()) { if (failed(verifyStrideOrDilation(op, dilations->getValue(), /*isStride=*/false))) return failure(); } return success(); } namespace mlir { namespace linalg { #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" } // namespace linalg } // namespace mlir static AffineMap extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) { if (maybeMap) return maybeMap.getValue(); if (rank == 0) return AffineMap(); return AffineMap::getMultiDimIdentityMap(rank, context); } // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) // and increments `curIdx` to `curIdx + num`. static SmallVector makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) { SmallVector res; res.reserve(num); for (unsigned i = 0; i < num; ++i) res.push_back(getAffineDimExpr(curIdx++, context)); return res; } static SmallVector weightedConvInputIndex(ConvOp op, ArrayRef a, ArrayRef b) { assert(a.size() == b.size()); SmallVector res; res.reserve(a.size()); for (unsigned i = 0, e = a.size(); i < e; ++i) { res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]); } return res; } static SmallVector concat(ArrayRef a, ArrayRef b) { SmallVector res; res.reserve(a.size() + b.size()); res.assign(a.begin(), a.end()); res.append(b.begin(), b.end()); return res; } // Note: both functions below would completely disappear with a simple tensor // kernel language. // // Ideally this should all be Tablegen'd but there is no good story for // AffineMap for now. SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { MLIRContext *context = op->getContext(); if (auto copyOp = dyn_cast(op)) { // I(input_perm(ivs)) -> O(output_perm(ivs)) auto maybeInputMap = copyOp.inputPermutation(); auto maybeOutputMap = copyOp.outputPermutation(); unsigned inputRank = copyOp.getInputShapedType(0).getRank(); unsigned outputRank = copyOp.getOutputShapedType(0).getRank(); return SmallVector{ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}; } if (auto fillOp = dyn_cast(op)) { // filling_value -> O(ivs) unsigned rank = fillOp.getNumParallelLoops(); return SmallVector{ extractOrIdentityMap(llvm::None, rank, context)}; } auto i = getAffineDimExpr(0, context); auto j = getAffineDimExpr(1, context); auto k = getAffineDimExpr(2, context); if (isa(op)) // A(r_i) * B(r_i) -> C() return SmallVector{AffineMap::get(1, 0, {i}), AffineMap::get(1, 0, {i}), AffineMap()}; if (isa(op)) // A(i, r_j) * B(r_j) -> C(i) return SmallVector{AffineMap::get(2, 0, {i, j}), AffineMap::get(2, 0, {j}), AffineMap::get(2, 0, {i})}; if (isa(op)) // A(i, r_k) * B(r_k, j) -> C(i, j) return SmallVector{AffineMap::get(3, 0, {i, k}), AffineMap::get(3, 0, {k, j}), AffineMap::get(3, 0, {i, j})}; if (auto convOp = dyn_cast(op)) { // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) -> // O(b, x0, ..., xN-1, k) // for N equal to `nWindow`. auto nWin = convOp.getNumWindowLoops(); assert(nWin > 0 && "expected at least one window dimension"); unsigned idx = 0; // In the following, AffineDimExprs are indexed in loop order: // [ b, xs, k, q, zs] // parallels non-window reductions windows // // Parallel dims are exactly the dimensions indexing `output`: // output[b, x[0], ..., x[N-1], k]; i.e. // * batch dimensions (bs with #bs = 1 for now) // * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks) // * output filter dimensions (ks with #ks = 1 for now) auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context); auto xs = makeAffineDimExprs(nWin, idx, context); auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx, context); // Non-window reduction dim: sum_{z[0], ..., z[N-1], q} auto qs = makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context); // Window reduction dims: sum_{z[0], ..., z[N-1], q} auto zs = makeAffineDimExprs(nWin, idx, context); // Construct the weighedSum expression. auto ws = weightedConvInputIndex(convOp, xs, zs); return SmallVector{ // filter[z[0], ..., z[N-1], q, k] AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), // input[b, // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1], // q] AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), // output[b, x[0], ..., x[N-1], k] AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; } else if (auto genericOp = dyn_cast(op)) { SmallVector res; unsigned nViews = genericOp.getNumInputsAndOutputs(); res.reserve(nViews); for (unsigned i = 0, e = nViews; i < e; ++i) { res.push_back(genericOp.getIndexingMap(i)); } return res; } else if (auto indexedGenericOp = dyn_cast(op)) { SmallVector res; unsigned nViews = indexedGenericOp.getNumInputsAndOutputs(); res.reserve(nViews); for (unsigned i = 0, e = nViews; i < e; ++i) res.push_back(indexedGenericOp.getIndexingMap(i)); return res; } llvm_unreachable("Missing loopToOperandRangesMaps for op"); } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = t.dyn_cast()) { ss << "view"; for (auto size : memref.getShape()) if (size < 0) ss << "sx"; else ss << size << "x"; appendMangledType(ss, memref.getElementType()); } else if (auto vec = t.dyn_cast()) { ss << "vector"; interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); appendMangledType(ss, vec.getElementType()); } else if (t.isIntOrIndexOrFloat()) { ss << t; } else { llvm_unreachable("Invalid type for linalg library name mangling"); } } std::string mlir::linalg::generateLibraryCallName(Operation *op) { assert(isa(op)); std::string name(op->getName().getStringRef().str()); name.reserve(128); std::replace(name.begin(), name.end(), '.', '_'); llvm::raw_string_ostream ss(name); ss << "_"; auto types = op->getOperandTypes(); interleave( types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, [&]() { ss << "_"; }); return ss.str(); } static ArrayAttr getIndexingMaps(Operation *op) { LinalgOp linalgOp = cast(op); SmallVector maps; maps.reserve(linalgOp.getNumInputsAndOutputs()); for (AffineMap map : loopToOperandRangesMaps(op)) maps.push_back(AffineMapAttr::get(map)); return ArrayAttr::get(maps, op->getContext()); } ArrayAttr mlir::linalg::ConvOp::indexing_maps() { return getIndexingMaps(getOperation()); } ArrayAttr mlir::linalg::CopyOp::indexing_maps() { return getIndexingMaps(getOperation()); } ArrayAttr mlir::linalg::DotOp::indexing_maps() { return getIndexingMaps(getOperation()); } ArrayAttr mlir::linalg::FillOp::indexing_maps() { return getIndexingMaps(getOperation()); } ArrayAttr mlir::linalg::MatmulOp::indexing_maps() { return getIndexingMaps(getOperation()); } ArrayAttr mlir::linalg::MatvecOp::indexing_maps() { return getIndexingMaps(getOperation()); }