diff options
Diffstat (limited to 'mlir/lib/Pass/PassRegistry.cpp')
-rw-r--r-- | mlir/lib/Pass/PassRegistry.cpp | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp new file mode 100644 index 00000000000..1c5193d0539 --- /dev/null +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -0,0 +1,542 @@ +//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; +using namespace detail; + +/// Static mapping of all of the registered passes. +static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry; + +/// Static mapping of all of the registered pass pipelines. +static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> + passPipelineRegistry; + +/// Utility to create a default registry function from a pass instance. +static PassRegistryFunction +buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { + return [=](OpPassManager &pm, StringRef options) { + std::unique_ptr<Pass> pass = allocator(); + LogicalResult result = pass->initializeOptions(options); + pm.addPass(std::move(pass)); + return result; + }; +} + +//===----------------------------------------------------------------------===// +// PassPipelineInfo +//===----------------------------------------------------------------------===// + +void mlir::registerPassPipeline(StringRef arg, StringRef description, + const PassRegistryFunction &function) { + PassPipelineInfo pipelineInfo(arg, description, function); + bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; + assert(inserted && "Pass pipeline registered multiple times"); + (void)inserted; +} + +//===----------------------------------------------------------------------===// +// PassInfo +//===----------------------------------------------------------------------===// + +PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID, + const PassAllocatorFunction &allocator) + : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {} + +void mlir::registerPass(StringRef arg, StringRef description, + const PassID *passID, + const PassAllocatorFunction &function) { + PassInfo passInfo(arg, description, passID, function); + bool inserted = passRegistry->try_emplace(passID, passInfo).second; + assert(inserted && "Pass registered multiple times"); + (void)inserted; +} + +/// Returns the pass info for the specified pass class or null if unknown. +const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) { + auto it = passRegistry->find(passID); + if (it == passRegistry->end()) + return nullptr; + return &it->getSecond(); +} + +//===----------------------------------------------------------------------===// +// PassOptions +//===----------------------------------------------------------------------===// + +/// Out of line virtual function to provide home for the class. +void detail::PassOptions::OptionBase::anchor() {} + +/// Copy the option values from 'other'. +void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { + assert(options.size() == other.options.size()); + if (options.empty()) + return; + for (auto optionsIt : llvm::zip(options, other.options)) + std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); +} + +LogicalResult detail::PassOptions::parseFromString(StringRef options) { + // TODO(parkers): Handle escaping strings. + // NOTE: `options` is modified in place to always refer to the unprocessed + // part of the string. + while (!options.empty()) { + size_t spacePos = options.find(' '); + StringRef arg = options; + if (spacePos != StringRef::npos) { + arg = options.substr(0, spacePos); + options = options.substr(spacePos + 1); + } else { + options = StringRef(); + } + if (arg.empty()) + continue; + + // At this point, arg refers to everything that is non-space in options + // upto the next space, and options refers to the rest of the string after + // that point. + + // Split the individual option on '=' to form key and value. If there is no + // '=', then value is `StringRef()`. + size_t equalPos = arg.find('='); + StringRef key = arg; + StringRef value; + if (equalPos != StringRef::npos) { + key = arg.substr(0, equalPos); + value = arg.substr(equalPos + 1); + } + auto it = OptionsMap.find(key); + if (it == OptionsMap.end()) { + llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; + return failure(); + } + if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) + return failure(); + } + + return success(); +} + +/// Print the options held by this struct in a form that can be parsed via +/// 'parseFromString'. +void detail::PassOptions::print(raw_ostream &os) { + // If there are no options, there is nothing left to do. + if (OptionsMap.empty()) + return; + + // Sort the options to make the ordering deterministic. + SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end()); + llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(), + [](OptionBase *const *lhs, OptionBase *const *rhs) { + return (*lhs)->getArgStr().compare( + (*rhs)->getArgStr()); + }); + + // Interleave the options with ' '. + os << '{'; + interleave( + orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " "); + os << '}'; +} + +//===----------------------------------------------------------------------===// +// TextualPassPipeline Parser +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a textual description of a pass pipeline. +class TextualPipeline { +public: + /// Try to initialize this pipeline with the given pipeline text. + /// `errorStream` is the output stream to emit errors to. + LogicalResult initialize(StringRef text, raw_ostream &errorStream); + + /// Add the internal pipeline elements to the provided pass manager. + LogicalResult addToPipeline(OpPassManager &pm) const; + +private: + /// A functor used to emit errors found during pipeline handling. The first + /// parameter corresponds to the raw location within the pipeline string. This + /// should always return failure. + using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; + + /// A struct to capture parsed pass pipeline names. + /// + /// A pipeline is defined as a series of names, each of which may in itself + /// recursively contain a nested pipeline. A name is either the name of a pass + /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name + /// is the name of a pass, the InnerPipeline is empty, since passes cannot + /// contain inner pipelines. + struct PipelineElement { + PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {} + + StringRef name; + StringRef options; + const PassRegistryEntry *registryEntry; + std::vector<PipelineElement> innerPipeline; + }; + + /// Parse the given pipeline text into the internal pipeline vector. This + /// function only parses the structure of the pipeline, and does not resolve + /// its elements. + LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); + + /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to + /// the corresponding registry entry. + LogicalResult + resolvePipelineElements(MutableArrayRef<PipelineElement> elements, + ErrorHandlerT errorHandler); + + /// Resolve a single element of the pipeline. + LogicalResult resolvePipelineElement(PipelineElement &element, + ErrorHandlerT errorHandler); + + /// Add the given pipeline elements to the provided pass manager. + LogicalResult addToPipeline(ArrayRef<PipelineElement> elements, + OpPassManager &pm) const; + + std::vector<PipelineElement> pipeline; +}; + +} // end anonymous namespace + +/// Try to initialize this pipeline with the given pipeline text. An option is +/// given to enable accurate error reporting. +LogicalResult TextualPipeline::initialize(StringRef text, + raw_ostream &errorStream) { + // Build a source manager to use for error reporting. + llvm::SourceMgr pipelineMgr; + pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer( + text, "MLIR Textual PassPipeline Parser"), + llvm::SMLoc()); + auto errorHandler = [&](const char *rawLoc, Twine msg) { + pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc), + llvm::SourceMgr::DK_Error, msg); + return failure(); + }; + + // Parse the provided pipeline string. + if (failed(parsePipelineText(text, errorHandler))) + return failure(); + return resolvePipelineElements(pipeline, errorHandler); +} + +/// Add the internal pipeline elements to the provided pass manager. +LogicalResult TextualPipeline::addToPipeline(OpPassManager &pm) const { + return addToPipeline(pipeline, pm); +} + +/// Parse the given pipeline text into the internal pipeline vector. This +/// function only parses the structure of the pipeline, and does not resolve +/// its elements. +LogicalResult TextualPipeline::parsePipelineText(StringRef text, + ErrorHandlerT errorHandler) { + SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; + for (;;) { + std::vector<PipelineElement> &pipeline = *pipelineStack.back(); + size_t pos = text.find_first_of(",(){"); + pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); + + // If we have a single terminating name, we're done. + if (pos == text.npos) + break; + + text = text.substr(pos); + char sep = text[0]; + + // Handle pulling ... from 'pass{...}' out as PipelineElement.options. + if (sep == '{') { + text = text.substr(1); + + // Skip over everything until the closing '}' and store as options. + size_t close = text.find('}'); + + // TODO(parkers): Handle skipping over quoted sub-strings. + if (close == StringRef::npos) { + return errorHandler( + /*rawLoc=*/text.data() - 1, + "missing closing '}' while processing pass options"); + } + pipeline.back().options = text.substr(0, close); + text = text.substr(close + 1); + + // Skip checking for '(' because nested pipelines cannot have options. + } else if (sep == '(') { + text = text.substr(1); + + // Push the inner pipeline onto the stack to continue processing. + pipelineStack.push_back(&pipeline.back().innerPipeline); + continue; + } + + // When handling the close parenthesis, we greedily consume them to avoid + // empty strings in the pipeline. + while (text.consume_front(")")) { + // If we try to pop the outer pipeline we have unbalanced parentheses. + if (pipelineStack.size() == 1) + return errorHandler(/*rawLoc=*/text.data() - 1, + "encountered extra closing ')' creating unbalanced " + "parentheses while parsing pipeline"); + + pipelineStack.pop_back(); + } + + // Check if we've finished parsing. + if (text.empty()) + break; + + // Otherwise, the end of an inner pipeline always has to be followed by + // a comma, and then we can continue. + if (!text.consume_front(",")) + return errorHandler(text.data(), "expected ',' after parsing pipeline"); + } + + // Check for unbalanced parentheses. + if (pipelineStack.size() > 1) + return errorHandler( + text.data(), + "encountered unbalanced parentheses while parsing pipeline"); + + assert(pipelineStack.back() == &pipeline && + "wrong pipeline at the bottom of the stack"); + return success(); +} + +/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to +/// the corresponding registry entry. +LogicalResult TextualPipeline::resolvePipelineElements( + MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { + for (auto &elt : elements) + if (failed(resolvePipelineElement(elt, errorHandler))) + return failure(); + return success(); +} + +/// Resolve a single element of the pipeline. +LogicalResult +TextualPipeline::resolvePipelineElement(PipelineElement &element, + ErrorHandlerT errorHandler) { + // If the inner pipeline of this element is not empty, this is an operation + // pipeline. + if (!element.innerPipeline.empty()) + return resolvePipelineElements(element.innerPipeline, errorHandler); + + // Otherwise, this must be a pass or pass pipeline. + // Check to see if a pipeline was registered with this name. + auto pipelineRegistryIt = passPipelineRegistry->find(element.name); + if (pipelineRegistryIt != passPipelineRegistry->end()) { + element.registryEntry = &pipelineRegistryIt->second; + return success(); + } + + // If not, then this must be a specific pass name. + for (auto &passIt : *passRegistry) { + if (passIt.second.getPassArgument() == element.name) { + element.registryEntry = &passIt.second; + return success(); + } + } + + // Emit an error for the unknown pass. + auto *rawLoc = element.name.data(); + return errorHandler(rawLoc, "'" + element.name + + "' does not refer to a " + "registered pass or pass pipeline"); +} + +/// Add the given pipeline elements to the provided pass manager. +LogicalResult TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements, + OpPassManager &pm) const { + for (auto &elt : elements) { + if (elt.registryEntry) { + if (failed(elt.registryEntry->addToPipeline(pm, elt.options))) + return failure(); + } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name)))) { + return failure(); + } + } + return success(); +} + +/// This function parses the textual representation of a pass pipeline, and adds +/// the result to 'pm' on success. This function returns failure if the given +/// pipeline was invalid. 'errorStream' is an optional parameter that, if +/// non-null, will be used to emit errors found during parsing. +LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, + raw_ostream &errorStream) { + TextualPipeline pipelineParser; + if (failed(pipelineParser.initialize(pipeline, errorStream))) + return failure(); + if (failed(pipelineParser.addToPipeline(pm))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// PassNameParser +//===----------------------------------------------------------------------===// + +namespace { +/// This struct represents the possible data entries in a parsed pass pipeline +/// list. +struct PassArgData { + PassArgData() : registryEntry(nullptr) {} + PassArgData(const PassRegistryEntry *registryEntry) + : registryEntry(registryEntry) {} + + /// This field is used when the parsed option corresponds to a registered pass + /// or pass pipeline. + const PassRegistryEntry *registryEntry; + + /// This field is set when instance specific pass options have been provided + /// on the command line. + StringRef options; + + /// This field is used when the parsed option corresponds to an explicit + /// pipeline. + TextualPipeline pipeline; +}; +} // end anonymous namespace + +namespace llvm { +namespace cl { +/// Define a valid OptionValue for the command line pass argument. +template <> +struct OptionValue<PassArgData> final + : OptionValueBase<PassArgData, /*isClass=*/true> { + OptionValue(const PassArgData &value) { this->setValue(value); } + OptionValue() = default; + void anchor() override {} + + bool hasValue() const { return true; } + const PassArgData &getValue() const { return value; } + void setValue(const PassArgData &value) { this->value = value; } + + PassArgData value; +}; +} // end namespace cl +} // end namespace llvm + +namespace { + +/// The name for the command line option used for parsing the textual pass +/// pipeline. +static constexpr StringLiteral passPipelineArg = "pass-pipeline"; + +/// Adds command line option for each registered pass or pass pipeline, as well +/// as textual pass pipelines. +struct PassNameParser : public llvm::cl::parser<PassArgData> { + PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} + + void initialize(); + void printOptionInfo(const llvm::cl::Option &opt, + size_t globalWidth) const override; + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + PassArgData &value); +}; +} // namespace + +void PassNameParser::initialize() { + llvm::cl::parser<PassArgData>::initialize(); + + /// Add an entry for the textual pass pipeline option. + addLiteralOption(passPipelineArg, PassArgData(), + "A textual description of a pass pipeline to run"); + + /// Add the pass entries. + for (const auto &kv : *passRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } + /// Add the pass pipeline entries. + for (const auto &kv : *passPipelineRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } +} + +void PassNameParser::printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const { + PassNameParser *TP = const_cast<PassNameParser *>(this); + llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), + [](const PassNameParser::OptionInfo *VT1, + const PassNameParser::OptionInfo *VT2) { + return VT1->Name.compare(VT2->Name); + }); + llvm::cl::parser<PassArgData>::printOptionInfo(O, GlobalWidth); +} + +bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, + StringRef arg, PassArgData &value) { + // Handle the pipeline option explicitly. + if (argName == passPipelineArg) + return failed(value.pipeline.initialize(arg, llvm::errs())); + + // Otherwise, default to the base for handling. + if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value)) + return true; + value.options = arg; + return false; +} + +//===----------------------------------------------------------------------===// +// PassPipelineCLParser +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace detail { +struct PassPipelineCLParserImpl { + PassPipelineCLParserImpl(StringRef arg, StringRef description) + : passList(arg, llvm::cl::desc(description)) { + passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); + } + + /// The set of passes and pass pipelines to run. + llvm::cl::list<PassArgData, bool, PassNameParser> passList; +}; +} // end namespace detail +} // end namespace mlir + +/// Construct a pass pipeline parser with the given command line description. +PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) + : impl(std::make_unique<detail::PassPipelineCLParserImpl>(arg, + description)) {} +PassPipelineCLParser::~PassPipelineCLParser() {} + +/// Returns true if this parser contains any valid options to add. +bool PassPipelineCLParser::hasAnyOccurrences() const { + return impl->passList.getNumOccurrences() != 0; +} + +/// Returns true if the given pass registry entry was registered at the +/// top-level of the parser, i.e. not within an explicit textual pipeline. +bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { + return llvm::any_of(impl->passList, [&](const PassArgData &data) { + return data.registryEntry == entry; + }); +} + +/// Adds the passes defined by this parser entry to the given pass manager. +LogicalResult PassPipelineCLParser::addToPipeline(OpPassManager &pm) const { + for (auto &passIt : impl->passList) { + if (passIt.registryEntry) { + if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options))) + return failure(); + } else if (failed(passIt.pipeline.addToPipeline(pm))) { + return failure(); + } + } + return success(); +} |