summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-20 10:19:01 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-20 10:45:45 -0800
commiteb418559ef29716cc34c891c93490c38ac5ea1dd (patch)
treebe1e4ac3d32e5df31b8668785ba15f92fe1895a9 /mlir/lib
parent3c055957de7e47e53d3ee8f5ab283cdb5c0ea535 (diff)
downloadbcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.tar.gz
bcm5719-llvm-eb418559ef29716cc34c891c93490c38ac5ea1dd.zip
Add a new OpAsmOpInterface to allow for ops to directly hook into the AsmPrinter.
This interface provides more fine-grained hooks into the AsmPrinter than the dialect interface, allowing for operations to define the asm name to use for results directly on the operations themselves. The hook is also expanded to enable defining named result "groups". Get a special name to use when printing the results of this operation. The given callback is invoked with a specific result value that starts a result "pack", and the name to give this result pack. To signal that a result pack should use the default naming scheme, a None can be passed in instead of the name. For example, if you have an operation that has four results and you want to split these into three distinct groups you could do the following: setNameFn(getResult(0), "first_result"); setNameFn(getResult(1), "middle_results"); setNameFn(getResult(3), ""); // use the default numbering. This would print the operation as follows: %first_result, %middle_results:2, %0 = "my.op" ... PiperOrigin-RevId: 281546873
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp58
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp193
-rw-r--r--mlir/lib/IR/CMakeLists.txt2
3 files changed, 167 insertions, 86 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 83c086784c4..21535515a66 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -44,37 +44,6 @@ using namespace mlir;
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
-struct StdOpAsmInterface : public OpAsmDialectInterface {
- using OpAsmDialectInterface::OpAsmDialectInterface;
-
- /// Get a special name to use when printing the given operation. The desired
- /// name should be streamed into 'os'.
- void getOpResultName(Operation *op, raw_ostream &os) const final {
- if (ConstantOp constant = dyn_cast<ConstantOp>(op))
- return getConstantOpResultName(constant, os);
- }
-
- /// Get a special name to use when printing the given constant.
- static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
- Type type = op.getType();
- Attribute value = op.getValue();
- if (auto intCst = value.dyn_cast<IntegerAttr>()) {
- if (type.isIndex()) {
- os << 'c' << intCst.getInt();
- } else if (type.cast<IntegerType>().isInteger(1)) {
- // i1 constants get special names.
- os << (intCst.getInt() ? "true" : "false");
- } else {
- os << 'c' << intCst.getInt() << '_' << type;
- }
- } else if (type.isa<FunctionType>()) {
- os << 'f';
- } else {
- os << "cst";
- }
- }
-};
-
/// This class defines the interface for handling inlining with standard
/// operations.
struct StdInlinerInterface : public DialectInlinerInterface {
@@ -191,7 +160,7 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
>();
- addInterfaces<StdInlinerInterface, StdOpAsmInterface>();
+ addInterfaces<StdInlinerInterface>();
}
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
@@ -1183,6 +1152,31 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}
+void ConstantOp::getAsmResultNames(
+ function_ref<void(Value *, StringRef)> setNameFn) {
+ Type type = getType();
+ if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
+ IntegerType intTy = type.dyn_cast<IntegerType>();
+
+ // Sugar i1 constants with 'true' and 'false'.
+ if (intTy && intTy.getWidth() == 1)
+ return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
+
+ // Otherwise, build a complex name with the value and type.
+ SmallString<32> specialNameBuffer;
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ specialName << 'c' << intCst.getInt();
+ if (intTy)
+ specialName << '_' << type;
+ setNameFn(getResult(), specialName.str());
+
+ } else if (type.isa<FunctionType>()) {
+ setNameFn(getResult(), "f");
+ } else {
+ setNameFn(getResult(), "cst");
+ }
+}
+
/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8ffe9c51a9c..655a776118c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -58,6 +58,13 @@ DialectAsmPrinter::~DialectAsmPrinter() {}
OpAsmPrinter::~OpAsmPrinter() {}
+//===--------------------------------------------------------------------===//
+// Operation OpAsm interface.
+//===--------------------------------------------------------------------===//
+
+/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
+#include "mlir/IR/OpAsmInterface.cpp.inc"
+
//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
@@ -1490,17 +1497,26 @@ public:
const static unsigned indentWidth = 2;
protected:
- void numberValueID(Value *value);
void numberValuesInRegion(Region &region);
void numberValuesInBlock(Block &block);
+ void numberValuesInOp(Operation &op);
void printValueID(Value *value, bool printResultNo = true) const {
printValueIDImpl(value, printResultNo, os);
}
private:
+ /// Given a result of an operation 'result', find the result group head
+ /// 'lookupValue' and the result of 'result' within that group in
+ /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
+ /// has more than 1 result.
+ void getResultIDAndNumber(OpResult *result, Value *&lookupValue,
+ int &lookupResultNo) const;
void printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const;
+ /// Set a special value name for the given value.
+ void setValueName(Value *value, StringRef name);
+
/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
@@ -1510,6 +1526,11 @@ private:
DenseMap<Value *, unsigned> valueIDs;
DenseMap<Value *, StringRef> valueNames;
+ /// This is a map of operations that contain multiple named result groups,
+ /// i.e. there may be multiple names for the results of the operation. The key
+ /// of this map are the result numbers that start a result group.
+ DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
+
/// This is the block ID for each block in the current.
DenseMap<Block *, unsigned> blockIDs;
@@ -1534,8 +1555,7 @@ private:
OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
: ModulePrinter(other) {
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
- if (op->getNumResults() != 0)
- numberValueID(op->getResult(0));
+ numberValuesInOp(*op);
for (auto &region : op->getRegions())
numberValuesInRegion(region);
@@ -1546,7 +1566,6 @@ OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
numberValuesInRegion(*region);
}
-/// Number all of the SSA values in the specified region.
void OperationPrinter::numberValuesInRegion(Region &region) {
// Save the current value ids to allow for numbering values in sibling regions
// the same.
@@ -1580,59 +1599,72 @@ void OperationPrinter::numberValuesInRegion(Region &region) {
nextConflictID = curConflictID;
}
-/// Number all of the SSA values in the specified block, without traversing
-/// nested regions.
void OperationPrinter::numberValuesInBlock(Block &block) {
- // Number the block arguments.
- for (auto *arg : block.getArguments())
- numberValueID(arg);
+ bool isEntryBlock = block.isEntryBlock();
+
+ // Number the block arguments. We give entry block arguments a special name
+ // 'arg'.
+ SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ for (auto *arg : block.getArguments()) {
+ if (isEntryBlock) {
+ specialNameBuffer.resize(strlen("arg"));
+ specialName << nextArgumentID++;
+ }
+ setValueName(arg, specialName.str());
+ }
- // We number operation that have results, and we only number the first result.
+ // Number the operations in this block.
for (auto &op : block)
- if (op.getNumResults() != 0)
- numberValueID(op.getResult(0));
+ numberValuesInOp(op);
}
-void OperationPrinter::numberValueID(Value *value) {
- assert(!valueIDs.count(value) && "Value numbered multiple times");
-
- SmallString<32> specialNameBuffer;
- llvm::raw_svector_ostream specialName(specialNameBuffer);
+void OperationPrinter::numberValuesInOp(Operation &op) {
+ unsigned numResults = op.getNumResults();
+ if (numResults == 0)
+ return;
+ Value *resultBegin = op.getResult(0);
+
+ // Function used to set the special result names for the operation.
+ SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ auto setResultNameFn = [&](Value *result, StringRef name) {
+ assert(!valueIDs.count(result) && "result numbered multiple times");
+ assert(result->getDefiningOp() == &op && "result not defined by 'op'");
+ setValueName(result, name);
+
+ // Record the result number for groups not anchored at 0.
+ if (int resultNo = cast<OpResult>(result)->getResultNumber())
+ resultGroups.push_back(resultNo);
+ };
- // Check to see if this value requested a special name.
- auto *op = value->getDefiningOp();
- if (state && op) {
- if (auto *interface = state->getOpAsmInterface(op->getDialect()))
- interface->getOpResultName(op, specialName);
+ if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
+ asmInterface.getAsmResultNames(setResultNameFn);
+ } else if (auto *dialectAsmInterface =
+ state ? state->getOpAsmInterface(op.getDialect()) : nullptr) {
+ dialectAsmInterface->getAsmResultNames(&op, setResultNameFn);
}
- if (specialNameBuffer.empty()) {
- auto *blockArg = dyn_cast<BlockArgument>(value);
- if (!blockArg) {
- // This is an uninteresting operation result, give it a boring number and
- // be done with it.
- valueIDs[value] = nextValueID++;
- return;
- }
+ // If the first result wasn't numbered, give it a default number.
+ if (valueIDs.try_emplace(resultBegin, nextValueID).second)
+ ++nextValueID;
- // Otherwise, if this is an argument to the entry block of a region, give it
- // an 'arg' name.
- if (auto *block = blockArg->getOwner()) {
- auto *parentRegion = block->getParent();
- if (parentRegion && block == &parentRegion->front())
- specialName << "arg" << nextArgumentID++;
- }
+ // If this operation has multiple result groups, mark it.
+ if (resultGroups.size() != 1) {
+ llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
+ opResultGroups.try_emplace(&op, std::move(resultGroups));
+ }
+}
- // Otherwise number it normally.
- if (specialNameBuffer.empty()) {
- valueIDs[value] = nextValueID++;
- return;
- }
+/// Set a special value name for the given value.
+void OperationPrinter::setValueName(Value *value, StringRef name) {
+ // If the name is empty, the value uses the default numbering.
+ if (name.empty()) {
+ valueIDs[value] = nextValueID++;
+ return;
}
- // Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
- valueNames[value] = uniqueValueName(specialName.str());
+ valueNames[value] = uniqueValueName(name);
}
/// Uniques the given value name within the printer. If the given name
@@ -1722,6 +1754,45 @@ void OperationPrinter::print(Operation *op) {
printTrailingLocation(op->getLoc());
}
+void OperationPrinter::getResultIDAndNumber(OpResult *result,
+ Value *&lookupValue,
+ int &lookupResultNo) const {
+ Operation *owner = result->getOwner();
+ if (owner->getNumResults() == 1)
+ return;
+ int resultNo = result->getResultNumber();
+
+ // If this operation has multiple result groups, we will need to find the
+ // one corresponding to this result.
+ auto resultGroupIt = opResultGroups.find(owner);
+ if (resultGroupIt == opResultGroups.end()) {
+ // If not, just use the first result.
+ lookupResultNo = resultNo;
+ lookupValue = owner->getResult(0);
+ return;
+ }
+
+ // Find the correct index using a binary search, as the groups are ordered.
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ auto it = llvm::upper_bound(resultGroups, resultNo);
+ int groupResultNo = 0, groupSize = 0;
+
+ // If there are no smaller elements, the last result group is the lookup.
+ if (it == resultGroups.end()) {
+ groupResultNo = resultGroups.back();
+ groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
+ } else {
+ // Otherwise, the previous element is the lookup.
+ groupResultNo = *std::prev(it);
+ groupSize = *it - groupResultNo;
+ }
+
+ // We only record the result number for a group of size greater than 1.
+ if (groupSize != 1)
+ lookupResultNo = resultNo - groupResultNo;
+ lookupValue = owner->getResult(groupResultNo);
+}
+
void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const {
if (!value) {
@@ -1735,12 +1806,8 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
// If this is a reference to the result of a multi-result operation or
// operation, print out the # identifier and make sure to map our lookup
// to the first result of the operation.
- if (auto *result = dyn_cast<OpResult>(value)) {
- if (result->getOwner()->getNumResults() != 1) {
- resultNo = result->getResultNumber();
- lookupValue = result->getOwner()->getResult(0);
- }
- }
+ if (OpResult *result = dyn_cast<OpResult>(value))
+ getResultIDAndNumber(result, lookupValue, resultNo);
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
@@ -1798,9 +1865,29 @@ void OperationPrinter::shadowRegionArgs(Region &region,
void OperationPrinter::printOperation(Operation *op) {
if (size_t numResults = op->getNumResults()) {
- printValueID(op->getResult(0), /*printResultNo=*/false);
- if (numResults > 1)
- os << ':' << numResults;
+ auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
+ printValueID(op->getResult(resultNo), /*printResultNo=*/false);
+ if (resultCount > 1)
+ os << ':' << resultCount;
+ };
+
+ // Check to see if this operation has multiple result groups.
+ auto resultGroupIt = opResultGroups.find(op);
+ if (resultGroupIt != opResultGroups.end()) {
+ ArrayRef<int> resultGroups = resultGroupIt->second;
+ // Interleave the groups excluding the last one, this one will be handled
+ // separately.
+ interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
+ printResultGroup(resultGroups[i],
+ resultGroups[i + 1] - resultGroups[i]);
+ });
+ os << ", ";
+ printResultGroup(resultGroups.back(), numResults - resultGroups.back());
+
+ } else {
+ printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
+ }
+
os << " = ";
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 2519b83eede..415d9d66e22 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -5,5 +5,5 @@ add_llvm_library(MLIRIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
-add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIRSupport LLVMSupport)
+add_dependencies(MLIRIR MLIRCallOpInterfacesIncGen MLIROpAsmInterfacesIncGen MLIRSupport LLVMSupport)
target_link_libraries(MLIRIR MLIRSupport LLVMSupport)
OpenPOWER on IntegriCloud