diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/g3doc/ConversionToLLVMDialect.md | 13 | ||||
| -rw-r--r-- | mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h | 26 | ||||
| -rw-r--r-- | mlir/lib/Conversion/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 22 | ||||
| -rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt | 15 | ||||
| -rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 207 | ||||
| -rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 33 | ||||
| -rw-r--r-- | mlir/tools/mlir-opt/CMakeLists.txt | 1 |
8 files changed, 303 insertions, 15 deletions
diff --git a/mlir/g3doc/ConversionToLLVMDialect.md b/mlir/g3doc/ConversionToLLVMDialect.md index 9da27c4ca68..a2898e022c6 100644 --- a/mlir/g3doc/ConversionToLLVMDialect.md +++ b/mlir/g3doc/ConversionToLLVMDialect.md @@ -39,11 +39,14 @@ object. For example, on x86-64 CPUs it converts to `!llvm.type<"i64">`. ### Vector Types LLVM IR only supports *one-dimensional* vectors, unlike MLIR where vectors can -be multi-dimensional. MLIR vectors are converted to LLVM IR vectors of the same -size with element type converted using these conversion rules. Vector types -cannot be nested in either IR. - -For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">`. +be multi-dimensional. Vector types cannot be nested in either IR. In the +one-dimensional case, MLIR vectors are converted to LLVM IR vectors of the same +size with element type converted using these conversion rules. In the +n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types +of one-dimensional vectors. + +For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and +`vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. ### Memref Types diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h new file mode 100644 index 00000000000..39b7ee2d03f --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h @@ -0,0 +1,26 @@ +//===- VectorToLLVM.h - Pass converting vector to LLVM dialect --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_ +#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_ + +namespace mlir { +class ModulePassBase; + +ModulePassBase *createLowerVectorToLLVMPass(); +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 1ddd103f28e..6c14f5487a6 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(GPUToNVVM) add_subdirectory(GPUToSPIRV) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5bb281112f5..c62a5d8719d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -145,18 +145,20 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { return LLVM::LLVMType::getStructTy(llvmDialect, types); } -// Convert a 1D vector type to an LLVM vector type. +// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when +// n > 1. +// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and +// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { - if (type.getRank() != 1) { - auto *mlirContext = llvmDialect->getContext(); - emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported"); + auto elementType = unwrap(convertType(type.getElementType())); + if (!elementType) return {}; - } - - LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); - return elementType - ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front()) - : Type(); + auto vectorType = + LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); + auto shape = type.getShape(); + for (int i = shape.size() - 2; i >= 0; --i) + vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); + return vectorType; } // Dispatch based on the actual type. Return null type on error. diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..a75b6c1e98a --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLLVM + VectorToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLLVM ${LIBS}) +target_link_libraries(MLIRVectorToLLVM ${LIBS}) diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp new file mode 100644 index 00000000000..bf90edba401 --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -0,0 +1,207 @@ +//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/VectorOps/VectorOps.h" + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; + +template <typename T> +static LLVM::LLVMType getPtrToElementType(T containerType, + LLVMTypeConverter &lowering) { + return lowering.convertType(containerType.getElementType()) + .template cast<LLVM::LLVMType>() + .getPointerTo(); +} + +// Create an array attribute containing integer attributes with values provided +// in `position`. +static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) { + SmallVector<Attribute, 4> attrs; + attrs.reserve(position.size()); + for (auto p : position) + attrs.push_back(builder.getI64IntegerAttr(p)); + return builder.getArrayAttr(attrs); +} + +class ExtractElementOpConversion : public LLVMOpLowering { +public: + explicit ExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractOp.vector()->getType().cast<VectorType>(); + auto resultType = extractOp.getResult()->getType(); + auto llvmResultType = lowering.convertType(resultType); + + auto positionArrayAttr = extractOp.position(); + // One-shot extraction of vector from array (only requires extractvalue). + if (resultType.isa<VectorType>()) { + Value *extracted = + rewriter + .create<LLVM::ExtractValueOp>(loc, llvmResultType, + adaptor.vector(), positionArrayAttr) + .getResult(); + rewriter.replaceOp(op, extracted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from struct. + auto *context = op->getContext(); + Value *extracted = adaptor.vector(); + auto positionAttrs = positionArrayAttr.getValue(); + auto indexType = rewriter.getIndexType(); + if (positionAttrs.size() > 1) { + auto nDVectorType = vectorType; + auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), + nDVectorType.getElementType()); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter + .create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs) + .getResult(); + } + + // Remaining extraction of element from 1-D LLVM vector + auto position = positionAttrs.back().cast<IntegerAttr>(); + auto constant = rewriter + .create<LLVM::ConstantOp>( + loc, lowering.convertType(indexType), position) + .getResult(); + extracted = + rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant) + .getResult(); + rewriter.replaceOp(op, extracted); + + return matchSuccess(); + } +}; + +class OuterProductOpConversion : public LLVMOpLowering { +public: + explicit OuterProductOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::OuterProductOpOperandAdaptor(operands); + auto *ctx = op->getContext(); + auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); + auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); + auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements(); + auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements(); + auto llvmArrayOfVectType = lowering.convertType( + cast<vector::OuterProductOp>(op).getResult()->getType()); + Value *desc = + rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult(); + for (unsigned i = 0, e = rankV1; i < e; ++i) { + // Emit the following pattern: + // vec(a[i]) * b -> llvmStructOfVectType[i] + Value *a = adaptor.lhs(), *b = adaptor.rhs(); + // shufflevector explicitly requires i32 / + auto attr = rewriter.getI32IntegerAttr(i); + SmallVector<Attribute, 4> broadcastAttr(rankV2, attr); + auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx); + auto *broadcasted = + rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr) + .getResult(); + auto *multiplied = + rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult(); + desc = rewriter + .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc, + multiplied, + positionAttr(rewriter, i)) + .getResult(); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// Populate the given list with patterns that convert from Vector to LLVM. +static void +populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( + ctx, converter); +} + +namespace { +struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { + void runOnModule(); +}; +} // namespace + +void LowerVectorToLLVMPass::runOnModule() { + // Convert to the LLVM IR dialect using the converter defined above. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns, &getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed( + applyPartialConversion(getModule(), target, patterns, &converter))) { + signalPassFailure(); + } +} + +ModulePassBase *mlir::createLowerVectorToLLVMPass() { + return new LowerVectorToLLVMPass(); +} + +static PassRegistration<LowerVectorToLLVMPass> + pass("vector-lower-to-llvm-dialect", + "Lower the operations from the vector dialect into the LLVM dialect"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir new file mode 100644 index 00000000000..f582de146ba --- /dev/null +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s + +func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { + %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> + %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32> + return %3 : vector<8xf32> +} +// CHECK-LABEL: vec_1d +// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> +// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>"> + +func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> { + %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32> + return %2 : vector<4x8xf32> +} +// CHECK-LABEL: vec_2d +// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]"> +// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]"> + +func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> { + %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32> + return %0 : vector<8x16xf32> +} +// CHECK-LABEL: vec_3d +// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]"> +// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">
\ No newline at end of file diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 26f8885a242..ff12852e347 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -43,6 +43,7 @@ set(LIBS MLIRTestTransforms MLIRSupport MLIRVectorOps + MLIRVectorToLLVM ) if(MLIR_CUDA_CONVERSIONS_ENABLED) list(APPEND LIBS |

