summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-02-01 15:40:22 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:06:31 -0700
commite0774c008fdcee1d4007ede4fde4cf7ad83cfda8 (patch)
tree8c9fc6e4951157a6cb7d76392f417d1bbe53a799
parent70e3873e86ac578b86e29637a14e44c334be4944 (diff)
downloadbcm5719-llvm-e0774c008fdcee1d4007ede4fde4cf7ad83cfda8.tar.gz
bcm5719-llvm-e0774c008fdcee1d4007ede4fde4cf7ad83cfda8.zip
[TableGen] Use tblgen::DagLeaf to model DAG arguments
This CL added a tblgen::DagLeaf wrapper class with several helper methods for handling DAG arguments. It helps to refactor the rewriter generation logic to be more higher level. This CL also added a tblgen::ConstantAttr wrapper class for constant attributes. PiperOrigin-RevId: 232050683
-rw-r--r--mlir/include/mlir/TableGen/Attribute.h18
-rw-r--r--mlir/include/mlir/TableGen/Pattern.h82
-rw-r--r--mlir/include/mlir/TableGen/Type.h1
-rw-r--r--mlir/lib/TableGen/Attribute.cpp14
-rw-r--r--mlir/lib/TableGen/Pattern.cpp72
-rw-r--r--mlir/test/mlir-tblgen/one-op-one-result.td17
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp232
7 files changed, 292 insertions, 144 deletions
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index b126617f289..e601fdf22ea 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -111,6 +111,24 @@ public:
StringRef getDerivedCodeBody() const;
};
+// Wrapper class providing helper methods for accessing MLIR constant attribute
+// defined in TableGen. This class should closely reflect what is defined as
+// class `ConstantAttr` in TableGen.
+class ConstantAttr {
+public:
+ explicit ConstantAttr(const llvm::DefInit *init);
+
+ // Returns the attribute kind.
+ Attribute getAttribute() const;
+
+ // Returns the constant value.
+ StringRef getConstantValue() const;
+
+private:
+ // The TableGen definition of this constant attribute.
+ const llvm::Record *def;
+};
+
} // end namespace tblgen
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 6544316313d..80a38329a33 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -23,6 +23,7 @@
#ifndef MLIR_TABLEGEN_PATTERN_H_
#define MLIR_TABLEGEN_PATTERN_H_
+#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
@@ -30,9 +31,9 @@
#include "llvm/TableGen/Error.h"
namespace llvm {
-class Record;
-class Init;
class DagInit;
+class Init;
+class Record;
class StringRef;
} // end namespace llvm
@@ -42,19 +43,61 @@ namespace tblgen {
// Mapping from TableGen Record to Operator wrapper object
using RecordOperatorMap = llvm::DenseMap<const llvm::Record *, Operator>;
-// Wrapper around DAG argument.
-struct DagArg {
- DagArg(Argument arg, llvm::Init *constraint)
- : arg(arg), constraint(constraint) {}
+class Pattern;
- // Returns true if this DAG argument concerns an operation attribute.
- bool isAttr() const;
+// Wrapper class providing helper methods for accessing TableGen DAG leaves
+// used inside Patterns. This class is lightweight and designed to be used like
+// values.
+//
+// A TableGen DAG construct is of the syntax
+// `(operator, arg0, arg1, ...)`.
+//
+// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
+// for handy helper methods. It only works on `arg*`s that are not nested DAG
+// constructs.
+class DagLeaf {
+public:
+ explicit DagLeaf(const llvm::Init *def) : def(def) {}
- Argument arg;
- llvm::Init *constraint;
-};
+ // Returns true if this DAG leaf is not specified in the pattern. That is, it
+ // places no further constraints/transforms and just carries over the original
+ // value.
+ bool isUnspecified() const;
-class Pattern;
+ // Returns true if this DAG leaf is matching an operand. That is, it specifies
+ // a type constraint.
+ bool isOperandMatcher() const;
+
+ // Returns true if this DAG leaf is matching an attribute. That is, it
+ // specifies an attribute constraint.
+ bool isAttrMatcher() const;
+
+ // Returns true if this DAG leaf is transforming an attribute.
+ bool isAttrTransformer() const;
+
+ // Returns true if this DAG leaf is specifying a constant attribute.
+ bool isConstantAttr() const;
+
+ // Returns this DAG leaf as a type constraint. Asserts if fails.
+ TypeConstraint getAsTypeConstraint() const;
+
+ // Returns this DAG leaf as an attribute constraint. Asserts if fails.
+ AttrConstraint getAsAttrConstraint() const;
+
+ // Returns this DAG leaf as an constant attribute. Asserts if fails.
+ ConstantAttr getAsConstantAttr() const;
+
+ // Returns the matching condition template inside this DAG leaf. Assumes the
+ // leaf is an operand/attribute matcher and asserts otherwise.
+ std::string getConditionTemplate() const;
+
+ // Returns the transformation template inside this DAG leaf. Assumes the
+ // leaf is an attribute matcher and asserts otherwise.
+ std::string getTransformationTemplate() const;
+
+private:
+ const llvm::Init *def;
+};
// Wrapper class providing helper methods for accessing TableGen DAG constructs
// used inside Patterns. This class is lightweight and designed to be used like
@@ -96,10 +139,9 @@ public:
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
// null DagNode otherwise.
DagNode getArgAsNestedDag(unsigned index) const;
- // Gets the `index`-th argument as a TableGen DefInit* if possible. Returns
- // nullptr otherwise.
- // TODO: This method is exposing raw TableGen object and should be changed.
- llvm::DefInit *getArgAsDefInit(unsigned index) const;
+
+ // Gets the `index`-th argument as a DAG leaf.
+ DagLeaf getArgAsLeaf(unsigned index) const;
// Returns the specified name of the `index`-th argument.
llvm::StringRef getArgName(unsigned index) const;
@@ -146,7 +188,7 @@ public:
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
// Returns a reference to all the bound arguments in the source pattern.
- llvm::StringMap<DagArg> &getSourcePatternBoundArgs();
+ llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@@ -159,8 +201,10 @@ private:
// The TableGen definition of this pattern.
const llvm::Record &def;
- RecordOperatorMap *recordOpMap; // All operators
- llvm::StringMap<DagArg> boundArguments; // All bound arguments
+ // All operators
+ RecordOperatorMap *recordOpMap;
+ // All bound arguments
+ llvm::StringMap<Argument> boundArguments;
};
} // end namespace tblgen
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index fd91ff4dc2c..247e0fc8e4b 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -42,6 +42,7 @@ public:
explicit TypeConstraint(const llvm::DefInit &init);
bool operator==(const TypeConstraint &that) { return def == that.def; }
+ bool operator!=(const TypeConstraint &that) { return def != that.def; }
// Returns the predicate that can be used to check if a type satisfies this
// type constraint.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 1e9c37cac9a..2b8cda031ef 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -133,3 +133,17 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body");
}
+
+tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init)
+ : def(init->getDef()) {
+ assert(def->isSubClassOf("ConstantAttr") &&
+ "must be subclass of TableGen 'ConstantAttr' class");
+}
+
+tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
+ return Attribute(def->getValueAsDef("attr"));
+}
+
+StringRef tblgen::ConstantAttr::getConstantValue() const {
+ return def->getValueAsString("value");
+}
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index f9bb4a4b08c..5262141e753 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -28,8 +28,66 @@ using namespace mlir;
using mlir::tblgen::Operator;
-bool tblgen::DagArg::isAttr() const {
- return arg.is<tblgen::NamedAttribute *>();
+bool tblgen::DagLeaf::isUnspecified() const {
+ return !def || isa<llvm::UnsetInit>(def);
+}
+
+bool tblgen::DagLeaf::isOperandMatcher() const {
+ if (!def || !isa<llvm::DefInit>(def))
+ return false;
+ // Operand matchers specify a type constraint.
+ return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
+}
+
+bool tblgen::DagLeaf::isAttrMatcher() const {
+ if (!def || !isa<llvm::DefInit>(def))
+ return false;
+ // Attribute matchers specify a type constraint.
+ return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
+}
+
+bool tblgen::DagLeaf::isAttrTransformer() const {
+ if (!def || !isa<llvm::DefInit>(def))
+ return false;
+ return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
+}
+
+bool tblgen::DagLeaf::isConstantAttr() const {
+ if (!def || !isa<llvm::DefInit>(def))
+ return false;
+ return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
+}
+
+tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const {
+ assert(isOperandMatcher() && "the DAG leaf must be operand");
+ return TypeConstraint(*cast<llvm::DefInit>(def)->getDef());
+}
+
+tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const {
+ assert(isAttrMatcher() && "the DAG leaf must be attribute");
+ return AttrConstraint(cast<llvm::DefInit>(def)->getDef());
+}
+
+tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
+ assert(isConstantAttr() && "the DAG leaf must be constant attribute");
+ return ConstantAttr(cast<llvm::DefInit>(def));
+}
+
+std::string tblgen::DagLeaf::getConditionTemplate() const {
+ assert((isOperandMatcher() || isAttrMatcher()) &&
+ "the DAG leaf must be operand/attribute matcher");
+ if (isOperandMatcher()) {
+ return getAsTypeConstraint().getConditionTemplate();
+ }
+ return getAsAttrConstraint().getConditionTemplate();
+}
+
+std::string tblgen::DagLeaf::getTransformationTemplate() const {
+ assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
+ return cast<llvm::DefInit>(def)
+ ->getDef()
+ ->getValueAsString("attrTransform")
+ .str();
}
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
@@ -56,8 +114,9 @@ tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
}
-llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const {
- return dyn_cast<llvm::DefInit>(node->getArg(index));
+tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
+ assert(!isNestedDagArg(index));
+ return DagLeaf(node->getArg(index));
}
StringRef tblgen::DagNode::getArgName(unsigned index) const {
@@ -81,7 +140,7 @@ static void collectBoundArguments(const llvm::DagInit *tree,
if (name.empty())
continue;
- pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg);
+ pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i));
}
}
@@ -131,7 +190,8 @@ void tblgen::Pattern::ensureArgBoundInSourcePattern(
Twine("referencing unbound variable '") + name + "'");
}
-llvm::StringMap<tblgen::DagArg> &tblgen::Pattern::getSourcePatternBoundArgs() {
+llvm::StringMap<tblgen::Argument> &
+tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
}
diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td
index 45056b154ed..3bdd9aa1b96 100644
--- a/mlir/test/mlir-tblgen/one-op-one-result.td
+++ b/mlir/test/mlir-tblgen/one-op-one-result.td
@@ -3,24 +3,21 @@
include "mlir/IR/op_base.td"
// Create a Type and Attribute.
-def YT : BuildableType<"buildYT">;
-def Y_Attr : TypeBasedAttr<YT, "Attribute", "attribute of Y type">;
-def Y_Const_Attr {
- Attr attr = Y_Attr;
- string value = "attrValue";
-}
+def T : BuildableType<"buildT">;
+def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
+def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
// Define ops to rewrite.
-def T1: Type<CPred<"true">, "T1">;
+def U: Type<CPred<"true">, "U">;
def X_AddOp : Op<"x.add"> {
- let arguments = (ins T1, T1);
+ let arguments = (ins U, U);
}
def Y_AddOp : Op<"y.add"> {
- let arguments = (ins T1, T1, Y_Attr:$attrName);
+ let arguments = (ins U, U, T_Attr:$attrName);
}
// Define rewrite pattern.
-def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>;
+def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
// CHECK: struct GeneratedConvert0 : public RewritePattern
// CHECK: RewritePattern("x.add", 1, context)
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6b62da7eb04..7ca663071d8 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -39,31 +39,11 @@
using namespace llvm;
using namespace mlir;
-using mlir::tblgen::Argument;
-using mlir::tblgen::Attribute;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::Operand;
using mlir::tblgen::Operator;
-using mlir::tblgen::Pattern;
using mlir::tblgen::RecordOperatorMap;
-using mlir::tblgen::Type;
-
-namespace {
-
-// Wrapper around DAG argument.
-struct DagArg {
- DagArg(Argument arg, Init *constraintInit)
- : arg(arg), constraintInit(constraintInit) {}
- bool isAttr();
-
- Argument arg;
- Init *constraintInit;
-};
-
-} // end namespace
-
-bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
namespace {
class PatternEmitter {
@@ -93,12 +73,19 @@ private:
void emitReplaceWithNativeBuilder(DagNode resultTree);
// Emits the value of constant attribute to `os`.
- void emitAttributeValue(Record *constAttr);
+ void emitConstantAttr(tblgen::ConstantAttr constAttr);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
+ // Emits C++ statements for matching the `index`-th argument of the given DAG
+ // `tree` as an operand.
+ void emitOperandMatch(DagNode tree, int index, int depth, int indent);
+ // Emits C++ statements for matching the `index`-th argument of the given DAG
+ // `tree` as an attribute.
+ void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
+
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@@ -107,14 +94,13 @@ private:
// Op's TableGen Record to wrapper object
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
- Pattern pattern;
+ tblgen::Pattern pattern;
raw_ostream &os;
};
} // end namespace
-void PatternEmitter::emitAttributeValue(Record *constAttr) {
- Attribute attr(constAttr->getValueAsDef("attr"));
- auto value = constAttr->getValue("value");
+void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
+ auto attr = constAttr.getAttribute();
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
@@ -122,7 +108,7 @@ void PatternEmitter::emitAttributeValue(Record *constAttr) {
// TODO(jpienaar): Verify the constants here
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
- value->getValue()->getAsUnquotedString());
+ constAttr.getConstantValue());
}
// Helper function to match patterns.
@@ -137,13 +123,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
op.getQualCppClassName());
}
- if (tree.getNumArgs() != op.getNumArgs())
- PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
- op.getOperationName() +
- "' in pattern and op's definition");
+ if (tree.getNumArgs() != op.getNumArgs()) {
+ PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
+ "pattern vs. {2} in definition",
+ op.getOperationName(), tree.getNumArgs(),
+ op.getNumArgs()));
+ }
+
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
+ // Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
@@ -154,50 +144,78 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
continue;
}
- // Verify arguments.
- if (auto defInit = tree.getArgAsDefInit(i)) {
- // Verify operands.
- if (auto *operand = opArg.dyn_cast<Operand *>()) {
- // Skip verification where not needed due to definition of op.
- if (operand->type == Type(defInit))
- goto StateCapture;
-
- if (!defInit->getDef()->isSubClassOf("Type"))
- PrintFatalError(loc, "type argument required for operand");
-
- auto constraint = tblgen::TypeConstraint(*defInit);
- os.indent(indent)
- << "if (!("
- << formatv(constraint.getConditionTemplate().c_str(),
- formatv("op{0}->getOperand({1})->getType()", depth, i))
- << ")) return matchFailure();\n";
- }
-
- // TODO(jpienaar): Verify attributes.
- if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
- auto constraint = tblgen::AttrConstraint(defInit);
- std::string condition = formatv(
- constraint.getConditionTemplate().c_str(),
- formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
- namedAttr->attr.getStorageType(), namedAttr->getName()));
- os.indent(indent) << "if (!(" << condition
- << ")) return matchFailure();\n";
- }
+ // Next handle DAG leaf: operand or attribute
+ if (auto *operand = opArg.dyn_cast<Operand *>()) {
+ emitOperandMatch(tree, i, depth, indent);
+ } else if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
+ emitAttributeMatch(tree, i, depth, indent);
+ } else {
+ PrintFatalError(loc, "unhandled case when matching op");
}
+ }
+}
- StateCapture:
- auto name = tree.getArgName(i);
- if (name.empty())
- continue;
- if (opArg.is<Operand *>())
- os.indent(indent) << "state->" << name << " = op" << depth
- << "->getOperand(" << i << ");\n";
- if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
- os.indent(indent) << "state->" << name << " = op" << depth
- << "->getAttrOfType<"
- << namedAttr->attr.getStorageType() << ">(\""
- << namedAttr->getName() << "\");\n";
+void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
+ int indent) {
+ Operator &op = tree.getDialectOp(opMap);
+ auto *operand = op.getArg(index).get<Operand *>();
+ auto matcher = tree.getArgAsLeaf(index);
+
+ // If a constraint is specified, we need to generate C++ statements to
+ // check the constraint.
+ if (!matcher.isUnspecified()) {
+ if (!matcher.isOperandMatcher()) {
+ PrintFatalError(
+ loc, formatv("the {1}-th argument of op '{0}' should be an operand",
+ op.getOperationName(), index + 1));
}
+
+ // Only need to verify if the matcher's type is different from the one
+ // of op definition.
+ if (static_cast<tblgen::TypeConstraint>(operand->type) !=
+ matcher.getAsTypeConstraint()) {
+ os.indent(indent) << "if (!("
+ << formatv(matcher.getConditionTemplate().c_str(),
+ formatv("op{0}->getOperand({1})->getType()",
+ depth, index))
+ << ")) return matchFailure();\n";
+ }
+ }
+
+ // Capture the value
+ auto name = tree.getArgName(index);
+ if (!name.empty()) {
+ os.indent(indent) << "state->" << name << " = op" << depth
+ << "->getOperand(" << index << ");\n";
+ }
+}
+
+void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
+ int indent) {
+ Operator &op = tree.getDialectOp(opMap);
+ auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
+ auto matcher = tree.getArgAsLeaf(index);
+
+ if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) {
+ PrintFatalError(
+ loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
+ op.getOperationName(), index + 1));
+ }
+
+ // If a constraint is specified, we need to generate C++ statements to
+ // check the constraint.
+ std::string condition =
+ formatv(matcher.getConditionTemplate().c_str(),
+ formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
+ namedAttr->attr.getStorageType(), namedAttr->getName()));
+ os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n";
+
+ // Capture the value
+ auto name = tree.getArgName(index);
+ if (!name.empty()) {
+ os.indent(indent) << "state->" << name << " = op" << depth
+ << "->getAttrOfType<" << namedAttr->attr.getStorageType()
+ << ">(\"" << namedAttr->getName() << "\");\n";
}
}
@@ -234,11 +252,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
- if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
- os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
+ auto fieldName = arg.first();
+ if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
+ os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
<< ";\n";
} else {
- os.indent(4) << "Value* " << arg.first() << ";\n";
+ os.indent(4) << "Value* " << fieldName << ";\n";
}
}
os << " };\n";
@@ -285,10 +304,10 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.getCppClassName());
if (numOpArgs != resultTree.getNumArgs()) {
- PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
- Twine(numOpArgs) +
- ") and arguments provided for rewrite (" +
- Twine(resultTree.getNumArgs()) + Twine(')'));
+ PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ resultOp.getOperationName(),
+ resultTree.getNumArgs(), numOpArgs));
}
// Create the builder call for the result.
@@ -312,38 +331,33 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
+ auto leaf = resultTree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
- auto argName = resultTree.getArgName(i);
- auto opName = resultOp.getArgName(i);
- auto *defInit = resultTree.getArgAsDefInit(i);
- auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
- if (!value) {
- pattern.ensureArgBoundInSourcePattern(argName);
- auto result = "s." + argName;
- os << "/*" << opName << "=*/";
- if (defInit) {
- auto transform = defInit->getDef();
- if (transform->isSubClassOf("tAttr")) {
- // TODO(jpienaar): move to helper class.
- os << formatv(
- transform->getValueAsString("attrTransform").str().c_str(),
- result);
- continue;
- }
- }
- os << result;
- continue;
+ auto patArgName = resultTree.getArgName(i);
+ // The argument in the op definition.
+ auto opArgName = resultOp.getArgName(i);
+
+ if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
+ pattern.ensureArgBoundInSourcePattern(patArgName);
+ os << formatv("/*{0}=*/s.{1}", opArgName, patArgName);
+ } else if (leaf.isAttrTransformer()) {
+ pattern.ensureArgBoundInSourcePattern(patArgName);
+ std::string result = std::string("s.") + patArgName.str();
+ result = formatv(leaf.getTransformationTemplate().c_str(), result);
+ os << formatv("/*{0}=*/{1}", opArgName, result);
+ } else if (leaf.isConstantAttr()) {
+ // TODO(jpienaar): Refactor out into map to avoid recomputing these.
+ auto argument = resultOp.getArg(i);
+ if (!argument.is<NamedAttribute *>())
+ PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
+
+ if (!patArgName.empty())
+ os << "/*" << patArgName << "=*/";
+ emitConstantAttr(leaf.getAsConstantAttr());
+ // TODO(jpienaar): verify types
+ } else {
+ PrintFatalError(loc, "unhandled case when rewriting op");
}
-
- // TODO(jpienaar): Refactor out into map to avoid recomputing these.
- auto argument = resultOp.getArg(i);
- if (!argument.is<NamedAttribute *>())
- PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
-
- if (!argName.empty())
- os << "/*" << argName << "=*/";
- emitAttributeValue(defInit->getDef());
- // TODO(jpienaar): verify types
}
os << "\n );\n";
}
@@ -367,7 +381,7 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
const auto &val = boundedValues.find(name);
- if (val->second.isAttr() && !printingAttr) {
+ if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
os << "}, {";
first = true;
printingAttr = true;
OpenPOWER on IntegriCloud