summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-08-09 19:03:58 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-09 19:04:23 -0700
commitac68637ba94d05f413b1b963950c300fb2f81a99 (patch)
treefa09796c3af34c3b170cf2ee087878a0ead7c871 /mlir/lib
parent41968fb4753118afe5a9f4fecf184fac90d96fe6 (diff)
downloadbcm5719-llvm-ac68637ba94d05f413b1b963950c300fb2f81a99.tar.gz
bcm5719-llvm-ac68637ba94d05f413b1b963950c300fb2f81a99.zip
NFC: Refactoring PatternSymbolResolver into SymbolInfoMap
In declarative rewrite rules, a symbol can be bound to op arguments or results in the source pattern, and it can be bound to op results in the result pattern. This means given a symbol in the pattern, it can stands for different things: op operand, op attribute, single op result, op result pack. We need a better way to model this complexity so that we can handle according to the specific kind a symbol corresponds to. Created SymbolInfo class for maintaining the information regarding a symbol. Also created a companion SymbolInfoMap class for a map of such symbols, providing insertion and querying depending on use cases. PiperOrigin-RevId: 262675515
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/TableGen/Pattern.cpp219
1 files changed, 184 insertions, 35 deletions
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index fa37d22cc5e..51e4c3b376b 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -31,6 +31,10 @@ using namespace mlir;
using llvm::formatv;
using mlir::tblgen::Operator;
+//===----------------------------------------------------------------------===//
+// DagLeaf
+//===----------------------------------------------------------------------===//
+
bool tblgen::DagLeaf::isUnspecified() const {
return dyn_cast_or_null<llvm::UnsetInit>(def);
}
@@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
return false;
}
+//===----------------------------------------------------------------------===//
+// DagNode
+//===----------------------------------------------------------------------===//
+
bool tblgen::DagNode::isNativeCodeCall() const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
return defInit->getDef()->isSubClassOf("NativeCodeCall");
@@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
-tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
- : def(*def), recordOpMap(mapper) {
- collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
- for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
- collectBoundSymbols(getResultPattern(i), resBoundOps,
- /*isSrcPattern=*/false);
+//===----------------------------------------------------------------------===//
+// SymbolInfoMap
+//===----------------------------------------------------------------------===//
+
+StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
+ int *index) {
+ StringRef name, indexStr;
+ int idx = -1;
+ std::tie(name, indexStr) = symbol.rsplit("__");
+
+ if (indexStr.consumeInteger(10, idx)) {
+ // The second part is not an index; we return the whole symbol as-is.
+ return symbol;
+ }
+ if (index) {
+ *index = idx;
+ }
+ return name;
+}
+
+tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
+ SymbolInfo::Kind kind,
+ Optional<int> index)
+ : op(op), kind(kind), argIndex(index) {}
+
+int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
+ switch (kind) {
+ case Kind::Attr:
+ case Kind::Operand:
+ case Kind::Value:
+ return 1;
+ case Kind::Result:
+ return op->getNumResults();
+ }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
+ switch (kind) {
+ case Kind::Attr: {
+ auto type =
+ op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+ return formatv("{0} {1};\n", type, name);
+ }
+ case Kind::Operand:
+ case Kind::Value: {
+ return formatv("Value *{0};\n", name);
+ }
+ case Kind::Result: {
+ // Use the op itself for the results.
+ return formatv("{0} {1};\n", op->getQualCppClassName(), name);
+ }
+ }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
+ int index) const {
+ switch (kind) {
+ case Kind::Attr:
+ case Kind::Operand: {
+ assert(index < 0 && "only allowed for symbol bound to result");
+ return name;
+ }
+ case Kind::Result: {
+ // TODO(b/133341698): The following is incorrect for variadic results. We
+ // should use getODSResults().
+ if (index >= 0) {
+ return formatv("{0}.getOperation()->getResult({1})", name, index);
+ }
+
+ // If referencing multiple results, compose a comma-separated list.
+ SmallVector<std::string, 4> values;
+ for (int i = 0, e = op->getNumResults(); i < e; ++i) {
+ values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
+ }
+ return llvm::join(values, ", ");
+ }
+ case Kind::Value: {
+ assert(index < 0 && "only allowed for symbol bound to result");
+ assert(op == nullptr);
+ return name;
+ }
+ }
+}
+
+bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
+ int argIndex) {
+ StringRef name = getValuePackName(symbol);
+ if (name != symbol) {
+ auto error = formatv(
+ "symbol '{0}' with trailing index cannot bind to op argument", symbol);
+ PrintFatalError(loc, error);
+ }
+
+ auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
+ ? SymbolInfo::getAttr(&op, argIndex)
+ : SymbolInfo::getOperand(&op, argIndex);
+
+ return symbolInfoMap.insert({symbol, symInfo}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
+ StringRef name = getValuePackName(symbol);
+ return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
+ return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
+}
+
+bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
+ return find(symbol) != symbolInfoMap.end();
+}
+
+tblgen::SymbolInfoMap::const_iterator
+tblgen::SymbolInfoMap::find(StringRef key) const {
+ StringRef name = getValuePackName(key);
+ return symbolInfoMap.find(name);
+}
+
+int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
+ StringRef name = getValuePackName(symbol);
+ if (name != symbol) {
+ // If there is a trailing index inside symbol, it references just one
+ // static value.
+ return 1;
+ }
+ // Otherwise, find how many it represents by querying the symbol's info.
+ return find(name)->getValue().getStaticValueCount();
}
+std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
+ int index = -1;
+ StringRef name = getValuePackName(symbol, &index);
+
+ auto it = symbolInfoMap.find(name);
+ if (it == symbolInfoMap.end()) {
+ auto error = formatv("referencing unbound symbol '{0}'", symbol);
+ PrintFatalError(loc, error);
+ }
+
+ return it->getValue().getValueAndRangeUse(name, index);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern
+//==----------------------------------------------------------------------===//
+
+tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
+ : def(*def), recordOpMap(mapper) {}
+
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
}
@@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
-void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
- if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
- srcBoundOps.find(name) == srcBoundOps.end())
- PrintFatalError(def.getLoc(),
- Twine("referencing unbound variable '") + name + "'");
+void tblgen::Pattern::collectSourcePatternBoundSymbols(
+ tblgen::SymbolInfoMap &infoMap) {
+ collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
}
-llvm::StringMap<tblgen::Argument> &
-tblgen::Pattern::getSourcePatternBoundArgs() {
- return srcBoundArguments;
-}
-
-llvm::StringMap<const tblgen::Operator *> &
-tblgen::Pattern::getSourcePatternBoundOps() {
- return srcBoundOps;
-}
-
-llvm::StringMap<const tblgen::Operator *> &
-tblgen::Pattern::getResultPatternBoundOps() {
- return resBoundOps;
+void tblgen::Pattern::collectResultPatternBoundSymbols(
+ tblgen::SymbolInfoMap &infoMap) {
+ for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
+ auto pattern = getResultPattern(i);
+ collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
+ }
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const {
return result;
}
-void tblgen::Pattern::collectBoundSymbols(DagNode tree,
- SymbolOperatorMap &symOpMap,
+void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
if (!tree.isOperation()) {
@@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree,
auto numTreeArgs = tree.getNumArgs();
if (numOpArgs != numTreeArgs) {
- PrintFatalError(def.getLoc(),
- formatv("op '{0}' argument number mismatch: "
- "{1} in pattern vs. {2} in definition",
- op.getOperationName(), numTreeArgs, numOpArgs));
+ auto err = formatv("op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ op.getOperationName(), numTreeArgs, numOpArgs);
+ PrintFatalError(def.getLoc(), err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
- if (!treeName.empty())
- symOpMap.try_emplace(treeName, &op);
+ if (!treeName.empty()) {
+ if (!infoMap.bindOpResult(treeName, op))
+ PrintFatalError(def.getLoc(),
+ formatv("symbol '{0}' bound more than once", treeName));
+ }
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
- collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
+ collectBoundSymbols(treeArg, infoMap, isSrcPattern);
} else if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
- if (!treeArgName.empty())
- srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
+ if (!treeArgName.empty()) {
+ if (!infoMap.bindOpArgument(treeArgName, op, i)) {
+ auto err = formatv("symbol '{0}' bound more than once", treeArgName);
+ PrintFatalError(def.getLoc(), err);
+ }
+ }
}
}
}
OpenPOWER on IntegriCloud