diff options
author | Aart Bik <ajcbik@google.com> | 2019-12-06 11:01:54 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 11:02:29 -0800 |
commit | b36aaeafb1b026213432b5a8110467e16ed3f306 (patch) | |
tree | 00d62d7455c2fd1b9b47f19b5681492ab10e1a51 /mlir/lib/Dialect/VectorOps/VectorOps.cpp | |
parent | 398f04aa49109fd5d1eff2c1946a2956dc6b29c6 (diff) | |
download | bcm5719-llvm-b36aaeafb1b026213432b5a8110467e16ed3f306.tar.gz bcm5719-llvm-b36aaeafb1b026213432b5a8110467e16ed3f306.zip |
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 5d596f388ed..65441674165 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -416,16 +416,16 @@ static LogicalResult verify(BroadcastOp op) { // Scalar to vector broadcast is always valid. A vector // to vector broadcast needs some additional checking. if (srcVectorType) { - const int64_t srcRank = srcVectorType.getRank(); - const int64_t dstRank = dstVectorType.getRank(); + int64_t srcRank = srcVectorType.getRank(); + int64_t dstRank = dstVectorType.getRank(); if (srcRank > dstRank) return op.emitOpError("source rank higher than destination rank"); // Source has an exact match or singleton value for all trailing dimensions // (all leading dimensions are simply duplicated). - const int64_t lead = dstRank - srcRank; - for (int64_t i = 0; i < srcRank; i++) { - const int64_t srcDim = srcVectorType.getDimSize(i); - const int64_t dstDim = dstVectorType.getDimSize(lead + i); + int64_t lead = dstRank - srcRank; + for (int64_t r = 0; r < srcRank; ++r) { + int64_t srcDim = srcVectorType.getDimSize(r); + int64_t dstDim = dstVectorType.getDimSize(lead + r); if (srcDim != 1 && srcDim != dstDim) return op.emitOpError("dimension mismatch (") << srcDim << " vs. " << dstDim << ")"; |