diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-20 10:19:01 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-20 10:45:45 -0800 |
| commit | eb418559ef29716cc34c891c93490c38ac5ea1dd (patch) | |
| tree | be1e4ac3d32e5df31b8668785ba15f92fe1895a9 | |
| parent | 3c055957de7e47e53d3ee8f5ab283cdb5c0ea535 (diff) | |
| download | bcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.tar.gz bcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.zip | |
Add a new OpAsmOpInterface to allow for ops to directly hook into the AsmPrinter.
This interface provides more fine-grained hooks into the AsmPrinter than the dialect interface, allowing for operations to define the asm name to use for results directly on the operations themselves. The hook is also expanded to enable defining named result "groups". Get a special name to use when printing the results of this operation.
The given callback is invoked with a specific result value that starts a
result "pack", and the name to give this result pack. To signal that a
result pack should use the default naming scheme, a None can be passed
in instead of the name.
For example, if you have an operation that has four results and you want
to split these into three distinct groups you could do the following:
setNameFn(getResult(0), "first_result");
setNameFn(getResult(1), "middle_results");
setNameFn(getResult(3), ""); // use the default numbering.
This would print the operation as follows:
%first_result, %middle_results:2, %0 = "my.op" ...
PiperOrigin-RevId: 281546873
| -rw-r--r-- | mlir/include/mlir/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.h | 3 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.td | 4 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/OpAsmInterface.td | 65 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/OpImplementation.h | 18 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 58 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 193 | ||||
| -rw-r--r-- | mlir/lib/IR/CMakeLists.txt | 2 | ||||
| -rw-r--r-- | mlir/test/IR/parser.mlir | 15 | ||||
| -rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 31 | ||||
| -rw-r--r-- | mlir/test/lib/TestDialect/TestOps.td | 14 |
12 files changed, 315 insertions, 93 deletions
diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt index 1a5094df90d..43eacfc91d5 100644 --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(EDSC) +add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index 64e52ba9fd5..cd4ce2c9f48 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -24,10 +24,9 @@ #define MLIR_DIALECT_STANDARDOPS_OPS_H #include "mlir/Analysis/CallInterfaces.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" // Pull in all enum type definitions and utility function declarations. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index eb7ebbb8f6e..f788de76e84 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -26,6 +26,7 @@ include "mlir/IR/OpBase.td" #endif // OP_BASE +include "mlir/IR/OpAsmInterface.td" include "mlir/Analysis/CallInterfaces.td" def Std_Dialect : Dialect { @@ -580,7 +581,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> { let hasCanonicalizer = 1; } -def ConstantOp : Std_Op<"constant", [NoSideEffect]> { +def ConstantOp : Std_Op<"constant", + [NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> { let summary = "constant"; let arguments = (ins AnyAttr:$value); diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt new file mode 100644 index 00000000000..555b16fd29d --- /dev/null +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td) +mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIROpAsmInterfacesIncGen) diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td new file mode 100644 index 00000000000..974360e72d2 --- /dev/null +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -0,0 +1,65 @@ +//===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains Interfaces for interacting with the AsmParser and +// AsmPrinter. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_OPASMINTERFACE +#define MLIR_OPASMINTERFACE + +#ifndef OP_BASE +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +/// Interface for hooking into the OpAsmPrinter and OpAsmParser. +def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> { + let description = [{ + This interface provides hooks to interact with the AsmPrinter and AsmParser + classes. + }]; + + let methods = [ + InterfaceMethod<[{ + Get a special name to use when printing the results of this operation. + The given callback is invoked with a specific result value that starts a + result "pack", and the name to give this result pack. To signal that a + result pack should use the default naming scheme, a None can be passed + in instead of the name. + + For example, if you have an operation that has four results and you want + to split these into three distinct groups you could do the following: + + ```c++ + setNameFn(getResult(0), "first_result"); + setNameFn(getResult(1), "middle_results"); + setNameFn(getResult(3), ""); // use the default numbering. + ``` + + This would print the operation as follows: + + ```mlir + %first_result, %middle_results:2, %0 = "my.op" ... + ``` + }], + "void", "getAsmResultNames", (ins "OpAsmSetValueNameFn":$setNameFn) + >, + ]; +} + +#endif // MLIR_OPASMINTERFACE diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 4a970c08b44..666a90ec6e1 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -600,6 +600,10 @@ private: // Dialect OpAsm interface. //===--------------------------------------------------------------------===// +/// A functor used to set the name of the start of a result group of an +/// operation. See 'getAsmResultNames' below for more details. +using OpAsmSetValueNameFn = function_ref<void(Value *, StringRef)>; + class OpAsmDialectInterface : public DialectInterface::Base<OpAsmDialectInterface> { public: @@ -621,11 +625,19 @@ public: virtual void getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {} - /// Get a special name to use when printing the given operation. The desired - /// name should be streamed into 'os'. - virtual void getOpResultName(Operation *op, raw_ostream &os) const {} + /// Get a special name to use when printing the given operation. See + /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. + virtual void getAsmResultNames(Operation *op, + OpAsmSetValueNameFn setNameFn) const {} }; +//===--------------------------------------------------------------------===// +// Operation OpAsm interface. +//===--------------------------------------------------------------------===// + +/// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmInterface.h.inc" + } // end namespace mlir #endif diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 83c086784c4..21535515a66 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -44,37 +44,6 @@ using namespace mlir; // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// namespace { -struct StdOpAsmInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - - /// Get a special name to use when printing the given operation. The desired - /// name should be streamed into 'os'. - void getOpResultName(Operation *op, raw_ostream &os) const final { - if (ConstantOp constant = dyn_cast<ConstantOp>(op)) - return getConstantOpResultName(constant, os); - } - - /// Get a special name to use when printing the given constant. - static void getConstantOpResultName(ConstantOp op, raw_ostream &os) { - Type type = op.getType(); - Attribute value = op.getValue(); - if (auto intCst = value.dyn_cast<IntegerAttr>()) { - if (type.isIndex()) { - os << 'c' << intCst.getInt(); - } else if (type.cast<IntegerType>().isInteger(1)) { - // i1 constants get special names. - os << (intCst.getInt() ? "true" : "false"); - } else { - os << 'c' << intCst.getInt() << '_' << type; - } - } else if (type.isa<FunctionType>()) { - os << 'f'; - } else { - os << "cst"; - } - } -}; - /// This class defines the interface for handling inlining with standard /// operations. struct StdInlinerInterface : public DialectInlinerInterface { @@ -191,7 +160,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) #define GET_OP_LIST #include "mlir/Dialect/StandardOps/Ops.cpp.inc" >(); - addInterfaces<StdInlinerInterface, StdOpAsmInterface>(); + addInterfaces<StdInlinerInterface>(); } void mlir::printDimAndSymbolList(Operation::operand_iterator begin, @@ -1183,6 +1152,31 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return getValue(); } +void ConstantOp::getAsmResultNames( + function_ref<void(Value *, StringRef)> setNameFn) { + Type type = getType(); + if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { + IntegerType intTy = type.dyn_cast<IntegerType>(); + + // Sugar i1 constants with 'true' and 'false'. + if (intTy && intTy.getWidth() == 1) + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + + // Otherwise, build a complex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst.getInt(); + if (intTy) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + + } else if (type.isa<FunctionType>()) { + setNameFn(getResult(), "f"); + } else { + setNameFn(getResult(), "cst"); + } +} + /// Returns true if a constant operation can be built with the given value and /// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8ffe9c51a9c..655a776118c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {} OpAsmPrinter::~OpAsmPrinter() {} +//===--------------------------------------------------------------------===// +// Operation OpAsm interface. +//===--------------------------------------------------------------------===// + +/// The OpAsmOpInterface, see OpAsmInterface.td for more details. +#include "mlir/IR/OpAsmInterface.cpp.inc" + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1490,17 +1497,26 @@ public: const static unsigned indentWidth = 2; protected: - void numberValueID(Value *value); void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); + void numberValuesInOp(Operation &op); void printValueID(Value *value, bool printResultNo = true) const { printValueIDImpl(value, printResultNo, os); } private: + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult *result, Value *&lookupValue, + int &lookupResultNo) const; void printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const; + /// Set a special value name for the given value. + void setValueName(Value *value, StringRef name); + /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); @@ -1510,6 +1526,11 @@ private: DenseMap<Value *, unsigned> valueIDs; DenseMap<Value *, StringRef> valueNames; + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The key + /// of this map are the result numbers that start a result group. + DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; + /// This is the block ID for each block in the current. DenseMap<Block *, unsigned> blockIDs; @@ -1534,8 +1555,7 @@ private: OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) : ModulePrinter(other) { llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); - if (op->getNumResults() != 0) - numberValueID(op->getResult(0)); + numberValuesInOp(*op); for (auto ®ion : op->getRegions()) numberValuesInRegion(region); @@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) numberValuesInRegion(*region); } -/// Number all of the SSA values in the specified region. void OperationPrinter::numberValuesInRegion(Region ®ion) { // Save the current value ids to allow for numbering values in sibling regions // the same. @@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { nextConflictID = curConflictID; } -/// Number all of the SSA values in the specified block, without traversing -/// nested regions. void OperationPrinter::numberValuesInBlock(Block &block) { - // Number the block arguments. - for (auto *arg : block.getArguments()) - numberValueID(arg); + bool isEntryBlock = block.isEntryBlock(); + + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto *arg : block.getArguments()) { + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; + } + setValueName(arg, specialName.str()); + } - // We number operation that have results, and we only number the first result. + // Number the operations in this block. for (auto &op : block) - if (op.getNumResults() != 0) - numberValueID(op.getResult(0)); + numberValuesInOp(op); } -void OperationPrinter::numberValueID(Value *value) { - assert(!valueIDs.count(value) && "Value numbered multiple times"); - - SmallString<32> specialNameBuffer; - llvm::raw_svector_ostream specialName(specialNameBuffer); +void OperationPrinter::numberValuesInOp(Operation &op) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value *resultBegin = op.getResult(0); + + // Function used to set the special result names for the operation. + SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value *result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); + + // Record the result number for groups not anchored at 0. + if (int resultNo = cast<OpResult>(result)->getResultNumber()) + resultGroups.push_back(resultNo); + }; - // Check to see if this value requested a special name. - auto *op = value->getDefiningOp(); - if (state && op) { - if (auto *interface = state->getOpAsmInterface(op->getDialect())) - interface->getOpResultName(op, specialName); + if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { + asmInterface.getAsmResultNames(setResultNameFn); + } else if (auto *dialectAsmInterface = + state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { + dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); } - if (specialNameBuffer.empty()) { - auto *blockArg = dyn_cast<BlockArgument>(value); - if (!blockArg) { - // This is an uninteresting operation result, give it a boring number and - // be done with it. - valueIDs[value] = nextValueID++; - return; - } + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; - // Otherwise, if this is an argument to the entry block of a region, give it - // an 'arg' name. - if (auto *block = blockArg->getOwner()) { - auto *parentRegion = block->getParent(); - if (parentRegion && block == &parentRegion->front()) - specialName << "arg" << nextArgumentID++; - } + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); + } +} - // Otherwise number it normally. - if (specialNameBuffer.empty()) { - valueIDs[value] = nextValueID++; - return; - } +/// Set a special value name for the given value. +void OperationPrinter::setValueName(Value *value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; } - // Ok, this value had an interesting name. Remember it with a sentinel. valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(specialName.str()); + valueNames[value] = uniqueValueName(name); } /// Uniques the given value name within the printer. If the given name @@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } +void OperationPrinter::getResultIDAndNumber(OpResult *result, + Value *&lookupValue, + int &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef<int> resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, raw_ostream &stream) const { if (!value) { @@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, // If this is a reference to the result of a multi-result operation or // operation, print out the # identifier and make sure to map our lookup // to the first result of the operation. - if (auto *result = dyn_cast<OpResult>(value)) { - if (result->getOwner()->getNumResults() != 1) { - resultNo = result->getResultNumber(); - lookupValue = result->getOwner()->getResult(0); - } - } + if (OpResult *result = dyn_cast<OpResult>(value)) + getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { @@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region ®ion, void OperationPrinter::printOperation(Operation *op) { if (size_t numResults = op->getNumResults()) { - printValueID(op->getResult(0), /*printResultNo=*/false); - if (numResults > 1) - os << ':' << numResults; + auto printResultGroup = [&](size_t resultNo, size_t resultCount) { + printValueID(op->getResult(resultNo), /*printResultNo=*/false); + if (resultCount > 1) + os << ':' << resultCount; + }; + + // Check to see if this operation has multiple result groups. + auto resultGroupIt = opResultGroups.find(op); + if (resultGroupIt != opResultGroups.end()) { + ArrayRef<int> resultGroups = resultGroupIt->second; + // Interleave the groups excluding the last one, this one will be handled + // separately. + interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { + printResultGroup(resultGroups[i], + resultGroups[i + 1] - resultGroups[i]); + }); + os << ", "; + printResultGroup(resultGroups.back(), numResults - resultGroups.back()); + + } else { + printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); + } + os << " = "; } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 2519b83eede..415d9d66e22 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -5,5 +5,5 @@ add_llvm_library(MLIRIR ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR ) -add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport) +add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport) target_link_libraries(MLIRIR MLIRSupport LLVMSupport) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index dc85fbb14b3..a6460b46dd9 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1116,3 +1116,18 @@ func @"\"_string_symbol_reference\""() { // CHECK-LABEL: func @nested_reference // CHECK-NEXT: ref = @some_symbol::@some_nested_symbol func @nested_reference() attributes {test.ref = @some_symbol::@some_nested_symbol } + +// CHECK-LABEL: func @custom_asm_names +func @custom_asm_names() -> (i32, i32, i32, i32, i32, i32, i32) { + // CHECK: %[[FIRST:first.*]], %[[MIDDLE:middle_results.*]]:2, %[[LAST:[0-9]+]] + %0, %1:2, %2 = "test.asm_interface_op"() : () -> (i32, i32, i32, i32) + + // CHECK: %[[FIRST_2:first.*]], %[[LAST_2:[0-9]+]] + %3, %4 = "test.asm_interface_op"() : () -> (i32, i32) + + // CHECK: %[[RESULT:result.*]] + %5 = "test.asm_dialect_interface_op"() : () -> (i32) + + // CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]] + return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32 +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 01780432a1a..d838f75f7e7 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -30,6 +30,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// namespace { + +// Test support for interacting with the AsmPrinter. +struct TestOpAsmInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + void getAsmResultNames(Operation *op, + OpAsmSetValueNameFn setNameFn) const final { + if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) + setNameFn(asmOp, "result"); + } +}; + struct TestOpFolderDialectInterface : public OpFolderDialectInterface { using OpFolderDialectInterface::OpFolderDialectInterface; @@ -112,7 +124,8 @@ TestDialect::TestDialect(MLIRContext *context) #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addInterfaces<TestOpFolderDialectInterface, TestInlinerInterface>(); + addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, + TestInlinerInterface>(); allowUnknownOperations(); } @@ -227,6 +240,7 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) { //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// + static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 4> ivsInfo; // Parse list of region arguments without a delimiter. @@ -241,6 +255,21 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { } //===----------------------------------------------------------------------===// +// Test OpAsmInterface. +//===----------------------------------------------------------------------===// + +void AsmInterfaceOp::getAsmResultNames( + function_ref<void(Value *, StringRef)> setNameFn) { + // Give a name to the first and middle results. + setNameFn(firstResult(), "first"); + if (!llvm::empty(middleResults())) + setNameFn(*middleResults().begin(), "middle_results"); + + // Use default numbering for the last result. + setNameFn(getResult(getNumResults() - 1), ""); +} + +//===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index a0e1cd61ba4..1ccbda2f9b1 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -19,6 +19,7 @@ #define TEST_OPS include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/Analysis/CallInterfaces.td" include "mlir/Analysis/InferTypeOpInterface.td" @@ -995,4 +996,17 @@ def PolyForOp : TEST_Op<"polyfor"> let parser = [{ return ::parse$cppClass(parser, result); }]; } +//===----------------------------------------------------------------------===// +// Test OpAsmInterface. + +def AsmInterfaceOp : TEST_Op<"asm_interface_op", + [DeclareOpInterfaceMethods<OpAsmOpInterface>]> { + let results = (outs AnyType:$firstResult, Variadic<AnyType>:$middleResults, + AnyType); +} + +def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> { + let results = (outs AnyType); +} + #endif // TEST_OPS |

