summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp30
1 files changed, 15 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 8a6946792b2..6a3ff74afcd 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -59,12 +59,12 @@ Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
}
IntegerType vector::getVectorSubscriptType(Builder &builder) {
- return builder.getIntegerType(32);
+ return builder.getIntegerType(64);
}
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
- ArrayRef<int32_t> values) {
- return builder.getI32ArrayAttr(values);
+ ArrayRef<int64_t> values) {
+ return builder.getI64ArrayAttr(values);
}
//===----------------------------------------------------------------------===//
@@ -404,7 +404,7 @@ static Type inferExtractOpResultType(VectorType vectorType,
}
void vector::ExtractOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int32_t> position) {
+ Value *source, ArrayRef<int64_t> position) {
result.addOperands(source);
auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
@@ -475,8 +475,8 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
result.addOperands(vector);
- auto sizesAttr = builder->getI64ArrayAttr(sizes);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(tupleType);
result.addAttribute(getSizesAttrName(), sizesAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
@@ -648,7 +648,7 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
- Value *v2, ArrayRef<int32_t> mask) {
+ Value *v2, ArrayRef<int64_t> mask) {
result.addOperands({v1, v2});
auto maskAttr = getVectorSubscriptAttr(*builder, mask);
result.addTypes(v1->getType());
@@ -772,7 +772,7 @@ static LogicalResult verify(InsertElementOp op) {
//===----------------------------------------------------------------------===//
void InsertOp::build(Builder *builder, OperationState &result, Value *source,
- Value *dest, ArrayRef<int32_t> position) {
+ Value *dest, ArrayRef<int64_t> position) {
result.addOperands({source, dest});
auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(dest->getType());
@@ -897,8 +897,8 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) {
result.addOperands({source, dest});
- auto offsetsAttr = builder->getI64ArrayAttr(offsets);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(dest->getType());
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
@@ -1250,9 +1250,9 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
Value *source, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
result.addOperands(source);
- auto offsetsAttr = builder->getI64ArrayAttr(offsets);
- auto sizesAttr = builder->getI64ArrayAttr(sizes);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
+ auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(
inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
offsetsAttr, sizesAttr, stridesAttr));
@@ -1375,7 +1375,7 @@ public:
// Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
stridedSliceOp, stridedSliceOp.getResult()->getType(),
- rewriter.getI64ArrayAttr(sliceMaskDimSizes));
+ vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
return matchSuccess();
}
};
@@ -1807,7 +1807,7 @@ public:
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult()->getType(),
- rewriter.getI64ArrayAttr(maskDimSizes));
+ vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
return matchSuccess();
}
};
OpenPOWER on IntegriCloud