diff options
author | River Riddle <riverriddle@google.com> | 2019-12-12 15:31:39 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-12 15:32:21 -0800 |
commit | e7aa47ff111c53127587d8aea71b088db3a671aa (patch) | |
tree | f2b4841362b381de4f2beb1632250b8cdbc49cf2 /mlir/lib/Dialect/StandardOps/Ops.cpp | |
parent | a50cb184a0c5ebc342a871b2e338e2591115639f (diff) | |
download | bcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.tar.gz bcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.zip |
NFC: Cleanup the various Op::print methods.
This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc.
PiperOrigin-RevId: 285285181
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
-rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 112 |
1 files changed, 29 insertions, 83 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 7726c0446aa..531be29666a 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) void mlir::printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p) { - p << '('; - p.printOperands(begin, begin + numDims); - p << ')'; - - if (begin + numDims != end) { - p << '['; - p.printOperands(begin + numDims, end); - p << ']'; - } + Operation::operand_range operands(begin, end); + p << '(' << operands.take_front(numDims) << ')'; + if (operands.size() != numDims) + p << '[' << operands.drop_front(numDims) << ']'; } // Parses dimension and symbol list, and sets 'numDims' to the number of @@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, CallOp op) { - p << "call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; + p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); + p << " : " << op.getCalleeType(); } static LogicalResult verify(CallOp op) { @@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, CallIndirectOp op) { - p << "call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - p.printOperands(op.getArgOperands()); - p << ')'; + p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : " << op.getCallee()->getType(); } @@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) { auto predicateValue = op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue)) - << '"'; - - p << ", "; - p.printOperand(op.lhs()); - p << ", "; - p.printOperand(op.rhs()); + << '"' << ", " << op.lhs() << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); @@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) { assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && "unknown predicate index"); - Builder b(op.getContext()); - auto predicateStringAttr = - b.getStringAttr(getCmpFPredicateNames()[predicateValue]); - p.printAttribute(predicateStringAttr); - - p << ", "; - p.printOperand(op.lhs()); - p << ", "; - p.printOperand(op.rhs()); + p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs() + << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); @@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, CondBranchOp op) { - p << "cond_br "; - p.printOperand(op.getCondition()); - p << ", "; + p << "cond_br " << op.getCondition() << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); p << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); @@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) { if (op.getAttrs().size() > 1) p << ' '; - p.printAttribute(op.getValue()); + p << op.getValue(); // If the value is a symbol reference, print a trailing type. if (op.getValue().isa<SymbolRefAttr>()) @@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result, } void DmaStartOp::print(OpAsmPrinter &p) { - p << "dma_start " << *getSrcMemRef() << '['; - p.printOperands(getSrcIndices()); - p << "], " << *getDstMemRef() << '['; - p.printOperands(getDstIndices()); - p << "], " << *getNumElements(); - p << ", " << *getTagMemRef() << '['; - p.printOperands(getTagIndices()); - p << ']'; - if (isStrided()) { - p << ", " << *getStride(); - p << ", " << *getNumElementsPerStride(); - } + p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], " + << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements() + << ", " << *getTagMemRef() << '[' << getTagIndices() << ']'; + if (isStrided()) + p << ", " << *getStride() << ", " << *getNumElementsPerStride(); + p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef()->getType(); p << ", " << getDstMemRef()->getType(); @@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result, } void DmaWaitOp::print(OpAsmPrinter &p) { - p << "dma_wait "; - p.printOperand(getTagMemRef()); - p << '['; - p.printOperands(getTagIndices()); - p << "], "; - p.printOperand(getNumElements()); + p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " + << getNumElements(); p.printOptionalAttrDict(getAttrs()); p << " : " << getTagMemRef()->getType(); } @@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExtractElementOp op) { - p << "extract_element " << *op.getAggregate() << '['; - p.printOperands(op.getIndices()); + p << "extract_element " << *op.getAggregate() << '[' << op.getIndices(); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getAggregate()->getType(); @@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, LoadOp op) { - p << "load " << *op.getMemRef() << '['; - p.printOperands(op.getIndices()); - p << ']'; + p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, ReturnOp op) { p << "return"; - if (op.getNumOperands() != 0) { - p << ' '; - p.printOperands(op.getOperands()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } + if (op.getNumOperands() != 0) + p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); } static LogicalResult verify(ReturnOp op) { @@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SelectOp op) { - p << "select "; - p.printOperands(op.getOperands()); - p << " : " << op.getTrueValue()->getType(); + p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType(); p.printOptionalAttrDict(op.getAttrs()); } @@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { static void print(OpAsmPrinter &p, StoreOp op) { p << "store " << *op.getValueToStore(); - p << ", " << *op.getMemRef() << '['; - p.printOperands(op.getIndices()); - p << ']'; + p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) { auto *dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); - p << "]["; - p.printOperands(op.getDynamicSizes()); - p << ']'; + p << "][" << op.getDynamicSizes() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SubViewOp op) { - p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.offsets()); - p << "]["; - p.printOperands(op.sizes()); - p << "]["; - p.printOperands(op.strides()); - p << ']'; + p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets() + << "][" << op.sizes() << "][" << op.strides() << ']'; SmallVector<StringRef, 1> elidedAttrs = { SubViewOp::getOperandSegmentSizeAttr()}; |