summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-12 04:08:26 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-12 04:08:57 -0700
commit252ada493276eace3e23802fb70ff3c7be53837d (patch)
tree3a40800f55e5486089cc8d6ae58dda343b88d474
parent5290e8c36d4e4aac4d8ce2726f6d373e87501945 (diff)
downloadbcm5719-llvm-252ada493276eace3e23802fb70ff3c7be53837d.tar.gz
bcm5719-llvm-252ada493276eace3e23802fb70ff3c7be53837d.zip
Add lowering of vector dialect to LLVM dialect.
This CL is step 3/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools. This CL adds support for converting MLIR n-D vector types to (n-1)-D arrays of 1-D LLVM vectors and a conversion VectorToLLVM that lowers the `vector.extractelement` and `vector.outerproduct` instructions to the proper mix of `llvm.vectorshuffle`, `llvm.extractelement` and `llvm.mulf`. This has been independently verified to produce proper avx2 code. Input: ``` func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32> return %3 : vector<8xf32> } ``` Command: ``` mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2 ``` Output: ``` vec_1d: # @vec_1d # %bb.0: vbroadcastss %xmm0, %ymm0 vmulps %ymm1, %ymm0, %ymm0 retq ``` PiperOrigin-RevId: 262895929
-rw-r--r--mlir/g3doc/ConversionToLLVMDialect.md13
-rw-r--r--mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h26
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp207
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir33
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
8 files changed, 303 insertions, 15 deletions
diff --git a/mlir/g3doc/ConversionToLLVMDialect.md b/mlir/g3doc/ConversionToLLVMDialect.md
index 9da27c4ca68..a2898e022c6 100644
--- a/mlir/g3doc/ConversionToLLVMDialect.md
+++ b/mlir/g3doc/ConversionToLLVMDialect.md
@@ -39,11 +39,14 @@ object. For example, on x86-64 CPUs it converts to `!llvm.type<"i64">`.
### Vector Types
LLVM IR only supports *one-dimensional* vectors, unlike MLIR where vectors can
-be multi-dimensional. MLIR vectors are converted to LLVM IR vectors of the same
-size with element type converted using these conversion rules. Vector types
-cannot be nested in either IR.
-
-For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">`.
+be multi-dimensional. Vector types cannot be nested in either IR. In the
+one-dimensional case, MLIR vectors are converted to LLVM IR vectors of the same
+size with element type converted using these conversion rules. In the
+n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types
+of one-dimensional vectors.
+
+For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
+`vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
### Memref Types
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
new file mode 100644
index 00000000000..39b7ee2d03f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
@@ -0,0 +1,26 @@
+//===- VectorToLLVM.h - Pass converting vector to LLVM dialect --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+
+namespace mlir {
+class ModulePassBase;
+
+ModulePassBase *createLowerVectorToLLVMPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 1ddd103f28e..6c14f5487a6 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToSPIRV)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
+add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 5bb281112f5..c62a5d8719d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -145,18 +145,20 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
-// Convert a 1D vector type to an LLVM vector type.
+// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
+// n > 1.
+// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
+// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
- if (type.getRank() != 1) {
- auto *mlirContext = llvmDialect->getContext();
- emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported");
+ auto elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
return {};
- }
-
- LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
- return elementType
- ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
- : Type();
+ auto vectorType =
+ LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
+ auto shape = type.getShape();
+ for (int i = shape.size() - 2; i >= 0; --i)
+ vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
+ return vectorType;
}
// Dispatch based on the actual type. Return null type on error.
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..a75b6c1e98a
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLLVM
+ VectorToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
+)
+set(LIBS
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRVectorToLLVM ${LIBS})
+target_link_libraries(MLIRVectorToLLVM ${LIBS})
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
new file mode 100644
index 00000000000..bf90edba401
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -0,0 +1,207 @@
+//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+
+template <typename T>
+static LLVM::LLVMType getPtrToElementType(T containerType,
+ LLVMTypeConverter &lowering) {
+ return lowering.convertType(containerType.getElementType())
+ .template cast<LLVM::LLVMType>()
+ .getPointerTo();
+}
+
+// Create an array attribute containing integer attributes with values provided
+// in `position`.
+static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
+ SmallVector<Attribute, 4> attrs;
+ attrs.reserve(position.size());
+ for (auto p : position)
+ attrs.push_back(builder.getI64IntegerAttr(p));
+ return builder.getArrayAttr(attrs);
+}
+
+class ExtractElementOpConversion : public LLVMOpLowering {
+public:
+ explicit ExtractElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
+ auto extractOp = cast<vector::ExtractElementOp>(op);
+ auto vectorType = extractOp.vector()->getType().cast<VectorType>();
+ auto resultType = extractOp.getResult()->getType();
+ auto llvmResultType = lowering.convertType(resultType);
+
+ auto positionArrayAttr = extractOp.position();
+ // One-shot extraction of vector from array (only requires extractvalue).
+ if (resultType.isa<VectorType>()) {
+ Value *extracted =
+ rewriter
+ .create<LLVM::ExtractValueOp>(loc, llvmResultType,
+ adaptor.vector(), positionArrayAttr)
+ .getResult();
+ rewriter.replaceOp(op, extracted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from struct.
+ auto *context = op->getContext();
+ Value *extracted = adaptor.vector();
+ auto positionAttrs = positionArrayAttr.getValue();
+ auto indexType = rewriter.getIndexType();
+ if (positionAttrs.size() > 1) {
+ auto nDVectorType = vectorType;
+ auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
+ nDVectorType.getElementType());
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter
+ .create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs)
+ .getResult();
+ }
+
+ // Remaining extraction of element from 1-D LLVM vector
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto constant = rewriter
+ .create<LLVM::ConstantOp>(
+ loc, lowering.convertType(indexType), position)
+ .getResult();
+ extracted =
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
+ .getResult();
+ rewriter.replaceOp(op, extracted);
+
+ return matchSuccess();
+ }
+};
+
+class OuterProductOpConversion : public LLVMOpLowering {
+public:
+ explicit OuterProductOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
+ auto *ctx = op->getContext();
+ auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
+ auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto llvmArrayOfVectType = lowering.convertType(
+ cast<vector::OuterProductOp>(op).getResult()->getType());
+ Value *desc =
+ rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
+ for (unsigned i = 0, e = rankV1; i < e; ++i) {
+ // Emit the following pattern:
+ // vec(a[i]) * b -> llvmStructOfVectType[i]
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ // shufflevector explicitly requires i32 /
+ auto attr = rewriter.getI32IntegerAttr(i);
+ SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
+ auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
+ auto *broadcasted =
+ rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
+ .getResult();
+ auto *multiplied =
+ rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
+ desc = rewriter
+ .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
+ multiplied,
+ positionAttr(rewriter, i))
+ .getResult();
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+/// Populate the given list with patterns that convert from Vector to LLVM.
+static void
+populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
+ ctx, converter);
+}
+
+namespace {
+struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+ void runOnModule();
+};
+} // namespace
+
+void LowerVectorToLLVMPass::runOnModule() {
+ // Convert to the LLVM IR dialect using the converter defined above.
+ OwningRewritePatternList patterns;
+ LLVMTypeConverter converter(&getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ if (failed(
+ applyPartialConversion(getModule(), target, patterns, &converter))) {
+ signalPassFailure();
+ }
+}
+
+ModulePassBase *mlir::createLowerVectorToLLVMPass() {
+ return new LowerVectorToLLVMPass();
+}
+
+static PassRegistration<LowerVectorToLLVMPass>
+ pass("vector-lower-to-llvm-dialect",
+ "Lower the operations from the vector dialect into the LLVM dialect");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
new file mode 100644
index 00000000000..f582de146ba
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
+
+func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+ %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
+ return %3 : vector<8xf32>
+}
+// CHECK-LABEL: vec_1d
+// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
+// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>">
+
+func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+ return %2 : vector<4x8xf32>
+}
+// CHECK-LABEL: vec_2d
+// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
+// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
+
+func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]"> \ No newline at end of file
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 26f8885a242..ff12852e347 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -43,6 +43,7 @@ set(LIBS
MLIRTestTransforms
MLIRSupport
MLIRVectorOps
+ MLIRVectorToLLVM
)
if(MLIR_CUDA_CONVERSIONS_ENABLED)
list(APPEND LIBS
OpenPOWER on IntegriCloud