summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.h3
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td4
-rw-r--r--mlir/include/mlir/IR/CMakeLists.txt4
-rw-r--r--mlir/include/mlir/IR/OpAsmInterface.td65
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h18
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp58
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp193
-rw-r--r--mlir/lib/IR/CMakeLists.txt2
-rw-r--r--mlir/test/IR/parser.mlir15
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp31
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td14
12 files changed, 315 insertions, 93 deletions
diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt
index 1a5094df90d..43eacfc91d5 100644
--- a/mlir/include/mlir/CMakeLists.txt
+++ b/mlir/include/mlir/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(EDSC)
+add_subdirectory(IR)
add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h
index 64e52ba9fd5..cd4ce2c9f48 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h
@@ -24,10 +24,9 @@
#define MLIR_DIALECT_STANDARDOPS_OPS_H
#include "mlir/Analysis/CallInterfaces.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
// Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index eb7ebbb8f6e..f788de76e84 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -26,6 +26,7 @@
include "mlir/IR/OpBase.td"
#endif // OP_BASE
+include "mlir/IR/OpAsmInterface.td"
include "mlir/Analysis/CallInterfaces.td"
def Std_Dialect : Dialect {
@@ -580,7 +581,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
let hasCanonicalizer = 1;
}
-def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
+def ConstantOp : Std_Op<"constant",
+ [NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "constant";
let arguments = (ins AnyAttr:$value);
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
new file mode 100644
index 00000000000..555b16fd29d
--- /dev/null
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
+mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIROpAsmInterfacesIncGen)
diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td
new file mode 100644
index 00000000000..974360e72d2
--- /dev/null
+++ b/mlir/include/mlir/IR/OpAsmInterface.td
@@ -0,0 +1,65 @@
+//===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains Interfaces for interacting with the AsmParser and
+// AsmPrinter.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_OPASMINTERFACE
+#define MLIR_OPASMINTERFACE
+
+#ifndef OP_BASE
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+/// Interface for hooking into the OpAsmPrinter and OpAsmParser.
+def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
+ let description = [{
+ This interface provides hooks to interact with the AsmPrinter and AsmParser
+ classes.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Get a special name to use when printing the results of this operation.
+ The given callback is invoked with a specific result value that starts a
+ result "pack", and the name to give this result pack. To signal that a
+ result pack should use the default naming scheme, a None can be passed
+ in instead of the name.
+
+ For example, if you have an operation that has four results and you want
+ to split these into three distinct groups you could do the following:
+
+ ```c++
+ setNameFn(getResult(0), "first_result");
+ setNameFn(getResult(1), "middle_results");
+ setNameFn(getResult(3), ""); // use the default numbering.
+ ```
+
+ This would print the operation as follows:
+
+ ```mlir
+ %first_result, %middle_results:2, %0 = "my.op" ...
+ ```
+ }],
+ "void", "getAsmResultNames", (ins "OpAsmSetValueNameFn":$setNameFn)
+ >,
+ ];
+}
+
+#endif // MLIR_OPASMINTERFACE
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 4a970c08b44..666a90ec6e1 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -600,6 +600,10 @@ private:
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
+/// A functor used to set the name of the start of a result group of an
+/// operation. See 'getAsmResultNames' below for more details.
+using OpAsmSetValueNameFn = function_ref<void(Value *, StringRef)>;
+
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
@@ -621,11 +625,19 @@ public:
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
- /// Get a special name to use when printing the given operation. The desired
- /// name should be streamed into 'os'.
- virtual void getOpResultName(Operation *op, raw_ostream &os) const {}
+ /// Get a special name to use when printing the given operation. See
+ /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
+ virtual void getAsmResultNames(Operation *op,
+ OpAsmSetValueNameFn setNameFn) const {}
};
+//===--------------------------------------------------------------------===//
+// Operation OpAsm interface.
+//===--------------------------------------------------------------------===//
+
+/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmInterface.h.inc"
+
} // end namespace mlir
#endif
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 83c086784c4..21535515a66 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -44,37 +44,6 @@ using namespace mlir;
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
-struct StdOpAsmInterface : public OpAsmDialectInterface {
- using OpAsmDialectInterface::OpAsmDialectInterface;
-
- /// Get a special name to use when printing the given operation. The desired
- /// name should be streamed into 'os'.
- void getOpResultName(Operation *op, raw_ostream &os) const final {
- if (ConstantOp constant = dyn_cast<ConstantOp>(op))
- return getConstantOpResultName(constant, os);
- }
-
- /// Get a special name to use when printing the given constant.
- static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
- Type type = op.getType();
- Attribute value = op.getValue();
- if (auto intCst = value.dyn_cast<IntegerAttr>()) {
- if (type.isIndex()) {
- os << 'c' << intCst.getInt();
- } else if (type.cast<IntegerType>().isInteger(1)) {
- // i1 constants get special names.
- os << (intCst.getInt() ? "true" : "false");
- } else {
- os << 'c' << intCst.getInt() << '_' << type;
- }
- } else if (type.isa<FunctionType>()) {
- os << 'f';
- } else {
- os << "cst";
- }
- }
-};
-
/// This class defines the interface for handling inlining with standard
/// operations.
struct StdInlinerInterface : public DialectInlinerInterface {
@@ -191,7 +160,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
>();
- addInterfaces<StdInlinerInterface, StdOpAsmInterface>();
+ addInterfaces<StdInlinerInterface>();
}
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
@@ -1183,6 +1152,31 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}
+void ConstantOp::getAsmResultNames(
+ function_ref<void(Value *, StringRef)> setNameFn) {
+ Type type = getType();
+ if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
+ IntegerType intTy = type.dyn_cast<IntegerType>();
+
+ // Sugar i1 constants with 'true' and 'false'.
+ if (intTy && intTy.getWidth() == 1)
+ return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
+
+ // Otherwise, build a complex name with the value and type.
+ SmallString<32> specialNameBuffer;
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ specialName << 'c' << intCst.getInt();
+ if (intTy)
+ specialName << '_' << type;
+ setNameFn(getResult(), specialName.str());
+
+ } else if (type.isa<FunctionType>()) {
+ setNameFn(getResult(), "f");
+ } else {
+ setNameFn(getResult(), "cst");
+ }
+}
+
/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8ffe9c51a9c..655a776118c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {}
OpAsmPrinter::~OpAsmPrinter() {}
+//===--------------------------------------------------------------------===//
+// Operation OpAsm interface.
+//===--------------------------------------------------------------------===//
+
+/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmInterface.cpp.inc"
+
//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
@@ -1490,17 +1497,26 @@ public:
const static unsigned indentWidth = 2;
protected:
- void numberValueID(Value *value);
void numberValuesInRegion(Region &region);
void numberValuesInBlock(Block &block);
+ void numberValuesInOp(Operation &op);
void printValueID(Value *value, bool printResultNo = true) const {
printValueIDImpl(value, printResultNo, os);
}
private:
+ /// Given a result of an operation 'result', find the result group head
+ /// 'lookupValue' and the result of 'result' within that group in
+ /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
+ /// has more than 1 result.
+ void getResultIDAndNumber(OpResult *result, Value *&lookupValue,
+ int &lookupResultNo) const;
void printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const;
+ /// Set a special value name for the given value.
+ void setValueName(Value *value, StringRef name);
+
/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
@@ -1510,6 +1526,11 @@ private:
DenseMap<Value *, unsigned> valueIDs;
DenseMap<Value *, StringRef> valueNames;
+ /// This is a map of operations that contain multiple named result groups,
+ /// i.e. there may be multiple names for the results of the operation. The key
+ /// of this map are the result numbers that start a result group.
+ DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
+
/// This is the block ID for each block in the current.
DenseMap<Block *, unsigned> blockIDs;
@@ -1534,8 +1555,7 @@ private:
OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
: ModulePrinter(other) {
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
- if (op->getNumResults() != 0)
- numberValueID(op->getResult(0));
+ numberValuesInOp(*op);
for (auto &region : op->getRegions())
numberValuesInRegion(region);
@@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
numberValuesInRegion(*region);
}
-/// Number all of the SSA values in the specified region.
void OperationPrinter::numberValuesInRegion(Region &region) {
// Save the current value ids to allow for numbering values in sibling regions
// the same.
@@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region &region) {
nextConflictID = curConflictID;
}
-/// Number all of the SSA values in the specified block, without traversing
-/// nested regions.
void OperationPrinter::numberValuesInBlock(Block &block) {
- // Number the block arguments.
- for (auto *arg : block.getArguments())
- numberValueID(arg);
+ bool isEntryBlock = block.isEntryBlock();
+
+ // Number the block arguments. We give entry block arguments a special name
+ // 'arg'.
+ SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ for (auto *arg : block.getArguments()) {
+ if (isEntryBlock) {
+ specialNameBuffer.resize(strlen("arg"));
+ specialName << nextArgumentID++;
+ }
+ setValueName(arg, specialName.str());
+ }
- // We number operation that have results, and we only number the first result.
+ // Number the operations in this block.
for (auto &op : block)
- if (op.getNumResults() != 0)
- numberValueID(op.getResult(0));
+ numberValuesInOp(op);
}
-void OperationPrinter::numberValueID(Value *value) {
- assert(!valueIDs.count(value) && "Value numbered multiple times");
-
- SmallString<32> specialNameBuffer;
- llvm::raw_svector_ostream specialName(specialNameBuffer);
+void OperationPrinter::numberValuesInOp(Operation &op) {
+ unsigned numResults = op.getNumResults();
+ if (numResults == 0)
+ return;
+ Value *resultBegin = op.getResult(0);
+
+ // Function used to set the special result names for the operation.
+ SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ auto setResultNameFn = [&](Value *result, StringRef name) {
+ assert(!valueIDs.count(result) && "result numbered multiple times");
+ assert(result->getDefiningOp() == &op && "result not defined by 'op'");
+ setValueName(result, name);
+
+ // Record the result number for groups not anchored at 0.
+ if (int resultNo = cast<OpResult>(result)->getResultNumber())
+ resultGroups.push_back(resultNo);
+ };
- // Check to see if this value requested a special name.
- auto *op = value->getDefiningOp();
- if (state && op) {
- if (auto *interface = state->getOpAsmInterface(op->getDialect()))
- interface->getOpResultName(op, specialName);
+ if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
+ asmInterface.getAsmResultNames(setResultNameFn);
+ } else if (auto *dialectAsmInterface =
+ state ? state->getOpAsmInterface(op.getDialect()) : nullptr) {
+ dialectAsmInterface->getAsmResultNames(&op, setResultNameFn);
}
- if (specialNameBuffer.empty()) {
- auto *blockArg = dyn_cast<BlockArgument>(value);
- if (!blockArg) {
- // This is an uninteresting operation result, give it a boring number and
- // be done with it.
- valueIDs[value] = nextValueID++;
- return;
- }
+ // If the first result wasn't numbered, give it a default number.
+ if (valueIDs.try_emplace(resultBegin, nextValueID).second)
+ ++nextValueID;
- // Otherwise, if this is an argument to the entry block of a region, give it
- // an 'arg' name.
- if (auto *block = blockArg->getOwner()) {
- auto *parentRegion = block->getParent();
- if (parentRegion && block == &parentRegion->front())
- specialName << "arg" << nextArgumentID++;
- }
+ // If this operation has multiple result groups, mark it.
+ if (resultGroups.size() != 1) {
+ llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
+ opResultGroups.try_emplace(&op, std::move(resultGroups));
+ }
+}
- // Otherwise number it normally.
- if (specialNameBuffer.empty()) {
- valueIDs[value] = nextValueID++;
- return;
- }
+/// Set a special value name for the given value.
+void OperationPrinter::setValueName(Value *value, StringRef name) {
+ // If the name is empty, the value uses the default numbering.
+ if (name.empty()) {
+ valueIDs[value] = nextValueID++;
+ return;
}
- // Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
- valueNames[value] = uniqueValueName(specialName.str());
+ valueNames[value] = uniqueValueName(name);
}
/// Uniques the given value name within the printer. If the given name
@@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) {
printTrailingLocation(op->getLoc());
}
+void OperationPrinter::getResultIDAndNumber(OpResult *result,
+ Value *&lookupValue,
+ int &lookupResultNo) const {
+ Operation *owner = result->getOwner();
+ if (owner->getNumResults() == 1)
+ return;
+ int resultNo = result->getResultNumber();
+
+ // If this operation has multiple result groups, we will need to find the
+ // one corresponding to this result.
+ auto resultGroupIt = opResultGroups.find(owner);
+ if (resultGroupIt == opResultGroups.end()) {
+ // If not, just use the first result.
+ lookupResultNo = resultNo;
+ lookupValue = owner->getResult(0);
+ return;
+ }
+
+ // Find the correct index using a binary search, as the groups are ordered.
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ auto it = llvm::upper_bound(resultGroups, resultNo);
+ int groupResultNo = 0, groupSize = 0;
+
+ // If there are no smaller elements, the last result group is the lookup.
+ if (it == resultGroups.end()) {
+ groupResultNo = resultGroups.back();
+ groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
+ } else {
+ // Otherwise, the previous element is the lookup.
+ groupResultNo = *std::prev(it);
+ groupSize = *it - groupResultNo;
+ }
+
+ // We only record the result number for a group of size greater than 1.
+ if (groupSize != 1)
+ lookupResultNo = resultNo - groupResultNo;
+ lookupValue = owner->getResult(groupResultNo);
+}
+
void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const {
if (!value) {
@@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
// If this is a reference to the result of a multi-result operation or
// operation, print out the # identifier and make sure to map our lookup
// to the first result of the operation.
- if (auto *result = dyn_cast<OpResult>(value)) {
- if (result->getOwner()->getNumResults() != 1) {
- resultNo = result->getResultNumber();
- lookupValue = result->getOwner()->getResult(0);
- }
- }
+ if (OpResult *result = dyn_cast<OpResult>(value))
+ getResultIDAndNumber(result, lookupValue, resultNo);
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
@@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region &region,
void OperationPrinter::printOperation(Operation *op) {
if (size_t numResults = op->getNumResults()) {
- printValueID(op->getResult(0), /*printResultNo=*/false);
- if (numResults > 1)
- os << ':' << numResults;
+ auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
+ printValueID(op->getResult(resultNo), /*printResultNo=*/false);
+ if (resultCount > 1)
+ os << ':' << resultCount;
+ };
+
+ // Check to see if this operation has multiple result groups.
+ auto resultGroupIt = opResultGroups.find(op);
+ if (resultGroupIt != opResultGroups.end()) {
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ // Interleave the groups excluding the last one, this one will be handled
+ // separately.
+ interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
+ printResultGroup(resultGroups[i],
+ resultGroups[i + 1] - resultGroups[i]);
+ });
+ os << ", ";
+ printResultGroup(resultGroups.back(), numResults - resultGroups.back());
+
+ } else {
+ printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
+ }
+
os << " = ";
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 2519b83eede..415d9d66e22 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -5,5 +5,5 @@ add_llvm_library(MLIRIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
-add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport)
+add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport)
target_link_libraries(MLIRIR MLIRSupport LLVMSupport)
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index dc85fbb14b3..a6460b46dd9 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1116,3 +1116,18 @@ func @"\"_string_symbol_reference\""() {
// CHECK-LABEL: func @nested_reference
// CHECK-NEXT: ref = @some_symbol::@some_nested_symbol
func @nested_reference() attributes {test.ref = @some_symbol::@some_nested_symbol }
+
+// CHECK-LABEL: func @custom_asm_names
+func @custom_asm_names() -> (i32, i32, i32, i32, i32, i32, i32) {
+ // CHECK: %[[FIRST:first.*]], %[[MIDDLE:middle_results.*]]:2, %[[LAST:[0-9]+]]
+ %0, %1:2, %2 = "test.asm_interface_op"() : () -> (i32, i32, i32, i32)
+
+ // CHECK: %[[FIRST_2:first.*]], %[[LAST_2:[0-9]+]]
+ %3, %4 = "test.asm_interface_op"() : () -> (i32, i32)
+
+ // CHECK: %[[RESULT:result.*]]
+ %5 = "test.asm_dialect_interface_op"() : () -> (i32)
+
+ // CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]]
+ return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32
+}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 01780432a1a..d838f75f7e7 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -30,6 +30,18 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
namespace {
+
+// Test support for interacting with the AsmPrinter.
+struct TestOpAsmInterface : public OpAsmDialectInterface {
+ using OpAsmDialectInterface::OpAsmDialectInterface;
+
+ void getAsmResultNames(Operation *op,
+ OpAsmSetValueNameFn setNameFn) const final {
+ if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
+ setNameFn(asmOp, "result");
+ }
+};
+
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
@@ -112,7 +124,8 @@ TestDialect::TestDialect(MLIRContext *context)
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
- addInterfaces<TestOpFolderDialectInterface, TestInlinerInterface>();
+ addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
+ TestInlinerInterface>();
allowUnknownOperations();
}
@@ -227,6 +240,7 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) {
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
+
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
@@ -241,6 +255,21 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
}
//===----------------------------------------------------------------------===//
+// Test OpAsmInterface.
+//===----------------------------------------------------------------------===//
+
+void AsmInterfaceOp::getAsmResultNames(
+ function_ref<void(Value *, StringRef)> setNameFn) {
+ // Give a name to the first and middle results.
+ setNameFn(firstResult(), "first");
+ if (!llvm::empty(middleResults()))
+ setNameFn(*middleResults().begin(), "middle_results");
+
+ // Use default numbering for the last result.
+ setNameFn(getResult(getNumResults() - 1), "");
+}
+
+//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index a0e1cd61ba4..1ccbda2f9b1 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -19,6 +19,7 @@
#define TEST_OPS
include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/InferTypeOpInterface.td"
@@ -995,4 +996,17 @@ def PolyForOp : TEST_Op<"polyfor">
let parser = [{ return ::parse$cppClass(parser, result); }];
}
+//===----------------------------------------------------------------------===//
+// Test OpAsmInterface.
+
+def AsmInterfaceOp : TEST_Op<"asm_interface_op",
+ [DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
+ let results = (outs AnyType:$firstResult, Variadic<AnyType>:$middleResults,
+ AnyType);
+}
+
+def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> {
+ let results = (outs AnyType);
+}
+
#endif // TEST_OPS
OpenPOWER on IntegriCloud