summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorOps.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-22 21:59:55 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-22 22:00:23 -0800
commit35807bc4c5c9d8abc31ba0b2f955a82abf276e12 (patch)
treed083d37d993a774239081509a50e3e6c65366421 /mlir/lib/Dialect/VectorOps/VectorOps.cpp
parent22954a0e408afde1d8686dffb3a3dcab107a2cd3 (diff)
downloadbcm5719-llvm-35807bc4c5c9d8abc31ba0b2f955a82abf276e12.tar.gz
bcm5719-llvm-35807bc4c5c9d8abc31ba0b2f955a82abf276e12.zip
NFC: Introduce new ValuePtr/ValueRef typedefs to simplify the transition to Value being value-typed.
This is an initial step to refactoring the representation of OpResult as proposed in: https://groups.google.com/a/tensorflow.org/g/mlir/c/XXzzKhqqF_0/m/v6bKb08WCgAJ This change will make it much simpler to incrementally transition all of the existing code to use value-typed semantics. PiperOrigin-RevId: 286844725
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp30
1 files changed, 15 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 6a3ff74afcd..18c1714f403 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -72,7 +72,7 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
//===----------------------------------------------------------------------===//
void vector::ContractionOp::build(Builder *builder, OperationState &result,
- Value *lhs, Value *rhs, Value *acc,
+ ValuePtr lhs, ValuePtr rhs, ValuePtr acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) {
result.addOperands({lhs, rhs, acc});
@@ -404,7 +404,7 @@ static Type inferExtractOpResultType(VectorType vectorType,
}
void vector::ExtractOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int64_t> position) {
+ ValuePtr source, ArrayRef<int64_t> position) {
result.addOperands(source);
auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
@@ -471,7 +471,7 @@ static LogicalResult verify(vector::ExtractOp op) {
//===----------------------------------------------------------------------===//
void ExtractSlicesOp::build(Builder *builder, OperationState &result,
- TupleType tupleType, Value *vector,
+ TupleType tupleType, ValuePtr vector,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
result.addOperands(vector);
@@ -647,8 +647,8 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
// ShuffleOp
//===----------------------------------------------------------------------===//
-void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
- Value *v2, ArrayRef<int64_t> mask) {
+void ShuffleOp::build(Builder *builder, OperationState &result, ValuePtr v1,
+ ValuePtr v2, ArrayRef<int64_t> mask) {
result.addOperands({v1, v2});
auto maskAttr = getVectorSubscriptAttr(*builder, mask);
result.addTypes(v1->getType());
@@ -771,8 +771,8 @@ static LogicalResult verify(InsertElementOp op) {
// InsertOp
//===----------------------------------------------------------------------===//
-void InsertOp::build(Builder *builder, OperationState &result, Value *source,
- Value *dest, ArrayRef<int64_t> position) {
+void InsertOp::build(Builder *builder, OperationState &result, ValuePtr source,
+ ValuePtr dest, ArrayRef<int64_t> position) {
result.addOperands({source, dest});
auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(dest->getType());
@@ -893,7 +893,7 @@ void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===//
void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
- Value *source, Value *dest,
+ ValuePtr source, ValuePtr dest,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) {
result.addOperands({source, dest});
@@ -1201,17 +1201,17 @@ static LogicalResult verify(ReshapeOp op) {
// If all shape operands are produced by constant ops, verify that product
// of dimensions for input/output shape match.
- auto isDefByConstant = [](Value *operand) {
+ auto isDefByConstant = [](ValuePtr operand) {
return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
};
if (llvm::all_of(op.input_shape(), isDefByConstant) &&
llvm::all_of(op.output_shape(), isDefByConstant)) {
int64_t numInputElements = 1;
- for (auto *operand : op.input_shape())
+ for (auto operand : op.input_shape())
numInputElements *=
cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
int64_t numOutputElements = 1;
- for (auto *operand : op.output_shape())
+ for (auto operand : op.output_shape())
numOutputElements *=
cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
if (numInputElements != numOutputElements)
@@ -1247,7 +1247,7 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
}
void StridedSliceOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int64_t> offsets,
+ ValuePtr source, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
result.addOperands(source);
auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
@@ -1603,7 +1603,7 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) {
}
void TypeCastOp::build(Builder *builder, OperationState &result,
- Value *source) {
+ ValuePtr source) {
result.addOperands(source);
result.addTypes(
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
@@ -1793,14 +1793,14 @@ public:
PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
- auto is_not_def_by_constant = [](Value *operand) {
+ auto is_not_def_by_constant = [](ValuePtr operand) {
return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
return matchFailure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
- for (auto *operand : createMaskOp.operands()) {
+ for (auto operand : createMaskOp.operands()) {
auto defOp = operand->getDefiningOp();
maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
}
OpenPOWER on IntegriCloud