diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 64 |
1 files changed, 63 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index a9566422a83..c39fdac93b9 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -402,9 +402,71 @@ public: } } + /// Propagate the shape to operands of instructions with shape information. + void propagateShapeBackward() { + SmallVector<Value *, 8> WorkList; + // Worklist contains instruction for which we already know the shape. + for (auto &V : ShapeMap) + WorkList.push_back(V.first); + + // Pop an element with known shape. Traverse the operands, if their shape + // derives from the result shape and is unknown, add it and add them to the + // worklist. + LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); + while (!WorkList.empty()) { + Value *V = WorkList.back(); + WorkList.pop_back(); + + if (!isa<Instruction>(V)) + continue; + + Value *MatrixA; + Value *MatrixB; + Value *M; + Value *N; + Value *K; + if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(MatrixA), m_Value(MatrixB), m_Value(M), + m_Value(N), m_Value(K)))) { + if (setShapeInfo(MatrixA, {M, N})) + WorkList.push_back(MatrixA); + + if (setShapeInfo(MatrixB, {N, K})) + WorkList.push_back(MatrixB); + + } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(MatrixA), m_Value(M), m_Value(N)))) { + // Flip dimensions. + if (setShapeInfo(MatrixA, {M, N})) + WorkList.push_back(MatrixA); + } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( + m_Value(MatrixA), m_Value(), m_Value(), + m_Value(M), m_Value(N)))) { + if (setShapeInfo(MatrixA, {M, N})) { + WorkList.push_back(MatrixA); + } + } else if (isa<LoadInst>(V) || + match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { + // Nothing to do, no matrix input. + } else if (isa<StoreInst>(V)) { + // Nothing to do. We forward-propagated to this so we would just + // backward propagate to an instruction with an already known shape. + } else if (isUniformShape(V)) { + // Propagate to all operands. + ShapeInfo Shape = ShapeMap[V]; + for (Use &U : cast<Instruction>(V)->operands()) { + if (setShapeInfo(U.get(), Shape)) + WorkList.push_back(U.get()); + } + } + } + } + bool Visit() { - if (EnableShapePropagation) + if (EnableShapePropagation) { propagateShapeForward(); + propagateShapeBackward(); + } ReversePostOrderTraversal<Function *> RPOT(&Func); bool Changed = false; |