diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-04-04 05:44:58 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-04-05 07:40:41 -0700 |
| commit | c7790df2ed9bdcde12683aee6cb89a2668b56661 (patch) | |
| tree | 61133a877a95d2f40c8d780ed0d77d71c9356c1e /mlir | |
| parent | 3c833344c858ae8af38fbd6f20ce9f07a685c15f (diff) | |
| download | bcm5719-llvm-c7790df2ed9bdcde12683aee6cb89a2668b56661.tar.gz bcm5719-llvm-c7790df2ed9bdcde12683aee6cb89a2668b56661.zip | |
[TableGen] Add PatternSymbolResolver for resolving symbols bound in patterns
We can bind symbols to op arguments/results in source pattern and op results in
result pattern. Previously resolving these symbols is scattered across
RewriterGen.cpp. This CL aggregated them into a `PatternSymbolResolver` class.
While we are here, this CL also cleans up tests for patterns to make them more
focused. Specifically, one-op-one-result.td is superseded by pattern.td;
pattern-tAttr.td is simplified; pattern-bound-symbol.td is added for the change
in this CL.
--
PiperOrigin-RevId: 241913973
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/IR/OpBase.td | 30 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/one-op-one-result.td | 31 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern-bound-symbol.td | 61 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern-tAttr.td | 54 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/pattern.td | 35 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/RewriterGen.cpp | 221 |
6 files changed, 295 insertions, 137 deletions
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 3deddd0dfa4..473997a50da 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -804,13 +804,33 @@ def addBenefit; // A rewrite rule contains two components: a source pattern and one or more // result patterns. Each pattern is specified as a (recursive) DAG node (tree) // in the form of `(node arg0, arg1, ...)`. +// // The `node` are normally MLIR ops, but it can also be one of the directives // listed later in this section. -// In the source pattern, `arg*` can be used to specify matchers (e.g., using -// type/attribute types, etc.) and bound to a name for later use. In -// the result pattern, `arg*` can be used to refer to a previously bound name, -// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself -// be nested DAG node. +// +// In the source pattern, `argN` can be used to specify matchers (e.g., using +// type/attribute type constraints, etc.) and bound to a name for later use. +// We can also bound names to op instances to reference them later in +// multi-entity constraints. +// +// In the result pattern, `argN` can be used to refer to a previously bound +// name, with potential transformations (e.g., using tAttr, etc.). `argN` can +// itself be nested DAG node. We can also bound names to op results to reference +// them later in other result patterns. +// +// For example, +// +// ``` +// def : Pattern<(OneResultOp1:$res1 $arg0, $arg1), +// [(OneResultOp2:$res2 $arg0, $arg1), +// (OneResultOp3 $res2 (OneResultOp4))], +// [(IsStaticShapeTensorTypePred $res1)]>; +// ``` +// +// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to +// build `OneResultOp2`. `$res1` is bound to `OneResultOp1`'s result and used to +// check whether the result's shape is static. `$res2` is bound to the result of +// `OneResultOp2` and used to build `OneResultOp3`. class Pattern<dag source, list<dag> results, list<dag> preds = [], dag benefitAdded = (addBenefit 0)> { dag sourcePattern = source; diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td deleted file mode 100644 index 56e104d87c0..00000000000 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s - -include "mlir/IR/OpBase.td" - -// Create a Type and Attribute. -def T : BuildableType<"buildT">; -def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">; -def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">; - -// Define ops to rewrite. -def U: Type<CPred<"true">, "U">; -def X_AddOp : Op<"x.add"> { - let arguments = (ins U, U); -} -def Y_AddOp : Op<"y.add"> { - let arguments = (ins U, U, T_Attr:$attrName); -} - -// Define rewrite pattern. -def : Pat<(X_AddOp:$res $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; -def : Pat<(X_AddOp (X_AddOp $lhs, $rhs):$res, $rrhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; -def : Pat<(X_AddOp (X_AddOp:$res $lhs, $rhs), $rrhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; - -// CHECK: struct GeneratedConvert0 : public RewritePattern -// CHECK: RewritePattern("x.add", 1, context) -// CHECK: PatternMatchResult match(Operation * -// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> -// CHECK: PatternRewriter &rewriter) -// CHECK: rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType() -// CHECK: void populateWithGenerated -// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context)) diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td new file mode 100644 index 00000000000..55f4d163116 --- /dev/null +++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td @@ -0,0 +1,61 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def OpA : Op<"op_a", []> { + let arguments = (ins I32:$operand, I32Attr:$attr); + let results = (outs I32:$result); +} + +def OpB : Op<"op_b", []> { + let arguments = (ins I32:$operand); + let results = (outs I32:$result); +} + +def OpC : Op<"op_c", []> { + let arguments = (ins I32:$operand); + let results = (outs I32:$result); +} + +def OpD : Op<"op_d", []> { + let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr); + let results = (outs I32:$result); +} + +def hasOneUse: Constraint<CPred<"{0}->hasOneUse()">, "has one use">; + +def : Pattern<(OpA:$res_a $operand, $attr), + [(OpC:$res_c (OpB:$res_b $operand)), + (OpD $res_b, $res_c, $attr)], + [(hasOneUse $res_a)]>; + +// CHECK-LABEL: GeneratedConvert0 + +// Test struct for bound arguments +// --- +// CHECK: struct MatchedState : public PatternState +// CHECK: Value* operand; +// CHECK: IntegerAttr attr; + +// Test bound arguments/results in source pattern +// --- +// CHECK: PatternMatchResult match +// CHECK: auto state = llvm::make_unique<MatchedState>(); +// CHECK: auto &s = *state; +// CHECK: mlir::Operation* tblgen_res_a; (void)tblgen_res_a; +// CHECK: tblgen_res_a = op0; +// CHECK: s.operand = op0->getOperand(0); +// CHECK: s.attr = op0->getAttrOfType<IntegerAttr>("attr"); +// CHECK: if (!(tblgen_res_a->hasOneUse())) return matchFailure(); + +// Test bound results in result pattern +// --- +// CHECK: void rewrite +// CHECK: auto& s = *static_cast<MatchedState *>(state.get()); +// CHECK: auto res_b = rewriter.create<OpB> +// CHECK: auto res_c = rewriter.create<OpC>( +// CHECK: /*operand=*/res_b +// CHECK: auto vOpD0 = rewriter.create<OpD>( +// CHECK: /*input1=*/res_b, +// CHECK: /*input2=*/res_c, +// CHECK: /*attr=*/s.attr diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td index 39fa4792481..08911156ead 100644 --- a/mlir/test/mlir-tblgen/pattern-tAttr.td +++ b/mlir/test/mlir-tblgen/pattern-tAttr.td @@ -3,52 +3,28 @@ include "mlir/IR/OpBase.td" // Create a Type and Attribute. -def T : BuildableType<"buildT">; +def T : BuildableType<"buildT()">; def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">; def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">; def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">; // Define ops to rewrite. -def U: Type<CPred<"true">, "U">; -def X_AddOp : Op<"x.add"> { - let arguments = (ins U, U); +def Y_Op : Op<"y.op"> { + let arguments = (ins T_Attr:$attrName); + let results = (outs I32:$result); } -def Y_AddOp : Op<"y.add"> { - let arguments = (ins U, U, T_Attr:$attrName); - let results = (outs U); -} -def Z_AddOp : Op<"z.add"> { - let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2); - let results = (outs U); +def Z_Op : Op<"z.op"> { + let arguments = (ins T_Attr:$attrName1, T_Attr:$attrName2); + let results = (outs I32:$result); } // Define rewrite pattern. -def : Pat<(Y_AddOp $lhs, $rhs, $attr1), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, T_Const_Attr:$attr2))>; -// CHECK: struct GeneratedConvert0 : public RewritePattern -// CHECK: RewritePattern("y.add", 1, context) -// CHECK: PatternMatchResult match(Operation * -// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> -// CHECK-NEXT: PatternRewriter &rewriter) -// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(), -// CHECK-NEXT: s.lhs, -// CHECK-NEXT: s.rhs, -// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT, attrValue)}) -// CHECK-NEXT: ); -// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0}); - -def : Pat<(Z_AddOp $lhs, $rhs, $attr1, $attr2), (Y_AddOp $lhs, $rhs, (T_Compose_Attr $attr1, $attr2))>; -// CHECK: struct GeneratedConvert1 : public RewritePattern -// CHECK: RewritePattern("z.add", 1, context) -// CHECK: PatternMatchResult match(Operation * -// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> -// CHECK-NEXT: PatternRewriter &rewriter) -// CHECK: auto vAddOp0 = rewriter.create<Y::AddOp>(loc, op->getResult(0)->getType(), -// CHECK-NEXT: s.lhs, -// CHECK-NEXT: s.rhs, -// CHECK-NEXT: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2}) -// CHECK-NEXT: ); -// CHECK-NEXT: rewriter.replaceOp(op, {vAddOp0}); +def : Pat<(Y_Op $attr1), (Y_Op (T_Compose_Attr $attr1, T_Const_Attr))>; +// CHECK-LABEL: struct GeneratedConvert0 +// CHECK: void rewrite( +// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT(), attrValue)}) -// CHECK: void populateWithGenerated -// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context)) -// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert1>(context)) +def : Pat<(Z_Op $attr1, $attr2), (Y_Op (T_Compose_Attr $attr1, $attr2))>; +// CHECK-LABEL: struct GeneratedConvert1 +// CHECK: void rewrite( +// CHECK: /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2}) diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td new file mode 100644 index 00000000000..bea44c93f53 --- /dev/null +++ b/mlir/test/mlir-tblgen/pattern.td @@ -0,0 +1,35 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def OpA : Op<"op_a", []> { + let arguments = (ins I32:$operand, I32Attr:$attr); + let results = (outs I32:$result); +} + +def OpB : Op<"op_b", []> { + let arguments = (ins I32:$operand, I32Attr:$attr); + let results = (outs I32:$result); +} + +def : Pat<(OpA $input, $attr), (OpB $input, $attr)>; + +// Test basic structure generated from Pattern +// --- + +// CHECK: struct GeneratedConvert0 : public RewritePattern + +// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("op_a", 1, context) {} + +// CHECK: struct MatchedState : public PatternState { +// CHECK: Value* input; +// CHECK: IntegerAttr attr; +// CHECK: }; + +// CHECK: PatternMatchResult match(Operation *op0) const override + +// CHECK: void rewrite(Operation *op, std::unique_ptr<PatternState> state, +// CHECK: PatternRewriter &rewriter) const override + +// CHECK: void populateWithGenerated(MLIRContext *context, OwningRewritePatternList *patterns) +// CHECK: patterns->push_back(llvm::make_unique<GeneratedConvert0>(context)); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index d0e400813ef..6622a783a95 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -41,6 +41,100 @@ using namespace llvm; using namespace mlir; using namespace mlir::tblgen; +static const char *const tblgenNamePrefix = "tblgen_"; + +// Returns the bound value for the given op result `symbol`. +static Twine getBoundResult(const StringRef &symbol) { + return tblgenNamePrefix + symbol; +} + +// Returns the bound value for the given op argument `symbol`. +// +// Arguments bound in the source pattern are grouped into a transient +// `PatternState` struct. This struct can be accessed in both `match()` and +// `rewrite()` via the local variable named as `s`. +static Twine getBoundArgument(const StringRef &symbol) { + return Twine("s.") + symbol; +} + +//===----------------------------------------------------------------------===// +// PatternSymbolResolver +//===----------------------------------------------------------------------===// + +namespace { +// A class for resolving symbols bound in patterns. +// +// Symbols can be bound to op arguments/results in the source pattern and op +// results in result patterns. For example, in +// +// ``` +// def : Pattern<(SrcOp:$op1 $arg0, %arg1), +// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; +// ``` +// +// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. +// `$op2` is bound to `ResOp1`. +// +// This class keeps track of such symbols and translates them into their bound +// values. +// +// Note that we also generate local variables for unnamed DAG nodes, like +// `(ResOp3)` in the above. Since we don't bind a symbol to the result, the +// generated local variable will be implicitly named. Those implicit names are +// not tracked in this class. +class PatternSymbolResolver { +public: + PatternSymbolResolver(const StringMap<Argument> &srcArgs, + const StringSet<> &srcResults); + + // Marks the given `symbol` as bound. Returns false if the `symbol` is + // already bound. + bool add(StringRef symbol); + + // Queries the substitution for the given `symbol`. + std::string query(StringRef symbol) const; + +private: + // Symbols bound to arguments in source pattern. + const StringMap<Argument> &sourceArguments; + // Symbols bound to ops (for their results) in source pattern. + const StringSet<> &sourceOps; + // Symbols bound to ops (for their results) in result patterns. + StringSet<> resultOps; +}; +} // end anonymous namespace + +PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs, + const StringSet<> &srcResults) + : sourceArguments(srcArgs), sourceOps(srcResults) {} + +bool PatternSymbolResolver::add(StringRef symbol) { + return resultOps.insert(symbol).second; +} + +std::string PatternSymbolResolver::query(StringRef symbol) const { + { + auto it = resultOps.find(symbol); + if (it != resultOps.end()) + return it->getKey(); + } + { + auto it = sourceArguments.find(symbol); + if (it != sourceArguments.end()) + return getBoundArgument(symbol).str(); + } + { + auto it = sourceOps.find(symbol); + if (it != sourceOps.end()) + return getBoundResult(symbol).str(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// PatternEmitter +//===----------------------------------------------------------------------===// + namespace { class PatternEmitter { public: @@ -109,6 +203,13 @@ private: // Returns the C++ expression to build an argument from the given DAG `tree`. std::string handleOpArgument(DagNode tree); + // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol + // is already bound. + void addSymbol(DagNode node); + + // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. + std::string resolveSymbol(StringRef symbol); + private: // Pattern instantiation location followed by the location of multiclass // prototypes used. This is intended to be used as a whole to @@ -118,16 +219,19 @@ private: RecordOperatorMap *opMap; // Handy wrapper for pattern being emitted Pattern pattern; + PatternSymbolResolver symbolResolver; // The next unused ID for newly created values unsigned nextValueId; raw_ostream &os; }; -} // end namespace +} // end anonymous namespace PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) - : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0), - os(os) {} + : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), + symbolResolver(pattern.getSourcePatternBoundArgs(), + pattern.getSourcePatternBoundResults()), + nextValueId(0), os(os) {} std::string PatternEmitter::handleConstantAttr(Attribute attr, StringRef value) { @@ -140,20 +244,6 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr, value); } -static Twine resultName(const StringRef &name) { return Twine("res_") + name; } - -static Twine boundArgNameInMatch(const StringRef &name) { - // Bound value in the source pattern are grouped into a transient struct. That - // struct is hold in a local variable named as "state" in the match() method. - return Twine("state->") + name; -} - -static Twine boundArgNameInRewrite(const StringRef &name) { - // Bound value in the source pattern are grouped into a transient struct. That - // struct is passed into the rewrite() method as a parameter with name `s`. - return Twine("s.") + name; -} - // Helper function to match patterns. void PatternEmitter::emitOpMatch(DagNode tree, int depth) { Operator &op = tree.getDialectOp(opMap); @@ -176,7 +266,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // If the operand's name is set, set to that variable. auto name = tree.getOpName(); if (!name.empty()) - os.indent(indent) << formatv("{0} = op{1};\n", resultName(name), depth); + os.indent(indent) << formatv("{0} = op{1};\n", getBoundResult(name), depth); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); @@ -232,7 +322,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, // Capture the value auto name = tree.getArgName(index); if (!name.empty()) { - os.indent(indent) << "state->" << name << " = op" << depth + os.indent(indent) << getBoundArgument(name) << " = op" << depth << "->getOperand(" << index << ");\n"; } } @@ -262,7 +352,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, // Capture the value auto name = tree.getArgName(index); if (!name.empty()) { - os.indent(indent) << "state->" << name << " = op" << depth + os.indent(indent) << getBoundArgument(name) << " = op" << depth << "->getAttrOfType<" << namedAttr->attr.getStorageType() << ">(\"" << namedAttr->getName() << "\");\n"; } @@ -273,8 +363,9 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { os << R"( PatternMatchResult match(Operation *op0) const override { auto ctx = op0->getContext(); (void)ctx; - auto state = llvm::make_unique<MatchedState>();)" - << "\n"; + auto state = llvm::make_unique<MatchedState>(); + auto &s = *state; +)"; // The rewrite pattern may specify that certain outputs should be unused in // the source IR. Check it here. @@ -287,20 +378,10 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { for (auto &res : pattern.getSourcePatternBoundResults()) os.indent(4) << formatv("mlir::Operation* {0}; (void){0};\n", - resultName(res.first())); + getBoundResult(res.first())); emitOpMatch(tree, 0); - auto deduceName = [&](const std::string &name) -> std::string { - if (pattern.isArgBoundInSourcePattern(name)) { - return boundArgNameInMatch(name).str(); - } - if (pattern.isResultBoundInSourcePattern(name)) { - return resultName(name).str(); - } - PrintFatalError(loc, formatv("referencing unbound variable '{0}'", name)); - }; - for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; @@ -311,8 +392,9 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { if (isa<TypeConstraint>(constraint)) { // TODO(jpienaar): Verify op only has one result. os.indent(4) << formatv( - cmd, formatv(condition.c_str(), "(*" + deduceName(entities.front()) + - "->result_type_begin())")); + cmd, + formatv(condition.c_str(), "(*" + resolveSymbol(entities.front()) + + "->result_type_begin())")); } else if (isa<AttrConstraint>(constraint)) { PrintFatalError( loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); @@ -325,7 +407,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { SmallVector<std::string, 4> names; unsigned i = 0; for (unsigned e = entities.size(); i < e; ++i) - names.push_back(deduceName(entities[i])); + names.push_back(resolveSymbol(entities[i])); for (; i < 4; ++i) names.push_back("<unused>"); os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0], @@ -393,6 +475,8 @@ void PatternEmitter::emitRewriteMethod() { for (unsigned i = 0; i < numProvidedResults; ++i) { DagNode resultTree = pattern.getResultPattern(i); resultValues.push_back(handleRewritePattern(resultTree, i, 0)); + // Keep track of bound symbols at the top-level DAG nodes + addSymbol(resultTree); } // Emit the final replaceOp() statement @@ -420,6 +504,11 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree, PrintFatalError(loc, "verifyUnusedValue directive can only be used to " "verify top-level result"); } + + if (!resultTree.getOpName().empty()) { + PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); + } + // The C++ statements to check that this result value is unused are already // emitted in the match() method. So returning a nullptr here directly // should be safe because the C++ RewritePattern harness will use it to @@ -441,10 +530,14 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { loc, "replaceWithValue directive must take exactly one argument"); } + if (!tree.getOpName().empty()) { + PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); + } + auto name = tree.getArgName(0); pattern.ensureArgBoundInSourcePattern(name); - return boundArgNameInRewrite(name).str(); + return getBoundArgument(name).str(); } void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) { @@ -466,7 +559,7 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, return handleConstantAttr(enumCase, enumCase.getSymbol()); } pattern.ensureArgBoundInSourcePattern(argName); - std::string result = boundArgNameInRewrite(argName).str(); + std::string result = getBoundArgument(argName).str(); if (leaf.isUnspecified() || leaf.isOperandMatcher()) { return result; } @@ -494,6 +587,22 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) { attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]); } +void PatternEmitter::addSymbol(DagNode node) { + StringRef symbol = node.getOpName(); + // Skip empty-named symbols, which happen for unbound ops in result patterns. + if (symbol.empty()) + return; + if (!symbolResolver.add(symbol)) + PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); +} + +std::string PatternEmitter::resolveSymbol(StringRef symbol) { + auto subst = symbolResolver.query(symbol); + if (subst.empty()) + PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); + return subst; +} + std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, int depth) { Operator &resultOp = tree.getDialectOp(opMap); @@ -513,7 +622,10 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, } // A map to collect all nested DAG child nodes' names, with operand index as - // the key. + // the key. This includes both bound and unbound child nodes. Bound child + // nodes will additionally be tracked in `symbolResolver` so they can be + // referenced by other patterns. Unbound child nodes will only be used once + // to build this op. llvm::DenseMap<unsigned, std::string> childNodeNames; // First go through all the child nodes who are nested DAG constructs to @@ -522,6 +634,8 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, for (unsigned i = 0, e = resultOp.getNumOperands(); i != e; ++i) { if (auto child = tree.getArgAsNestedDag(i)) { childNodeNames[i] = handleRewritePattern(child, i, depth + 1); + // Keep track of bound symbols at the middle-level DAG nodes + addSymbol(child); } } @@ -532,27 +646,6 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, // Then we build the new op corresponding to this DAG node. - // Returns the name we should use for the `index`-th argument of this - // DAG node. This is needed because the we can reference an argument - // 1) generated from a nested DAG node and implicitly named, - // 2) bound in the source pattern and explicitly named, - // 3) bound in the result pattern and explicitly named. - auto deduceArgName = [&](unsigned index) -> std::string { - if (tree.isNestedDagArg(index)) { - // Implicitly named - return childNodeNames[index]; - } - - auto name = tree.getArgName(index); - if (this->pattern.isArgBoundInSourcePattern(name)) { - // Bound in source pattern, explicitly named - return boundArgNameInRewrite(name).str(); - } - - // Bound in result pattern, explicitly named - return name.str(); - }; - // TODO: this is a hack to support various constant ops. We are assuming // all of them have no operands and one attribute here. Figure out a better // way to do this. @@ -584,7 +677,11 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, if (!operand.name.empty()) os << "/*" << operand.name << "=*/"; - os << deduceArgName(i); + if (tree.isNestedDagArg(i)) { + os << childNodeNames[i]; + } else { + os << resolveSymbol(tree.getArgName(i)); + } // TODO(jpienaar): verify types ++i; @@ -640,7 +737,7 @@ std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { } if (!first) os << ","; - os << boundArgNameInRewrite(name); + os << getBoundArgument(name); first = false; } if (!printingAttr) |

