summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSmit Hinsu <hinsu@google.com>2019-02-05 12:02:53 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:15:08 -0700
commit2927297a1cc85dec2e19df339dac696d271aff59 (patch)
tree7fd57a5ed70df0a3f633f96492fd1235e533fade
parent40d5d09f9d52c581fd4419e8a54f4b952d904bb2 (diff)
downloadbcm5719-llvm-2927297a1cc85dec2e19df339dac696d271aff59.tar.gz
bcm5719-llvm-2927297a1cc85dec2e19df339dac696d271aff59.zip
Add derived type attributes for TensorFlow ops generated by TableGen
Motivation for this change is to remove redundant TF type attributes for TensorFlow ops. For example, tf$T: "tfdtype$DT_FLOAT". Type attributes can be derived using the MLIR operand or result MLIR types, attribute names and their mapping. This will also allow constant folding of instructions generated within MLIR (and not imported from TensorFlow) without adding type attributes for the instruction. Derived attributes are populated while exporting MLIR to TF GraphDef using auto-generated populators. Populators are only available for the ops that are generated by the TableGen. Also, fixed Operator::getNumArgs method to exclude derived attributes as they are not part of the arguments. TESTED with unit test PiperOrigin-RevId: 232531561
-rw-r--r--mlir/include/mlir/TableGen/Operator.h17
-rw-r--r--mlir/lib/TableGen/Operator.cpp16
2 files changed, 25 insertions, 8 deletions
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 502e39b4b3c..a75b909a9d5 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -29,6 +29,7 @@
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
namespace llvm {
class CodeInit;
@@ -54,6 +55,9 @@ public:
// Returns the TableGen definition name split around '_'.
const SmallVectorImpl<StringRef> &getSplitDefName() const;
+ // Returns dialect name of the op.
+ StringRef getDialectName() const;
+
// Returns the C++ class name of the op.
StringRef getCppClassName() const;
@@ -69,15 +73,16 @@ public:
StringRef getResultName(int index) const;
// Op attribute interators.
- using attribute_iterator = NamedAttribute *;
- attribute_iterator attribute_begin();
- attribute_iterator attribute_end();
- llvm::iterator_range<attribute_iterator> getAttributes();
+ using attribute_iterator = const NamedAttribute *;
+ attribute_iterator attribute_begin() const;
+ attribute_iterator attribute_end() const;
+ llvm::iterator_range<attribute_iterator> getAttributes() const;
// Op attribute accessors.
int getNumAttributes() const { return attributes.size(); }
// Returns the total number of native attributes.
int getNumNativeAttributes() const;
+ int getNumDerivedAttributes() const;
NamedAttribute &getAttribute(int index) { return attributes[index]; }
const NamedAttribute &getAttribute(int index) const;
@@ -96,7 +101,9 @@ public:
Argument getArg(int index);
StringRef getArgName(int index) const;
// Returns the total number of arguments.
- int getNumArgs() const { return operands.size() + attributes.size(); }
+ int getNumArgs() const { return getNumOperands() + getNumNativeAttributes(); }
+
+ ArrayRef<llvm::SMLoc> getLoc() const;
// Query functions for the documentation of the operator.
bool hasDescription() const;
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index f5435ef4adb..21d855a4b18 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -46,6 +46,10 @@ StringRef tblgen::Operator::getOperationName() const {
return def.getValueAsString("opName");
}
+StringRef tblgen::Operator::getDialectName() const {
+ return getSplitDefName().front();
+}
+
StringRef tblgen::Operator::getCppClassName() const {
return getSplitDefName().back();
}
@@ -72,6 +76,10 @@ int tblgen::Operator::getNumNativeAttributes() const {
return derivedAttrStart - nativeAttrStart;
}
+int tblgen::Operator::getNumDerivedAttributes() const {
+ return getNumAttributes() - getNumNativeAttributes();
+}
+
const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
return attributes[index];
}
@@ -81,13 +89,13 @@ StringRef tblgen::Operator::getArgName(int index) const {
return argumentValues->getArgName(index)->getValue();
}
-auto tblgen::Operator::attribute_begin() -> attribute_iterator {
+auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
return attributes.begin();
}
-auto tblgen::Operator::attribute_end() -> attribute_iterator {
+auto tblgen::Operator::attribute_end() const -> attribute_iterator {
return attributes.end();
}
-auto tblgen::Operator::getAttributes()
+auto tblgen::Operator::getAttributes() const
-> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
@@ -173,6 +181,8 @@ void tblgen::Operator::populateOperandsAndAttributes() {
}
}
+ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
+
bool tblgen::Operator::hasDescription() const {
return def.getValue("description") != nullptr;
}
OpenPOWER on IntegriCloud