diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-01-28 20:29:46 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 15:59:14 -0700 |
| commit | 39d81f246a556ef02a26ee239cb097f6d1aec64a (patch) | |
| tree | ea138d2b367fda0a47f145c4265e32120b3e0d2c | |
| parent | 0f9436e56a2c366c839cf43784c1c7a74f5c9ee3 (diff) | |
| download | bcm5719-llvm-39d81f246a556ef02a26ee239cb097f6d1aec64a.tar.gz bcm5719-llvm-39d81f246a556ef02a26ee239cb097f6d1aec64a.zip | |
Introduce python bindings for MLIR EDSCs
This CL also introduces a set of python bindings using pybind11. The bindings
are exercised using a `test_py2andpy3.py` test suite that works for both
python 2 and 3.
`test_py3.py` on the other hand uses the more idiomatic,
python 3 only "PEP 3132 -- Extended Iterable Unpacking" to implement a rank
and type-agnostic copy with transposition.
Because python assignment is by reference, we cannot easily make the
assignment operator use the same type of sugaring as in C++; i.e. the
following:
```cpp
Stmt block = edsc::Block({
For(ivs, zeros, shapeA, ones, {
C[ivs] = IA[ivs] + IB[ivs]
})});
```
has no equivalent in the native Python EDSCs at this time.
However, the sugaring can be built as a simple DSL in python and is left as
future work.
PiperOrigin-RevId: 231337667
| -rw-r--r-- | mlir/bindings/python/pybind.cpp | 559 | ||||
| -rw-r--r-- | mlir/bindings/python/test/test_py2and3.py | 208 | ||||
| -rw-r--r-- | mlir/bindings/python/test/test_py3.py | 47 | ||||
| -rw-r--r-- | mlir/lib/EDSC/MLIREmitter.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/EDSC/Types.cpp | 42 |
5 files changed, 839 insertions, 20 deletions
diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp new file mode 100644 index 00000000000..e99694b94e1 --- /dev/null +++ b/mlir/bindings/python/pybind.cpp @@ -0,0 +1,559 @@ +#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm/include/llvm/IR/Module.h" +#include "third_party/llvm/llvm/include/llvm/Support/TargetSelect.h" +#include "third_party/llvm/llvm/include/llvm/Support/raw_ostream.h" +#include <cstddef> + +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir-c/Core.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Passes.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +#include "mlir/IR/Function.h" +#include "mlir/IR/Types.h" + +static bool inited = [] { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + return true; +}(); + +namespace mlir { +namespace edsc { +namespace python { + +static std::vector<std::unique_ptr<mlir::Pass>> getDefaultPasses( + const std::vector<const mlir::PassInfo *> &mlirPassInfoList = {}) { + std::vector<std::unique_ptr<mlir::Pass>> passList; + passList.reserve(mlirPassInfoList.size() + 4); + // Run each of the passes that were selected. + for (const auto *passInfo : mlirPassInfoList) { + passList.emplace_back(passInfo->createPass()); + } + // Append the extra passes for lowering to MLIR. + passList.emplace_back(mlir::createConstantFoldPass()); + passList.emplace_back(mlir::createCSEPass()); + passList.emplace_back(mlir::createCanonicalizerPass()); + passList.emplace_back(mlir::createLowerAffinePass()); + return passList; +} + +// Run the passes sequentially on the given module. +// Return `nullptr` immediately if any of the passes fails. +static bool runPasses(const std::vector<std::unique_ptr<mlir::Pass>> &passes, + Module *module) { + for (const auto &pass : passes) { + mlir::PassResult result = pass->runOnModule(module); + if (result == mlir::PassResult::Failure || module->verify()) { + llvm::errs() << "Pass failed\n"; + return true; + } + } + return false; +} + +namespace py = pybind11; + +struct PythonBindable; +struct PythonExpr; +struct PythonStmt; + +struct PythonFunction { + PythonFunction() : function{nullptr} {} + PythonFunction(mlir_func_t f) : function{f} {} + PythonFunction(mlir::Function *f) : function{f} {} + operator mlir_func_t() { return function; } + std::string str() { + mlir::Function *f = reinterpret_cast<mlir::Function *>(function); + std::string res; + llvm::raw_string_ostream os(res); + f->print(os); + return res; + } + mlir_func_t function; +}; + +struct PythonType { + PythonType() : type{nullptr} {} + PythonType(mlir_type_t t) : type{t} {} + operator mlir_type_t() { return type; } + std::string str() { + mlir::Type f = mlir::Type::getFromOpaquePointer(type); + std::string res; + llvm::raw_string_ostream os(res); + f.print(os); + return res; + } + mlir_type_t type; +}; + +/// Trivial C++ wrappers make use of the EDSC C API. +struct PythonMLIRModule { + PythonMLIRModule() : mlirContext(), module(new mlir::Module(&mlirContext)) {} + + PythonType makeScalarType(const std::string &mlirElemType, + unsigned bitwidth) { + return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(), + bitwidth); + } + PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) { + return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType, + int64_list_t{sizes.data(), sizes.size()}); + } + PythonFunction makeFunction(const std::string &name, + std::vector<PythonType> &inputTypes, + std::vector<PythonType> &outputTypes) { + std::vector<mlir_type_t> ins(inputTypes.begin(), inputTypes.end()); + std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end()); + auto funcType = ::makeFunctionType( + mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()}, + mlir_type_list_t{outs.data(), outs.size()}); + auto *func = new mlir::Function( + UnknownLoc::get(&mlirContext), name, + mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>()); + func->addEntryBlock(); + module->getFunctions().push_back(func); + return mlir_func_t{func}; + } + + void compile() { + auto created = mlir::ExecutionEngine::create(module.get()); + llvm::handleAllErrors(created.takeError(), + [](const llvm::ErrorInfoBase &b) { + b.log(llvm::errs()); + assert(false); + }); + engine = std::move(*created); + } + + std::string getIR() { + std::string res; + llvm::raw_string_ostream os(res); + module->print(os); + return res; + } + + uint64_t getEngineAddress() { + assert(engine && "module must be compiled into engine first"); + return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get())); + } + +private: + mlir::MLIRContext mlirContext; + // One single module in a python-exposed MLIRContext for now. + std::unique_ptr<mlir::Module> module; + std::unique_ptr<mlir::ExecutionEngine> engine; +}; + +struct ContextManager { + void enter() { context = new ScopedEDSCContext(); } + void exit(py::object, py::object, py::object) { + delete context; + context = nullptr; + } + mlir::edsc::ScopedEDSCContext *context; +}; + +struct PythonExpr { + PythonExpr() : expr{nullptr} {} + PythonExpr(const PythonBindable &bindable); + PythonExpr(const edsc_expr_t &expr) : expr{expr} {} + operator edsc_expr_t() { return expr; } + std::string str() { + assert(expr && "unexpected empty expr"); + return Expr(*this).str(); + } + edsc_expr_t expr; +}; + +struct PythonBindable : public PythonExpr { + PythonBindable() : PythonExpr(edsc_expr_t{makeBindable()}) {} + PythonBindable(PythonExpr expr) : PythonExpr(expr) { + assert(Expr(expr).isa<Bindable>() && "Expected Bindable"); + } + std::string str() { + assert(expr && "unexpected empty expr"); + return Expr(expr).str(); + } +}; + +struct PythonStmt { + PythonStmt() : stmt{nullptr} {} + PythonStmt(const edsc_stmt_t &stmt) : stmt{stmt} {} + PythonStmt(const PythonExpr &e) : stmt{makeStmt(e.expr)} {} + operator edsc_stmt_t() { return stmt; } + std::string str() { + assert(stmt && "unexpected empty stmt"); + return Stmt(stmt).str(); + } + edsc_stmt_t stmt; +}; + +struct PythonIndexed : public edsc_indexed_t { + PythonIndexed() : edsc_indexed_t{makeIndexed(PythonBindable())} {} + PythonIndexed(PythonExpr e) : edsc_indexed_t{makeIndexed(e)} {} + PythonIndexed(PythonBindable b) : edsc_indexed_t{makeIndexed(b)} {} + operator PythonExpr() { return PythonExpr(base); } +}; + +struct MLIRFunctionEmitter { + MLIRFunctionEmitter(PythonFunction f) + : currentFunction(reinterpret_cast<mlir::Function *>(f.function)), + currentBuilder(currentFunction), + emitter(¤tBuilder, currentFunction->getLoc()) {} + + PythonExpr bindConstantBF16(double value); + PythonExpr bindConstantF16(float value); + PythonExpr bindConstantF32(float value); + PythonExpr bindConstantF64(double value); + PythonExpr bindConstantInt(int64_t value, unsigned bitwidth); + PythonExpr bindConstantIndex(int64_t value); + PythonExpr bindFunctionArgument(unsigned pos); + py::list bindFunctionArguments(); + py::list bindFunctionArgumentView(unsigned pos); + py::list bindMemRefShape(PythonExpr boundMemRef); + py::list bindIndexedMemRefShape(PythonIndexed boundMemRef) { + return bindMemRefShape(boundMemRef.base); + } + py::list bindMemRefView(PythonExpr boundMemRef); + py::list bindIndexedMemRefView(PythonIndexed boundMemRef) { + return bindMemRefView(boundMemRef.base); + } + void emit(PythonStmt stmt); + +private: + mlir::Function *currentFunction; + mlir::FuncBuilder currentBuilder; + mlir::edsc::MLIREmitter emitter; + edsc_mlir_emitter_t c_emitter; +}; + +static edsc_stmt_list_t makeCStmts(llvm::SmallVectorImpl<edsc_stmt_t> &owning, + const py::list &stmts) { + for (auto &inp : stmts) { + owning.push_back(edsc_stmt_t{inp.cast<PythonStmt>()}); + } + return edsc_stmt_list_t{owning.data(), owning.size()}; +} + +static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl<edsc_expr_t> &owning, + const py::list &exprs) { + for (auto &inp : exprs) { + owning.push_back(edsc_expr_t{inp.cast<PythonExpr>()}); + } + return edsc_expr_list_t{owning.data(), owning.size()}; +} + +PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {} + +PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) { + return ::bindConstantBF16(edsc_mlir_emitter_t{&emitter}, value); +} + +PythonExpr MLIRFunctionEmitter::bindConstantF16(float value) { + return ::bindConstantF16(edsc_mlir_emitter_t{&emitter}, value); +} + +PythonExpr MLIRFunctionEmitter::bindConstantF32(float value) { + return ::bindConstantF32(edsc_mlir_emitter_t{&emitter}, value); +} + +PythonExpr MLIRFunctionEmitter::bindConstantF64(double value) { + return ::bindConstantF64(edsc_mlir_emitter_t{&emitter}, value); +} + +PythonExpr MLIRFunctionEmitter::bindConstantInt(int64_t value, + unsigned bitwidth) { + return ::bindConstantInt(edsc_mlir_emitter_t{&emitter}, value, bitwidth); +} + +PythonExpr MLIRFunctionEmitter::bindConstantIndex(int64_t value) { + return ::bindConstantIndex(edsc_mlir_emitter_t{&emitter}, value); +} + +PythonExpr MLIRFunctionEmitter::bindFunctionArgument(unsigned pos) { + return ::bindFunctionArgument(edsc_mlir_emitter_t{&emitter}, + mlir_func_t{currentFunction}, pos); +} + +PythonExpr getPythonType(edsc_expr_t e) { return PythonExpr(e); } + +template <typename T> py::list makePyList(llvm::ArrayRef<T> owningResults) { + py::list res; + for (auto e : owningResults) { + res.append(getPythonType(e)); + } + return res; +} + +py::list MLIRFunctionEmitter::bindFunctionArguments() { + auto arity = getFunctionArity(mlir_func_t{currentFunction}); + llvm::SmallVector<edsc_expr_t, 8> owningResults(arity); + edsc_expr_list_t results{owningResults.data(), owningResults.size()}; + ::bindFunctionArguments(edsc_mlir_emitter_t{&emitter}, + mlir_func_t{currentFunction}, &results); + return makePyList(ArrayRef<edsc_expr_t>{owningResults}); +} + +py::list MLIRFunctionEmitter::bindMemRefShape(PythonExpr boundMemRef) { + auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef); + llvm::SmallVector<edsc_expr_t, 8> owningShapes(rank); + edsc_expr_list_t resultShapes{owningShapes.data(), owningShapes.size()}; + ::bindMemRefShape(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultShapes); + return makePyList(ArrayRef<edsc_expr_t>{owningShapes}); +} + +py::list MLIRFunctionEmitter::bindMemRefView(PythonExpr boundMemRef) { + auto rank = getBoundMemRefRank(edsc_mlir_emitter_t{&emitter}, boundMemRef); + // Own the PythonExpr for the arg as well as all its dims. + llvm::SmallVector<edsc_expr_t, 8> owningLbs(rank); + llvm::SmallVector<edsc_expr_t, 8> owningUbs(rank); + llvm::SmallVector<edsc_expr_t, 8> owningSteps(rank); + edsc_expr_list_t resultLbs{owningLbs.data(), owningLbs.size()}; + edsc_expr_list_t resultUbs{owningUbs.data(), owningUbs.size()}; + edsc_expr_list_t resultSteps{owningSteps.data(), owningSteps.size()}; + ::bindMemRefView(edsc_mlir_emitter_t{&emitter}, boundMemRef, &resultLbs, + &resultUbs, &resultSteps); + py::list res; + res.append(makePyList(ArrayRef<edsc_expr_t>{owningLbs})); + res.append(makePyList(ArrayRef<edsc_expr_t>{owningUbs})); + res.append(makePyList(ArrayRef<edsc_expr_t>{owningSteps})); + return res; +} + +void MLIRFunctionEmitter::emit(PythonStmt stmt) { + emitter.emitStmt(Stmt(stmt)); +} + +PYBIND11_MODULE(pybind, m) { + m.doc() = + "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)"; + m.def("version", []() { return "EDSC Python extensions v0.0"; }); + m.def("initContext", + []() { return static_cast<void *>(new ScopedEDSCContext()); }); + m.def("deleteContext", + [](void *ctx) { delete reinterpret_cast<ScopedEDSCContext *>(ctx); }); + + m.def("Block", [](const py::list &stmts) { + SmallVector<edsc_stmt_t, 8> owning; + return PythonStmt(::Block(makeCStmts(owning, stmts))); + }); + m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs, + const py::list &steps, const py::list &stmts) { + SmallVector<edsc_expr_t, 8> owningIVs; + SmallVector<edsc_expr_t, 8> owningLBs; + SmallVector<edsc_expr_t, 8> owningUBs; + SmallVector<edsc_expr_t, 8> owningSteps; + SmallVector<edsc_stmt_t, 8> owningStmts; + return PythonStmt( + ::ForNest(makeCExprs(owningIVs, ivs), makeCExprs(owningLBs, lbs), + makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps), + makeCStmts(owningStmts, stmts))); + }); + m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step, + const py::list &stmts) { + SmallVector<edsc_stmt_t, 8> owning; + return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts))); + }); + m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) { + return PythonExpr(::Select(cond, e1, e2)); + }); + m.def("Return", []() { + return PythonStmt(::Return(edsc_expr_list_t{nullptr, 0})); + }); + m.def("Return", [](const py::list &returns) { + SmallVector<edsc_expr_t, 8> owningExprs; + return PythonStmt(::Return(makeCExprs(owningExprs, returns))); + }); + +#define DEFINE_PYBIND_BINARY_OP(PYTHON_NAME, C_NAME) \ + m.def(PYTHON_NAME, [](PythonExpr e1, PythonExpr e2) { \ + return PythonExpr(::C_NAME(e1, e2)); \ + }); + + DEFINE_PYBIND_BINARY_OP("Add", Add); + DEFINE_PYBIND_BINARY_OP("Mul", Mul); + DEFINE_PYBIND_BINARY_OP("Sub", Sub); + // DEFINE_PYBIND_BINARY_OP("Div", Div); + DEFINE_PYBIND_BINARY_OP("LT", LT); + DEFINE_PYBIND_BINARY_OP("LE", LE); + DEFINE_PYBIND_BINARY_OP("GT", GT); + DEFINE_PYBIND_BINARY_OP("GE", GE); + +#undef DEFINE_PYBIND_BINARY_OP + + py::class_<PythonFunction>(m, "Function", + "Wrapping class for mlir::Function.") + .def(py::init<PythonFunction>()) + .def("__str__", &PythonFunction::str); + + py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.") + .def(py::init<PythonType>()) + .def("__str__", &PythonType::str); + + py::class_<PythonMLIRModule>( + m, "MLIRModule", + "An MLIRModule is the abstraction that owns the allocations to support " + "compilation of a single mlir::Module into an ExecutionEngine backed by " + "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, " + "adding functions, compiling the module to obtain an ExecutionEngine on " + "which named functions may be called. For now the only means to retrieve " + "the ExecutionEngine is by calling `get_engine_address`. This mode of " + "execution is limited to passing the pointer to C++ where the function " + "is called. Extending the API to allow calling JIT compiled functions " + "directly require integration with a tensor library (e.g. numpy). This " + "is left as the prerogative of libraries and frameworks for now.") + .def(py::init<>()) + .def("make_function", &PythonMLIRModule::makeFunction, + "Creates a new mlir::Function in the current mlir::Module.") + .def( + "make_scalar_type", + [](PythonMLIRModule &instance, const std::string &type, + unsigned bitwidth) { + return instance.makeScalarType(type, bitwidth); + }, + py::arg("type"), py::arg("bitwidth") = 0, + "Returns a scalar mlir::Type using the following convention:\n" + " - makeScalarType(c, \"bf16\") return an `mlir::Type::getBF16`\n" + " - makeScalarType(c, \"f16\") return an `mlir::Type::getF16`\n" + " - makeScalarType(c, \"f32\") return an `mlir::Type::getF32`\n" + " - makeScalarType(c, \"f64\") return an `mlir::Type::getF64`\n" + " - makeScalarType(c, \"index\") return an `mlir::Type::getIndex`\n" + " - makeScalarType(c, \"i\", bitwidth) return an " + "`mlir::Type::getInteger(bitwidth)`\n\n" + " No other combinations are currently supported.") + .def("make_memref_type", &PythonMLIRModule::makeMemRefType, + "Returns an mlir::MemRefType of an elemental scalar. -1 is used to " + "denote symbolic dimensions in the resulting memref shape.") + .def("compile", &PythonMLIRModule::compile, + "Compiles the mlir::Module to LLVMIR a creates new opaque " + "ExecutionEngine backed by the ORC JIT.") + .def("get_ir", &PythonMLIRModule::getIR, + "Returns a dump of the MLIR representation of the module. This is " + "used for serde to support out-of-process execution as well as " + "debugging purposes.") + .def("get_engine_address", &PythonMLIRModule::getEngineAddress, + "Returns the address of the compiled ExecutionEngine. This is used " + "for in-process execution."); + + py::class_<ContextManager>( + m, "ContextManager", + "An EDSC context manager is the memory arena containing all the EDSC " + "allocations.\nUsage:\n\n" + "with E.ContextManager() as _:\n i = E.Expr(E.Bindable())\n ...") + .def(py::init<>()) + .def("__enter__", &ContextManager::enter) + .def("__exit__", &ContextManager::exit); + + py::class_<MLIRFunctionEmitter>( + m, "MLIRFunctionEmitter", + "An MLIRFunctionEmitter is used to fill an empty function body. This is " + "a staged process:\n" + " 1. create or retrieve an mlir::Function `f` with an empty body;\n" + " 2. make an `MLIRFunctionEmitter(f)` to build the current function;\n" + " 3. create leaf Expr that are either Bindable or already Expr that are" + " bound to constants and function arguments by using methods of " + " `MLIRFunctionEmitter`;\n" + " 4. build the function body using Expr, Indexed and Stmt;\n" + " 5. emit the MLIR to implement the function body.") + .def(py::init<PythonFunction>()) + .def("bind_constant_bf16", &MLIRFunctionEmitter::bindConstantBF16) + .def("bind_constant_f16", &MLIRFunctionEmitter::bindConstantF16) + .def("bind_constant_f32", &MLIRFunctionEmitter::bindConstantF32) + .def("bind_constant_f64", &MLIRFunctionEmitter::bindConstantF64) + .def("bind_constant_int", &MLIRFunctionEmitter::bindConstantInt) + .def("bind_constant_index", &MLIRFunctionEmitter::bindConstantIndex) + .def("bind_function_argument", &MLIRFunctionEmitter::bindFunctionArgument, + "Returns an Expr that has been bound to a positional argument in " + "the current Function.") + .def("bind_function_arguments", + &MLIRFunctionEmitter::bindFunctionArguments, + "Returns a list of Expr where each Expr has been bound to the " + "corresponding positional argument in the current Function.") + .def("bind_memref_shape", &MLIRFunctionEmitter::bindMemRefShape, + "Returns a list of Expr where each Expr has been bound to the " + "corresponding dimension of the memref.") + .def("bind_memref_view", &MLIRFunctionEmitter::bindMemRefView, + "Returns three lists (lower bound, upper bound and step) of Expr " + "where each triplet of Expr has been bound to the minimal offset, " + "extent and stride of the corresponding dimension of the memref.") + .def("bind_indexed_shape", &MLIRFunctionEmitter::bindIndexedMemRefShape, + "Same as bind_memref_shape but returns a list of `Indexed` that " + "support load and store operations") + .def("bind_indexed_view", &MLIRFunctionEmitter::bindIndexedMemRefView, + "Same as bind_memref_view but returns lists of `Indexed` that " + "support load and store operations") + .def("emit", &MLIRFunctionEmitter::emit, + "Emits the MLIR for the EDSC expressions and statements in the " + "current function body."); + + py::class_<PythonExpr>(m, "Expr", "Wrapping class for mlir::edsc::Expr") + .def(py::init<PythonBindable>()) + .def("__add__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::Add(e1, e2)); }) + .def("__sub__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::Sub(e1, e2)); }) + .def("__mul__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::Mul(e1, e2)); }) + // .def("__div__", [](PythonExpr e1, PythonExpr e2) { return + // PythonExpr(::Div(e1, e2)); }) + .def("__lt__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::LT(e1, e2)); }) + .def("__le__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::LE(e1, e2)); }) + .def("__gt__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::GT(e1, e2)); }) + .def("__ge__", [](PythonExpr e1, + PythonExpr e2) { return PythonExpr(::GE(e1, e2)); }) + .def("__str__", &PythonExpr::str, + R"DOC(Returns the string value for the Expr)DOC"); + + py::class_<PythonBindable>( + m, "Bindable", + "Wrapping class for mlir::edsc::Bindable.\nA Bindable is a special Expr " + "that can be bound manually to specific MLIR SSA Values.") + .def(py::init<>()) + .def("__str__", &PythonBindable::str); + + py::class_<PythonStmt>(m, "Stmt", "Wrapping class for mlir::edsc::Stmt.") + .def(py::init<PythonExpr>()) + .def("__str__", &PythonStmt::str, + R"DOC(Returns the string value for the Expr)DOC"); + + py::class_<PythonIndexed>( + m, "Indexed", + "Wrapping class for mlir::edsc::Indexed.\nAn Indexed is a wrapper class " + "that support load and store operations.") + .def(py::init<>(), R"DOC(Build from fresh Bindable)DOC") + .def(py::init<PythonExpr>(), R"DOC(Build from existing Expr)DOC") + .def(py::init<PythonBindable>(), R"DOC(Build from existing Bindable)DOC") + .def( + "load", + [](PythonIndexed &instance, const py::list &indices) { + SmallVector<edsc_expr_t, 8> owning; + return PythonExpr(Load(instance, makeCExprs(owning, indices))); + }, + R"DOC(Returns an Expr that loads from an Indexed)DOC") + .def( + "store", + [](PythonIndexed &instance, const py::list &indices, + PythonExpr value) { + SmallVector<edsc_expr_t, 8> owning; + return PythonStmt( + Store(value, instance, makeCExprs(owning, indices))); + }, + R"DOC(Returns the Stmt that stores into an Indexed)DOC"); +} + +} // namespace python +} // namespace edsc +} // namespace mlir diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py new file mode 100644 index 00000000000..77798497a94 --- /dev/null +++ b/mlir/bindings/python/test/test_py2and3.py @@ -0,0 +1,208 @@ +"""Python2 and 3 test for the MLIR EDSC C API and Python bindings""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import google_mlir.bindings.python.pybind as E + +help(E) + + +class EdscTest(unittest.TestCase): + + def testBindables(self): + with E.ContextManager(): + i = E.Expr(E.Bindable()) + self.assertIn("$1", i.__str__()) + + def testOneExpr(self): + with E.ContextManager(): + i, lb, ub = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + expr = E.Mul(i, E.Add(lb, ub)) + str = expr.__str__() + self.assertIn("($1 * ($2 + $3))", str) + + def testOneLoop(self): + with E.ContextManager(): + i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + loop = E.For(i, lb, ub, step, [E.Stmt(E.Add(lb, ub))]) + str = loop.__str__() + self.assertIn("for($1 = $2 to $3 step $4) {", str) + self.assertIn("$5 = ($2 + $3)", str) + + def testTwoLoops(self): + with E.ContextManager(): + i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + loop = E.For(i, lb, ub, step, [E.For(i, lb, ub, step, [E.Stmt(i)])]) + str = loop.__str__() + self.assertIn("for($1 = $2 to $3 step $4) {", str) + self.assertIn("for($1 = $2 to $3 step $4) {", str) + self.assertIn("$5 = $1;", str) + + def testNestedLoops(self): + with E.ContextManager(): + i, lb, ub, step = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + ivs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + lbs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + ubs = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + steps = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + loop = E.For(ivs, lbs, ubs, steps, [ + E.For(i, lb, ub, step, [E.Stmt(ub * step - lb)]), + ]) + str = loop.__str__() + self.assertIn("for($5 = $9 to $13 step $17) {", str) + self.assertIn("for($6 = $10 to $14 step $18) {", str) + self.assertIn("for($7 = $11 to $15 step $19) {", str) + self.assertIn("for($8 = $12 to $16 step $20) {", str) + self.assertIn("for($1 = $2 to $3 step $4) {", str) + self.assertIn("= (($3 * $4) - $2);", str) + + def testIndexed(self): + with E.ContextManager(): + i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)])) + stmt = C.store([i, j], A.load([i, k]) * B.load([k, j])) + str = stmt.__str__() + self.assertIn(" = store( ... )", str) + + def testMatmul(self): + with E.ContextManager(): + ivs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + lbs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + ubs = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + steps = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + i, j, k = ivs[0], ivs[1], ivs[2] + A, B, C = list(map(E.Indexed, [E.Bindable() for _ in range(3)])) + loop = E.For( + ivs, lbs, ubs, steps, + [C.store([i, j], + C.load([i, j]) + A.load([i, k]) * B.load([k, j]))]) + str = loop.__str__() + self.assertIn("for($1 = $4 to $7 step $10) {", str) + self.assertIn("for($2 = $5 to $8 step $11) {", str) + self.assertIn("for($3 = $6 to $9 step $12) {", str) + self.assertIn(" = store( ... )", str) + + def testArithmetic(self): + with E.ContextManager(): + i, j, k, l = list(map(E.Expr, [E.Bindable() for _ in range(4)])) + stmt = i + j * k - l + str = stmt.__str__() + self.assertIn("(($1 + ($2 * $3)) - $4)", str) + + def testSelect(self): + with E.ContextManager(): + i, j, k = list(map(E.Expr, [E.Bindable() for _ in range(3)])) + stmt = E.Select(i > j, i, j) + str = stmt.__str__() + self.assertIn("select(($1 > $2), $1, $2)", str) + + def testBlock(self): + with E.ContextManager(): + i, j = list(map(E.Expr, [E.Bindable() for _ in range(2)])) + stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)]) + str = stmt.__str__() + self.assertIn("block {", str) + self.assertIn("$3 = ($1 + $2)", str) + self.assertIn("$4 = ($1 - $2)", str) + self.assertIn("}", str) + + def testMLIRScalarTypes(self): + module = E.MLIRModule() + t = module.make_scalar_type("bf16") + self.assertIn("bf16", t.__str__()) + t = module.make_scalar_type("f16") + self.assertIn("f16", t.__str__()) + t = module.make_scalar_type("f32") + self.assertIn("f32", t.__str__()) + t = module.make_scalar_type("f64") + self.assertIn("f64", t.__str__()) + t = module.make_scalar_type("i", 1) + self.assertIn("i1", t.__str__()) + t = module.make_scalar_type("i", 8) + self.assertIn("i8", t.__str__()) + t = module.make_scalar_type("i", 32) + self.assertIn("i32", t.__str__()) + t = module.make_scalar_type("i", 123) + self.assertIn("i123", t.__str__()) + t = module.make_scalar_type("index") + self.assertIn("index", t.__str__()) + + def testMLIRFunctionCreation(self): + module = E.MLIRModule() + t = module.make_scalar_type("f32") + self.assertIn("f32", t.__str__()) + m = module.make_memref_type(t, [3, 4, -1, 5]) + self.assertIn("memref<3x4x?x5xf32>", m.__str__()) + f = module.make_function("copy", [m, m], []) + self.assertIn( + "func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {", + f.__str__()) + + def testMLIRConstantEmission(self): + module = E.MLIRModule() + f = module.make_function("constants", [], []) + with E.ContextManager(): + emitter = E.MLIRFunctionEmitter(f) + emitter.bind_constant_bf16(1.23) + emitter.bind_constant_f16(1.23) + emitter.bind_constant_f32(1.23) + emitter.bind_constant_f64(1.23) + emitter.bind_constant_int(1, 1) + emitter.bind_constant_int(123, 8) + emitter.bind_constant_int(123, 16) + emitter.bind_constant_int(123, 32) + emitter.bind_constant_index(123) + str = f.__str__() + self.assertIn("constant 1.230000e+00 : bf16", str) + self.assertIn("constant 1.230470e+00 : f16", str) + self.assertIn("constant 1.230000e+00 : f32", str) + self.assertIn("constant 1.230000e+00 : f64", str) + self.assertIn("constant 1 : i1", str) + self.assertIn("constant 123 : i8", str) + self.assertIn("constant 123 : i16", str) + self.assertIn("constant 123 : i32", str) + self.assertIn("constant 123 : index", str) + + # TODO(ntv): support symbolic For bounds with EDSCs + def testMLIREmission(self): + shape = [3, 4, 5] + module = E.MLIRModule() + index = module.make_scalar_type("index") + t = module.make_scalar_type("f32") + m = module.make_memref_type(t, shape) + f = module.make_function("copy", [m, m], []) + + with E.ContextManager(): + emitter = E.MLIRFunctionEmitter(f) + zero = emitter.bind_constant_index(0) + one = emitter.bind_constant_index(1) + input, output = list(map(E.Indexed, emitter.bind_function_arguments())) + M, N, O = emitter.bind_indexed_shape(input) + + ivs = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))])) + lbs = [zero, zero, zero] + ubs = [M, N, O] + steps = [one, one, one] + + # TODO(ntv): emitter.assertEqual(M, oM) etc + loop = E.Block([ + E.For(ivs, lbs, ubs, steps, [output.store(ivs, input.load(ivs))]), + E.Return() + ]) + emitter.emit(loop) + + # print(f) # uncomment to see the emitted IR + str = f.__str__() + self.assertIn("""store %0, %arg1[%i0, %i1, %i2] : memref<3x4x5xf32>""", + str) + + module.compile() + self.assertNotEqual(module.get_engine_address(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlir/bindings/python/test/test_py3.py b/mlir/bindings/python/test/test_py3.py new file mode 100644 index 00000000000..427add5bf18 --- /dev/null +++ b/mlir/bindings/python/test/test_py3.py @@ -0,0 +1,47 @@ +"""Python3 test for the MLIR EDSC C API and Python bindings""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import google_mlir.bindings.python.pybind as E + + +class EdscTest(unittest.TestCase): + + def testSugaredMLIREmission(self): + shape = [3, 4, 5, 6, 7] + shape_t = [7, 4, 5, 6, 3] + module = E.MLIRModule() + t = module.make_scalar_type("f32") + m = module.make_memref_type(t, shape) + m_t = module.make_memref_type(t, shape_t) + f = module.make_function("copy", [m, m_t], []) + + with E.ContextManager(): + emitter = E.MLIRFunctionEmitter(f) + input, output = list(map(E.Indexed, emitter.bind_function_arguments())) + lbs, ubs, steps = emitter.bind_indexed_view(input) + i, *ivs, j = list(map(E.Expr, [E.Bindable() for _ in range(len(shape))])) + + # n-D type and rank agnostic copy-transpose-first-last (where n >= 2). + loop = E.Block([ + E.For([i, *ivs, j], lbs, ubs, steps, + [output.store([i, *ivs, j], input.load([j, *ivs, i]))]), + E.Return() + ]) + emitter.emit(loop) + + # print(f) # uncomment to see the emitted IR + str = f.__str__() + self.assertIn("load %arg0[%i4, %i1, %i2, %i3, %i0]", str) + self.assertIn("store %0, %arg1[%i0, %i1, %i2, %i3, %i4]", str) + + module.compile() + self.assertNotEqual(module.get_engine_address(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index f0efdb5081a..72d30ff31b5 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -459,7 +459,8 @@ edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value, auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter); Bindable b; e->bindConstant<mlir::ConstantIntOp>( - b, value, e->getBuilder()->getIntegerType(bitwidth)); + b, // mlir::APInt(bitwidth, value), + value, e->getBuilder()->getIntegerType(bitwidth)); return b; } diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index e3460936045..e261cc8c836 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -23,7 +23,6 @@ #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -624,8 +623,9 @@ operator[](llvm::ArrayRef<Bindable> indices) const { return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}]; } -// NOLINTNEXTLINE: unconventional-assign-operator -Stmt mlir::edsc::Indexed::operator=(Expr expr) { +// clang-format off +Stmt mlir::edsc::Indexed::operator=(Expr expr) { // NOLINT: unconventional-assign-operator + // clang-format on assert(!indices.empty() && "Expected attached indices to Indexed"); assert(base); Stmt stmt(store(expr, base, indices)); @@ -644,23 +644,27 @@ edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) { mlir_type_t makeScalarType(mlir_context_t context, const char *name, unsigned bitwidth) { mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context); - mlir_type_t res = - llvm::StringSwitch<mlir_type_t>(name) - .Case("bf16", - mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()}) - .Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()}) - .Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()}) - .Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()}) - .Case("index", - mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()}) - .Case("i", - mlir_type_t{ - mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()}) - .Default(mlir_type_t{nullptr}); - if (!res) { - llvm_unreachable("Invalid type specifier"); + if (llvm::StringRef(name) == "bf16") { + return mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()}; } - return res; + if (llvm::StringRef(name) == "f16") { + return mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()}; + } + if (llvm::StringRef(name) == "f32") { + return mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()}; + } + if (llvm::StringRef(name) == "f64") { + return mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()}; + } + if (llvm::StringRef(name) == "index") { + return mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()}; + } + if (llvm::StringRef(name) == "i") { + return mlir_type_t{ + mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()}; + } + assert(false && "unknown scalar type"); + return mlir_type_t{nullptr}; } mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, |

