summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp26
-rw-r--r--mlir/lib/Pass/Pass.cpp20
-rw-r--r--mlir/lib/Pass/PassRegistry.cpp56
3 files changed, 76 insertions, 26 deletions
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 115096003e1..68392c36765 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -35,17 +35,17 @@ namespace {
/// 2) Lower the body of the spirv::ModuleOp.
class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
public:
- GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize)
- : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+ GPUToSPIRVPass() = default;
+ GPUToSPIRVPass(const GPUToSPIRVPass &) {}
+ GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+ this->workGroupSize = workGroupSize;
+ }
+
void runOnModule() override;
private:
- SmallVector<int64_t, 3> workGroupSize;
-};
-
-/// Command line option to specify the workgroup size.
-struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> {
- List<unsigned> workGroupSize{
+ /// Command line option to specify the workgroup size.
+ ListOption<int64_t> workGroupSize{
*this, "workgroup-size",
llvm::cl::desc(
"Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
@@ -92,11 +92,5 @@ mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
return std::make_unique<GPUToSPIRVPass>(workGroupSize);
}
-static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions>
- pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect",
- [](const GPUToSPIRVPassOptions &passOptions) {
- SmallVector<int64_t, 3> workGroupSize;
- workGroupSize.assign(passOptions.workGroupSize.begin(),
- passOptions.workGroupSize.end());
- return std::make_unique<GPUToSPIRVPass>(workGroupSize);
- });
+static PassRegistration<GPUToSPIRVPass>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 22e58cc5b63..8877cc5f684 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -36,6 +36,17 @@ using namespace mlir::detail;
/// single .o file.
void Pass::anchor() {}
+/// Attempt to initialize the options of this pass from the given string.
+LogicalResult Pass::initializeOptions(StringRef options) {
+ return passOptions.parseFromString(options);
+}
+
+/// Copy the option values from 'other', which is another instance of this
+/// pass.
+void Pass::copyOptionValuesFrom(const Pass *other) {
+ passOptions.copyOptionValuesFrom(other->passOptions);
+}
+
/// Prints out the pass in the textual representation of pipelines. If this is
/// an adaptor pass, print with the op_name(sub_pass,...) format.
void Pass::printAsTextualPipeline(raw_ostream &os) {
@@ -46,11 +57,14 @@ void Pass::printAsTextualPipeline(raw_ostream &os) {
pm.printAsTextualPipeline(os);
os << ")";
});
- } else if (const PassInfo *info = lookupPassInfo()) {
+ return;
+ }
+ // Otherwise, print the pass argument followed by its options.
+ if (const PassInfo *info = lookupPassInfo())
os << info->getPassArgument();
- } else {
+ else
os << getName();
- }
+ passOptions.print(os);
}
/// Forwarding function to execute this pass.
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 93753d363db..1c5193d0539 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -24,10 +24,15 @@ static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry;
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
passPipelineRegistry;
-// Helper to avoid exposing OpPassManager.
-void mlir::detail::addPassToPassManager(OpPassManager &pm,
- std::unique_ptr<Pass> pass) {
- pm.addPass(std::move(pass));
+/// Utility to create a default registry function from a pass instance.
+static PassRegistryFunction
+buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
+ return [=](OpPassManager &pm, StringRef options) {
+ std::unique_ptr<Pass> pass = allocator();
+ LogicalResult result = pass->initializeOptions(options);
+ pm.addPass(std::move(pass));
+ return result;
+ };
}
//===----------------------------------------------------------------------===//
@@ -46,9 +51,13 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description,
// PassInfo
//===----------------------------------------------------------------------===//
+PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
+ const PassAllocatorFunction &allocator)
+ : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
+
void mlir::registerPass(StringRef arg, StringRef description,
const PassID *passID,
- const PassRegistryFunction &function) {
+ const PassAllocatorFunction &function) {
PassInfo passInfo(arg, description, passID, function);
bool inserted = passRegistry->try_emplace(passID, passInfo).second;
assert(inserted && "Pass registered multiple times");
@@ -67,7 +76,19 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
// PassOptions
//===----------------------------------------------------------------------===//
-LogicalResult PassOptionsBase::parseFromString(StringRef options) {
+/// Out of line virtual function to provide home for the class.
+void detail::PassOptions::OptionBase::anchor() {}
+
+/// Copy the option values from 'other'.
+void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
+ assert(options.size() == other.options.size());
+ if (options.empty())
+ return;
+ for (auto optionsIt : llvm::zip(options, other.options))
+ std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
+}
+
+LogicalResult detail::PassOptions::parseFromString(StringRef options) {
// TODO(parkers): Handle escaping strings.
// NOTE: `options` is modified in place to always refer to the unprocessed
// part of the string.
@@ -99,7 +120,6 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) {
auto it = OptionsMap.find(key);
if (it == OptionsMap.end()) {
llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
-
return failure();
}
if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
@@ -109,6 +129,28 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) {
return success();
}
+/// Print the options held by this struct in a form that can be parsed via
+/// 'parseFromString'.
+void detail::PassOptions::print(raw_ostream &os) {
+ // If there are no options, there is nothing left to do.
+ if (OptionsMap.empty())
+ return;
+
+ // Sort the options to make the ordering deterministic.
+ SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end());
+ llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(),
+ [](OptionBase *const *lhs, OptionBase *const *rhs) {
+ return (*lhs)->getArgStr().compare(
+ (*rhs)->getArgStr());
+ });
+
+ // Interleave the options with ' '.
+ os << '{';
+ interleave(
+ orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " ");
+ os << '}';
+}
+
//===----------------------------------------------------------------------===//
// TextualPassPipeline Parser
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud