summaryrefslogtreecommitdiffstats
path: root/mlir/test/mlir-cpu-runner/cblas_interface.cpp
diff options
context:
space:
mode:
authorMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
committerMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
commit0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch)
tree31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/test/mlir-cpu-runner/cblas_interface.cpp
parent6f635f90929da9545dd696071a829a1a42f84b30 (diff)
parent5b4a01d4a63cb66ab981e52548f940813393bf42 (diff)
downloadbcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz
bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/test/mlir-cpu-runner/cblas_interface.cpp')
-rw-r--r--mlir/test/mlir-cpu-runner/cblas_interface.cpp105
1 files changed, 105 insertions, 0 deletions
diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp
new file mode 100644
index 00000000000..5e3a00e7fd1
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp
@@ -0,0 +1,105 @@
+//===- cblas_interface.cpp - Simple Blas subset interface -----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Simple Blas subset interface implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "include/cblas.h"
+#include <assert.h>
+#include <iostream>
+
+extern "C" void linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X,
+ float f) {
+ X->data[X->offset] = f;
+}
+
+extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
+ float f) {
+ for (unsigned i = 0; i < X->sizes[0]; ++i)
+ *(X->data + X->offset + i * X->strides[0]) = f;
+}
+
+extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
+ float f) {
+ for (unsigned i = 0; i < X->sizes[0]; ++i)
+ for (unsigned j = 0; j < X->sizes[1]; ++j)
+ *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
+}
+
+extern "C" void linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
+ StridedMemRefType<float, 0> *O) {
+ O->data[O->offset] = I->data[I->offset];
+}
+
+extern "C" void
+linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
+ StridedMemRefType<float, 1> *O) {
+ if (I->sizes[0] != O->sizes[0]) {
+ std::cerr << "Incompatible strided memrefs\n";
+ printMemRefMetaData(std::cerr, *I);
+ printMemRefMetaData(std::cerr, *O);
+ return;
+ }
+ for (unsigned i = 0; i < I->sizes[0]; ++i)
+ O->data[O->offset + i * O->strides[0]] =
+ I->data[I->offset + i * I->strides[0]];
+}
+
+extern "C" void
+linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float, 2> *I,
+ StridedMemRefType<float, 2> *O) {
+ if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
+ std::cerr << "Incompatible strided memrefs\n";
+ printMemRefMetaData(std::cerr, *I);
+ printMemRefMetaData(std::cerr, *O);
+ return;
+ }
+ auto so0 = O->strides[0], so1 = O->strides[1];
+ auto si0 = I->strides[0], si1 = I->strides[1];
+ for (unsigned i = 0; i < I->sizes[0]; ++i)
+ for (unsigned j = 0; j < I->sizes[1]; ++j)
+ O->data[O->offset + i * so0 + j * so1] =
+ I->data[I->offset + i * si0 + j * si1];
+}
+
+extern "C" void
+linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float, 1> *X,
+ StridedMemRefType<float, 1> *Y,
+ StridedMemRefType<float, 0> *Z) {
+ if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
+ std::cerr << "Incompatible strided memrefs\n";
+ printMemRefMetaData(std::cerr, *X);
+ printMemRefMetaData(std::cerr, *Y);
+ printMemRefMetaData(std::cerr, *Z);
+ return;
+ }
+ Z->data[Z->offset] +=
+ cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
+ Y->data + Y->offset, Y->strides[0]);
+}
+
+extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
+ StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
+ StridedMemRefType<float, 2> *C) {
+ if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
+ A->strides[1] != 1 || A->sizes[0] < A->strides[1] ||
+ B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] ||
+ C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] ||
+ A->sizes[1] != B->sizes[0]) {
+ printMemRefMetaData(std::cerr, *A);
+ printMemRefMetaData(std::cerr, *B);
+ printMemRefMetaData(std::cerr, *C);
+ return;
+ }
+ cblas_sgemm(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
+ CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1],
+ A->sizes[1], 1.0f, A->data + A->offset, A->strides[0],
+ B->data + B->offset, B->strides[0], 1.0f, C->data + C->offset,
+ C->strides[0]);
+}
OpenPOWER on IntegriCloud