diff options
author | Lei Zhang <antiagainst@google.com> | 2019-08-09 19:03:58 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-09 19:04:23 -0700 |
commit | ac68637ba94d05f413b1b963950c300fb2f81a99 (patch) | |
tree | fa09796c3af34c3b170cf2ee087878a0ead7c871 /mlir/lib/TableGen/Pattern.cpp | |
parent | 41968fb4753118afe5a9f4fecf184fac90d96fe6 (diff) | |
download | bcm5719-llvm-ac68637ba94d05f413b1b963950c300fb2f81a99.tar.gz bcm5719-llvm-ac68637ba94d05f413b1b963950c300fb2f81a99.zip |
NFC: Refactoring PatternSymbolResolver into SymbolInfoMap
In declarative rewrite rules, a symbol can be bound to op arguments or
results in the source pattern, and it can be bound to op results in the
result pattern. This means given a symbol in the pattern, it can stands
for different things: op operand, op attribute, single op result,
op result pack. We need a better way to model this complexity so that
we can handle according to the specific kind a symbol corresponds to.
Created SymbolInfo class for maintaining the information regarding a
symbol. Also created a companion SymbolInfoMap class for a map of
such symbols, providing insertion and querying depending on use cases.
PiperOrigin-RevId: 262675515
Diffstat (limited to 'mlir/lib/TableGen/Pattern.cpp')
-rw-r--r-- | mlir/lib/TableGen/Pattern.cpp | 219 |
1 files changed, 184 insertions, 35 deletions
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index fa37d22cc5e..51e4c3b376b 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -31,6 +31,10 @@ using namespace mlir; using llvm::formatv; using mlir::tblgen::Operator; +//===----------------------------------------------------------------------===// +// DagLeaf +//===----------------------------------------------------------------------===// + bool tblgen::DagLeaf::isUnspecified() const { return dyn_cast_or_null<llvm::UnsetInit>(def); } @@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { return false; } +//===----------------------------------------------------------------------===// +// DagNode +//===----------------------------------------------------------------------===// + bool tblgen::DagNode::isNativeCodeCall() const { if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) return defInit->getDef()->isSubClassOf("NativeCodeCall"); @@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const { return dagOpDef->getName() == "replaceWithValue"; } -tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) - : def(*def), recordOpMap(mapper) { - collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true); - for (int i = 0, e = getNumResultPatterns(); i < e; ++i) - collectBoundSymbols(getResultPattern(i), resBoundOps, - /*isSrcPattern=*/false); +//===----------------------------------------------------------------------===// +// SymbolInfoMap +//===----------------------------------------------------------------------===// + +StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, + int *index) { + StringRef name, indexStr; + int idx = -1; + std::tie(name, indexStr) = symbol.rsplit("__"); + + if (indexStr.consumeInteger(10, idx)) { + // The second part is not an index; we return the whole symbol as-is. + return symbol; + } + if (index) { + *index = idx; + } + return name; +} + +tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, + SymbolInfo::Kind kind, + Optional<int> index) + : op(op), kind(kind), argIndex(index) {} + +int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { + switch (kind) { + case Kind::Attr: + case Kind::Operand: + case Kind::Value: + return 1; + case Kind::Result: + return op->getNumResults(); + } +} + +std::string +tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { + switch (kind) { + case Kind::Attr: { + auto type = + op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); + return formatv("{0} {1};\n", type, name); + } + case Kind::Operand: + case Kind::Value: { + return formatv("Value *{0};\n", name); + } + case Kind::Result: { + // Use the op itself for the results. + return formatv("{0} {1};\n", op->getQualCppClassName(), name); + } + } +} + +std::string +tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name, + int index) const { + switch (kind) { + case Kind::Attr: + case Kind::Operand: { + assert(index < 0 && "only allowed for symbol bound to result"); + return name; + } + case Kind::Result: { + // TODO(b/133341698): The following is incorrect for variadic results. We + // should use getODSResults(). + if (index >= 0) { + return formatv("{0}.getOperation()->getResult({1})", name, index); + } + + // If referencing multiple results, compose a comma-separated list. + SmallVector<std::string, 4> values; + for (int i = 0, e = op->getNumResults(); i < e; ++i) { + values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i)); + } + return llvm::join(values, ", "); + } + case Kind::Value: { + assert(index < 0 && "only allowed for symbol bound to result"); + assert(op == nullptr); + return name; + } + } +} + +bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, + int argIndex) { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + auto error = formatv( + "symbol '{0}' with trailing index cannot bind to op argument", symbol); + PrintFatalError(loc, error); + } + + auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() + ? SymbolInfo::getAttr(&op, argIndex) + : SymbolInfo::getOperand(&op, argIndex); + + return symbolInfoMap.insert({symbol, symInfo}).second; +} + +bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { + StringRef name = getValuePackName(symbol); + return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; +} + +bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { + return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; +} + +bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { + return find(symbol) != symbolInfoMap.end(); +} + +tblgen::SymbolInfoMap::const_iterator +tblgen::SymbolInfoMap::find(StringRef key) const { + StringRef name = getValuePackName(key); + return symbolInfoMap.find(name); +} + +int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + // If there is a trailing index inside symbol, it references just one + // static value. + return 1; + } + // Otherwise, find how many it represents by querying the symbol's info. + return find(name)->getValue().getStaticValueCount(); } +std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const { + int index = -1; + StringRef name = getValuePackName(symbol, &index); + + auto it = symbolInfoMap.find(name); + if (it == symbolInfoMap.end()) { + auto error = formatv("referencing unbound symbol '{0}'", symbol); + PrintFatalError(loc, error); + } + + return it->getValue().getValueAndRangeUse(name, index); +} + +//===----------------------------------------------------------------------===// +// Pattern +//==----------------------------------------------------------------------===// + +tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) + : def(*def), recordOpMap(mapper) {} + tblgen::DagNode tblgen::Pattern::getSourcePattern() const { return tblgen::DagNode(def.getValueAsDag("sourcePattern")); } @@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); } -void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const { - if (srcBoundArguments.find(name) == srcBoundArguments.end() && - srcBoundOps.find(name) == srcBoundOps.end()) - PrintFatalError(def.getLoc(), - Twine("referencing unbound variable '") + name + "'"); +void tblgen::Pattern::collectSourcePatternBoundSymbols( + tblgen::SymbolInfoMap &infoMap) { + collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); } -llvm::StringMap<tblgen::Argument> & -tblgen::Pattern::getSourcePatternBoundArgs() { - return srcBoundArguments; -} - -llvm::StringMap<const tblgen::Operator *> & -tblgen::Pattern::getSourcePatternBoundOps() { - return srcBoundOps; -} - -llvm::StringMap<const tblgen::Operator *> & -tblgen::Pattern::getResultPatternBoundOps() { - return resBoundOps; +void tblgen::Pattern::collectResultPatternBoundSymbols( + tblgen::SymbolInfoMap &infoMap) { + for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { + auto pattern = getResultPattern(i); + collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); + } } const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { @@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const { return result; } -void tblgen::Pattern::collectBoundSymbols(DagNode tree, - SymbolOperatorMap &symOpMap, +void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); if (!tree.isOperation()) { @@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, auto numTreeArgs = tree.getNumArgs(); if (numOpArgs != numTreeArgs) { - PrintFatalError(def.getLoc(), - formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs)); + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(def.getLoc(), err); } // The name attached to the DAG node's operator is for representing the // results generated from this op. It should be remembered as bound results. - if (!treeName.empty()) - symOpMap.try_emplace(treeName, &op); + if (!treeName.empty()) { + if (!infoMap.bindOpResult(treeName, op)) + PrintFatalError(def.getLoc(), + formatv("symbol '{0}' bound more than once", treeName)); + } for (int i = 0; i != numTreeArgs; ++i) { if (auto treeArg = tree.getArgAsNestedDag(i)) { // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, symOpMap, isSrcPattern); + collectBoundSymbols(treeArg, infoMap, isSrcPattern); } else if (isSrcPattern) { // We can only bind symbols to op arguments in source pattern. Those // symbols are referenced in result patterns. auto treeArgName = tree.getArgName(i); - if (!treeArgName.empty()) - srcBoundArguments.try_emplace(treeArgName, op.getArg(i)); + if (!treeArgName.empty()) { + if (!infoMap.bindOpArgument(treeArgName, op, i)) { + auto err = formatv("symbol '{0}' bound more than once", treeArgName); + PrintFatalError(def.getLoc(), err); + } + } } } } |