diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 108 |
1 files changed, 56 insertions, 52 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9ec8ec6f88d..5099cb01bbc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -62,9 +62,10 @@ static VectorType reducedVectorTypeBack(VectorType tp) { } // Helper that picks the proper sequence for inserting. -static Value *insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value *val1, - Value *val2, Type llvmType, int64_t rank, int64_t pos) { +static ValuePtr insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, + ValuePtr val1, ValuePtr val2, Type llvmType, + int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create<LLVM::ConstantOp>( @@ -78,9 +79,10 @@ static Value *insertOne(ConversionPatternRewriter &rewriter, } // Helper that picks the proper sequence for extracting. -static Value *extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value *val, - Type llvmType, int64_t rank, int64_t pos) { +static ValuePtr extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, + ValuePtr val, Type llvmType, int64_t rank, + int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create<LLVM::ConstantOp>( @@ -101,7 +103,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast<vector::BroadcastOp>(op); VectorType dstVectorType = broadcastOp.getVectorType(); @@ -129,9 +131,9 @@ private: // ops once all insert/extract/shuffle operations // are available with lowering implemention. // - Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, - ConversionPatternRewriter &rewriter) const { + ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { assert((dstVectorType != nullptr) && "invalid result type in broadcast"); // Determine rank of source and destination. int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; @@ -168,23 +170,24 @@ private: // becomes: // x = [s,s] // v = [x,x,x,x] - Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + ValuePtr duplicateOneRank(ValuePtr value, Location loc, + VectorType srcVectorType, VectorType dstVectorType, + int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { - Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); - Value *expand = + ValuePtr undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); + ValuePtr expand = insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - Value *expand = + ValuePtr expand = expandRanks(value, loc, srcVectorType, reducedVectorTypeFront(dstVectorType), rewriter); - Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); @@ -209,19 +212,20 @@ private: // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> // a = [x, y] // etc. - Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + ValuePtr stretchOneRank(ValuePtr value, Location loc, + VectorType srcVectorType, VectorType dstVectorType, + int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); - Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - Value *one = + ValuePtr one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); - Value *expand = + ValuePtr expand = insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( @@ -232,9 +236,9 @@ private: Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; - Value *one = + ValuePtr one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); - Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } @@ -250,7 +254,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ShuffleOpOperandAdaptor(operands); @@ -274,23 +278,23 @@ public: // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { - Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>( + ValuePtr shuffle = rewriter.create<LLVM::ShuffleVectorOp>( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); return matchSuccess(); } // For all other cases, insert the individual values individually. - Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); + ValuePtr insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast<IntegerAttr>().getInt(); - Value *value = adaptor.v1(); + ValuePtr value = adaptor.v1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.v2(); } - Value *extract = + ValuePtr extract = extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, rank, insPos++); @@ -308,7 +312,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractEltOp = cast<vector::ExtractElementOp>(op); @@ -333,7 +337,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); @@ -349,7 +353,7 @@ public: // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { - Value *extracted = rewriter.create<LLVM::ExtractValueOp>( + ValuePtr extracted = rewriter.create<LLVM::ExtractValueOp>( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -357,7 +361,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - Value *extracted = adaptor.vector(); + ValuePtr extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -388,7 +392,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpOperandAdaptor(operands); auto insertEltOp = cast<vector::InsertElementOp>(op); @@ -413,7 +417,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpOperandAdaptor(operands); @@ -429,7 +433,7 @@ public: // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa<VectorType>()) { - Value *inserted = rewriter.create<LLVM::InsertValueOp>( + ValuePtr inserted = rewriter.create<LLVM::InsertValueOp>( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); @@ -438,7 +442,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - Value *extracted = adaptor.dest(); + ValuePtr extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast<IntegerAttr>(); auto oneDVectorType = destVectorType; @@ -454,7 +458,7 @@ public: // Insertion of an element into a 1-D LLVM vector. auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); - Value *inserted = rewriter.create<LLVM::InsertElementOp>( + ValuePtr inserted = rewriter.create<LLVM::InsertElementOp>( loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), constant); @@ -480,7 +484,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); @@ -491,10 +495,10 @@ public: auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast<vector::OuterProductOp>(op).getResult()->getType()); - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); - Value *a = adaptor.lhs(), *b = adaptor.rhs(); - Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); - SmallVector<Value *, 8> lhs, accs; + ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); + ValuePtr a = adaptor.lhs(), b = adaptor.rhs(); + ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector<ValuePtr, 8> lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { @@ -502,7 +506,7 @@ public: auto attr = rewriter.getI32IntegerAttr(d); SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); - Value *aD = nullptr, *accD = nullptr; + ValuePtr aD = nullptr, accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. @@ -510,7 +514,7 @@ public: accD = rewriter.create<LLVM::ExtractValueOp>( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). - Value *aOuterbD = + ValuePtr aOuterbD = accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) .getResult() : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); @@ -532,7 +536,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); @@ -581,12 +585,12 @@ public: auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); + ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. @@ -632,7 +636,7 @@ public: // TODO(ajcbik): rely solely on libc in future? something else? // PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast<vector::PrintOp>(op); auto adaptor = vector::PrintOpOperandAdaptor(operands); @@ -662,7 +666,7 @@ public: private: void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, - Value *value, VectorType vectorType, Operation *printer, + ValuePtr value, VectorType vectorType, Operation *printer, int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { @@ -678,7 +682,7 @@ private: rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; auto llvmType = lowering.convertType( rank > 1 ? reducedType : vectorType.getElementType()); - Value *nestedVal = + ValuePtr nestedVal = extractOne(rewriter, lowering, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); if (d != dim - 1) |