diff options
-rw-r--r-- | mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h | 29 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h | 12 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/LoopUtils.h | 4 | ||||
-rw-r--r-- | mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 253 | ||||
-rw-r--r-- | mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp | 77 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 7 | ||||
-rw-r--r-- | mlir/test/Conversion/LoopsToGPU/imperfect_2D.mlir | 83 | ||||
-rw-r--r-- | mlir/test/Conversion/LoopsToGPU/imperfect_3D.mlir | 83 | ||||
-rw-r--r-- | mlir/test/Conversion/LoopsToGPU/imperfect_4D.mlir | 86 | ||||
-rw-r--r-- | mlir/test/Conversion/LoopsToGPU/imperfect_linalg.mlir | 48 | ||||
-rw-r--r-- | mlir/test/Conversion/LoopsToGPU/perfect_1D_setlaunch.mlir | 26 | ||||
-rw-r--r-- | mlir/test/Transforms/parametric-mapping.mlir | 20 |
12 files changed, 691 insertions, 37 deletions
diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index 973b995f10b..0aab8723eab 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -17,9 +17,12 @@ #ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ #define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ +#include "mlir/Support/LLVM.h" + namespace mlir { class AffineForOp; struct LogicalResult; +class Value; namespace loop { class ForOp; @@ -52,6 +55,32 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp, LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp, unsigned numBlockDims, unsigned numThreadDims); + +/// Convert a loop operation into a GPU launch with the values provided in +/// `numWorkGroups` as the grid size and the values provided in `workGroupSizes` +/// as the block size. Size of `numWorkGroups` and workGroupSizes` must be less +/// than or equal to 3. The loop operation can be an imperfectly nested +/// computation with the following restrictions: +/// 1) The loop nest must contain as many perfectly nested loops as the number +/// of values passed in through `numWorkGroups`. This corresponds to the number +/// of grid dimensions of the launch. All loops within the loop nest must be +/// parallel. +/// 2) The body of the innermost loop of the above perfectly nested loops, must +/// contain statements that satisfy one of the two conditions below: +/// a) A perfect loop nest of depth greater than or equal to the number of +/// values passed in through `workGroupSizes`, i.e. the number of thread +/// dimensions of the launch. Loops at depth less than or equal to size of +/// `workGroupSizes` must be parallel. Loops nested deeper can be sequential +/// and are retained as such in the generated GPU launch code. +/// b) Statements that are safe to be executed by all threads within the +/// workgroup. No checks are performed that this is indeed the case. +/// TODO(ravishankarm) : Add checks that verify 2(b) above. +/// The above conditions are assumed to be satisfied by the computation rooted +/// at `forOp`. +LogicalResult convertLoopToGPULaunch(loop::ForOp forOp, + ArrayRef<Value *> numWorkGroups, + ArrayRef<Value *> workGroupSizes); + } // namespace mlir #endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h index 960a93dd566..a42320c9bdf 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h @@ -17,6 +17,8 @@ #ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ #define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ +#include "mlir/Support/LLVM.h" + #include <memory> namespace mlir { @@ -33,6 +35,16 @@ template <typename T> class OpPassBase; /// calling the conversion. std::unique_ptr<OpPassBase<FuncOp>> createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims); + +/// Create a pass that converts every loop operation within the body of the +/// FuncOp into a GPU launch. The number of workgroups and workgroup size for +/// the implementation is controlled by SSA values passed into conversion +/// method. For testing, the values are set as constants obtained from a command +/// line flag. See convertLoopToGPULaunch for a description of the required +/// semantics of the converted loop operation. +std::unique_ptr<OpPassBase<FuncOp>> +createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups, + ArrayRef<int64_t> workGroupSize); } // namespace mlir #endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_ diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index 6bfe5564f78..5ca3f7f6510 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -224,8 +224,8 @@ void coalesceLoops(MutableArrayRef<loop::ForOp> loops); /// is rewritten into a version resembling the following pseudo-IR: /// /// ``` -/// loop.for %i = %lb + threadIdx.x + blockIdx.x * blockDim.x to %ub -/// step %gridDim.x * blockDim.x { +/// loop.for %i = %lb + %step * (threadIdx.x + blockIdx.x * blockDim.x) +/// to %ub step %gridDim.x * blockDim.x * %step { /// ... /// } /// ``` diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 2229455ef33..e33b8401c74 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -22,15 +22,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" + #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" +#include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/RegionUtils.h" - +#include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "loops-to-gpu" @@ -38,6 +40,8 @@ using namespace mlir; using namespace mlir::loop; +using llvm::seq; + // Extract an indexed value from KernelDim3. static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { @@ -97,12 +101,38 @@ static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) { } // Check the structure of the loop nest: -// - there are enough loops to map to numBlockDims + numThreadDims; +// - there are enough loops to map to numDims; // - the loops are perfectly nested; // - the loop bounds can be computed above the outermost loop. // This roughly corresponds to the "matcher" part of the pattern-based // rewriting infrastructure. template <typename OpTy> +LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) { + Region &limit = forOp.region(); + for (unsigned i = 0, e = numDims; i < e; ++i) { + Operation *nested = &forOp.getBody()->front(); + if (!areValuesDefinedAbove(getLowerBoundOperands(forOp), limit) || + !areValuesDefinedAbove(getUpperBoundOperands(forOp), limit)) + return forOp.emitError( + "loops with bounds depending on other mapped loops " + "are not supported"); + + // The innermost loop can have an arbitrary body, skip the perfect nesting + // check for it. + if (i == e - 1) + break; + + auto begin = forOp.getBody()->begin(), end = forOp.getBody()->end(); + if (forOp.getBody()->empty() || std::next(begin, 2) != end) + return forOp.emitError("expected perfectly nested loops in the body"); + + if (!(forOp = dyn_cast<OpTy>(nested))) + return nested->emitError("expected a nested loop"); + } + return success(); +} + +template <typename OpTy> LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, unsigned numThreadDims) { if (numBlockDims < 1 || numThreadDims < 1) { @@ -112,39 +142,61 @@ LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, OpBuilder builder(forOp.getOperation()); if (numBlockDims > 3) { - return emitError(builder.getUnknownLoc(), - "cannot map to more than 3 block dimensions"); + return forOp.emitError("cannot map to more than 3 block dimensions"); } if (numThreadDims > 3) { - return emitError(builder.getUnknownLoc(), - "cannot map to more than 3 thread dimensions"); + return forOp.emitError("cannot map to more than 3 thread dimensions"); } + return checkLoopNestMappableImpl(forOp, numBlockDims + numThreadDims); +} - OpTy currentLoop = forOp; - Region &limit = forOp.region(); - for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) { - Operation *nested = ¤tLoop.getBody()->front(); - if (!areValuesDefinedAbove(getLowerBoundOperands(currentLoop), limit) || - !areValuesDefinedAbove(getUpperBoundOperands(currentLoop), limit)) - return currentLoop.emitError( - "loops with bounds depending on other mapped loops " - "are not supported"); +template <typename OpTy> +LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims, + unsigned numThreadDims) { + if (numBlockDims < 1 || numThreadDims < 1) { + LLVM_DEBUG(llvm::dbgs() << "nothing to map"); + return success(); + } - // The innermost loop can have an arbitrary body, skip the perfect nesting - // check for it. - if (i == e - 1) - break; + if (numBlockDims > 3) { + return forOp.emitError("cannot map to more than 3 block dimensions"); + } + if (numThreadDims > 3) { + return forOp.emitError("cannot map to more than 3 thread dimensions"); + } + if (numBlockDims != numThreadDims) { + // TODO(ravishankarm) : This can probably be relaxed by having a one-trip + // loop for the missing dimension, but there is not reason to handle this + // case for now. + return forOp.emitError( + "mismatch in block dimensions and thread dimensions"); + } - auto begin = currentLoop.getBody()->begin(), - end = currentLoop.getBody()->end(); - if (currentLoop.getBody()->empty() || std::next(begin, 2) != end) - return currentLoop.emitError( - "expected perfectly nested loops in the body"); + // Check that the forOp contains perfectly nested loops for numBlockDims + if (failed(checkLoopNestMappableImpl(forOp, numBlockDims))) { + return failure(); + } - if (!(currentLoop = dyn_cast<OpTy>(nested))) - return nested->emitError("expected a nested loop"); + // Get to the innermost loop. + for (auto i : seq<unsigned>(0, numBlockDims - 1)) { + forOp = cast<OpTy>(&forOp.getBody()->front()); + (void)i; } + // The forOp now points to the body of the innermost loop mapped to blocks. + for (Operation &op : *forOp.getBody()) { + // If the operation is a loop, check that it is mappable to workItems. + if (auto innerLoop = dyn_cast<OpTy>(&op)) { + if (failed(checkLoopNestMappableImpl(innerLoop, numThreadDims))) { + return failure(); + } + continue; + } + // TODO(ravishankarm) : If it is not a loop op, it is assumed that the + // statement is executed by all threads. It might be a collective operation, + // or some non-side effect instruction. Have to decide on "allowable" + // statements and check for those here. + } return success(); } @@ -215,10 +267,140 @@ Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp, return currentLoop; } +/// Given `nDims` perfectly nested loops rooted as `rootForOp`, convert them o +/// be partitioned across workgroups or workitems. The values for the +/// workgroup/workitem id along each dimension is passed in with `ids`. The +/// number of workgroups/workitems along each dimension are passed in with +/// `nids`. The innermost loop is mapped to the x-dimension, followed by the +/// next innermost loop to y-dimension, followed by z-dimension. +template <typename OpTy> +OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value *> ids, + ArrayRef<Value *> nids) { + auto nDims = ids.size(); + assert(nDims == nids.size()); + for (auto dim : llvm::seq<unsigned>(0, nDims)) { + // TODO(ravishankarm): Don't always need to generate a loop here. If nids >= + // number of iterations of the original loop, this becomes a if + // condition. Though that does rely on how the workgroup/workitem sizes are + // specified to begin with. + mapLoopToProcessorIds(rootForOp, ids[dim], nids[dim]); + if (dim != nDims - 1) { + rootForOp = cast<OpTy>(rootForOp.getBody()->front()); + } + } + return rootForOp; +} + +/// Utility method to convert the gpu::KernelDim3 object for representing id of +/// each workgroup/workitem and number of workgroup/workitems along a dimension +/// of the launch into a container. +void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids, + unsigned nDims, SmallVectorImpl<Value *> &ids, + SmallVectorImpl<Value *> &nids) { + assert(nDims <= 3 && "invalid number of launch dimensions"); + SmallVector<Value *, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; + SmallVector<Value *, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; + ids.clear(); + ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end()); + nids.clear(); + nids.append(std::next(allNids.begin(), allNids.size() - nDims), + allNids.end()); +} + +/// Generate the body of the launch operation. +template <typename OpTy> +LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, + gpu::LaunchOp launchOp, unsigned numBlockDims, + unsigned numThreadDims) { + OpBuilder::InsertionGuard bodyInsertionGuard(builder); + builder.setInsertionPointToEnd(&launchOp.getBody().front()); + auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc()); + + rootForOp.getOperation()->moveBefore(returnOp); + SmallVector<Value *, 3> workgroupID, numWorkGroups; + packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims, + workgroupID, numWorkGroups); + + // Partition the loop for mapping to workgroups. + auto loopOp = createGPULaunchLoops(rootForOp, workgroupID, numWorkGroups); + + // Iterate over the body of the loopOp and get the loops to partition for + // thread blocks. + SmallVector<OpTy, 1> threadRootForOps; + for (Operation &op : *loopOp.getBody()) { + if (auto threadRootForOp = dyn_cast<OpTy>(&op)) { + threadRootForOps.push_back(threadRootForOp); + } + } + + SmallVector<Value *, 3> workItemID, workGroupSize; + packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(), + numThreadDims, workItemID, workGroupSize); + for (auto &loopOp : threadRootForOps) { + builder.setInsertionPoint(loopOp); + createGPULaunchLoops(loopOp, workItemID, workGroupSize); + } + return success(); +} + +// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the +// given workgroup size and number of workgroups. +template <typename OpTy> +LogicalResult createLaunchFromOp(OpTy rootForOp, + ArrayRef<Value *> numWorkGroups, + ArrayRef<Value *> workGroupSizes) { + OpBuilder builder(rootForOp.getOperation()); + if (numWorkGroups.size() > 3) { + return rootForOp.emitError("invalid ") + << numWorkGroups.size() << "-D workgroup specification"; + } + auto loc = rootForOp.getLoc(); + Value *one = builder.create<ConstantOp>( + loc, builder.getIntegerAttr(builder.getIndexType(), 1)); + SmallVector<Value *, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one); + for (auto numWorkGroup : enumerate(numWorkGroups)) { + numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value(); + } + for (auto workGroupSize : enumerate(workGroupSizes)) { + workGroupSize3D[workGroupSize.index()] = workGroupSize.value(); + } + + // Get the values used within the region of the rootForOp but defined above + // it. + llvm::SetVector<Value *> valuesToForwardSet; + getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(), + valuesToForwardSet); + // Also add the values used for the lb, ub, and step of the rootForOp. + valuesToForwardSet.insert(rootForOp.getOperands().begin(), + rootForOp.getOperands().end()); + auto valuesToForward = valuesToForwardSet.takeVector(); + auto launchOp = builder.create<gpu::LaunchOp>( + rootForOp.getLoc(), numWorkGroups3D[0], numWorkGroups3D[1], + numWorkGroups3D[2], workGroupSize3D[0], workGroupSize3D[1], + workGroupSize3D[2], valuesToForward); + if (failed(createLaunchBody(builder, rootForOp, launchOp, + numWorkGroups.size(), workGroupSizes.size()))) { + return failure(); + } + + // Replace values that are used within the region of the launchOp but are + // defined outside. They all are replaced with kernel arguments. + for (const auto &pair : + llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { + Value *from = std::get<0>(pair); + Value *to = std::get<1>(pair); + replaceAllUsesInRegionWith(from, to, launchOp.getBody()); + } + return success(); +} + // Replace the rooted at "rootForOp" with a GPU launch operation. This expects // "innermostForOp" to point to the last loop to be transformed to the kernel, // and to have (numBlockDims + numThreadDims) perfectly nested loops between // "rootForOp" and "innermostForOp". +// TODO(ravishankarm) : This method can be modified to use the +// createLaunchFromOp method, since that is a strict generalization of this +// method. template <typename OpTy> void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, unsigned numBlockDims, @@ -324,6 +506,19 @@ static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, return success(); } +// Generic loop to GPU kernel conversion function when loop is imperfectly +// nested. The workgroup size and num workgroups is provided as input +template <typename OpTy> +static LogicalResult convertLoopToGPULaunch(OpTy forOp, + ArrayRef<Value *> numWorkGroups, + ArrayRef<Value *> workGroupSize) { + if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(), + workGroupSize.size()))) { + return failure(); + } + return createLaunchFromOp(forOp, numWorkGroups, workGroupSize); +} + LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp, unsigned numBlockDims, unsigned numThreadDims) { @@ -335,3 +530,9 @@ LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp, unsigned numThreadDims) { return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); } + +LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp, + ArrayRef<Value *> numWorkGroups, + ArrayRef<Value *> workGroupSizes) { + return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes); +} diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 6d4cb9d8256..21abc3cf99b 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -19,11 +19,14 @@ #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/CommandLine.h" #define PASS_NAME "convert-loops-to-gpu" +#define LOOPOP_TO_GPU_PASS_NAME "convert-loop-op-to-gpu" using namespace mlir; using namespace mlir::loop; @@ -38,6 +41,19 @@ static llvm::cl::opt<unsigned> clNumThreadDims( llvm::cl::desc("Number of GPU thread dimensions for mapping"), llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u)); +static llvm::cl::OptionCategory clLoopOpToGPUCategory(LOOPOP_TO_GPU_PASS_NAME + " options"); +static llvm::cl::list<unsigned> + clNumWorkGroups("gpu-num-workgroups", + llvm::cl::desc("Num workgroups in the GPU launch"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clLoopOpToGPUCategory)); +static llvm::cl::list<unsigned> + clWorkGroupSize("gpu-workgroup-size", + llvm::cl::desc("Workgroup Size in the GPU launch"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clLoopOpToGPUCategory)); + namespace { // A pass that traverses top-level loops in the function and converts them to // GPU launch operations. Nested launches are not allowed, so this does not @@ -64,6 +80,50 @@ struct ForLoopMapper : public FunctionPass<ForLoopMapper> { unsigned numBlockDims; unsigned numThreadDims; }; + +// A pass that traverses top-level loops in the function and convertes them to +// GPU launch operations. The top-level loops itself does not have to be +// perfectly nested. The only requirement is that there be as many perfectly +// nested loops as the size of `numWorkGroups`. Within these any loop nest has +// to be perfectly nested upto depth equal to size of `workGroupSize`. +struct ImperfectlyNestedForLoopMapper + : public FunctionPass<ImperfectlyNestedForLoopMapper> { + ImperfectlyNestedForLoopMapper(ArrayRef<int64_t> numWorkGroups, + ArrayRef<int64_t> workGroupSize) + : numWorkGroups(numWorkGroups.begin(), numWorkGroups.end()), + workGroupSize(workGroupSize.begin(), workGroupSize.end()) {} + + void runOnFunction() override { + // Insert the num work groups and workgroup sizes as constant values. This + // pass is only used for testing. + FuncOp funcOp = getFunction(); + OpBuilder builder(funcOp.getOperation()->getRegion(0)); + SmallVector<Value *, 3> numWorkGroupsVal, workGroupSizeVal; + for (auto val : numWorkGroups) { + auto constOp = builder.create<ConstantOp>( + funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); + numWorkGroupsVal.push_back(constOp); + } + for (auto val : workGroupSize) { + auto constOp = builder.create<ConstantOp>( + funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); + workGroupSizeVal.push_back(constOp); + } + for (Block &block : getFunction()) { + for (Operation &op : llvm::make_early_inc_range(block)) { + if (auto forOp = dyn_cast<ForOp>(&op)) { + if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, + workGroupSizeVal))) { + return signalPassFailure(); + } + } + } + } + } + SmallVector<int64_t, 3> numWorkGroups; + SmallVector<int64_t, 3> workGroupSize; +}; + } // namespace std::unique_ptr<OpPassBase<FuncOp>> @@ -72,8 +132,25 @@ mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, return std::make_unique<ForLoopMapper>(numBlockDims, numThreadDims); } +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups, + ArrayRef<int64_t> workGroupSize) { + return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups, + workGroupSize); +} + static PassRegistration<ForLoopMapper> registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] { return std::make_unique<ForLoopMapper>(clNumBlockDims.getValue(), clNumThreadDims.getValue()); }); + +static PassRegistration<ImperfectlyNestedForLoopMapper> loopOpToGPU( + LOOPOP_TO_GPU_PASS_NAME, "Convert top-level loop::ForOp to GPU kernels", + [] { + SmallVector<int64_t, 3> numWorkGroups, workGroupSize; + numWorkGroups.assign(clNumWorkGroups.begin(), clNumWorkGroups.end()); + workGroupSize.assign(clWorkGroupSize.begin(), clWorkGroupSize.end()); + return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups, + workGroupSize); + }); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index e09d8c89b37..405116e72e7 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1118,11 +1118,12 @@ void mlir::mapLoopToProcessorIds(loop::ForOp forOp, for (unsigned i = 1, e = processorId.size(); i < e; ++i) mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]), processorId[i]); - Value *lb = b.create<AddIOp>(loc, forOp.lowerBound(), mul); + Value *lb = b.create<AddIOp>(loc, forOp.lowerBound(), + b.create<MulIOp>(loc, forOp.step(), mul)); forOp.setLowerBound(lb); - Value *step = numProcessors.front(); - for (auto *numProcs : numProcessors.drop_front()) + Value *step = forOp.step(); + for (auto *numProcs : numProcessors) step = b.create<MulIOp>(loc, step, numProcs); forOp.setStep(step); } diff --git a/mlir/test/Conversion/LoopsToGPU/imperfect_2D.mlir b/mlir/test/Conversion/LoopsToGPU/imperfect_2D.mlir new file mode 100644 index 00000000000..cc80954c6a9 --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/imperfect_2D.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt -convert-loop-op-to-gpu -gpu-num-workgroups=2,2 -gpu-workgroup-size=32,4 %s | FileCheck %s + +module { + // arg2 = arg0 * transpose(arg1) ; with intermediate buffer and tile size passed as argument + // CHECK: func {{@.*}}([[ARG0:%.*]]: memref<?x?xf32>, [[ARG1:%.*]]: memref<?x?xf32>, [[ARG2:%.*]]: memref<?x?xf32>, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index) + func @foo(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>, %arg3 : index, %arg4 : index) { + %0 = dim %arg0, 0 : memref<?x?xf32> + %1 = dim %arg0, 1 : memref<?x?xf32> + %c0 = constant 0 : index + %c1 = constant 1 : index + // CHECK: gpu.launch blocks([[ARG5:%.*]], [[ARG6:%.*]], [[ARG7:%.*]]) in ([[ARG11:%.*]] = {{%.*}}, [[ARG12:%.*]] = {{%.*}}, [[ARG13:%.*]] = {{%.*}}) threads([[ARG8:%.*]], [[ARG9:%.*]], [[ARG10:%.*]]) in ([[ARG14:%.*]] = {{%.*}}, [[ARG15:%.*]] = {{%.*}}, [[ARG16:%.*]] = {{%.*}}) args([[ARG17:%.*]] = [[ARG3]], [[ARG18:%.*]] = [[ARG4]], [[ARG19:%.*]] = [[ARG1]], [[ARG20:%.*]] = {{%.*}}, {{%.*}} = {{%.*}}, [[ARG22:%.*]] = [[ARG0]], [[ARG23:%.*]] = [[ARG2]] + // CHECK: [[TEMP1:%.*]] = muli [[ARG17]], [[ARG6]] : index + // CHECK: [[BLOCKLOOPYLB:%.*]] = addi {{%.*}}, [[TEMP1]] : index + // CHECK: [[BLOCKLOOPYSTEP:%.*]] = muli [[ARG17]], [[ARG12]] : index + // CHECK: loop.for [[BLOCKLOOPYIV:%.*]] = [[BLOCKLOOPYLB]] to {{%.*}} step [[BLOCKLOOPYSTEP]] + loop.for %iv1 = %c0 to %0 step %arg3 { + + // CHECK: [[TEMP2:%.*]] = muli [[ARG18]], [[ARG5]] : index + // CHECK: [[BLOCKLOOPXLB:%.*]] = addi {{%.*}}, [[TEMP2]] : index + // CHECK: [[BLOCKLOOPXSTEP:%.*]] = muli [[ARG18]], [[ARG11]] : index + // CHECK: loop.for [[BLOCKLOOPXIV:%.*]] = [[BLOCKLOOPXLB]] to {{%.*}} step [[BLOCKLOOPXSTEP]] + + loop.for %iv2 = %c0 to %1 step %arg4 { + + // TODO: This is effectively shared memory. Lower it to a + // shared memory. + %2 = alloc(%arg3, %arg4) : memref<?x?xf32> + + // Load transpose tile + // CHECK: [[TEMP3:%.*]] = muli [[ARG20]], [[ARG9:%.*]] : index + // CHECK: [[THREADLOOP1YLB:%.*]] = addi {{%.*}}, [[TEMP3]] : index + // CHECK: [[THREADLOOP1YSTEP:%.*]] = muli [[ARG20]], [[ARG15]] : index + // CHECK: loop.for [[THREADLOOP1YIV:%.*]] = [[THREADLOOP1YLB]] to {{%.*}} step [[THREADLOOP1YSTEP]] + loop.for %iv3 = %c0 to %arg3 step %c1 { + // CHECK: [[TEMP4:%.*]] = muli [[ARG20]], [[ARG8]] : index + // CHECK: [[THREADLOOP1XLB:%.*]] = addi {{%.*}}, [[TEMP4]] : index + // CHECK: [[THREADLOOP1XSTEP:%.*]] = muli [[ARG20]], [[ARG14]] : index + // CHECK: loop.for [[THREADLOOP1XIV:%.*]] = [[THREADLOOP1XLB]] to {{%.*}} step [[THREADLOOP1XSTEP]] + loop.for %iv4 = %c1 to %arg4 step %c1 { + // CHECK: [[INDEX2:%.*]] = addi [[BLOCKLOOPYIV]], [[THREADLOOP1YIV]] : index + %10 = addi %iv1, %iv3 : index + // CHECK: [[INDEX1:%.*]] = addi [[BLOCKLOOPXIV]], [[THREADLOOP1XIV]] : index + %11 = addi %iv2, %iv4 : index + // CHECK: [[VAL1:%.*]] = load [[ARG19]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} : memref<?x?xf32> + %12 = load %arg1[%11, %10] : memref<?x?xf32> + // CHECK: store [[VAL1]], [[SCRATCHSPACE:%.*]]{{\[}}[[THREADLOOP1XIV]], [[THREADLOOP1YIV]]{{\]}} : memref<?x?xf32> + store %12, %2[%iv4, %iv3] : memref<?x?xf32> + } + } + + // TODO: There needs to be a sync here for correctness, but + // testing only loop partitioning for now. + + // CHECK: [[TEMP5:%.*]] = muli [[ARG20]], [[ARG9]] : index + // CHECK: [[THREADLOOP2YLB:%.*]] = addi {{%.*}}, [[TEMP5]] : index + // CHECK: [[THREADLOOP2YSTEP:%.*]] = muli [[ARG20]], [[ARG15]] : index + // CHECK: loop.for [[THREADLOOP2YIV:%.*]] = [[THREADLOOP2YLB]] to {{%.*}} step [[THREADLOOP2YSTEP]] + loop.for %iv3 = %c0 to %arg3 step %c1 { + // CHECK: [[TEMP6:%.*]] = muli [[ARG20]], [[ARG8]] : index + // CHECK: [[THREADLOOP2XLB:%.*]] = addi {{%.*}}, [[TEMP6]] : index + // CHECK: [[THREADLOOP2XSTEP:%.*]] = muli [[ARG20]], [[ARG14]] : index + // CHECK: loop.for [[THREADLOOP2XIV:%.*]] = [[THREADLOOP2XLB]] to {{%.*}} step [[THREADLOOP2XSTEP]] + loop.for %iv4 = %c1 to %arg4 step %c1 { + // CHECK: [[INDEX3:%.*]] = addi [[BLOCKLOOPYIV]], [[THREADLOOP2YIV]] : index + %13 = addi %iv1, %iv3 : index + // CHECK: [[INDEX4:%.*]] = addi [[BLOCKLOOPXIV]], [[THREADLOOP2XIV]] : index + %14 = addi %iv2, %iv4 : index + // CHECK: {{%.*}} = load [[SCRATCHSPACE]]{{\[}}[[THREADLOOP2XIV]], [[THREADLOOP2YIV]]{{\]}} : memref<?x?xf32> + %15 = load %2[%iv4, %iv3] : memref<?x?xf32> + // CHECK: {{%.*}} = load [[ARG22]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}} + %16 = load %arg0[%13, %14] : memref<?x?xf32> + %17 = mulf %15, %16 : f32 + // CHECK: store {{%.*}}, [[ARG23]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}} + store %17, %arg2[%13, %14] : memref<?x?xf32> + } + } + + dealloc %2 : memref<?x?xf32> + } + } + return + } +}
\ No newline at end of file diff --git a/mlir/test/Conversion/LoopsToGPU/imperfect_3D.mlir b/mlir/test/Conversion/LoopsToGPU/imperfect_3D.mlir new file mode 100644 index 00000000000..4741c385533 --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/imperfect_3D.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt -convert-loop-op-to-gpu -gpu-num-workgroups=4,2,2 -gpu-workgroup-size=32,2,2 %s | FileCheck %s + +module { + func @imperfect_3D(%arg0 : memref<?x?x?xf32>, %arg1 : memref<?x?x?xf32>, %arg2 : memref<?x?x?xf32>, %arg3 : memref<?x?x?xf32>, %t1 : index, %t2 : index, %t3 : index, %step1 : index, %step2 : index, %step3 : index) { + %0 = dim %arg0, 0 : memref<?x?x?xf32> + %1 = dim %arg0, 1 : memref<?x?x?xf32> + %2 = dim %arg0, 2 : memref<?x?x?xf32> + %c0 = constant 0 : index + // CHECK: gpu.launch + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: alloc + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: load + // CHECK: load + // CHECK: addf + // CHECK: store + // CHECK: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: loop.for {{.*}} { + // CHECK: load + // CHECK: load + // CHECK: mulf + // CHECK: store + // CHECK: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK: dealloc + loop.for %iv1 = %c0 to %0 step %t1 { + loop.for %iv2 = %c0 to %1 step %t2 { + loop.for %iv3 = %c0 to %2 step %t3 { + %6 = alloc(%t1, %t2, %t3) : memref<?x?x?xf32> + %ubcmp1 = cmpi "slt", %0, %t1 : index + %ub1 = select %ubcmp1, %0, %t1 : index + %ubcmp2 = cmpi "slt", %1, %t2 : index + %ub2 = select %ubcmp2, %1, %t2 : index + %ubcmp3 = cmpi "slt", %2, %t3 : index + %ub3 = select %ubcmp3, %2, %t3 : index + loop.for %iv4 = %iv1 to %ub1 step %step1 { + loop.for %iv5 = %iv2 to %ub2 step %step2 { + loop.for %iv6 = %iv3 to %ub3 step %step3 { + %7 = load %arg0[%iv4, %iv5, %iv6] : memref<?x?x?xf32> + %8 = load %arg1[%iv4, %iv6, %iv5] : memref<?x?x?xf32> + %9 = addf %7, %8 : f32 + %10 = subi %iv4, %iv1 : index + %11 = divis %10, %step1 : index + %12 = subi %iv5, %iv2 : index + %13 = divis %12, %step2 : index + %14 = subi %iv6, %iv3 : index + %15 = divis %14, %step3 : index + store %9, %6[%11, %13, %15] : memref<?x?x?xf32> + } + } + } + loop.for %iv7 = %iv1 to %ub1 step %step1 { + loop.for %iv8 = %iv2 to %ub2 step %step2 { + loop.for %iv9 = %iv3 to %ub3 step %step3 { + %16 = subi %iv7, %iv1 : index + %17 = divis %16, %step1 : index + %18 = subi %iv8, %iv2 : index + %19 = divis %18, %step2 : index + %20 = subi %iv9, %iv3 : index + %21 = divis %20, %step3 : index + %22 = load %6[%17, %19, %21] : memref<?x?x?xf32> + %23 = load %arg2[%iv9, %iv8, %iv7] : memref<?x?x?xf32> + %24 = mulf %22, %23 : f32 + store %24, %arg3[%iv7, %iv8, %iv9] : memref<?x?x?xf32> + } + } + } + dealloc %6 : memref<?x?x?xf32> + } + } + } + return + } +}
\ No newline at end of file diff --git a/mlir/test/Conversion/LoopsToGPU/imperfect_4D.mlir b/mlir/test/Conversion/LoopsToGPU/imperfect_4D.mlir new file mode 100644 index 00000000000..2753cd28188 --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/imperfect_4D.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt -convert-loop-op-to-gpu -gpu-num-workgroups=4,2,2 -gpu-workgroup-size=32,2,2 %s | FileCheck %s + +module { + func @imperfect_3D(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>, %arg2 : memref<?x?x?x?xf32>, %arg3 : memref<?x?x?x?xf32>, %t1 : index, %t2 : index, %t3 : index, %t4 : index, %step1 : index, %step2 : index, %step3 : index, %step4 : index) { + %0 = dim %arg0, 0 : memref<?x?x?x?xf32> + %1 = dim %arg0, 1 : memref<?x?x?x?xf32> + %2 = dim %arg0, 2 : memref<?x?x?x?xf32> + %3 = dim %arg0, 3 : memref<?x?x?x?xf32> + %c0 = constant 0 : index + // CHECK: gpu.launch + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: alloc + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: load + // CHECK: load + // CHECK: addf + // CHECK: store + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: load + // CHECK: load + // CHECK: mulf + // CHECK: store + // CHECK: dealloc + loop.for %iv1 = %c0 to %0 step %t1 { + loop.for %iv2 = %c0 to %1 step %t2 { + loop.for %iv3 = %c0 to %2 step %t3 { + %6 = alloc(%t1, %t2, %t3, %3) : memref<?x?x?x?xf32> + %ubcmp1 = cmpi "slt", %0, %t1 : index + %ub1 = select %ubcmp1, %0, %t1 : index + %ubcmp2 = cmpi "slt", %1, %t2 : index + %ub2 = select %ubcmp2, %1, %t2 : index + %ubcmp3 = cmpi "slt", %2, %t3 : index + %ub3 = select %ubcmp3, %2, %t3 : index + %ubcmp4 = cmpi "slt", %3, %t4 : index + %ub4 = select %ubcmp3, %3, %t4 : index + loop.for %iv5 = %iv1 to %ub1 step %step1 { + loop.for %iv6 = %iv2 to %ub2 step %step2 { + loop.for %iv7 = %iv3 to %ub3 step %step3 { + loop.for %iv8 = %c0 to %3 step %step4 { + %7 = load %arg0[%iv5, %iv6, %iv7, %iv8] : memref<?x?x?x?xf32> + %8 = load %arg1[%iv5, %iv6, %iv7, %iv8] : memref<?x?x?x?xf32> + %9 = addf %7, %8 : f32 + %10 = subi %iv5, %iv1 : index + %11 = divis %10, %step1 : index + %12 = subi %iv6, %iv2 : index + %13 = divis %12, %step2 : index + %14 = subi %iv7, %iv3 : index + %15 = divis %14, %step3 : index + store %9, %6[%11, %13, %15, %iv8] : memref<?x?x?x?xf32> + } + } + } + } + loop.for %iv9 = %iv1 to %ub1 step %step1 { + loop.for %iv10 = %iv2 to %ub2 step %step2 { + loop.for %iv11 = %iv3 to %ub3 step %step3 { + loop.for %iv12 = %c0 to %3 step %step4 { + %18 = subi %iv9, %iv1 : index + %19 = divis %18, %step1 : index + %20 = subi %iv10, %iv2 : index + %21 = divis %20, %step2 : index + %22 = subi %iv11, %iv3 : index + %23 = divis %22, %step3 : index + %26 = load %6[%19, %21, %23, %iv12] : memref<?x?x?x?xf32> + %27 = load %arg2[%iv9, %iv10, %iv12, %iv11] : memref<?x?x?x?xf32> + %28 = mulf %26, %27 : f32 + store %28, %arg3[%iv9, %iv10, %iv11, %iv12] : memref<?x?x?x?xf32> + } + } + } + } + dealloc %6 : memref<?x?x?x?xf32> + } + } + } + return + } +}
\ No newline at end of file diff --git a/mlir/test/Conversion/LoopsToGPU/imperfect_linalg.mlir b/mlir/test/Conversion/LoopsToGPU/imperfect_linalg.mlir new file mode 100644 index 00000000000..cdcaf42c1d1 --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/imperfect_linalg.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -convert-loop-op-to-gpu -gpu-num-workgroups=2,16 -gpu-workgroup-size=32,4 | FileCheck %s + +#map0 = (d0) -> (d0 + 2) +#map1 = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +module { + func @fmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) { + %c1 = constant 1 : index + %c0 = constant 0 : index + %c2 = constant 2 : index + %0 = dim %arg0, 0 : memref<?x?xf32> + %1 = dim %arg0, 1 : memref<?x?xf32> + // CHECK-LABEL: gpu.launch + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: loop.for + // CHECK: load + // CHECK: load + // CHECK: load + // CHECK: mulf + // CHECK: store + loop.for %arg3 = %c0 to %0 step %c2 { + loop.for %arg4 = %c0 to %1 step %c2 { + %2 = affine.apply #map0(%arg3) + %3 = affine.apply #map0(%arg4) + %4 = linalg.subview %arg0[%arg3, %2, %c1, %arg4, %3, %c1] : memref<?x?xf32> + %5 = affine.apply #map0(%arg3) + %6 = affine.apply #map0(%arg4) + %7 = linalg.subview %arg1[%arg3, %5, %c1, %arg4, %6, %c1] : memref<?x?xf32> + %8 = affine.apply #map0(%arg3) + %9 = affine.apply #map0(%arg4) + %10 = linalg.subview %arg2[%arg3, %8, %c1, %arg4, %9, %c1] : memref<?x?xf32> + %11 = dim %4, 0 : memref<?x?xf32, #map1> + %12 = dim %4, 1 : memref<?x?xf32, #map1> + loop.for %arg5 = %c0 to %11 step %c1 { + loop.for %arg6 = %c0 to %12 step %c1 { + %13 = load %4[%arg5, %arg6] : memref<?x?xf32, #map1> + %14 = load %7[%arg5, %arg6] : memref<?x?xf32, #map1> + %15 = load %10[%arg5, %arg6] : memref<?x?xf32, #map1> + %16 = mulf %13, %14 : f32 + store %16, %10[%arg5, %arg6] : memref<?x?xf32, #map1> + } + } + } + } + return + } +} diff --git a/mlir/test/Conversion/LoopsToGPU/perfect_1D_setlaunch.mlir b/mlir/test/Conversion/LoopsToGPU/perfect_1D_setlaunch.mlir new file mode 100644 index 00000000000..bf437a348b6 --- /dev/null +++ b/mlir/test/Conversion/LoopsToGPU/perfect_1D_setlaunch.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt -convert-loop-op-to-gpu -gpu-num-workgroups=2 -gpu-workgroup-size=32 %s | FileCheck %s + +module { + func @foo(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>) { + %0 = dim %arg0, 0 : memref<?x?xf32> + %1 = dim %arg0, 1 : memref<?x?xf32> + %c0 = constant 0 : index + %c1 = constant 1 : index + // CHECK: gpu.launch + // CHECK: loop.for + // CHECK: loop.for + // CHECK: load + // CHECK: load + // CHECK: add + // CHECK: store + loop.for %iv1 = %c0 to %0 step %c1 { + loop.for %iv2 = %c0 to %1 step %c1 { + %12 = load %arg0[%iv1, %iv2] : memref<?x?xf32> + %13 = load %arg1[%iv2, %iv1] : memref<?x?xf32> + %14 = addf %12, %13 : f32 + store %12, %arg2[%iv1, %iv2] : memref<?x?xf32> + } + } + return + } +}
\ No newline at end of file diff --git a/mlir/test/Transforms/parametric-mapping.mlir b/mlir/test/Transforms/parametric-mapping.mlir index debee29f940..fdfd8cf526e 100644 --- a/mlir/test/Transforms/parametric-mapping.mlir +++ b/mlir/test/Transforms/parametric-mapping.mlir @@ -6,8 +6,10 @@ func @map1d(%lb: index, %ub: index, %step: index) { // CHECK: %[[threads:.*]]:2 = "new_processor_id_and_range"() : () -> (index, index) %0:2 = "new_processor_id_and_range"() : () -> (index, index) - // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[threads]]#0 - // CHECK: loop.for %{{.*}} = %[[new_lb]] to %[[ub]] step %[[threads]]#1 { + // CHECK: %[[thread_offset:.*]] = muli %[[step]], %[[threads]]#0 + // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[thread_offset]] + // CHECK: %[[new_step:.*]] = muli %[[step]], %[[threads]]#1 + // CHECK: loop.for %{{.*}} = %[[new_lb]] to %[[ub]] step %[[new_step]] { loop.for %i = %lb to %ub step %step {} return } @@ -27,11 +29,17 @@ func @map2d(%lb : index, %ub : index, %step : index) { // threadIdx.x + blockIdx.x * blockDim.x // CHECK: %[[tidxpbidxXbdimx:.*]] = addi %[[bidxXbdimx]], %[[threads]]#0 : index // - // new_lb = lb + threadIdx.x + blockIdx.x * blockDim.x - // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[tidxpbidxXbdimx]] : index + // thread_offset = step * (threadIdx.x + blockIdx.x * blockDim.x) + // CHECK: %[[thread_offset:.*]] = muli %[[step]], %[[tidxpbidxXbdimx]] : index + // + // new_lb = lb + thread_offset + // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[thread_offset]] : index + // + // stepXgdimx = step * gridDim.x + // CHECK: %[[stepXgdimx:.*]] = muli %[[step]], %[[blocks]]#1 : index // - // new_step = gridDim.x * blockDim.x - // CHECK: %[[new_step:.*]] = muli %[[blocks]]#1, %[[threads]]#1 : index + // new_step = step * gridDim.x * blockDim.x + // CHECK: %[[new_step:.*]] = muli %[[stepXgdimx]], %[[threads]]#1 : index // // CHECK: loop.for %{{.*}} = %[[new_lb]] to %[[ub]] step %[[new_step]] { loop.for %i = %lb to %ub step %step {} |