diff options
| -rw-r--r-- | mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index ccad2cd0c0b..e4bdd7cb2be 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -308,8 +308,6 @@ private: ConversionPatternRewriter &rewriter) const { Value *warpSize = rewriter.create<LLVM::ConstantOp>( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); auto type = operand->getType().cast<LLVM::LLVMType>(); @@ -326,6 +324,9 @@ private: loc, int32Type, rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth), one); + // Clamp lane: `activeWidth - 1` + Value *maskAndClamp = + rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one); auto dialect = lowering.getDialect(); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); @@ -363,6 +364,8 @@ private: Value *value = operand; Value *activeMask = rewriter.create<LLVM::ConstantOp>( loc, int32Type, rewriter.getI32IntegerAttr(~0u)); + Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create<LLVM::ConstantOp>( loc, int32Type, rewriter.getI32IntegerAttr(i)); |

