diff options
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 72 |
1 files changed, 42 insertions, 30 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 60002649a21..9fc6f320e1d 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2346,29 +2346,40 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { Type srcType, dstType; return failure( parser.parseOperand(srcInfo) || - parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } static void print(OpAsmPrinter &p, ViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.getDynamicSizes()); - p << "]["; auto *dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); + p << "]["; + p.printOperands(op.getDynamicSizes()); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } +Value *ViewOp::getDynamicOffset() { + int64_t offset; + llvm::SmallVector<int64_t, 4> strides; + auto result = + succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); + assert(result); + if (result && offset == MemRefType::getDynamicStrideOrOffset()) + return getOperand(1); + return nullptr; +} + static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); auto viewType = op.getResult()->getType().cast<MemRefType>(); @@ -2438,13 +2449,37 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { return matchFailure(); auto map = memrefType.getAffineMaps()[0]; + // Get offset from old memref view type 'memRefType'. + int64_t oldOffset; + llvm::SmallVector<int64_t, 4> oldStrides; + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + return matchFailure(); + + SmallVector<Value *, 4> newOperands; + SmallVector<Value *, 4> droppedOperands; + + // Fold dynamic offset operand if it is produced by a constant. + auto *dynamicOffset = viewOp.getDynamicOffset(); + int64_t newOffset = oldOffset; + unsigned dynamicOffsetOperandCount = 0; + if (dynamicOffset != nullptr) { + auto *defOp = dynamicOffset->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { + // Dynamic offset will be folded into the map. + newOffset = constantIndexOp.getValue(); + droppedOperands.push_back(dynamicOffset); + } else { + // Unable to fold dynamic offset. Add it to 'newOperands' list. + newOperands.push_back(dynamicOffset); + dynamicOffsetOperandCount = 1; + } + } + // Fold any dynamic dim operands which are produced by a constant. SmallVector<int64_t, 4> newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector<Value *, 4> newOperands; - SmallVector<Value *, 4> droppedOperands; - unsigned dynamicDimPos = 1; + unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart(); unsigned rank = memrefType.getRank(); for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); @@ -2467,29 +2502,6 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { dynamicDimPos++; } - // Get offset from old memref view type 'memRefType'. - int64_t oldOffset; - llvm::SmallVector<int64_t, 4> oldStrides; - if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) - return matchFailure(); - - // Fold dynamic offset operand if it is produced by a constant. - auto *dynamicOffset = viewOp.getDynamicOffset(); - int64_t newOffset = oldOffset; - unsigned dynamicOffsetOperandCount = 0; - if (dynamicOffset != nullptr) { - auto *defOp = dynamicOffset->getDefiningOp(); - if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { - // Dynamic offset will be folded into the map. - newOffset = constantIndexOp.getValue(); - droppedOperands.push_back(dynamicOffset); - } else { - // Unable to fold dynamic offset. Add it to 'newOperands' list. - newOperands.push_back(dynamicOffset); - dynamicOffsetOperandCount = 1; - } - } - // Compute new strides based on 'newShapeConstants'. SmallVector<int64_t, 4> newStrides(rank); newStrides[rank - 1] = 1; |

