diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 169 |
1 files changed, 114 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 17fb93396d9..d4c27a69fb5 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -61,6 +61,38 @@ static VectorType reducedVectorTypeBack(VectorType tp) { return VectorType::get(tp.getShape().take_back(), tp.getElementType()); } +// Helper that picks the proper sequence for inserting. +static Value *insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value *val1, + Value *val2, Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, + constant); + } + return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, + rewriter.getI64ArrayAttr(pos)); +} + +// Helper that picks the proper sequence for extracting. +static Value *extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value *val, + Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, + constant); + } + return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, + rewriter.getI64ArrayAttr(pos)); +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -77,11 +109,12 @@ public: return matchFailure(); // Rewrite when the full vector type can be lowered (which // implies all 'reduced' types can be lowered too). + auto adaptor = vector::BroadcastOpOperandAdaptor(operands); VectorType srcVectorType = broadcastOp.getSourceType().dyn_cast<VectorType>(); rewriter.replaceOp( - op, expandRanks(operands[0], // source value to be expanded - op->getLoc(), // location of original broadcast + op, expandRanks(adaptor.source(), // source value to be expanded + op->getLoc(), // location of original broadcast srcVectorType, dstVectorType, rewriter)); return matchSuccess(); } @@ -142,7 +175,8 @@ private: assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); - Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); + Value *expand = + insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); @@ -152,7 +186,8 @@ private: reducedVectorTypeFront(dstVectorType), rewriter); Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { - result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } return result; } @@ -182,62 +217,86 @@ private: Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { + assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - if (atStretch) { - Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); - Value *expand = - insertOne(result, one, loc, llvmType, rank, 0, rewriter); - SmallVector<int32_t, 4> zeroValues(dim, 0); - return rewriter.create<LLVM::ShuffleVectorOp>( - loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); - } - for (int64_t d = 0; d < dim; ++d) { - Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); - result = insertOne(result, one, loc, llvmType, rank, d, rewriter); - } - } else { - VectorType redSrcType = reducedVectorTypeFront(srcVectorType); - VectorType redDstType = reducedVectorTypeFront(dstVectorType); - Type redLlvmType = lowering.convertType(redSrcType); - for (int64_t d = 0; d < dim; ++d) { - int64_t pos = atStretch ? 0 : d; - Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); - Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); - result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); - } + Value *one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); + Value *expand = + insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); + } + VectorType redSrcType = reducedVectorTypeFront(srcVectorType); + VectorType redDstType = reducedVectorTypeFront(dstVectorType); + Type redLlvmType = lowering.convertType(redSrcType); + for (int64_t d = 0; d < dim; ++d) { + int64_t pos = atStretch ? 0 : d; + Value *one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); + Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } return result; } +}; - // Picks the proper sequence for inserting. - Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, - int64_t rank, int64_t pos, - ConversionPatternRewriter &rewriter) const { - if (rank == 1) { - auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, lowering.convertType(idxType), - rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, - constant); +class VectorShuffleOpConversion : public LLVMOpLowering { +public: + explicit VectorShuffleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ShuffleOpOperandAdaptor(operands); + auto shuffleOp = cast<vector::ShuffleOp>(op); + auto v1Type = shuffleOp.getV1VectorType(); + auto v2Type = shuffleOp.getV2VectorType(); + auto vectorType = shuffleOp.getVectorType(); + Type llvmType = lowering.convertType(vectorType); + auto maskArrayAttr = shuffleOp.mask(); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + // Get rank and dimension sizes. + int64_t rank = vectorType.getRank(); + assert(v1Type.getRank() == rank); + assert(v2Type.getRank() == rank); + int64_t v1Dim = v1Type.getDimSize(0); + + // For rank 1, where both operands have *exactly* the same vector type, + // there is direct shuffle support in LLVM. Use it! + if (rank == 1 && v1Type == v2Type) { + Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>( + loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); + rewriter.replaceOp(op, shuffle); + return matchSuccess(); } - return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, - rewriter.getI64ArrayAttr(pos)); - } - // Picks the proper sequence for extracting. - Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, - int64_t pos, ConversionPatternRewriter &rewriter) const { - if (rank == 1) { - auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, lowering.convertType(idxType), - rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value, - constant); + // For all other cases, insert the individual values individually. + Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); + int64_t insPos = 0; + for (auto en : llvm::enumerate(maskArrayAttr)) { + int64_t extPos = en.value().cast<IntegerAttr>().getInt(); + Value *value = adaptor.v1(); + if (extPos >= v1Dim) { + extPos -= v1Dim; + value = adaptor.v2(); + } + Value *extract = + extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); + insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, + rank, insPos++); } - return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, - rewriter.getI64ArrayAttr(pos)); + rewriter.replaceOp(op, insert); + return matchSuccess(); } }; @@ -506,9 +565,9 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion, - VectorInsertOpConversion, VectorOuterProductOpConversion, - VectorTypeCastOpConversion>( + patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, + VectorExtractOpConversion, VectorInsertOpConversion, + VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } |