summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-12 15:31:39 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-12 15:32:21 -0800
commite7aa47ff111c53127587d8aea71b088db3a671aa (patch)
treef2b4841362b381de4f2beb1632250b8cdbc49cf2 /mlir/lib
parenta50cb184a0c5ebc342a871b2e338e2591115639f (diff)
downloadbcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.tar.gz
bcm5719-llvm-e7aa47ff111c53127587d8aea71b088db3a671aa.zip
NFC: Cleanup the various Op::print methods.
This cleans up the implementation of the various operation print methods. This is done via a combination of code cleanup, adding new streaming methods to the printer(e.g. operand ranges), etc. PiperOrigin-RevId: 285285181
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp16
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp25
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp16
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp19
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp16
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp41
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp116
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp112
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp64
9 files changed, 146 insertions, 279 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 59e5afec9ce..96a1a68889c 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -1985,18 +1985,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, AffineMinOp op) {
p << op.getOperationName() << ' '
<< op.getAttr(AffineMinOp::getMapAttrName());
- auto begin = op.operand_begin();
- auto end = op.operand_end();
+ auto operands = op.getOperands();
unsigned numDims = op.map().getNumDims();
- p << '(';
- p.printOperands(begin, begin + numDims);
- p << ')';
-
- if (begin + numDims != end) {
- p << '[';
- p.printOperands(begin + numDims, end);
- p << ']';
- }
+ p << '(' << operands.take_front(numDims) << ')';
+
+ if (operands.size() != numDims)
+ p << '[' << operands.drop_front(numDims) << ']';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b8970650806..1f48d6d47e4 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -289,8 +289,7 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the launch configuration.
p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword();
- printSizeAssignment(p, op.getGridSize(),
- operands.drop_back(operands.size() - 3),
+ printSizeAssignment(p, op.getGridSize(), operands.take_front(3),
op.getBlockIds());
p << ' ' << op.getThreadsKeyword();
printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
@@ -303,25 +302,17 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
// Print the data argument remapping.
if (!op.body().empty() && !operands.empty()) {
p << ' ' << op.getArgsKeyword() << '(';
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (i != 0)
- p << ", ";
- p << *op.body().front().getArgument(LaunchOp::kNumConfigRegionAttributes +
- i)
+ Block *entryBlock = &op.body().front();
+ interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
+ p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
<< " = " << *operands[i];
- }
+ });
p << ") ";
}
// Print the types of data arguments.
- if (!operands.empty()) {
- p << ": ";
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (i != 0)
- p << ", ";
- p << operands[i]->getType();
- }
- }
+ if (!operands.empty())
+ p << ": " << operands.getTypes();
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs());
@@ -701,7 +692,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
return;
p << ' ' << keyword << '(';
- interleaveComma(values, p.getStream(),
+ interleaveComma(values, p,
[&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
p << ')';
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 78da9998c6d..d037d2edb72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -177,9 +177,8 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
- p << op.getOperationName() << ' ' << *op.base() << '[';
- p.printOperands(std::next(op.operand_begin()), op.operand_end());
- p << ']';
+ p << op.getOperationName() << ' ' << *op.base() << '['
+ << op.getOperands().drop_front() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << funcTy;
}
@@ -312,10 +311,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
else
p << *op.getOperand(0);
- p << '(';
- p.printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
- p << ')';
-
+ p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
@@ -938,8 +934,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
// Print the trailing type unless it's a string global.
if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
return;
- p << " : ";
- p.printType(op.type());
+ p << " : " << op.type();
Region &initializer = op.getInitializerRegion();
if (!initializer.empty())
@@ -1346,8 +1341,7 @@ static LogicalResult verify(LLVMFuncOp op) {
static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
p << NullOp::getOperationName();
p.printOptionalAttrDict(op.getAttrs());
- p << " : ";
- p.printType(op.getType());
+ p << " : " << op.getType();
}
// <operation> = `llvm.mlir.null` : type
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0b10391f180..e4708fbe535 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -37,18 +37,17 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
-namespace mlir {
-namespace NVVM {
+using namespace mlir;
+using namespace NVVM;
//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
- p << op->getName() << " ";
- p.printOperands(op->getOperands());
+ p << op->getName() << " " << op->getOperands();
if (op->getNumResults() > 0)
- interleaveComma(op->getResultTypes(), p << " : ");
+ p << " : " << op->getResultTypes();
}
// <operation> ::= `llvm.nvvm.XYZ` : type
@@ -141,8 +140,7 @@ static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
}
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
- p << op.getOperationName() << " ";
- p.printOperands(op.getOperands());
+ p << op.getOperationName() << " " << op.getOperands();
p.printOptionalAttrDict(op.getAttrs());
p << " : "
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
@@ -210,10 +208,11 @@ NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
allowUnknownOperations();
}
+namespace mlir {
+namespace NVVM {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
-
-static DialectRegistration<NVVMDialect> nvvmDialect;
-
} // namespace NVVM
} // namespace mlir
+
+static DialectRegistration<NVVMDialect> nvvmDialect;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 487382bb364..30c55b52e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -36,18 +36,17 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
-namespace mlir {
-namespace ROCDL {
+using namespace mlir;
+using namespace ROCDL;
//===----------------------------------------------------------------------===//
// Printing/parsing for ROCDL ops
//===----------------------------------------------------------------------===//
static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
- p << op->getName() << " ";
- p.printOperands(op->getOperands());
+ p << op->getName() << " " << op->getOperands();
if (op->getNumResults() > 0)
- interleaveComma(op->getResultTypes(), p << " : ");
+ p << " : " << op->getResultTypes();
}
// <operation> ::= `rocdl.XYZ` : type
@@ -73,10 +72,11 @@ ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
allowUnknownOperations();
}
+namespace mlir {
+namespace ROCDL {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
-
-static DialectRegistration<ROCDLDialect> rocdlDialect;
-
} // namespace ROCDL
} // namespace mlir
+
+static DialectRegistration<ROCDLDialect> rocdlDialect;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2efd26ad78b..6adfeb592ef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -60,18 +60,16 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
llvm::StringSet<> linalgTraitAttrsSet;
linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
- for (auto attr : op.getAttrs()) {
+ for (auto attr : op.getAttrs())
if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
- }
+
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
- p << op.getOperationName() << " " << dictAttr << " ";
- p.printOperands(op.getOperands());
+ p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
if (!op.region().empty())
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs(), attrNames);
- p << ": ";
- interleaveComma(op.getOperandTypes(), p);
+ p << ": " << op.getOperandTypes();
}
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@@ -342,14 +340,13 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
}
static void print(OpAsmPrinter &p, SliceOp op) {
- p << SliceOp::getOperationName() << " " << *op.view() << "[";
- p.printOperands(op.indexings());
- p << "] ";
+ auto indexings = op.indexings();
+ p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
+ << "] ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getBaseViewType();
- for (auto indexing : op.indexings()) {
- p << ", " << indexing->getType();
- }
+ if (!indexings.empty())
+ p << ", " << op.indexings().getTypes();
p << ", " << op.getType();
}
@@ -455,16 +452,11 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, YieldOp op) {
p << op.getOperationName();
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- }
+ if (op.getNumOperands() > 0)
+ p << ' ' << op.getOperands();
p.printOptionalAttrDict(op.getAttrs());
- if (op.getNumOperands() > 0) {
- p << " : ";
- interleaveComma(op.getOperands(), p,
- [&](Value *e) { p.printType(e->getType()); });
- }
+ if (op.getNumOperands() > 0)
+ p << " : " << op.getOperandTypes();
}
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
@@ -536,12 +528,9 @@ static LogicalResult verify(YieldOp op) {
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
- p << op->getName().getStringRef() << "(";
- interleaveComma(op->getOperands(), p, [&](Value *v) { p << *v; });
- p << ")";
+ p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
p.printOptionalAttrDict(op->getAttrs());
- p << " : ";
- interleaveComma(op->getOperands(), p, [&](Value *v) { p << v->getType(); });
+ p << " : " << op->getOperandTypes();
}
static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 99ab1cdee42..839f134ec8f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
}
static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
- printer << op->getName() << ' ';
- printer.printOperands(op->getOperands());
- printer << " : " << op->getOperand(0)->getType() << ", "
- << op->getOperand(1)->getType() << ", "
- << op->getOperand(2)->getType();
+ printer << op->getName() << ' ' << op->getOperands() << " : "
+ << op->getOperandTypes();
}
static LogicalResult verifyBitFieldExtractOp(Operation *op) {
@@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
}
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
- printer << logicalOp->getName() << ' ';
- printer.printOperands(logicalOp->getOperands());
- printer << " : " << logicalOp->getOperand(0)->getType();
+ printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
+ << logicalOp->getOperand(0)->getType();
}
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
- << '[';
- printer.printOperands(op.indices());
- printer << "] : " << op.base_ptr()->getType();
+ << '[' << op.indices() << "] : " << op.base_ptr()->getType();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
@@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
<< stringifyScope(atomOp.memory_scope()) << "\" \""
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
- << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" ";
- printer.printOperands(atomOp.getOperands());
- printer << " : " << atomOp.pointer()->getType();
+ << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.pointer()->getType();
}
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) {
- printer << spirv::BitFieldInsertOp::getOperationName() << ' ';
- printer.printOperands(bitFieldInsertOp.getOperands());
- printer << " : " << bitFieldInsertOp.base()->getType() << ", "
+ printer << spirv::BitFieldInsertOp::getOperationName() << ' '
+ << bitFieldInsertOp.getOperands() << " : "
+ << bitFieldInsertOp.base()->getType() << ", "
<< bitFieldInsertOp.offset()->getType() << ", "
<< bitFieldInsertOp.count()->getType();
}
@@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
}
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
- printer << spirv::BranchConditionalOp::getOperationName() << ' ';
- printer.printOperand(branchOp.condition());
+ printer << spirv::BranchConditionalOp::getOperationName() << ' '
+ << branchOp.condition();
if (auto weights = branchOp.branch_weights()) {
printer << " [";
@@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
static void print(spirv::CompositeConstructOp compositeConstructOp,
OpAsmPrinter &printer) {
- printer << spirv::CompositeConstructOp::getOperationName() << " ";
- printer.printOperands(compositeConstructOp.constituents());
- printer << " : " << compositeConstructOp.getResult()->getType();
+ printer << spirv::CompositeConstructOp::getOperationName() << " "
+ << compositeConstructOp.constituents() << " : "
+ << compositeConstructOp.getResult()->getType();
}
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
- if (constOp.getType().isa<spirv::ArrayType>()) {
+ if (constOp.getType().isa<spirv::ArrayType>())
printer << " : " << constOp.getType();
- }
}
static LogicalResult verify(spirv::ConstantOp constOp) {
@@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
<< execModeOp.fn() << " \""
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
auto values = execModeOp.values();
- if (!values.size()) {
+ if (!values.size())
return;
- }
printer << ", ";
interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
@@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
printer << spirv::FunctionCallOp::getOperationName() << ' '
- << functionCallOp.getAttr(kCallee) << '(';
- printer.printOperands(functionCallOp.arguments());
- printer << ") : " << functionType;
+ << functionCallOp.getAttr(kCallee) << '('
+ << functionCallOp.arguments() << ") : " << functionType;
}
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
@@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
static void print(spirv::GroupNonUniformBallotOp ballotOp,
OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
- << stringifyScope(ballotOp.execution_scope()) << "\" ";
- printer.printOperand(ballotOp.predicate());
- printer << " : " << ballotOp.getType();
+ << stringifyScope(ballotOp.execution_scope()) << "\" "
+ << ballotOp.predicate() << " : " << ballotOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
@@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
- // Print the pointer operand.
- printer.printOperand(loadOp.ptr());
+ printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
+ << loadOp.ptr();
printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
@@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
- auto *op = moduleOp.getOperation();
+ printer << spirv::ModuleOp::getOperationName();
// Only print out addressing model and memory model in a nicer way if both
- // presents. Otherwise, print them in the general form. This helps debugging
- // ill-formed ModuleOp.
+ // presents. Otherwise, print them in the general form. This helps
+ // debugging ill-formed ModuleOp.
SmallVector<StringRef, 2> elidedAttrs;
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
- if (op->getAttr(addressingModelAttrName) &&
- op->getAttr(memoryModelAttrName)) {
- printer << spirv::ModuleOp::getOperationName() << " \""
+ if (moduleOp.getAttr(addressingModelAttrName) &&
+ moduleOp.getAttr(memoryModelAttrName)) {
+ printer << " \""
<< spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
<< '"';
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
}
- printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+ printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
- printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
+ printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
}
static LogicalResult verify(spirv::ModuleOp moduleOp) {
@@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
}
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
- printer << spirv::ReturnValueOp::getOperationName() << ' ';
- printer.printOperand(retValOp.value());
- printer << " : " << retValOp.value()->getType();
+ printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
+ << " : " << retValOp.value()->getType();
}
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
- printer << spirv::SelectOp::getOperationName() << " ";
-
- // Print the operands.
- printer.printOperands(op.getOperands());
-
- // Print colon and types.
- printer << " : " << op.condition()->getType() << ", "
+ printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
+ << " : " << op.condition()->getType() << ", "
<< op.result()->getType();
}
@@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
printer.printSymbolName(constOp.sym_name());
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
- printer << " = ";
- printer.printAttribute(constOp.default_value());
+ printer << " = " << constOp.default_value();
}
static LogicalResult verify(spirv::SpecConstantOp constOp) {
@@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
- // Print the pointer operand
- printer.printOperand(storeOp.ptr());
- printer << ", ";
- // Print the value operand
- printer.printOperand(storeOp.value());
+ printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
+ << storeOp.ptr() << ", " << storeOp.value();
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
printer << " : " << storeOp.value()->getType();
-
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
@@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
}
static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
- printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
- printer.printOperand(ballotOp.predicate());
- printer << " : " << ballotOp.getType();
+ printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
+ << ballotOp.predicate() << " : " << ballotOp.getType();
}
//===----------------------------------------------------------------------===//
@@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
- auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
printer << spirv::VariableOp::getOperationName();
// Print optional initializer
- if (op->getNumOperands() > 0) {
- printer << " init(";
- printer.printOperands(varOp.initializer());
- printer << ")";
- }
-
- printVariableDecorations(op, printer, elidedAttrs);
+ if (varOp.getNumOperands() != 0)
+ printer << " init(" << varOp.initializer() << ")";
+ printVariableDecorations(varOp, printer, elidedAttrs);
printer << " : " << varOp.getType();
}
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 7726c0446aa..531be29666a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -166,15 +166,10 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) {
- p << '(';
- p.printOperands(begin, begin + numDims);
- p << ')';
-
- if (begin + numDims != end) {
- p << '[';
- p.printOperands(begin + numDims, end);
- p << ']';
- }
+ Operation::operand_range operands(begin, end);
+ p << '(' << operands.take_front(numDims) << ')';
+ if (operands.size() != numDims)
+ p << '[' << operands.drop_front(numDims) << ']';
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
@@ -485,12 +480,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, CallOp op) {
- p << "call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
+ p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
+ p << " : " << op.getCalleeType();
}
static LogicalResult verify(CallOp op) {
@@ -572,11 +564,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CallIndirectOp op) {
- p << "call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- p.printOperands(op.getArgOperands());
- p << ')';
+ p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCallee()->getType();
}
@@ -690,12 +678,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
- << '"';
-
- p << ", ";
- p.printOperand(op.lhs());
- p << ", ";
- p.printOperand(op.rhs());
+ << '"' << ", " << op.lhs() << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@@ -851,15 +834,8 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
"unknown predicate index");
- Builder b(op.getContext());
- auto predicateStringAttr =
- b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
- p.printAttribute(predicateStringAttr);
-
- p << ", ";
- p.printOperand(op.lhs());
- p << ", ";
- p.printOperand(op.rhs());
+ p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
+ << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
p << " : " << op.lhs()->getType();
@@ -1002,9 +978,7 @@ static ParseResult parseCondBranchOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, CondBranchOp op) {
- p << "cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
+ p << "cond_br " << op.getCondition() << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
@@ -1025,7 +999,7 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
if (op.getAttrs().size() > 1)
p << ' ';
- p.printAttribute(op.getValue());
+ p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
@@ -1407,18 +1381,12 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
}
void DmaStartOp::print(OpAsmPrinter &p) {
- p << "dma_start " << *getSrcMemRef() << '[';
- p.printOperands(getSrcIndices());
- p << "], " << *getDstMemRef() << '[';
- p.printOperands(getDstIndices());
- p << "], " << *getNumElements();
- p << ", " << *getTagMemRef() << '[';
- p.printOperands(getTagIndices());
- p << ']';
- if (isStrided()) {
- p << ", " << *getStride();
- p << ", " << *getNumElementsPerStride();
- }
+ p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
+ << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
+ << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
+ if (isStrided())
+ p << ", " << *getStride() << ", " << *getNumElementsPerStride();
+
p.printOptionalAttrDict(getAttrs());
p << " : " << getSrcMemRef()->getType();
p << ", " << getDstMemRef()->getType();
@@ -1550,12 +1518,8 @@ void DmaWaitOp::build(Builder *builder, OperationState &result,
}
void DmaWaitOp::print(OpAsmPrinter &p) {
- p << "dma_wait ";
- p.printOperand(getTagMemRef());
- p << '[';
- p.printOperands(getTagIndices());
- p << "], ";
- p.printOperand(getNumElements());
+ p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
+ << getNumElements();
p.printOptionalAttrDict(getAttrs());
p << " : " << getTagMemRef()->getType();
}
@@ -1604,8 +1568,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
- p << "extract_element " << *op.getAggregate() << '[';
- p.printOperands(op.getIndices());
+ p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
p << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getAggregate()->getType();
@@ -1686,9 +1649,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) {
- p << "load " << *op.getMemRef() << '[';
- p.printOperands(op.getIndices());
- p << ']';
+ p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -1922,12 +1883,8 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReturnOp op) {
p << "return";
- if (op.getNumOperands() != 0) {
- p << ' ';
- p.printOperands(op.getOperands());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
+ if (op.getNumOperands() != 0)
+ p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
}
static LogicalResult verify(ReturnOp op) {
@@ -1984,9 +1941,7 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SelectOp op) {
- p << "select ";
- p.printOperands(op.getOperands());
- p << " : " << op.getTrueValue()->getType();
+ p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
p.printOptionalAttrDict(op.getAttrs());
}
@@ -2093,9 +2048,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
static void print(OpAsmPrinter &p, StoreOp op) {
p << "store " << *op.getValueToStore();
- p << ", " << *op.getMemRef() << '[';
- p.printOperands(op.getIndices());
- p << ']';
+ p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -2339,9 +2292,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
auto *dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
- p << "][";
- p.printOperands(op.getDynamicSizes());
- p << ']';
+ p << "][" << op.getDynamicSizes() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
@@ -2609,13 +2560,8 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SubViewOp op) {
- p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
- p.printOperands(op.offsets());
- p << "][";
- p.printOperands(op.sizes());
- p << "][";
- p.printOperands(op.strides());
- p << ']';
+ p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
+ << "][" << op.sizes() << "][" << op.strides() << ']';
SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 28a03222311..a2345fe1c40 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -110,17 +110,16 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
- for (auto attr : op.getAttrs()) {
+ for (auto attr : op.getAttrs())
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
- }
+
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
p << *op.rhs() << ", " << *op.acc();
- if (llvm::size(op.masks()) == 2) {
- p << ", " << **op.masks().begin();
- p << ", " << **(op.masks().begin() + 1);
- }
+ if (op.masks().size() == 2)
+ p << ", " << op.masks();
+
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType();
@@ -417,9 +416,8 @@ static LogicalResult verify(vector::ExtractOp op) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) {
- p << op.getOperationName() << " " << *op.source();
- p << " : " << op.getSourceType();
- p << " to " << op.getVectorType();
+ p << op.getOperationName() << " " << *op.source() << " : "
+ << op.getSourceType() << " to " << op.getVectorType();
}
static LogicalResult verify(BroadcastOp op) {
@@ -560,8 +558,7 @@ static void print(OpAsmPrinter &p, InsertOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
- p << " : " << op.getSourceType();
- p << " into " << op.getDestVectorType();
+ p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
}
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
@@ -789,8 +786,8 @@ static LogicalResult verify(InsertStridedSliceOp op) {
static void print(OpAsmPrinter &p, OuterProductOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
- if (llvm::size(op.acc()) > 0)
- p << ", " << **op.acc().begin();
+ if (!op.acc().empty())
+ p << ", " << op.acc();
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
}
@@ -1034,16 +1031,10 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
}
static void print(OpAsmPrinter &p, TransferReadOp op) {
- p << op.getOperationName() << " ";
- p.printOperand(op.memref());
- p << "[";
- p.printOperands(op.indices());
- p << "], ";
- p.printOperand(op.padding());
- p << " ";
+ p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
+ << "], " << op.padding() << " ";
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getMemRefType();
- p << ", " << op.getVectorType();
+ p << " : " << op.getMemRefType() << ", " << op.getVectorType();
}
ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
@@ -1106,15 +1097,10 @@ static LogicalResult verify(TransferReadOp op) {
// TransferWriteOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TransferWriteOp op) {
- p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
- p << "[";
- p.printOperands(op.indices());
- p << "]";
+ p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
+ << "[" << op.indices() << "]";
p.printOptionalAttrDict(op.getAttrs());
- p << " : ";
- p.printType(op.getVectorType());
- p << ", ";
- p.printType(op.getMemRefType());
+ p << " : " << op.getVectorType() << ", " << op.getMemRefType();
}
ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
@@ -1180,13 +1166,13 @@ void TypeCastOp::build(Builder *builder, OperationState &result,
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
}
-static void print(OpAsmPrinter &p, TypeCastOp &op) {
+static void print(OpAsmPrinter &p, TypeCastOp op) {
auto type = op.getOperand()->getType().cast<MemRefType>();
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
<< inferVectorTypeCastResultType(type);
}
-static LogicalResult verify(TypeCastOp &op) {
+static LogicalResult verify(TypeCastOp op) {
auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
if (op.getResultMemRefType() != resultType)
return op.emitOpError("expects result type to be: ") << resultType;
@@ -1208,9 +1194,9 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types));
}
-static void print(OpAsmPrinter &p, ConstantMaskOp &op) {
- p << op.getOperationName() << ' ' << op.mask_dim_sizes();
- p << " : " << op.getResult()->getType();
+static void print(OpAsmPrinter &p, ConstantMaskOp op) {
+ p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
+ << op.getResult()->getType();
}
static LogicalResult verify(ConstantMaskOp &op) {
@@ -1256,13 +1242,11 @@ ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(resultType, result.types));
}
-static void print(OpAsmPrinter &p, CreateMaskOp &op) {
- p << op.getOperationName() << ' ';
- p.printOperands(op.operands());
- p << " : " << op.getResult()->getType();
+static void print(OpAsmPrinter &p, CreateMaskOp op) {
+ p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
}
-static LogicalResult verify(CreateMaskOp &op) {
+static LogicalResult verify(CreateMaskOp op) {
// Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() !=
op.getResult()->getType().cast<VectorType>().getRank())
OpenPOWER on IntegriCloud