summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/g3doc/ConversionToLLVMDialect.md13
-rw-r--r--mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h26
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp207
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir33
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
8 files changed, 303 insertions, 15 deletions
diff --git a/mlir/g3doc/ConversionToLLVMDialect.md b/mlir/g3doc/ConversionToLLVMDialect.md
index 9da27c4ca68..a2898e022c6 100644
--- a/mlir/g3doc/ConversionToLLVMDialect.md
+++ b/mlir/g3doc/ConversionToLLVMDialect.md
@@ -39,11 +39,14 @@ object. For example, on x86-64 CPUs it converts to `!llvm.type<"i64">`.
### Vector Types
LLVM IR only supports *one-dimensional* vectors, unlike MLIR where vectors can
-be multi-dimensional. MLIR vectors are converted to LLVM IR vectors of the same
-size with element type converted using these conversion rules. Vector types
-cannot be nested in either IR.
-
-For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">`.
+be multi-dimensional. Vector types cannot be nested in either IR. In the
+one-dimensional case, MLIR vectors are converted to LLVM IR vectors of the same
+size with element type converted using these conversion rules. In the
+n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types
+of one-dimensional vectors.
+
+For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
+`vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
### Memref Types
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
new file mode 100644
index 00000000000..39b7ee2d03f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
@@ -0,0 +1,26 @@
+//===- VectorToLLVM.h - Pass converting vector to LLVM dialect --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+
+namespace mlir {
+class ModulePassBase;
+
+ModulePassBase *createLowerVectorToLLVMPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 1ddd103f28e..6c14f5487a6 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToSPIRV)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
+add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 5bb281112f5..c62a5d8719d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -145,18 +145,20 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
-// Convert a 1D vector type to an LLVM vector type.
+// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
+// n > 1.
+// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
+// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
- if (type.getRank() != 1) {
- auto *mlirContext = llvmDialect->getContext();
- emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported");
+ auto elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
return {};
- }
-
- LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
- return elementType
- ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
- : Type();
+ auto vectorType =
+ LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
+ auto shape = type.getShape();
+ for (int i = shape.size() - 2; i >= 0; --i)
+ vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
+ return vectorType;
}
// Dispatch based on the actual type. Return null type on error.
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..a75b6c1e98a
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLLVM
+ VectorToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
+)
+set(LIBS
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRVectorToLLVM ${LIBS})
+target_link_libraries(MLIRVectorToLLVM ${LIBS})
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
new file mode 100644
index 00000000000..bf90edba401
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -0,0 +1,207 @@
+//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/VectorOps/VectorOps.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+
+template <typename T>
+static LLVM::LLVMType getPtrToElementType(T containerType,
+ LLVMTypeConverter &lowering) {
+ return lowering.convertType(containerType.getElementType())
+ .template cast<LLVM::LLVMType>()
+ .getPointerTo();
+}
+
+// Create an array attribute containing integer attributes with values provided
+// in `position`.
+static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
+ SmallVector<Attribute, 4> attrs;
+ attrs.reserve(position.size());
+ for (auto p : position)
+ attrs.push_back(builder.getI64IntegerAttr(p));
+ return builder.getArrayAttr(attrs);
+}
+
+class ExtractElementOpConversion : public LLVMOpLowering {
+public:
+ explicit ExtractElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
+ auto extractOp = cast<vector::ExtractElementOp>(op);
+ auto vectorType = extractOp.vector()->getType().cast<VectorType>();
+ auto resultType = extractOp.getResult()->getType();
+ auto llvmResultType = lowering.convertType(resultType);
+
+ auto positionArrayAttr = extractOp.position();
+ // One-shot extraction of vector from array (only requires extractvalue).
+ if (resultType.isa<VectorType>()) {
+ Value *extracted =
+ rewriter
+ .create<LLVM::ExtractValueOp>(loc, llvmResultType,
+ adaptor.vector(), positionArrayAttr)
+ .getResult();
+ rewriter.replaceOp(op, extracted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from struct.
+ auto *context = op->getContext();
+ Value *extracted = adaptor.vector();
+ auto positionAttrs = positionArrayAttr.getValue();
+ auto indexType = rewriter.getIndexType();
+ if (positionAttrs.size() > 1) {
+ auto nDVectorType = vectorType;
+ auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
+ nDVectorType.getElementType());
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter
+ .create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs)
+ .getResult();
+ }
+
+ // Remaining extraction of element from 1-D LLVM vector
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto constant = rewriter
+ .create<LLVM::ConstantOp>(
+ loc, lowering.convertType(indexType), position)
+ .getResult();
+ extracted =
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
+ .getResult();
+ rewriter.replaceOp(op, extracted);
+
+ return matchSuccess();
+ }
+};
+
+class OuterProductOpConversion : public LLVMOpLowering {
+public:
+ explicit OuterProductOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
+ auto *ctx = op->getContext();
+ auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
+ auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto llvmArrayOfVectType = lowering.convertType(
+ cast<vector::OuterProductOp>(op).getResult()->getType());
+ Value *desc =
+ rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
+ for (unsigned i = 0, e = rankV1; i < e; ++i) {
+ // Emit the following pattern:
+ // vec(a[i]) * b -> llvmStructOfVectType[i]
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ // shufflevector explicitly requires i32 /
+ auto attr = rewriter.getI32IntegerAttr(i);
+ SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
+ auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
+ auto *broadcasted =
+ rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
+ .getResult();
+ auto *multiplied =
+ rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
+ desc = rewriter
+ .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
+ multiplied,
+ positionAttr(rewriter, i))
+ .getResult();
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+/// Populate the given list with patterns that convert from Vector to LLVM.
+static void
+populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
+ ctx, converter);
+}
+
+namespace {
+struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+ void runOnModule();
+};
+} // namespace
+
+void LowerVectorToLLVMPass::runOnModule() {
+ // Convert to the LLVM IR dialect using the converter defined above.
+ OwningRewritePatternList patterns;
+ LLVMTypeConverter converter(&getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ if (failed(
+ applyPartialConversion(getModule(), target, patterns, &converter))) {
+ signalPassFailure();
+ }
+}
+
+ModulePassBase *mlir::createLowerVectorToLLVMPass() {
+ return new LowerVectorToLLVMPass();
+}
+
+static PassRegistration<LowerVectorToLLVMPass>
+ pass("vector-lower-to-llvm-dialect",
+ "Lower the operations from the vector dialect into the LLVM dialect");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
new file mode 100644
index 00000000000..f582de146ba
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
+
+func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+ %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
+ return %3 : vector<8xf32>
+}
+// CHECK-LABEL: vec_1d
+// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
+// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>">
+
+func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+ return %2 : vector<4x8xf32>
+}
+// CHECK-LABEL: vec_2d
+// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
+// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
+
+func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
+// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]"> \ No newline at end of file
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 26f8885a242..ff12852e347 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -43,6 +43,7 @@ set(LIBS
MLIRTestTransforms
MLIRSupport
MLIRVectorOps
+ MLIRVectorToLLVM
)
if(MLIR_CUDA_CONVERSIONS_ENABLED)
list(APPEND LIBS
OpenPOWER on IntegriCloud