summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-16 03:52:56 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-16 03:53:26 -0700
commitf826ceef3ce8bfea1b78ab7bb2c60c53eb13729a (patch)
tree574e65c0dbb0f8f89c1219a7f4ccfcf0547d20ba /mlir/lib/Conversion/VectorToLLVM
parentcc980aa41651c2cbfcbd9048fb0788f4aa9ae475 (diff)
downloadbcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.tar.gz
bcm5719-llvm-f826ceef3ce8bfea1b78ab7bb2c60c53eb13729a.zip
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction. When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it. In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...). This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage. This has been independently verified to result in proper fma instructions for haswell as follows. Input: ``` func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> { %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32> return %2 : vector<17x8xf32> } } ``` Command: ``` mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2 ``` Output: ``` outerproduct_add: # @outerproduct_add # %bb.0: ... vmovaps 112(%rbp), %ymm8 vbroadcastss %xmm0, %ymm0 ... vbroadcastss 64(%rbp), %ymm15 vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem ... vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem ... ``` PiperOrigin-RevId: 263743359
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp90
1 files changed, 43 insertions, 47 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index bf90edba401..1e4b8ca6419 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -79,11 +79,8 @@ public:
auto positionArrayAttr = extractOp.position();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted =
- rewriter
- .create<LLVM::ExtractValueOp>(loc, llvmResultType,
- adaptor.vector(), positionArrayAttr)
- .getResult();
+ Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
@@ -92,29 +89,24 @@ public:
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto indexType = rewriter.getIndexType();
+ auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
- extracted = rewriter
- .create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
- nMinusOnePositionAttrs)
- .getResult();
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter
- .create<LLVM::ConstantOp>(
- loc, lowering.convertType(indexType), position)
- .getResult();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(i32Type), position);
extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
- .getResult();
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -134,32 +126,38 @@ public:
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
- auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
- auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
- auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
- auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+ auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc =
- rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
- for (unsigned i = 0, e = rankV1; i < e; ++i) {
- // Emit the following pattern:
- // vec(a[i]) * b -> llvmStructOfVectType[i]
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- // shufflevector explicitly requires i32 /
- auto attr = rewriter.getI32IntegerAttr(i);
- SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
- auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
- auto *broadcasted =
- rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
- .getResult();
- auto *multiplied =
- rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
- desc = rewriter
- .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
- multiplied,
- positionAttr(rewriter, i))
- .getResult();
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value *, 8> lhs, accs;
+ lhs.reserve(rankLHS);
+ accs.reserve(rankLHS);
+ for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+ // shufflevector explicitly requires i32.
+ auto attr = rewriter.getI32IntegerAttr(d);
+ SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+ auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+ Value *aD = nullptr, *accD = nullptr;
+ // 1. Broadcast the element a[d] into vector aD.
+ aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+ // 2. If acc is present, extract 1-d vector acc[d] into accD.
+ if (acc)
+ accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc,
+ positionAttr(rewriter, d));
+ // 3. Compute aD outer b (plus accD, if relevant).
+ Value *aOuterbD =
+ accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD)
+ .getResult()
+ : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+ // 4. Insert as value `d` in the descriptor.
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
@@ -167,12 +165,10 @@ public:
};
/// Populate the given list with patterns that convert from Vector to LLVM.
-static void
-populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+void mlir::populateVectorToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
- ctx, converter);
+ converter.getDialect()->getContext(), converter);
}
namespace {
@@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
- populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
OpenPOWER on IntegriCloud