diff options
author | Nicolas Vasilache <ntv@google.com> | 2019-08-12 04:08:26 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-12 04:08:57 -0700 |
commit | 252ada493276eace3e23802fb70ff3c7be53837d (patch) | |
tree | 3a40800f55e5486089cc8d6ae58dda343b88d474 /mlir/lib/Conversion | |
parent | 5290e8c36d4e4aac4d8ce2726f6d373e87501945 (diff) | |
download | bcm5719-llvm-252ada493276eace3e23802fb70ff3c7be53837d.tar.gz bcm5719-llvm-252ada493276eace3e23802fb70ff3c7be53837d.zip |
Add lowering of vector dialect to LLVM dialect.
This CL is step 3/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools.
This CL adds support for converting MLIR n-D vector types to (n-1)-D arrays of 1-D LLVM vectors and a conversion VectorToLLVM that lowers the `vector.extractelement` and `vector.outerproduct` instructions to the proper mix of `llvm.vectorshuffle`, `llvm.extractelement` and `llvm.mulf`.
This has been independently verified to produce proper avx2 code.
Input:
```
func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
%3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
return %3 : vector<8xf32>
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
vec_1d: # @vec_1d
# %bb.0:
vbroadcastss %xmm0, %ymm0
vmulps %ymm1, %ymm0, %ymm0
retq
```
PiperOrigin-RevId: 262895929
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 22 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt | 15 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 207 |
4 files changed, 235 insertions, 10 deletions
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 1ddd103f28e..6c14f5487a6 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(GPUToNVVM) add_subdirectory(GPUToSPIRV) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5bb281112f5..c62a5d8719d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -145,18 +145,20 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { return LLVM::LLVMType::getStructTy(llvmDialect, types); } -// Convert a 1D vector type to an LLVM vector type. +// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when +// n > 1. +// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and +// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { - if (type.getRank() != 1) { - auto *mlirContext = llvmDialect->getContext(); - emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported"); + auto elementType = unwrap(convertType(type.getElementType())); + if (!elementType) return {}; - } - - LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); - return elementType - ? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front()) - : Type(); + auto vectorType = + LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); + auto shape = type.getShape(); + for (int i = shape.size() - 2; i >= 0; --i) + vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); + return vectorType; } // Dispatch based on the actual type. Return null type on error. diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..a75b6c1e98a --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLLVM + VectorToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLLVM ${LIBS}) +target_link_libraries(MLIRVectorToLLVM ${LIBS}) diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp new file mode 100644 index 00000000000..bf90edba401 --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -0,0 +1,207 @@ +//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/VectorOps/VectorOps.h" + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; + +template <typename T> +static LLVM::LLVMType getPtrToElementType(T containerType, + LLVMTypeConverter &lowering) { + return lowering.convertType(containerType.getElementType()) + .template cast<LLVM::LLVMType>() + .getPointerTo(); +} + +// Create an array attribute containing integer attributes with values provided +// in `position`. +static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) { + SmallVector<Attribute, 4> attrs; + attrs.reserve(position.size()); + for (auto p : position) + attrs.push_back(builder.getI64IntegerAttr(p)); + return builder.getArrayAttr(attrs); +} + +class ExtractElementOpConversion : public LLVMOpLowering { +public: + explicit ExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractOp.vector()->getType().cast<VectorType>(); + auto resultType = extractOp.getResult()->getType(); + auto llvmResultType = lowering.convertType(resultType); + + auto positionArrayAttr = extractOp.position(); + // One-shot extraction of vector from array (only requires extractvalue). + if (resultType.isa<VectorType>()) { + Value *extracted = + rewriter + .create<LLVM::ExtractValueOp>(loc, llvmResultType, + adaptor.vector(), positionArrayAttr) + .getResult(); + rewriter.replaceOp(op, extracted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from struct. + auto *context = op->getContext(); + Value *extracted = adaptor.vector(); + auto positionAttrs = positionArrayAttr.getValue(); + auto indexType = rewriter.getIndexType(); + if (positionAttrs.size() > 1) { + auto nDVectorType = vectorType; + auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), + nDVectorType.getElementType()); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter + .create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs) + .getResult(); + } + + // Remaining extraction of element from 1-D LLVM vector + auto position = positionAttrs.back().cast<IntegerAttr>(); + auto constant = rewriter + .create<LLVM::ConstantOp>( + loc, lowering.convertType(indexType), position) + .getResult(); + extracted = + rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant) + .getResult(); + rewriter.replaceOp(op, extracted); + + return matchSuccess(); + } +}; + +class OuterProductOpConversion : public LLVMOpLowering { +public: + explicit OuterProductOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::OuterProductOpOperandAdaptor(operands); + auto *ctx = op->getContext(); + auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); + auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); + auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements(); + auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements(); + auto llvmArrayOfVectType = lowering.convertType( + cast<vector::OuterProductOp>(op).getResult()->getType()); + Value *desc = + rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult(); + for (unsigned i = 0, e = rankV1; i < e; ++i) { + // Emit the following pattern: + // vec(a[i]) * b -> llvmStructOfVectType[i] + Value *a = adaptor.lhs(), *b = adaptor.rhs(); + // shufflevector explicitly requires i32 / + auto attr = rewriter.getI32IntegerAttr(i); + SmallVector<Attribute, 4> broadcastAttr(rankV2, attr); + auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx); + auto *broadcasted = + rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr) + .getResult(); + auto *multiplied = + rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult(); + desc = rewriter + .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc, + multiplied, + positionAttr(rewriter, i)) + .getResult(); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// Populate the given list with patterns that convert from Vector to LLVM. +static void +populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( + ctx, converter); +} + +namespace { +struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { + void runOnModule(); +}; +} // namespace + +void LowerVectorToLLVMPass::runOnModule() { + // Convert to the LLVM IR dialect using the converter defined above. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns, &getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed( + applyPartialConversion(getModule(), target, patterns, &converter))) { + signalPassFailure(); + } +} + +ModulePassBase *mlir::createLowerVectorToLLVMPass() { + return new LowerVectorToLLVMPass(); +} + +static PassRegistration<LowerVectorToLLVMPass> + pass("vector-lower-to-llvm-dialect", + "Lower the operations from the vector dialect into the LLVM dialect"); |