diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 9 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 9 |
2 files changed, 16 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index c2493f773d1..00f89d3644f 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -712,7 +712,8 @@ public: populateGpuToNVVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addIllegalDialect<gpu::GPUDialect>(); - target.addIllegalOp<LLVM::ExpOp>(); + target.addIllegalOp<LLVM::FAbsOp, LLVM::FCeilOp, LLVM::CosOp, + LLVM::ExpOP>(); target.addIllegalOp<FuncOp>(); target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<NVVM::NVVMDialect>(); @@ -739,6 +740,12 @@ void mlir::populateGpuToNVVMConversionPatterns( NVVM::GridDimYOp, NVVM::GridDimZOp>, GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(converter); + patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf", + "__nv_fabs"); + patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf", + "__nv_ceil"); + patterns.insert<OpToFuncCallLowering<CosOp>>(converter, "__nv_cosf", + "__nv_cos"); patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf", "__nv_exp"); } diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 83770641bd4..119479d7ec1 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -51,12 +51,19 @@ public: GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>( converter); + patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "_ocml_fabs_f32", + "_ocml_fabs_f64"); + patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "_ocml_ceil_f32", + "_ocml_ceil_f64"); + patterns.insert<OpToFuncCallLowering<CosOp>>(converter, "_ocml_cos_f32", + "_ocml_cos_f64"); patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "_ocml_exp_f32", "_ocml_exp_f64"); ConversionTarget target(getContext()); target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); - target.addIllegalOp<LLVM::ExpOp>(); + target.addIllegalOp<LLVM::FAbsOp, LLVM::FCeilOp, LLVM::CosOP, + LLVM::ExpOp>(); target.addDynamicallyLegalOp<FuncOp>( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(m, target, patterns, &converter))) |