diff options
| author | Smit Hinsu <hinsu@google.com> | 2019-02-05 12:02:53 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:15:08 -0700 |
| commit | 2927297a1cc85dec2e19df339dac696d271aff59 (patch) | |
| tree | 7fd57a5ed70df0a3f633f96492fd1235e533fade | |
| parent | 40d5d09f9d52c581fd4419e8a54f4b952d904bb2 (diff) | |
| download | bcm5719-llvm-2927297a1cc85dec2e19df339dac696d271aff59.tar.gz bcm5719-llvm-2927297a1cc85dec2e19df339dac696d271aff59.zip | |
Add derived type attributes for TensorFlow ops generated by TableGen
Motivation for this change is to remove redundant TF type attributes for
TensorFlow ops. For example, tf$T: "tfdtype$DT_FLOAT". Type attributes can be derived using the MLIR operand or result MLIR types, attribute names and their mapping. This will also allow constant folding of instructions generated within MLIR (and not imported from TensorFlow) without adding type attributes for the instruction.
Derived attributes are populated while exporting MLIR to TF GraphDef using
auto-generated populators. Populators are only available for the ops that are generated by the TableGen.
Also, fixed Operator::getNumArgs method to exclude derived attributes as they are not
part of the arguments.
TESTED with unit test
PiperOrigin-RevId: 232531561
| -rw-r--r-- | mlir/include/mlir/TableGen/Operator.h | 17 | ||||
| -rw-r--r-- | mlir/lib/TableGen/Operator.cpp | 16 |
2 files changed, 25 insertions, 8 deletions
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 502e39b4b3c..a75b909a9d5 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -29,6 +29,7 @@ #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" namespace llvm { class CodeInit; @@ -54,6 +55,9 @@ public: // Returns the TableGen definition name split around '_'. const SmallVectorImpl<StringRef> &getSplitDefName() const; + // Returns dialect name of the op. + StringRef getDialectName() const; + // Returns the C++ class name of the op. StringRef getCppClassName() const; @@ -69,15 +73,16 @@ public: StringRef getResultName(int index) const; // Op attribute interators. - using attribute_iterator = NamedAttribute *; - attribute_iterator attribute_begin(); - attribute_iterator attribute_end(); - llvm::iterator_range<attribute_iterator> getAttributes(); + using attribute_iterator = const NamedAttribute *; + attribute_iterator attribute_begin() const; + attribute_iterator attribute_end() const; + llvm::iterator_range<attribute_iterator> getAttributes() const; // Op attribute accessors. int getNumAttributes() const { return attributes.size(); } // Returns the total number of native attributes. int getNumNativeAttributes() const; + int getNumDerivedAttributes() const; NamedAttribute &getAttribute(int index) { return attributes[index]; } const NamedAttribute &getAttribute(int index) const; @@ -96,7 +101,9 @@ public: Argument getArg(int index); StringRef getArgName(int index) const; // Returns the total number of arguments. - int getNumArgs() const { return operands.size() + attributes.size(); } + int getNumArgs() const { return getNumOperands() + getNumNativeAttributes(); } + + ArrayRef<llvm::SMLoc> getLoc() const; // Query functions for the documentation of the operator. bool hasDescription() const; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index f5435ef4adb..21d855a4b18 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -46,6 +46,10 @@ StringRef tblgen::Operator::getOperationName() const { return def.getValueAsString("opName"); } +StringRef tblgen::Operator::getDialectName() const { + return getSplitDefName().front(); +} + StringRef tblgen::Operator::getCppClassName() const { return getSplitDefName().back(); } @@ -72,6 +76,10 @@ int tblgen::Operator::getNumNativeAttributes() const { return derivedAttrStart - nativeAttrStart; } +int tblgen::Operator::getNumDerivedAttributes() const { + return getNumAttributes() - getNumNativeAttributes(); +} + const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const { return attributes[index]; } @@ -81,13 +89,13 @@ StringRef tblgen::Operator::getArgName(int index) const { return argumentValues->getArgName(index)->getValue(); } -auto tblgen::Operator::attribute_begin() -> attribute_iterator { +auto tblgen::Operator::attribute_begin() const -> attribute_iterator { return attributes.begin(); } -auto tblgen::Operator::attribute_end() -> attribute_iterator { +auto tblgen::Operator::attribute_end() const -> attribute_iterator { return attributes.end(); } -auto tblgen::Operator::getAttributes() +auto tblgen::Operator::getAttributes() const -> llvm::iterator_range<attribute_iterator> { return {attribute_begin(), attribute_end()}; } @@ -173,6 +181,8 @@ void tblgen::Operator::populateOperandsAndAttributes() { } } +ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); } + bool tblgen::Operator::hasDescription() const { return def.getValue("description") != nullptr; } |

