summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-04-04 05:44:58 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-04-05 07:40:41 -0700
commitc7790df2ed9bdcde12683aee6cb89a2668b56661 (patch)
tree61133a877a95d2f40c8d780ed0d77d71c9356c1e /mlir
parent3c833344c858ae8af38fbd6f20ce9f07a685c15f (diff)
downloadbcm5719-llvm-c7790df2ed9bdcde12683aee6cb89a2668b56661.tar.gz
bcm5719-llvm-c7790df2ed9bdcde12683aee6cb89a2668b56661.zip
[TableGen] Add PatternSymbolResolver for resolving symbols bound in patterns
We can bind symbols to op arguments/results in source pattern and op results in result pattern. Previously resolving these symbols is scattered across RewriterGen.cpp. This CL aggregated them into a `PatternSymbolResolver` class. While we are here, this CL also cleans up tests for patterns to make them more focused. Specifically, one-op-one-result.td is superseded by pattern.td; pattern-tAttr.td is simplified; pattern-bound-symbol.td is added for the change in this CL. -- PiperOrigin-RevId: 241913973
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/OpBase.td30
-rw-r--r--mlir/test/mlir-tblgen/one-op-one-result.td31
-rw-r--r--mlir/test/mlir-tblgen/pattern-bound-symbol.td61
-rw-r--r--mlir/test/mlir-tblgen/pattern-tAttr.td54
-rw-r--r--mlir/test/mlir-tblgen/pattern.td35
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp221
6 files changed, 295 insertions, 137 deletions
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 3deddd0dfa4..473997a50da 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -804,13 +804,33 @@ def addBenefit;
// A rewrite rule contains two components: a source pattern and one or more
// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`.
+//
// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
-// In the source pattern, `arg*` can be used to specify matchers (e.g., using
-// type/attribute types, etc.) and bound to a name for later use. In
-// the result pattern, `arg*` can be used to refer to a previously bound name,
-// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
-// be nested DAG node.
+//
+// In the source pattern, `argN` can be used to specify matchers (e.g., using
+// type/attribute type constraints, etc.) and bound to a name for later use.
+// We can also bound names to op instances to reference them later in
+// multi-entity constraints.
+//
+// In the result pattern, `argN` can be used to refer to a previously bound
+// name, with potential transformations (e.g., using tAttr, etc.). `argN` can
+// itself be nested DAG node. We can also bound names to op results to reference
+// them later in other result patterns.
+//
+// For example,
+//
+// ```
+// def : Pattern<(OneResultOp1:$res1 $arg0, $arg1),
+// [(OneResultOp2:$res2 $arg0, $arg1),
+// (OneResultOp3 $res2 (OneResultOp4))],
+// [(IsStaticShapeTensorTypePred $res1)]>;
+// ```
+//
+// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to
+// build `OneResultOp2`. `$res1` is bound to `OneResultOp1`'s result and used to
+// check whether the result's shape is static. `$res2` is bound to the result of
+// `OneResultOp2` and used to build `OneResultOp3`.
class Pattern<dag source, list<dag> results, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td
deleted file mode 100644
index 56e104d87c0..00000000000
--- a/mlir/test/mlir-tblgen/one-op-one-result.td
+++ /dev/null
@@ -1,31 +0,0 @@
-// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-// Create a Type and Attribute.
-def T : BuildableType<"buildT">;
-def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
-def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
-
-// Define ops to rewrite.
-def U: Type<CPred<"true">, "U">;
-def X_AddOp : Op<"x.add"> {
- let arguments = (ins U, U);
-}
-def Y_AddOp : Op<"y.add"> {
- let arguments = (ins U, U, T_Attr:$attrName);
-}
-
-// Define rewrite pattern.
-def : Pat<(X_AddOp:$res $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
-def : Pat<(X_AddOp (X_AddOp $lhs, $rhs):$res, $rrhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
-def : Pat<(X_AddOp (X_AddOp:$res $lhs, $rhs), $rrhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
-
-// CHECK: struct GeneratedConvert0 : public RewritePattern
-// CHECK: RewritePattern("x.add", 1, context)
-// CHECK: PatternMatchResult match(Operation *
-// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState>
-// CHECK: PatternRewriter &rewriter)
-// CHECK: rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType()
-// CHECK: void populateWithGenerated
-// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context))
diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td
new file mode 100644
index 00000000000..55f4d163116
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td
@@ -0,0 +1,61 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def OpA : Op<"op_a", []> {
+ let arguments = (ins I32:$operand, I32Attr:$attr);
+ let results = (outs I32:$result);
+}
+
+def OpB : Op<"op_b", []> {
+ let arguments = (ins I32:$operand);
+ let results = (outs I32:$result);
+}
+
+def OpC : Op<"op_c", []> {
+ let arguments = (ins I32:$operand);
+ let results = (outs I32:$result);
+}
+
+def OpD : Op<"op_d", []> {
+ let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr);
+ let results = (outs I32:$result);
+}
+
+def hasOneUse: Constraint<CPred<"{0}->hasOneUse()">, "has one use">;
+
+def : Pattern<(OpA:$res_a $operand, $attr),
+ [(OpC:$res_c (OpB:$res_b $operand)),
+ (OpD $res_b, $res_c, $attr)],
+ [(hasOneUse $res_a)]>;
+
+// CHECK-LABEL: GeneratedConvert0
+
+// Test struct for bound arguments
+// ---
+// CHECK: struct MatchedState : public PatternState
+// CHECK: Value* operand;
+// CHECK: IntegerAttr attr;
+
+// Test bound arguments/results in source pattern
+// ---
+// CHECK: PatternMatchResult match
+// CHECK: auto state = llvm::make_unique<MatchedState>();
+// CHECK: auto &s = *state;
+// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a;
+// CHECK: tblgen_res_a = op0;
+// CHECK: s.operand = op0->getOperand(0);
+// CHECK: s.attr = op0->getAttrOfType<IntegerAttr>("attr");
+// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure();
+
+// Test bound results in result pattern
+// ---
+// CHECK: void rewrite
+// CHECK: auto& s = *static_cast<MatchedState *>(state.get());
+// CHECK: auto res_b = rewriter.create<OpB>
+// CHECK: auto res_c = rewriter.create<OpC>(
+// CHECK: /*operand=*/res_b
+// CHECK: auto vOpD0 = rewriter.create<OpD>(
+// CHECK: /*input1=*/res_b,
+// CHECK: /*input2=*/res_c,
+// CHECK: /*attr=*/s.attr
diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td
index 39fa4792481..08911156ead 100644
--- a/mlir/test/mlir-tblgen/pattern-tAttr.td
+++ b/mlir/test/mlir-tblgen/pattern-tAttr.td
@@ -3,52 +3,28 @@
include "mlir/IR/OpBase.td"
// Create a Type and Attribute.
-def T : BuildableType<"buildT">;
+def T : BuildableType<"buildT()">;
def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
// Define ops to rewrite.
-def U: Type<CPred<"true">, "U">;
-def X_AddOp : Op<"x.add"> {
- let arguments = (ins U, U);
+def Y_Op : Op<"y.op"> {
+ let arguments = (ins T_Attr:$attrName);
+ let results = (outs I32:$result);
}
-def Y_AddOp : Op<"y.add"> {
- let arguments = (ins U, U, T_Attr:$attrName);
- let results = (outs U);
-}
-def Z_AddOp : Op<"z.add"> {
- let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2);
- let results = (outs U);
+def Z_Op : Op<"z.op"> {
+ let arguments = (ins T_Attr:$attrName1, T_Attr:$attrName2);
+ let results = (outs I32:$result);
}
// Define rewrite pattern.
-def : Pat<(Y_AddOp $lhs, $rhs, $attr1), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, T_Const_Attr:$attr2))>;
-// CHECK: struct GeneratedConvert0 : public RewritePattern
-// CHECK: RewritePattern("y.add", 1, context)
-// CHECK: PatternMatchResult match(Operation *
-// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState>
-// CHECK-NEXT: PatternRewriter &rewriter)
-// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
-// CHECK-NEXT: s.lhs,
-// CHECK-NEXT: s.rhs,
-// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT, attrValue)})
-// CHECK-NEXT: );
-// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
-
-def : Pat<(Z_AddOp $lhs, $rhs, $attr1, $attr2), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, $attr2))>;
-// CHECK: struct GeneratedConvert1 : public RewritePattern
-// CHECK: RewritePattern("z.add", 1, context)
-// CHECK: PatternMatchResult match(Operation *
-// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState>
-// CHECK-NEXT: PatternRewriter &rewriter)
-// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(),
-// CHECK-NEXT: s.lhs,
-// CHECK-NEXT: s.rhs,
-// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2})
-// CHECK-NEXT: );
-// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0});
+def : Pat<(Y_Op $attr1), (Y_Op (T_Compose_Attr $attr1, T_Const_Attr))>;
+// CHECK-LABEL: struct GeneratedConvert0
+// CHECK: void rewrite(
+// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT(), attrValue)})
-// CHECK: void populateWithGenerated
-// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context))
-// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert1>(context))
+def : Pat<(Z_Op $attr1, $attr2), (Y_Op (T_Compose_Attr $attr1, $attr2))>;
+// CHECK-LABEL: struct GeneratedConvert1
+// CHECK: void rewrite(
+// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2})
diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td
new file mode 100644
index 00000000000..bea44c93f53
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pattern.td
@@ -0,0 +1,35 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def OpA : Op<"op_a", []> {
+ let arguments = (ins I32:$operand, I32Attr:$attr);
+ let results = (outs I32:$result);
+}
+
+def OpB : Op<"op_b", []> {
+ let arguments = (ins I32:$operand, I32Attr:$attr);
+ let results = (outs I32:$result);
+}
+
+def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+
+// Test basic structure generated from Pattern
+// ---
+
+// CHECK: struct GeneratedConvert0 : public RewritePattern
+
+// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {}
+
+// CHECK: struct MatchedState : public PatternState {
+// CHECK: Value* input;
+// CHECK: IntegerAttr attr;
+// CHECK: };
+
+// CHECK: PatternMatchResult match(Operation *op0) const override
+
+// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+// CHECK: PatternRewriter &rewriter) const override
+
+// CHECK: void populateWithGenerated(MLIRContext *context, OwningRewritePatternList *patterns)
+// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context));
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index d0e400813ef..6622a783a95 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -41,6 +41,100 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
+static const char *const tblgenNamePrefix = "tblgen_";
+
+// Returns the bound value for the given op result `symbol`.
+static Twine getBoundResult(const StringRef &symbol) {
+ return tblgenNamePrefix + symbol;
+}
+
+// Returns the bound value for the given op argument `symbol`.
+//
+// Arguments bound in the source pattern are grouped into a transient
+// `PatternState` struct. This struct can be accessed in both `match()` and
+// `rewrite()` via the local variable named as `s`.
+static Twine getBoundArgument(const StringRef &symbol) {
+ return Twine("s.") + symbol;
+}
+
+//===----------------------------------------------------------------------===//
+// PatternSymbolResolver
+//===----------------------------------------------------------------------===//
+
+namespace {
+// A class for resolving symbols bound in patterns.
+//
+// Symbols can be bound to op arguments/results in the source pattern and op
+// results in result patterns. For example, in
+//
+// ```
+// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
+// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
+// ```
+//
+// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
+// `$op2` is bound to `ResOp1`.
+//
+// This class keeps track of such symbols and translates them into their bound
+// values.
+//
+// Note that we also generate local variables for unnamed DAG nodes, like
+// `(ResOp3)` in the above. Since we don't bind a symbol to the result, the
+// generated local variable will be implicitly named. Those implicit names are
+// not tracked in this class.
+class PatternSymbolResolver {
+public:
+ PatternSymbolResolver(const StringMap<Argument> &srcArgs,
+ const StringSet<> &srcResults);
+
+ // Marks the given `symbol` as bound. Returns false if the `symbol` is
+ // already bound.
+ bool add(StringRef symbol);
+
+ // Queries the substitution for the given `symbol`.
+ std::string query(StringRef symbol) const;
+
+private:
+ // Symbols bound to arguments in source pattern.
+ const StringMap<Argument> &sourceArguments;
+ // Symbols bound to ops (for their results) in source pattern.
+ const StringSet<> &sourceOps;
+ // Symbols bound to ops (for their results) in result patterns.
+ StringSet<> resultOps;
+};
+} // end anonymous namespace
+
+PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
+ const StringSet<> &srcResults)
+ : sourceArguments(srcArgs), sourceOps(srcResults) {}
+
+bool PatternSymbolResolver::add(StringRef symbol) {
+ return resultOps.insert(symbol).second;
+}
+
+std::string PatternSymbolResolver::query(StringRef symbol) const {
+ {
+ auto it = resultOps.find(symbol);
+ if (it != resultOps.end())
+ return it->getKey();
+ }
+ {
+ auto it = sourceArguments.find(symbol);
+ if (it != sourceArguments.end())
+ return getBoundArgument(symbol).str();
+ }
+ {
+ auto it = sourceOps.find(symbol);
+ if (it != sourceOps.end())
+ return getBoundResult(symbol).str();
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// PatternEmitter
+//===----------------------------------------------------------------------===//
+
namespace {
class PatternEmitter {
public:
@@ -109,6 +203,13 @@ private:
// Returns the C++ expression to build an argument from the given DAG `tree`.
std::string handleOpArgument(DagNode tree);
+ // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
+ // is already bound.
+ void addSymbol(DagNode node);
+
+ // Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
+ std::string resolveSymbol(StringRef symbol);
+
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@@ -118,16 +219,19 @@ private:
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
Pattern pattern;
+ PatternSymbolResolver symbolResolver;
// The next unused ID for newly created values
unsigned nextValueId;
raw_ostream &os;
};
-} // end namespace
+} // end anonymous namespace
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
- : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
- os(os) {}
+ : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
+ symbolResolver(pattern.getSourcePatternBoundArgs(),
+ pattern.getSourcePatternBoundResults()),
+ nextValueId(0), os(os) {}
std::string PatternEmitter::handleConstantAttr(Attribute attr,
StringRef value) {
@@ -140,20 +244,6 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
value);
}
-static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
-
-static Twine boundArgNameInMatch(const StringRef &name) {
- // Bound value in the source pattern are grouped into a transient struct. That
- // struct is hold in a local variable named as "state" in the match() method.
- return Twine("state->") + name;
-}
-
-static Twine boundArgNameInRewrite(const StringRef &name) {
- // Bound value in the source pattern are grouped into a transient struct. That
- // struct is passed into the rewrite() method as a parameter with name `s`.
- return Twine("s.") + name;
-}
-
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
@@ -176,7 +266,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getOpName();
if (!name.empty())
- os.indent(indent) << formatv("{0} = op{1};\n", resultName(name), depth);
+ os.indent(indent) << formatv("{0} = op{1};\n", getBoundResult(name), depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
@@ -232,7 +322,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
- os.indent(indent) << "state->" << name << " = op" << depth
+ os.indent(indent) << getBoundArgument(name) << " = op" << depth
<< "->getOperand(" << index << ");\n";
}
}
@@ -262,7 +352,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
- os.indent(indent) << "state->" << name << " = op" << depth
+ os.indent(indent) << getBoundArgument(name) << " = op" << depth
<< "->getAttrOfType<" << namedAttr->attr.getStorageType()
<< ">(\"" << namedAttr->getName() << "\");\n";
}
@@ -273,8 +363,9 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
os << R"(
PatternMatchResult match(Operation *op0) const override {
auto ctx = op0->getContext(); (void)ctx;
- auto state = llvm::make_unique<MatchedState>();)"
- << "\n";
+ auto state = llvm::make_unique<MatchedState>();
+ auto &s = *state;
+)";
// The rewrite pattern may specify that certain outputs should be unused in
// the source IR. Check it here.
@@ -287,20 +378,10 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
for (auto &res : pattern.getSourcePatternBoundResults())
os.indent(4) << formatv("mlir::Operation* {0}; (void){0};\n",
- resultName(res.first()));
+ getBoundResult(res.first()));
emitOpMatch(tree, 0);
- auto deduceName = [&](const std::string &name) -> std::string {
- if (pattern.isArgBoundInSourcePattern(name)) {
- return boundArgNameInMatch(name).str();
- }
- if (pattern.isResultBoundInSourcePattern(name)) {
- return resultName(name).str();
- }
- PrintFatalError(loc, formatv("referencing unbound variable '{0}'", name));
- };
-
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
@@ -311,8 +392,9 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
if (isa<TypeConstraint>(constraint)) {
// TODO(jpienaar): Verify op only has one result.
os.indent(4) << formatv(
- cmd, formatv(condition.c_str(), "(*" + deduceName(entities.front()) +
- "->result_type_begin())"));
+ cmd,
+ formatv(condition.c_str(), "(*" + resolveSymbol(entities.front()) +
+ "->result_type_begin())"));
} else if (isa<AttrConstraint>(constraint)) {
PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
@@ -325,7 +407,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
SmallVector<std::string, 4> names;
unsigned i = 0;
for (unsigned e = entities.size(); i < e; ++i)
- names.push_back(deduceName(entities[i]));
+ names.push_back(resolveSymbol(entities[i]));
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0],
@@ -393,6 +475,8 @@ void PatternEmitter::emitRewriteMethod() {
for (unsigned i = 0; i < numProvidedResults; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
+ // Keep track of bound symbols at the top-level DAG nodes
+ addSymbol(resultTree);
}
// Emit the final replaceOp() statement
@@ -420,6 +504,11 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
"verify top-level result");
}
+
+ if (!resultTree.getOpName().empty()) {
+ PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
+ }
+
// The C++ statements to check that this result value is unused are already
// emitted in the match() method. So returning a nullptr here directly
// should be safe because the C++ RewritePattern harness will use it to
@@ -441,10 +530,14 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
loc, "replaceWithValue directive must take exactly one argument");
}
+ if (!tree.getOpName().empty()) {
+ PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
+ }
+
auto name = tree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
- return boundArgNameInRewrite(name).str();
+ return getBoundArgument(name).str();
}
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
@@ -466,7 +559,7 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
return handleConstantAttr(enumCase, enumCase.getSymbol());
}
pattern.ensureArgBoundInSourcePattern(argName);
- std::string result = boundArgNameInRewrite(argName).str();
+ std::string result = getBoundArgument(argName).str();
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return result;
}
@@ -494,6 +587,22 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) {
attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]);
}
+void PatternEmitter::addSymbol(DagNode node) {
+ StringRef symbol = node.getOpName();
+ // Skip empty-named symbols, which happen for unbound ops in result patterns.
+ if (symbol.empty())
+ return;
+ if (!symbolResolver.add(symbol))
+ PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
+}
+
+std::string PatternEmitter::resolveSymbol(StringRef symbol) {
+ auto subst = symbolResolver.query(symbol);
+ if (subst.empty())
+ PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
+ return subst;
+}
+
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth) {
Operator &resultOp = tree.getDialectOp(opMap);
@@ -513,7 +622,10 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
}
// A map to collect all nested DAG child nodes' names, with operand index as
- // the key.
+ // the key. This includes both bound and unbound child nodes. Bound child
+ // nodes will additionally be tracked in `symbolResolver` so they can be
+ // referenced by other patterns. Unbound child nodes will only be used once
+ // to build this op.
llvm::DenseMap<unsigned, std::string> childNodeNames;
// First go through all the child nodes who are nested DAG constructs to
@@ -522,6 +634,8 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
for (unsigned i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) {
childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
+ // Keep track of bound symbols at the middle-level DAG nodes
+ addSymbol(child);
}
}
@@ -532,27 +646,6 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
// Then we build the new op corresponding to this DAG node.
- // Returns the name we should use for the `index`-th argument of this
- // DAG node. This is needed because the we can reference an argument
- // 1) generated from a nested DAG node and implicitly named,
- // 2) bound in the source pattern and explicitly named,
- // 3) bound in the result pattern and explicitly named.
- auto deduceArgName = [&](unsigned index) -> std::string {
- if (tree.isNestedDagArg(index)) {
- // Implicitly named
- return childNodeNames[index];
- }
-
- auto name = tree.getArgName(index);
- if (this->pattern.isArgBoundInSourcePattern(name)) {
- // Bound in source pattern, explicitly named
- return boundArgNameInRewrite(name).str();
- }
-
- // Bound in result pattern, explicitly named
- return name.str();
- };
-
// TODO: this is a hack to support various constant ops. We are assuming
// all of them have no operands and one attribute here. Figure out a better
// way to do this.
@@ -584,7 +677,11 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
- os << deduceArgName(i);
+ if (tree.isNestedDagArg(i)) {
+ os << childNodeNames[i];
+ } else {
+ os << resolveSymbol(tree.getArgName(i));
+ }
// TODO(jpienaar): verify types
++i;
@@ -640,7 +737,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
}
if (!first)
os << ",";
- os << boundArgNameInRewrite(name);
+ os << getBoundArgument(name);
first = false;
}
if (!printingAttr)
OpenPOWER on IntegriCloud