summaryrefslogtreecommitdiffstats
path: root/mlir/tools/mlir-tblgen/RewriterGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/RewriterGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp1036
1 files changed, 1036 insertions, 0 deletions
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
new file mode 100644
index 00000000000..824ddae85aa
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -0,0 +1,1036 @@
+//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Pattern.h"
+#include "mlir/TableGen/Predicate.h"
+#include "mlir/TableGen/Type.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatAdapters.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Main.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::formatv;
+using llvm::Record;
+using llvm::RecordKeeper;
+
+#define DEBUG_TYPE "mlir-tblgen-rewritergen"
+
+namespace llvm {
+template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
+ static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
+ raw_ostream &os, StringRef style) {
+ os << v.first << ":" << v.second;
+ }
+};
+} // end namespace llvm
+
+//===----------------------------------------------------------------------===//
+// PatternEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+class PatternEmitter {
+public:
+ PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
+
+ // Emits the mlir::RewritePattern struct named `rewriteName`.
+ void emit(StringRef rewriteName);
+
+private:
+ // Emits the code for matching ops.
+ void emitMatchLogic(DagNode tree);
+
+ // Emits the code for rewriting ops.
+ void emitRewriteLogic();
+
+ //===--------------------------------------------------------------------===//
+ // Match utilities
+ //===--------------------------------------------------------------------===//
+
+ // Emits C++ statements for matching the op constrained by the given DAG
+ // `tree`.
+ void emitOpMatch(DagNode tree, int depth);
+
+ // Emits C++ statements for matching the `argIndex`-th argument of the given
+ // DAG `tree` as an operand.
+ void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
+
+ // Emits C++ statements for matching the `argIndex`-th argument of the given
+ // DAG `tree` as an attribute.
+ void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
+
+ //===--------------------------------------------------------------------===//
+ // Rewrite utilities
+ //===--------------------------------------------------------------------===//
+
+ // The entry point for handling a result pattern rooted at `resultTree`. This
+ // method dispatches to concrete handlers according to `resultTree`'s kind and
+ // returns a symbol representing the whole value pack. Callers are expected to
+ // further resolve the symbol according to the specific use case.
+ //
+ // `depth` is the nesting level of `resultTree`; 0 means top-level result
+ // pattern. For top-level result pattern, `resultIndex` indicates which result
+ // of the matched root op this pattern is intended to replace, which can be
+ // used to deduce the result type of the op generated from this result
+ // pattern.
+ std::string handleResultPattern(DagNode resultTree, int resultIndex,
+ int depth);
+
+ // Emits the C++ statement to replace the matched DAG with a value built via
+ // calling native C++ code.
+ std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
+
+ // Returns the C++ expression referencing the old value serving as the
+ // replacement.
+ std::string handleReplaceWithValue(DagNode tree);
+
+ // Emits the C++ statement to build a new op out of the given DAG `tree` and
+ // returns the variable name that this op is assigned to. If the root op in
+ // DAG `tree` has a specified name, the created op will be assigned to a
+ // variable of the given name. Otherwise, a unique name will be used as the
+ // result value name.
+ std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
+
+ using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
+
+ // Emits a local variable for each value and attribute to be used for creating
+ // an op.
+ void createSeparateLocalVarsForOpArgs(DagNode node,
+ ChildNodeIndexNameMap &childNodeNames);
+
+ // Emits the concrete arguments used to call a op's builder.
+ void supplyValuesForOpArgs(DagNode node,
+ const ChildNodeIndexNameMap &childNodeNames);
+
+ // Emits the local variables for holding all values as a whole and all named
+ // attributes as a whole to be used for creating an op.
+ void createAggregateLocalVarsForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames);
+
+ // Returns the C++ expression to construct a constant attribute of the given
+ // `value` for the given attribute kind `attr`.
+ std::string handleConstantAttr(Attribute attr, StringRef value);
+
+ // Returns the C++ expression to build an argument from the given DAG `leaf`.
+ // `patArgName` is used to bound the argument to the source pattern.
+ std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
+
+ //===--------------------------------------------------------------------===//
+ // General utilities
+ //===--------------------------------------------------------------------===//
+
+ // Collects all of the operations within the given dag tree.
+ void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
+
+ // Returns a unique symbol for a local variable of the given `op`.
+ std::string getUniqueSymbol(const Operator *op);
+
+ //===--------------------------------------------------------------------===//
+ // Symbol utilities
+ //===--------------------------------------------------------------------===//
+
+ // Returns how many static values the given DAG `node` correspond to.
+ int getNodeValueCount(DagNode node);
+
+private:
+ // Pattern instantiation location followed by the location of multiclass
+ // prototypes used. This is intended to be used as a whole to
+ // PrintFatalError() on errors.
+ ArrayRef<llvm::SMLoc> loc;
+
+ // Op's TableGen Record to wrapper object.
+ RecordOperatorMap *opMap;
+
+ // Handy wrapper for pattern being emitted.
+ Pattern pattern;
+
+ // Map for all bound symbols' info.
+ SymbolInfoMap symbolInfoMap;
+
+ // The next unused ID for newly created values.
+ unsigned nextValueId;
+
+ raw_ostream &os;
+
+ // Format contexts containing placeholder substitutions.
+ FmtContext fmtCtx;
+
+ // Number of op processed.
+ int opCounter = 0;
+};
+} // end anonymous namespace
+
+PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
+ raw_ostream &os)
+ : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
+ symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
+ fmtCtx.withBuilder("rewriter");
+}
+
+std::string PatternEmitter::handleConstantAttr(Attribute attr,
+ StringRef value) {
+ if (!attr.isConstBuildable())
+ PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
+ " does not have the 'constBuilderCall' field");
+
+ // TODO(jpienaar): Verify the constants here
+ return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
+ Operator &op = tree.getDialectOp(opMap);
+ LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
+ << op.getOperationName() << "' at depth " << depth
+ << '\n');
+
+ int indent = 4 + 2 * depth;
+ os.indent(indent) << formatv(
+ "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
+ depth, op.getQualCppClassName());
+ // Skip the operand matching at depth 0 as the pattern rewriter already does.
+ if (depth != 0) {
+ // Skip if there is no defining operation (e.g., arguments to function).
+ os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
+ depth);
+ }
+ if (tree.getNumArgs() != op.getNumArgs()) {
+ PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
+ "pattern vs. {2} in definition",
+ op.getOperationName(), tree.getNumArgs(),
+ op.getNumArgs()));
+ }
+
+ // If the operand's name is set, set to that variable.
+ auto name = tree.getSymbol();
+ if (!name.empty())
+ os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
+
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ auto opArg = op.getArg(i);
+
+ // Handle nested DAG construct first
+ if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (operand->isVariadic()) {
+ auto error = formatv("use nested DAG construct to match op {0}'s "
+ "variadic operand #{1} unsupported now",
+ op.getOperationName(), i);
+ PrintFatalError(loc, error);
+ }
+ }
+ os.indent(indent) << "{\n";
+
+ os.indent(indent + 2) << formatv(
+ "auto *op{0} = "
+ "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
+ depth + 1, depth, i);
+ emitOpMatch(argTree, depth + 1);
+ os.indent(indent + 2)
+ << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+ os.indent(indent) << "}\n";
+ continue;
+ }
+
+ // Next handle DAG leaf: operand or attribute
+ if (opArg.is<NamedTypeConstraint *>()) {
+ emitOperandMatch(tree, i, depth, indent);
+ } else if (opArg.is<NamedAttribute *>()) {
+ emitAttributeMatch(tree, i, depth, indent);
+ } else {
+ PrintFatalError(loc, "unhandled case when matching op");
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
+ << op.getOperationName() << "' at depth " << depth
+ << '\n');
+}
+
+void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
+ int indent) {
+ Operator &op = tree.getDialectOp(opMap);
+ auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
+ auto matcher = tree.getArgAsLeaf(argIndex);
+
+ // If a constraint is specified, we need to generate C++ statements to
+ // check the constraint.
+ if (!matcher.isUnspecified()) {
+ if (!matcher.isOperandMatcher()) {
+ PrintFatalError(
+ loc, formatv("the {1}-th argument of op '{0}' should be an operand",
+ op.getOperationName(), argIndex + 1));
+ }
+
+ // Only need to verify if the matcher's type is different from the one
+ // of op definition.
+ if (operand->constraint != matcher.getAsConstraint()) {
+ if (operand->isVariadic()) {
+ auto error = formatv(
+ "further constrain op {0}'s variadic operand #{1} unsupported now",
+ op.getOperationName(), argIndex);
+ PrintFatalError(loc, error);
+ }
+ auto self =
+ formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
+ depth, argIndex);
+ os.indent(indent) << "if (!("
+ << tgfmt(matcher.getConditionTemplate(),
+ &fmtCtx.withSelf(self))
+ << ")) return matchFailure();\n";
+ }
+ }
+
+ // Capture the value
+ auto name = tree.getArgName(argIndex);
+ // `$_` is a special symbol to ignore op argument matching.
+ if (!name.empty() && name != "_") {
+ // We need to subtract the number of attributes before this operand to get
+ // the index in the operand list.
+ auto numPrevAttrs = std::count_if(
+ op.arg_begin(), op.arg_begin() + argIndex,
+ [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
+
+ os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
+ name, depth, argIndex - numPrevAttrs);
+ }
+}
+
+void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
+ int indent) {
+
+ Operator &op = tree.getDialectOp(opMap);
+ auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
+ const auto &attr = namedAttr->attr;
+
+ os.indent(indent) << "{\n";
+ indent += 2;
+ os.indent(indent) << formatv(
+ "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
+ attr.getStorageType(), namedAttr->name);
+
+ // TODO(antiagainst): This should use getter method to avoid duplication.
+ if (attr.hasDefaultValue()) {
+ os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
+ << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
+ attr.getDefaultValue())
+ << ";\n";
+ } else if (attr.isOptional()) {
+ // For a missing attribute that is optional according to definition, we
+ // should just capture a mlir::Attribute() to signal the missing state.
+ // That is precisely what getAttr() returns on missing attributes.
+ } else {
+ os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
+ }
+
+ auto matcher = tree.getArgAsLeaf(argIndex);
+ if (!matcher.isUnspecified()) {
+ if (!matcher.isAttrMatcher()) {
+ PrintFatalError(
+ loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
+ op.getOperationName(), argIndex + 1));
+ }
+
+ // If a constraint is specified, we need to generate C++ statements to
+ // check the constraint.
+ os.indent(indent) << "if (!("
+ << tgfmt(matcher.getConditionTemplate(),
+ &fmtCtx.withSelf("tblgen_attr"))
+ << ")) return matchFailure();\n";
+ }
+
+ // Capture the value
+ auto name = tree.getArgName(argIndex);
+ // `$_` is a special symbol to ignore op argument matching.
+ if (!name.empty() && name != "_") {
+ os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
+ }
+
+ indent -= 2;
+ os.indent(indent) << "}\n";
+}
+
+void PatternEmitter::emitMatchLogic(DagNode tree) {
+ LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
+ emitOpMatch(tree, 0);
+
+ for (auto &appliedConstraint : pattern.getConstraints()) {
+ auto &constraint = appliedConstraint.constraint;
+ auto &entities = appliedConstraint.entities;
+
+ auto condition = constraint.getConditionTemplate();
+ auto cmd = "if (!({0})) return matchFailure();\n";
+
+ if (isa<TypeConstraint>(constraint)) {
+ auto self = formatv("({0}->getType())",
+ symbolInfoMap.getValueAndRangeUse(entities.front()));
+ os.indent(4) << formatv(cmd,
+ tgfmt(condition, &fmtCtx.withSelf(self.str())));
+ } else if (isa<AttrConstraint>(constraint)) {
+ PrintFatalError(
+ loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
+ } else {
+ // TODO(b/138794486): replace formatv arguments with the exact specified
+ // args.
+ if (entities.size() > 4) {
+ PrintFatalError(loc, "only support up to 4-entity constraints now");
+ }
+ SmallVector<std::string, 4> names;
+ int i = 0;
+ for (int e = entities.size(); i < e; ++i)
+ names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
+ std::string self = appliedConstraint.self;
+ if (!self.empty())
+ self = symbolInfoMap.getValueAndRangeUse(self);
+ for (; i < 4; ++i)
+ names.push_back("<unused>");
+ os.indent(4) << formatv(cmd,
+ tgfmt(condition, &fmtCtx.withSelf(self), names[0],
+ names[1], names[2], names[3]));
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
+}
+
+void PatternEmitter::collectOps(DagNode tree,
+ llvm::SmallPtrSetImpl<const Operator *> &ops) {
+ // Check if this tree is an operation.
+ if (tree.isOperation()) {
+ const Operator &op = tree.getDialectOp(opMap);
+ LLVM_DEBUG(llvm::dbgs()
+ << "found operation " << op.getOperationName() << '\n');
+ ops.insert(&op);
+ }
+
+ // Recurse the arguments of the tree.
+ for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
+ if (auto child = tree.getArgAsNestedDag(i))
+ collectOps(child, ops);
+}
+
+void PatternEmitter::emit(StringRef rewriteName) {
+ // Get the DAG tree for the source pattern.
+ DagNode sourceTree = pattern.getSourcePattern();
+
+ const Operator &rootOp = pattern.getSourceRootOp();
+ auto rootName = rootOp.getOperationName();
+
+ // Collect the set of result operations.
+ llvm::SmallPtrSet<const Operator *, 4> resultOps;
+ LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
+ for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
+ collectOps(pattern.getResultPattern(i), resultOps);
+ }
+ LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
+
+ // Emit RewritePattern for Pattern.
+ auto locs = pattern.getLocation();
+ os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
+ make_range(locs.rbegin(), locs.rend()));
+ os << formatv(R"(struct {0} : public RewritePattern {
+ {0}(MLIRContext *context)
+ : RewritePattern("{1}", {{)",
+ rewriteName, rootName);
+ // Sort result operators by name.
+ llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
+ resultOps.end());
+ llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
+ return lhs->getOperationName() < rhs->getOperationName();
+ });
+ interleaveComma(sortedResultOps, os, [&](const Operator *op) {
+ os << '"' << op->getOperationName() << '"';
+ });
+ os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
+
+ // Emit matchAndRewrite() function.
+ os << R"(
+ PatternMatchResult matchAndRewrite(Operation *op0,
+ PatternRewriter &rewriter) const override {
+)";
+
+ // Register all symbols bound in the source pattern.
+ pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
+
+ LLVM_DEBUG(
+ llvm::dbgs() << "start creating local variables for capturing matches\n");
+ os.indent(4) << "// Variables for capturing values and attributes used for "
+ "creating ops\n";
+ // Create local variables for storing the arguments and results bound
+ // to symbols.
+ for (const auto &symbolInfoPair : symbolInfoMap) {
+ StringRef symbol = symbolInfoPair.getKey();
+ auto &info = symbolInfoPair.getValue();
+ os.indent(4) << info.getVarDecl(symbol);
+ }
+ // TODO(jpienaar): capture ops with consistent numbering so that it can be
+ // reused for fused loc.
+ os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
+ pattern.getSourcePattern().getNumOps());
+ LLVM_DEBUG(
+ llvm::dbgs() << "done creating local variables for capturing matches\n");
+
+ os.indent(4) << "// Match\n";
+ os.indent(4) << "tblgen_ops[0] = op0;\n";
+ emitMatchLogic(sourceTree);
+ os << "\n";
+
+ os.indent(4) << "// Rewrite\n";
+ emitRewriteLogic();
+
+ os.indent(4) << "return matchSuccess();\n";
+ os << " };\n";
+ os << "};\n";
+}
+
+void PatternEmitter::emitRewriteLogic() {
+ LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
+ const Operator &rootOp = pattern.getSourceRootOp();
+ int numExpectedResults = rootOp.getNumResults();
+ int numResultPatterns = pattern.getNumResultPatterns();
+
+ // First register all symbols bound to ops generated in result patterns.
+ pattern.collectResultPatternBoundSymbols(symbolInfoMap);
+
+ // Only the last N static values generated are used to replace the matched
+ // root N-result op. We need to calculate the starting index (of the results
+ // of the matched op) each result pattern is to replace.
+ SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
+ // If we don't need to replace any value at all, set the replacement starting
+ // index as the number of result patterns so we skip all of them when trying
+ // to replace the matched op's results.
+ int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
+ for (int i = numResultPatterns - 1; i >= 0; --i) {
+ auto numValues = getNodeValueCount(pattern.getResultPattern(i));
+ offsets[i] = offsets[i + 1] - numValues;
+ if (offsets[i] == 0) {
+ if (replStartIndex == -1)
+ replStartIndex = i;
+ } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
+ auto error = formatv(
+ "cannot use the same multi-result op '{0}' to generate both "
+ "auxiliary values and values to be used for replacing the matched op",
+ pattern.getResultPattern(i).getSymbol());
+ PrintFatalError(loc, error);
+ }
+ }
+
+ if (offsets.front() > 0) {
+ const char error[] = "no enough values generated to replace the matched op";
+ PrintFatalError(loc, error);
+ }
+
+ os.indent(4) << "auto loc = rewriter.getFusedLoc({";
+ for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
+ os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
+ }
+ os << "}); (void)loc;\n";
+
+ // Process auxiliary result patterns.
+ for (int i = 0; i < replStartIndex; ++i) {
+ DagNode resultTree = pattern.getResultPattern(i);
+ auto val = handleResultPattern(resultTree, offsets[i], 0);
+ // Normal op creation will be streamed to `os` by the above call; but
+ // NativeCodeCall will only be materialized to `os` if it is used. Here
+ // we are handling auxiliary patterns so we want the side effect even if
+ // NativeCodeCall is not replacing matched root op's results.
+ if (resultTree.isNativeCodeCall())
+ os.indent(4) << val << ";\n";
+ }
+
+ if (numExpectedResults == 0) {
+ assert(replStartIndex >= numResultPatterns &&
+ "invalid auxiliary vs. replacement pattern division!");
+ // No result to replace. Just erase the op.
+ os.indent(4) << "rewriter.eraseOp(op0);\n";
+ } else {
+ // Process replacement result patterns.
+ os.indent(4) << "SmallVector<Value, 4> tblgen_repl_values;\n";
+ for (int i = replStartIndex; i < numResultPatterns; ++i) {
+ DagNode resultTree = pattern.getResultPattern(i);
+ auto val = handleResultPattern(resultTree, offsets[i], 0);
+ os.indent(4) << "\n";
+ // Resolve each symbol for all range use so that we can loop over them.
+ os << symbolInfoMap.getAllRangeUse(
+ val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }",
+ "\n");
+ }
+ os.indent(4) << "\n";
+ os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
+}
+
+std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
+ return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
+}
+
+std::string PatternEmitter::handleResultPattern(DagNode resultTree,
+ int resultIndex, int depth) {
+ LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
+ LLVM_DEBUG(resultTree.print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+
+ if (resultTree.isNativeCodeCall()) {
+ auto symbol = handleReplaceWithNativeCodeCall(resultTree);
+ symbolInfoMap.bindValue(symbol);
+ return symbol;
+ }
+
+ if (resultTree.isReplaceWithValue()) {
+ return handleReplaceWithValue(resultTree);
+ }
+
+ // Normal op creation.
+ auto symbol = handleOpCreation(resultTree, resultIndex, depth);
+ if (resultTree.getSymbol().empty()) {
+ // This is an op not explicitly bound to a symbol in the rewrite rule.
+ // Register the auto-generated symbol for it.
+ symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
+ }
+ return symbol;
+}
+
+std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
+ assert(tree.isReplaceWithValue());
+
+ if (tree.getNumArgs() != 1) {
+ PrintFatalError(
+ loc, "replaceWithValue directive must take exactly one argument");
+ }
+
+ if (!tree.getSymbol().empty()) {
+ PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
+ }
+
+ return tree.getArgName(0);
+}
+
+std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
+ StringRef patArgName) {
+ if (leaf.isConstantAttr()) {
+ auto constAttr = leaf.getAsConstantAttr();
+ return handleConstantAttr(constAttr.getAttribute(),
+ constAttr.getConstantValue());
+ }
+ if (leaf.isEnumAttrCase()) {
+ auto enumCase = leaf.getAsEnumAttrCase();
+ if (enumCase.isStrCase())
+ return handleConstantAttr(enumCase, enumCase.getSymbol());
+ // This is an enum case backed by an IntegerAttr. We need to get its value
+ // to build the constant.
+ std::string val = std::to_string(enumCase.getValue());
+ return handleConstantAttr(enumCase, val);
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
+ auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
+ if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
+ LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
+ << "' (via symbol ref)\n");
+ return argName;
+ }
+ if (leaf.isNativeCodeCall()) {
+ auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
+ LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
+ << "' (via NativeCodeCall)\n");
+ return repl;
+ }
+ PrintFatalError(loc, "unhandled case when rewriting op");
+}
+
+std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
+ LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
+ LLVM_DEBUG(tree.print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+
+ auto fmt = tree.getNativeCodeTemplate();
+ // TODO(b/138794486): replace formatv arguments with the exact specified args.
+ SmallVector<std::string, 8> attrs(8);
+ if (tree.getNumArgs() > 8) {
+ PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
+ Twine(tree.getNumArgs()));
+ }
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+ LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
+ << " replacement: " << attrs[i] << "\n");
+ }
+ return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
+ attrs[5], attrs[6], attrs[7]);
+}
+
+int PatternEmitter::getNodeValueCount(DagNode node) {
+ if (node.isOperation()) {
+ // If the op is bound to a symbol in the rewrite rule, query its result
+ // count from the symbol info map.
+ auto symbol = node.getSymbol();
+ if (!symbol.empty()) {
+ return symbolInfoMap.getStaticValueCount(symbol);
+ }
+ // Otherwise this is an unbound op; we will use all its results.
+ return pattern.getDialectOp(node).getNumResults();
+ }
+ // TODO(antiagainst): This considers all NativeCodeCall as returning one
+ // value. Enhance if multi-value ones are needed.
+ return 1;
+}
+
+std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
+ int depth) {
+ LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
+ LLVM_DEBUG(tree.print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+
+ Operator &resultOp = tree.getDialectOp(opMap);
+ auto numOpArgs = resultOp.getNumArgs();
+
+ if (numOpArgs != tree.getNumArgs()) {
+ PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ resultOp.getOperationName(), tree.getNumArgs(),
+ numOpArgs));
+ }
+
+ // A map to collect all nested DAG child nodes' names, with operand index as
+ // the key. This includes both bound and unbound child nodes.
+ ChildNodeIndexNameMap childNodeNames;
+
+ // First go through all the child nodes who are nested DAG constructs to
+ // create ops for them and remember the symbol names for them, so that we can
+ // use the results in the current node. This happens in a recursive manner.
+ for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
+ if (auto child = tree.getArgAsNestedDag(i)) {
+ childNodeNames[i] = handleResultPattern(child, i, depth + 1);
+ }
+ }
+
+ // The name of the local variable holding this op.
+ std::string valuePackName;
+ // The symbol for holding the result of this pattern. Note that the result of
+ // this pattern is not necessarily the same as the variable created by this
+ // pattern because we can use `__N` suffix to refer only a specific result if
+ // the generated op is a multi-result op.
+ std::string resultValue;
+ if (tree.getSymbol().empty()) {
+ // No symbol is explicitly bound to this op in the pattern. Generate a
+ // unique name.
+ valuePackName = resultValue = getUniqueSymbol(&resultOp);
+ } else {
+ resultValue = tree.getSymbol();
+ // Strip the index to get the name for the value pack and use it to name the
+ // local variable for the op.
+ valuePackName = SymbolInfoMap::getValuePackName(resultValue);
+ }
+
+ // Create the local variable for this op.
+ os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
+ valuePackName);
+ os.indent(4) << "{\n";
+
+ // Right now ODS don't have general type inference support. Except a few
+ // special cases listed below, DRR needs to supply types for all results
+ // when building an op.
+ bool isSameOperandsAndResultType =
+ resultOp.getTrait("OpTrait::SameOperandsAndResultType");
+ bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType");
+
+ if (isSameOperandsAndResultType || useFirstAttr) {
+ // We know how to deduce the result type for ops with these traits and we've
+ // generated builders taking aggregate parameters. Use those builders to
+ // create the ops.
+
+ // First prepare local variables for op arguments used in builder call.
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+ // Then create the op.
+ os.indent(6) << formatv(
+ "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
+ valuePackName, resultOp.getQualCppClassName());
+ os.indent(4) << "}\n";
+ return resultValue;
+ }
+
+ bool isBroadcastable =
+ resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
+ bool usePartialResults = valuePackName != resultValue;
+
+ if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
+ // For these cases (broadcastable ops, op results used both as auxiliary
+ // values and replacement values, ops in nested patterns, auxiliary ops), we
+ // still need to supply the result types when building the op. But because
+ // we don't generate a builder automatically with ODS for them, it's the
+ // developer's responsiblity to make sure such a builder (with result type
+ // deduction ability) exists. We go through the separate-parameter builder
+ // here given that it's easier for developers to write compared to
+ // aggregate-parameter builders.
+ createSeparateLocalVarsForOpArgs(tree, childNodeNames);
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
+ resultOp.getQualCppClassName());
+ supplyValuesForOpArgs(tree, childNodeNames);
+ os << "\n );\n";
+ os.indent(4) << "}\n";
+ return resultValue;
+ }
+
+ // If depth == 0 and resultIndex >= 0, it means we are replacing the values
+ // generated from the source pattern root op. Then we can use the source
+ // pattern's value types to determine the value type of the generated op
+ // here.
+
+ // First prepare local variables for op arguments used in builder call.
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+
+ // Then prepare the result types. We need to specify the types for all
+ // results.
+ os.indent(6) << formatv(
+ "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
+ int numResults = resultOp.getNumResults();
+ if (numResults != 0) {
+ for (int i = 0; i < numResults; ++i)
+ os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
+ "tblgen_types.push_back(v->getType()); }\n",
+ resultIndex + i);
+ }
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
+ "tblgen_values, tblgen_attrs);\n",
+ valuePackName, resultOp.getQualCppClassName());
+ os.indent(4) << "}\n";
+ return resultValue;
+}
+
+void PatternEmitter::createSeparateLocalVarsForOpArgs(
+ DagNode node, ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+
+ // Now prepare operands used for building this op:
+ // * If the operand is non-variadic, we create a `Value` local variable.
+ // * If the operand is variadic, we create a `SmallVector<Value>` local
+ // variable.
+
+ int valueIndex = 0; // An index for uniquing local variable names.
+ for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
+ const auto *operand =
+ resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
+ if (!operand) {
+ // We do not need special handling for attributes.
+ continue;
+ }
+
+ std::string varName;
+ if (operand->isVariadic()) {
+ varName = formatv("tblgen_values_{0}", valueIndex++);
+ os.indent(6) << formatv("SmallVector<Value, 4> {0};\n", varName);
+ std::string range;
+ if (node.isNestedDagArg(argIndex)) {
+ range = childNodeNames[argIndex];
+ } else {
+ range = node.getArgName(argIndex);
+ }
+ // Resolve the symbol for all range use so that we have a uniform way of
+ // capturing the values.
+ range = symbolInfoMap.getValueAndRangeUse(range);
+ os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
+ varName);
+ } else {
+ varName = formatv("tblgen_value_{0}", valueIndex++);
+ os.indent(6) << formatv("Value {0} = ", varName);
+ if (node.isNestedDagArg(argIndex)) {
+ os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
+ } else {
+ DagLeaf leaf = node.getArgAsLeaf(argIndex);
+ auto symbol =
+ symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
+ if (leaf.isNativeCodeCall()) {
+ os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
+ } else {
+ os << symbol;
+ }
+ }
+ os << ";\n";
+ }
+
+ // Update to use the newly created local variable for building the op later.
+ childNodeNames[argIndex] = varName;
+ }
+}
+
+void PatternEmitter::supplyValuesForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+ for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
+ argIndex != numOpArgs; ++argIndex) {
+ // Start each argument on its own line.
+ (os << ",\n").indent(8);
+
+ Argument opArg = resultOp.getArg(argIndex);
+ // Handle the case of operand first.
+ if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (!operand->name.empty())
+ os << "/*" << operand->name << "=*/";
+ os << childNodeNames.lookup(argIndex);
+ continue;
+ }
+
+ // The argument in the op definition.
+ auto opArgName = resultOp.getArgName(argIndex);
+ if (auto subTree = node.getArgAsNestedDag(argIndex)) {
+ if (!subTree.isNativeCodeCall())
+ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+ "for creating attribute");
+ os << formatv("/*{0}=*/{1}", opArgName,
+ handleReplaceWithNativeCodeCall(subTree));
+ } else {
+ auto leaf = node.getArgAsLeaf(argIndex);
+ // The argument in the result DAG pattern.
+ auto patArgName = node.getArgName(argIndex);
+ if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
+ // TODO(jpienaar): Refactor out into map to avoid recomputing these.
+ if (!opArg.is<NamedAttribute *>())
+ PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
+ if (!patArgName.empty())
+ os << "/*" << patArgName << "=*/";
+ } else {
+ os << "/*" << opArgName << "=*/";
+ }
+ os << handleOpArgument(leaf, patArgName);
+ }
+ }
+}
+
+void PatternEmitter::createAggregateLocalVarsForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+
+ os.indent(6) << formatv(
+ "SmallVector<Value, 4> tblgen_values; (void)tblgen_values;\n");
+ os.indent(6) << formatv(
+ "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
+
+ for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
+ if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
+ const char *addAttrCmd = "if ({1}) {{"
+ " tblgen_attrs.emplace_back(rewriter."
+ "getIdentifier(\"{0}\"), {1}); }\n";
+ // The argument in the op definition.
+ auto opArgName = resultOp.getArgName(argIndex);
+ if (auto subTree = node.getArgAsNestedDag(argIndex)) {
+ if (!subTree.isNativeCodeCall())
+ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+ "for creating attribute");
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleReplaceWithNativeCodeCall(subTree));
+ } else {
+ auto leaf = node.getArgAsLeaf(argIndex);
+ // The argument in the result DAG pattern.
+ auto patArgName = node.getArgName(argIndex);
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ }
+ continue;
+ }
+
+ const auto *operand =
+ resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
+ std::string varName;
+ if (operand->isVariadic()) {
+ std::string range;
+ if (node.isNestedDagArg(argIndex)) {
+ range = childNodeNames.lookup(argIndex);
+ } else {
+ range = node.getArgName(argIndex);
+ }
+ // Resolve the symbol for all range use so that we have a uniform way of
+ // capturing the values.
+ range = symbolInfoMap.getValueAndRangeUse(range);
+ os.indent(6) << formatv(
+ "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
+ } else {
+ os.indent(6) << formatv("tblgen_values.push_back(", varName);
+ if (node.isNestedDagArg(argIndex)) {
+ os << symbolInfoMap.getValueAndRangeUse(
+ childNodeNames.lookup(argIndex));
+ } else {
+ DagLeaf leaf = node.getArgAsLeaf(argIndex);
+ auto symbol =
+ symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
+ if (leaf.isNativeCodeCall()) {
+ os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
+ } else {
+ os << symbol;
+ }
+ }
+ os << ");\n";
+ }
+ }
+}
+
+static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ emitSourceFileHeader("Rewriters", os);
+
+ const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+ auto numPatterns = patterns.size();
+
+ // We put the map here because it can be shared among multiple patterns.
+ RecordOperatorMap recordOpMap;
+
+ std::vector<std::string> rewriterNames;
+ rewriterNames.reserve(numPatterns);
+
+ std::string baseRewriterName = "GeneratedConvert";
+ int rewriterIndex = 0;
+
+ for (Record *p : patterns) {
+ std::string name;
+ if (p->isAnonymous()) {
+ // If no name is provided, ensure unique rewriter names simply by
+ // appending unique suffix.
+ name = baseRewriterName + llvm::utostr(rewriterIndex++);
+ } else {
+ name = p->getName();
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "=== start generating pattern '" << name << "' ===\n");
+ PatternEmitter(p, &recordOpMap, os).emit(name);
+ LLVM_DEBUG(llvm::dbgs()
+ << "=== done generating pattern '" << name << "' ===\n");
+ rewriterNames.push_back(std::move(name));
+ }
+
+ // Emit function to add the generated matchers to the pattern list.
+ os << "void populateWithGenerated(MLIRContext *context, "
+ << "OwningRewritePatternList *patterns) {\n";
+ for (const auto &name : rewriterNames) {
+ os << " patterns->insert<" << name << ">(context);\n";
+ }
+ os << "}\n";
+}
+
+static mlir::GenRegistration
+ genRewriters("gen-rewriters", "Generate pattern rewriters",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ emitRewriters(records, os);
+ return false;
+ });
OpenPOWER on IntegriCloud