summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp124
1 files changed, 100 insertions, 24 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8adc415f820..17fb93396d9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -49,6 +49,18 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
.getPointerTo();
}
+// Helper to reduce vector type by one rank at front.
+static VectorType reducedVectorTypeFront(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+}
+
+// Helper to reduce vector type by *all* but one rank at back.
+static VectorType reducedVectorTypeBack(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+}
+
class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -135,8 +147,9 @@ private:
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
- Value *expand = expandRanks(value, loc, srcVectorType,
- reducedVectorType(dstVectorType), rewriter);
+ Value *expand =
+ expandRanks(value, loc, srcVectorType,
+ reducedVectorTypeFront(dstVectorType), rewriter);
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
@@ -183,8 +196,8 @@ private:
result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
}
} else {
- VectorType redSrcType = reducedVectorType(srcVectorType);
- VectorType redDstType = reducedVectorType(dstVectorType);
+ VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
+ VectorType redDstType = reducedVectorTypeFront(dstVectorType);
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
@@ -226,18 +239,12 @@ private:
return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
rewriter.getI64ArrayAttr(pos));
}
-
- // Helper to reduce vector type by one rank.
- static VectorType reducedVectorType(VectorType tp) {
- assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
- }
};
-class VectorExtractElementOpConversion : public LLVMOpLowering {
+class VectorExtractOpConversion : public LLVMOpLowering {
public:
- explicit VectorExtractElementOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
+ explicit VectorExtractOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
@@ -247,11 +254,15 @@ public:
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
- auto vectorType = extractOp.vector()->getType().cast<VectorType>();
+ auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);
-
auto positionArrayAttr = extractOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -260,15 +271,12 @@ public:
return matchSuccess();
}
- // Potential extraction of 1-D vector from struct.
+ // Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
- auto nDVectorType = vectorType;
- auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
- nDVectorType.getElementType());
+ auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -278,8 +286,8 @@ public:
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, lowering.convertType(i32Type), position);
+ auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
@@ -288,6 +296,73 @@ public:
}
};
+class VectorInsertOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorInsertOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::InsertOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::InsertOpOperandAdaptor(operands);
+ auto insertOp = cast<vector::InsertOp>(op);
+ auto sourceType = insertOp.getSourceType();
+ auto destVectorType = insertOp.getDestVectorType();
+ auto llvmResultType = lowering.convertType(destVectorType);
+ auto positionArrayAttr = insertOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
+ // One-shot insertion of a vector into an array (only requires insertvalue).
+ if (sourceType.isa<VectorType>()) {
+ Value *inserted = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmResultType, adaptor.dest(), adaptor.source(),
+ positionArrayAttr);
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from array.
+ auto *context = op->getContext();
+ Value *extracted = adaptor.dest();
+ auto positionAttrs = positionArrayAttr.getValue();
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto oneDVectorType = destVectorType;
+ if (positionAttrs.size() > 1) {
+ oneDVectorType = reducedVectorTypeBack(destVectorType);
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
+ }
+
+ // Insertion of an element into a 1-D LLVM vector.
+ auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
+ Value *inserted = rewriter.create<LLVM::InsertElementOp>(
+ loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
+ constant);
+
+ // Potential insertion of resulting 1-D vector into array.
+ if (positionAttrs.size() > 1) {
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
+ adaptor.dest(), inserted,
+ nMinusOnePositionAttrs);
+ }
+
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+};
+
class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
@@ -431,8 +506,9 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
+ patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion,
+ VectorInsertOpConversion, VectorOuterProductOpConversion,
+ VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
OpenPOWER on IntegriCloud