summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp111
1 files changed, 53 insertions, 58 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 56005220d3f..b48930c4dda 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -53,10 +53,9 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
}
// Helper that picks the proper sequence for inserting.
-static ValuePtr insertOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc,
- ValuePtr val1, ValuePtr val2, Type llvmType,
- int64_t rank, int64_t pos) {
+static Value insertOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val1,
+ Value val2, Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -70,10 +69,9 @@ static ValuePtr insertOne(ConversionPatternRewriter &rewriter,
}
// Helper that picks the proper sequence for extracting.
-static ValuePtr extractOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc,
- ValuePtr val, Type llvmType, int64_t rank,
- int64_t pos) {
+static Value extractOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val,
+ Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -94,7 +92,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
@@ -122,9 +120,9 @@ private:
// ops once all insert/extract/shuffle operations
// are available with lowering implemention.
//
- ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType,
- ConversionPatternRewriter &rewriter) const {
+ Value expandRanks(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType,
+ ConversionPatternRewriter &rewriter) const {
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
// Determine rank of source and destination.
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
@@ -161,24 +159,22 @@ private:
// becomes:
// x = [s,s]
// v = [x,x,x,x]
- ValuePtr duplicateOneRank(ValuePtr value, Location loc,
- VectorType srcVectorType, VectorType dstVectorType,
- int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
- ValuePtr undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- ValuePtr expand =
+ Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value expand =
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
- ValuePtr expand =
- expandRanks(value, loc, srcVectorType,
- reducedVectorTypeFront(dstVectorType), rewriter);
- ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value expand = expandRanks(value, loc, srcVectorType,
+ reducedVectorTypeFront(dstVectorType), rewriter);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
@@ -203,20 +199,19 @@ private:
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
// a = [x, y]
// etc.
- ValuePtr stretchOneRank(ValuePtr value, Location loc,
- VectorType srcVectorType, VectorType dstVectorType,
- int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
- ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
- ValuePtr one =
+ Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
- ValuePtr expand =
+ Value expand =
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
@@ -227,9 +222,9 @@ private:
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
- ValuePtr one =
+ Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
- ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
+ Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
}
@@ -245,7 +240,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
@@ -269,23 +264,23 @@ public:
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
- ValuePtr shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
}
// For all other cases, insert the individual values individually.
- ValuePtr insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (auto en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
- ValuePtr value = adaptor.v1();
+ Value value = adaptor.v1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
value = adaptor.v2();
}
- ValuePtr extract =
+ Value extract =
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
rank, insPos++);
@@ -303,7 +298,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
@@ -328,7 +323,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
@@ -344,7 +339,7 @@ public:
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- ValuePtr extracted = rewriter.create<LLVM::ExtractValueOp>(
+ Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -352,7 +347,7 @@ public:
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- ValuePtr extracted = adaptor.vector();
+ Value extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
@@ -383,7 +378,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
@@ -408,7 +403,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
@@ -424,7 +419,7 @@ public:
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
- ValuePtr inserted = rewriter.create<LLVM::InsertValueOp>(
+ Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
@@ -433,7 +428,7 @@ public:
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- ValuePtr extracted = adaptor.dest();
+ Value extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
@@ -449,7 +444,7 @@ public:
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
- ValuePtr inserted = rewriter.create<LLVM::InsertElementOp>(
+ Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);
@@ -475,7 +470,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
@@ -486,10 +481,10 @@ public:
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
- ValuePtr a = adaptor.lhs(), b = adaptor.rhs();
- ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
- SmallVector<ValuePtr, 8> lhs, accs;
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value a = adaptor.lhs(), b = adaptor.rhs();
+ Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value, 8> lhs, accs;
lhs.reserve(rankLHS);
accs.reserve(rankLHS);
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
@@ -497,7 +492,7 @@ public:
auto attr = rewriter.getI32IntegerAttr(d);
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
- ValuePtr aD = nullptr, accD = nullptr;
+ Value aD = nullptr, accD = nullptr;
// 1. Broadcast the element a[d] into vector aD.
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
// 2. If acc is present, extract 1-d vector acc[d] into accD.
@@ -505,7 +500,7 @@ public:
accD = rewriter.create<LLVM::ExtractValueOp>(
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
// 3. Compute aD outer b (plus accD, if relevant).
- ValuePtr aOuterbD =
+ Value aOuterbD =
accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
.getResult()
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
@@ -527,7 +522,7 @@ public:
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
@@ -576,12 +571,12 @@ public:
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
- ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc);
+ Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
@@ -627,7 +622,7 @@ public:
// TODO(ajcbik): rely solely on libc in future? something else?
//
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpOperandAdaptor(operands);
@@ -657,7 +652,7 @@ public:
private:
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
- ValuePtr value, VectorType vectorType, Operation *printer,
+ Value value, VectorType vectorType, Operation *printer,
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
@@ -673,7 +668,7 @@ private:
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
auto llvmType = lowering.convertType(
rank > 1 ? reducedType : vectorType.getElementType());
- ValuePtr nestedVal =
+ Value nestedVal =
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
if (d != dim - 1)
OpenPOWER on IntegriCloud