diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 111 |
1 files changed, 53 insertions, 58 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 56005220d3f..b48930c4dda 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -53,10 +53,9 @@ static VectorType reducedVectorTypeBack(VectorType tp) { } // Helper that picks the proper sequence for inserting. -static ValuePtr insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, - ValuePtr val1, ValuePtr val2, Type llvmType, - int64_t rank, int64_t pos) { +static Value insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val1, + Value val2, Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create<LLVM::ConstantOp>( @@ -70,10 +69,9 @@ static ValuePtr insertOne(ConversionPatternRewriter &rewriter, } // Helper that picks the proper sequence for extracting. -static ValuePtr extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, - ValuePtr val, Type llvmType, int64_t rank, - int64_t pos) { +static Value extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val, + Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create<LLVM::ConstantOp>( @@ -94,7 +92,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast<vector::BroadcastOp>(op); VectorType dstVectorType = broadcastOp.getVectorType(); @@ -122,9 +120,9 @@ private: // ops once all insert/extract/shuffle operations // are available with lowering implemention. // - ValuePtr expandRanks(ValuePtr value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, - ConversionPatternRewriter &rewriter) const { + Value expandRanks(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { assert((dstVectorType != nullptr) && "invalid result type in broadcast"); // Determine rank of source and destination. int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; @@ -161,24 +159,22 @@ private: // becomes: // x = [s,s] // v = [x,x,x,x] - ValuePtr duplicateOneRank(ValuePtr value, Location loc, - VectorType srcVectorType, VectorType dstVectorType, - int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { - ValuePtr undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); - ValuePtr expand = + Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value expand = insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - ValuePtr expand = - expandRanks(value, loc, srcVectorType, - reducedVectorTypeFront(dstVectorType), rewriter); - ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value expand = expandRanks(value, loc, srcVectorType, + reducedVectorTypeFront(dstVectorType), rewriter); + Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); @@ -203,20 +199,19 @@ private: // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> // a = [x, y] // etc. - ValuePtr stretchOneRank(ValuePtr value, Location loc, - VectorType srcVectorType, VectorType dstVectorType, - int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { + Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); - ValuePtr result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - ValuePtr one = + Value one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); - ValuePtr expand = + Value expand = insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( @@ -227,9 +222,9 @@ private: Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; - ValuePtr one = + Value one = extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); - ValuePtr expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); result = insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } @@ -245,7 +240,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ShuffleOpOperandAdaptor(operands); @@ -269,23 +264,23 @@ public: // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { - ValuePtr shuffle = rewriter.create<LLVM::ShuffleVectorOp>( + Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); return matchSuccess(); } // For all other cases, insert the individual values individually. - ValuePtr insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast<IntegerAttr>().getInt(); - ValuePtr value = adaptor.v1(); + Value value = adaptor.v1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.v2(); } - ValuePtr extract = + Value extract = extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, rank, insPos++); @@ -303,7 +298,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractEltOp = cast<vector::ExtractElementOp>(op); @@ -328,7 +323,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); @@ -344,7 +339,7 @@ public: // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { - ValuePtr extracted = rewriter.create<LLVM::ExtractValueOp>( + Value extracted = rewriter.create<LLVM::ExtractValueOp>( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); @@ -352,7 +347,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - ValuePtr extracted = adaptor.vector(); + Value extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -383,7 +378,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpOperandAdaptor(operands); auto insertEltOp = cast<vector::InsertElementOp>(op); @@ -408,7 +403,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpOperandAdaptor(operands); @@ -424,7 +419,7 @@ public: // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa<VectorType>()) { - ValuePtr inserted = rewriter.create<LLVM::InsertValueOp>( + Value inserted = rewriter.create<LLVM::InsertValueOp>( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); @@ -433,7 +428,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = op->getContext(); - ValuePtr extracted = adaptor.dest(); + Value extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast<IntegerAttr>(); auto oneDVectorType = destVectorType; @@ -449,7 +444,7 @@ public: // Insertion of an element into a 1-D LLVM vector. auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); - ValuePtr inserted = rewriter.create<LLVM::InsertElementOp>( + Value inserted = rewriter.create<LLVM::InsertElementOp>( loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), constant); @@ -475,7 +470,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); @@ -486,10 +481,10 @@ public: auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast<vector::OuterProductOp>(op).getResult()->getType()); - ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); - ValuePtr a = adaptor.lhs(), b = adaptor.rhs(); - ValuePtr acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); - SmallVector<ValuePtr, 8> lhs, accs; + Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); + Value a = adaptor.lhs(), b = adaptor.rhs(); + Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector<Value, 8> lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { @@ -497,7 +492,7 @@ public: auto attr = rewriter.getI32IntegerAttr(d); SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); - ValuePtr aD = nullptr, accD = nullptr; + Value aD = nullptr, accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. @@ -505,7 +500,7 @@ public: accD = rewriter.create<LLVM::ExtractValueOp>( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). - ValuePtr aOuterbD = + Value aOuterbD = accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) .getResult() : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); @@ -527,7 +522,7 @@ public: typeConverter) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); @@ -576,12 +571,12 @@ public: auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc); + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc); + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. @@ -627,7 +622,7 @@ public: // TODO(ajcbik): rely solely on libc in future? something else? // PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast<vector::PrintOp>(op); auto adaptor = vector::PrintOpOperandAdaptor(operands); @@ -657,7 +652,7 @@ public: private: void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, - ValuePtr value, VectorType vectorType, Operation *printer, + Value value, VectorType vectorType, Operation *printer, int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { @@ -673,7 +668,7 @@ private: rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; auto llvmType = lowering.convertType( rank > 1 ? reducedType : vectorType.getElementType()); - ValuePtr nestedVal = + Value nestedVal = extractOne(rewriter, lowering, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); if (d != dim - 1) |