diff options
Diffstat (limited to 'mlir/lib/Transforms/DialectConversion.cpp')
-rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 54 |
1 files changed, 16 insertions, 38 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index be60ada6a43..84f00b97e38 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -849,7 +849,7 @@ struct FunctionConverter { /// error, success otherwise. If 'signatureConversion' is provided, the /// arguments of the entry block are updated accordingly. LogicalResult - convertFunction(Function *f, + convertFunction(Function f, TypeConverter::SignatureConversion *signatureConversion); /// Converts the given region starting from the entry block and following the @@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, } LogicalResult FunctionConverter::convertFunction( - Function *f, TypeConverter::SignatureConversion *signatureConversion) { + Function f, TypeConverter::SignatureConversion *signatureConversion) { // If this is an external function, there is nothing else to do. - if (f->isExternal()) + if (f.isExternal()) return success(); - DialectConversionRewriter rewriter(f->getBody(), typeConverter); + DialectConversionRewriter rewriter(f.getBody(), typeConverter); // Update the signature of the entry block. if (signatureConversion) { rewriter.argConverter.convertSignature( - &f->getBody().front(), *signatureConversion, rewriter.mapping); + &f.getBody().front(), *signatureConversion, rewriter.mapping); } // Rewrite the function body. if (failed( - convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) { + convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) { // Reset any of the generated rewrites. rewriter.discardRewrites(); return failure(); @@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const // applyConversionPatterns //===----------------------------------------------------------------------===// -namespace { -/// This class represents a function to be converted. It allows for converting -/// the body of functions and the signature in two phases. -struct ConvertedFunction { - ConvertedFunction(Function *fn, FunctionType newType, - ArrayRef<NamedAttributeList> newFunctionArgAttrs) - : fn(fn), newType(newType), - newFunctionArgAttrs(newFunctionArgAttrs.begin(), - newFunctionArgAttrs.end()) {} - - /// The function to convert. - Function *fn; - /// The new type and argument attributes for the function. - FunctionType newType; - SmallVector<NamedAttributeList, 4> newFunctionArgAttrs; -}; -} // end anonymous namespace - /// Convert the given module with the provided conversion patterns and type /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. @@ -1149,37 +1131,33 @@ LogicalResult mlir::applyConversionPatterns(Module &module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { - std::vector<Function *> allFunctions; - allFunctions.reserve(module.getFunctions().size()); - for (auto &func : module) - allFunctions.push_back(&func); + SmallVector<Function, 32> allFunctions(module.getFunctions()); return applyConversionPatterns(allFunctions, target, converter, std::move(patterns)); } /// Convert the given functions with the provided conversion patterns. LogicalResult mlir::applyConversionPatterns( - ArrayRef<Function *> fns, ConversionTarget &target, + MutableArrayRef<Function> fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { if (fns.empty()) return success(); // Build the function converter. - FunctionConverter funcConverter(fns.front()->getContext(), target, patterns, - &converter); + auto *ctx = fns.front().getContext(); + FunctionConverter funcConverter(ctx, target, patterns, &converter); // Try to convert each of the functions within the module. - auto *ctx = fns.front()->getContext(); - for (auto *func : fns) { + for (auto func : fns) { // Convert the function type using the type converter. auto conversion = - converter.convertSignature(func->getType(), func->getAllArgAttrs()); + converter.convertSignature(func.getType(), func.getAllArgAttrs()); if (!conversion) return failure(); // Update the function signature. - func->setType(conversion->getConvertedType(ctx)); - func->setAllArgAttrs(conversion->getConvertedArgAttrs()); + func.setType(conversion->getConvertedType(ctx)); + func.setAllArgAttrs(conversion->getConvertedArgAttrs()); // Convert the body of this function. if (failed(funcConverter.convertFunction(func, &*conversion))) @@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns( /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function &fn, ConversionTarget &target, +mlir::applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. FunctionConverter converter(fn.getContext(), target, patterns); - return converter.convertFunction(&fn, /*signatureConversion=*/nullptr); + return converter.convertFunction(fn, /*signatureConversion=*/nullptr); } |