summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/LoopsToGPU
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/LoopsToGPU')
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp253
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp77
2 files changed, 304 insertions, 26 deletions
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index 2229455ef33..e33b8401c74 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -22,15 +22,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/LowerAffine.h"
#include "mlir/Transforms/RegionUtils.h"
-
+#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "loops-to-gpu"
@@ -38,6 +40,8 @@
using namespace mlir;
using namespace mlir::loop;
+using llvm::seq;
+
// Extract an indexed value from KernelDim3.
static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
switch (pos) {
@@ -97,12 +101,38 @@ static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
}
// Check the structure of the loop nest:
-// - there are enough loops to map to numBlockDims + numThreadDims;
+// - there are enough loops to map to numDims;
// - the loops are perfectly nested;
// - the loop bounds can be computed above the outermost loop.
// This roughly corresponds to the "matcher" part of the pattern-based
// rewriting infrastructure.
template <typename OpTy>
+LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) {
+ Region &limit = forOp.region();
+ for (unsigned i = 0, e = numDims; i < e; ++i) {
+ Operation *nested = &forOp.getBody()->front();
+ if (!areValuesDefinedAbove(getLowerBoundOperands(forOp), limit) ||
+ !areValuesDefinedAbove(getUpperBoundOperands(forOp), limit))
+ return forOp.emitError(
+ "loops with bounds depending on other mapped loops "
+ "are not supported");
+
+ // The innermost loop can have an arbitrary body, skip the perfect nesting
+ // check for it.
+ if (i == e - 1)
+ break;
+
+ auto begin = forOp.getBody()->begin(), end = forOp.getBody()->end();
+ if (forOp.getBody()->empty() || std::next(begin, 2) != end)
+ return forOp.emitError("expected perfectly nested loops in the body");
+
+ if (!(forOp = dyn_cast<OpTy>(nested)))
+ return nested->emitError("expected a nested loop");
+ }
+ return success();
+}
+
+template <typename OpTy>
LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
unsigned numThreadDims) {
if (numBlockDims < 1 || numThreadDims < 1) {
@@ -112,39 +142,61 @@ LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
OpBuilder builder(forOp.getOperation());
if (numBlockDims > 3) {
- return emitError(builder.getUnknownLoc(),
- "cannot map to more than 3 block dimensions");
+ return forOp.emitError("cannot map to more than 3 block dimensions");
}
if (numThreadDims > 3) {
- return emitError(builder.getUnknownLoc(),
- "cannot map to more than 3 thread dimensions");
+ return forOp.emitError("cannot map to more than 3 thread dimensions");
}
+ return checkLoopNestMappableImpl(forOp, numBlockDims + numThreadDims);
+}
- OpTy currentLoop = forOp;
- Region &limit = forOp.region();
- for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) {
- Operation *nested = &currentLoop.getBody()->front();
- if (!areValuesDefinedAbove(getLowerBoundOperands(currentLoop), limit) ||
- !areValuesDefinedAbove(getUpperBoundOperands(currentLoop), limit))
- return currentLoop.emitError(
- "loops with bounds depending on other mapped loops "
- "are not supported");
+template <typename OpTy>
+LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims,
+ unsigned numThreadDims) {
+ if (numBlockDims < 1 || numThreadDims < 1) {
+ LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ return success();
+ }
- // The innermost loop can have an arbitrary body, skip the perfect nesting
- // check for it.
- if (i == e - 1)
- break;
+ if (numBlockDims > 3) {
+ return forOp.emitError("cannot map to more than 3 block dimensions");
+ }
+ if (numThreadDims > 3) {
+ return forOp.emitError("cannot map to more than 3 thread dimensions");
+ }
+ if (numBlockDims != numThreadDims) {
+ // TODO(ravishankarm) : This can probably be relaxed by having a one-trip
+ // loop for the missing dimension, but there is not reason to handle this
+ // case for now.
+ return forOp.emitError(
+ "mismatch in block dimensions and thread dimensions");
+ }
- auto begin = currentLoop.getBody()->begin(),
- end = currentLoop.getBody()->end();
- if (currentLoop.getBody()->empty() || std::next(begin, 2) != end)
- return currentLoop.emitError(
- "expected perfectly nested loops in the body");
+ // Check that the forOp contains perfectly nested loops for numBlockDims
+ if (failed(checkLoopNestMappableImpl(forOp, numBlockDims))) {
+ return failure();
+ }
- if (!(currentLoop = dyn_cast<OpTy>(nested)))
- return nested->emitError("expected a nested loop");
+ // Get to the innermost loop.
+ for (auto i : seq<unsigned>(0, numBlockDims - 1)) {
+ forOp = cast<OpTy>(&forOp.getBody()->front());
+ (void)i;
}
+ // The forOp now points to the body of the innermost loop mapped to blocks.
+ for (Operation &op : *forOp.getBody()) {
+ // If the operation is a loop, check that it is mappable to workItems.
+ if (auto innerLoop = dyn_cast<OpTy>(&op)) {
+ if (failed(checkLoopNestMappableImpl(innerLoop, numThreadDims))) {
+ return failure();
+ }
+ continue;
+ }
+ // TODO(ravishankarm) : If it is not a loop op, it is assumed that the
+ // statement is executed by all threads. It might be a collective operation,
+ // or some non-side effect instruction. Have to decide on "allowable"
+ // statements and check for those here.
+ }
return success();
}
@@ -215,10 +267,140 @@ Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp,
return currentLoop;
}
+/// Given `nDims` perfectly nested loops rooted as `rootForOp`, convert them o
+/// be partitioned across workgroups or workitems. The values for the
+/// workgroup/workitem id along each dimension is passed in with `ids`. The
+/// number of workgroups/workitems along each dimension are passed in with
+/// `nids`. The innermost loop is mapped to the x-dimension, followed by the
+/// next innermost loop to y-dimension, followed by z-dimension.
+template <typename OpTy>
+OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value *> ids,
+ ArrayRef<Value *> nids) {
+ auto nDims = ids.size();
+ assert(nDims == nids.size());
+ for (auto dim : llvm::seq<unsigned>(0, nDims)) {
+ // TODO(ravishankarm): Don't always need to generate a loop here. If nids >=
+ // number of iterations of the original loop, this becomes a if
+ // condition. Though that does rely on how the workgroup/workitem sizes are
+ // specified to begin with.
+ mapLoopToProcessorIds(rootForOp, ids[dim], nids[dim]);
+ if (dim != nDims - 1) {
+ rootForOp = cast<OpTy>(rootForOp.getBody()->front());
+ }
+ }
+ return rootForOp;
+}
+
+/// Utility method to convert the gpu::KernelDim3 object for representing id of
+/// each workgroup/workitem and number of workgroup/workitems along a dimension
+/// of the launch into a container.
+void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids,
+ unsigned nDims, SmallVectorImpl<Value *> &ids,
+ SmallVectorImpl<Value *> &nids) {
+ assert(nDims <= 3 && "invalid number of launch dimensions");
+ SmallVector<Value *, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
+ SmallVector<Value *, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
+ ids.clear();
+ ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end());
+ nids.clear();
+ nids.append(std::next(allNids.begin(), allNids.size() - nDims),
+ allNids.end());
+}
+
+/// Generate the body of the launch operation.
+template <typename OpTy>
+LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp,
+ gpu::LaunchOp launchOp, unsigned numBlockDims,
+ unsigned numThreadDims) {
+ OpBuilder::InsertionGuard bodyInsertionGuard(builder);
+ builder.setInsertionPointToEnd(&launchOp.getBody().front());
+ auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc());
+
+ rootForOp.getOperation()->moveBefore(returnOp);
+ SmallVector<Value *, 3> workgroupID, numWorkGroups;
+ packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims,
+ workgroupID, numWorkGroups);
+
+ // Partition the loop for mapping to workgroups.
+ auto loopOp = createGPULaunchLoops(rootForOp, workgroupID, numWorkGroups);
+
+ // Iterate over the body of the loopOp and get the loops to partition for
+ // thread blocks.
+ SmallVector<OpTy, 1> threadRootForOps;
+ for (Operation &op : *loopOp.getBody()) {
+ if (auto threadRootForOp = dyn_cast<OpTy>(&op)) {
+ threadRootForOps.push_back(threadRootForOp);
+ }
+ }
+
+ SmallVector<Value *, 3> workItemID, workGroupSize;
+ packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(),
+ numThreadDims, workItemID, workGroupSize);
+ for (auto &loopOp : threadRootForOps) {
+ builder.setInsertionPoint(loopOp);
+ createGPULaunchLoops(loopOp, workItemID, workGroupSize);
+ }
+ return success();
+}
+
+// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the
+// given workgroup size and number of workgroups.
+template <typename OpTy>
+LogicalResult createLaunchFromOp(OpTy rootForOp,
+ ArrayRef<Value *> numWorkGroups,
+ ArrayRef<Value *> workGroupSizes) {
+ OpBuilder builder(rootForOp.getOperation());
+ if (numWorkGroups.size() > 3) {
+ return rootForOp.emitError("invalid ")
+ << numWorkGroups.size() << "-D workgroup specification";
+ }
+ auto loc = rootForOp.getLoc();
+ Value *one = builder.create<ConstantOp>(
+ loc, builder.getIntegerAttr(builder.getIndexType(), 1));
+ SmallVector<Value *, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one);
+ for (auto numWorkGroup : enumerate(numWorkGroups)) {
+ numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value();
+ }
+ for (auto workGroupSize : enumerate(workGroupSizes)) {
+ workGroupSize3D[workGroupSize.index()] = workGroupSize.value();
+ }
+
+ // Get the values used within the region of the rootForOp but defined above
+ // it.
+ llvm::SetVector<Value *> valuesToForwardSet;
+ getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(),
+ valuesToForwardSet);
+ // Also add the values used for the lb, ub, and step of the rootForOp.
+ valuesToForwardSet.insert(rootForOp.getOperands().begin(),
+ rootForOp.getOperands().end());
+ auto valuesToForward = valuesToForwardSet.takeVector();
+ auto launchOp = builder.create<gpu::LaunchOp>(
+ rootForOp.getLoc(), numWorkGroups3D[0], numWorkGroups3D[1],
+ numWorkGroups3D[2], workGroupSize3D[0], workGroupSize3D[1],
+ workGroupSize3D[2], valuesToForward);
+ if (failed(createLaunchBody(builder, rootForOp, launchOp,
+ numWorkGroups.size(), workGroupSizes.size()))) {
+ return failure();
+ }
+
+ // Replace values that are used within the region of the launchOp but are
+ // defined outside. They all are replaced with kernel arguments.
+ for (const auto &pair :
+ llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
+ Value *from = std::get<0>(pair);
+ Value *to = std::get<1>(pair);
+ replaceAllUsesInRegionWith(from, to, launchOp.getBody());
+ }
+ return success();
+}
+
// Replace the rooted at "rootForOp" with a GPU launch operation. This expects
// "innermostForOp" to point to the last loop to be transformed to the kernel,
// and to have (numBlockDims + numThreadDims) perfectly nested loops between
// "rootForOp" and "innermostForOp".
+// TODO(ravishankarm) : This method can be modified to use the
+// createLaunchFromOp method, since that is a strict generalization of this
+// method.
template <typename OpTy>
void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
unsigned numBlockDims,
@@ -324,6 +506,19 @@ static LogicalResult convertLoopNestToGPULaunch(OpTy forOp,
return success();
}
+// Generic loop to GPU kernel conversion function when loop is imperfectly
+// nested. The workgroup size and num workgroups is provided as input
+template <typename OpTy>
+static LogicalResult convertLoopToGPULaunch(OpTy forOp,
+ ArrayRef<Value *> numWorkGroups,
+ ArrayRef<Value *> workGroupSize) {
+ if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(),
+ workGroupSize.size()))) {
+ return failure();
+ }
+ return createLaunchFromOp(forOp, numWorkGroups, workGroupSize);
+}
+
LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims) {
@@ -335,3 +530,9 @@ LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp,
unsigned numThreadDims) {
return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims);
}
+
+LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp,
+ ArrayRef<Value *> numWorkGroups,
+ ArrayRef<Value *> workGroupSizes) {
+ return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes);
+}
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
index 6d4cb9d8256..21abc3cf99b 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -19,11 +19,14 @@
#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/CommandLine.h"
#define PASS_NAME "convert-loops-to-gpu"
+#define LOOPOP_TO_GPU_PASS_NAME "convert-loop-op-to-gpu"
using namespace mlir;
using namespace mlir::loop;
@@ -38,6 +41,19 @@ static llvm::cl::opt<unsigned> clNumThreadDims(
llvm::cl::desc("Number of GPU thread dimensions for mapping"),
llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+static llvm::cl::OptionCategory clLoopOpToGPUCategory(LOOPOP_TO_GPU_PASS_NAME
+ " options");
+static llvm::cl::list<unsigned>
+ clNumWorkGroups("gpu-num-workgroups",
+ llvm::cl::desc("Num workgroups in the GPU launch"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clLoopOpToGPUCategory));
+static llvm::cl::list<unsigned>
+ clWorkGroupSize("gpu-workgroup-size",
+ llvm::cl::desc("Workgroup Size in the GPU launch"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clLoopOpToGPUCategory));
+
namespace {
// A pass that traverses top-level loops in the function and converts them to
// GPU launch operations. Nested launches are not allowed, so this does not
@@ -64,6 +80,50 @@ struct ForLoopMapper : public FunctionPass<ForLoopMapper> {
unsigned numBlockDims;
unsigned numThreadDims;
};
+
+// A pass that traverses top-level loops in the function and convertes them to
+// GPU launch operations. The top-level loops itself does not have to be
+// perfectly nested. The only requirement is that there be as many perfectly
+// nested loops as the size of `numWorkGroups`. Within these any loop nest has
+// to be perfectly nested upto depth equal to size of `workGroupSize`.
+struct ImperfectlyNestedForLoopMapper
+ : public FunctionPass<ImperfectlyNestedForLoopMapper> {
+ ImperfectlyNestedForLoopMapper(ArrayRef<int64_t> numWorkGroups,
+ ArrayRef<int64_t> workGroupSize)
+ : numWorkGroups(numWorkGroups.begin(), numWorkGroups.end()),
+ workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+
+ void runOnFunction() override {
+ // Insert the num work groups and workgroup sizes as constant values. This
+ // pass is only used for testing.
+ FuncOp funcOp = getFunction();
+ OpBuilder builder(funcOp.getOperation()->getRegion(0));
+ SmallVector<Value *, 3> numWorkGroupsVal, workGroupSizeVal;
+ for (auto val : numWorkGroups) {
+ auto constOp = builder.create<ConstantOp>(
+ funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
+ numWorkGroupsVal.push_back(constOp);
+ }
+ for (auto val : workGroupSize) {
+ auto constOp = builder.create<ConstantOp>(
+ funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
+ workGroupSizeVal.push_back(constOp);
+ }
+ for (Block &block : getFunction()) {
+ for (Operation &op : llvm::make_early_inc_range(block)) {
+ if (auto forOp = dyn_cast<ForOp>(&op)) {
+ if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
+ workGroupSizeVal))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+ }
+ SmallVector<int64_t, 3> numWorkGroups;
+ SmallVector<int64_t, 3> workGroupSize;
+};
+
} // namespace
std::unique_ptr<OpPassBase<FuncOp>>
@@ -72,8 +132,25 @@ mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims,
return std::make_unique<ForLoopMapper>(numBlockDims, numThreadDims);
}
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups,
+ ArrayRef<int64_t> workGroupSize) {
+ return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups,
+ workGroupSize);
+}
+
static PassRegistration<ForLoopMapper>
registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] {
return std::make_unique<ForLoopMapper>(clNumBlockDims.getValue(),
clNumThreadDims.getValue());
});
+
+static PassRegistration<ImperfectlyNestedForLoopMapper> loopOpToGPU(
+ LOOPOP_TO_GPU_PASS_NAME, "Convert top-level loop::ForOp to GPU kernels",
+ [] {
+ SmallVector<int64_t, 3> numWorkGroups, workGroupSize;
+ numWorkGroups.assign(clNumWorkGroups.begin(), clNumWorkGroups.end());
+ workGroupSize.assign(clWorkGroupSize.begin(), clWorkGroupSize.end());
+ return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups,
+ workGroupSize);
+ });
OpenPOWER on IntegriCloud