diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/bindings/python/pybind.cpp | 22 | ||||
| -rw-r--r-- | mlir/bindings/python/test/test_py2and3.py | 11 | ||||
| -rw-r--r-- | mlir/include/mlir-c/Core.h | 6 | ||||
| -rw-r--r-- | mlir/lib/EDSC/CoreAPIs.cpp | 12 |
4 files changed, 49 insertions, 2 deletions
diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index eb9984feb1c..825f800c0bd 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -193,12 +193,18 @@ struct PythonMLIRModule { const py::list &arguments, const py::list &successors, py::kwargs attributes); - // Create an integer attribute. + // Creates an integer attribute. PythonAttribute integerAttr(PythonType type, int64_t value); - // Create a boolean attribute. + // Creates a boolean attribute. PythonAttribute boolAttr(bool value); + // Creates a float attribute. + PythonAttribute floatAttr(float value); + + // Creates a string atrribute. + PythonAttribute stringAttr(const std::string &value); + // Creates an Array attribute. PythonAttribute arrayAttr(const std::vector<PythonAttribute> &values); @@ -713,6 +719,14 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) { return PythonAttribute(::makeBoolAttr(&mlirContext, value)); } +PythonAttribute PythonMLIRModule::floatAttr(float value) { + return PythonAttribute(::makeFloatAttr(&mlirContext, value)); +} + +PythonAttribute PythonMLIRModule::stringAttr(const std::string &value) { + return PythonAttribute(::makeStringAttr(&mlirContext, value.c_str())); +} + PythonAttribute PythonMLIRModule::arrayAttr(const std::vector<PythonAttribute> &values) { std::vector<mlir::Attribute> mlir_attributes(values.begin(), values.end()); @@ -910,6 +924,10 @@ PYBIND11_MODULE(pybind, m) { "integerAttr", &PythonMLIRModule::integerAttr, "Creates an mlir::IntegerAttr of the given type with the given value " "in the context associated with this MLIR module.") + .def("floatAttr", &PythonMLIRModule::floatAttr, + "Creates an mlir::FloatAttr with the given value") + .def("stringAttr", &PythonMLIRModule::stringAttr, + "Creates an mlir::StringAttr with the given value") .def("arrayAttr", &PythonMLIRModule::arrayAttr, "Creates an mlir::ArrayAttr of the given type with the given values " "in the context associated with this MLIR module.") diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index 02ff4ab3061..02f8f628046 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -347,6 +347,17 @@ class EdscTest: # CHECK-LABEL: testFunctionDeclarationWithArrayAttr # CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]}) + def testFunctionDeclarationWithFloatAndStringAttr(self): + self.setUp() + float_attr = self.module.floatAttr(23.3) + string_attr = self.module.stringAttr("TEST_STRING") + + f = self.module.declare_function( + "foo", [], [], float_attr=float_attr, string_attr=string_attr) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testFunctionDeclarationWithFloatAndStringAttr + # CHECK: func @foo() attributes {float_attr = 2.330000e+01 : f32, string_attr = "TEST_STRING"} + def testFunctionMultiple(self): self.setUp() with self.module.function_context("foo", [], []): diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 857d42ecf7a..c205e898901 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -87,6 +87,12 @@ mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value); /// Returns an `mlir::BoolAttr` with the given value. mlir_attr_t makeBoolAttr(mlir_context_t context, bool value); +/// Returns an `mlir::FloatAttr` with the given value. +mlir_attr_t makeFloatAttr(mlir_context_t context, float value); + +/// Returns an `mlir::StringAttr` with the given value. +mlir_attr_t makeStringAttr(mlir_context_t context, const char *value); + /// Parses an MLIR type from the string `type` in the given context. Returns a /// NULL type on error. If non-NULL, `charsRead` will contain the number of /// characters that were processed by the parser. diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index ab935742b8d..b88a1fdf4ef 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -74,6 +74,18 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) { return mlir_attr_t{attr.getAsOpaquePointer()}; } +mlir_attr_t makeFloatAttr(mlir_context_t context, float value) { + auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context); + auto attr = FloatAttr::get(FloatType::getF32(ctx), APFloat(value)); + return mlir_attr_t{attr.getAsOpaquePointer()}; +} + +mlir_attr_t makeStringAttr(mlir_context_t context, const char *value) { + auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context); + auto attr = StringAttr::get(value, ctx); + return mlir_attr_t{attr.getAsOpaquePointer()}; +} + unsigned getFunctionArity(mlir_func_t function) { auto f = mlir::FuncOp::getFromOpaquePointer(function); return f.getNumArguments(); |

