summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp90
1 files changed, 43 insertions, 47 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index bf90edba401..1e4b8ca6419 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -79,11 +79,8 @@ public:
auto positionArrayAttr = extractOp.position();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted =
- rewriter
- .create<LLVM::ExtractValueOp>(loc, llvmResultType,
- adaptor.vector(), positionArrayAttr)
- .getResult();
+ Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
@@ -92,29 +89,24 @@ public:
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto indexType = rewriter.getIndexType();
+ auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
- extracted = rewriter
- .create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
- nMinusOnePositionAttrs)
- .getResult();
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter
- .create<LLVM::ConstantOp>(
- loc, lowering.convertType(indexType), position)
- .getResult();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(i32Type), position);
extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
- .getResult();
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -134,32 +126,38 @@ public:
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
- auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
- auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
- auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
- auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+ auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc =
- rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
- for (unsigned i = 0, e = rankV1; i < e; ++i) {
- // Emit the following pattern:
- // vec(a[i]) * b -> llvmStructOfVectType[i]
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- // shufflevector explicitly requires i32 /
- auto attr = rewriter.getI32IntegerAttr(i);
- SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
- auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
- auto *broadcasted =
- rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
- .getResult();
- auto *multiplied =
- rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
- desc = rewriter
- .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
- multiplied,
- positionAttr(rewriter, i))
- .getResult();
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value *, 8> lhs, accs;
+ lhs.reserve(rankLHS);
+ accs.reserve(rankLHS);
+ for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+ // shufflevector explicitly requires i32.
+ auto attr = rewriter.getI32IntegerAttr(d);
+ SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+ auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+ Value *aD = nullptr, *accD = nullptr;
+ // 1. Broadcast the element a[d] into vector aD.
+ aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+ // 2. If acc is present, extract 1-d vector acc[d] into accD.
+ if (acc)
+ accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc,
+ positionAttr(rewriter, d));
+ // 3. Compute aD outer b (plus accD, if relevant).
+ Value *aOuterbD =
+ accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD)
+ .getResult()
+ : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+ // 4. Insert as value `d` in the descriptor.
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
@@ -167,12 +165,10 @@ public:
};
/// Populate the given list with patterns that convert from Vector to LLVM.
-static void
-populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+void mlir::populateVectorToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
- ctx, converter);
+ converter.getDialect()->getContext(), converter);
}
namespace {
@@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
- populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
OpenPOWER on IntegriCloud