diff options
author | Aart Bik <ajcbik@google.com> | 2019-12-12 14:11:27 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-12 14:11:56 -0800 |
commit | 1c81adf362ec79750850dc5ecb0bf3e60399e54f (patch) | |
tree | 67355d032779f502620845e1d6bb5a44aa95bf0d /mlir/lib/Conversion/VectorToLLVM | |
parent | 41a73ddce8923f506eaf6e8c5a61d32add5e4c06 (diff) | |
download | bcm5719-llvm-1c81adf362ec79750850dc5ecb0bf3e60399e54f.tar.gz bcm5719-llvm-1c81adf362ec79750850dc5ecb0bf3e60399e54f.zip |
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 169 |
1 files changed, 114 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 17fb93396d9..d4c27a69fb5 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -61,6 +61,38 @@ static VectorType reducedVectorTypeBack(VectorType tp) { return VectorType::get(tp.getShape().take_back(), tp.getElementType()); } +// Helper that picks the proper sequence for inserting. +static Value *insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value *val1, + Value *val2, Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, + constant); + } + return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, + rewriter.getI64ArrayAttr(pos)); +} + +// Helper that picks the proper sequence for extracting. +static Value *extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value *val, + Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, + constant); + } + return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, + rewriter.getI64ArrayAttr(pos)); +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -77,11 +109,12 @@ public: return matchFailure(); // Rewrite when the full vector type can be lowered (which // implies all 'reduced' types can be lowered too). + auto adaptor = vector::BroadcastOpOperandAdaptor(operands); VectorType srcVectorType = broadcastOp.getSourceType().dyn_cast<VectorType>(); rewriter.replaceOp( - op, expandRanks(operands[0], // source value to be expanded - op->getLoc(), // location of original broadcast + op, expandRanks(adaptor.source(), // source value to be expanded + op->getLoc(), // location of original broadcast srcVectorType, dstVectorType, rewriter)); return matchSuccess(); } @@ -142,7 +175,8 @@ private: assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); - Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); + Value *expand = + insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); SmallVector<int32_t, 4> zeroValues(dim, 0); return rewriter.create<LLVM::ShuffleVectorOp>( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); @@ -152,7 +186,8 @@ private: reducedVectorTypeFront(dstVectorType), rewriter); Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { - result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } return result; } @@ -182,62 +217,86 @@ private: Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { + assert(atStretch); Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); - if (atStretch) { - Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); - Value *expand = - insertOne(result, one, loc, llvmType, rank, 0, rewriter); - SmallVector<int32_t, 4> zeroValues(dim, 0); - return rewriter.create<LLVM::ShuffleVectorOp>( - loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); - } - for (int64_t d = 0; d < dim; ++d) { - Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); - result = insertOne(result, one, loc, llvmType, rank, d, rewriter); - } - } else { - VectorType redSrcType = reducedVectorTypeFront(srcVectorType); - VectorType redDstType = reducedVectorTypeFront(dstVectorType); - Type redLlvmType = lowering.convertType(redSrcType); - for (int64_t d = 0; d < dim; ++d) { - int64_t pos = atStretch ? 0 : d; - Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); - Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); - result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); - } + Value *one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); + Value *expand = + insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); + } + VectorType redSrcType = reducedVectorTypeFront(srcVectorType); + VectorType redDstType = reducedVectorTypeFront(dstVectorType); + Type redLlvmType = lowering.convertType(redSrcType); + for (int64_t d = 0; d < dim; ++d) { + int64_t pos = atStretch ? 0 : d; + Value *one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); + Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); } return result; } +}; - // Picks the proper sequence for inserting. - Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, - int64_t rank, int64_t pos, - ConversionPatternRewriter &rewriter) const { - if (rank == 1) { - auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, lowering.convertType(idxType), - rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, - constant); +class VectorShuffleOpConversion : public LLVMOpLowering { +public: + explicit VectorShuffleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ShuffleOpOperandAdaptor(operands); + auto shuffleOp = cast<vector::ShuffleOp>(op); + auto v1Type = shuffleOp.getV1VectorType(); + auto v2Type = shuffleOp.getV2VectorType(); + auto vectorType = shuffleOp.getVectorType(); + Type llvmType = lowering.convertType(vectorType); + auto maskArrayAttr = shuffleOp.mask(); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + // Get rank and dimension sizes. + int64_t rank = vectorType.getRank(); + assert(v1Type.getRank() == rank); + assert(v2Type.getRank() == rank); + int64_t v1Dim = v1Type.getDimSize(0); + + // For rank 1, where both operands have *exactly* the same vector type, + // there is direct shuffle support in LLVM. Use it! + if (rank == 1 && v1Type == v2Type) { + Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>( + loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); + rewriter.replaceOp(op, shuffle); + return matchSuccess(); } - return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, - rewriter.getI64ArrayAttr(pos)); - } - // Picks the proper sequence for extracting. - Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, - int64_t pos, ConversionPatternRewriter &rewriter) const { - if (rank == 1) { - auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, lowering.convertType(idxType), - rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value, - constant); + // For all other cases, insert the individual values individually. + Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); + int64_t insPos = 0; + for (auto en : llvm::enumerate(maskArrayAttr)) { + int64_t extPos = en.value().cast<IntegerAttr>().getInt(); + Value *value = adaptor.v1(); + if (extPos >= v1Dim) { + extPos -= v1Dim; + value = adaptor.v2(); + } + Value *extract = + extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); + insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, + rank, insPos++); } - return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, - rewriter.getI64ArrayAttr(pos)); + rewriter.replaceOp(op, insert); + return matchSuccess(); } }; @@ -506,9 +565,9 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion, - VectorInsertOpConversion, VectorOuterProductOpConversion, - VectorTypeCastOpConversion>( + patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, + VectorExtractOpConversion, VectorInsertOpConversion, + VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } |