summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h5
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp15
-rw-r--r--mlir/test/IR/pretty-region-args.mlir12
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp14
4 files changed, 46 insertions, 0 deletions
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 97569cc06d9..7dd11d089c2 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -661,6 +661,11 @@ public:
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}
+
+ /// Get a special name to use when printing the entry block arguments of the
+ /// region contained by an operation in this dialect.
+ virtual void getAsmBlockArgumentNames(Block *block,
+ OpAsmSetValueNameFn setNameFn) const {}
};
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e1903d560b1..f3c92ada0a0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1619,13 +1619,28 @@ void OperationPrinter::numberValuesInRegion(Region &region) {
}
void OperationPrinter::numberValuesInBlock(Block &block) {
+ auto setArgNameFn = [&](Value *arg, StringRef name) {
+ assert(!valueIDs.count(arg) && "arg numbered multiple times");
+ assert(cast<BlockArgument>(arg)->getOwner() == &block &&
+ "arg not defined in 'block'");
+ setValueName(arg, name);
+ };
+
bool isEntryBlock = block.isEntryBlock();
+ if (isEntryBlock && state) {
+ if (auto *op = block.getParentOp()) {
+ if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect()))
+ dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
+ }
+ }
// Number the block arguments. We give entry block arguments a special name
// 'arg'.
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
for (auto *arg : block.getArguments()) {
+ if (valueIDs.count(arg))
+ continue;
if (isEntryBlock) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
diff --git a/mlir/test/IR/pretty-region-args.mlir b/mlir/test/IR/pretty-region-args.mlir
new file mode 100644
index 00000000000..59a9ebce092
--- /dev/null
+++ b/mlir/test/IR/pretty-region-args.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: func @custom_region_names
+func @custom_region_names() -> () {
+ "test.polyfor"() ( {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index):
+ "foo"() : () -> ()
+ }) { arg_names = ["i", "j", "k"] } : () -> ()
+ // CHECK: test.polyfor
+ // CHECK-NEXT: ^bb{{.*}}(%i: index, %j: index, %k: index):
+ return
+}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 059cfb3dce7..7462db4544f 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -41,6 +41,20 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
+
+ void getAsmBlockArgumentNames(Block *block,
+ OpAsmSetValueNameFn setNameFn) const final {
+ auto op = block->getParentOp();
+ auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
+ if (!arrayAttr)
+ return;
+ auto args = block->getArguments();
+ auto e = std::min(arrayAttr.size(), args.size());
+ for (unsigned i = 0; i < e; ++i) {
+ if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
+ setNameFn(args[i], strAttr.getValue());
+ }
+ }
};
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
OpenPOWER on IntegriCloud