summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-04 13:00:14 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-04 13:00:43 -0800
commitd20d763241020161ea173efe358d207b93310a34 (patch)
tree66d6390a6e106ad272dd5d0ee6ae6b6a12f650cf /mlir
parent6f895bec7d63e31ad005b0ae05395eb016e5014f (diff)
downloadbcm5719-llvm-d20d763241020161ea173efe358d207b93310a34.tar.gz
bcm5719-llvm-d20d763241020161ea173efe358d207b93310a34.zip
Add canonicalization patterns for vector CreateMaskOp and StridedSliceOp to be used in the unroll vector op transformation.
Adds a ConstantMaskOp to the vector ops dialect. Adds the following canonicalization patterns: CreateMaskOp -> ConstantMaskOp StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp PiperOrigin-RevId: 283816752
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/VectorOps/VectorOps.td41
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp162
-rw-r--r--mlir/test/Dialect/VectorOps/canonicalize.mlir89
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir33
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir7
5 files changed, 326 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 36c26fe577f..f4bfeb73dd7 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -402,6 +402,7 @@ def Vector_StridedSliceOp :
static StringRef getStridesAttrName() { return "strides"; }
VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
}];
+ let hasCanonicalizer = 1;
}
def Vector_TransferReadOp :
@@ -639,7 +640,41 @@ def Vector_TypeCastOp :
}];
}
-// TODO(andydavis) Add constant folding support.
+def Vector_ConstantMaskOp :
+ Vector_Op<"constant_mask", [NoSideEffect]>,
+ Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
+ Results<(outs VectorOf<[I1]>)> {
+ let summary = "creates a constant vector mask";
+ let description = [{
+ Creates and returns a vector mask where elements of the result vector
+ are set to '0' or '1', based on whether the element indices are contained
+ within a hyper-rectangular region specified by the 'mask_dim_sizes'
+ array attribute argument. Each element of the 'mask_dim_sizes' array,
+ specifices an exclusive upper bound [0, mask-dim-size-element-value)
+ for a unique dimension in the vector result. The conjunction of the ranges
+ define a hyper-rectangular region within which elements values are set to 1
+ (otherwise element values are set to 0).
+
+ Example: create a constant vector mask of size 4x3xi1 with elements in range
+ 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+
+ %1 = vector.constant_mask [3, 2] : vector<4x3xi1>
+
+ print %1
+ columns
+ 0 1 2
+ |------------
+ 0 | 1 1 0
+ rows 1 | 1 1 0
+ 2 | 1 1 0
+ 3 | 0 0 0
+ }];
+
+ let extraClassDeclaration = [{
+ static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
+ }];
+}
+
def Vector_CreateMaskOp :
Vector_Op<"create_mask", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
@@ -649,7 +684,7 @@ def Vector_CreateMaskOp :
are set to '0' or '1', based on whether the element indices are contained
within a hyper-rectangular region specified by the operands. Specifically,
each operand specifies a range [0, operand-value) for a unique dimension in
- the vector result. The conjunction of the operand ranges define
+ the vector result. The conjunction of the operand ranges define a
hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0).
@@ -667,6 +702,8 @@ def Vector_CreateMaskOp :
2 | 1 1 0
3 | 0 0 0
}];
+
+ let hasCanonicalizer = 1;
}
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index ab457a6b833..f96d3bacacf 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -21,10 +21,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
@@ -342,8 +344,9 @@ static Type inferExtractElementOpResultType(VectorType vectorType,
vectorType.getElementType());
}
-void ExtractElementOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int32_t> position) {
+void vector::ExtractElementOp::build(Builder *builder, OperationState &result,
+ Value *source,
+ ArrayRef<int32_t> position) {
result.addOperands(source);
auto positionAttr = builder->getI32ArrayAttr(position);
result.addTypes(inferExtractElementOpResultType(
@@ -351,7 +354,7 @@ void ExtractElementOp::build(Builder *builder, OperationState &result,
result.addAttribute(getPositionAttrName(), positionAttr);
}
-static void print(OpAsmPrinter &p, ExtractElementOp op) {
+static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
p << op.getOperationName() << " " << *op.vector() << op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
p << " : " << op.vector()->getType();
@@ -387,7 +390,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
parser.addTypeToList(resType, result.types));
}
-static LogicalResult verify(ExtractElementOp op) {
+static LogicalResult verify(vector::ExtractElementOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
return op.emitOpError("expected non-empty position attribute");
@@ -841,6 +844,74 @@ static LogicalResult verify(StridedSliceOp op) {
return success();
}
+namespace {
+
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+ SmallVectorImpl<int64_t> &results) {
+ for (auto attr : arrayAttr)
+ results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
+// Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
+class StridedSliceConstantMaskFolder final
+ : public OpRewritePattern<StridedSliceOp> {
+public:
+ using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
+ auto defOp = stridedSliceOp.vector()->getDefiningOp();
+ auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
+ if (!constantMaskOp)
+ return matchFailure();
+ // Return if 'stridedSliceOp' has non-unit strides.
+ if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() != 1;
+ }))
+ return matchFailure();
+ // Gather constant mask dimension sizes.
+ SmallVector<int64_t, 4> maskDimSizes;
+ populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
+ // Gather strided slice offsets and sizes.
+ SmallVector<int64_t, 4> sliceOffsets;
+ populateFromInt64AttrArray(stridedSliceOp.offsets(), sliceOffsets);
+ SmallVector<int64_t, 4> sliceSizes;
+ populateFromInt64AttrArray(stridedSliceOp.sizes(), sliceSizes);
+
+ // Compute slice of vector mask region.
+ SmallVector<int64_t, 4> sliceMaskDimSizes;
+ assert(sliceOffsets.size() == maskDimSizes.size());
+ for (const auto &it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+ int64_t maskDimSize = std::get<0>(it);
+ int64_t sliceOffset = std::get<1>(it);
+ int64_t sliceSize = std::get<2>(it);
+ int64_t sliceMaskDimSize = std::max(
+ static_cast<int64_t>(0),
+ std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
+ sliceMaskDimSizes.push_back(sliceMaskDimSize);
+ }
+ // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
+ // region is a conjunction of mask dim intervals).
+ if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
+ sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
+
+ // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
+ rewriter.replaceOpWithNewOp<ConstantMaskOp>(
+ stridedSliceOp, stridedSliceOp.getResult()->getType(),
+ rewriter.getI64ArrayAttr(sliceMaskDimSizes));
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void StridedSliceOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
+ results.insert<StridedSliceConstantMaskFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// TransferReadOp
//===----------------------------------------------------------------------===//
@@ -1034,6 +1105,53 @@ static LogicalResult verify(TypeCastOp &op) {
}
//===----------------------------------------------------------------------===//
+// ConstantMaskOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
+ Type resultType;
+ ArrayAttr maskDimSizesAttr;
+ StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
+ return failure(
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
+ parser.parseColonType(resultType) ||
+ parser.addTypeToList(resultType, result.types));
+}
+
+static void print(OpAsmPrinter &p, ConstantMaskOp &op) {
+ p << op.getOperationName() << ' ' << op.mask_dim_sizes();
+ p << " : " << op.getResult()->getType();
+}
+
+static LogicalResult verify(ConstantMaskOp &op) {
+ // Verify that array attr size matches the rank of the vector result.
+ auto resultType = op.getResult()->getType().cast<VectorType>();
+ if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
+ return op.emitOpError(
+ "must specify array attr of size equal vector result rank");
+ // Verify that each array attr element is in bounds of corresponding vector
+ // result dimension size.
+ auto resultShape = resultType.getShape();
+ SmallVector<int64_t, 4> maskDimSizes;
+ for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
+ int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
+ if (attrValue < 0 || attrValue > resultShape[it.index()])
+ return op.emitOpError(
+ "array attr of size out of bounds of vector result dimension size");
+ maskDimSizes.push_back(attrValue);
+ }
+ // Verify that if one mask dim size is zero, they all should be zero (because
+ // the mask region is a conjunction of each mask dimension interval).
+ bool any_zeros = llvm::is_contained(maskDimSizes, 0);
+ bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
+ if (any_zeros && !all_zeros)
+ return op.emitOpError("expected all mask dim sizes to be zeros, "
+ "as a result of conjunction with zero mask dim");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
@@ -1064,6 +1182,42 @@ static LogicalResult verify(CreateMaskOp &op) {
return success();
}
+namespace {
+
+// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
+public:
+ using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ // Return if any of 'createMaskOp' operands are not defined by a constant.
+ auto is_not_def_by_constant = [](Value *operand) {
+ return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
+ };
+ if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
+ return matchFailure();
+ // Gather constant mask dimension sizes.
+ SmallVector<int64_t, 4> maskDimSizes;
+ for (auto *operand : createMaskOp.operands()) {
+ auto defOp = operand->getDefiningOp();
+ maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
+ }
+ // Replace 'createMaskOp' with ConstantMaskOp.
+ rewriter.replaceOpWithNewOp<ConstantMaskOp>(
+ createMaskOp, createMaskOp.getResult()->getType(),
+ rewriter.getI64ArrayAttr(maskDimSizes));
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void CreateMaskOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<CreateMaskFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// IndexTupleOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/VectorOps/canonicalize.mlir b/mlir/test/Dialect/VectorOps/canonicalize.mlir
new file mode 100644
index 00000000000..8dca47515ce
--- /dev/null
+++ b/mlir/test/Dialect/VectorOps/canonicalize.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: create_vector_mask_to_constant_mask
+func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
+ %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+ return %0 : vector<4x3xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [1, 0], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK: vector.constant_mask [1, 2] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK: vector.constant_mask [2, 1] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x1xi1>
+ // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+ return %1 : vector<2x1xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x1xi1>
+ // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+ return %1 : vector<2x1xi1>
+}
+
+// -----
+
+func @strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
+ %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
+ %1 = vector.strided_slice %0
+ {offsets = [1, 1], sizes = [2, 1], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x1xi1>
+ // CHECK: vector.constant_mask [1, 1] : vector<2x1xi1>
+ return %1 : vector<2x1xi1>
+}
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 0f19033fb42..bd664f71575 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -616,3 +616,36 @@ func @create_mask() {
%0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1>
return
}
+
+
+// -----
+
+func @constant_mask() {
+ // expected-error@+1 {{must specify array attr of size equal vector result rank}}
+ %0 = vector.constant_mask [3, 2, 7] : vector<4x3xi1>
+ return
+}
+
+// -----
+
+func @constant_mask_out_of_bounds() {
+ // expected-error@+1 {{array attr of size out of bounds of vector result dimension size}}
+ %0 = vector.constant_mask [-1, 2] : vector<4x3xi1>
+ return
+}
+
+// -----
+
+func @constant_mask_out_of_bounds() {
+ // expected-error@+1 {{array attr of size out of bounds of vector result dimension size}}
+ %0 = vector.constant_mask [3, 4] : vector<4x3xi1>
+ return
+}
+
+// -----
+
+func @constant_mask_with_zero_mask_dim_size() {
+ // expected-error@+1 {{expected all mask dim sizes to be zeros, as a result of conjunction with zero mask dim}}
+ %0 = vector.constant_mask [0, 2] : vector<4x3xi1>
+ return
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index 0a52a1ea45b..cb87c20a2b9 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -136,3 +136,10 @@ func @create_vector_mask() {
return
}
+
+// CHECK-LABEL: constant_vector_mask
+func @constant_vector_mask() {
+ // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
+ %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
+ return
+}
OpenPOWER on IntegriCloud