summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp64
1 files changed, 63 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index a9566422a83..c39fdac93b9 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -402,9 +402,71 @@ public:
}
}
+ /// Propagate the shape to operands of instructions with shape information.
+ void propagateShapeBackward() {
+ SmallVector<Value *, 8> WorkList;
+ // Worklist contains instruction for which we already know the shape.
+ for (auto &V : ShapeMap)
+ WorkList.push_back(V.first);
+
+ // Pop an element with known shape. Traverse the operands, if their shape
+ // derives from the result shape and is unknown, add it and add them to the
+ // worklist.
+ LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
+ while (!WorkList.empty()) {
+ Value *V = WorkList.back();
+ WorkList.pop_back();
+
+ if (!isa<Instruction>(V))
+ continue;
+
+ Value *MatrixA;
+ Value *MatrixB;
+ Value *M;
+ Value *N;
+ Value *K;
+ if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
+ m_Value(N), m_Value(K)))) {
+ if (setShapeInfo(MatrixA, {M, N}))
+ WorkList.push_back(MatrixA);
+
+ if (setShapeInfo(MatrixB, {N, K}))
+ WorkList.push_back(MatrixB);
+
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(MatrixA), m_Value(M), m_Value(N)))) {
+ // Flip dimensions.
+ if (setShapeInfo(MatrixA, {M, N}))
+ WorkList.push_back(MatrixA);
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
+ m_Value(MatrixA), m_Value(), m_Value(),
+ m_Value(M), m_Value(N)))) {
+ if (setShapeInfo(MatrixA, {M, N})) {
+ WorkList.push_back(MatrixA);
+ }
+ } else if (isa<LoadInst>(V) ||
+ match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
+ // Nothing to do, no matrix input.
+ } else if (isa<StoreInst>(V)) {
+ // Nothing to do. We forward-propagated to this so we would just
+ // backward propagate to an instruction with an already known shape.
+ } else if (isUniformShape(V)) {
+ // Propagate to all operands.
+ ShapeInfo Shape = ShapeMap[V];
+ for (Use &U : cast<Instruction>(V)->operands()) {
+ if (setShapeInfo(U.get(), Shape))
+ WorkList.push_back(U.get());
+ }
+ }
+ }
+ }
+
bool Visit() {
- if (EnableShapePropagation)
+ if (EnableShapePropagation) {
propagateShapeForward();
+ propagateShapeBackward();
+ }
ReversePostOrderTraversal<Function *> RPOT(&Func);
bool Changed = false;
OpenPOWER on IntegriCloud