summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp108
1 files changed, 56 insertions, 52 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9ec8ec6f88d..5099cb01bbc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -62,9 +62,10 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
}
// Helper that picks the proper sequence for inserting.
-static Value *insertOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value *val1,
- Value *val2, Type llvmType, int64_t rank, int64_t pos) {
+static ValuePtr insertOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc,
+ ValuePtr val1, ValuePtr val2, Type llvmType,
+ int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -78,9 +79,10 @@ static Value *insertOne(ConversionPatternRewriter &rewriter,
}
// Helper that picks the proper sequence for extracting.
-static Value *extractOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value *val,
- Type llvmType, int64_t rank, int64_t pos) {
+static ValuePtr extractOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc,
+ ValuePtr val, Type llvmType, int64_t rank,
+ int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -101,7 +103,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
@@ -129,9 +131,9 @@ private:
// ops once all insert/extract/shuffle operations
// are available with lowering implemention.
//
- Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType,
- ConversionPatternRewriter &rewriter) const {
+ ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType,
+ ConversionPatternRewriter &rewriter) const {
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
// Determine rank of source and destination.
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
@@ -168,23 +170,24 @@ private:
// becomes:
// x = [s,s]
// v = [x,x,x,x]
- Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ ValuePtr duplicateOneRank(ValuePtr value, Location loc,
+ VectorType srcVectorType, VectorType dstVectorType,
+ int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
- Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- Value *expand =
+ ValuePtr undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ ValuePtr expand =
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
- Value *expand =
+ ValuePtr expand =
expandRanks(value, loc, srcVectorType,
reducedVectorTypeFront(dstVectorType), rewriter);
- Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
@@ -209,19 +212,20 @@ private:
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
// a = [x, y]
// etc.
- Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ ValuePtr stretchOneRank(ValuePtr value, Location loc,
+ VectorType srcVectorType, VectorType dstVectorType,
+ int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
- Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
- Value *one =
+ ValuePtr one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
- Value *expand =
+ ValuePtr expand =
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
@@ -232,9 +236,9 @@ private:
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
- Value *one =
+ ValuePtr one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
- Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
+ ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
}
@@ -250,7 +254,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
@@ -274,23 +278,23 @@ public:
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
- Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ ValuePtr shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
}
// For all other cases, insert the individual values individually.
- Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ ValuePtr insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (auto en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
- Value *value = adaptor.v1();
+ ValuePtr value = adaptor.v1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
value = adaptor.v2();
}
- Value *extract =
+ ValuePtr extract =
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
rank, insPos++);
@@ -308,7 +312,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
@@ -333,7 +337,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
@@ -349,7 +353,7 @@ public:
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ ValuePtr extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -357,7 +361,7 @@ public:
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- Value *extracted = adaptor.vector();
+ ValuePtr extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
@@ -388,7 +392,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
@@ -413,7 +417,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
@@ -429,7 +433,7 @@ public:
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
- Value *inserted = rewriter.create<LLVM::InsertValueOp>(
+ ValuePtr inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
@@ -438,7 +442,7 @@ public:
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- Value *extracted = adaptor.dest();
+ ValuePtr extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
@@ -454,7 +458,7 @@ public:
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
- Value *inserted = rewriter.create<LLVM::InsertElementOp>(
+ ValuePtr inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);
@@ -480,7 +484,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
@@ -491,10 +495,10 @@ public:
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
- SmallVector<Value *, 8> lhs, accs;
+ ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ ValuePtr a = adaptor.lhs(), b = adaptor.rhs();
+ ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<ValuePtr, 8> lhs, accs;
lhs.reserve(rankLHS);
accs.reserve(rankLHS);
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
@@ -502,7 +506,7 @@ public:
auto attr = rewriter.getI32IntegerAttr(d);
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
- Value *aD = nullptr, *accD = nullptr;
+ ValuePtr aD = nullptr, accD = nullptr;
// 1. Broadcast the element a[d] into vector aD.
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
// 2. If acc is present, extract 1-d vector acc[d] into accD.
@@ -510,7 +514,7 @@ public:
accD = rewriter.create<LLVM::ExtractValueOp>(
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
// 3. Compute aD outer b (plus accD, if relevant).
- Value *aOuterbD =
+ ValuePtr aOuterbD =
accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
.getResult()
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
@@ -532,7 +536,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
@@ -581,12 +585,12 @@ public:
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+ ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
- Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
+ ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
@@ -632,7 +636,7 @@ public:
// TODO(ajcbik): rely solely on libc in future? something else?
//
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpOperandAdaptor(operands);
@@ -662,7 +666,7 @@ public:
private:
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
- Value *value, VectorType vectorType, Operation *printer,
+ ValuePtr value, VectorType vectorType, Operation *printer,
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
@@ -678,7 +682,7 @@ private:
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
auto llvmType = lowering.convertType(
rank > 1 ? reducedType : vectorType.getElementType());
- Value *nestedVal =
+ ValuePtr nestedVal =
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
if (d != dim - 1)
OpenPOWER on IntegriCloud