diff options
| author | Andy Davis <andydavis@google.com> | 2019-11-11 10:32:52 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-11 10:33:27 -0800 |
| commit | 5cf6e0ce7f03f9841675b1a9d44232540f3df5cc (patch) | |
| tree | ff2fd4639568dc7bdc4fa91132397c154a5e7d8d | |
| parent | e04d4bf865b01ec35ecfb98b34372a1dacd70266 (diff) | |
| download | bcm5719-llvm-5cf6e0ce7f03f9841675b1a9d44232540f3df5cc.tar.gz bcm5719-llvm-5cf6e0ce7f03f9841675b1a9d44232540f3df5cc.zip | |
Adds std.subview operation which takes dynamic offsets, sizes and strides and returns a memref type which represents sub/reduced-size view of its memref argument.
This operation is a companion operation to the std.view operation added as proposed in "Updates to the MLIR MemRefType" RFC.
PiperOrigin-RevId: 279766410
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.td | 85 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 14 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 21 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 121 | ||||
| -rw-r--r-- | mlir/test/IR/core-ops.mlir | 37 | ||||
| -rw-r--r-- | mlir/test/IR/invalid-ops.mlir | 67 |
10 files changed, 338 insertions, 53 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 710ffe1b63d..d7de15576d5 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1174,10 +1174,10 @@ def ViewOp : Std_Op<"view"> { let results = (outs AnyMemRef); let extraClassDeclaration = [{ - /// The result of a memref_shape_cast is always a memref. + /// The result of a view is always a memref. MemRefType getType() { return getResult()->getType().cast<MemRefType>(); } - /// Returns the dynamic offset for this shape cast operation if specified. + /// Returns the dynamic offset for this view operation if specified. /// Returns nullptr if no dynamic offset was specified. Value *getDynamicOffset(); @@ -1186,7 +1186,7 @@ def ViewOp : Std_Op<"view"> { return getDynamicOffset() == nullptr ? 1 : 2; } - /// Returns the dynamic sizes for this shape cast operation. + /// Returns the dynamic sizes for this view operation. operand_range getDynamicSizes() { return {operand_begin() + getDynamicSizesOperandStart(), operand_end()}; } @@ -1195,6 +1195,85 @@ def ViewOp : Std_Op<"view"> { let hasCanonicalizer = 1; } +def SubViewOp : Std_Op<"subview", [SameVariadicOperandSize]> { + let summary = "memref subview operation"; + let description = [{ + The "subview" operation converts a memref type to another memref type + which represents a reduced-size view of the original memref as specified by + the operation's offsets, sizes and strides arguments. + + The SubView operation supports the following arguments: + *) Memref: the "base" memref on which to create a "view" memref. + *) Offsets: memref-rank number of dynamic offsets into the "base" memref at + which to create the "view" memref. + *) Sizes: memref-rank dynamic size operands which specify the dynamic sizes + of the result "view" memref type. + *) Strides: memref-rank number of dynamic strides which are applied + multiplicatively to the base memref strides in each dimension. + + Example 1: + + %0 = alloc() : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> + + // Create a sub-view of "base" memref '%0' with offset arguments '%c0', + // dynamic sizes for each dimension, and stride arguments '%c1'. + %1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1] + : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to + memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)> + + Example 2: + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> + + // Create a sub-view of "base" memref '%0' with dynamic offsets, sizes, + // and strides. + // Note that dynamic offsets are represented by the linearized dynamic + // offset symbol 's0' in the subview memref layout map, and that the + // dynamic strides operands, after being applied to the base memref + // strides in each dimension, are represented in the view memref layout + // map as symbols 's1', 's2' and 's3'. + %1 = subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + memref<?x?x?xf32, + (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + } + }]; + + let arguments = (ins AnyMemRef:$source, Variadic<Index>:$offsets, + Variadic<Index>:$sizes, Variadic<Index>:$strides); + let results = (outs AnyMemRef); + + let extraClassDeclaration = [{ + /// The result of a subview is always a memref. + MemRefType getType() { return getResult()->getType().cast<MemRefType>(); } + + /// Returns the dynamic offsets for this subview operation. + operand_range getDynamicOffsets() { + return {operand_begin() + 1, operand_begin() + 1 + getType().getRank()}; + } + + /// Returns the operand starting position of the size operands. + unsigned getSizeOperandsStart() { return 1 + getType().getRank(); } + + /// Returns the dynamic sizes for this subview operation if specified. + operand_range getDynamicSizes() { + return {operand_begin() + getSizeOperandsStart(), + operand_begin() + getSizeOperandsStart() + getType().getRank()}; + } + + /// Returns the operand starting position of the size operands. + unsigned getStrideOperandsStart() { return 1 + 2 * getType().getRank(); } + + /// Returns the dynamic strides for this subview operation if specified. + operand_range getDynamicStrides() { + return {operand_begin() + getStrideOperandsStart(), + operand_begin() + getStrideOperandsStart() + getType().getRank()}; + } + }]; + + // TODO(andydavis) Add canonicalizer. +} + def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 9a160f534fe..06563ff96cf 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -376,10 +376,10 @@ void mlir::linalg::SubViewOp::build(Builder *b, OperationState &result, result.addAttributes(attrs); } -static void print(OpAsmPrinter &p, SubViewOp op) { +static void print(OpAsmPrinter &p, mlir::linalg::SubViewOp op) { p << op.getOperationName() << " " << *op.getOperand(0) << "["; auto ranges = op.getRanges(); - interleaveComma(ranges, p, [&p](const SubViewOp::Range &i) { + interleaveComma(ranges, p, [&p](const mlir::linalg::SubViewOp::Range &i) { p << *i.min << ", " << *i.max << ", " << *i.step; }); p << "]"; @@ -646,8 +646,9 @@ static LogicalResult verify(ConvOp op) { return success(); } -llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os, - SubViewOp::Range &range) { +llvm::raw_ostream & +mlir::linalg::operator<<(llvm::raw_ostream &os, + mlir::linalg::SubViewOp::Range &range) { return os << "range " << *range.min << ":" << *range.max << ":" << *range.step; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 82699545b3f..c6dffc3ab11 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -74,8 +74,9 @@ static llvm::cl::list<unsigned> clTileSizes( // a subset of the original loop ranges of `op`. // This is achieved by applying the `loopToOperandRangesMaps` permutation maps // to the `loopRanges` in order to obtain view ranges. -static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, - ArrayRef<SubViewOp::Range> loopRanges) { +static LinalgOp +cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, + ArrayRef<mlir::linalg::SubViewOp::Range> loopRanges) { auto maps = loopToOperandRangesMaps(op); SmallVector<Value *, 8> clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); @@ -87,7 +88,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, auto map = maps[idx]; LLVM_DEBUG(dbgs() << "map: " << map << "\n"); Value *view = en.value(); - SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults()); + SmallVector<mlir::linalg::SubViewOp::Range, 8> viewRanges( + map.getNumResults()); for (auto en2 : llvm::enumerate(map.getResults())) { unsigned d = en2.index(); // loopToOperandRangesMaps are permutations-only. @@ -105,7 +107,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, subViewOperands.push_back(r.max); subViewOperands.push_back(r.step); } - clonedViews.push_back(b.create<SubViewOp>(loc, view, subViewOperands)); + clonedViews.push_back( + b.create<mlir::linalg::SubViewOp>(loc, view, subViewOperands)); } auto operands = getAssumedNonViewOperands(op); clonedViews.append(operands.begin(), operands.end()); @@ -150,7 +153,7 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, unsigned producerIdx, OperationFolder *folder) { - auto subView = dyn_cast_or_null<SubViewOp>( + auto subView = dyn_cast_or_null<mlir::linalg::SubViewOp>( consumer.getInput(consumerIdx)->getDefiningOp()); auto slice = dyn_cast_or_null<SliceOp>( consumer.getInput(consumerIdx)->getDefiningOp()); @@ -169,7 +172,7 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); - SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); + SmallVector<mlir::linalg::SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); // Iterate over dimensions identified by the producer map for `producerIdx`. // This defines a subset of the loop ranges that we need to complete later. @@ -189,9 +192,9 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto viewDim = getViewDefiningLoopRange(producer, i); - loopRanges[i] = SubViewOp::Range{constant_index(folder, 0), - dim(viewDim.view, viewDim.dimension), - constant_index(folder, 1)}; + loopRanges[i] = mlir::linalg::SubViewOp::Range{ + constant_index(folder, 0), dim(viewDim.view, viewDim.dimension), + constant_index(folder, 1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } @@ -283,7 +286,8 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( // Must be a subview or a slice to guarantee there are loops we can fuse // into. - auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp()); + auto subView = dyn_cast_or_null<mlir::linalg::SubViewOp>( + consumedView->getDefiningOp()); auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp()); if (!subView && !slice) { LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 6b51e039a5b..7a8bc7162af 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -487,11 +487,11 @@ public: /// A non-conversion rewrite pattern kicks in to convert SubViewOp into RangeOps /// and SliceOps. -class SubViewOpConversion : public OpRewritePattern<SubViewOp> { +class SubViewOpConversion : public OpRewritePattern<mlir::linalg::SubViewOp> { public: - using OpRewritePattern<SubViewOp>::OpRewritePattern; + using OpRewritePattern<mlir::linalg::SubViewOp>::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp op, + PatternMatchResult matchAndRewrite(mlir::linalg::SubViewOp op, PatternRewriter &rewriter) const override { auto *view = op.getView(); SmallVector<Value *, 8> ranges; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index a23e68dc8f3..3afee415405 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -89,7 +89,7 @@ static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) { // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, - SubViewOp subView, + mlir::linalg::SubViewOp subView, bool dynamicBuffers, OperationFolder *folder) { auto zero = constant_index(folder, 0); @@ -135,7 +135,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, res.reserve(subViews.size()); DenseMap<Value *, PromotionInfo> promotionInfoMap; for (auto *v : subViews) { - SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + mlir::linalg::SubViewOp subView = + cast<mlir::linalg::SubViewOp>(v->getDefiningOp()); auto viewType = subView.getViewType(); // TODO(ntv): support more cases than just float. if (!viewType.getElementType().isa<FloatType>()) @@ -147,7 +148,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, } for (auto *v : subViews) { - SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + mlir::linalg::SubViewOp subView = + cast<mlir::linalg::SubViewOp>(v->getDefiningOp()); auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; @@ -165,7 +167,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; - copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView); + copy(cast<mlir::linalg::SubViewOp>(v->getDefiningOp()), + info->second.partialLocalView); } return res; } @@ -223,7 +226,8 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { // nothing. SetVector<Value *> subViews; for (auto it : op.getInputsAndOutputs()) - if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + if (auto sv = + dyn_cast_or_null<mlir::linalg::SubViewOp>(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index c1d9755f4df..b7a5740a387 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -65,7 +65,7 @@ static bool isZero(Value *v) { // avoiding affine map manipulations. // The returned ranges correspond to the loop ranges, in the proper order, that // are tiled and for which new loops will be created. -static SmallVector<SubViewOp::Range, 4> +static SmallVector<mlir::linalg::SubViewOp::Range, 4> makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, ArrayRef<Value *> allViewSizes, ArrayRef<Value *> allTileSizes, OperationFolder *folder) { @@ -83,10 +83,10 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, } // Create a new range with the applied tile sizes. - SmallVector<SubViewOp::Range, 4> res; + SmallVector<mlir::linalg::SubViewOp::Range, 4> res; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx], - tileSizes[idx]}); + res.push_back(mlir::linalg::SubViewOp::Range{ + constant_index(folder, 0), viewSizes[idx], tileSizes[idx]}); } return res; } @@ -182,13 +182,13 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } // Construct a new subview for the tile. - SmallVector<SubViewOp::Range, 4> subViewRangeOperands; + SmallVector<mlir::linalg::SubViewOp::Range, 4> subViewRangeOperands; subViewRangeOperands.reserve(rank * 3); for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { - subViewRangeOperands.push_back( - SubViewOp::Range{constant_index(folder, 0), dim(view, r), - constant_index(folder, 1)}); + subViewRangeOperands.push_back(mlir::linalg::SubViewOp::Range{ + constant_index(folder, 0), dim(view, r), + constant_index(folder, 1)}); continue; } @@ -198,7 +198,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the slice view does not subsample, stepping occurs in the loop). subViewRangeOperands.push_back( - SubViewOp::Range{min, max, constant_index(folder, 1)}); + mlir::linalg::SubViewOp::Range{min, max, constant_index(folder, 1)}); } SmallVector<Value *, 12> subViewOperands; subViewOperands.reserve(subViewRangeOperands.size() * 3); @@ -207,7 +207,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, subViewOperands.push_back(r.max); subViewOperands.push_back(r.step); } - res.push_back(b.create<SubViewOp>(loc, view, subViewOperands)); + res.push_back( + b.create<mlir::linalg::SubViewOp>(loc, view, subViewOperands)); } // Traverse the mins/maxes and erase those that don't have uses left. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index dcd2e56a1ee..20cf1834698 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -56,8 +56,8 @@ mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, enter(body, /*prev=*/1); } -mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, - SubViewOp::Range range) { +mlir::edsc::LoopRangeBuilder::LoopRangeBuilder( + ValueHandle *iv, mlir::linalg::SubViewOp::Range range) { auto forOp = OperationHandle::createOp<ForOp>(range.min, range.max, range.step); *iv = ValueHandle(forOp.getInductionVar()); @@ -74,7 +74,8 @@ mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) { } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef<ValueHandle *> ivs, ArrayRef<SubViewOp::Range> ranges) { + ArrayRef<ValueHandle *> ivs, + ArrayRef<mlir::linalg::SubViewOp::Range> ranges) { loops.reserve(ranges.size()); for (unsigned i = 0, e = ranges.size(); i < e; ++i) { loops.emplace_back(ivs[i], ranges[i]); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 9fc6f320e1d..12029248168 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2380,6 +2380,23 @@ Value *ViewOp::getDynamicOffset() { return nullptr; } +static LogicalResult verifyDynamicStrides(MemRefType memrefType, + ArrayRef<int64_t> strides) { + ArrayRef<int64_t> shape = memrefType.getShape(); + unsigned rank = memrefType.getRank(); + assert(rank == strides.size()); + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(shape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) + return failure(); + } + return success(); +} + static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); auto viewType = op.getResult()->getType().cast<MemRefType>(); @@ -2396,7 +2413,7 @@ static LogicalResult verify(ViewOp op) { "type ") << baseType << " and view memref type " << viewType; - // Verify that the result memref type has a strided layout map. is strided + // Verify that the result memref type has a strided layout map. int64_t offset; llvm::SmallVector<int64_t, 4> strides; if (failed(getStridesAndOffset(viewType, strides, offset))) @@ -2413,20 +2430,9 @@ static LogicalResult verify(ViewOp op) { // Verify dynamic strides symbols were added to correct dimensions based // on dynamic sizes. - ArrayRef<int64_t> viewShape = viewType.getShape(); - unsigned viewRank = viewType.getRank(); - assert(viewRank == strides.size()); - bool dynamicStrides = false; - for (int i = viewRank - 2; i >= 0; --i) { - // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. - if (ShapedType::isDynamic(viewShape[i + 1])) - dynamicStrides = true; - // If stride at dim 'i' is not dynamic, return error. - if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) - return op.emitError("incorrect dynamic strides in view memref type ") - << viewType; - } - + if (failed(verifyDynamicStrides(viewType, strides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; return success(); } @@ -2544,6 +2550,91 @@ void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } //===----------------------------------------------------------------------===// +// SubViewOp +//===----------------------------------------------------------------------===// + +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector<OpAsmParser::OperandType, 4> offsetsInfo; + SmallVector<OpAsmParser::OperandType, 4> sizesInfo; + SmallVector<OpAsmParser::OperandType, 4> stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + return failure( + parser.parseOperand(srcInfo) || + parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, SubViewOp op) { + p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; + p.printOperands(op.getDynamicOffsets()); + p << "]["; + p.printOperands(op.getDynamicSizes()); + p << "]["; + p.printOperands(op.getDynamicStrides()); + p << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); +} + +static LogicalResult verify(SubViewOp op) { + auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); + auto subViewType = op.getResult()->getType().cast<MemRefType>(); + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and subview memref type " << subViewType; + + // Verify that the base memref type has a strided layout map. + int64_t baseOffset; + llvm::SmallVector<int64_t, 4> baseStrides; + if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) + return op.emitError("base type ") << subViewType << " is not strided"; + + // Verify that the result memref type has a strided layout map. + int64_t subViewOffset; + llvm::SmallVector<int64_t, 4> subViewStrides; + if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) + return op.emitError("result type ") << subViewType << " is not strided"; + + unsigned memrefOperandCount = 1; + unsigned numDynamicOffsets = llvm::size(op.getDynamicOffsets()); + unsigned numDynamicSizes = llvm::size(op.getDynamicSizes()); + unsigned numDynamicStrides = llvm::size(op.getDynamicStrides()); + + // Verify that we have the correct number of operands for the result type. + if (op.getNumOperands() != memrefOperandCount + numDynamicOffsets + + numDynamicSizes + numDynamicStrides) + return op.emitError("incorrect number of operands for type ") + << subViewType; + + // Verify that the subview layout map has a dynamic offset. + if (subViewOffset != MemRefType::getDynamicStrideOrOffset()) + return op.emitError("subview memref layout map must specify a dynamic " + "offset for type ") + << subViewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + if (failed(verifyDynamicStrides(subViewType, subViewStrides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << subViewType; + return success(); +} + +//===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 252a13df102..96df40202a3 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -15,6 +15,15 @@ // CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0) // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1) +// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2) +// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0) + +// CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = (d0)[s0, s1] -> (d0 * s1 + s0) + +// CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = (d0, d1) -> (d0 * 22 + d1) +// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) + // CHECK-LABEL: func @func_with_ops(%arg0: f32) { func @func_with_ops(f32) { ^bb0(%a : f32): @@ -506,6 +515,34 @@ func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { return } +// CHECK-LABEL: func @memref_subview(%arg0 +func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + + //%2 = alloc() : memref<64xf32, (d0) -> (d0)> + + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + // CHECK: std.subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP0]]> + %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<?x?x?xf32, + (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)> + + %2 = alloc()[%arg2] : memref<64xf32, (d0)[s0] -> (d0 + s0)> + // CHECK: std.subview %2[%c1][%arg0][%c1] : memref<64xf32, #[[BASE_MAP1]]> to memref<?xf32, #[[SUBVIEW_MAP1]]> + %3 = subview %2[%c1][%arg0][%c1] + : memref<64xf32, (d0)[s0] -> (d0 + s0)> to + memref<?xf32, (d0)[s0, s1] -> (d0 * s1 + s0)> + + %4 = alloc() : memref<64x22xf32, (d0, d1) -> (d0 * 22 + d1)> + // CHECK: std.subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] : memref<64x22xf32, #[[BASE_MAP2]]> to memref<?x?xf32, #[[SUBVIEW_MAP2]]> + %5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] + : memref<64x22xf32, (d0, d1) -> (d0 * 22 + d1)> to + memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> + return +} + // CHECK-LABEL: func @test_dimop(%arg0 func @test_dimop(%arg0: tensor<4x4x?xf32>) { // CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index ec38ecebb0e..9c1807807c3 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -976,3 +976,70 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { return } +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2), 2> + // expected-error@+1 {{different memory spaces}} + %1 = subview %0[][%arg2][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2), 2> to + memref<8x?x4xf32, (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + // expected-error@+1 {{is not strided}} + %1 = subview %0[][%arg2][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<8x?x4xf32, (d0, d1, d2)[s0] -> (d0 + s0, d1, d2)> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 + d1, d1 + d2, d2)> + // expected-error@+1 {{is not strided}} + %1 = subview %0[][%arg2][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 + d1, d1 + d2, d2)> to + memref<8x?x4xf32, (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + // expected-error@+1 {{incorrect number of operands for type}} + %1 = subview %0[%arg0, %arg1][%arg2][] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<8x?x4xf32, (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + // expected-error@+1 {{incorrect dynamic strides in view memref type}} + %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<?x?x4xf32, (d0, d1, d2)[s0] -> (d0 * 64 + d1 * 4 + d2 + s0)> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> + %c0 = constant 0 : index + %c1 = constant 1 : index + // expected-error@+1 {{subview memref layout map must specify a dynamic offset}} + %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] + : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to + memref<?x?x?xf32, (d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)> + return +} |

