summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-19 16:04:59 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-19 16:05:32 -0800
commit8020ad3e396bcca8dba94cea397cece81b76b119 (patch)
treec5836064e80ece1a666f56a6441fe51bc8631431 /mlir/lib/Dialect
parent6685282253c33fa2c5dc7487b04fc92d47082e78 (diff)
downloadbcm5719-llvm-8020ad3e396bcca8dba94cea397cece81b76b119.tar.gz
bcm5719-llvm-8020ad3e396bcca8dba94cea397cece81b76b119.zip
[VectorOps] Update vector transfer_read/write ops to operatate on memrefs with vector element type.
Update vector transfer_read/write ops to operatate on memrefs with vector element type. This handle cases where the memref vector element type represents the minimal memory transfer unit (or multiple of the minimal memory transfer unit). PiperOrigin-RevId: 286482115
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp121
1 files changed, 87 insertions, 34 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 541b5427af9..8a6946792b2 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -1420,6 +1420,59 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}
+static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
+ VectorType vectorType,
+ AffineMap permutationMap) {
+ auto memrefElementType = memrefType.getElementType();
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+
+ // Check that 'memrefVectorElementType' and vector element types match.
+ if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that memref vector type is a suffix of 'vectorType.
+ unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned resultVecRank = vectorType.getRank();
+ if (memrefVecEltRank > resultVecRank)
+ return op->emitOpError(
+ "requires memref vector element and vector result ranks to match.");
+ // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
+ unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ auto memrefVecEltShape = memrefVectorElementType.getShape();
+ auto resultVecShape = vectorType.getShape();
+ for (unsigned i = 0; i < memrefVecEltRank; ++i)
+ if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
+ return op->emitOpError(
+ "requires memref vector element shape to match suffix of "
+ "vector result shape.");
+ // Check that permutation map results match 'rankOffset' of vector type.
+ if (permutationMap.getNumResults() != rankOffset)
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ } else {
+ // Memref has scalar element type.
+
+ // Check that memref and vector element types match.
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that permutation map results match rank of vector type.
+ if (permutationMap.getNumResults() != vectorType.getRank())
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ }
+
+ if (permutationMap.getNumSymbols() != 0)
+ return op->emitOpError("requires permutation_map without symbols");
+ if (permutationMap.getNumInputs() != memrefType.getRank())
+ return op->emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ return success();
+}
+
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
@@ -1459,26 +1512,35 @@ static LogicalResult verify(TransferReadOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
- auto elementalType = op.padding()->getType();
- if (!VectorType::isValidElementType(elementalType))
- return op.emitOpError("requires valid padding vector elemental type");
- if (elementalType != vectorType.getElementType())
- return op.emitOpError(
- "requires formal padding and vector of the same elemental type");
- if (llvm::size(op.indices()) != memrefType.getRank())
- return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ auto paddingType = op.padding()->getType();
auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires permutation_map without symbols");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type");
+ auto memrefElementType = memrefType.getElementType();
+
+ if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
+ return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+ // Check that 'memrefVectorElementType' and 'paddingType' types match.
+ if (memrefVectorElementType != paddingType)
+ return op.emitOpError(
+ "requires memref element type and padding type to match.");
+
+ } else {
+ // Check that 'paddingType' is valid to store in a vector type.
+ if (!VectorType::isValidElementType(paddingType))
+ return op.emitOpError("requires valid padding vector elemental type");
+
+ // Check that padding type and vector element types match.
+ if (paddingType != vectorType.getElementType())
+ return op.emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ }
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
@@ -1519,24 +1581,15 @@ static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
+ auto permutationMap = op.permutation_map();
+
if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
- // Consistency of AffineMap attribute.
- auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires a symbol-less permutation_map");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type: ")
- << permutationMap.getNumInputs() << " vs " << memrefType;
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type.")
- << permutationMap.getNumResults() << " vs " << vectorType;
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
OpenPOWER on IntegriCloud