diff options
-rw-r--r-- | mlir/g3doc/WritingAPass.md | 52 | ||||
-rw-r--r-- | mlir/include/mlir/Pass/Pass.h | 47 | ||||
-rw-r--r-- | mlir/include/mlir/Pass/PassOptions.h | 185 | ||||
-rw-r--r-- | mlir/include/mlir/Pass/PassRegistry.h | 83 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp | 26 | ||||
-rw-r--r-- | mlir/lib/Pass/Pass.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Pass/PassRegistry.cpp | 56 | ||||
-rw-r--r-- | mlir/test/Pass/pipeline-options-parsing.mlir | 6 | ||||
-rw-r--r-- | mlir/test/lib/Pass/TestPassManager.cpp | 47 | ||||
-rw-r--r-- | mlir/test/lib/Transforms/TestLoopParametricTiling.cpp | 24 |
10 files changed, 369 insertions, 177 deletions
diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md index 784757139d3..5119c469e20 100644 --- a/mlir/g3doc/WritingAPass.md +++ b/mlir/g3doc/WritingAPass.md @@ -421,7 +421,8 @@ options ::= '{' (key ('=' value)?)+ '}' pass pipeline, e.g. `cse` or `canonicalize`. * `options` * Options are pass specific key value pairs that are handled as described - in the instance specific pass options section. + in the [instance specific pass options](#instance-specific-pass-options) + section. For example, the following pipeline: @@ -443,30 +444,47 @@ options in the format described above. ### Instance Specific Pass Options Options may be specified for a parametric pass. Individual options are defined -using `llvm::cl::opt` flag definition rules. These options will then be parsed -at pass construction time independently for each instance of the pass. The -`PassRegistration` and `PassPipelineRegistration` templates take an additional -optional template parameter that is the Option struct definition to be used for -that pass. To use pass specific options, create a class that inherits from -`mlir::PassOptions` and then add a new constructor that takes `const -MyPassOptions&` and constructs the pass. When using `PassPipelineRegistration`, -the constructor now takes a function with the signature `void (OpPassManager -&pm, const MyPassOptions&)` which should construct the passes from the options -and pass them to the pm. The user code will look like the following: +using the [LLVM command line](https://llvm.org/docs/CommandLine.html) flag +definition rules. These options will then be parsed at pass construction time +independently for each instance of the pass. To provide options for passes, the +`Option<>` and `OptionList<>` classes may be used: ```c++ -class MyPass ... { -public: - MyPass(const MyPassOptions& options) ... +struct MyPass ... { + /// Make sure that we have a valid default constructor and copy constructor to + /// make sure that the options are initialized properly. + MyPass() = default; + MyPass(const MyPass& pass) {} + + // These just forward onto llvm::cl::list and llvm::cl::opt respectively. + Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")}; + ListOption<int> exampleListOption{*this, "list-flag-name", + llvm::cl::desc("...")}; }; +``` -struct MyPassOptions : public PassOptions<MyPassOptions> { +For pass pipelines, the `PassPipelineRegistration` templates take an additional +optional template parameter that is the Option struct definition to be used for +that pipeline. To use pipeline specific options, create a class that inherits +from `mlir::PassPipelineOptions` that contains the desired options. When using +`PassPipelineRegistration`, the constructor now takes a function with the +signature `void (OpPassManager &pm, const MyPipelineOptions&)` which should +construct the passes from the options and pass them to the pm: + +```c++ +struct MyPipelineOptions : public PassPipelineOptions { // These just forward onto llvm::cl::list and llvm::cl::opt respectively. Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")}; - List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")}; + ListOption<int> exampleListOption{*this, "list-flag-name", + llvm::cl::desc("...")}; }; -static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description"); + +static mlir::PassPipelineRegistration<MyPipelineOptions> pipeline( + "example-pipeline", "Run an example pipeline.", + [](OpPassManager &pm, const MyPipelineOptions &pipelineOptions) { + // Initialize the pass manager. + }); ``` ## Pass Statistics diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index b4e8db86ff0..bcb297356fa 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -61,12 +61,40 @@ public: /// this is a generic OperationPass. Optional<StringRef> getOpName() const { return opName; } + //===--------------------------------------------------------------------===// + // Options + //===--------------------------------------------------------------------===// + + /// This class represents a specific pass option, with a provided data type. + template <typename DataType> + struct Option : public detail::PassOptions::Option<DataType> { + template <typename... Args> + Option(Pass &parent, StringRef arg, Args &&... args) + : detail::PassOptions::Option<DataType>(parent.passOptions, arg, + std::forward<Args>(args)...) {} + using detail::PassOptions::Option<DataType>::operator=; + }; + /// This class represents a specific pass option that contains a list of + /// values of the provided data type. + template <typename DataType> + struct ListOption : public detail::PassOptions::ListOption<DataType> { + template <typename... Args> + ListOption(Pass &parent, StringRef arg, Args &&... args) + : detail::PassOptions::ListOption<DataType>( + parent.passOptions, arg, std::forward<Args>(args)...) {} + using detail::PassOptions::ListOption<DataType>::operator=; + }; + + /// Attempt to initialize the options of this pass from the given string. + LogicalResult initializeOptions(StringRef options); + /// Prints out the pass in the textual representation of pipelines. If this is /// an adaptor pass, print with the op_name(sub_pass,...) format. - /// Note: The default implementation uses the class name and does not respect - /// options used to construct the pass. Override this method to allow for your - /// pass to be to be round-trippable to the textual format. - virtual void printAsTextualPipeline(raw_ostream &os); + void printAsTextualPipeline(raw_ostream &os); + + //===--------------------------------------------------------------------===// + // Statistics + //===--------------------------------------------------------------------===// /// This class represents a single pass statistic. This statistic functions /// similarly to an unsigned integer value, and may be updated and incremented @@ -119,6 +147,10 @@ protected: return getPassState().analysisManager; } + /// Copy the option values from 'other', which is another instance of this + /// pass. + void copyOptionValuesFrom(const Pass *other); + private: /// Forwarding function to execute this pass on the given operation. LLVM_NODISCARD @@ -141,6 +173,9 @@ private: /// The set of statistics held by this pass. std::vector<Statistic *> statistics; + /// The pass options registered to this pass instance. + detail::PassOptions passOptions; + /// Allow access to 'clone' and 'run'. friend class OpPassManager; }; @@ -204,7 +239,9 @@ protected: /// A clone method to create a copy of this pass. std::unique_ptr<Pass> clone() const override { - return std::make_unique<PassT>(*static_cast<const PassT *>(this)); + auto newInst = std::make_unique<PassT>(*static_cast<const PassT *>(this)); + newInst->copyOptionValuesFrom(this); + return newInst; } /// Returns the analysis for the parent operation if it exists. diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index 8ebeead90c8..0ecb7ba970a 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -24,50 +24,202 @@ namespace mlir { namespace detail { -/// Base class for PassOptions<T> that holds all of the non-CRTP features. -class PassOptionsBase : protected llvm::cl::SubCommand { +/// Base container class and manager for all pass options. +class PassOptions : protected llvm::cl::SubCommand { +private: + /// This is the type-erased option base class. This provides some additional + /// hooks into the options that are not available via llvm::cl::Option. + class OptionBase { + public: + virtual ~OptionBase() = default; + + /// Out of line virtual function to provide home for the class. + virtual void anchor(); + + /// Print the name and value of this option to the given stream. + virtual void print(raw_ostream &os) = 0; + + /// Return the argument string of this option. + StringRef getArgStr() const { return getOption()->ArgStr; } + + protected: + /// Return the main option instance. + virtual const llvm::cl::Option *getOption() const = 0; + + /// Copy the value from the given option into this one. + virtual void copyValueFrom(const OptionBase &other) = 0; + + /// Allow access to private methods. + friend PassOptions; + }; + + /// This is the parser that is used by pass options that use literal options. + /// This is a thin wrapper around the llvm::cl::parser, that exposes some + /// additional methods. + template <typename DataType> + struct GenericOptionParser : public llvm::cl::parser<DataType> { + using llvm::cl::parser<DataType>::parser; + + /// Returns an argument name that maps to the specified value. + Optional<StringRef> findArgStrForValue(const DataType &value) { + for (auto &it : this->Values) + if (it.V.compare(value)) + return it.Name; + return llvm::None; + } + }; + + /// The specific parser to use depending on llvm::cl parser used. This is only + /// necessary because we need to provide additional methods for certain data + /// type parsers. + /// TODO(riverriddle) We should upstream the methods in GenericOptionParser to + /// avoid the need to do this. + template <typename DataType> + using OptionParser = + std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base, + llvm::cl::parser<DataType>>::value, + GenericOptionParser<DataType>, + llvm::cl::parser<DataType>>; + + /// Utility methods for printing option values. + template <typename DataT> + static void printOptionValue(raw_ostream &os, + GenericOptionParser<DataT> &parser, + const DataT &value) { + if (Optional<StringRef> argStr = parser.findArgStrForValue(value)) + os << argStr; + else + llvm_unreachable("unknown data value for option"); + } + template <typename DataT, typename ParserT> + static void printOptionValue(raw_ostream &os, ParserT &parser, + const DataT &value) { + os << value; + } + template <typename ParserT> + static void printOptionValue(raw_ostream &os, ParserT &parser, + const bool &value) { + os << (value ? StringRef("true") : StringRef("false")); + } + public: /// This class represents a specific pass option, with a provided data type. - template <typename DataType> struct Option : public llvm::cl::opt<DataType> { + template <typename DataType> + class Option : public llvm::cl::opt<DataType, /*ExternalStorage=*/false, + OptionParser<DataType>>, + public OptionBase { + public: template <typename... Args> - Option(PassOptionsBase &parent, StringRef arg, Args &&... args) - : llvm::cl::opt<DataType>(arg, llvm::cl::sub(parent), - std::forward<Args>(args)...) { + Option(PassOptions &parent, StringRef arg, Args &&... args) + : llvm::cl::opt<DataType, /*ExternalStorage=*/false, + OptionParser<DataType>>(arg, llvm::cl::sub(parent), + std::forward<Args>(args)...) { assert(!this->isPositional() && !this->isSink() && "sink and positional options are not supported"); + parent.options.push_back(this); + } + using llvm::cl::opt<DataType, /*ExternalStorage=*/false, + OptionParser<DataType>>::operator=; + ~Option() override = default; + + private: + /// Return the main option instance. + const llvm::cl::Option *getOption() const final { return this; } + + /// Print the name and value of this option to the given stream. + void print(raw_ostream &os) final { + os << this->ArgStr << '='; + printOptionValue(os, this->getParser(), this->getValue()); + } + + /// Copy the value from the given option into this one. + void copyValueFrom(const OptionBase &other) final { + this->setValue(static_cast<const Option<DataType> &>(other).getValue()); } }; /// This class represents a specific pass option that contains a list of /// values of the provided data type. - template <typename DataType> struct List : public llvm::cl::list<DataType> { + template <typename DataType> + class ListOption : public llvm::cl::list<DataType, /*StorageClass=*/bool, + OptionParser<DataType>>, + public OptionBase { + public: template <typename... Args> - List(PassOptionsBase &parent, StringRef arg, Args &&... args) - : llvm::cl::list<DataType>(arg, llvm::cl::sub(parent), - std::forward<Args>(args)...) { + ListOption(PassOptions &parent, StringRef arg, Args &&... args) + : llvm::cl::list<DataType, /*StorageClass=*/bool, + OptionParser<DataType>>(arg, llvm::cl::sub(parent), + std::forward<Args>(args)...) { assert(!this->isPositional() && !this->isSink() && "sink and positional options are not supported"); + parent.options.push_back(this); + } + ~ListOption() override = default; + + /// Allow assigning from an ArrayRef. + ListOption<DataType> &operator=(ArrayRef<DataType> values) { + (*this)->assign(values.begin(), values.end()); + return *this; + } + + std::vector<DataType> *operator->() { return &*this; } + + private: + /// Return the main option instance. + const llvm::cl::Option *getOption() const final { return this; } + + /// Print the name and value of this option to the given stream. + void print(raw_ostream &os) final { + os << this->ArgStr << '='; + auto printElementFn = [&](const DataType &value) { + printOptionValue(os, this->getParser(), value); + }; + interleave(*this, os, printElementFn, ","); + } + + /// Copy the value from the given option into this one. + void copyValueFrom(const OptionBase &other) final { + (*this) = ArrayRef<DataType>((ListOption<DataType> &)other); } }; + PassOptions() = default; + + /// Copy the option values from 'other' into 'this', where 'other' has the + /// same options as 'this'. + void copyOptionValuesFrom(const PassOptions &other); + /// Parse options out as key=value pairs that can then be handed off to the /// `llvm::cl` command line passing infrastructure. Everything is space /// separated. LogicalResult parseFromString(StringRef options); + + /// Print the options held by this struct in a form that can be parsed via + /// 'parseFromString'. + void print(raw_ostream &os); + +private: + /// A list of all of the opaque options. + std::vector<OptionBase *> options; }; } // end namespace detail -/// Subclasses of PassOptions provide a set of options that can be used to -/// initialize a pass instance. See PassRegistration for usage details. +//===----------------------------------------------------------------------===// +// PassPipelineOptions +//===----------------------------------------------------------------------===// + +/// Subclasses of PassPipelineOptions provide a set of options that can be used +/// to initialize a pass pipeline. See PassPipelineRegistration for usage +/// details. /// /// Usage: /// -/// struct MyPassOptions : PassOptions<MyPassOptions> { -/// List<int> someListFlag{ +/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> { +/// ListOption<int> someListFlag{ /// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated, /// llvm::cl::desc("...")}; /// }; -template <typename T> class PassOptions : public detail::PassOptionsBase { +template <typename T> class PassPipelineOptions : public detail::PassOptions { public: /// Factory that parses the provided options and returns a unique_ptr to the /// struct. @@ -81,7 +233,8 @@ public: /// A default empty option struct to be used for passes that do not need to take /// any options. -struct EmptyPassOptions : public PassOptions<EmptyPassOptions> {}; +struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> { +}; } // end namespace mlir diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index e07b9855c8d..c5604c04616 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -25,6 +25,7 @@ class Pass; /// also parse options and return success() if parsing succeeded. using PassRegistryFunction = std::function<LogicalResult(OpPassManager &, StringRef options)>; +using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>; /// A special type used by transformation passes to provide an address that can /// act as a unique identifier during pass registration. @@ -56,7 +57,7 @@ public: protected: PassRegistryEntry(StringRef arg, StringRef description, - PassRegistryFunction builder) + const PassRegistryFunction &builder) : arg(arg), description(description), builder(builder) {} private: @@ -74,7 +75,7 @@ private: class PassPipelineInfo : public PassRegistryEntry { public: PassPipelineInfo(StringRef arg, StringRef description, - PassRegistryFunction builder) + const PassRegistryFunction &builder) : PassRegistryEntry(arg, description, builder) {} }; @@ -84,8 +85,7 @@ public: /// PassInfo constructor should not be invoked directly, instead use /// PassRegistration or registerPass. PassInfo(StringRef arg, StringRef description, const PassID *passID, - PassRegistryFunction allocator) - : PassRegistryEntry(arg, description, allocator) {} + const PassAllocatorFunction &allocator); }; //===----------------------------------------------------------------------===// @@ -100,80 +100,28 @@ void registerPassPipeline(StringRef arg, StringRef description, /// Register a specific dialect pass allocator function with the system, /// typically used through the PassRegistration template. void registerPass(StringRef arg, StringRef description, const PassID *passID, - const PassRegistryFunction &function); - -namespace detail { - -// Calls `pm.addPass(std::move(pass))` to avoid including the PassManager -// header. Only used in `makePassRegistryFunction`. -void addPassToPassManager(OpPassManager &pm, std::unique_ptr<Pass> pass); - -// Helper function which constructs a PassRegistryFunction that parses options -// into a struct of type `Options` and then calls constructor(options) to -// build the pass. -template <typename Options, typename PassConstructor> -PassRegistryFunction makePassRegistryFunction(PassConstructor constructor) { - return [=](OpPassManager &pm, StringRef optionsStr) { - Options options; - if (failed(options.parseFromString(optionsStr))) - return failure(); - addPassToPassManager(pm, constructor(options)); - return success(); - }; -} - -} // end namespace detail + const PassAllocatorFunction &function); /// PassRegistration provides a global initializer that registers a Pass -/// allocation routine for a concrete pass instance. The third argument is +/// allocation routine for a concrete pass instance. The third argument is /// optional and provides a callback to construct a pass that does not have /// a default constructor. /// /// Usage: /// -/// // At namespace scope. +/// /// At namespace scope. /// static PassRegistration<MyPass> reg("my-pass", "My Pass Description."); /// -/// // Same, but also providing an Options struct. -/// static PassRegistration<MyPass, MyPassOptions> reg("my-pass", "Docs..."); -template <typename ConcretePass, typename Options = EmptyPassOptions> -struct PassRegistration { +template <typename ConcretePass> struct PassRegistration { PassRegistration(StringRef arg, StringRef description, - const std::function<std::unique_ptr<Pass>(const Options &)> - &constructor) { - registerPass(arg, description, PassID::getID<ConcretePass>(), - detail::makePassRegistryFunction<Options>(constructor)); + const PassAllocatorFunction &constructor) { + registerPass(arg, description, PassID::getID<ConcretePass>(), constructor); } - PassRegistration(StringRef arg, StringRef description) { - registerPass( - arg, description, PassID::getID<ConcretePass>(), - detail::makePassRegistryFunction<Options>([](const Options &options) { - return std::make_unique<ConcretePass>(options); - })); - } -}; - -/// Convenience specialization of PassRegistration for EmptyPassOptions that -/// does not pass an empty options struct to the pass constructor. -template <typename ConcretePass> -struct PassRegistration<ConcretePass, EmptyPassOptions> { - PassRegistration(StringRef arg, StringRef description, - const std::function<std::unique_ptr<Pass>()> &constructor) { - registerPass( - arg, description, PassID::getID<ConcretePass>(), - detail::makePassRegistryFunction<EmptyPassOptions>( - [=](const EmptyPassOptions &options) { return constructor(); })); - } - - PassRegistration(StringRef arg, StringRef description) { - registerPass(arg, description, PassID::getID<ConcretePass>(), - detail::makePassRegistryFunction<EmptyPassOptions>( - [](const EmptyPassOptions &options) { - return std::make_unique<ConcretePass>(); - })); - } + PassRegistration(StringRef arg, StringRef description) + : PassRegistration(arg, description, + [] { return std::make_unique<ConcretePass>(); }) {} }; /// PassPipelineRegistration provides a global initializer that registers a Pass @@ -189,7 +137,8 @@ struct PassRegistration<ConcretePass, EmptyPassOptions> { /// /// static PassPipelineRegistration Unused("unused", "Unused pass", /// pipelineBuilder); -template <typename Options = EmptyPassOptions> struct PassPipelineRegistration { +template <typename Options = EmptyPipelineOptions> +struct PassPipelineRegistration { PassPipelineRegistration( StringRef arg, StringRef description, std::function<void(OpPassManager &, const Options &options)> builder) { @@ -206,7 +155,7 @@ template <typename Options = EmptyPassOptions> struct PassPipelineRegistration { /// Convenience specialization of PassPipelineRegistration for EmptyPassOptions /// that does not pass an empty options struct to the pass builder function. -template <> struct PassPipelineRegistration<EmptyPassOptions> { +template <> struct PassPipelineRegistration<EmptyPipelineOptions> { PassPipelineRegistration(StringRef arg, StringRef description, std::function<void(OpPassManager &)> builder) { registerPassPipeline(arg, description, diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index 115096003e1..68392c36765 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -35,17 +35,17 @@ namespace { /// 2) Lower the body of the spirv::ModuleOp. class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { public: - GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) - : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {} + GPUToSPIRVPass() = default; + GPUToSPIRVPass(const GPUToSPIRVPass &) {} + GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { + this->workGroupSize = workGroupSize; + } + void runOnModule() override; private: - SmallVector<int64_t, 3> workGroupSize; -}; - -/// Command line option to specify the workgroup size. -struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> { - List<unsigned> workGroupSize{ + /// Command line option to specify the workgroup size. + ListOption<int64_t> workGroupSize{ *this, "workgroup-size", llvm::cl::desc( "Workgroup Sizes in the SPIR-V module for x, followed by y, followed " @@ -92,11 +92,5 @@ mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { return std::make_unique<GPUToSPIRVPass>(workGroupSize); } -static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions> - pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect", - [](const GPUToSPIRVPassOptions &passOptions) { - SmallVector<int64_t, 3> workGroupSize; - workGroupSize.assign(passOptions.workGroupSize.begin(), - passOptions.workGroupSize.end()); - return std::make_unique<GPUToSPIRVPass>(workGroupSize); - }); +static PassRegistration<GPUToSPIRVPass> + pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect"); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 22e58cc5b63..8877cc5f684 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -36,6 +36,17 @@ using namespace mlir::detail; /// single .o file. void Pass::anchor() {} +/// Attempt to initialize the options of this pass from the given string. +LogicalResult Pass::initializeOptions(StringRef options) { + return passOptions.parseFromString(options); +} + +/// Copy the option values from 'other', which is another instance of this +/// pass. +void Pass::copyOptionValuesFrom(const Pass *other) { + passOptions.copyOptionValuesFrom(other->passOptions); +} + /// Prints out the pass in the textual representation of pipelines. If this is /// an adaptor pass, print with the op_name(sub_pass,...) format. void Pass::printAsTextualPipeline(raw_ostream &os) { @@ -46,11 +57,14 @@ void Pass::printAsTextualPipeline(raw_ostream &os) { pm.printAsTextualPipeline(os); os << ")"; }); - } else if (const PassInfo *info = lookupPassInfo()) { + return; + } + // Otherwise, print the pass argument followed by its options. + if (const PassInfo *info = lookupPassInfo()) os << info->getPassArgument(); - } else { + else os << getName(); - } + passOptions.print(os); } /// Forwarding function to execute this pass. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 93753d363db..1c5193d0539 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -24,10 +24,15 @@ static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry; static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> passPipelineRegistry; -// Helper to avoid exposing OpPassManager. -void mlir::detail::addPassToPassManager(OpPassManager &pm, - std::unique_ptr<Pass> pass) { - pm.addPass(std::move(pass)); +/// Utility to create a default registry function from a pass instance. +static PassRegistryFunction +buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { + return [=](OpPassManager &pm, StringRef options) { + std::unique_ptr<Pass> pass = allocator(); + LogicalResult result = pass->initializeOptions(options); + pm.addPass(std::move(pass)); + return result; + }; } //===----------------------------------------------------------------------===// @@ -46,9 +51,13 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description, // PassInfo //===----------------------------------------------------------------------===// +PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID, + const PassAllocatorFunction &allocator) + : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {} + void mlir::registerPass(StringRef arg, StringRef description, const PassID *passID, - const PassRegistryFunction &function) { + const PassAllocatorFunction &function) { PassInfo passInfo(arg, description, passID, function); bool inserted = passRegistry->try_emplace(passID, passInfo).second; assert(inserted && "Pass registered multiple times"); @@ -67,7 +76,19 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) { // PassOptions //===----------------------------------------------------------------------===// -LogicalResult PassOptionsBase::parseFromString(StringRef options) { +/// Out of line virtual function to provide home for the class. +void detail::PassOptions::OptionBase::anchor() {} + +/// Copy the option values from 'other'. +void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { + assert(options.size() == other.options.size()); + if (options.empty()) + return; + for (auto optionsIt : llvm::zip(options, other.options)) + std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); +} + +LogicalResult detail::PassOptions::parseFromString(StringRef options) { // TODO(parkers): Handle escaping strings. // NOTE: `options` is modified in place to always refer to the unprocessed // part of the string. @@ -99,7 +120,6 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) { auto it = OptionsMap.find(key); if (it == OptionsMap.end()) { llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; - return failure(); } if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) @@ -109,6 +129,28 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) { return success(); } +/// Print the options held by this struct in a form that can be parsed via +/// 'parseFromString'. +void detail::PassOptions::print(raw_ostream &os) { + // If there are no options, there is nothing left to do. + if (OptionsMap.empty()) + return; + + // Sort the options to make the ordering deterministic. + SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end()); + llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(), + [](OptionBase *const *lhs, OptionBase *const *rhs) { + return (*lhs)->getArgStr().compare( + (*rhs)->getArgStr()); + }); + + // Interleave the options with ' '. + os << '{'; + interleave( + orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " "); + os << '}'; +} + //===----------------------------------------------------------------------===// // TextualPassPipeline Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir index 02452a35f23..bfb24af9302 100644 --- a/mlir/test/Pass/pipeline-options-parsing.mlir +++ b/mlir/test/Pass/pipeline-options-parsing.mlir @@ -13,6 +13,6 @@ // CHECK_ERROR_4: 'notaninteger' value invalid for integer argument // CHECK_ERROR_5: string option: may only occur zero or one times -// CHECK_1: test-options-pass{list=1,2,3,4,5 string-list=a,b,c,d string=some_value} -// CHECK_2: test-options-pass{list=1 string-list=a,b} -// CHECK_3: module(func(test-options-pass{list=3}), func(test-options-pass{list=1,2,3,4})) +// CHECK_1: test-options-pass{list=1,2,3,4,5 string=some_value string-list=a,b,c,d} +// CHECK_2: test-options-pass{list=1 string= string-list=a,b} +// CHECK_3: module(func(test-options-pass{list=3 string= string-list=}), func(test-options-pass{list=1,2,3,4 string= string-list=})) diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 2e811634880..cc926e1c01e 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -21,43 +21,34 @@ struct TestFunctionPass : public FunctionPass<TestFunctionPass> { }; class TestOptionsPass : public FunctionPass<TestOptionsPass> { public: - struct Options : public PassOptions<Options> { - List<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated, - llvm::cl::desc("Example list option")}; - List<std::string> stringListOption{ + struct Options : public PassPipelineOptions<Options> { + ListOption<int> listOption{*this, "list", + llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Example list option")}; + ListOption<std::string> stringListOption{ *this, "string-list", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Example string list option")}; Option<std::string> stringOption{*this, "string", llvm::cl::desc("Example string option")}; }; + TestOptionsPass() = default; + TestOptionsPass(const TestOptionsPass &) {} TestOptionsPass(const Options &options) { - listOption.assign(options.listOption.begin(), options.listOption.end()); - stringOption = options.stringOption; - stringListOption.assign(options.stringListOption.begin(), - options.stringListOption.end()); - } - - void printAsTextualPipeline(raw_ostream &os) final { - os << "test-options-pass{"; - if (!listOption.empty()) { - os << "list="; - // Not interleaveComma to avoid spaces between the elements. - interleave(listOption, os, ","); - } - if (!stringListOption.empty()) { - os << " string-list="; - interleave(stringListOption, os, ","); - } - if (!stringOption.empty()) - os << " string=" << stringOption; - os << "}"; + listOption->assign(options.listOption.begin(), options.listOption.end()); + stringOption.setValue(options.stringOption); + stringListOption->assign(options.stringListOption.begin(), + options.stringListOption.end()); } void runOnFunction() final {} - SmallVector<int64_t, 4> listOption; - SmallVector<std::string, 4> stringListOption; - std::string stringOption; + ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Example list option")}; + ListOption<std::string> stringListOption{ + *this, "string-list", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Example string list option")}; + Option<std::string> stringOption{*this, "string", + llvm::cl::desc("Example string option")}; }; /// A test pass that always aborts to enable testing the crash recovery @@ -97,7 +88,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) { (void)parsePassPipeline("test-pm-nested-pipeline", pm); } -static PassRegistration<TestOptionsPass, TestOptionsPass::Options> +static PassRegistration<TestOptionsPass> reg("test-options-pass", "Test options parsing capabilities"); static PassRegistration<TestModulePass> diff --git a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp index 7b0cdcade4d..e793ee54cda 100644 --- a/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp +++ b/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp @@ -25,18 +25,10 @@ namespace { class SimpleParametricLoopTilingPass : public FunctionPass<SimpleParametricLoopTilingPass> { public: - struct Options : public PassOptions<Options> { - List<int> clOuterLoopSizes{ - *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated, - llvm::cl::desc( - "fixed number of iterations that the outer loops should have")}; - }; - - explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) - : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {} - explicit SimpleParametricLoopTilingPass(const Options &options) { - sizes.assign(options.clOuterLoopSizes.begin(), - options.clOuterLoopSizes.end()); + SimpleParametricLoopTilingPass() = default; + SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {} + explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) { + sizes = outerLoopSizes; } void runOnFunction() override { @@ -49,7 +41,10 @@ public: }); } - SmallVector<int64_t, 4> sizes; + ListOption<int64_t> sizes{ + *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc( + "fixed number of iterations that the outer loops should have")}; }; } // end namespace @@ -58,8 +53,7 @@ mlir::createSimpleParametricTilingPass(ArrayRef<int64_t> outerLoopSizes) { return std::make_unique<SimpleParametricLoopTilingPass>(outerLoopSizes); } -static PassRegistration<SimpleParametricLoopTilingPass, - SimpleParametricLoopTilingPass::Options> +static PassRegistration<SimpleParametricLoopTilingPass> reg("test-extract-fixed-outer-loops", "test application of parametric tiling to the outer loops so that the " "ranges of outer loops become static"); |