diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp | 26 | ||||
-rw-r--r-- | mlir/lib/Pass/Pass.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Pass/PassRegistry.cpp | 56 |
3 files changed, 76 insertions, 26 deletions
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index 115096003e1..68392c36765 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -35,17 +35,17 @@ namespace { /// 2) Lower the body of the spirv::ModuleOp. class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { public: - GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) - : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {} + GPUToSPIRVPass() = default; + GPUToSPIRVPass(const GPUToSPIRVPass &) {} + GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { + this->workGroupSize = workGroupSize; + } + void runOnModule() override; private: - SmallVector<int64_t, 3> workGroupSize; -}; - -/// Command line option to specify the workgroup size. -struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> { - List<unsigned> workGroupSize{ + /// Command line option to specify the workgroup size. + ListOption<int64_t> workGroupSize{ *this, "workgroup-size", llvm::cl::desc( "Workgroup Sizes in the SPIR-V module for x, followed by y, followed " @@ -92,11 +92,5 @@ mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { return std::make_unique<GPUToSPIRVPass>(workGroupSize); } -static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions> - pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect", - [](const GPUToSPIRVPassOptions &passOptions) { - SmallVector<int64_t, 3> workGroupSize; - workGroupSize.assign(passOptions.workGroupSize.begin(), - passOptions.workGroupSize.end()); - return std::make_unique<GPUToSPIRVPass>(workGroupSize); - }); +static PassRegistration<GPUToSPIRVPass> + pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect"); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 22e58cc5b63..8877cc5f684 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -36,6 +36,17 @@ using namespace mlir::detail; /// single .o file. void Pass::anchor() {} +/// Attempt to initialize the options of this pass from the given string. +LogicalResult Pass::initializeOptions(StringRef options) { + return passOptions.parseFromString(options); +} + +/// Copy the option values from 'other', which is another instance of this +/// pass. +void Pass::copyOptionValuesFrom(const Pass *other) { + passOptions.copyOptionValuesFrom(other->passOptions); +} + /// Prints out the pass in the textual representation of pipelines. If this is /// an adaptor pass, print with the op_name(sub_pass,...) format. void Pass::printAsTextualPipeline(raw_ostream &os) { @@ -46,11 +57,14 @@ void Pass::printAsTextualPipeline(raw_ostream &os) { pm.printAsTextualPipeline(os); os << ")"; }); - } else if (const PassInfo *info = lookupPassInfo()) { + return; + } + // Otherwise, print the pass argument followed by its options. + if (const PassInfo *info = lookupPassInfo()) os << info->getPassArgument(); - } else { + else os << getName(); - } + passOptions.print(os); } /// Forwarding function to execute this pass. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 93753d363db..1c5193d0539 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -24,10 +24,15 @@ static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry; static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> passPipelineRegistry; -// Helper to avoid exposing OpPassManager. -void mlir::detail::addPassToPassManager(OpPassManager &pm, - std::unique_ptr<Pass> pass) { - pm.addPass(std::move(pass)); +/// Utility to create a default registry function from a pass instance. +static PassRegistryFunction +buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { + return [=](OpPassManager &pm, StringRef options) { + std::unique_ptr<Pass> pass = allocator(); + LogicalResult result = pass->initializeOptions(options); + pm.addPass(std::move(pass)); + return result; + }; } //===----------------------------------------------------------------------===// @@ -46,9 +51,13 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description, // PassInfo //===----------------------------------------------------------------------===// +PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID, + const PassAllocatorFunction &allocator) + : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {} + void mlir::registerPass(StringRef arg, StringRef description, const PassID *passID, - const PassRegistryFunction &function) { + const PassAllocatorFunction &function) { PassInfo passInfo(arg, description, passID, function); bool inserted = passRegistry->try_emplace(passID, passInfo).second; assert(inserted && "Pass registered multiple times"); @@ -67,7 +76,19 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) { // PassOptions //===----------------------------------------------------------------------===// -LogicalResult PassOptionsBase::parseFromString(StringRef options) { +/// Out of line virtual function to provide home for the class. +void detail::PassOptions::OptionBase::anchor() {} + +/// Copy the option values from 'other'. +void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { + assert(options.size() == other.options.size()); + if (options.empty()) + return; + for (auto optionsIt : llvm::zip(options, other.options)) + std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); +} + +LogicalResult detail::PassOptions::parseFromString(StringRef options) { // TODO(parkers): Handle escaping strings. // NOTE: `options` is modified in place to always refer to the unprocessed // part of the string. @@ -99,7 +120,6 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) { auto it = OptionsMap.find(key); if (it == OptionsMap.end()) { llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; - return failure(); } if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) @@ -109,6 +129,28 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) { return success(); } +/// Print the options held by this struct in a form that can be parsed via +/// 'parseFromString'. +void detail::PassOptions::print(raw_ostream &os) { + // If there are no options, there is nothing left to do. + if (OptionsMap.empty()) + return; + + // Sort the options to make the ordering deterministic. + SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end()); + llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(), + [](OptionBase *const *lhs, OptionBase *const *rhs) { + return (*lhs)->getArgStr().compare( + (*rhs)->getArgStr()); + }); + + // Interleave the options with ' '. + os << '{'; + interleave( + orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " "); + os << '}'; +} + //===----------------------------------------------------------------------===// // TextualPassPipeline Parser //===----------------------------------------------------------------------===// |