summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps/Ops.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-12 15:31:39 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-12 15:32:21 -0800
commite7aa47ff111c53127587d8aea71b088db3a671aa (patch)
treef2b4841362b381de4f2beb1632250b8cdbc49cf2 /mlir/lib/Dialect/StandardOps/Ops.cpp
parenta50cb184a0c5ebc342a871b2e338e2591115639f (diff)
downloadbcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.tar.gz
bcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.zip
NFC: Cleanup the various Op::print methods.
This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc. PiperOrigin-RevId: 285285181
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp112
1 files changed, 29 insertions, 83 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 7726c0446aa..531be29666a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) {
- p << '(';
- p.printOperands(begin, begin + numDims);
- p << ')';
-
- if (begin + numDims != end) {
- p << '[';
- p.printOperands(begin + numDims, end);
- p << ']';
- }
+ Operation::operand_range operands(begin, end);
+ p << '(' << operands.take_front(numDims) << ')';
+ if (operands.size() != numDims)
+ p << '[' << operands.drop_front(numDims) << ']';
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
@@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, CallOp op) {
- p << "call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
+ p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
+ p << " : " << op.getCalleeType();
}
static LogicalResult verify(CallOp op) {
@@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CallIndirectOp op) {
- p << "call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- p.printOperands(op.getArgOperands());
- p << ')';
+ p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCallee()->getType();
}
@@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
- << '"';
-
- p << ", ";
- p.printOperand(op.lhs());
- p << ", ";
- p.printOperand(op.rhs());
+ << '"' << ", " << op.lhs() << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
"unknown predicate index");
- Builder b(op.getContext());
- auto predicateStringAttr =
- b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
- p.printAttribute(predicateStringAttr);
-
- p << ", ";
- p.printOperand(op.lhs());
- p << ", ";
- p.printOperand(op.rhs());
+ p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
+ << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CondBranchOp op) {
- p << "cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
+ p << "cond_br " << op.getCondition() << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
@@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
if (op.getAttrs().size() > 1)
p << ' ';
- p.printAttribute(op.getValue());
+ p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
@@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
}
void DmaStartOp::print(OpAsmPrinter &p) {
- p << "dma_start " << *getSrcMemRef() << '[';
- p.printOperands(getSrcIndices());
- p << "], " << *getDstMemRef() << '[';
- p.printOperands(getDstIndices());
- p << "], " << *getNumElements();
- p << ", " << *getTagMemRef() << '[';
- p.printOperands(getTagIndices());
- p << ']';
- if (isStrided()) {
- p << ", " << *getStride();
- p << ", " << *getNumElementsPerStride();
- }
+ p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
+ << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
+ << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
+ if (isStrided())
+ p << ", " << *getStride() << ", " << *getNumElementsPerStride();
+
p.printOptionalAttrDict(getAttrs());
p << " : " << getSrcMemRef()->getType();
p << ", " << getDstMemRef()->getType();
@@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
}
void DmaWaitOp::print(OpAsmPrinter &p) {
- p << "dma_wait ";
- p.printOperand(getTagMemRef());
- p << '[';
- p.printOperands(getTagIndices());
- p << "], ";
- p.printOperand(getNumElements());
+ p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
+ << getNumElements();
p.printOptionalAttrDict(getAttrs());
p << " : " << getTagMemRef()->getType();
}
@@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
- p << "extract_element " << *op.getAggregate() << '[';
- p.printOperands(op.getIndices());
+ p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
p << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getAggregate()->getType();
@@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) {
- p << "load " << *op.getMemRef() << '[';
- p.printOperands(op.getIndices());
- p << ']';
+ p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReturnOp op) {
p << "return";
- if (op.getNumOperands() != 0) {
- p << ' ';
- p.printOperands(op.getOperands());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
+ if (op.getNumOperands() != 0)
+ p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
}
static LogicalResult verify(ReturnOp op) {
@@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SelectOp op) {
- p << "select ";
- p.printOperands(op.getOperands());
- p << " : " << op.getTrueValue()->getType();
+ p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
p.printOptionalAttrDict(op.getAttrs());
}
@@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
static void print(OpAsmPrinter &p, StoreOp op) {
p << "store " << *op.getValueToStore();
- p << ", " << *op.getMemRef() << '[';
- p.printOperands(op.getIndices());
- p << ']';
+ p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
auto *dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
- p << "][";
- p.printOperands(op.getDynamicSizes());
- p << ']';
+ p << "][" << op.getDynamicSizes() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
@@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SubViewOp op) {
- p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
- p.printOperands(op.offsets());
- p << "][";
- p.printOperands(op.sizes());
- p << "][";
- p.printOperands(op.strides());
- p << ']';
+ p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
+ << "][" << op.sizes() << "][" << op.strides() << ']';
SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};
OpenPOWER on IntegriCloud