diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-12-02 07:51:27 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-02 07:52:00 -0800 |
| commit | 0d22a3fdc87cb8e96a73cb427c6621c405c4674e (patch) | |
| tree | dc5483c28c0da664deae0cadbc1e08f720a2ffb4 /mlir/lib/Dialect/StandardOps | |
| parent | 4231de7897442f7423dae1e8b7fffdd1a69d5b58 (diff) | |
| download | bcm5719-llvm-0d22a3fdc87cb8e96a73cb427c6621c405c4674e.tar.gz bcm5719-llvm-0d22a3fdc87cb8e96a73cb427c6621c405c4674e.zip | |
NFC: Update std.subview op to use AttrSizedOperandSegments
This turns a few manually written helper methods into auto-generated ones.
PiperOrigin-RevId: 283339617
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 123 |
1 files changed, 46 insertions, 77 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 0bf562337a9..31431be5054 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { // Fold dim to the size argument of a SubViewOp. auto memref = memrefOrTensor()->getDefiningOp(); if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) { - auto sizes = subview.getDynamicSizes(); + auto sizes = subview.sizes(); if (!sizes.empty()) return *(sizes.begin() + getIndex()); } @@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value *source, unsigned num_offsets, - unsigned num_sizes, unsigned num_strides, - ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, - ArrayRef<Value *> strides) { - SmallVector<Value *, 8> operands; - operands.reserve(num_offsets + num_sizes + num_strides); - operands.append(offsets.begin(), offsets.end()); - operands.append(sizes.begin(), sizes.end()); - operands.append(strides.begin(), strides.end()); - build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets), - b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides), - operands); -} - void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, ArrayRef<Value *> strides, Type resultType, ArrayRef<NamedAttribute> attrs) { if (!resultType) resultType = inferSubViewResultType(source->getType().cast<MemRefType>()); - build(b, result, resultType, source, offsets.size(), sizes.size(), - strides.size(), offsets, sizes, strides); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()), + static_cast<int32_t>(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); result.addAttributes(attrs); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, Value *source) { - build(b, result, resultType, source, 0, 0, 0, {}, {}, {}); + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { @@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); } + auto builder = parser.getBuilder(); - result.addAttribute("num_offsets", - builder.getI32IntegerAttr(offsetsInfo.size())); - result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size())); - result.addAttribute("num_strides", - builder.getI32IntegerAttr(stridesInfo.size())); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()), + static_cast<int32_t>(sizesInfo.size()), + static_cast<int32_t>(stridesInfo.size())})); return failure( parser.parseOptionalAttrDict(result.attributes) || @@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.getDynamicOffsets()); + p.printOperands(op.offsets()); p << "]["; - p.printOperands(op.getDynamicSizes()); + p.printOperands(op.sizes()); p << "]["; - p.printOperands(op.getDynamicStrides()); + p.printOperands(op.strides()); p << ']'; - SmallVector<StringRef, 3> elidedAttrs = {"num_offsets", "num_sizes", - "num_strides"}; + + SmallVector<StringRef, 1> elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) { } // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and viceversa. + // dynamic values, and vice versa. if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { return op.emitError("invalid to specify dynamic sizes when subview result " "type is statically shaped and viceversa"); } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. if (op.getNumSizes() > 0) { - // Verify that non if the shape values of the result type are static. if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { return dim != ShapedType::kDynamicSize; })) { @@ -2758,9 +2750,8 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() { unsigned rank = getType().getRank(); res.reserve(rank); for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(getDynamicOffsets().begin() + i), - *(getDynamicSizes().begin() + i), - *(getDynamicStrides().begin() + i)}); + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); return res; } @@ -2792,13 +2783,13 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) { + llvm::any_of(subViewOp.sizes(), [](Value *operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes()); - for (auto size : enumerate(subViewOp.getDynamicSizes())) { + for (auto size : enumerate(subViewOp.sizes())) { auto defOp = size.value()->getDefiningOp(); assert(defOp); staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue(); @@ -2808,12 +2799,12 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef<Value *>(), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.offsets()), ArrayRef<Value *>(), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2839,14 +2830,14 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) { + llvm::any_of(subViewOp.strides(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides()); - for (auto stride : enumerate(subViewOp.getDynamicStrides())) { + for (auto stride : enumerate(subViewOp.strides())) { auto defOp = stride.value()->getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); @@ -2858,15 +2849,15 @@ public: MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); - auto newSubViewOp = rewriter.create<SubViewOp>( - subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef<Value *>(), - newMemRefType); + auto newSubViewOp = + rewriter.create<SubViewOp>(subViewOp.getLoc(), subViewOp.source(), + llvm::to_vector<4>(subViewOp.offsets()), + llvm::to_vector<4>(subViewOp.sizes()), + ArrayRef<Value *>(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2893,14 +2884,14 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) { + llvm::any_of(subViewOp.offsets(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } auto staticOffset = baseOffset; - for (auto offset : enumerate(subViewOp.getDynamicOffsets())) { + for (auto offset : enumerate(subViewOp.offsets())) { auto defOp = offset.value()->getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); @@ -2915,39 +2906,17 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.sizes()), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace -SubViewOp::operand_range SubViewOp::getDynamicOffsets() { - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numOffsets + 1); - return {operand_begin() + 1, operand_begin() + 1 + numOffsets}; -} - -SubViewOp::operand_range SubViewOp::getDynamicSizes() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numSizes + numOffsets + 1); - return {operand_begin() + 1 + numOffsets, - operand_begin() + 1 + numOffsets + numSizes}; -} - -SubViewOp::operand_range SubViewOp::getDynamicStrides() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - auto numStrides = getNumStrides(); - assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1); - return {operand_begin() + (1 + numOffsets + numSizes), - operand_begin() + (1 + numOffsets + numSizes + numStrides)}; -} void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { |

