diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-12-02 07:51:27 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-02 07:52:00 -0800 |
| commit | 0d22a3fdc87cb8e96a73cb427c6621c405c4674e (patch) | |
| tree | dc5483c28c0da664deae0cadbc1e08f720a2ffb4 | |
| parent | 4231de7897442f7423dae1e8b7fffdd1a69d5b58 (diff) | |
| download | bcm5719-llvm-0d22a3fdc87cb8e96a73cb427c6621c405c4674e.tar.gz bcm5719-llvm-0d22a3fdc87cb8e96a73cb427c6621c405c4674e.zip | |
NFC: Update std.subview op to use AttrSizedOperandSegments
This turns a few manually written helper methods into auto-generated ones.
PiperOrigin-RevId: 283339617
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.td | 50 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Builders.h | 2 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 123 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 8 |
5 files changed, 77 insertions, 109 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index e2731acf47f..70cf3bb7775 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { let hasCanonicalizer = 1; } -def SubViewOp : Std_Op<"subview", [NoSideEffect]> { +def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { // TODO(b/144779634, ravishankarm) : Use different arguments for // offsets, sizes and strides. - let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets, - I32Attr:$num_sizes, I32Attr:$num_strides, - Variadic<Index>:$operands); + let arguments = (ins + AnyMemRef:$source, + Variadic<Index>:$offsets, + Variadic<Index>:$sizes, + Variadic<Index>:$strides, + I32ElementsAttr:$operand_segment_sizes + ); let results = (outs AnyMemRef); - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *source, " - "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, " - "ArrayRef<Value *> strides, Type resultType = Type(), " - "ArrayRef<NamedAttribute> attrs = {}">, + let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source">, + "Builder *b, OperationState &result, Value *source, " + "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, " + "ArrayRef<Value *> strides, Type resultType = Type(), " + "ArrayRef<NamedAttribute> attrs = {}">, OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source, " - "unsigned num_offsets, unsigned num_sizes, unsigned num_strides, " - "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, " - "ArrayRef<Value *> strides">]; + "Builder *builder, OperationState &result, " + "Type resultType, Value *source"> + ]; let extraClassDeclaration = [{ /// Returns the type of the base memref operand. @@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { MemRefType getType() { return getResult()->getType().cast<MemRefType>(); } /// Returns as integer value the number of offset operands. - int64_t getNumOffsets() { - return num_offsets().getSExtValue(); - } + int64_t getNumOffsets() { return llvm::size(offsets()); } /// Returns as integer value the number of size operands. - int64_t getNumSizes() { - return num_sizes().getSExtValue(); - } + int64_t getNumSizes() { return llvm::size(sizes()); } /// Returns as integer value the number of stride operands. - int64_t getNumStrides() { - return num_strides().getSExtValue(); - } - - /// Returns the dynamic offsets for this subview operation. - operand_range getDynamicOffsets(); + int64_t getNumStrides() { return llvm::size(strides()); } /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes(); - - /// Returns the dynamic strides for this subview operation if specified. - operand_range getDynamicStrides(); + operand_range getDynamicSizes() { return sizes(); } // Auxiliary range data structure and helper function that unpacks the // offset, size and stride operands of the SubViewOp into a list of triples. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 01ad38cfc11..c5ed7b16b56 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -120,6 +120,8 @@ public: IntegerAttr getI32IntegerAttr(int32_t value); IntegerAttr getI64IntegerAttr(int64_t value); + DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); + ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index ae2b7837c40..d226766a3fc 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast<SubViewOp>(op); - SubViewOpOperandAdaptor adaptor(operands); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. @@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { return matchFailure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(adaptor.source()); + MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 0bf562337a9..31431be5054 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { // Fold dim to the size argument of a SubViewOp. auto memref = memrefOrTensor()->getDefiningOp(); if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) { - auto sizes = subview.getDynamicSizes(); + auto sizes = subview.sizes(); if (!sizes.empty()) return *(sizes.begin() + getIndex()); } @@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value *source, unsigned num_offsets, - unsigned num_sizes, unsigned num_strides, - ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, - ArrayRef<Value *> strides) { - SmallVector<Value *, 8> operands; - operands.reserve(num_offsets + num_sizes + num_strides); - operands.append(offsets.begin(), offsets.end()); - operands.append(sizes.begin(), sizes.end()); - operands.append(strides.begin(), strides.end()); - build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets), - b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides), - operands); -} - void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, ArrayRef<Value *> strides, Type resultType, ArrayRef<NamedAttribute> attrs) { if (!resultType) resultType = inferSubViewResultType(source->getType().cast<MemRefType>()); - build(b, result, resultType, source, offsets.size(), sizes.size(), - strides.size(), offsets, sizes, strides); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()), + static_cast<int32_t>(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); result.addAttributes(attrs); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, Value *source) { - build(b, result, resultType, source, 0, 0, 0, {}, {}, {}); + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { @@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); } + auto builder = parser.getBuilder(); - result.addAttribute("num_offsets", - builder.getI32IntegerAttr(offsetsInfo.size())); - result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size())); - result.addAttribute("num_strides", - builder.getI32IntegerAttr(stridesInfo.size())); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()), + static_cast<int32_t>(sizesInfo.size()), + static_cast<int32_t>(stridesInfo.size())})); return failure( parser.parseOptionalAttrDict(result.attributes) || @@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.getDynamicOffsets()); + p.printOperands(op.offsets()); p << "]["; - p.printOperands(op.getDynamicSizes()); + p.printOperands(op.sizes()); p << "]["; - p.printOperands(op.getDynamicStrides()); + p.printOperands(op.strides()); p << ']'; - SmallVector<StringRef, 3> elidedAttrs = {"num_offsets", "num_sizes", - "num_strides"}; + + SmallVector<StringRef, 1> elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) { } // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and viceversa. + // dynamic values, and vice versa. if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { return op.emitError("invalid to specify dynamic sizes when subview result " "type is statically shaped and viceversa"); } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. if (op.getNumSizes() > 0) { - // Verify that non if the shape values of the result type are static. if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { return dim != ShapedType::kDynamicSize; })) { @@ -2758,9 +2750,8 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() { unsigned rank = getType().getRank(); res.reserve(rank); for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(getDynamicOffsets().begin() + i), - *(getDynamicSizes().begin() + i), - *(getDynamicStrides().begin() + i)}); + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); return res; } @@ -2792,13 +2783,13 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) { + llvm::any_of(subViewOp.sizes(), [](Value *operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes()); - for (auto size : enumerate(subViewOp.getDynamicSizes())) { + for (auto size : enumerate(subViewOp.sizes())) { auto defOp = size.value()->getDefiningOp(); assert(defOp); staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue(); @@ -2808,12 +2799,12 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef<Value *>(), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.offsets()), ArrayRef<Value *>(), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2839,14 +2830,14 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) { + llvm::any_of(subViewOp.strides(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides()); - for (auto stride : enumerate(subViewOp.getDynamicStrides())) { + for (auto stride : enumerate(subViewOp.strides())) { auto defOp = stride.value()->getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); @@ -2858,15 +2849,15 @@ public: MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); - auto newSubViewOp = rewriter.create<SubViewOp>( - subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef<Value *>(), - newMemRefType); + auto newSubViewOp = + rewriter.create<SubViewOp>(subViewOp.getLoc(), subViewOp.source(), + llvm::to_vector<4>(subViewOp.offsets()), + llvm::to_vector<4>(subViewOp.sizes()), + ArrayRef<Value *>(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2893,14 +2884,14 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) { + llvm::any_of(subViewOp.offsets(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } auto staticOffset = baseOffset; - for (auto offset : enumerate(subViewOp.getDynamicOffsets())) { + for (auto offset : enumerate(subViewOp.offsets())) { auto defOp = offset.value()->getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); @@ -2915,39 +2906,17 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.sizes()), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace -SubViewOp::operand_range SubViewOp::getDynamicOffsets() { - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numOffsets + 1); - return {operand_begin() + 1, operand_begin() + 1 + numOffsets}; -} - -SubViewOp::operand_range SubViewOp::getDynamicSizes() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numSizes + numOffsets + 1); - return {operand_begin() + 1 + numOffsets, - operand_begin() + 1 + numOffsets + numSizes}; -} - -SubViewOp::operand_range SubViewOp::getDynamicStrides() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - auto numStrides = getNumStrides(); - assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1); - return {operand_begin() + (1 + numOffsets + numSizes), - operand_begin() + (1 + numOffsets + numSizes + numStrides)}; -} void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index afdeefd023c..4d6cd3550ca 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) { return IntegerAttr::get(getIntegerType(64), APInt(64, value)); } +DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) { + return DenseElementsAttr::get( + VectorType::get(static_cast<int64_t>(values.size()), + getIntegerType(32)), + values) + .cast<DenseIntElementsAttr>(); +} + IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } |

