diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 53 | ||||
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 74 |
2 files changed, 126 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d4c27a69fb5..71bed9516ca 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -300,6 +300,31 @@ public: } }; +class VectorExtractElementOpConversion : public LLVMOpLowering { +public: + explicit VectorExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractEltOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractEltOp.getVectorType(); + auto llvmType = lowering.convertType(vectorType.getElementType()); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( + op, llvmType, adaptor.vector(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorExtractOpConversion : public LLVMOpLowering { public: explicit VectorExtractOpConversion(MLIRContext *context, @@ -355,6 +380,31 @@ public: } }; +class VectorInsertElementOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::InsertElementOpOperandAdaptor(operands); + auto insertEltOp = cast<vector::InsertElementOp>(op); + auto vectorType = insertEltOp.getDestVectorType(); + auto llvmType = lowering.convertType(vectorType); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( + op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorInsertOpConversion : public LLVMOpLowering { public: explicit VectorInsertOpConversion(MLIRContext *context, @@ -566,7 +616,8 @@ public: void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, - VectorExtractOpConversion, VectorInsertOpConversion, + VectorExtractElementOpConversion, VectorExtractOpConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 2dfa4568a3e..fc8abd710e9 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -347,6 +347,42 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { } //===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { + p << op.getOperationName() << " " << *op.vector() << "[" << *op.position() + << " : " << op.position()->getType() << "]"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.vector()->getType(); +} + +static ParseResult parseExtractElementOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType vector, position; + Type positionType; + VectorType vectorType; + if (parser.parseOperand(vector) || parser.parseLSquare() || + parser.parseOperand(position) || parser.parseColonType(positionType) || + parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(vectorType)) + return failure(); + Type resultType = vectorType.getElementType(); + return failure( + parser.resolveOperand(vector, vectorType, result.operands) || + parser.resolveOperand(position, positionType, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static LogicalResult verify(vector::ExtractElementOp op) { + VectorType vectorType = op.getVectorType(); + if (vectorType.getRank() != 1) + return op.emitOpError("expected 1-D vector"); + return success(); +} + +//===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// @@ -685,6 +721,44 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { } //===----------------------------------------------------------------------===// +// InsertElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, InsertElementOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "[" + << *op.position() << " : " << op.position()->getType() << "]"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.dest()->getType(); +} + +static ParseResult parseInsertElementOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest, position; + Type positionType; + VectorType destType; + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || parser.parseLSquare() || + parser.parseOperand(position) || parser.parseColonType(positionType) || + parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(destType)) + return failure(); + Type sourceType = destType.getElementType(); + return failure( + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.resolveOperand(position, positionType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + +static LogicalResult verify(InsertElementOp op) { + auto dstVectorType = op.getDestVectorType(); + if (dstVectorType.getRank() != 1) + return op.emitOpError("expected 1-D vector"); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// |

