summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/SPIRVOps.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp116
1 files changed, 43 insertions, 73 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 99ab1cdee42..839f134ec8f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
}
static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
- printer << op->getName() << ' ';
- printer.printOperands(op->getOperands());
- printer << " : " << op->getOperand(0)->getType() << ", "
- << op->getOperand(1)->getType() << ", "
- << op->getOperand(2)->getType();
+ printer << op->getName() << ' ' << op->getOperands() << " : "
+ << op->getOperandTypes();
}
static LogicalResult verifyBitFieldExtractOp(Operation *op) {
@@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
}
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
- printer << logicalOp->getName() << ' ';
- printer.printOperands(logicalOp->getOperands());
- printer << " : " << logicalOp->getOperand(0)->getType();
+ printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
+ << logicalOp->getOperand(0)->getType();
}
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
- << '[';
- printer.printOperands(op.indices());
- printer << "] : " << op.base_ptr()->getType();
+ << '[' << op.indices() << "] : " << op.base_ptr()->getType();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
@@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
<< stringifyScope(atomOp.memory_scope()) << "\" \""
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
- << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" ";
- printer.printOperands(atomOp.getOperands());
- printer << " : " << atomOp.pointer()->getType();
+ << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.pointer()->getType();
}
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) {
- printer << spirv::BitFieldInsertOp::getOperationName() << ' ';
- printer.printOperands(bitFieldInsertOp.getOperands());
- printer << " : " << bitFieldInsertOp.base()->getType() << ", "
+ printer << spirv::BitFieldInsertOp::getOperationName() << ' '
+ << bitFieldInsertOp.getOperands() << " : "
+ << bitFieldInsertOp.base()->getType() << ", "
<< bitFieldInsertOp.offset()->getType() << ", "
<< bitFieldInsertOp.count()->getType();
}
@@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
}
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
- printer << spirv::BranchConditionalOp::getOperationName() << ' ';
- printer.printOperand(branchOp.condition());
+ printer << spirv::BranchConditionalOp::getOperationName() << ' '
+ << branchOp.condition();
if (auto weights = branchOp.branch_weights()) {
printer << " [";
@@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
static void print(spirv::CompositeConstructOp compositeConstructOp,
OpAsmPrinter &printer) {
- printer << spirv::CompositeConstructOp::getOperationName() << " ";
- printer.printOperands(compositeConstructOp.constituents());
- printer << " : " << compositeConstructOp.getResult()->getType();
+ printer << spirv::CompositeConstructOp::getOperationName() << " "
+ << compositeConstructOp.constituents() << " : "
+ << compositeConstructOp.getResult()->getType();
}
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
- if (constOp.getType().isa<spirv::ArrayType>()) {
+ if (constOp.getType().isa<spirv::ArrayType>())
printer << " : " << constOp.getType();
- }
}
static LogicalResult verify(spirv::ConstantOp constOp) {
@@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
<< execModeOp.fn() << " \""
<< stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
auto values = execModeOp.values();
- if (!values.size()) {
+ if (!values.size())
return;
- }
printer << ", ";
interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
@@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
printer << spirv::FunctionCallOp::getOperationName() << ' '
- << functionCallOp.getAttr(kCallee) << '(';
- printer.printOperands(functionCallOp.arguments());
- printer << ") : " << functionType;
+ << functionCallOp.getAttr(kCallee) << '('
+ << functionCallOp.arguments() << ") : " << functionType;
}
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
@@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
static void print(spirv::GroupNonUniformBallotOp ballotOp,
OpAsmPrinter &printer) {
printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
- << stringifyScope(ballotOp.execution_scope()) << "\" ";
- printer.printOperand(ballotOp.predicate());
- printer << " : " << ballotOp.getType();
+ << stringifyScope(ballotOp.execution_scope()) << "\" "
+ << ballotOp.predicate() << " : " << ballotOp.getType();
}
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
@@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
- // Print the pointer operand.
- printer.printOperand(loadOp.ptr());
+ printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
+ << loadOp.ptr();
printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
@@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
- auto *op = moduleOp.getOperation();
+ printer << spirv::ModuleOp::getOperationName();
// Only print out addressing model and memory model in a nicer way if both
- // presents. Otherwise, print them in the general form. This helps debugging
- // ill-formed ModuleOp.
+ // presents. Otherwise, print them in the general form. This helps
+ // debugging ill-formed ModuleOp.
SmallVector<StringRef, 2> elidedAttrs;
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
- if (op->getAttr(addressingModelAttrName) &&
- op->getAttr(memoryModelAttrName)) {
- printer << spirv::ModuleOp::getOperationName() << " \""
+ if (moduleOp.getAttr(addressingModelAttrName) &&
+ moduleOp.getAttr(memoryModelAttrName)) {
+ printer << " \""
<< spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
<< '"';
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
}
- printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+ printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
- printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
+ printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
}
static LogicalResult verify(spirv::ModuleOp moduleOp) {
@@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
}
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
- printer << spirv::ReturnValueOp::getOperationName() << ' ';
- printer.printOperand(retValOp.value());
- printer << " : " << retValOp.value()->getType();
+ printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
+ << " : " << retValOp.value()->getType();
}
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
- printer << spirv::SelectOp::getOperationName() << " ";
-
- // Print the operands.
- printer.printOperands(op.getOperands());
-
- // Print colon and types.
- printer << " : " << op.condition()->getType() << ", "
+ printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
+ << " : " << op.condition()->getType() << ", "
<< op.result()->getType();
}
@@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
printer.printSymbolName(constOp.sym_name());
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
- printer << " = ";
- printer.printAttribute(constOp.default_value());
+ printer << " = " << constOp.default_value();
}
static LogicalResult verify(spirv::SpecConstantOp constOp) {
@@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
- // Print the pointer operand
- printer.printOperand(storeOp.ptr());
- printer << ", ";
- // Print the value operand
- printer.printOperand(storeOp.value());
+ printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
+ << storeOp.ptr() << ", " << storeOp.value();
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
printer << " : " << storeOp.value()->getType();
-
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
@@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
}
static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
- printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
- printer.printOperand(ballotOp.predicate());
- printer << " : " << ballotOp.getType();
+ printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
+ << ballotOp.predicate() << " : " << ballotOp.getType();
}
//===----------------------------------------------------------------------===//
@@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
}
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
- auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
printer << spirv::VariableOp::getOperationName();
// Print optional initializer
- if (op->getNumOperands() > 0) {
- printer << " init(";
- printer.printOperands(varOp.initializer());
- printer << ")";
- }
-
- printVariableDecorations(op, printer, elidedAttrs);
+ if (varOp.getNumOperands() != 0)
+ printer << " init(" << varOp.initializer() << ")";
+ printVariableDecorations(varOp, printer, elidedAttrs);
printer << " : " << varOp.getType();
}
OpenPOWER on IntegriCloud