diff options
| -rw-r--r-- | mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index 262b12c8954..934faf82b36 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -33,6 +33,15 @@ #define MLIR_RUNNER_UTILS_EXPORT #endif +template <typename T, int N> struct StridedMemRefType; +template <typename StreamType, typename T, int N> +void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V); + +template <int N> void dropFront(int64_t arr[N], int64_t *res) { + for (unsigned i = 1; i < N; ++i) + *(res + i - 1) = arr[i]; +} + /// StridedMemRef descriptor type with static rank. template <typename T, int N> struct StridedMemRefType { T *basePtr; @@ -40,6 +49,26 @@ template <typename T, int N> struct StridedMemRefType { int64_t offset; int64_t sizes[N]; int64_t strides[N]; + // This operator[] is extremely slow and only for sugaring purposes. + StridedMemRefType<T, N - 1> operator[](int64_t idx) { + StridedMemRefType<T, N - 1> res; + res.basePtr = basePtr; + res.data = data; + res.offset = offset + idx * strides[0]; + dropFront<N>(sizes, res.sizes); + dropFront<N>(strides, res.strides); + return res; + } +}; + +/// StridedMemRef descriptor type specialized for rank 1. +template <typename T> struct StridedMemRefType<T, 1> { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[1]; + int64_t strides[1]; + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } }; /// StridedMemRef descriptor type specialized for rank 0. |

