diff options
Diffstat (limited to 'clang/lib/CodeGen/CGBuiltin.cpp')
| -rw-r--r-- | clang/lib/CodeGen/CGBuiltin.cpp | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index f25a70a8537..e8d2c59c7cd 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -9731,6 +9731,204 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, Builder.CreateStore(Pred, PredOutPtr); return Builder.CreateExtractValue(ResultPair, 0); } + case NVPTX::BI__hmma_m16n16k16_ld_a: + case NVPTX::BI__hmma_m16n16k16_ld_b: + case NVPTX::BI__hmma_m16n16k16_ld_c_f16: + case NVPTX::BI__hmma_m16n16k16_ld_c_f32: { + Address Dst = EmitPointerWithAlignment(E->getArg(0)); + Value *Src = EmitScalarExpr(E->getArg(1)); + Value *Ldm = EmitScalarExpr(E->getArg(2)); + llvm::APSInt isColMajorArg; + if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext())) + return nullptr; + bool isColMajor = isColMajorArg.getSExtValue(); + unsigned IID; + unsigned NumResults; + switch (BuiltinID) { + case NVPTX::BI__hmma_m16n16k16_ld_a: + IID = isColMajor ? Intrinsic::nvvm_wmma_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_load_a_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m16n16k16_ld_b: + IID = isColMajor ? Intrinsic::nvvm_wmma_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_load_b_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m16n16k16_ld_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_load_c_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m16n16k16_ld_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_load_c_f32_row_stride; + NumResults = 8; + break; + default: + llvm_unreachable("Unexpected builtin ID."); + } + Value *Result = + Builder.CreateCall(CGM.getIntrinsic(IID), + {Builder.CreatePointerCast(Src, VoidPtrTy), Ldm}); + + // Save returned values. + for (unsigned i = 0; i < NumResults; ++i) { + Builder.CreateAlignedStore( + Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), + Dst.getElementType()), + Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + } + return Result; + } + + case NVPTX::BI__hmma_m16n16k16_st_c_f16: + case NVPTX::BI__hmma_m16n16k16_st_c_f32: { + Value *Dst = EmitScalarExpr(E->getArg(0)); + Address Src = EmitPointerWithAlignment(E->getArg(1)); + Value *Ldm = EmitScalarExpr(E->getArg(2)); + llvm::APSInt isColMajorArg; + if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext())) + return nullptr; + bool isColMajor = isColMajorArg.getSExtValue(); + unsigned IID; + unsigned NumResults = 8; + // PTX Instructions (and LLVM instrinsics) are defined for slice _d_, yet + // for some reason nvcc builtins use _c_. + switch (BuiltinID) { + case NVPTX::BI__hmma_m16n16k16_st_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_store_d_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m16n16k16_st_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_store_d_f32_row_stride; + break; + default: + llvm_unreachable("Unexpected builtin ID."); + } + Function *Intrinsic = CGM.getIntrinsic(IID); + llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1); + SmallVector<Value *, 10> Values; + Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy)); + for (unsigned i = 0; i < NumResults; ++i) { + Value *V = Builder.CreateAlignedLoad( + Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + Values.push_back(Builder.CreateBitCast(V, ParamType)); + } + Values.push_back(Ldm); + Value *Result = Builder.CreateCall(Intrinsic, Values); + return Result; + } + + // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf) + // --> Intrinsic::nvvm_wmma_mma_sync<layout A,B><DType><CType><Satf> + case NVPTX::BI__hmma_m16n16k16_mma_f16f16: + case NVPTX::BI__hmma_m16n16k16_mma_f32f16: + case NVPTX::BI__hmma_m16n16k16_mma_f32f32: + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: { + Address Dst = EmitPointerWithAlignment(E->getArg(0)); + Address SrcA = EmitPointerWithAlignment(E->getArg(1)); + Address SrcB = EmitPointerWithAlignment(E->getArg(2)); + Address SrcC = EmitPointerWithAlignment(E->getArg(3)); + llvm::APSInt LayoutArg; + if (!E->getArg(4)->isIntegerConstantExpr(LayoutArg, getContext())) + return nullptr; + int Layout = LayoutArg.getSExtValue(); + if (Layout < 0 || Layout > 3) + return nullptr; + llvm::APSInt SatfArg; + if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext())) + return nullptr; + bool Satf = SatfArg.getSExtValue(); + + // clang-format off +#define MMA_VARIANTS(type) {{ \ + Intrinsic::nvvm_wmma_mma_sync_row_row_##type, \ + Intrinsic::nvvm_wmma_mma_sync_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_mma_sync_row_col_##type, \ + Intrinsic::nvvm_wmma_mma_sync_row_col_##type##_satfinite, \ + Intrinsic::nvvm_wmma_mma_sync_col_row_##type, \ + Intrinsic::nvvm_wmma_mma_sync_col_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_mma_sync_col_col_##type, \ + Intrinsic::nvvm_wmma_mma_sync_col_col_##type##_satfinite \ + }} + // clang-format on + + auto getMMAIntrinsic = [Layout, Satf](std::array<unsigned, 8> Variants) { + unsigned Index = Layout * 2 + Satf; + assert(Index < 8); + return Variants[Index]; + }; + unsigned IID; + unsigned NumEltsC; + unsigned NumEltsD; + switch (BuiltinID) { + case NVPTX::BI__hmma_m16n16k16_mma_f16f16: + IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16)); + NumEltsC = 4; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m16n16k16_mma_f32f16: + IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16)); + NumEltsC = 4; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: + IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32)); + NumEltsC = 8; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m16n16k16_mma_f32f32: + IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32)); + NumEltsC = 8; + NumEltsD = 8; + break; + default: + llvm_unreachable("Unexpected builtin ID."); + } +#undef MMA_VARIANTS + + SmallVector<Value *, 24> Values; + Function *Intrinsic = CGM.getIntrinsic(IID); + llvm::Type *ABType = Intrinsic->getFunctionType()->getParamType(0); + // Load A + for (unsigned i = 0; i < 8; ++i) { + Value *V = Builder.CreateAlignedLoad( + Builder.CreateGEP(SrcA.getPointer(), + llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + Values.push_back(Builder.CreateBitCast(V, ABType)); + } + // Load B + for (unsigned i = 0; i < 8; ++i) { + Value *V = Builder.CreateAlignedLoad( + Builder.CreateGEP(SrcB.getPointer(), + llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + Values.push_back(Builder.CreateBitCast(V, ABType)); + } + // Load C + llvm::Type *CType = Intrinsic->getFunctionType()->getParamType(16); + for (unsigned i = 0; i < NumEltsC; ++i) { + Value *V = Builder.CreateAlignedLoad( + Builder.CreateGEP(SrcC.getPointer(), + llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + Values.push_back(Builder.CreateBitCast(V, CType)); + } + Value *Result = Builder.CreateCall(Intrinsic, Values); + llvm::Type *DType = Dst.getElementType(); + for (unsigned i = 0; i < NumEltsD; ++i) + Builder.CreateAlignedStore( + Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), DType), + Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + return Result; + } default: return nullptr; } |

