diff options
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 8a6946792b2..6a3ff74afcd 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -59,12 +59,12 @@ Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder, } IntegerType vector::getVectorSubscriptType(Builder &builder) { - return builder.getIntegerType(32); + return builder.getIntegerType(64); } ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, - ArrayRef<int32_t> values) { - return builder.getI32ArrayAttr(values); + ArrayRef<int64_t> values) { + return builder.getI64ArrayAttr(values); } //===----------------------------------------------------------------------===// @@ -404,7 +404,7 @@ static Type inferExtractOpResultType(VectorType vectorType, } void vector::ExtractOp::build(Builder *builder, OperationState &result, - Value *source, ArrayRef<int32_t> position) { + Value *source, ArrayRef<int64_t> position) { result.addOperands(source); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(), @@ -475,8 +475,8 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { result.addOperands(vector); - auto sizesAttr = builder->getI64ArrayAttr(sizes); - auto stridesAttr = builder->getI64ArrayAttr(strides); + auto sizesAttr = getVectorSubscriptAttr(*builder, sizes); + auto stridesAttr = getVectorSubscriptAttr(*builder, strides); result.addTypes(tupleType); result.addAttribute(getSizesAttrName(), sizesAttr); result.addAttribute(getStridesAttrName(), stridesAttr); @@ -648,7 +648,7 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1, - Value *v2, ArrayRef<int32_t> mask) { + Value *v2, ArrayRef<int64_t> mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(*builder, mask); result.addTypes(v1->getType()); @@ -772,7 +772,7 @@ static LogicalResult verify(InsertElementOp op) { //===----------------------------------------------------------------------===// void InsertOp::build(Builder *builder, OperationState &result, Value *source, - Value *dest, ArrayRef<int32_t> position) { + Value *dest, ArrayRef<int64_t> position) { result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(*builder, position); result.addTypes(dest->getType()); @@ -897,8 +897,8 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result, ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides) { result.addOperands({source, dest}); - auto offsetsAttr = builder->getI64ArrayAttr(offsets); - auto stridesAttr = builder->getI64ArrayAttr(strides); + auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets); + auto stridesAttr = getVectorSubscriptAttr(*builder, strides); result.addTypes(dest->getType()); result.addAttribute(getOffsetsAttrName(), offsetsAttr); result.addAttribute(getStridesAttrName(), stridesAttr); @@ -1250,9 +1250,9 @@ void StridedSliceOp::build(Builder *builder, OperationState &result, Value *source, ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) { result.addOperands(source); - auto offsetsAttr = builder->getI64ArrayAttr(offsets); - auto sizesAttr = builder->getI64ArrayAttr(sizes); - auto stridesAttr = builder->getI64ArrayAttr(strides); + auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets); + auto sizesAttr = getVectorSubscriptAttr(*builder, sizes); + auto stridesAttr = getVectorSubscriptAttr(*builder, strides); result.addTypes( inferStridedSliceOpResultType(source->getType().cast<VectorType>(), offsetsAttr, sizesAttr, stridesAttr)); @@ -1375,7 +1375,7 @@ public: // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region. rewriter.replaceOpWithNewOp<ConstantMaskOp>( stridedSliceOp, stridedSliceOp.getResult()->getType(), - rewriter.getI64ArrayAttr(sliceMaskDimSizes)); + vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); return matchSuccess(); } }; @@ -1807,7 +1807,7 @@ public: // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp<ConstantMaskOp>( createMaskOp, createMaskOp.getResult()->getType(), - rewriter.getI64ArrayAttr(maskDimSizes)); + vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); return matchSuccess(); } }; |