summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/bindings/python/pybind.cpp22
-rw-r--r--mlir/bindings/python/test/test_py2and3.py11
-rw-r--r--mlir/include/mlir-c/Core.h6
-rw-r--r--mlir/lib/EDSC/CoreAPIs.cpp12
4 files changed, 49 insertions, 2 deletions
diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp
index eb9984feb1c..825f800c0bd 100644
--- a/mlir/bindings/python/pybind.cpp
+++ b/mlir/bindings/python/pybind.cpp
@@ -193,12 +193,18 @@ struct PythonMLIRModule {
const py::list &arguments, const py::list &successors,
py::kwargs attributes);
- // Create an integer attribute.
+ // Creates an integer attribute.
PythonAttribute integerAttr(PythonType type, int64_t value);
- // Create a boolean attribute.
+ // Creates a boolean attribute.
PythonAttribute boolAttr(bool value);
+ // Creates a float attribute.
+ PythonAttribute floatAttr(float value);
+
+ // Creates a string atrribute.
+ PythonAttribute stringAttr(const std::string &value);
+
// Creates an Array attribute.
PythonAttribute arrayAttr(const std::vector<PythonAttribute> &values);
@@ -713,6 +719,14 @@ PythonAttribute PythonMLIRModule::boolAttr(bool value) {
return PythonAttribute(::makeBoolAttr(&mlirContext, value));
}
+PythonAttribute PythonMLIRModule::floatAttr(float value) {
+ return PythonAttribute(::makeFloatAttr(&mlirContext, value));
+}
+
+PythonAttribute PythonMLIRModule::stringAttr(const std::string &value) {
+ return PythonAttribute(::makeStringAttr(&mlirContext, value.c_str()));
+}
+
PythonAttribute
PythonMLIRModule::arrayAttr(const std::vector<PythonAttribute> &values) {
std::vector<mlir::Attribute> mlir_attributes(values.begin(), values.end());
@@ -910,6 +924,10 @@ PYBIND11_MODULE(pybind, m) {
"integerAttr", &PythonMLIRModule::integerAttr,
"Creates an mlir::IntegerAttr of the given type with the given value "
"in the context associated with this MLIR module.")
+ .def("floatAttr", &PythonMLIRModule::floatAttr,
+ "Creates an mlir::FloatAttr with the given value")
+ .def("stringAttr", &PythonMLIRModule::stringAttr,
+ "Creates an mlir::StringAttr with the given value")
.def("arrayAttr", &PythonMLIRModule::arrayAttr,
"Creates an mlir::ArrayAttr of the given type with the given values "
"in the context associated with this MLIR module.")
diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py
index 02ff4ab3061..02f8f628046 100644
--- a/mlir/bindings/python/test/test_py2and3.py
+++ b/mlir/bindings/python/test/test_py2and3.py
@@ -347,6 +347,17 @@ class EdscTest:
# CHECK-LABEL: testFunctionDeclarationWithArrayAttr
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]})
+ def testFunctionDeclarationWithFloatAndStringAttr(self):
+ self.setUp()
+ float_attr = self.module.floatAttr(23.3)
+ string_attr = self.module.stringAttr("TEST_STRING")
+
+ f = self.module.declare_function(
+ "foo", [], [], float_attr=float_attr, string_attr=string_attr)
+ printWithCurrentFunctionName(str(self.module))
+ # CHECK-LABEL: testFunctionDeclarationWithFloatAndStringAttr
+ # CHECK: func @foo() attributes {float_attr = 2.330000e+01 : f32, string_attr = "TEST_STRING"}
+
def testFunctionMultiple(self):
self.setUp()
with self.module.function_context("foo", [], []):
diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h
index 857d42ecf7a..c205e898901 100644
--- a/mlir/include/mlir-c/Core.h
+++ b/mlir/include/mlir-c/Core.h
@@ -87,6 +87,12 @@ mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value);
/// Returns an `mlir::BoolAttr` with the given value.
mlir_attr_t makeBoolAttr(mlir_context_t context, bool value);
+/// Returns an `mlir::FloatAttr` with the given value.
+mlir_attr_t makeFloatAttr(mlir_context_t context, float value);
+
+/// Returns an `mlir::StringAttr` with the given value.
+mlir_attr_t makeStringAttr(mlir_context_t context, const char *value);
+
/// Parses an MLIR type from the string `type` in the given context. Returns a
/// NULL type on error. If non-NULL, `charsRead` will contain the number of
/// characters that were processed by the parser.
diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp
index ab935742b8d..b88a1fdf4ef 100644
--- a/mlir/lib/EDSC/CoreAPIs.cpp
+++ b/mlir/lib/EDSC/CoreAPIs.cpp
@@ -74,6 +74,18 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
return mlir_attr_t{attr.getAsOpaquePointer()};
}
+mlir_attr_t makeFloatAttr(mlir_context_t context, float value) {
+ auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+ auto attr = FloatAttr::get(FloatType::getF32(ctx), APFloat(value));
+ return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeStringAttr(mlir_context_t context, const char *value) {
+ auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+ auto attr = StringAttr::get(value, ctx);
+ return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
unsigned getFunctionArity(mlir_func_t function) {
auto f = mlir::FuncOp::getFromOpaquePointer(function);
return f.getNumArguments();
OpenPOWER on IntegriCloud