diff options
-rw-r--r-- | mlir/include/mlir/IR/OpImplementation.h | 5 | ||||
-rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 15 | ||||
-rw-r--r-- | mlir/test/IR/pretty-region-args.mlir | 12 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 14 |
4 files changed, 46 insertions, 0 deletions
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 97569cc06d9..7dd11d089c2 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -661,6 +661,11 @@ public: /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. virtual void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const {} + + /// Get a special name to use when printing the entry block arguments of the + /// region contained by an operation in this dialect. + virtual void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const {} }; //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e1903d560b1..f3c92ada0a0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1619,13 +1619,28 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { } void OperationPrinter::numberValuesInBlock(Block &block) { + auto setArgNameFn = [&](Value *arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(cast<BlockArgument>(arg)->getOwner() == &block && + "arg not defined in 'block'"); + setValueName(arg, name); + }; + bool isEntryBlock = block.isEntryBlock(); + if (isEntryBlock && state) { + if (auto *op = block.getParentOp()) { + if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect())) + dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); + } + } // Number the block arguments. We give entry block arguments a special name // 'arg'. SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); llvm::raw_svector_ostream specialName(specialNameBuffer); for (auto *arg : block.getArguments()) { + if (valueIDs.count(arg)) + continue; if (isEntryBlock) { specialNameBuffer.resize(strlen("arg")); specialName << nextArgumentID++; diff --git a/mlir/test/IR/pretty-region-args.mlir b/mlir/test/IR/pretty-region-args.mlir new file mode 100644 index 00000000000..59a9ebce092 --- /dev/null +++ b/mlir/test/IR/pretty-region-args.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: func @custom_region_names +func @custom_region_names() -> () { + "test.polyfor"() ( { + ^bb0(%arg0: index, %arg1: index, %arg2: index): + "foo"() : () -> () + }) { arg_names = ["i", "j", "k"] } : () -> () + // CHECK: test.polyfor + // CHECK-NEXT: ^bb{{.*}}(%i: index, %j: index, %k: index): + return +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 059cfb3dce7..7462db4544f 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -41,6 +41,20 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) setNameFn(asmOp, "result"); } + + void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const final { + auto op = block->getParentOp(); + auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); + if (!arrayAttr) + return; + auto args = block->getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>()) + setNameFn(args[i], strAttr.getValue()); + } + } }; struct TestOpFolderDialectInterface : public OpFolderDialectInterface { |