diff options
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 121 |
1 files changed, 106 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 9fc6f320e1d..12029248168 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2380,6 +2380,23 @@ Value *ViewOp::getDynamicOffset() { return nullptr; } +static LogicalResult verifyDynamicStrides(MemRefType memrefType, + ArrayRef<int64_t> strides) { + ArrayRef<int64_t> shape = memrefType.getShape(); + unsigned rank = memrefType.getRank(); + assert(rank == strides.size()); + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(shape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) + return failure(); + } + return success(); +} + static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); auto viewType = op.getResult()->getType().cast<MemRefType>(); @@ -2396,7 +2413,7 @@ static LogicalResult verify(ViewOp op) { "type ") << baseType << " and view memref type " << viewType; - // Verify that the result memref type has a strided layout map. is strided + // Verify that the result memref type has a strided layout map. int64_t offset; llvm::SmallVector<int64_t, 4> strides; if (failed(getStridesAndOffset(viewType, strides, offset))) @@ -2413,20 +2430,9 @@ static LogicalResult verify(ViewOp op) { // Verify dynamic strides symbols were added to correct dimensions based // on dynamic sizes. - ArrayRef<int64_t> viewShape = viewType.getShape(); - unsigned viewRank = viewType.getRank(); - assert(viewRank == strides.size()); - bool dynamicStrides = false; - for (int i = viewRank - 2; i >= 0; --i) { - // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. - if (ShapedType::isDynamic(viewShape[i + 1])) - dynamicStrides = true; - // If stride at dim 'i' is not dynamic, return error. - if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) - return op.emitError("incorrect dynamic strides in view memref type ") - << viewType; - } - + if (failed(verifyDynamicStrides(viewType, strides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; return success(); } @@ -2544,6 +2550,91 @@ void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } //===----------------------------------------------------------------------===// +// SubViewOp +//===----------------------------------------------------------------------===// + +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector<OpAsmParser::OperandType, 4> offsetsInfo; + SmallVector<OpAsmParser::OperandType, 4> sizesInfo; + SmallVector<OpAsmParser::OperandType, 4> stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + return failure( + parser.parseOperand(srcInfo) || + parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, SubViewOp op) { + p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; + p.printOperands(op.getDynamicOffsets()); + p << "]["; + p.printOperands(op.getDynamicSizes()); + p << "]["; + p.printOperands(op.getDynamicStrides()); + p << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); +} + +static LogicalResult verify(SubViewOp op) { + auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); + auto subViewType = op.getResult()->getType().cast<MemRefType>(); + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and subview memref type " << subViewType; + + // Verify that the base memref type has a strided layout map. + int64_t baseOffset; + llvm::SmallVector<int64_t, 4> baseStrides; + if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) + return op.emitError("base type ") << subViewType << " is not strided"; + + // Verify that the result memref type has a strided layout map. + int64_t subViewOffset; + llvm::SmallVector<int64_t, 4> subViewStrides; + if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) + return op.emitError("result type ") << subViewType << " is not strided"; + + unsigned memrefOperandCount = 1; + unsigned numDynamicOffsets = llvm::size(op.getDynamicOffsets()); + unsigned numDynamicSizes = llvm::size(op.getDynamicSizes()); + unsigned numDynamicStrides = llvm::size(op.getDynamicStrides()); + + // Verify that we have the correct number of operands for the result type. + if (op.getNumOperands() != memrefOperandCount + numDynamicOffsets + + numDynamicSizes + numDynamicStrides) + return op.emitError("incorrect number of operands for type ") + << subViewType; + + // Verify that the subview layout map has a dynamic offset. + if (subViewOffset != MemRefType::getDynamicStrideOrOffset()) + return op.emitError("subview memref layout map must specify a dynamic " + "offset for type ") + << subViewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + if (failed(verifyDynamicStrides(subViewType, subViewStrides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << subViewType; + return success(); +} + +//===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// |

