summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h29
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
index 262b12c8954..934faf82b36 100644
--- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
@@ -33,6 +33,15 @@
#define MLIR_RUNNER_UTILS_EXPORT
#endif
+template <typename T, int N> struct StridedMemRefType;
+template <typename StreamType, typename T, int N>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V);
+
+template <int N> void dropFront(int64_t arr[N], int64_t *res) {
+ for (unsigned i = 1; i < N; ++i)
+ *(res + i - 1) = arr[i];
+}
+
/// StridedMemRef descriptor type with static rank.
template <typename T, int N> struct StridedMemRefType {
T *basePtr;
@@ -40,6 +49,26 @@ template <typename T, int N> struct StridedMemRefType {
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
+ // This operator[] is extremely slow and only for sugaring purposes.
+ StridedMemRefType<T, N - 1> operator[](int64_t idx) {
+ StridedMemRefType<T, N - 1> res;
+ res.basePtr = basePtr;
+ res.data = data;
+ res.offset = offset + idx * strides[0];
+ dropFront<N>(sizes, res.sizes);
+ dropFront<N>(strides, res.strides);
+ return res;
+ }
+};
+
+/// StridedMemRef descriptor type specialized for rank 1.
+template <typename T> struct StridedMemRefType<T, 1> {
+ T *basePtr;
+ T *data;
+ int64_t offset;
+ int64_t sizes[1];
+ int64_t strides[1];
+ T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
};
/// StridedMemRef descriptor type specialized for rank 0.
OpenPOWER on IntegriCloud