summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp53
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp74
2 files changed, 126 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d4c27a69fb5..71bed9516ca 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -300,6 +300,31 @@ public:
}
};
+class VectorExtractElementOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorExtractElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
+ auto extractEltOp = cast<vector::ExtractElementOp>(op);
+ auto vectorType = extractEltOp.getVectorType();
+ auto llvmType = lowering.convertType(vectorType.getElementType());
+
+ // Bail if result type cannot be lowered.
+ if (!llvmType)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
+ op, llvmType, adaptor.vector(), adaptor.position());
+ return matchSuccess();
+ }
+};
+
class VectorExtractOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractOpConversion(MLIRContext *context,
@@ -355,6 +380,31 @@ public:
}
};
+class VectorInsertElementOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorInsertElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
+ auto insertEltOp = cast<vector::InsertElementOp>(op);
+ auto vectorType = insertEltOp.getDestVectorType();
+ auto llvmType = lowering.convertType(vectorType);
+
+ // Bail if result type cannot be lowered.
+ if (!llvmType)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
+ return matchSuccess();
+ }
+};
+
class VectorInsertOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertOpConversion(MLIRContext *context,
@@ -566,7 +616,8 @@ public:
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
- VectorExtractOpConversion, VectorInsertOpConversion,
+ VectorExtractElementOpConversion, VectorExtractOpConversion,
+ VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 2dfa4568a3e..fc8abd710e9 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -347,6 +347,42 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
}
//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
+ p << op.getOperationName() << " " << *op.vector() << "[" << *op.position()
+ << " : " << op.position()->getType() << "]";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.vector()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType vector, position;
+ Type positionType;
+ VectorType vectorType;
+ if (parser.parseOperand(vector) || parser.parseLSquare() ||
+ parser.parseOperand(position) || parser.parseColonType(positionType) ||
+ parser.parseRSquare() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(vectorType))
+ return failure();
+ Type resultType = vectorType.getElementType();
+ return failure(
+ parser.resolveOperand(vector, vectorType, result.operands) ||
+ parser.resolveOperand(position, positionType, result.operands) ||
+ parser.addTypeToList(resultType, result.types));
+}
+
+static LogicalResult verify(vector::ExtractElementOp op) {
+ VectorType vectorType = op.getVectorType();
+ if (vectorType.getRank() != 1)
+ return op.emitOpError("expected 1-D vector");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
@@ -685,6 +721,44 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
}
//===----------------------------------------------------------------------===//
+// InsertElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, InsertElementOp op) {
+ p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "["
+ << *op.position() << " : " << op.position()->getType() << "]";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.dest()->getType();
+}
+
+static ParseResult parseInsertElementOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType source, dest, position;
+ Type positionType;
+ VectorType destType;
+ if (parser.parseOperand(source) || parser.parseComma() ||
+ parser.parseOperand(dest) || parser.parseLSquare() ||
+ parser.parseOperand(position) || parser.parseColonType(positionType) ||
+ parser.parseRSquare() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(destType))
+ return failure();
+ Type sourceType = destType.getElementType();
+ return failure(
+ parser.resolveOperand(source, sourceType, result.operands) ||
+ parser.resolveOperand(dest, destType, result.operands) ||
+ parser.resolveOperand(position, positionType, result.operands) ||
+ parser.addTypeToList(destType, result.types));
+}
+
+static LogicalResult verify(InsertElementOp op) {
+ auto dstVectorType = op.getDestVectorType();
+ if (dstVectorType.getRank() != 1)
+ return op.emitOpError("expected 1-D vector");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud