diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 90 |
1 files changed, 43 insertions, 47 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index bf90edba401..1e4b8ca6419 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -79,11 +79,8 @@ public: auto positionArrayAttr = extractOp.position(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { - Value *extracted = - rewriter - .create<LLVM::ExtractValueOp>(loc, llvmResultType, - adaptor.vector(), positionArrayAttr) - .getResult(); + Value *extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); } @@ -92,29 +89,24 @@ public: auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); - auto indexType = rewriter.getIndexType(); + auto i32Type = rewriter.getIntegerType(32); if (positionAttrs.size() > 1) { auto nDVectorType = vectorType; auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), nDVectorType.getElementType()); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); - extracted = rewriter - .create<LLVM::ExtractValueOp>( - loc, lowering.convertType(oneDVectorType), extracted, - nMinusOnePositionAttrs) - .getResult(); + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast<IntegerAttr>(); - auto constant = rewriter - .create<LLVM::ConstantOp>( - loc, lowering.convertType(indexType), position) - .getResult(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(i32Type), position); extracted = - rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant) - .getResult(); + rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -134,32 +126,38 @@ public: auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); auto *ctx = op->getContext(); - auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); - auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); - auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements(); - auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements(); + auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); + auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); + auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); + auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast<vector::OuterProductOp>(op).getResult()->getType()); - Value *desc = - rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult(); - for (unsigned i = 0, e = rankV1; i < e; ++i) { - // Emit the following pattern: - // vec(a[i]) * b -> llvmStructOfVectType[i] - Value *a = adaptor.lhs(), *b = adaptor.rhs(); - // shufflevector explicitly requires i32 / - auto attr = rewriter.getI32IntegerAttr(i); - SmallVector<Attribute, 4> broadcastAttr(rankV2, attr); - auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx); - auto *broadcasted = - rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr) - .getResult(); - auto *multiplied = - rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult(); - desc = rewriter - .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc, - multiplied, - positionAttr(rewriter, i)) - .getResult(); + Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); + Value *a = adaptor.lhs(), *b = adaptor.rhs(); + Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector<Value *, 8> lhs, accs; + lhs.reserve(rankLHS); + accs.reserve(rankLHS); + for (unsigned d = 0, e = rankLHS; d < e; ++d) { + // shufflevector explicitly requires i32. + auto attr = rewriter.getI32IntegerAttr(d); + SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); + auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); + Value *aD = nullptr, *accD = nullptr; + // 1. Broadcast the element a[d] into vector aD. + aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); + // 2. If acc is present, extract 1-d vector acc[d] into accD. + if (acc) + accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc, + positionAttr(rewriter, d)); + // 3. Compute aD outer b (plus accD, if relevant). + Value *aOuterbD = + accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD) + .getResult() + : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); + // 4. Insert as value `d` in the descriptor. + desc = rewriter.create<LLVM::InsertValueOp>( + loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d)); } rewriter.replaceOp(op, desc); return matchSuccess(); @@ -167,12 +165,10 @@ public: }; /// Populate the given list with patterns that convert from Vector to LLVM. -static void -populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns, - MLIRContext *ctx) { +void mlir::populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( - ctx, converter); + converter.getDialect()->getContext(), converter); } namespace { @@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() { // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); - populateVectorToLLVMConversionPatterns(converter, patterns, &getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); |