diff options
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 44 |
1 files changed, 19 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 65441674165..c1e88aa0076 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -324,35 +324,33 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { } //===----------------------------------------------------------------------===// -// ExtractElementOp +// ExtractOp //===----------------------------------------------------------------------===// -static Type inferExtractElementOpResultType(VectorType vectorType, - ArrayAttr position) { +static Type inferExtractOpResultType(VectorType vectorType, + ArrayAttr position) { if (static_cast<int64_t>(position.size()) == vectorType.getRank()) return vectorType.getElementType(); return VectorType::get(vectorType.getShape().drop_front(position.size()), vectorType.getElementType()); } -void vector::ExtractElementOp::build(Builder *builder, OperationState &result, - Value *source, - ArrayRef<int32_t> position) { +void vector::ExtractOp::build(Builder *builder, OperationState &result, + Value *source, ArrayRef<int32_t> position) { result.addOperands(source); auto positionAttr = builder->getI32ArrayAttr(position); - result.addTypes(inferExtractElementOpResultType( - source->getType().cast<VectorType>(), positionAttr)); + result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(), + positionAttr)); result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { +static void print(OpAsmPrinter &p, vector::ExtractOp op) { p << op.getOperationName() << " " << *op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); p << " : " << op.vector()->getType(); } -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; SmallVector<NamedAttribute, 4> attrs; OpAsmParser::OperandType vector; @@ -375,13 +373,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, attributeLoc, "expected position attribute of rank smaller than vector rank"); - Type resType = inferExtractElementOpResultType(vectorType, positionAttr); + Type resType = inferExtractOpResultType(vectorType, positionAttr); result.attributes = attrs; return failure(parser.resolveOperand(vector, type, result.operands) || parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(vector::ExtractElementOp op) { +static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); @@ -447,29 +445,26 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, } //===----------------------------------------------------------------------===// -// InsertElementOp +// InsertOp //===----------------------------------------------------------------------===// -void InsertElementOp::build(Builder *builder, OperationState &result, - Value *source, Value *dest, - ArrayRef<int32_t> position) { +void InsertOp::build(Builder *builder, OperationState &result, Value *source, + Value *dest, ArrayRef<int32_t> position) { result.addOperands({source, dest}); auto positionAttr = builder->getI32ArrayAttr(position); result.addTypes(dest->getType()); result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, InsertElementOp op) { +static void print(OpAsmPrinter &p, InsertOp op) { p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << op.position(); - p.printOptionalAttrDict(op.getAttrs(), - {InsertElementOp::getPositionAttrName()}); + p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); p << " : " << op.getSourceType(); p << " into " << op.getDestVectorType(); } -static ParseResult parseInsertElementOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { SmallVector<NamedAttribute, 4> attrs; OpAsmParser::OperandType source, dest; Type sourceType; @@ -477,8 +472,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, Attribute attr; return failure(parser.parseOperand(source) || parser.parseComma() || parser.parseOperand(dest) || - parser.parseAttribute(attr, - InsertElementOp::getPositionAttrName(), + parser.parseAttribute(attr, InsertOp::getPositionAttrName(), result.attributes) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(sourceType) || @@ -488,7 +482,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, parser.addTypeToList(destType, result.types)); } -static LogicalResult verify(InsertElementOp op) { +static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); |