diff options
-rw-r--r-- | mlir/include/mlir/Dialect/GPU/GPUDialect.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/GPU/GPUOps.td | 35 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 62 | ||||
-rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 41 | ||||
-rw-r--r-- | mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 25 | ||||
-rw-r--r-- | mlir/test/Dialect/GPU/invalid.mlir | 14 | ||||
-rw-r--r-- | mlir/test/Dialect/GPU/ops.mlir | 5 | ||||
-rw-r--r-- | mlir/test/mlir-cuda-runner/shuffle.mlir | 32 |
8 files changed, 213 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index 495238ffea6..93c0b13ee3e 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -26,6 +26,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 46433c6edd5..6751f0a3f70 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -536,6 +536,41 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce", let verifier = [{ return ::verifyAllReduce(*this); }]; } +def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">; + +def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", + "Indexing modes supported by gpu.shuffle.", + [ + GPU_ShuffleOpXor, + ]>; + +def GPU_ShuffleOp : GPU_Op<"shuffle", [NoSideEffect]>, + Arguments<(ins AnyType:$value, I32:$offset, I32:$width, + GPU_ShuffleModeAttr:$mode)>, + Results<(outs AnyType:$result, I1:$valid)> { + let summary = "Shuffles values within a subgroup."; + let description = [{ + The "shuffle" op moves values to a different invocation within the same + subgroup. + + For example + ``` + %1, %2 = gpu.shuffle %0, %offset, %width xor : f32 + ``` + for lane k returns the value from lane `k ^ offset` and `true` if that lane + is smaller than %width. Otherwise it returns an unspecified value and + `false`. A lane is the index of an invocation relative to its subgroup. + + The width specifies the number of invocations that participate in the + shuffle. The width needs to be the same for all invocations that participate + in the shuffle. Exactly the first `width` invocations of a subgroup need to + execute this op in convergence. + }]; + let verifier = [{ return ::verifyShuffleOp(*this); }]; + let printer = [{ printShuffleOp(p, *this); }]; + let parser = [{ return parseShuffleOp(parser, result); }]; +} + def GPU_BarrierOp : GPU_Op<"barrier"> { let summary = "Synchronizes all work items of a workgroup."; let description = [{ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 78fe15dff50..220df53b977 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -473,6 +473,64 @@ private: static constexpr int kWarpSize = 32; }; +struct GPUShuffleOpLowering : public LLVMOpLowering { + explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::ShuffleOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} + + /// Lowers a shuffle to the corresponding NVVM op. + /// + /// Convert the `width` argument into an activeMask (a bitmask which specifies + /// which threads participate in the shuffle) and a maskAndClamp (specifying + /// the highest lane which participates in the shuffle). + /// + /// %one = llvm.constant(1 : i32) : !llvm.i32 + /// %shl = llvm.shl %one, %width : !llvm.i32 + /// %active_mask = llvm.sub %shl, %one : !llvm.i32 + /// %mask_and_clamp = llvm.sub %width, %one : !llvm.i32 + /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset, + /// %mask_and_clamp : !llvm<"{ float, i1 }"> + /// %shfl_value = llvm.extractvalue %shfl[0 : index] : + /// !llvm<"{ float, i1 }"> + /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : + /// !llvm<"{ float, i1 }"> + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + gpu::ShuffleOpOperandAdaptor adaptor(operands); + + auto dialect = lowering.getDialect(); + auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>(); + auto int32Type = LLVM::LLVMType::getInt32Ty(dialect); + auto predTy = LLVM::LLVMType::getInt1Ty(dialect); + auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy}); + + Value *one = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + // Bit mask of active lanes: `(1 << activeWidth) - 1`. + Value *activeMask = rewriter.create<LLVM::SubOp>( + loc, int32Type, + rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()), + one); + // Clamp lane: `activeWidth - 1` + Value *maskAndClamp = + rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one); + + auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); + Value *shfl = rewriter.create<NVVM::ShflBflyOp>( + loc, resultTy, activeMask, adaptor.value(), adaptor.offset(), + maskAndClamp, returnValueAndIsValidAttr); + Value *shflValue = rewriter.create<LLVM::ExtractValueOp>( + loc, valueTy, shfl, rewriter.getIndexArrayAttr(0)); + Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>( + loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); + + rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); + return matchSuccess(); + } +}; + struct GPUFuncOpLowering : LLVMOpLowering { explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter) : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(), @@ -688,8 +746,8 @@ void mlir::populateGpuToNVVMConversionPatterns( NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>, - GPUAllReduceOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>( - converter); + GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, + GPUReturnOpLowering>(converter); patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf", "__nv_exp"); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 7324b96a7e1..9c0183eb90f 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -165,6 +165,47 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { return success(); } +static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { + auto type = shuffleOp.value()->getType(); + if (shuffleOp.result()->getType() != type) { + return shuffleOp.emitOpError() + << "requires the same type for value operand and result"; + } + if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) { + return shuffleOp.emitOpError() + << "requires value operand type to be f32 or i32"; + } + return success(); +} + +static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { + p << ShuffleOp::getOperationName() << ' '; + p.printOperands(op.getOperands()); + p << ' ' << op.mode() << " : "; + p.printType(op.value()->getType()); +} + +static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 3> operandInfo; + if (parser.parseOperandList(operandInfo, 3)) + return failure(); + + StringRef mode; + if (parser.parseKeyword(&mode)) + return failure(); + state.addAttribute("mode", parser.getBuilder().getStringAttr(mode)); + + Type valueType; + Type int32Type = parser.getBuilder().getIntegerType(32); + Type int1Type = parser.getBuilder().getI1Type(); + if (parser.parseColonType(valueType) || + parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type}, + parser.getCurrentLocation(), state.operands) || + parser.addTypesToList({valueType, int1Type}, state.types)) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // LaunchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 525016b7cf4..b1820cb778f 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -75,6 +75,31 @@ module attributes {gpu.kernel_module} { // ----- module attributes {gpu.kernel_module} { + // CHECK-LABEL: func @gpu_shuffle() + func @gpu_shuffle() + attributes { gpu.kernel } { + // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + %arg0 = constant 1.0 : f32 + // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32 + %arg1 = constant 4 : i32 + // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : !llvm.i32 + %arg2 = constant 23 : i32 + // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : !llvm.i32 + // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : !llvm.i32 + // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : !llvm.i32 + // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm<"{ float, i1 }"> + // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm<"{ float, i1 }"> + // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm<"{ float, i1 }"> + %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1) + + std.return + } +} + +// ----- + +module attributes {gpu.kernel_module} { // CHECK-LABEL: func @gpu_sync() func @gpu_sync() attributes { gpu.kernel } { diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index f8ed1a9d783..8323fdf8709 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -362,6 +362,20 @@ func @reduce_incorrect_yield(%arg0 : f32) { // ----- +func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) { + // expected-error@+1 {{'gpu.shuffle' op requires the same type for value operand and result}} + %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1) +} + +// ----- + +func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) { + // expected-error@+1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}} + %shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index +} + +// ----- + module { module @gpu_funcs attributes {gpu.kernel_module} { // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index ff5a40d64b4..1dd08cea492 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -81,6 +81,11 @@ module attributes {gpu.container_module} { %one = constant 1.0 : f32 %sum = "gpu.all_reduce"(%one) ({}) {op = "add"} : (f32) -> (f32) + %width = constant 7 : i32 + %offset = constant 3 : i32 + // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32 + %shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32 + "gpu.barrier"() : () -> () "some_op"(%bIdX, %tIdX) : (index, index) -> () diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir new file mode 100644 index 00000000000..1b01399cb6d --- /dev/null +++ b/mlir/test/mlir-cuda-runner/shuffle.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +// CHECK: [4, 5, 6, 7, 0, 1, 2, 3, 12, -1, -1, -1, 8] +func @main() { + %arg = alloc() : memref<13xf32> + %dst = memref_cast %arg : memref<13xf32> to memref<?xf32> + %one = constant 1 : index + %sx = dim %dst, 0 : memref<?xf32> + call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> () + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) + threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) + args(%kernel_dst = %dst) : memref<?xf32> { + %t0 = index_cast %tx : index to i32 + %val = sitofp %t0 : i32 to f32 + %width = index_cast %block_x : index to i32 + %offset = constant 4 : i32 + %shfl, %valid = gpu.shuffle %val, %offset, %width xor : f32 + cond_br %valid, ^bb1(%shfl : f32), ^bb0 + ^bb0: + %m1 = constant -1.0 : f32 + br ^bb1(%m1 : f32) + ^bb1(%value : f32): + store %value, %kernel_dst[%tx] : memref<?xf32> + gpu.return + } + %U = memref_cast %dst : memref<?xf32> to memref<*xf32> + call @print_memref_f32(%U) : (memref<*xf32>) -> () + return +} + +func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>) +func @print_memref_f32(%ptr : memref<*xf32>) |