diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 42 |
1 files changed, 29 insertions, 13 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index c39fdac93b9..afe1b4e7cc7 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -95,20 +95,20 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); // Compute the start of the column with index Col as Col * Stride. - Value *ColumnStart = Builder.CreateMul(Col, Stride); + Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); // Get pointer to the start of the selected column. Skip GEP creation, // if we select column 0. if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) ColumnStart = BasePtr; else - ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart); + ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); // Cast elementwise column start pointer to a pointer to a column // (EltType x NumRows)*. Type *ColumnType = VectorType::get(EltType, NumRows); Type *ColumnPtrType = PointerType::get(ColumnType, AS); - return Builder.CreatePointerCast(ColumnStart, ColumnPtrType); + return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -317,7 +317,7 @@ public: default: return false; } - return isUniformShape(V) || isa<StoreInst>(V); + return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V); } /// Propagate the shape information of instructions to their users. @@ -481,6 +481,8 @@ public: Value *Op2; if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) Changed |= VisitBinaryOperator(BinOp); + if (match(&Inst, m_Load(m_Value(Op1)))) + Changed |= VisitLoad(&Inst, Op1, Builder); else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) Changed |= VisitStore(&Inst, Op1, Op2, Builder); } @@ -495,7 +497,7 @@ public: LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, IRBuilder<> Builder) { unsigned Align = DL.getABITypeAlignment(EltType); - return Builder.CreateAlignedLoad(ColumnPtr, Align); + return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); } StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, @@ -536,17 +538,11 @@ public: return true; } - /// Lowers llvm.matrix.columnwise.load. - /// - /// The intrinsic loads a matrix from memory using a stride between columns. - void LowerColumnwiseLoad(CallInst *Inst) { + void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, + ShapeInfo Shape) { IRBuilder<> Builder(Inst); - Value *Ptr = Inst->getArgOperand(0); - Value *Stride = Inst->getArgOperand(1); auto VType = cast<VectorType>(Inst->getType()); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - ShapeInfo Shape(Inst->getArgOperand(2), Inst->getArgOperand(3)); - ColumnMatrixTy Result; // Distance between start of one column and the start of the next for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { @@ -560,6 +556,16 @@ public: finalizeLowering(Inst, Result, Builder); } + /// Lowers llvm.matrix.columnwise.load. + /// + /// The intrinsic loads a matrix from memory using a stride between columns. + void LowerColumnwiseLoad(CallInst *Inst) { + Value *Ptr = Inst->getArgOperand(0); + Value *Stride = Inst->getArgOperand(1); + LowerLoad(Inst, Ptr, Stride, + {Inst->getArgOperand(2), Inst->getArgOperand(3)}); + } + void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, ShapeInfo Shape) { IRBuilder<> Builder(Inst); @@ -755,6 +761,16 @@ public: finalizeLowering(Inst, Result, Builder); } + /// Lower load instructions, if shape information is available. + bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); + return true; + } + bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, IRBuilder<> &Builder) { auto I = ShapeMap.find(StoredVal); |