diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/SPIRVOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 116 |
1 files changed, 43 insertions, 73 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 99ab1cdee42..839f134ec8f 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -500,11 +500,8 @@ static ParseResult parseBitFieldExtractOp(OpAsmParser &parser, } static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName() << ' '; - printer.printOperands(op->getOperands()); - printer << " : " << op->getOperand(0)->getType() << ", " - << op->getOperand(1)->getType() << ", " - << op->getOperand(2)->getType(); + printer << op->getName() << ' ' << op->getOperands() << " : " + << op->getOperandTypes(); } static LogicalResult verifyBitFieldExtractOp(Operation *op) { @@ -580,9 +577,8 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, } static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << logicalOp->getName() << ' '; - printer.printOperands(logicalOp->getOperands()); - printer << " : " << logicalOp->getOperand(0)->getType(); + printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : " + << logicalOp->getOperand(0)->getType(); } static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { @@ -717,9 +713,7 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr() - << '['; - printer.printOperands(op.indices()); - printer << "] : " << op.base_ptr()->getType(); + << '[' << op.indices() << "] : " << op.base_ptr()->getType(); } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { @@ -875,9 +869,8 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp, printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" - << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "; - printer.printOperands(atomOp.getOperands()); - printer << " : " << atomOp.pointer()->getType(); + << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.pointer()->getType(); } static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { @@ -975,9 +968,9 @@ static ParseResult parseBitFieldInsertOp(OpAsmParser &parser, static void print(spirv::BitFieldInsertOp bitFieldInsertOp, OpAsmPrinter &printer) { - printer << spirv::BitFieldInsertOp::getOperationName() << ' '; - printer.printOperands(bitFieldInsertOp.getOperands()); - printer << " : " << bitFieldInsertOp.base()->getType() << ", " + printer << spirv::BitFieldInsertOp::getOperationName() << ' ' + << bitFieldInsertOp.getOperands() << " : " + << bitFieldInsertOp.base()->getType() << ", " << bitFieldInsertOp.offset()->getType() << ", " << bitFieldInsertOp.count()->getType(); } @@ -1072,8 +1065,8 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, } static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { - printer << spirv::BranchConditionalOp::getOperationName() << ' '; - printer.printOperand(branchOp.condition()); + printer << spirv::BranchConditionalOp::getOperationName() << ' ' + << branchOp.condition(); if (auto weights = branchOp.branch_weights()) { printer << " ["; @@ -1148,9 +1141,9 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser, static void print(spirv::CompositeConstructOp compositeConstructOp, OpAsmPrinter &printer) { - printer << spirv::CompositeConstructOp::getOperationName() << " "; - printer.printOperands(compositeConstructOp.constituents()); - printer << " : " << compositeConstructOp.getResult()->getType(); + printer << spirv::CompositeConstructOp::getOperationName() << " " + << compositeConstructOp.constituents() << " : " + << compositeConstructOp.getResult()->getType(); } static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { @@ -1322,9 +1315,8 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) { static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value(); - if (constOp.getType().isa<spirv::ArrayType>()) { + if (constOp.getType().isa<spirv::ArrayType>()) printer << " : " << constOp.getType(); - } } static LogicalResult verify(spirv::ConstantOp constOp) { @@ -1577,9 +1569,8 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { << execModeOp.fn() << " \"" << stringifyExecutionMode(execModeOp.execution_mode()) << "\""; auto values = execModeOp.values(); - if (!values.size()) { + if (!values.size()) return; - } printer << ", "; interleaveComma(values, printer, [&](Attribute a) { printer << a.cast<IntegerAttr>().getInt(); @@ -1626,9 +1617,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); printer << spirv::FunctionCallOp::getOperationName() << ' ' - << functionCallOp.getAttr(kCallee) << '('; - printer.printOperands(functionCallOp.arguments()); - printer << ") : " << functionType; + << functionCallOp.getAttr(kCallee) << '(' + << functionCallOp.arguments() << ") : " << functionType; } static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { @@ -1829,9 +1819,8 @@ static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser, static void print(spirv::GroupNonUniformBallotOp ballotOp, OpAsmPrinter &printer) { printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" - << stringifyScope(ballotOp.execution_scope()) << "\" "; - printer.printOperand(ballotOp.predicate()); - printer << " : " << ballotOp.getType(); + << stringifyScope(ballotOp.execution_scope()) << "\" " + << ballotOp.predicate() << " : " << ballotOp.getType(); } static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { @@ -1943,9 +1932,8 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) { SmallVector<StringRef, 4> elidedAttrs; StringRef sc = stringifyStorageClass( loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); - printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "; - // Print the pointer operand. - printer.printOperand(loadOp.ptr()); + printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" " + << loadOp.ptr(); printMemoryAccessAttribute(loadOp, printer, elidedAttrs); @@ -2238,26 +2226,26 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { - auto *op = moduleOp.getOperation(); + printer << spirv::ModuleOp::getOperationName(); // Only print out addressing model and memory model in a nicer way if both - // presents. Otherwise, print them in the general form. This helps debugging - // ill-formed ModuleOp. + // presents. Otherwise, print them in the general form. This helps + // debugging ill-formed ModuleOp. SmallVector<StringRef, 2> elidedAttrs; auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); - if (op->getAttr(addressingModelAttrName) && - op->getAttr(memoryModelAttrName)) { - printer << spirv::ModuleOp::getOperationName() << " \"" + if (moduleOp.getAttr(addressingModelAttrName) && + moduleOp.getAttr(memoryModelAttrName)) { + printer << " \"" << spirv::stringifyAddressingModel(moduleOp.addressing_model()) << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model()) << '"'; elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); } - printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); - printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs); + printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs); } static LogicalResult verify(spirv::ModuleOp moduleOp) { @@ -2417,9 +2405,8 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser, } static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { - printer << spirv::ReturnValueOp::getOperationName() << ' '; - printer.printOperand(retValOp.value()); - printer << " : " << retValOp.value()->getType(); + printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value() + << " : " << retValOp.value()->getType(); } static LogicalResult verify(spirv::ReturnValueOp retValOp) { @@ -2471,13 +2458,8 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::SelectOp op, OpAsmPrinter &printer) { - printer << spirv::SelectOp::getOperationName() << " "; - - // Print the operands. - printer.printOperands(op.getOperands()); - - // Print colon and types. - printer << " : " << op.condition()->getType() << ", " + printer << spirv::SelectOp::getOperationName() << " " << op.getOperands() + << " : " << op.condition()->getType() << ", " << op.result()->getType(); } @@ -2788,8 +2770,7 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) { printer.printSymbolName(constOp.sym_name()); if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; - printer << " = "; - printer.printAttribute(constOp.default_value()); + printer << " = " << constOp.default_value(); } static LogicalResult verify(spirv::SpecConstantOp constOp) { @@ -2844,17 +2825,12 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) { SmallVector<StringRef, 4> elidedAttrs; StringRef sc = stringifyStorageClass( storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); - printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "; - // Print the pointer operand - printer.printOperand(storeOp.ptr()); - printer << ", "; - // Print the value operand - printer.printOperand(storeOp.value()); + printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" " + << storeOp.ptr() << ", " << storeOp.value(); printMemoryAccessAttribute(storeOp, printer, elidedAttrs); printer << " : " << storeOp.value()->getType(); - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -2885,9 +2861,8 @@ static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, } static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { - printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '; - printer.printOperand(ballotOp.predicate()); - printer << " : " << ballotOp.getType(); + printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ' + << ballotOp.predicate() << " : " << ballotOp.getType(); } //===----------------------------------------------------------------------===// @@ -2973,20 +2948,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { } static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { - auto *op = varOp.getOperation(); SmallVector<StringRef, 4> elidedAttrs{ spirv::attributeName<spirv::StorageClass>()}; printer << spirv::VariableOp::getOperationName(); // Print optional initializer - if (op->getNumOperands() > 0) { - printer << " init("; - printer.printOperands(varOp.initializer()); - printer << ")"; - } - - printVariableDecorations(op, printer, elidedAttrs); + if (varOp.getNumOperands() != 0) + printer << " init(" << varOp.initializer() << ")"; + printVariableDecorations(varOp, printer, elidedAttrs); printer << " : " << varOp.getType(); } |