summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td50
-rw-r--r--mlir/include/mlir/IR/Builders.h2
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp3
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp123
-rw-r--r--mlir/lib/IR/Builders.cpp8
5 files changed, 77 insertions, 109 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index e2731acf47f..70cf3bb7775 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
-def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
+def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let summary = "memref subview operation";
let description = [{
The "subview" operation converts a memref type to another memref type
@@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
// TODO(b/144779634, ravishankarm) : Use different arguments for
// offsets, sizes and strides.
- let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets,
- I32Attr:$num_sizes, I32Attr:$num_strides,
- Variadic<Index>:$operands);
+ let arguments = (ins
+ AnyMemRef:$source,
+ Variadic<Index>:$offsets,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides,
+ I32ElementsAttr:$operand_segment_sizes
+ );
let results = (outs AnyMemRef);
- let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *source, "
- "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
- "ArrayRef<Value *> strides, Type resultType = Type(), "
- "ArrayRef<NamedAttribute> attrs = {}">,
+ let builders = [
OpBuilder<
- "Builder *builder, OperationState &result, Type resultType, Value *source">,
+ "Builder *b, OperationState &result, Value *source, "
+ "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
+ "ArrayRef<Value *> strides, Type resultType = Type(), "
+ "ArrayRef<NamedAttribute> attrs = {}">,
OpBuilder<
- "Builder *builder, OperationState &result, Type resultType, Value *source, "
- "unsigned num_offsets, unsigned num_sizes, unsigned num_strides, "
- "ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
- "ArrayRef<Value *> strides">];
+ "Builder *builder, OperationState &result, "
+ "Type resultType, Value *source">
+ ];
let extraClassDeclaration = [{
/// Returns the type of the base memref operand.
@@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
/// Returns as integer value the number of offset operands.
- int64_t getNumOffsets() {
- return num_offsets().getSExtValue();
- }
+ int64_t getNumOffsets() { return llvm::size(offsets()); }
/// Returns as integer value the number of size operands.
- int64_t getNumSizes() {
- return num_sizes().getSExtValue();
- }
+ int64_t getNumSizes() { return llvm::size(sizes()); }
/// Returns as integer value the number of stride operands.
- int64_t getNumStrides() {
- return num_strides().getSExtValue();
- }
-
- /// Returns the dynamic offsets for this subview operation.
- operand_range getDynamicOffsets();
+ int64_t getNumStrides() { return llvm::size(strides()); }
/// Returns the dynamic sizes for this subview operation if specified.
- operand_range getDynamicSizes();
-
- /// Returns the dynamic strides for this subview operation if specified.
- operand_range getDynamicStrides();
+ operand_range getDynamicSizes() { return sizes(); }
// Auxiliary range data structure and helper function that unpacks the
// offset, size and stride operands of the SubViewOp into a list of triples.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 01ad38cfc11..c5ed7b16b56 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -120,6 +120,8 @@ public:
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
+ DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
+
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index ae2b7837c40..d226766a3fc 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<SubViewOp>(op);
- SubViewOpOperandAdaptor adaptor(operands);
// TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
// having multiple variadic operands where each operand can have different
// number of entries, clean all of this up.
@@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
return matchFailure();
// Create the descriptor.
- MemRefDescriptor sourceMemRef(adaptor.source());
+ MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 0bf562337a9..31431be5054 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Fold dim to the size argument of a SubViewOp.
auto memref = memrefOrTensor()->getDefiningOp();
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
- auto sizes = subview.getDynamicSizes();
+ auto sizes = subview.sizes();
if (!sizes.empty())
return *(sizes.begin() + getIndex());
}
@@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) {
memRefType.getMemorySpace());
}
-void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
- Value *source, unsigned num_offsets,
- unsigned num_sizes, unsigned num_strides,
- ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
- ArrayRef<Value *> strides) {
- SmallVector<Value *, 8> operands;
- operands.reserve(num_offsets + num_sizes + num_strides);
- operands.append(offsets.begin(), offsets.end());
- operands.append(sizes.begin(), sizes.end());
- operands.append(strides.begin(), strides.end());
- build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets),
- b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides),
- operands);
-}
-
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source,
ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
ArrayRef<Value *> strides, Type resultType,
ArrayRef<NamedAttribute> attrs) {
if (!resultType)
resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
- build(b, result, resultType, source, offsets.size(), sizes.size(),
- strides.size(), offsets, sizes, strides);
+ auto segmentAttr = b->getI32VectorAttr(
+ {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
+ static_cast<int32_t>(strides.size())});
+ build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
result.addAttributes(attrs);
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
Value *source) {
- build(b, result, resultType, source, 0, 0, 0, {}, {}, {});
+ build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
+ resultType);
}
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
@@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
return failure();
}
+
auto builder = parser.getBuilder();
- result.addAttribute("num_offsets",
- builder.getI32IntegerAttr(offsetsInfo.size()));
- result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size()));
- result.addAttribute("num_strides",
- builder.getI32IntegerAttr(stridesInfo.size()));
+ result.addAttribute(
+ SubViewOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
+ static_cast<int32_t>(sizesInfo.size()),
+ static_cast<int32_t>(stridesInfo.size())}));
return failure(
parser.parseOptionalAttrDict(result.attributes) ||
@@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, SubViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
- p.printOperands(op.getDynamicOffsets());
+ p.printOperands(op.offsets());
p << "][";
- p.printOperands(op.getDynamicSizes());
+ p.printOperands(op.sizes());
p << "][";
- p.printOperands(op.getDynamicStrides());
+ p.printOperands(op.strides());
p << ']';
- SmallVector<StringRef, 3> elidedAttrs = {"num_offsets", "num_sizes",
- "num_strides"};
+
+ SmallVector<StringRef, 1> elidedAttrs = {
+ SubViewOp::getOperandSegmentSizeAttr()};
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
@@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) {
}
// Verify that if the shape of the subview type is static, then sizes are not
- // dynamic values, and viceversa.
+ // dynamic values, and vice versa.
if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
(op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
return op.emitError("invalid to specify dynamic sizes when subview result "
"type is statically shaped and viceversa");
}
+
+ // Verify that if dynamic sizes are specified, then the result memref type
+ // have full dynamic dimensions.
if (op.getNumSizes() > 0) {
- // Verify that non if the shape values of the result type are static.
if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
return dim != ShapedType::kDynamicSize;
})) {
@@ -2758,9 +2750,8 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
- res.emplace_back(Range{*(getDynamicOffsets().begin() + i),
- *(getDynamicSizes().begin() + i),
- *(getDynamicStrides().begin() + i)});
+ res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
+ *(strides().begin() + i)});
return res;
}
@@ -2792,13 +2783,13 @@ public:
// Follow all or nothing approach for shapes for now. If all the operands
// for sizes are constants then fold it into the type of the result memref.
if (subViewType.hasStaticShape() ||
- llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) {
+ llvm::any_of(subViewOp.sizes(), [](Value *operand) {
return !matchPattern(operand, m_ConstantIndex());
})) {
return matchFailure();
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
- for (auto size : enumerate(subViewOp.getDynamicSizes())) {
+ for (auto size : enumerate(subViewOp.sizes())) {
auto defOp = size.value()->getDefiningOp();
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
@@ -2808,12 +2799,12 @@ public:
subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(),
- llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef<Value *>(),
- llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
+ llvm::to_vector<4>(subViewOp.offsets()), ArrayRef<Value *>(),
+ llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
- llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp,
- newSubViewOp, subViewOp.getType());
+ llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp,
+ subViewOp.getType());
return matchSuccess();
}
};
@@ -2839,14 +2830,14 @@ public:
failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
- llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) {
+ llvm::any_of(subViewOp.strides(), [](Value *stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
}
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
- for (auto stride : enumerate(subViewOp.getDynamicStrides())) {
+ for (auto stride : enumerate(subViewOp.strides())) {
auto defOp = stride.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[stride.index()] > 0);
@@ -2858,15 +2849,15 @@ public:
MemRefType newMemRefType =
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
layoutMap, subViewType.getMemorySpace());
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(),
- llvm::to_vector<4>(subViewOp.getDynamicOffsets()),
- llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef<Value *>(),
- newMemRefType);
+ auto newSubViewOp =
+ rewriter.create<SubViewOp>(subViewOp.getLoc(), subViewOp.source(),
+ llvm::to_vector<4>(subViewOp.offsets()),
+ llvm::to_vector<4>(subViewOp.sizes()),
+ ArrayRef<Value *>(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
- llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp,
- newSubViewOp, subViewOp.getType());
+ llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp,
+ subViewOp.getType());
return matchSuccess();
}
};
@@ -2893,14 +2884,14 @@ public:
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) {
+ llvm::any_of(subViewOp.offsets(), [](Value *stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
}
auto staticOffset = baseOffset;
- for (auto offset : enumerate(subViewOp.getDynamicOffsets())) {
+ for (auto offset : enumerate(subViewOp.offsets())) {
auto defOp = offset.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[offset.index()] > 0);
@@ -2915,39 +2906,17 @@ public:
layoutMap, subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(),
- llvm::to_vector<4>(subViewOp.getDynamicSizes()),
- llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
+ llvm::to_vector<4>(subViewOp.sizes()),
+ llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
- llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp,
- newSubViewOp, subViewOp.getType());
+ llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp,
+ subViewOp.getType());
return matchSuccess();
}
};
} // end anonymous namespace
-SubViewOp::operand_range SubViewOp::getDynamicOffsets() {
- auto numOffsets = getNumOffsets();
- assert(getNumOperands() >= numOffsets + 1);
- return {operand_begin() + 1, operand_begin() + 1 + numOffsets};
-}
-
-SubViewOp::operand_range SubViewOp::getDynamicSizes() {
- auto numSizes = getNumSizes();
- auto numOffsets = getNumOffsets();
- assert(getNumOperands() >= numSizes + numOffsets + 1);
- return {operand_begin() + 1 + numOffsets,
- operand_begin() + 1 + numOffsets + numSizes};
-}
-
-SubViewOp::operand_range SubViewOp::getDynamicStrides() {
- auto numSizes = getNumSizes();
- auto numOffsets = getNumOffsets();
- auto numStrides = getNumStrides();
- assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1);
- return {operand_begin() + (1 + numOffsets + numSizes),
- operand_begin() + (1 + numOffsets + numSizes + numStrides)};
-}
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index afdeefd023c..4d6cd3550ca 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
+DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
+ return DenseElementsAttr::get(
+ VectorType::get(static_cast<int64_t>(values.size()),
+ getIntegerType(32)),
+ values)
+ .cast<DenseIntElementsAttr>();
+}
+
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
}
OpenPOWER on IntegriCloud