//===- mlir_runner_utils.h - Utils for debugging MLIR CPU execution -------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_CPU_RUNNER_MLIRUTILS_H_ #define MLIR_CPU_RUNNER_MLIRUTILS_H_ #include #include #include #ifdef _WIN32 #ifndef MLIR_RUNNER_UTILS_EXPORT #ifdef mlir_runner_utils_EXPORTS /* We are building this library */ #define MLIR_RUNNER_UTILS_EXPORT __declspec(dllexport) #else /* We are using this library */ #define MLIR_RUNNER_UTILS_EXPORT __declspec(dllimport) #endif // mlir_runner_utils_EXPORTS #endif // MLIR_RUNNER_UTILS_EXPORT #else #define MLIR_RUNNER_UTILS_EXPORT #endif // _WIN32 template struct StridedMemRefType; template void printMemRefMetaData(StreamType &os, StridedMemRefType &V); template void dropFront(int64_t arr[N], int64_t *res) { for (unsigned i = 1; i < N; ++i) *(res + i - 1) = arr[i]; } /// StridedMemRef descriptor type with static rank. template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; int64_t sizes[N]; int64_t strides[N]; // This operator[] is extremely slow and only for sugaring purposes. StridedMemRefType operator[](int64_t idx) { StridedMemRefType res; res.basePtr = basePtr; res.data = data; res.offset = offset + idx * strides[0]; dropFront(sizes, res.sizes); dropFront(strides, res.strides); return res; } }; /// StridedMemRef descriptor type specialized for rank 1. template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; int64_t sizes[1]; int64_t strides[1]; T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } }; /// StridedMemRef descriptor type specialized for rank 0. template struct StridedMemRefType { T *basePtr; T *data; int64_t offset; }; // Unranked MemRef template struct UnrankedMemRefType { int64_t rank; void *descriptor; }; template void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { static_assert(N > 0, "Expected N > 0"); os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = " << N << " offset = " << V.offset << " sizes = [" << V.sizes[0]; for (unsigned i = 1; i < N; ++i) os << ", " << V.sizes[i]; os << "] strides = [" << V.strides[0]; for (unsigned i = 1; i < N; ++i) os << ", " << V.strides[i]; os << "]"; } template void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = 0" << " offset = " << V.offset; } template void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { os << "Unranked Memref rank = " << V.rank << " " << "descriptor@ = " << reinterpret_cast(V.descriptor) << "\n"; } template struct Vector { Vector vector[Dim]; }; template struct Vector { T vector[Dim]; }; template using Vector1D = Vector; template using Vector2D = Vector; template using Vector3D = Vector; template using Vector4D = Vector; //////////////////////////////////////////////////////////////////////////////// // Templated instantiation follows. //////////////////////////////////////////////////////////////////////////////// namespace impl { template std::ostream &operator<<(std::ostream &os, const Vector &v); template struct StaticSizeMult { static constexpr int value = 1; }; template struct StaticSizeMult { static constexpr int value = N * StaticSizeMult::value; }; static inline void printSpace(std::ostream &os, int count) { for (int i = 0; i < count; ++i) { os << ' '; } } template struct VectorDataPrinter { static void print(std::ostream &os, const Vector &val); }; template void VectorDataPrinter::print(std::ostream &os, const Vector &val) { static_assert(M > 0, "0 dimensioned tensor"); static_assert(sizeof(val) == M * StaticSizeMult::value * sizeof(T), "Incorrect vector size!"); // First os << "(" << val.vector[0]; if (M > 1) os << ", "; if (sizeof...(Dims) > 1) os << "\n"; // Kernel for (unsigned i = 1; i + 1 < M; ++i) { printSpace(os, 2 * sizeof...(Dims)); os << val.vector[i] << ", "; if (sizeof...(Dims) > 1) os << "\n"; } // Last if (M > 1) { printSpace(os, sizeof...(Dims)); os << val.vector[M - 1]; } os << ")"; } template std::ostream &operator<<(std::ostream &os, const Vector &v) { VectorDataPrinter::print(os, v); return os; } template struct MemRefDataPrinter { static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides); static void printFirst(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides); static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides); }; template struct MemRefDataPrinter { static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes = nullptr, int64_t *strides = nullptr); }; template void MemRefDataPrinter::printFirst(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides) { os << "["; MemRefDataPrinter::print(os, base, rank, offset, sizes + 1, strides + 1); // If single element, close square bracket and return early. if (sizes[0] <= 1) { os << "]"; return; } os << ", "; if (N > 1) os << "\n"; } template void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides) { printFirst(os, base, rank, offset, sizes, strides); for (unsigned i = 1; i + 1 < sizes[0]; ++i) { printSpace(os, rank - N + 1); MemRefDataPrinter::print(os, base, rank, offset + i * strides[0], sizes + 1, strides + 1); os << ", "; if (N > 1) os << "\n"; } if (sizes[0] <= 1) return; printLast(os, base, rank, offset, sizes, strides); } template void MemRefDataPrinter::printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides) { printSpace(os, rank - N + 1); MemRefDataPrinter::print(os, base, rank, offset + (sizes[0] - 1) * (*strides), sizes + 1, strides + 1); os << "]"; } template void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, int64_t offset, int64_t *sizes, int64_t *strides) { os << base[offset]; } template void printMemRef(StridedMemRefType &M) { static_assert(N > 0, "Expected N > 0"); printMemRefMetaData(std::cout, M); std::cout << " data = " << std::endl; MemRefDataPrinter::print(std::cout, M.data, N, M.offset, M.sizes, M.strides); std::cout << std::endl; } template void printMemRef(StridedMemRefType &M) { printMemRefMetaData(std::cout, M); std::cout << " data = " << std::endl; std::cout << "["; MemRefDataPrinter::print(std::cout, M.data, 0, M.offset); std::cout << "]" << std::endl; } } // namespace impl //////////////////////////////////////////////////////////////////////////////// // Currently exposed C API. //////////////////////////////////////////////////////////////////////////////// extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_i8(UnrankedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(UnrankedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_0d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_1d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_2d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_3d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_4d_f32(StridedMemRefType *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_vector_4x4xf32(StridedMemRefType, 2> *M); // Small runtime support "lib" for vector.print lowering. extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f64(double d); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_open(); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_close(); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_comma(); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_newline(); #endif // MLIR_CPU_RUNNER_MLIRUTILS_H_