diff options
| author | Andy Davis <andydavis@google.com> | 2019-12-04 13:00:14 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-04 13:00:43 -0800 |
| commit | d20d763241020161ea173efe358d207b93310a34 (patch) | |
| tree | 66d6390a6e106ad272dd5d0ee6ae6b6a12f650cf /mlir | |
| parent | 6f895bec7d63e31ad005b0ae05395eb016e5014f (diff) | |
| download | bcm5719-llvm-d20d763241020161ea173efe358d207b93310a34.tar.gz bcm5719-llvm-d20d763241020161ea173efe358d207b93310a34.zip | |
Add canonicalization patterns for vector CreateMaskOp and StridedSliceOp to be used in the unroll vector op transformation.
Adds a ConstantMaskOp to the vector ops dialect.
Adds the following canonicalization patterns:
CreateMaskOp -> ConstantMaskOp
StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp
PiperOrigin-RevId: 283816752
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 41 | ||||
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 162 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/canonicalize.mlir | 89 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 33 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 7 |
5 files changed, 326 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 36c26fe577f..f4bfeb73dd7 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -402,6 +402,7 @@ def Vector_StridedSliceOp : static StringRef getStridesAttrName() { return "strides"; } VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); } }]; + let hasCanonicalizer = 1; } def Vector_TransferReadOp : @@ -639,7 +640,41 @@ def Vector_TypeCastOp : }]; } -// TODO(andydavis) Add constant folding support. +def Vector_ConstantMaskOp : + Vector_Op<"constant_mask", [NoSideEffect]>, + Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, + Results<(outs VectorOf<[I1]>)> { + let summary = "creates a constant vector mask"; + let description = [{ + Creates and returns a vector mask where elements of the result vector + are set to '0' or '1', based on whether the element indices are contained + within a hyper-rectangular region specified by the 'mask_dim_sizes' + array attribute argument. Each element of the 'mask_dim_sizes' array, + specifices an exclusive upper bound [0, mask-dim-size-element-value) + for a unique dimension in the vector result. The conjunction of the ranges + define a hyper-rectangular region within which elements values are set to 1 + (otherwise element values are set to 0). + + Example: create a constant vector mask of size 4x3xi1 with elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + + %1 = vector.constant_mask [3, 2] : vector<4x3xi1> + + print %1 + columns + 0 1 2 + |------------ + 0 | 1 1 0 + rows 1 | 1 1 0 + 2 | 1 1 0 + 3 | 0 0 0 + }]; + + let extraClassDeclaration = [{ + static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } + }]; +} + def Vector_CreateMaskOp : Vector_Op<"create_mask", [NoSideEffect]>, Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> { @@ -649,7 +684,7 @@ def Vector_CreateMaskOp : are set to '0' or '1', based on whether the element indices are contained within a hyper-rectangular region specified by the operands. Specifically, each operand specifies a range [0, operand-value) for a unique dimension in - the vector result. The conjunction of the operand ranges define + the vector result. The conjunction of the operand ranges define a hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). @@ -667,6 +702,8 @@ def Vector_CreateMaskOp : 2 | 1 1 0 3 | 0 0 0 }]; + + let hasCanonicalizer = 1; } // TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index ab457a6b833..f96d3bacacf 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -21,10 +21,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" @@ -342,8 +344,9 @@ static Type inferExtractElementOpResultType(VectorType vectorType, vectorType.getElementType()); } -void ExtractElementOp::build(Builder *builder, OperationState &result, - Value *source, ArrayRef<int32_t> position) { +void vector::ExtractElementOp::build(Builder *builder, OperationState &result, + Value *source, + ArrayRef<int32_t> position) { result.addOperands(source); auto positionAttr = builder->getI32ArrayAttr(position); result.addTypes(inferExtractElementOpResultType( @@ -351,7 +354,7 @@ void ExtractElementOp::build(Builder *builder, OperationState &result, result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, ExtractElementOp op) { +static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { p << op.getOperationName() << " " << *op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); p << " : " << op.vector()->getType(); @@ -387,7 +390,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(ExtractElementOp op) { +static LogicalResult verify(vector::ExtractElementOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); @@ -841,6 +844,74 @@ static LogicalResult verify(StridedSliceOp op) { return success(); } +namespace { + +static void populateFromInt64AttrArray(ArrayAttr arrayAttr, + SmallVectorImpl<int64_t> &results) { + for (auto attr : arrayAttr) + results.push_back(attr.cast<IntegerAttr>().getInt()); +} + +// Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. +class StridedSliceConstantMaskFolder final + : public OpRewritePattern<StridedSliceOp> { +public: + using OpRewritePattern<StridedSliceOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp, + PatternRewriter &rewriter) const override { + // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp. + auto defOp = stridedSliceOp.vector()->getDefiningOp(); + auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); + if (!constantMaskOp) + return matchFailure(); + // Return if 'stridedSliceOp' has non-unit strides. + if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) { + return attr.cast<IntegerAttr>().getInt() != 1; + })) + return matchFailure(); + // Gather constant mask dimension sizes. + SmallVector<int64_t, 4> maskDimSizes; + populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); + // Gather strided slice offsets and sizes. + SmallVector<int64_t, 4> sliceOffsets; + populateFromInt64AttrArray(stridedSliceOp.offsets(), sliceOffsets); + SmallVector<int64_t, 4> sliceSizes; + populateFromInt64AttrArray(stridedSliceOp.sizes(), sliceSizes); + + // Compute slice of vector mask region. + SmallVector<int64_t, 4> sliceMaskDimSizes; + assert(sliceOffsets.size() == maskDimSizes.size()); + for (const auto &it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { + int64_t maskDimSize = std::get<0>(it); + int64_t sliceOffset = std::get<1>(it); + int64_t sliceSize = std::get<2>(it); + int64_t sliceMaskDimSize = std::max( + static_cast<int64_t>(0), + std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); + sliceMaskDimSizes.push_back(sliceMaskDimSize); + } + // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked + // region is a conjunction of mask dim intervals). + if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; })) + sliceMaskDimSizes.assign(maskDimSizes.size(), 0); + + // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region. + rewriter.replaceOpWithNewOp<ConstantMaskOp>( + stridedSliceOp, stridedSliceOp.getResult()->getType(), + rewriter.getI64ArrayAttr(sliceMaskDimSizes)); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void StridedSliceOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. + results.insert<StridedSliceConstantMaskFolder>(context); +} + //===----------------------------------------------------------------------===// // TransferReadOp //===----------------------------------------------------------------------===// @@ -1034,6 +1105,53 @@ static LogicalResult verify(TypeCastOp &op) { } //===----------------------------------------------------------------------===// +// ConstantMaskOp +//===----------------------------------------------------------------------===// + +ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) { + Type resultType; + ArrayAttr maskDimSizesAttr; + StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName(); + return failure( + parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) || + parser.parseColonType(resultType) || + parser.addTypeToList(resultType, result.types)); +} + +static void print(OpAsmPrinter &p, ConstantMaskOp &op) { + p << op.getOperationName() << ' ' << op.mask_dim_sizes(); + p << " : " << op.getResult()->getType(); +} + +static LogicalResult verify(ConstantMaskOp &op) { + // Verify that array attr size matches the rank of the vector result. + auto resultType = op.getResult()->getType().cast<VectorType>(); + if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank()) + return op.emitOpError( + "must specify array attr of size equal vector result rank"); + // Verify that each array attr element is in bounds of corresponding vector + // result dimension size. + auto resultShape = resultType.getShape(); + SmallVector<int64_t, 4> maskDimSizes; + for (auto it : llvm::enumerate(op.mask_dim_sizes())) { + int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); + if (attrValue < 0 || attrValue > resultShape[it.index()]) + return op.emitOpError( + "array attr of size out of bounds of vector result dimension size"); + maskDimSizes.push_back(attrValue); + } + // Verify that if one mask dim size is zero, they all should be zero (because + // the mask region is a conjunction of each mask dimension interval). + bool any_zeros = llvm::is_contained(maskDimSizes, 0); + bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); + if (any_zeros && !all_zeros) + return op.emitOpError("expected all mask dim sizes to be zeros, " + "as a result of conjunction with zero mask dim"); + return success(); +} + +//===----------------------------------------------------------------------===// // CreateMaskOp //===----------------------------------------------------------------------===// @@ -1064,6 +1182,42 @@ static LogicalResult verify(CreateMaskOp &op) { return success(); } +namespace { + +// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. +class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { +public: + using OpRewritePattern<CreateMaskOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + // Return if any of 'createMaskOp' operands are not defined by a constant. + auto is_not_def_by_constant = [](Value *operand) { + return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp()); + }; + if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) + return matchFailure(); + // Gather constant mask dimension sizes. + SmallVector<int64_t, 4> maskDimSizes; + for (auto *operand : createMaskOp.operands()) { + auto defOp = operand->getDefiningOp(); + maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue()); + } + // Replace 'createMaskOp' with ConstantMaskOp. + rewriter.replaceOpWithNewOp<ConstantMaskOp>( + createMaskOp, createMaskOp.getResult()->getType(), + rewriter.getI64ArrayAttr(maskDimSizes)); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void CreateMaskOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<CreateMaskFolder>(context); +} + //===----------------------------------------------------------------------===// // IndexTupleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/canonicalize.mlir b/mlir/test/Dialect/VectorOps/canonicalize.mlir new file mode 100644 index 00000000000..8dca47515ce --- /dev/null +++ b/mlir/test/Dialect/VectorOps/canonicalize.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// ----- + +// CHECK-LABEL: create_vector_mask_to_constant_mask +func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) { + %c2 = constant 2 : index + %c3 = constant 3 : index + // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> + %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> + return %0 : vector<4x3xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [1, 0], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK: vector.constant_mask [1, 2] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK: vector.constant_mask [2, 1] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} + : vector<4x3xi1> to vector<2x1xi1> + // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1> + return %1 : vector<2x1xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} + : vector<4x3xi1> to vector<2x1xi1> + // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1> + return %1 : vector<2x1xi1> +} + +// ----- + +func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) { + %0 = vector.constant_mask [2, 2] : vector<4x3xi1> + %1 = vector.strided_slice %0 + {offsets = [1, 1], sizes = [2, 1], strides = [1, 1]} + : vector<4x3xi1> to vector<2x1xi1> + // CHECK: vector.constant_mask [1, 1] : vector<2x1xi1> + return %1 : vector<2x1xi1> +} diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 0f19033fb42..bd664f71575 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -616,3 +616,36 @@ func @create_mask() { %0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1> return } + + +// ----- + +func @constant_mask() { + // expected-error@+1 {{must specify array attr of size equal vector result rank}} + %0 = vector.constant_mask [3, 2, 7] : vector<4x3xi1> + return +} + +// ----- + +func @constant_mask_out_of_bounds() { + // expected-error@+1 {{array attr of size out of bounds of vector result dimension size}} + %0 = vector.constant_mask [-1, 2] : vector<4x3xi1> + return +} + +// ----- + +func @constant_mask_out_of_bounds() { + // expected-error@+1 {{array attr of size out of bounds of vector result dimension size}} + %0 = vector.constant_mask [3, 4] : vector<4x3xi1> + return +} + +// ----- + +func @constant_mask_with_zero_mask_dim_size() { + // expected-error@+1 {{expected all mask dim sizes to be zeros, as a result of conjunction with zero mask dim}} + %0 = vector.constant_mask [0, 2] : vector<4x3xi1> + return +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 0a52a1ea45b..cb87c20a2b9 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -136,3 +136,10 @@ func @create_vector_mask() { return } + +// CHECK-LABEL: constant_vector_mask +func @constant_vector_mask() { + // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> + %0 = vector.constant_mask [3, 2] : vector<4x3xi1> + return +} |

