diff options
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 64 |
1 files changed, 24 insertions, 40 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 28a03222311..a2345fe1c40 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) { llvm::StringSet<> traitAttrsSet; traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector<NamedAttribute, 8> attrs; - for (auto attr : op.getAttrs()) { + for (auto attr : op.getAttrs()) if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - } + auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", "; p << *op.rhs() << ", " << *op.acc(); - if (llvm::size(op.masks()) == 2) { - p << ", " << **op.masks().begin(); - p << ", " << **(op.masks().begin() + 1); - } + if (op.masks().size() == 2) + p << ", " << op.masks(); + p.printOptionalAttrDict(op.getAttrs(), attrNames); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into " << op.getResultType(); @@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << *op.source(); - p << " : " << op.getSourceType(); - p << " to " << op.getVectorType(); + p << op.getOperationName() << " " << *op.source() << " : " + << op.getSourceType() << " to " << op.getVectorType(); } static LogicalResult verify(BroadcastOp op) { @@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) { p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); - p << " : " << op.getSourceType(); - p << " into " << op.getDestVectorType(); + p << " : " << op.getSourceType() << " into " << op.getDestVectorType(); } static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { @@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) { static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); - if (llvm::size(op.acc()) > 0) - p << ", " << **op.acc().begin(); + if (!op.acc().empty()) + p << ", " << op.acc(); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); } @@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, } static void print(OpAsmPrinter &p, TransferReadOp op) { - p << op.getOperationName() << " "; - p.printOperand(op.memref()); - p << "["; - p.printOperands(op.indices()); - p << "], "; - p.printOperand(op.padding()); - p << " "; + p << op.getOperationName() << " " << op.memref() << "[" << op.indices() + << "], " << op.padding() << " "; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getMemRefType(); - p << ", " << op.getVectorType(); + p << " : " << op.getMemRefType() << ", " << op.getVectorType(); } ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) { @@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) { // TransferWriteOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref(); - p << "["; - p.printOperands(op.indices()); - p << "]"; + p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref() + << "[" << op.indices() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : "; - p.printType(op.getVectorType()); - p << ", "; - p.printType(op.getMemRefType()); + p << " : " << op.getVectorType() << ", " << op.getMemRefType(); } ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) { @@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result, inferVectorTypeCastResultType(source->getType().cast<MemRefType>())); } -static void print(OpAsmPrinter &p, TypeCastOp &op) { +static void print(OpAsmPrinter &p, TypeCastOp op) { auto type = op.getOperand()->getType().cast<MemRefType>(); p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to " << inferVectorTypeCastResultType(type); } -static LogicalResult verify(TypeCastOp &op) { +static LogicalResult verify(TypeCastOp op) { auto resultType = inferVectorTypeCastResultType(op.getMemRefType()); if (op.getResultMemRefType() != resultType) return op.emitOpError("expects result type to be: ") << resultType; @@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) { parser.addTypeToList(resultType, result.types)); } -static void print(OpAsmPrinter &p, ConstantMaskOp &op) { - p << op.getOperationName() << ' ' << op.mask_dim_sizes(); - p << " : " << op.getResult()->getType(); +static void print(OpAsmPrinter &p, ConstantMaskOp op) { + p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : " + << op.getResult()->getType(); } static LogicalResult verify(ConstantMaskOp &op) { @@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) { parser.addTypeToList(resultType, result.types)); } -static void print(OpAsmPrinter &p, CreateMaskOp &op) { - p << op.getOperationName() << ' '; - p.printOperands(op.operands()); - p << " : " << op.getResult()->getType(); +static void print(OpAsmPrinter &p, CreateMaskOp op) { + p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType(); } -static LogicalResult verify(CreateMaskOp &op) { +static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != op.getResult()->getType().cast<VectorType>().getRank()) |