summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-07-01 10:29:09 -0700
committerjpienaar <jpienaar@google.com>2019-07-01 11:39:00 -0700
commit54cd6a7e97a226738e2c85b86559918dd9e3cd5d (patch)
treeaffa803347d6695be575137d1ad55a055a8021e3
parent84bd67fc4fd116e80f7a66bfadfe9a7fd6fd5e82 (diff)
downloadbcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.tar.gz
bcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.zip
NFC: Refactor Function to be value typed.
Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed). PiperOrigin-RevId: 255983022
-rw-r--r--mlir/bindings/python/pybind.cpp35
-rw-r--r--mlir/examples/Linalg/Linalg1/include/linalg1/Common.h24
-rw-r--r--mlir/examples/Linalg/Linalg2/Example.cpp20
-rw-r--r--mlir/examples/Linalg/Linalg3/Conversion.cpp22
-rw-r--r--mlir/examples/Linalg/Linalg3/Example.cpp32
-rw-r--r--mlir/examples/Linalg/Linalg3/Execution.cpp22
-rw-r--r--mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h6
-rw-r--r--mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp2
-rw-r--r--mlir/examples/Linalg/Linalg3/lib/Transforms.cpp12
-rw-r--r--mlir/examples/Linalg/Linalg4/Example.cpp40
-rw-r--r--mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h4
-rw-r--r--mlir/examples/Linalg/Linalg4/lib/Transforms.cpp9
-rw-r--r--mlir/examples/toy/Ch2/mlir/MLIRGen.cpp28
-rw-r--r--mlir/examples/toy/Ch3/mlir/MLIRGen.cpp28
-rw-r--r--mlir/examples/toy/Ch4/mlir/MLIRGen.cpp28
-rw-r--r--mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp54
-rw-r--r--mlir/examples/toy/Ch5/mlir/LateLowering.cpp22
-rw-r--r--mlir/examples/toy/Ch5/mlir/MLIRGen.cpp28
-rw-r--r--mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp54
-rw-r--r--mlir/include/mlir/Analysis/Dominance.h4
-rw-r--r--mlir/include/mlir/Analysis/NestedMatcher.h4
-rw-r--r--mlir/include/mlir/ExecutionEngine/MemRefUtils.h2
-rw-r--r--mlir/include/mlir/GPU/GPUDialect.h6
-rw-r--r--mlir/include/mlir/IR/Attributes.h3
-rw-r--r--mlir/include/mlir/IR/Block.h2
-rw-r--r--mlir/include/mlir/IR/Builders.h2
-rw-r--r--mlir/include/mlir/IR/Dialect.h10
-rw-r--r--mlir/include/mlir/IR/Function.h234
-rw-r--r--mlir/include/mlir/IR/Module.h66
-rw-r--r--mlir/include/mlir/IR/Operation.h2
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h2
-rw-r--r--mlir/include/mlir/IR/Region.h11
-rw-r--r--mlir/include/mlir/IR/SymbolTable.h12
-rw-r--r--mlir/include/mlir/IR/Value.h4
-rw-r--r--mlir/include/mlir/LLVMIR/LLVMDialect.h2
-rw-r--r--mlir/include/mlir/Pass/AnalysisManager.h18
-rw-r--r--mlir/include/mlir/Pass/Pass.h17
-rw-r--r--mlir/include/mlir/Pass/PassInstrumentation.h10
-rw-r--r--mlir/include/mlir/StandardOps/Ops.td4
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h4
-rw-r--r--mlir/include/mlir/Transforms/LowerAffine.h2
-rw-r--r--mlir/include/mlir/Transforms/ViewFunctionGraph.h4
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp2
-rw-r--r--mlir/lib/Analysis/Dominance.cpp9
-rw-r--r--mlir/lib/Analysis/OpStats.cpp2
-rw-r--r--mlir/lib/Analysis/TestParallelismDetection.cpp2
-rw-r--r--mlir/lib/Analysis/Verifier.cpp14
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp106
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp26
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp18
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp4
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp2
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/MemRefUtils.cpp10
-rw-r--r--mlir/lib/GPU/IR/GPUDialect.cpp18
-rw-r--r--mlir/lib/GPU/Transforms/KernelOutlining.cpp28
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp50
-rw-r--r--mlir/lib/IR/Attributes.cpp5
-rw-r--r--mlir/lib/IR/Block.cpp2
-rw-r--r--mlir/lib/IR/Builders.cpp4
-rw-r--r--mlir/lib/IR/Dialect.cpp15
-rw-r--r--mlir/lib/IR/Function.cpp59
-rw-r--r--mlir/lib/IR/Operation.cpp9
-rw-r--r--mlir/lib/IR/Region.cpp10
-rw-r--r--mlir/lib/IR/SymbolTable.cpp22
-rw-r--r--mlir/lib/IR/Value.cpp8
-rw-r--r--mlir/lib/LLVMIR/IR/LLVMDialect.cpp4
-rw-r--r--mlir/lib/Linalg/Transforms/Fusion.cpp2
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp69
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLoops.cpp3
-rw-r--r--mlir/lib/Linalg/Transforms/Tiling.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp28
-rw-r--r--mlir/lib/Pass/IRPrinting.cpp12
-rw-r--r--mlir/lib/Pass/Pass.cpp23
-rw-r--r--mlir/lib/Pass/PassDetail.h2
-rw-r--r--mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp2
-rw-r--r--mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp2
-rw-r--r--mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp2
-rw-r--r--mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp6
-rw-r--r--mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp2
-rw-r--r--mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp2
-rw-r--r--mlir/lib/StandardOps/Ops.cpp12
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp31
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp4
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp2
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp54
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp10
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp12
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp2
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp8
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp2
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp12
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp2
-rw-r--r--mlir/lib/Transforms/StripDebugInfo.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp4
-rw-r--r--mlir/lib/Transforms/ViewFunctionGraph.cpp4
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp150
-rw-r--r--mlir/test/lib/Transforms/TestVectorizationUtils.cpp16
-rw-r--r--mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp18
-rw-r--r--mlir/unittests/Pass/AnalysisManagerTest.cpp20
103 files changed, 986 insertions, 874 deletions
diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp
index 222ef52b9be..cdf4a7fe89c 100644
--- a/mlir/bindings/python/pybind.cpp
+++ b/mlir/bindings/python/pybind.cpp
@@ -17,6 +17,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
@@ -110,13 +111,14 @@ struct PythonValueHandle {
struct PythonFunction {
PythonFunction() : function{nullptr} {}
PythonFunction(mlir_func_t f) : function{f} {}
- PythonFunction(mlir::Function *f) : function{f} {}
+ PythonFunction(mlir::Function f)
+ : function(const_cast<void *>(f.getAsOpaquePointer())) {}
operator mlir_func_t() { return function; }
std::string str() {
- mlir::Function *f = reinterpret_cast<mlir::Function *>(function);
+ mlir::Function f = mlir::Function::getFromOpaquePointer(function);
std::string res;
llvm::raw_string_ostream os(res);
- f->print(os);
+ f.print(os);
return res;
}
@@ -124,18 +126,18 @@ struct PythonFunction {
// declaration, add the entry block, transforming the declaration into a
// definition. Return true if the block was added, false otherwise.
bool define() {
- auto *f = reinterpret_cast<mlir::Function *>(function);
- if (!f->getBlocks().empty())
+ auto f = mlir::Function::getFromOpaquePointer(function);
+ if (!f.getBlocks().empty())
return false;
- f->addEntryBlock();
+ f.addEntryBlock();
return true;
}
PythonValueHandle arg(unsigned index) {
- Function *f = static_cast<Function *>(function);
- assert(index < f->getNumArguments() && "argument index out of bounds");
- return PythonValueHandle(ValueHandle(f->getArgument(index)));
+ auto f = mlir::Function::getFromOpaquePointer(function);
+ assert(index < f.getNumArguments() && "argument index out of bounds");
+ return PythonValueHandle(ValueHandle(f.getArgument(index)));
}
mlir_func_t function;
@@ -250,10 +252,9 @@ struct PythonFunctionContext {
PythonFunction enter() {
assert(function.function && "function is not set up");
- auto *mlirFunc = static_cast<mlir::Function *>(function.function);
- contextBuilder.emplace(mlirFunc->getBody());
- context =
- new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc());
+ auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function);
+ contextBuilder.emplace(mlirFunc.getBody());
+ context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
return function;
}
@@ -594,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
}
// Create the function itself.
- auto *func = new mlir::Function(
+ auto func = mlir::Function::create(
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
inputAttrs);
@@ -652,9 +653,9 @@ PYBIND11_MODULE(pybind, m) {
return ValueHandle::create<ConstantFloatOp>(value, floatType);
});
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
- auto *function = reinterpret_cast<Function *>(func.function);
- auto attr = FunctionAttr::get(function);
- return ValueHandle::create<ConstantOp>(function->getType(), attr);
+ auto function = Function::getFromOpaquePointer(func.function);
+ auto attr = FunctionAttr::get(function.getName(), function.getContext());
+ return ValueHandle::create<ConstantOp>(function.getType(), attr);
});
m.def("appendTo", [](const PythonBlockHandle &handle) {
return PythonBlockAppender(handle);
diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
index ddd6df9fb89..1f129c6b283 100644
--- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
+++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
@@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
}
/// A basic function builder
-inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
- llvm::ArrayRef<mlir::Type> types,
- llvm::ArrayRef<mlir::Type> resultTypes) {
+inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
+ llvm::ArrayRef<mlir::Type> types,
+ llvm::ArrayRef<mlir::Type> resultTypes) {
auto *context = module.getContext();
- auto *function = new mlir::Function(
+ auto function = mlir::Function::create(
mlir::UnknownLoc::get(context), name,
mlir::FunctionType::get({types}, resultTypes, context));
- function->addEntryBlock();
- module.getFunctions().push_back(function);
+ function.addEntryBlock();
+ module.push_back(function);
return function;
}
@@ -83,19 +83,19 @@ inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
/// llvm::outs() for FileCheck'ing.
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
/// which will make the associated FileCheck test fail.
-inline void cleanupAndPrintFunction(mlir::Function *f) {
+inline void cleanupAndPrintFunction(mlir::Function f) {
bool printToOuts = true;
- auto check = [f, &printToOuts](mlir::LogicalResult result) {
+ auto check = [&f, &printToOuts](mlir::LogicalResult result) {
if (failed(result)) {
- f->emitError("Verification and cleanup passes failed");
+ f.emitError("Verification and cleanup passes failed");
printToOuts = false;
}
};
auto pm = cleanupPassManager();
- check(f->getModule()->verify());
- check(pm->run(f->getModule()));
+ check(f.getModule()->verify());
+ check(pm->run(f.getModule()));
if (printToOuts)
- f->print(llvm::outs());
+ f.print(llvm::outs());
}
/// Helper class to sugar building loop nests from indexings that appear in
diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp
index a415daebdf5..9534711f1f4 100644
--- a/mlir/examples/Linalg/Linalg2/Example.cpp
+++ b/mlir/examples/Linalg/Linalg2/Example.cpp
@@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
- mlir::Function *f =
+ mlir::Function f =
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
- ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
+ ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
@@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
- mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
- {indexType, indexType, indexType}, {});
+ mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
+ {indexType, indexType, indexType}, {});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
- ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
+ ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
@@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) {
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
// clang-format on
- f->walk<SliceOp>([](SliceOp slice) {
+ f.walk<SliceOp>([](SliceOp slice) {
auto *sliceResult = slice.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());
diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp
index 37d1b51f53e..23d1cfef5dc 100644
--- a/mlir/examples/Linalg/Linalg3/Conversion.cpp
+++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp
@@ -37,26 +37,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
- mlir::Function *f = linalg::common::makeFunction(
+ mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
- M = dim(f->getArgument(0), 0),
- N = dim(f->getArgument(2), 1),
- K = dim(f->getArgument(0), 1),
+ M = dim(f.getArgument(0), 0),
+ N = dim(f.getArgument(2), 1),
+ K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
- vA = view(f->getArgument(0), {rM, rK}),
- vB = view(f->getArgument(1), {rK, rN}),
- vC = view(f->getArgument(2), {rM, rN});
+ vA = view(f.getArgument(0), {rM, rK}),
+ vB = view(f.getArgument(1), {rK, rN}),
+ vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(foo) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
convertLinalg3ToLLVM(module);
diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp
index f02aef920e4..8b04344b19e 100644
--- a/mlir/examples/Linalg/Linalg3/Example.cpp
+++ b/mlir/examples/Linalg/Linalg3/Example.cpp
@@ -34,26 +34,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
- mlir::Function *f = linalg::common::makeFunction(
+ mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- mlir::OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ mlir::OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
- M = dim(f->getArgument(0), 0),
- N = dim(f->getArgument(2), 1),
- K = dim(f->getArgument(0), 1),
+ M = dim(f.getArgument(0), 0),
+ N = dim(f.getArgument(2), 1),
+ K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
- vA = view(f->getArgument(0), {rM, rK}),
- vB = view(f->getArgument(1), {rK, rN}),
- vC = view(f->getArgument(2), {rM, rN});
+ vA = view(f.getArgument(0), {rM, rK}),
+ vB = view(f.getArgument(1), {rK, rN}),
+ vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(matmul_as_matvec) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
// clang-format off
@@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) {
TEST_FUNC(matmul_as_dot) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
lowerToFinerGrainedTensorContraction(f);
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
@@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) {
TEST_FUNC(matmul_as_loops) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
composeSliceOps(f);
// clang-format off
@@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) {
TEST_FUNC(matmul_as_matvec_as_loops) {
MLIRContext context;
Module module(&context);
- mlir::Function *f =
+ mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
lowerToFinerGrainedTensorContraction(f);
lowerToLoops(f);
@@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
TEST_FUNC(matmul_as_matvec_as_affine) {
MLIRContext context;
Module module(&context);
- mlir::Function *f =
+ mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
lowerToLoops(f);
PassManager pm;
pm.addPass(createLowerLinalgLoadStorePass());
- if (succeeded(pm.run(f->getModule())))
+ if (succeeded(pm.run(f.getModule())))
cleanupAndPrintFunction(f);
// clang-format off
diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp
index 00d571cbc99..94b233a56b0 100644
--- a/mlir/examples/Linalg/Linalg3/Execution.cpp
+++ b/mlir/examples/Linalg/Linalg3/Execution.cpp
@@ -37,26 +37,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
- mlir::Function *f = linalg::common::makeFunction(
+ mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- mlir::OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ mlir::OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
- M = dim(f->getArgument(0), 0),
- N = dim(f->getArgument(2), 1),
- K = dim(f->getArgument(0), 1),
+ M = dim(f.getArgument(0), 0),
+ N = dim(f.getArgument(2), 1),
+ K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
- vA = view(f->getArgument(0), {rM, rK}),
- vB = view(f->getArgument(1), {rK, rN}),
- vC = view(f->getArgument(2), {rM, rN});
+ vA = view(f.getArgument(0), {rM, rK}),
+ vB = view(f.getArgument(1), {rK, rN}),
+ vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@@ -110,7 +110,7 @@ TEST_FUNC(execution) {
// dialect through partial conversions.
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
convertLinalg3ToLLVM(module);
diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h
index 9af528e8c51..6c0aec0b000 100644
--- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h
+++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h
@@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
/// to only use linalg.view operations.
-void composeSliceOps(mlir::Function *f);
+void composeSliceOps(mlir::Function f);
/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
/// as linalg.matvec(resp. linalg.dot).
-void lowerToFinerGrainedTensorContraction(mlir::Function *f);
+void lowerToFinerGrainedTensorContraction(mlir::Function f);
/// Operation-wise writing of linalg operations to loop form.
/// It is the caller's responsibility to erase the `op` if necessary.
@@ -69,7 +69,7 @@ llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
writeAsLoops(mlir::Operation *op);
/// Traverses `f` and rewrites linalg operations in loop form.
-void lowerToLoops(mlir::Function *f);
+void lowerToLoops(mlir::Function f);
/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and
/// affine.store operations.
diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
index 7b559bf2f21..96b0f371ef1 100644
--- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
+++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
@@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
void linalg::convertLinalg3ToLLVM(Module &module) {
// Remove affine constructs.
- for (auto &func : module) {
+ for (auto func : module) {
auto rr = lowerAffineConstructs(func);
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
index d5c8641acbe..7b9e5ffee96 100644
--- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
+++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
@@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
-void linalg::composeSliceOps(mlir::Function *f) {
- f->walk<SliceOp>([](SliceOp sliceOp) {
+void linalg::composeSliceOps(mlir::Function f) {
+ f.walk<SliceOp>([](SliceOp sliceOp) {
auto *sliceResult = sliceOp.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());
@@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) {
});
}
-void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
- f->walk([](Operation *op) {
+void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) {
+ f.walk([](Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
matmulOp.writeAsFinerGrainTensorContraction();
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
@@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) {
return llvm::None;
}
-void linalg::lowerToLoops(mlir::Function *f) {
- f->walk([](Operation *op) {
+void linalg::lowerToLoops(mlir::Function f) {
+ f.walk([](Operation *op) {
if (writeAsLoops(op))
op->erase();
});
diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp
index cdc05a1cc21..873e57e78f3 100644
--- a/mlir/examples/Linalg/Linalg4/Example.cpp
+++ b/mlir/examples/Linalg/Linalg4/Example.cpp
@@ -34,27 +34,27 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
- mlir::Function *f = linalg::common::makeFunction(
+ mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
- M = dim(f->getArgument(0), 0),
- N = dim(f->getArgument(2), 1),
- K = dim(f->getArgument(0), 1),
+ M = dim(f.getArgument(0), 0),
+ N = dim(f.getArgument(2), 1),
+ K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
- vA = view(f->getArgument(0), {rM, rK}),
- vB = view(f->getArgument(1), {rK, rN}),
- vC = view(f->getArgument(2), {rM, rN});
+ vA = view(f.getArgument(0), {rM, rK}),
+ vB = view(f.getArgument(1), {rK, rN}),
+ vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(matmul_tiled_loops) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
lowerToTiledLoops(f, {8, 9});
PassManager pm;
pm.addPass(createLowerLinalgLoadStorePass());
- if (succeeded(pm.run(f->getModule())))
+ if (succeeded(pm.run(f.getModule())))
cleanupAndPrintFunction(f);
// clang-format off
@@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) {
TEST_FUNC(matmul_tiled_views) {
MLIRContext context;
Module module(&context);
- mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
- OpBuilder b(f->getBody());
- lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
- b.create<ConstantIndexOp>(f->getLoc(), 9)});
+ mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
+ OpBuilder b(f.getBody());
+ lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
+ b.create<ConstantIndexOp>(f.getLoc(), 9)});
composeSliceOps(f);
// clang-format off
@@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) {
TEST_FUNC(matmul_tiled_views_as_loops) {
MLIRContext context;
Module module(&context);
- mlir::Function *f =
+ mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
- OpBuilder b(f->getBody());
- lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
- b.create<ConstantIndexOp>(f->getLoc(), 9)});
+ OpBuilder b(f.getBody());
+ lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
+ b.create<ConstantIndexOp>(f.getLoc(), 9)});
composeSliceOps(f);
lowerToLoops(f);
// This cannot lower below linalg.load and linalg.store due to lost
diff --git a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h
index 2165cab6ac1..ba7273e409d 100644
--- a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h
+++ b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h
@@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes);
/// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
/// in a function can generally not be specified.
-void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef<uint64_t> tileSizes);
+void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef<uint64_t> tileSizes);
/// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
/// in a function can generally not be specified.
-void lowerToTiledViews(mlir::Function *f,
+void lowerToTiledViews(mlir::Function f,
llvm::ArrayRef<mlir::Value *> tileSizes);
} // namespace linalg
diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp
index 1a308df1313..16b395da506 100644
--- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp
+++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp
@@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) {
return llvm::None;
}
-void linalg::lowerToTiledLoops(mlir::Function *f,
- ArrayRef<uint64_t> tileSizes) {
- f->walk([tileSizes](Operation *op) {
+void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef<uint64_t> tileSizes) {
+ f.walk([tileSizes](Operation *op) {
if (writeAsTiledLoops(op, tileSizes).hasValue())
op->erase();
});
@@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) {
return llvm::None;
}
-void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) {
- f->walk([tileSizes](Operation *op) {
+void linalg::lowerToTiledViews(mlir::Function f, ArrayRef<Value *> tileSizes) {
+ f.walk([tileSizes](Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
writeAsTiledViews(matmulOp, tileSizes);
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
index 842c7a1d0f8..73789fa41a4 100644
--- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
@@ -75,7 +75,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->getFunctions().push_back(func.release());
+ theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@@ -129,40 +129,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
- mlir::Function *mlirGen(PrototypeAST &proto) {
+ mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
- if (function->getNumArguments())
- function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+ if (function.getNumArguments())
+ function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
- std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
- std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
- function->addEntryBlock();
+ function.addEntryBlock();
- auto &entryBlock = function->front();
+ auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@@ -172,16 +172,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
- builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
+ builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
- if (!mlirGen(*funcAST.getBody()))
+ if (!mlirGen(*funcAST.getBody())) {
+ function.erase();
return nullptr;
+ }
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function->getBlocks().back().back().getName().getStringRef() !=
+ if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
index e365f37f8c8..23cb85309c2 100644
--- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
@@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->getFunctions().push_back(func.release());
+ theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
- mlir::Function *mlirGen(PrototypeAST &proto) {
+ mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
- if (function->getNumArguments())
- function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+ if (function.getNumArguments())
+ function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
- std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
- std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
- function->addEntryBlock();
+ function.addEntryBlock();
- auto &entryBlock = function->front();
+ auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
- builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
+ builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
- if (!mlirGen(*funcAST.getBody()))
+ if (!mlirGen(*funcAST.getBody())) {
+ function.erase();
return nullptr;
+ }
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function->getBlocks().back().back().getName().getStringRef() !=
+ if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
index 032766a547f..f2132c29c33 100644
--- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
@@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->getFunctions().push_back(func.release());
+ theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
- mlir::Function *mlirGen(PrototypeAST &proto) {
+ mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
- if (function->getNumArguments())
- function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+ if (function.getNumArguments())
+ function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
- std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
- std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
- function->addEntryBlock();
+ function.addEntryBlock();
- auto &entryBlock = function->front();
+ auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
- builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
+ builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
- if (!mlirGen(*funcAST.getBody()))
+ if (!mlirGen(*funcAST.getBody())) {
+ function.erase();
return nullptr;
+ }
// Implicitly return void if no return statement was emited.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function->getBlocks().back().back().getName().getStringRef() !=
+ if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index 688c73645a5..f237fd9fb53 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -113,14 +113,14 @@ public:
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
- mlir::Function *function;
+ mlir::Function function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
void runOnModule() override {
auto &module = getModule();
- auto *main = module.getNamedFunction("main");
+ auto main = module.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@@ -139,7 +139,7 @@ public:
// Delete any generic function left
// FIXME: we may want this as a separate pass.
- for (mlir::Function &function : llvm::make_early_inc_range(module)) {
+ for (mlir::Function function : llvm::make_early_inc_range(module)) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
@@ -153,7 +153,7 @@ public:
mlir::LogicalResult
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
- mlir::Function *f = functionToSpecialize.function;
+ mlir::Function f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
@@ -169,36 +169,36 @@ public:
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
- auto *newFunction = new mlir::Function(
- f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
- getModule().getFunctions().push_back(newFunction);
+ auto newFunction = mlir::Function::create(
+ f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
+ getModule().push_back(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
- f->cloneInto(newFunction, mapper);
+ f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
- f->dump();
+ f.dump();
llvm::dbgs() << "====== Into : \n";
- newFunction->dump();
+ newFunction.dump();
});
f = newFunction;
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
- auto &entryBlock = f->getBlocks().front();
+ auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
- assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
- entryBlock.addArguments(f->getType().getInputs());
+ assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
+ entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
- assert(succeeded(f->verify()));
+ assert(succeeded(f.verify()));
}
LLVM_DEBUG(llvm::dbgs()
- << "Run shape inference on : '" << f->getName() << "'\n");
+ << "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
@@ -211,7 +211,7 @@ public:
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
- f->walk([&](mlir::Operation *op) {
+ f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
@@ -292,9 +292,9 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
- auto *callee = getModule().getNamedFunction(calleeName);
+ auto callee = getModule().getNamedFunction(calleeName);
if (!callee) {
- f->emitError("Shape inference failed, call to unknown '")
+ f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
signalPassFailure();
return mlir::failure();
@@ -302,7 +302,7 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
- auto *mangledCallee = getModule().getNamedFunction(mangledName);
+ auto mangledCallee = getModule().getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
@@ -327,7 +327,7 @@ public:
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
@@ -337,31 +337,31 @@ public:
<< " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
errorMsg << " - " << *ope << "\n";
- f->emitError(errorMsg.str());
+ f.emitError(errorMsg.str());
signalPassFailure();
return mlir::failure();
}
// Finally, update the return type of the function based on the argument to
// the return operation.
- for (auto &block : f->getBlocks()) {
+ for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
- f->getType().getResult(0) == ret.getOperand()->getType())
+ f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
- for (auto arg : f->getArguments())
+ for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
- f->setType(newType);
- assert(succeeded(f->verify()));
+ f.setType(newType);
+ assert(succeeded(f.verify()));
break;
}
return mlir::success();
diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
index 8b2a3927d78..60a8b5a3b9a 100644
--- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
@@ -136,14 +136,14 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
- Function *printfFunc = getPrintf(*op->getFunction()->getModule());
+ Function printfFunc = getPrintf(*op->getFunction().getModule());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
- Type retTy = printfFunc->getType().getResult(0);
+ Type retTy = printfFunc.getType().getResult(0);
// Create our loop nest now
using namespace edsc;
@@ -205,8 +205,8 @@ private:
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
- Function *getPrintf(Module &module) const {
- auto *printfFunc = module.getNamedFunction("printf");
+ Function getPrintf(Module &module) const {
+ auto printfFunc = module.getNamedFunction("printf");
if (printfFunc)
return printfFunc;
@@ -218,10 +218,10 @@ private:
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
- printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
+ printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
- printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
- module.getFunctions().push_back(printfFunc);
+ printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
+ module.push_back(printfFunc);
return printfFunc;
}
};
@@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
// affine dialect: they already include conversion to the LLVM dialect.
// First patch calls type to return memref instead of ToyArray
- for (auto &function : getModule()) {
+ for (auto function : getModule()) {
function.walk([&](Operation *op) {
auto callOp = dyn_cast<CallOp>(op);
if (!callOp)
@@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
});
}
- for (auto &function : getModule()) {
+ for (auto function : getModule()) {
function.walk([&](Operation *op) {
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
@@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
}
// Lower Linalg to affine
- for (auto &function : getModule())
- linalg::lowerToLoops(&function);
+ for (auto function : getModule())
+ linalg::lowerToLoops(function);
getModule().dump();
diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
index f7e6fad568e..9ebfeb438ca 100644
--- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
@@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->getFunctions().push_back(func.release());
+ theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
- mlir::Function *mlirGen(PrototypeAST &proto) {
+ mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
+ func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
- if (function->getNumArguments())
- function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+ if (function.getNumArguments())
+ function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
- std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+ mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
- std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+ mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
- function->addEntryBlock();
+ function.addEntryBlock();
- auto &entryBlock = function->front();
+ auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
- builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
+ builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
- if (!mlirGen(*funcAST.getBody()))
+ if (!mlirGen(*funcAST.getBody())) {
+ function.erase();
return nullptr;
+ }
// Implicitly return void if no return statement was emited.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function->getBlocks().back().back().getName().getStringRef() !=
+ if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index cad2deda57e..0abcb4bb850 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -113,7 +113,7 @@ public:
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
- mlir::Function *function;
+ mlir::Function function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
@@ -121,7 +121,7 @@ public:
void runOnModule() override {
auto &module = getModule();
mlir::ModuleManager moduleManager(&module);
- auto *main = moduleManager.getNamedFunction("main");
+ auto main = moduleManager.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@@ -140,7 +140,7 @@ public:
// Delete any generic function left
// FIXME: we may want this as a separate pass.
- for (mlir::Function &function : llvm::make_early_inc_range(module)) {
+ for (mlir::Function function : llvm::make_early_inc_range(module)) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
@@ -155,7 +155,7 @@ public:
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
- mlir::Function *f = functionToSpecialize.function;
+ mlir::Function f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
@@ -171,36 +171,36 @@ public:
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
- auto *newFunction = new mlir::Function(
- f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
+ auto newFunction = mlir::Function::create(
+ f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
- f->cloneInto(newFunction, mapper);
+ f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
- f->dump();
+ f.dump();
llvm::dbgs() << "====== Into : \n";
- newFunction->dump();
+ newFunction.dump();
});
f = newFunction;
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
- auto &entryBlock = f->getBlocks().front();
+ auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
- assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
- entryBlock.addArguments(f->getType().getInputs());
+ assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
+ entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
- assert(succeeded(f->verify()));
+ assert(succeeded(f.verify()));
}
LLVM_DEBUG(llvm::dbgs()
- << "Run shape inference on : '" << f->getName() << "'\n");
+ << "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
@@ -212,7 +212,7 @@ public:
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
- f->walk([&](mlir::Operation *op) {
+ f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
@@ -295,16 +295,16 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
- auto *callee = moduleManager.getNamedFunction(calleeName);
+ auto callee = moduleManager.getNamedFunction(calleeName);
if (!callee) {
signalPassFailure();
- return f->emitError("Shape inference failed, call to unknown '")
+ return f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
- auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
+ auto mangledCallee = moduleManager.getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
@@ -315,7 +315,7 @@ public:
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
- mlir::OpBuilder builder(f->getBody());
+ mlir::OpBuilder builder(f.getBody());
builder.setInsertionPoint(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
@@ -330,12 +330,12 @@ public:
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
signalPassFailure();
- auto diag = f->emitError("Shape inference failed, ")
+ auto diag = f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
diag << " - " << *ope << "\n";
@@ -344,24 +344,24 @@ public:
// Finally, update the return type of the function based on the argument to
// the return operation.
- for (auto &block : f->getBlocks()) {
+ for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
- f->getType().getResult(0) == ret.getOperand()->getType())
+ f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
- for (auto arg : f->getArguments())
+ for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
- f->setType(newType);
- assert(succeeded(f->verify()));
+ f.setType(newType);
+ assert(succeeded(f.verify()));
break;
}
return mlir::success();
diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h
index e69756e73f4..8d7b2d59afe 100644
--- a/mlir/include/mlir/Analysis/Dominance.h
+++ b/mlir/include/mlir/Analysis/Dominance.h
@@ -34,7 +34,7 @@ template <bool IsPostDom> class DominanceInfoBase {
using base = llvm::DominatorTreeBase<Block, IsPostDom>;
public:
- DominanceInfoBase(Function *function) { recalculate(function); }
+ DominanceInfoBase(Function function) { recalculate(function); }
DominanceInfoBase(Operation *op) { recalculate(op); }
DominanceInfoBase(DominanceInfoBase &&) = default;
DominanceInfoBase &operator=(DominanceInfoBase &&) = default;
@@ -43,7 +43,7 @@ public:
DominanceInfoBase &operator=(const DominanceInfoBase &) = delete;
/// Recalculate the dominance info.
- void recalculate(Function *function);
+ void recalculate(Function function);
void recalculate(Operation *op);
/// Get the root dominance node of the given region.
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h
index 3ab24f84640..b89011a28e3 100644
--- a/mlir/include/mlir/Analysis/NestedMatcher.h
+++ b/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -104,8 +104,8 @@ struct NestedPattern {
NestedPattern &operator=(const NestedPattern &) = default;
/// Returns all the top-level matches in `func`.
- void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
- func->walk([&](Operation *op) { matchOne(op, matches); });
+ void match(Function func, SmallVectorImpl<NestedMatch> *matches) {
+ func.walk([&](Operation *op) { matchOne(op, matches); });
}
/// Returns all the top-level matches in `op`.
diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index a2d982d299b..3d20eaff46c 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -44,7 +44,7 @@ struct StaticFloatMemRef {
/// each of the arguments, initialize the storage with `initialValue`, and
/// return a list of type-erased descriptor pointers.
llvm::Expected<SmallVector<void *, 8>>
-allocateMemRefArguments(Function *func, float initialValue = 0.0);
+allocateMemRefArguments(Function func, float initialValue = 0.0);
/// Free a list of type-erased descriptors to statically-shaped memrefs with
/// element type f32.
diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h
index 8f682ce7c2e..c0326deb7cd 100644
--- a/mlir/include/mlir/GPU/GPUDialect.h
+++ b/mlir/include/mlir/GPU/GPUDialect.h
@@ -44,7 +44,7 @@ public:
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
- static bool isKernel(Function *function);
+ static bool isKernel(Function function);
};
/// Utility class for the GPU dialect to represent triples of `Value`s
@@ -122,12 +122,12 @@ public:
using Op::Op;
static void build(Builder *builder, OperationState *result,
- Function *kernelFunc, Value *gridSizeX, Value *gridSizeY,
+ Function kernelFunc, Value *gridSizeX, Value *gridSizeY,
Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
Value *blockSizeZ, ArrayRef<Value *> kernelOperands);
static void build(Builder *builder, OperationState *result,
- Function *kernelFunc, KernelDim3 gridSize,
+ Function kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
/// The kernel function specified by the operation's `kernel` attribute.
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 5b9bfca35ad..b46e160174b 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -313,9 +313,8 @@ class FunctionAttr
detail::StringAttributeStorage> {
public:
using Base::Base;
- using ValueType = Function *;
+ using ValueType = StringRef;
- static FunctionAttr get(Function *value);
static FunctionAttr get(StringRef value, MLIRContext *ctx);
/// Returns the name of the held function reference.
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index f4ecb4ec6d7..feae5c93fea 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -101,7 +101,7 @@ public:
/// Returns the function that this block is part of, even if the block is
/// nested under an operation region.
- Function *getFunction();
+ Function getFunction();
/// Insert this block (which must not already be in a function) right before
/// the specified block.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 6ce5c22eadc..e5c8c035c46 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -112,7 +112,7 @@ public:
AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
- FunctionAttr getFunctionAttr(Function *value);
+ FunctionAttr getFunctionAttr(Function value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 56d06619c79..4e82689efff 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -145,17 +145,13 @@ public:
/// Verify an attribute from this dialect on the given function. Returns
/// failure if the verification failed, success otherwise.
- virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) {
- return success();
- }
+ virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute);
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the given function. Returns failure if the verification failed, success
/// otherwise.
- virtual LogicalResult
- verifyFunctionArgAttribute(Function *, unsigned argIndex, NamedAttribute) {
- return success();
- }
+ virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex,
+ NamedAttribute);
/// Verify an attribute from this dialect on the given operation. Returns
/// failure if the verification failed, success otherwise.
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 8f3b3b0df13..e11a45ba033 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -29,29 +29,79 @@
namespace mlir {
class BlockAndValueMapping;
class FunctionType;
+class Function;
class MLIRContext;
class Module;
-/// This is the base class for all of the MLIR function types.
-class Function : public llvm::ilist_node_with_parent<Function, Module> {
+namespace detail {
+/// This class represents all of the internal state of a Function. This allows
+/// for the Function class to be value typed.
+class FunctionStorage
+ : public llvm::ilist_node_with_parent<FunctionStorage, Module> {
+ FunctionStorage(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs = {});
+ FunctionStorage(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<NamedAttributeList> argAttrs);
+ /// The name of the function.
+ Identifier name;
+
+ /// The module this function is embedded into.
+ Module *module = nullptr;
+
+ /// The source location the function was defined or derived from.
+ Location location;
+
+ /// The type of the function.
+ FunctionType type;
+
+ /// This holds general named attributes for the function.
+ NamedAttributeList attrs;
+
+ /// The attributes lists for each of the function arguments.
+ std::vector<NamedAttributeList> argAttrs;
+
+ /// The body of the function.
+ Region body;
+
+ friend struct llvm::ilist_traits<FunctionStorage>;
+ friend Function;
+};
+} // namespace detail
+
+/// This class represents an MLIR function, or the common unit of computation.
+/// The region of a function is not allowed to implicitly capture global values,
+/// and all external references must use Function arguments or attributes.
+class Function {
public:
- Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs = {});
- Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs,
- ArrayRef<NamedAttributeList> argAttrs);
+ Function(detail::FunctionStorage *impl = nullptr) : impl(impl) {}
+
+ static Function create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs = {}) {
+ return new detail::FunctionStorage(location, name, type, attrs);
+ }
+ static Function create(Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<NamedAttributeList> argAttrs) {
+ return new detail::FunctionStorage(location, name, type, attrs, argAttrs);
+ }
+
+ /// Allow converting a Function to bool for null checks.
+ operator bool() const { return impl; }
+ bool operator==(Function other) const { return impl == other.impl; }
+ bool operator!=(Function other) const { return !(*this == other); }
/// The source location the function was defined or derived from.
- Location getLoc() { return location; }
+ Location getLoc() { return impl->location; }
/// Set the source location this function was defined or derived from.
- void setLoc(Location loc) { location = loc; }
+ void setLoc(Location loc) { impl->location = loc; }
/// Return the name of this function, without the @.
- Identifier getName() { return name; }
+ Identifier getName() { return impl->name; }
/// Return the type of this function.
- FunctionType getType() { return type; }
+ FunctionType getType() { return impl->type; }
/// Change the type of this function in place. This is an extremely dangerous
/// operation and it is up to the caller to ensure that this is legal for this
@@ -61,12 +111,12 @@ public:
/// parameters we drop the extra attributes, if there are more parameters
/// they won't have any attributes.
void setType(FunctionType newType) {
- type = newType;
- argAttrs.resize(type.getNumInputs());
+ impl->type = newType;
+ impl->argAttrs.resize(newType.getNumInputs());
}
MLIRContext *getContext();
- Module *getModule() { return module; }
+ Module *getModule() { return impl->module; }
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
@@ -82,28 +132,28 @@ public:
// Body Handling
//===--------------------------------------------------------------------===//
- Region &getBody() { return body; }
- void eraseBody() { body.getBlocks().clear(); }
+ Region &getBody() { return impl->body; }
+ void eraseBody() { getBody().getBlocks().clear(); }
/// This is the list of blocks in the function.
using RegionType = Region::RegionType;
- RegionType &getBlocks() { return body.getBlocks(); }
+ RegionType &getBlocks() { return getBody().getBlocks(); }
// Iteration over the block in the function.
using iterator = RegionType::iterator;
using reverse_iterator = RegionType::reverse_iterator;
- iterator begin() { return body.begin(); }
- iterator end() { return body.end(); }
- reverse_iterator rbegin() { return body.rbegin(); }
- reverse_iterator rend() { return body.rend(); }
+ iterator begin() { return getBody().begin(); }
+ iterator end() { return getBody().end(); }
+ reverse_iterator rbegin() { return getBody().rbegin(); }
+ reverse_iterator rend() { return getBody().rend(); }
- bool empty() { return body.empty(); }
- void push_back(Block *block) { body.push_back(block); }
- void push_front(Block *block) { body.push_front(block); }
+ bool empty() { return getBody().empty(); }
+ void push_back(Block *block) { getBody().push_back(block); }
+ void push_front(Block *block) { getBody().push_front(block); }
- Block &back() { return body.back(); }
- Block &front() { return body.front(); }
+ Block &back() { return getBody().back(); }
+ Block &front() { return getBody().front(); }
//===--------------------------------------------------------------------===//
// Operation Walkers
@@ -150,53 +200,55 @@ public:
/// the lifetime of an function.
/// Return all of the attributes on this function.
- ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
+ ArrayRef<NamedAttribute> getAttrs() { return impl->attrs.getAttrs(); }
/// Return the internal attribute list on this function.
- NamedAttributeList &getAttrList() { return attrs; }
+ NamedAttributeList &getAttrList() { return impl->attrs; }
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
assert(index < getNumArguments() && "invalid argument number");
- return argAttrs[index].getAttrs();
+ return impl->argAttrs[index].getAttrs();
}
/// Set the attributes held by this function.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
- attrs.setAttrs(attributes);
+ impl->attrs.setAttrs(attributes);
}
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
- argAttrs[index].setAttrs(attributes);
+ impl->argAttrs[index].setAttrs(attributes);
}
void setArgAttrs(unsigned index, NamedAttributeList attributes) {
assert(index < getNumArguments() && "invalid argument number");
- argAttrs[index] = attributes;
+ impl->argAttrs[index] = attributes;
}
void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
assert(attributes.size() == getNumArguments());
for (unsigned i = 0, e = attributes.size(); i != e; ++i)
- argAttrs[i] = attributes[i];
+ impl->argAttrs[i] = attributes[i];
}
/// Return all argument attributes of this function.
- MutableArrayRef<NamedAttributeList> getAllArgAttrs() { return argAttrs; }
+ MutableArrayRef<NamedAttributeList> getAllArgAttrs() {
+ return impl->argAttrs;
+ }
/// Return the specified attribute if present, null otherwise.
- Attribute getAttr(Identifier name) { return attrs.get(name); }
- Attribute getAttr(StringRef name) { return attrs.get(name); }
+ Attribute getAttr(Identifier name) { return impl->attrs.get(name); }
+ Attribute getAttr(StringRef name) { return impl->attrs.get(name); }
/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
Attribute getArgAttr(unsigned index, Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
- return argAttrs[index].get(name);
+ return impl->argAttrs[index].get(name);
}
Attribute getArgAttr(unsigned index, StringRef name) {
assert(index < getNumArguments() && "invalid argument number");
- return argAttrs[index].get(name);
+ return impl->argAttrs[index].get(name);
}
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
@@ -219,13 +271,15 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
- void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
+ void setAttr(Identifier name, Attribute value) {
+ impl->attrs.set(name, value);
+ }
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
void setArgAttr(unsigned index, Identifier name, Attribute value) {
assert(index < getNumArguments() && "invalid argument number");
- argAttrs[index].set(name, value);
+ impl->argAttrs[index].set(name, value);
}
void setArgAttr(unsigned index, StringRef name, Attribute value) {
setArgAttr(index, Identifier::get(name, getContext()), value);
@@ -234,12 +288,12 @@ public:
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
- return attrs.remove(name);
+ return impl->attrs.remove(name);
}
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
- return attrs.remove(name);
+ return impl->attrs.remove(name);
}
//===--------------------------------------------------------------------===//
@@ -281,44 +335,37 @@ public:
/// contains entries for function arguments, these arguments are not included
/// in the new function. Replaces references to cloned sub-values with the
/// corresponding value that is copied, and adds those mappings to the mapper.
- Function *clone(BlockAndValueMapping &mapper);
- Function *clone();
+ Function clone(BlockAndValueMapping &mapper);
+ Function clone();
/// Clone the internal blocks and attributes from this function into dest. Any
/// cloned blocks are appended to the back of dest. This function asserts that
/// the attributes of the current function and dest are compatible.
- void cloneInto(Function *dest, BlockAndValueMapping &mapper);
+ void cloneInto(Function dest, BlockAndValueMapping &mapper);
+
+ /// Methods for supporting PointerLikeTypeTraits.
+ const void *getAsOpaquePointer() const {
+ return static_cast<const void *>(impl);
+ }
+ static Function getFromOpaquePointer(const void *pointer) {
+ return reinterpret_cast<detail::FunctionStorage *>(
+ const_cast<void *>(pointer));
+ }
private:
/// Set the name of this function.
- void setName(Identifier newName) { name = newName; }
-
- /// The name of the function.
- Identifier name;
-
- /// The module this function is embedded into.
- Module *module = nullptr;
-
- /// The source location the function was defined or derived from.
- Location location;
-
- /// The type of the function.
- FunctionType type;
-
- /// This holds general named attributes for the function.
- NamedAttributeList attrs;
+ void setName(Identifier newName) { impl->name = newName; }
- /// The attributes lists for each of the function arguments.
- std::vector<NamedAttributeList> argAttrs;
-
- /// The body of the function.
- Region body;
-
- void operator=(Function &) = delete;
- friend struct llvm::ilist_traits<Function>;
+ /// A pointer to the impl storage instance for this function. This allows for
+ /// 'Function' to be treated as a value type.
+ detail::FunctionStorage *impl = nullptr;
// Allow access to 'setName'.
friend class SymbolTable;
+
+ // Allow access to 'impl'.
+ friend class Module;
+ friend class Region;
};
//===--------------------------------------------------------------------===//
@@ -487,21 +534,52 @@ private:
namespace llvm {
template <>
-struct ilist_traits<::mlir::Function>
- : public ilist_alloc_traits<::mlir::Function> {
- using Function = ::mlir::Function;
- using function_iterator = simple_ilist<Function>::iterator;
+struct ilist_traits<::mlir::detail::FunctionStorage>
+ : public ilist_alloc_traits<::mlir::detail::FunctionStorage> {
+ using FunctionStorage = ::mlir::detail::FunctionStorage;
+ using function_iterator = simple_ilist<FunctionStorage>::iterator;
- static void deleteNode(Function *function) { delete function; }
+ static void deleteNode(FunctionStorage *function) { delete function; }
- void addNodeToList(Function *function);
- void removeNodeFromList(Function *function);
- void transferNodesFromList(ilist_traits<Function> &otherList,
+ void addNodeToList(FunctionStorage *function);
+ void removeNodeFromList(FunctionStorage *function);
+ void transferNodesFromList(ilist_traits<FunctionStorage> &otherList,
function_iterator first, function_iterator last);
private:
mlir::Module *getContainingModule();
};
-} // end namespace llvm
+
+// Functions hash just like pointers.
+template <> struct DenseMapInfo<mlir::Function> {
+ static mlir::Function getEmptyKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::Function::getFromOpaquePointer(pointer);
+ }
+ static mlir::Function getTombstoneKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::Function::getFromOpaquePointer(pointer);
+ }
+ static unsigned getHashValue(mlir::Function val) {
+ return hash_value(val.getAsOpaquePointer());
+ }
+ static bool isEqual(mlir::Function LHS, mlir::Function RHS) {
+ return LHS == RHS;
+ }
+};
+
+/// Allow stealing the low bits of FunctionStorage.
+template <> struct PointerLikeTypeTraits<mlir::Function> {
+public:
+ static inline void *getAsVoidPointer(mlir::Function I) {
+ return const_cast<void *>(I.getAsOpaquePointer());
+ }
+ static inline mlir::Function getFromVoidPointer(void *P) {
+ return mlir::Function::getFromOpaquePointer(P);
+ }
+ enum { NumLowBitsAvailable = 3 };
+};
+
+} // namespace llvm
#endif // MLIR_IR_FUNCTION_H
diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h
index 8161a305fb5..d8a47891ace 100644
--- a/mlir/include/mlir/IR/Module.h
+++ b/mlir/include/mlir/IR/Module.h
@@ -34,34 +34,54 @@ public:
MLIRContext *getContext() { return context; }
+ /// An iterator class used to iterate over the held functions.
+ class iterator : public llvm::mapped_iterator<
+ llvm::iplist<detail::FunctionStorage>::iterator,
+ Function (*)(detail::FunctionStorage &)> {
+ static Function unwrap(detail::FunctionStorage &impl) { return &impl; }
+
+ public:
+ using reference = Function;
+
+ /// Initializes the operand type iterator to the specified operand iterator.
+ iterator(llvm::iplist<detail::FunctionStorage>::iterator it)
+ : llvm::mapped_iterator<llvm::iplist<detail::FunctionStorage>::iterator,
+ Function (*)(detail::FunctionStorage &)>(
+ it, &unwrap) {}
+ iterator(Function it)
+ : iterator(llvm::iplist<detail::FunctionStorage>::iterator(it.impl)) {}
+ };
+
/// This is the list of functions in the module.
- using FunctionListType = llvm::iplist<Function>;
- FunctionListType &getFunctions() { return functions; }
+ llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
// Iteration over the functions in the module.
- using iterator = FunctionListType::iterator;
- using reverse_iterator = FunctionListType::reverse_iterator;
-
iterator begin() { return functions.begin(); }
iterator end() { return functions.end(); }
- reverse_iterator rbegin() { return functions.rbegin(); }
- reverse_iterator rend() { return functions.rend(); }
+ Function front() { return &functions.front(); }
+ Function back() { return &functions.back(); }
+
+ void push_back(Function fn) { functions.push_back(fn.impl); }
+ void insert(iterator insertPt, Function fn) {
+ functions.insert(insertPt.getCurrent(), fn.impl);
+ }
// Interfaces for working with the symbol table.
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
- Function *getNamedFunction(StringRef name) {
+ Function getNamedFunction(StringRef name) {
return getNamedFunction(Identifier::get(name, getContext()));
}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
- Function *getNamedFunction(Identifier name) {
- auto it = llvm::find_if(
- functions, [name](Function &fn) { return fn.getName() == name; });
+ Function getNamedFunction(Identifier name) {
+ auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
+ return Function(&fn).getName() == name;
+ });
return it == functions.end() ? nullptr : &*it;
}
@@ -74,11 +94,13 @@ public:
void dump();
private:
- friend struct llvm::ilist_traits<Function>;
- friend class Function;
+ friend struct llvm::ilist_traits<detail::FunctionStorage>;
+ friend detail::FunctionStorage;
+ friend Function;
/// getSublistAccess() - Returns pointer to member of function list
- static FunctionListType Module::*getSublistAccess(Function *) {
+ static llvm::iplist<detail::FunctionStorage> Module::*
+ getSublistAccess(detail::FunctionStorage *) {
return &Module::functions;
}
@@ -86,7 +108,7 @@ private:
MLIRContext *context;
/// This is the actual list of functions the module contains.
- FunctionListType functions;
+ llvm::iplist<detail::FunctionStorage> functions;
};
/// A class used to manage the symbols held by a module. This class handles
@@ -98,24 +120,24 @@ public:
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
- template <typename NameTy> Function *getNamedFunction(NameTy &&name) const {
+ template <typename NameTy> Function getNamedFunction(NameTy &&name) const {
return symbolTable.lookup(name);
}
/// Insert a new symbol into the module, auto-renaming it as necessary.
- void insert(Function *function) {
+ void insert(Function function) {
symbolTable.insert(function);
- module->getFunctions().push_back(function);
+ module->push_back(function);
}
- void insert(Module::iterator insertPt, Function *function) {
+ void insert(Module::iterator insertPt, Function function) {
symbolTable.insert(function);
- module->getFunctions().insert(insertPt, function);
+ module->insert(insertPt, function);
}
/// Remove the given symbol from the module symbol table and then erase it.
- void erase(Function *function) {
+ void erase(Function function) {
symbolTable.erase(function);
- function->erase();
+ function.erase();
}
/// Return the internally held module.
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index e5323999df7..f916f4ba583 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -128,7 +128,7 @@ public:
/// Returns the function that this operation is part of.
/// The function is determined by traversing the chain of parent operations.
/// Returns nullptr if the operation is unlinked.
- Function *getFunction();
+ Function getFunction();
/// Replace any uses of 'from' with 'to' within this operation.
void replaceUsesOfWith(Value *from, Value *to);
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index a1b81fcde40..921437601e1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -420,7 +420,7 @@ private:
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
-bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
+bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 2189ad490f8..ad0692b0864 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -27,11 +27,16 @@
namespace mlir {
class BlockAndValueMapping;
+namespace detail {
+class FunctionStorage;
+}
+
/// This class contains a list of basic blocks and has a notion of the object it
/// is part of - a Function or an Operation.
class Region {
public:
- explicit Region(Function *container = nullptr);
+ Region() = default;
+ explicit Region(Function container);
explicit Region(Operation *container);
~Region();
@@ -77,7 +82,7 @@ public:
/// A Region is either a function body or a part of an operation. If it is
/// a Function body, then return this function, otherwise return null.
- Function *getContainingFunction();
+ Function getContainingFunction();
/// Return true if this region is a proper ancestor of the `other` region.
bool isProperAncestor(Region *other);
@@ -118,7 +123,7 @@ private:
RegionType blocks;
/// This is the object we are part of.
- llvm::PointerUnion<Function *, Operation *> container;
+ llvm::PointerUnion<detail::FunctionStorage *, Operation *> container;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 30749582031..a351f66eb2e 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -18,7 +18,7 @@
#ifndef MLIR_IR_SYMBOLTABLE_H
#define MLIR_IR_SYMBOLTABLE_H
-#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Function.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@@ -35,18 +35,18 @@ public:
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
- Function *lookup(StringRef name) const;
+ Function lookup(StringRef name) const;
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
- Function *lookup(Identifier name) const;
+ Function lookup(Identifier name) const;
/// Erase the given symbol from the table.
- void erase(Function *symbol);
+ void erase(Function symbol);
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
- void insert(Function *symbol);
+ void insert(Function symbol);
/// Returns the context held by this symbol table.
MLIRContext *getContext() const { return context; }
@@ -55,7 +55,7 @@ private:
MLIRContext *context;
/// This is a mapping from a name to the function with that name.
- llvm::DenseMap<Identifier, Function *> symbolTable;
+ llvm::DenseMap<Identifier, Function> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index e90505ec90d..4604ed99c77 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -72,7 +72,7 @@ public:
}
/// Return the function that this Value is defined in.
- Function *getFunction();
+ Function getFunction();
/// If this value is the result of an operation, return the operation that
/// defines it.
@@ -128,7 +128,7 @@ public:
}
/// Return the function that this argument is defined in.
- Function *getFunction();
+ Function getFunction();
Block *getOwner() { return owner; }
diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h
index bd3286df8f4..a28aa719965 100644
--- a/mlir/include/mlir/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h
@@ -153,7 +153,7 @@ public:
/// Verify a function argument attribute registered to this dialect.
/// Returns failure if the verification failed, success otherwise.
- LogicalResult verifyFunctionArgAttribute(Function *func, unsigned argIdx,
+ LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx,
NamedAttribute argAttr) override;
private:
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 3751a93629d..c44f88f6763 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -106,7 +106,7 @@ template <typename IRUnitT> class AnalysisMap {
}
public:
- explicit AnalysisMap(IRUnitT *ir) : ir(ir) {}
+ explicit AnalysisMap(IRUnitT ir) : ir(ir) {}
/// Get an analysis for the current IR unit, computing it if necessary.
template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
@@ -140,8 +140,8 @@ public:
}
/// Returns the IR unit that this analysis map represents.
- IRUnitT *getIRUnit() { return ir; }
- const IRUnitT *getIRUnit() const { return ir; }
+ IRUnitT getIRUnit() { return ir; }
+ const IRUnitT getIRUnit() const { return ir; }
/// Clear any held analyses.
void clear() { analyses.clear(); }
@@ -158,7 +158,7 @@ public:
}
private:
- IRUnitT *ir;
+ IRUnitT ir;
ConceptMap analyses;
};
@@ -231,14 +231,14 @@ public:
/// Query for the analysis of a function. The analysis is computed if it does
/// not exist.
template <typename AnalysisT>
- AnalysisT &getFunctionAnalysis(Function *function) {
+ AnalysisT &getFunctionAnalysis(Function function) {
return slice(function).getAnalysis<AnalysisT>();
}
/// Query for a cached analysis of a child function, or return null.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
- getCachedFunctionAnalysis(Function *function) const {
+ getCachedFunctionAnalysis(Function function) const {
auto it = functionAnalyses.find(function);
if (it == functionAnalyses.end())
return llvm::None;
@@ -258,7 +258,7 @@ public:
}
/// Create an analysis slice for the given child function.
- FunctionAnalysisManager slice(Function *function);
+ FunctionAnalysisManager slice(Function function);
/// Invalidate any non preserved analyses.
void invalidate(const detail::PreservedAnalyses &pa);
@@ -269,11 +269,11 @@ public:
private:
/// The cached analyses for functions within the current module.
- llvm::DenseMap<Function *, std::unique_ptr<detail::AnalysisMap<Function>>>
+ llvm::DenseMap<Function, std::unique_ptr<detail::AnalysisMap<Function>>>
functionAnalyses;
/// The analyses for the owning module.
- detail::AnalysisMap<Module> moduleAnalyses;
+ detail::AnalysisMap<Module *> moduleAnalyses;
/// An optional instrumentation object.
PassInstrumentor *passInstrumentor;
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 5fd6dfd18b5..41d20ccdd63 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -70,12 +70,12 @@ class ModulePassExecutor;
/// interface for accessing and initializing necessary state for pass execution.
template <typename IRUnitT, typename AnalysisManagerT>
struct PassExecutionState {
- PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager)
+ PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
/// The current IR unit being transformed and a bool for if the pass signaled
/// a failure.
- llvm::PointerIntPair<IRUnitT *, 1, bool> irAndPassFailed;
+ llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
/// The analysis manager for the IR unit.
AnalysisManagerT &analysisManager;
@@ -107,9 +107,7 @@ protected:
virtual FunctionPassBase *clone() const = 0;
/// Return the current function being transformed.
- Function &getFunction() {
- return *getPassState().irAndPassFailed.getPointer();
- }
+ Function getFunction() { return getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current function being transformed.
MLIRContext &getContext() { return *getFunction().getContext(); }
@@ -128,7 +126,7 @@ protected:
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
- LogicalResult run(Function *fn, FunctionAnalysisManager &fam);
+ LogicalResult run(Function fn, FunctionAnalysisManager &fam);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
@@ -140,7 +138,8 @@ private:
/// Pass to transform a module. Derived passes should not inherit from this
/// class directly, and instead should use the CRTP ModulePass class.
class ModulePassBase : public Pass {
- using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
+ using PassStateT =
+ detail::PassExecutionState<Module *, ModuleAnalysisManager>;
public:
static bool classof(const Pass *pass) {
@@ -272,7 +271,7 @@ struct FunctionPass : public detail::PassModel<Function, T, FunctionPassBase> {
template <typename T>
struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
/// Returns the analysis for a child function.
- template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function *f) {
+ template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function f) {
return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
f);
}
@@ -280,7 +279,7 @@ struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
/// Returns an existing analysis for a child function if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
- getCachedFunctionAnalysis(Function *f) {
+ getCachedFunctionAnalysis(Function f) {
return this->getAnalysisManager()
.template getCachedFunctionAnalysis<AnalysisT>(f);
}
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 0f427066296..40358329f45 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -77,29 +77,29 @@ public:
~PassInstrumentor();
/// See PassInstrumentation::runBeforePass for details.
- template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT *ir) {
+ template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
runBeforePass(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterPass for details.
- template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT *ir) {
+ template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
runAfterPass(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterPassFailed for details.
- template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT *ir) {
+ template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
runAfterPassFailed(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runBeforeAnalysis for details.
template <typename IRUnitT>
- void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
+ void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
runBeforeAnalysis(name, id, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterAnalysis for details.
template <typename IRUnitT>
- void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
+ void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
runAfterAnalysis(name, id, llvm::Any(ir));
}
diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td
index 1b14e2a2a9c..a7afe1f9e7c 100644
--- a/mlir/include/mlir/StandardOps/Ops.td
+++ b/mlir/include/mlir/StandardOps/Ops.td
@@ -214,11 +214,11 @@ def CallOp : Std_Op<"call"> {
let results = (outs Variadic<AnyType>);
let builders = [OpBuilder<
- "Builder *builder, OperationState *result, Function *callee,"
+ "Builder *builder, OperationState *result, Function callee,"
"ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
- result->addTypes(callee->getType().getResults());
+ result->addTypes(callee.getType().getResults());
}]>, OpBuilder<
"Builder *builder, OperationState *result, StringRef callee,"
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 00da0d5fcc0..c8ede78ec20 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns(
/// Convert the given functions with the provided conversion patterns. This
/// function returns failure if a type conversion failed.
LLVM_NODISCARD
-LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
+LogicalResult applyConversionPatterns(MutableArrayRef<Function> fns,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns);
@@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LLVM_NODISCARD
-LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target,
+LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns);
} // end namespace mlir
diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h
index d77b35a8044..09aa7dc8acd 100644
--- a/mlir/include/mlir/Transforms/LowerAffine.h
+++ b/mlir/include/mlir/Transforms/LowerAffine.h
@@ -37,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
/// Convert from the Affine dialect to the Standard dialect, in particular
/// convert structured affine control flow into CFG branch-based control flow.
-LogicalResult lowerAffineConstructs(Function &function);
+LogicalResult lowerAffineConstructs(Function function);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
diff --git a/mlir/include/mlir/Transforms/ViewFunctionGraph.h b/mlir/include/mlir/Transforms/ViewFunctionGraph.h
index c1da5ef9638..5780df5c21b 100644
--- a/mlir/include/mlir/Transforms/ViewFunctionGraph.h
+++ b/mlir/include/mlir/Transforms/ViewFunctionGraph.h
@@ -33,11 +33,11 @@ class FunctionPassBase;
/// Displays the CFG in a window. This is for use from the debugger and
/// depends on Graphviz to generate the graph.
-void viewGraph(Function &function, const Twine &name, bool shortNames = false,
+void viewGraph(Function function, const Twine &name, bool shortNames = false,
const Twine &title = "",
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
-llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function,
+llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames = false, const Twine &title = "");
/// Creates a pass to print CFG graphs.
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 016ef43a84a..d7650dcb127 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
if (inserted) {
reorderedDims.push_back(v);
}
- return getAffineDimExpr(iterPos->second, v->getFunction()->getContext())
+ return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
.cast<AffineDimExpr>();
}
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index 954a01b4843..b4cdeb7d886 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase<Block>;
/// Recalculate the dominance info.
template <bool IsPostDom>
-void DominanceInfoBase<IsPostDom>::recalculate(Function *function) {
+void DominanceInfoBase<IsPostDom>::recalculate(Function function) {
dominanceInfos.clear();
// Build the top level function dominance.
auto functionDominance = llvm::make_unique<base>();
- functionDominance->recalculate(function->getBody());
- dominanceInfos.try_emplace(&function->getBody(),
- std::move(functionDominance));
+ functionDominance->recalculate(function.getBody());
+ dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance));
/// Build the dominance for each of the operation regions.
- function->walk([&](Operation *op) {
+ function.walk([&](Operation *op) {
for (auto &region : op->getRegions()) {
// Don't compute dominance if the region is empty.
if (region.empty())
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index 5177afcee67..75a2fc1a5dc 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() {
opCount.clear();
// Compute the operation statistics for each function in the module.
- for (auto &fn : getModule())
+ for (auto fn : getModule())
fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
printSummary();
}
diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp
index cbda6d40224..473d253cfa2 100644
--- a/mlir/lib/Analysis/TestParallelismDetection.cpp
+++ b/mlir/lib/Analysis/TestParallelismDetection.cpp
@@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() {
// Walks the function and emits a note for all 'affine.for' ops detected as
// parallel.
void TestParallelismDetection::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
OpBuilder b(f.getBody());
f.walk<AffineForOp>([&](AffineForOp forOp) {
if (isLoopParallel(forOp))
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 1330fe0fb94..0d0525145ef 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -53,7 +53,7 @@ public:
: ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
/// Verify the body of the given function.
- LogicalResult verify(Function &fn);
+ LogicalResult verify(Function fn);
/// Verify the given operation.
LogicalResult verify(Operation &op);
@@ -104,7 +104,7 @@ private:
} // end anonymous namespace
/// Verify the body of the given function.
-LogicalResult OperationVerifier::verify(Function &fn) {
+LogicalResult OperationVerifier::verify(Function fn) {
// Verify the body first.
if (failed(verifyRegion(fn.getBody())))
return failure();
@@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) {
// check. We do this as a second pass since malformed CFG's can cause
// dominator analysis constructure to crash and we want the verifier to be
// resilient to malformed code.
- DominanceInfo theDomInfo(&fn);
+ DominanceInfo theDomInfo(fn);
domInfo = &theDomInfo;
if (failed(verifyDominance(fn.getBody())))
return failure();
@@ -313,7 +313,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
- if (failed(dialect->verifyFunctionAttribute(this, attr)))
+ if (failed(dialect->verifyFunctionAttribute(*this, attr)))
return failure();
}
@@ -331,7 +331,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
- if (failed(dialect->verifyFunctionArgAttribute(this, i, attr)))
+ if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr)))
return failure();
}
}
@@ -369,7 +369,7 @@ LogicalResult Operation::verify() {
LogicalResult Module::verify() {
// Check that all functions are uniquely named.
llvm::StringMap<Location> nameToOrigLoc;
- for (auto &fn : *this) {
+ for (auto fn : *this) {
auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
if (!it.second)
return fn.emitError()
@@ -379,7 +379,7 @@ LogicalResult Module::verify() {
}
// Check that each function is correct.
- for (auto &fn : *this)
+ for (auto fn : *this)
if (failed(fn.verify()))
return failure();
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
index 9d7aeeb6321..022d8c70cc6 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
@@ -64,8 +64,8 @@ public:
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
- for (auto &function : getModule()) {
- if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
+ for (auto function : getModule()) {
+ if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
continue;
}
if (failed(translateGpuKernelToCubinAnnotation(function)))
@@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
std::unique_ptr<Module> module(builder.createModule());
// TODO(herhut): Also handle called functions.
- module->getFunctions().push_back(function.clone());
+ module->push_back(function.clone());
auto llvmModule = translateModuleToNVVMIR(*module);
auto cubin = convertModuleToCubin(*llvmModule, function);
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index bd96f396b22..f9d5899456a 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -118,7 +118,7 @@ private:
void declareCudaFunctions(Location loc);
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
- Value *generateKernelNameConstant(Function *kernelFunction, Location &loc,
+ Value *generateKernelNameConstant(Function kernelFunction, Location &loc,
OpBuilder &builder);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
@@ -130,7 +130,7 @@ public:
// Cache the used LLVM types.
initializeCachedTypes();
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchFuncOp>(
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
}
@@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
Module &module = getModule();
Builder builder(&module);
if (!module.getNamedFunction(cuModuleLoadName)) {
- module.getFunctions().push_back(
- new Function(loc, cuModuleLoadName,
- builder.getFunctionType(
- {
- getPointerPointerType(), /* CUmodule *module */
- getPointerType() /* void *cubin */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuModuleLoadName,
+ builder.getFunctionType(
+ {
+ getPointerPointerType(), /* CUmodule *module */
+ getPointerType() /* void *cubin */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction.
- module.getFunctions().push_back(
- new Function(loc, cuModuleGetFunctionName,
- builder.getFunctionType(
- {
- getPointerPointerType(), /* void **function */
- getPointerType(), /* void *module */
- getPointerType() /* char *name */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuModuleGetFunctionName,
+ builder.getFunctionType(
+ {
+ getPointerPointerType(), /* void **function */
+ getPointerType(), /* void *module */
+ getPointerType() /* char *name */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
// CUstream.
- module.getFunctions().push_back(
- new Function(loc, cuLaunchKernelName,
- builder.getFunctionType(
- {
- getPointerType(), /* void* f */
- getIntPtrType(), /* intptr_t gridXDim */
- getIntPtrType(), /* intptr_t gridyDim */
- getIntPtrType(), /* intptr_t gridZDim */
- getIntPtrType(), /* intptr_t blockXDim */
- getIntPtrType(), /* intptr_t blockYDim */
- getIntPtrType(), /* intptr_t blockZDim */
- getInt32Type(), /* unsigned int sharedMemBytes */
- getPointerType(), /* void *hstream */
- getPointerPointerType(), /* void **kernelParams */
- getPointerPointerType() /* void **extra */
- },
- getCUResultType())));
+ module.push_back(Function::create(
+ loc, cuLaunchKernelName,
+ builder.getFunctionType(
+ {
+ getPointerType(), /* void* f */
+ getIntPtrType(), /* intptr_t gridXDim */
+ getIntPtrType(), /* intptr_t gridyDim */
+ getIntPtrType(), /* intptr_t gridZDim */
+ getIntPtrType(), /* intptr_t blockXDim */
+ getIntPtrType(), /* intptr_t blockYDim */
+ getIntPtrType(), /* intptr_t blockZDim */
+ getInt32Type(), /* unsigned int sharedMemBytes */
+ getPointerType(), /* void *hstream */
+ getPointerPointerType(), /* void **kernelParams */
+ getPointerPointerType() /* void **extra */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream.
- module.getFunctions().push_back(new Function(
+ module.push_back(Function::create(
loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */)));
}
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
- module.getFunctions().push_back(
- new Function(loc, cuStreamSynchronizeName,
- builder.getFunctionType(
- {
- getPointerType() /* CUstream stream */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuStreamSynchronizeName,
+ builder.getFunctionType(
+ {
+ getPointerType() /* CUstream stream */
+ },
+ getCUResultType())));
}
}
@@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
// %0[n] = constant name[n]
// %0[n+1] = 0
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
- Function *kernelFunction, Location &loc, OpBuilder &builder) {
+ Function kernelFunction, Location &loc, OpBuilder &builder) {
// TODO(herhut): Make this a constant once this is supported.
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction->getName().size() + 1));
+ builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
auto kernelName =
builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
- for (auto byte : llvm::enumerate(kernelFunction->getName())) {
+ for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
@@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
// Add trailing zero to terminate string.
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction->getName().size()));
+ builder.getI32IntegerAttr(kernelFunction.getName().size()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
ArrayRef<Value *>{index});
auto value = builder.create<LLVM::ConstantOp>(
@@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
auto cubinGetter =
- kernelFunction->getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
+ kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
if (!cubinGetter) {
- kernelFunction->emitError("Missing ")
+ kernelFunction.emitError("Missing ")
<< kCubinGetterAnnotation << " attribute.";
return signalPassFailure();
}
@@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
- Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
+ Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)});
@@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc);
- Function *cuModuleGetFunction =
+ Function cuModuleGetFunction =
getModule().getNamedFunction(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
// Grab the global stream needed for execution.
- Function *cuGetStreamHelper =
+ Function cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
index c1d4af380ce..97790a5afce 100644
--- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
@@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc";
class GpuGenerateCubinAccessorsPass
: public ModulePass<GpuGenerateCubinAccessorsPass> {
private:
- Function *getMallocHelper(Location loc, Builder &builder) {
- Function *result = getModule().getNamedFunction(kMallocHelperName);
+ Function getMallocHelper(Location loc, Builder &builder) {
+ Function result = getModule().getNamedFunction(kMallocHelperName);
if (!result) {
- result = new Function(
+ result = Function::create(
loc, kMallocHelperName,
builder.getFunctionType(
ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
- getModule().getFunctions().push_back(result);
+ getModule().push_back(result);
}
return result;
}
@@ -70,18 +70,18 @@ private:
// data from blob. As there are currently no global constants, this uses a
// sequence of store operations.
// TODO(herhut): Use global constants instead.
- Function *generateCubinAccessor(Builder &builder, Function &orig,
- StringAttr blob) {
+ Function generateCubinAccessor(Builder &builder, Function &orig,
+ StringAttr blob) {
Location loc = orig.getLoc();
SmallString<128> nameBuffer(orig.getName());
nameBuffer.append(kCubinGetterSuffix);
// Generate a function that returns void*.
- Function *result = new Function(
+ Function result = Function::create(
loc, mlir::Identifier::get(nameBuffer, &getContext()),
builder.getFunctionType(ArrayRef<Type>{},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
// Insert a body block that just returns the constant.
- OpBuilder ob(result->getBody());
+ OpBuilder ob(result.getBody());
ob.createBlock();
auto sizeConstant = ob.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
@@ -115,18 +115,18 @@ public:
void runOnModule() override {
llvmDialect =
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- Builder builder(getModule().getContext());
+ auto &module = getModule();
+ Builder builder(&getContext());
- auto &functions = getModule().getFunctions();
+ auto functions = module.getFunctions();
for (auto it = functions.begin(); it != functions.end();) {
// Move iterator to after the current function so that potential insertion
// of the accessor is after the kernel with cubin iself.
- Function &orig = *it++;
+ Function orig = *it++;
StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation);
if (!cubinBlob)
continue;
- it =
- functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
+ module.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
}
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 872707842d7..e849f6fd023 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
createIndexConstant(rewriter, op->getLoc(), elementSize)});
// Insert the `malloc` declaration if it is not already present.
- Function *mallocFunc =
- op->getFunction()->getModule()->getNamedFunction("malloc");
+ Function mallocFunc =
+ op->getFunction().getModule()->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
- mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
- op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+ mallocFunc =
+ Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+ op->getFunction().getModule()->push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
@@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
- Function *freeFunc =
- op->getFunction()->getModule()->getNamedFunction("free");
+ Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
- freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
- op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+ freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
+ op->getFunction().getModule()->push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
@@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) {
}
void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
- for (auto &f : *m) {
+ for (auto f : *m) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index ff198217bb7..dafc8e711f5 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
- auto &fn = getFunction();
+ auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
@@ -386,7 +386,7 @@ static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
- auto &fn = getFunction();
+ auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 9dcc6df6bea..8469fa2ea70 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
applyPatternsGreedily(func, std::move(patterns));
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
index ea8095b791c..0c93146a232 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -95,7 +95,7 @@ public:
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp
index 51636037382..f13b743de0c 100644
--- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp
+++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp
@@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
}
llvm::Expected<SmallVector<void *, 8>>
-mlir::allocateMemRefArguments(Function *func, float initialValue) {
+mlir::allocateMemRefArguments(Function func, float initialValue) {
SmallVector<void *, 8> args;
- args.reserve(func->getNumArguments());
- for (const auto &arg : func->getArguments()) {
+ args.reserve(func.getNumArguments());
+ for (const auto &arg : func.getArguments()) {
auto descriptor =
allocMemRefDescriptor(arg->getType(),
/*allocateData=*/true, initialValue);
@@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) {
args.push_back(*descriptor);
}
- if (func->getType().getNumResults() > 1)
+ if (func.getType().getNumResults() > 1)
return make_string_error("functions with more than 1 result not supported");
- for (Type resType : func->getType().getResults()) {
+ for (Type resType : func.getType().getResults()) {
auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
if (!descriptor)
return descriptor.takeError();
diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp
index e39860bddda..5e8090b42b4 100644
--- a/mlir/lib/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/GPU/IR/GPUDialect.cpp
@@ -30,9 +30,9 @@ using namespace mlir::gpu;
StringRef GPUDialect::getDialectName() { return "gpu"; }
-bool GPUDialect::isKernel(Function *function) {
+bool GPUDialect::isKernel(Function function) {
UnitAttr isKernelAttr =
- function->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
+ function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
@@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState *result,
- Function *kernelFunc, Value *gridSizeX,
+ Function kernelFunc, Value *gridSizeX,
Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
Value *blockSizeY, Value *blockSizeZ,
ArrayRef<Value *> kernelOperands) {
@@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
}
void LaunchFuncOp::build(Builder *builder, OperationState *result,
- Function *kernelFunc, KernelDim3 gridSize,
+ Function kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() {
return emitOpError("attribute 'kernel' must be a function");
}
- auto *module = getOperation()->getFunction()->getModule();
- Function *kernelFunc = module->getNamedFunction(kernel());
+ auto *module = getOperation()->getFunction().getModule();
+ Function kernelFunc = module->getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
- if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
+ if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName())) {
return emitError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
}
- unsigned numKernelFuncArgs = kernelFunc->getNumArguments();
+ unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
if (getNumKernelOperands() != numKernelFuncArgs) {
return emitOpError("got ")
<< getNumKernelOperands() << " kernel operands but expected "
<< numKernelFuncArgs;
}
- auto functionType = kernelFunc->getType();
+ auto functionType = kernelFunc.getType();
for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
return emitOpError("type of function argument ")
diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp
index 46363f06f72..f93febcf5da 100644
--- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp
@@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc,
// Add operations generating block/thread ids and gird/block dimensions at the
// beginning of `kernelFunc` and replace uses of the respective function args.
-static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
+static void injectGpuIndexOperations(Location loc, Function kernelFunc) {
OpBuilder OpBuilder(kernelFunc.getBody());
SmallVector<Value *, 12> indexOps;
createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
@@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
// Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.return` operations by `std.return` in the generated functions.
-static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
+static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
- Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
- Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
- outlinedFunc->getBody().takeBody(launchOp.getBody());
+ Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
+ Function outlinedFunc = Function::create(loc, kernelFuncName, type);
+ outlinedFunc.getBody().takeBody(launchOp.getBody());
Builder builder(launchOp.getContext());
- outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
- builder.getUnitAttr());
- injectGpuIndexOperations(loc, *outlinedFunc);
- outlinedFunc->walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
+ outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+ injectGpuIndexOperations(loc, outlinedFunc);
+ outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
OpBuilder replacer(op);
replacer.create<ReturnOp>(op.getLoc());
op.erase();
@@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
// `kernelFunc`.
static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
- Function &kernelFunc) {
+ Function kernelFunc) {
OpBuilder builder(launchOp);
SmallVector<Value *, 4> kernelOperandValues(
launchOp.getKernelOperandValues());
builder.create<gpu::LaunchFuncOp>(
- launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(),
+ launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
launchOp.erase();
}
@@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
ModuleManager moduleManager(&getModule());
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
- Function *outlinedFunc = outlineKernelFunc(op);
+ Function outlinedFunc = outlineKernelFunc(op);
moduleManager.insert(outlinedFunc);
- convertToLaunchFuncOp(op, *outlinedFunc);
+ convertToLaunchFuncOp(op, outlinedFunc);
});
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8e3d5788bb1..346d35af231 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) {
initializeSymbolAliases();
// Walk the module and visit each operation.
- for (auto &fn : *module) {
+ for (auto fn : *module) {
visitType(fn.getType());
for (auto attr : fn.getAttrs())
ModuleState::visitAttribute(attr.second);
@@ -342,7 +342,7 @@ public:
void printAttribute(Attribute attr, bool mayElideType = false);
void printType(Type type);
- void print(Function *fn);
+ void print(Function fn);
void printLocation(LocationAttr loc);
void printAffineMap(AffineMap map);
@@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) {
state.printTypeAliases(os);
// Print the module.
- for (auto &fn : *module)
- print(&fn);
+ for (auto fn : *module)
+ print(fn);
}
/// Print a floating point value in a way that the parser will be able to
@@ -1186,7 +1186,7 @@ namespace {
// CFG and ML functions.
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
- FunctionPrinter(Function *function, ModulePrinter &other);
+ FunctionPrinter(Function function, ModulePrinter &other);
// Prints the function as a whole.
void print();
@@ -1275,7 +1275,7 @@ protected:
void printValueID(Value *value, bool printResultNo = true) const;
private:
- Function *function;
+ Function function;
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
@@ -1305,10 +1305,10 @@ private:
};
} // end anonymous namespace
-FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other)
+FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other)
: ModulePrinter(other), function(function) {
- for (auto &block : *function)
+ for (auto &block : function)
numberValuesInBlock(block);
}
@@ -1419,17 +1419,17 @@ void FunctionPrinter::print() {
printFunctionSignature();
// Print out function attributes, if present.
- auto attrs = function->getAttrs();
+ auto attrs = function.getAttrs();
if (!attrs.empty()) {
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
// Print the trailing location.
- printTrailingLocation(function->getLoc());
+ printTrailingLocation(function.getLoc());
- if (!function->empty()) {
- printRegion(function->getBody(), /*printEntryBlockArgs=*/false,
+ if (!function.empty()) {
+ printRegion(function.getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
os << "\n";
}
@@ -1437,24 +1437,24 @@ void FunctionPrinter::print() {
}
void FunctionPrinter::printFunctionSignature() {
- os << "func @" << function->getName() << '(';
+ os << "func @" << function.getName() << '(';
- auto fnType = function->getType();
- bool isExternal = function->isExternal();
- for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+ auto fnType = function.getType();
+ bool isExternal = function.isExternal();
+ for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
// If this is an external function, don't print argument labels.
if (!isExternal) {
- printOperand(function->getArgument(i));
+ printOperand(function.getArgument(i));
os << ": ";
}
printType(fnType.getInput(i));
// Print the attributes for this argument.
- printOptionalAttrDict(function->getArgAttrs(i));
+ printOptionalAttrDict(function.getArgAttrs(i));
}
os << ')';
@@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term,
}
// Prints function with initialized module state.
-void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
+void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); }
//===----------------------------------------------------------------------===//
// print and dump methods
@@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) {
void Value::dump() { print(llvm::errs()); }
void Operation::print(raw_ostream &os) {
- auto *function = getFunction();
+ auto function = getFunction();
if (!function) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
- ModuleState state(function->getContext());
+ ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@@ -1754,13 +1754,13 @@ void Operation::dump() {
}
void Block::print(raw_ostream &os) {
- auto *function = getFunction();
+ auto function = getFunction();
if (!function) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
- ModuleState state(function->getContext());
+ ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
- ModuleState state(getFunction()->getContext());
+ ModuleState state(getFunction().getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(getFunction(), modulePrinter).printBlockName(this);
}
void Function::print(raw_ostream &os) {
ModuleState state(getContext());
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).print(*this);
}
void Function::dump() { print(llvm::errs()); }
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 01f9a060bd9..9cbba0fe429 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// FunctionAttr
//===----------------------------------------------------------------------===//
-FunctionAttr FunctionAttr::get(Function *value) {
- assert(value && "Cannot get FunctionAttr for a null function");
- return get(value->getName(), value->getContext());
-}
-
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::Function, value,
NoneType::get(ctx));
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index e7616f6d7d0..134f6e468a0 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -50,7 +50,7 @@ Operation *Block::getContainingOp() {
return getParent() ? getParent()->getContainingOp() : nullptr;
}
-Function *Block::getFunction() {
+Function Block::getFunction() {
Block *block = this;
while (auto *op = block->getContainingOp()) {
block = op->getBlock();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9b30205abdb..89df64260d3 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
-FunctionAttr Builder::getFunctionAttr(Function *value) {
- return FunctionAttr::get(value);
+FunctionAttr Builder::getFunctionAttr(Function value) {
+ return getFunctionAttr(value.getName());
}
FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 4547452eb55..e38b95ff0f7 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
@@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context)
Dialect::~Dialect() {}
+/// Verify an attribute from this dialect on the given function. Returns
+/// failure if the verification failed, success otherwise.
+LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) {
+ return success();
+}
+
+/// Verify an attribute from this dialect on the argument at 'argIndex' for
+/// the given function. Returns failure if the verification failed, success
+/// otherwise.
+LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex,
+ NamedAttribute) {
+ return success();
+}
+
/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 7d17ed1d705..f8835f02c26 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -27,45 +27,50 @@
#include "llvm/ADT/Twine.h"
using namespace mlir;
+using namespace mlir::detail;
-Function::Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs)
+FunctionStorage::FunctionStorage(Location location, StringRef name,
+ FunctionType type,
+ ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
-Function::Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs,
- ArrayRef<NamedAttributeList> argAttrs)
+FunctionStorage::FunctionStorage(Location location, StringRef name,
+ FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<NamedAttributeList> argAttrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
MLIRContext *Function::getContext() { return getType().getContext(); }
-Module *llvm::ilist_traits<Function>::getContainingModule() {
+Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
- iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
+ iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
}
/// This is a trait method invoked when a Function is added to a Module. We
/// keep the module pointer and module symbol table up to date.
-void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
- assert(!function->getModule() && "already in a module!");
+void llvm::ilist_traits<FunctionStorage>::addNodeToList(
+ FunctionStorage *function) {
+ assert(!function->module && "already in a module!");
function->module = getContainingModule();
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
-void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
+void llvm::ilist_traits<FunctionStorage>::removeNodeFromList(
+ FunctionStorage *function) {
assert(function->module && "not already in a module!");
function->module = nullptr;
}
/// This is a trait method invoked when an operation is moved from one block
/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<Function>::transferNodesFromList(
- ilist_traits<Function> &otherList, function_iterator first,
+void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
+ ilist_traits<FunctionStorage> &otherList, function_iterator first,
function_iterator last) {
// If we are transferring functions within the same module, the Module
// pointer doesn't need to be updated.
@@ -82,8 +87,10 @@ void llvm::ilist_traits<Function>::transferNodesFromList(
/// Unlink this function from its Module and delete it.
void Function::erase() {
- assert(getModule() && "Function has no parent");
- getModule()->getFunctions().erase(this);
+ if (auto *module = getModule())
+ getModule()->functions.erase(impl);
+ else
+ delete impl;
}
/// Emit an error about fatal conditions with this function, reporting up to
@@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) {
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
-void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
+void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<Identifier, Attribute> newAttrs;
- for (auto &attr : dest->getAttrs())
+ for (auto &attr : dest.getAttrs())
newAttrs.insert(attr);
for (auto &attr : getAttrs()) {
auto insertPair = newAttrs.insert(attr);
@@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
assert((insertPair.second || insertPair.first->second == attr.second) &&
"the two functions have incompatible attributes");
}
- dest->setAttrs(newAttrs.takeVector());
+ dest.setAttrs(newAttrs.takeVector());
// Clone the body.
- body.cloneInto(&dest->body, mapper);
+ impl->body.cloneInto(&dest.impl->body, mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
@@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
-Function *Function::clone(BlockAndValueMapping &mapper) {
- FunctionType newType = type;
+Function Function::clone(BlockAndValueMapping &mapper) {
+ FunctionType newType = impl->type;
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
@@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
- inputTypes.push_back(type.getInput(i));
- newType = FunctionType::get(inputTypes, type.getResults(), getContext());
+ inputTypes.push_back(newType.getInput(i));
+ newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
}
// Create the new function.
- Function *newFunc = new Function(getLoc(), getName(), newType);
+ Function newFunc = Function::create(getLoc(), getName(), newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
- newFunc->setArgAttrs(destI++, getArgAttrs(i));
+ newFunc.setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
-Function *Function::clone() {
+Function Function::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}
@@ -178,7 +185,7 @@ void Function::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
push_back(entry);
- entry->addArguments(type.getInputs());
+ entry->addArguments(impl->type.getInputs());
}
void Function::walk(const std::function<void(Operation *)> &callback) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 83171f12d1d..f953cd27a56 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -281,7 +281,7 @@ Operation *Operation::getParentOp() {
return block ? block->getContainingOp() : nullptr;
}
-Function *Operation::getFunction() {
+Function Operation::getFunction() {
return block ? block->getFunction() : nullptr;
}
@@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands,
}
static LogicalResult verifyTerminatorSuccessors(Operation *op) {
+ auto *parent = op->getContainingRegion();
+
// Verify that the operands lines up with the BB arguments in the successor.
- Function *fn = op->getFunction();
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
auto *succ = op->getSuccessor(i);
- if (succ->getFunction() != fn)
- return op->emitError("reference to block defined in another function");
+ if (succ->getParent() != parent)
+ return op->emitError("reference to block defined in another region");
if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
return failure();
}
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 992d9112beb..74c71b7aeac 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -21,7 +21,7 @@
#include "mlir/IR/Operation.h"
using namespace mlir;
-Region::Region(Function *container) : container(container) {}
+Region::Region(Function container) : container(container.impl) {}
Region::Region(Operation *container) : container(container) {}
@@ -38,7 +38,7 @@ MLIRContext *Region::getContext() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getContext();
- return getContainingFunction()->getContext();
+ return getContainingFunction().getContext();
}
/// Return a location for this region. This is the location attached to the
@@ -47,7 +47,7 @@ Location Region::getLoc() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getLoc();
- return getContainingFunction()->getLoc();
+ return getContainingFunction().getLoc();
}
Region *Region::getContainingRegion() {
@@ -60,8 +60,8 @@ Operation *Region::getContainingOp() {
return container.dyn_cast<Operation *>();
}
-Function *Region::getContainingFunction() {
- return container.dyn_cast<Function *>();
+Function Region::getContainingFunction() {
+ return container.dyn_cast<detail::FunctionStorage *>();
}
bool Region::isProperAncestor(Region *other) {
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index a0819a78fc1..dafbd48f513 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -22,8 +22,8 @@ using namespace mlir;
/// Build a symbol table with the symbols within the given module.
SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
- for (auto &func : *module) {
- auto inserted = symbolTable.insert({func.getName(), &func});
+ for (auto func : *module) {
+ auto inserted = symbolTable.insert({func.getName(), func});
(void)inserted;
assert(inserted.second &&
"expected module to contain uniquely named functions");
@@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
-Function *SymbolTable::lookup(StringRef name) const {
+Function SymbolTable::lookup(StringRef name) const {
return lookup(Identifier::get(name, context));
}
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
-Function *SymbolTable::lookup(Identifier name) const {
+Function SymbolTable::lookup(Identifier name) const {
return symbolTable.lookup(name);
}
/// Erase the given symbol from the table.
-void SymbolTable::erase(Function *symbol) {
- auto it = symbolTable.find(symbol->getName());
+void SymbolTable::erase(Function symbol) {
+ auto it = symbolTable.find(symbol.getName());
if (it != symbolTable.end() && it->second == symbol)
symbolTable.erase(it);
}
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
-void SymbolTable::insert(Function *symbol) {
+void SymbolTable::insert(Function symbol) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
- if (symbolTable.insert({symbol->getName(), symbol}).second)
+ if (symbolTable.insert({symbol.getName(), symbol}).second)
return;
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(symbol->getName());
+ SmallString<128> nameBuffer(symbol.getName());
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used. We use a
@@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
- symbol->setName(Identifier::get(nameBuffer, context));
- } while (!symbolTable.insert({symbol->getName(), symbol}).second);
+ symbol.setName(Identifier::get(nameBuffer, context));
+ } while (!symbolTable.insert({symbol.getName(), symbol}).second);
}
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 073c3b369c6..65a98f7ee59 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() {
}
/// Return the function that this Value is defined in.
-Function *Value::getFunction() {
+Function Value::getFunction() {
switch (getKind()) {
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
@@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() {
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
-Function *BlockArgument::getFunction() {
+Function BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->getFunction();
return nullptr;
@@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() {
/// Returns if the current argument is a function argument.
bool BlockArgument::isFunctionArgument() {
- auto *containingFn = getFunction();
- return containingFn && &containingFn->front() == getOwner();
+ auto containingFn = getFunction();
+ return containingFn && &containingFn.front() == getOwner();
}
diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
index 0d3a5ca2756..0dbf63a3ce7 100644
--- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
@@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const {
}
/// Verify LLVMIR function argument attributes.
-LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
+LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func,
unsigned argIdx,
NamedAttribute argAttr) {
// Check that llvm.noalias is a boolean attribute.
if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
- return func->emitError()
+ return func.emitError()
<< "llvm.noalias argument attribute of non boolean type";
return success();
}
diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp
index 7ddb7b0c19f..5761cc637b7 100644
--- a/mlir/lib/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Linalg/Transforms/Fusion.cpp
@@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
return true;
}
-static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
+static void fuseLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
DenseSet<Operation *> eraseSet;
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index a8099aaff99..5fe4f07613a 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -170,12 +170,13 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *mallocFunc = module->getNamedFunction("malloc");
+ auto *module = op->getFunction().getModule();
+ Function mallocFunc = module->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
- mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
- module->getFunctions().push_back(mallocFunc);
+ mallocFunc =
+ Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+ module->push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
@@ -230,12 +231,12 @@ public:
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *freeFunc = module->getNamedFunction("free");
+ auto *module = op->getFunction().getModule();
+ Function freeFunc = module->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
- freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
- module->getFunctions().push_back(freeFunc);
+ freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
+ module->push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
@@ -572,37 +573,37 @@ public:
// Create a function definition which takes as argument pointers to the input
// types and returns pointers to the output types.
-static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
- auto implFnName = (libFn->getName().str() + "_impl");
- auto module = libFn->getModule();
- if (auto *f = module->getNamedFunction(implFnName)) {
+static Function getLLVMLibraryCallImplDefinition(Function libFn) {
+ auto implFnName = (libFn.getName().str() + "_impl");
+ auto module = libFn.getModule();
+ if (auto f = module->getNamedFunction(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
- for (auto t : libFn->getType().getInputs()) {
+ for (auto t : libFn.getType().getInputs()) {
assert(t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
}
- auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
+ auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
// Insert the implementation function definition.
- auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
- module->getFunctions().push_back(implFnDefn);
+ auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
+ module->push_back(implFnDefn);
return implFnDefn;
}
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
-static Function *getLLVMLibraryCallDeclaration(Operation *op,
- LLVMTypeConverter &lowering,
- PatternRewriter &rewriter) {
+static Function getLLVMLibraryCallDeclaration(Operation *op,
+ LLVMTypeConverter &lowering,
+ PatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
- auto module = op->getFunction()->getModule();
- if (auto *f = module->getNamedFunction(fnName)) {
+ auto module = op->getFunction().getModule();
+ if (auto f = module->getNamedFunction(fnName)) {
return f;
}
@@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op,
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
- auto libFn = new Function(op->getLoc(), fnName, libFnType);
- module->getFunctions().push_back(libFn);
+ auto libFn = Function::create(op->getLoc(), fnName, libFnType);
+ module->push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
-static void getLLVMLibraryCallDefinition(Function *fn,
+static void getLLVMLibraryCallDefinition(Function fn,
LLVMTypeConverter &lowering) {
// Generate the implementation function definition.
auto implFn = getLLVMLibraryCallImplDefinition(fn);
// Generate the function body.
- fn->addEntryBlock();
+ fn.addEntryBlock();
- OpBuilder builder(fn->getBody());
- edsc::ScopedContext scope(builder, fn->getLoc());
+ OpBuilder builder(fn.getBody());
+ edsc::ScopedContext scope(builder, fn.getLoc());
SmallVector<Value *, 4> implFnArgs;
// Create a constant 1.
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
- IntegerAttr::get(IndexType::get(fn->getContext()), 1));
- for (auto arg : fn->getArguments()) {
+ IntegerAttr::get(IndexType::get(fn.getContext()), 1));
+ for (auto arg : fn.getArguments()) {
// Allocate a stack for storing the argument value. The stack is passed to
// the implementation function.
auto alloca =
@@ -665,17 +666,17 @@ public:
return convertLinalgType(t, *this);
}
- void addLibraryFnDeclaration(Function *fn) {
+ void addLibraryFnDeclaration(Function fn) {
libraryFnDeclarations.push_back(fn);
}
- ArrayRef<Function *> getLibraryFnDeclarations() {
+ ArrayRef<Function> getLibraryFnDeclarations() {
return libraryFnDeclarations;
}
private:
/// List of library functions declarations needed during dialect conversion
- SmallVector<Function *, 2> libraryFnDeclarations;
+ SmallVector<Function, 2> libraryFnDeclarations;
};
} // end anonymous namespace
@@ -692,7 +693,7 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
- auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+ auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
@@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) {
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
- for (auto &f : module.getFunctions()) {
+ for (auto f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))
diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
index d31ba5bf22d..2e616c35f1d 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
@@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
} // namespace
void LowerLinalgToLoopsPass::runOnFunction() {
- auto &f = getFunction();
OperationFolder state;
- f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
+ getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) {
emitLinalgOpAsLoops(linalgOp, state);
linalgOp.getOperation()->erase();
});
diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp
index c63e1cf197d..2f752b2b637 100644
--- a/mlir/lib/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Linalg/Transforms/Tiling.cpp
@@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
return tileLinalgOp(op, tileSizeValues, state);
}
-static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
+static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 44f05963727..4af2f093daf 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -254,7 +254,7 @@ public:
/// trailing-location ::= location?
///
template <typename Owner>
- ParseResult parseOptionalTrailingLocation(Owner *owner) {
+ ParseResult parseOptionalTrailingLocation(Owner &owner) {
// If there is a 'loc' we parse a trailing location.
if (!getToken().is(Token::kw_loc))
return success();
@@ -263,7 +263,7 @@ public:
LocationAttr directLoc;
if (parseLocation(directLoc))
return failure();
- owner->setLoc(directLoc);
+ owner.setLoc(directLoc);
return success();
}
@@ -2472,8 +2472,8 @@ namespace {
/// operations.
class OperationParser : public Parser {
public:
- OperationParser(ParserState &state, Function *function)
- : Parser(state), function(function), opBuilder(function->getBody()) {}
+ OperationParser(ParserState &state, Function function)
+ : Parser(state), function(function), opBuilder(function.getBody()) {}
~OperationParser();
@@ -2588,7 +2588,7 @@ public:
Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
private:
- Function *function;
+ Function function;
/// Returns the info for a block at the current scope for the given name.
std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
@@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() {
for (auto entry : forwardRefInCurrentScope) {
errors.push_back({entry.second.getPointer(), entry.first});
// Add this block to the top-level region to allow for automatic cleanup.
- function->push_back(entry.first);
+ function.push_back(entry.first);
}
llvm::array_pod_sort(errors.begin(), errors.end());
@@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() {
}
// Try to parse the optional trailing location.
- if (parseOptionalTrailingLocation(op))
+ if (parseOptionalTrailingLocation(*op))
return failure();
return success();
@@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) {
}
// Okay, the function signature was parsed correctly, create the function now.
- auto *function =
- new Function(getEncodedSourceLocation(loc), name, type, attrs);
- module->getFunctions().push_back(function);
+ auto function =
+ Function::create(getEncodedSourceLocation(loc), name, type, attrs);
+ module->push_back(function);
// Parse an optional trailing location.
if (parseOptionalTrailingLocation(function))
return failure();
// Add the attributes to the function arguments.
- for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
- function->setArgAttrs(i, argAttrs[i]);
+ for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i)
+ function.setArgAttrs(i, argAttrs[i]);
// External functions have no body.
if (getToken().isNot(Token::l_brace))
@@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) {
// Parse the function body.
auto parser = OperationParser(getState(), function);
- if (parser.parseRegion(function->getBody(), entryArgs))
+ if (parser.parseRegion(function.getBody(), entryArgs))
return failure();
// Verify that a valid function body was parsed.
- if (function->empty())
+ if (function.empty())
return emitError(braceLoc, "function must have a body");
return parser.finalize(braceLoc);
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 868d492e094..057f2655207 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -61,12 +61,12 @@ private:
static void printIR(const llvm::Any &ir, bool printModuleScope,
raw_ostream &out) {
// Check for printing at module scope.
- if (printModuleScope && llvm::any_isa<Function *>(ir)) {
- Function *function = llvm::any_cast<Function *>(ir);
+ if (printModuleScope && llvm::any_isa<Function>(ir)) {
+ Function function = llvm::any_cast<Function>(ir);
// Print the function name and a newline before the Module.
- out << " (function: " << function->getName() << ")\n";
- function->getModule()->print(out);
+ out << " (function: " << function.getName() << ")\n";
+ function.getModule()->print(out);
return;
}
@@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
out << "\n";
// Print the given function.
- if (llvm::any_isa<Function *>(ir)) {
- llvm::any_cast<Function *>(ir)->print(out);
+ if (llvm::any_isa<Function>(ir)) {
+ llvm::any_cast<Function>(ir).print(out);
return;
}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 2f605b6690b..27ec74c23c2 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -46,8 +46,7 @@ static llvm::cl::opt<bool>
void Pass::anchor() {}
/// Forwarding function to execute this pass.
-LogicalResult FunctionPassBase::run(Function *fn,
- FunctionAnalysisManager &fam) {
+LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
// Initialize the pass state.
passState.emplace(fn, fam);
@@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
}
/// Run all of the passes in this manager over the current function.
-LogicalResult detail::FunctionPassExecutor::run(Function *function,
+LogicalResult detail::FunctionPassExecutor::run(Function function,
FunctionAnalysisManager &fam) {
// Run each of the held passes.
for (auto &pass : passes)
@@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module,
/// Utility to run the given function and analysis manager on a provided
/// function pass executor.
static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
- Function *func,
+ Function func,
FunctionAnalysisManager &fam) {
// Run the function pipeline over the provided function.
auto result = fpe.run(func, fam);
@@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
/// module.
void ModuleToFunctionPassAdaptor::runOnModule() {
ModuleAnalysisManager &mam = getAnalysisManager();
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
// Skip external functions.
if (func.isExternal())
continue;
// Run the held function pipeline over the current function.
- auto fam = mam.slice(&func);
- if (failed(runFunctionPipeline(fpe, &func, fam)))
+ auto fam = mam.slice(func);
+ if (failed(runFunctionPipeline(fpe, func, fam)))
return signalPassFailure();
// Clear out any computed function analyses. These analyses won't be used
@@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
// Run a prepass over the module to collect the functions to execute a over.
// This ensures that an analysis manager exists for each function, as well as
// providing a queue of functions to execute over.
- std::vector<std::pair<Function *, FunctionAnalysisManager>> funcAMPairs;
- for (auto &func : getModule())
+ std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs;
+ for (auto func : getModule())
if (!func.isExternal())
- funcAMPairs.emplace_back(&func, mam.slice(&func));
+ funcAMPairs.emplace_back(func, mam.slice(func));
// A parallel diagnostic handler that provides deterministic diagnostic
// ordering.
@@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
}
/// Create an analysis slice for the given child function.
-FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) {
- assert(func->getModule() == moduleAnalyses.getIRUnit() &&
+FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) {
+ assert(func.getModule() == moduleAnalyses.getIRUnit() &&
"function has a different parent module");
auto it = functionAnalyses.find(func);
if (it == functionAnalyses.end()) {
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 46addfb8e9c..d2563fb62cd 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -48,7 +48,7 @@ public:
FunctionPassExecutor(const FunctionPassExecutor &rhs);
/// Run the executor on the given function.
- LogicalResult run(Function *function, FunctionAnalysisManager &fam);
+ LogicalResult run(Function function, FunctionAnalysisManager &fam);
/// Add a pass to the current executor. This takes ownership over the provided
/// pass pointer.
diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
index 375a64d8f2d..3f26bf075af 100644
--- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
@@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() {
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
- auto &func = getFunction();
+ auto func = getFunction();
// Insert stats for each argument.
for (auto *arg : func.getArguments()) {
diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
index dec4ea90db8..169fec3b39a 100644
--- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
@@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() {
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
CAGSlice cag(solverContext);
- for (auto &f : getModule()) {
+ for (auto f : getModule()) {
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
}
config.finalizeAnchors(cag);
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
index ed3b0956a16..6b376db8516 100644
--- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -58,7 +58,7 @@ public:
void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
index 3add211fdd5..543b7300af0 100644
--- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
+++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
@@ -36,11 +36,11 @@ using namespace mlir;
// block. The created block will be terminated by `std.return`.
Block *createOneBlockFunction(Builder builder, Module *module) {
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
- auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType);
- module->getFunctions().push_back(fn);
+ auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
+ module->push_back(fn);
auto *block = new Block();
- fn->push_back(block);
+ fn.push_back(block);
OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName());
ReturnOp::build(&builder, &state);
diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
index ebdcaf73717..33572d5adbe 100644
--- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
+++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
@@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
// to take in the SPIR-V ModuleOp directly after module and function are
// migrated to be general ops.
- for (auto &fn : *module) {
+ for (auto fn : *module) {
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
if (done) {
spirvModule.emitError("found more than one 'spv.module' op");
diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
index 1a8d79c1790..1ce2b69f055 100644
--- a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
+++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
@@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass
void StdOpsToSPIRVConversionPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
populateWithGenerated(func.getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index 6d5073f1c37..9fc216eef25 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
- auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
- auto fnType = fn->getType();
+ auto fnType = fn.getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
@@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
- auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
// Check that the referenced function has the correct type.
- if (fn->getType() != type)
+ if (fn.getType() != type)
return op.emitOpError("reference to function with mismatched type");
return success();
@@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
}
static LogicalResult verify(ReturnOp op) {
- auto *function = op.getOperation()->getFunction();
+ auto function = op.getOperation()->getFunction();
// The operand number and types must match the function signature.
- const auto &results = function->getType().getResults();
+ const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError("has ")
<< op.getNumOperands()
diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
index 74ade942fc7..1e8409246ef 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -69,7 +69,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
- for (Function &func : m) {
+ for (Function func : m) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;
@@ -89,20 +89,21 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
return llvmModule;
}
-static TranslateFromMLIRRegistration registration(
- "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) {
- if (!module)
- return true;
+static TranslateFromMLIRRegistration
+ registration("mlir-to-nvvmir",
+ [](Module *module, llvm::StringRef outputFilename) {
+ if (!module)
+ return true;
- auto llvmModule = mlir::translateModuleToNVVMIR(*module);
- if (!llvmModule)
- return true;
+ auto llvmModule = mlir::translateModuleToNVVMIR(*module);
+ if (!llvmModule)
+ return true;
- auto file = openOutputFile(outputFilename);
- if (!file)
- return true;
+ auto file = openOutputFile(outputFilename);
+ if (!file)
+ return true;
- llvmModule->print(file->os(), nullptr);
- file->keep();
- return false;
- });
+ llvmModule->print(file->os(), nullptr);
+ file->keep();
+ return false;
+ });
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ef286cb64fd..4a68ac71ee0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
bool ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
- for (Function &function : mlirModule) {
+ for (Function function : mlirModule) {
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
@@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() {
}
// Convert functions.
- for (Function &function : mlirModule) {
+ for (Function function : mlirModule) {
// Ignore external functions.
if (function.isExternal())
continue;
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 8a2002ce368..394b3ef8db5 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> {
void Canonicalizer::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index be60ada6a43..84f00b97e38 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -849,7 +849,7 @@ struct FunctionConverter {
/// error, success otherwise. If 'signatureConversion' is provided, the
/// arguments of the entry block are updated accordingly.
LogicalResult
- convertFunction(Function *f,
+ convertFunction(Function f,
TypeConverter::SignatureConversion *signatureConversion);
/// Converts the given region starting from the entry block and following the
@@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
}
LogicalResult FunctionConverter::convertFunction(
- Function *f, TypeConverter::SignatureConversion *signatureConversion) {
+ Function f, TypeConverter::SignatureConversion *signatureConversion) {
// If this is an external function, there is nothing else to do.
- if (f->isExternal())
+ if (f.isExternal())
return success();
- DialectConversionRewriter rewriter(f->getBody(), typeConverter);
+ DialectConversionRewriter rewriter(f.getBody(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
rewriter.argConverter.convertSignature(
- &f->getBody().front(), *signatureConversion, rewriter.mapping);
+ &f.getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
if (failed(
- convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) {
+ convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
// Reset any of the generated rewrites.
rewriter.discardRewrites();
return failure();
@@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const
// applyConversionPatterns
//===----------------------------------------------------------------------===//
-namespace {
-/// This class represents a function to be converted. It allows for converting
-/// the body of functions and the signature in two phases.
-struct ConvertedFunction {
- ConvertedFunction(Function *fn, FunctionType newType,
- ArrayRef<NamedAttributeList> newFunctionArgAttrs)
- : fn(fn), newType(newType),
- newFunctionArgAttrs(newFunctionArgAttrs.begin(),
- newFunctionArgAttrs.end()) {}
-
- /// The function to convert.
- Function *fn;
- /// The new type and argument attributes for the function.
- FunctionType newType;
- SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-};
-} // end anonymous namespace
-
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
@@ -1149,37 +1131,33 @@ LogicalResult
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
- std::vector<Function *> allFunctions;
- allFunctions.reserve(module.getFunctions().size());
- for (auto &func : module)
- allFunctions.push_back(&func);
+ SmallVector<Function, 32> allFunctions(module.getFunctions());
return applyConversionPatterns(allFunctions, target, converter,
std::move(patterns));
}
/// Convert the given functions with the provided conversion patterns.
LogicalResult mlir::applyConversionPatterns(
- ArrayRef<Function *> fns, ConversionTarget &target,
+ MutableArrayRef<Function> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns) {
if (fns.empty())
return success();
// Build the function converter.
- FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
- &converter);
+ auto *ctx = fns.front().getContext();
+ FunctionConverter funcConverter(ctx, target, patterns, &converter);
// Try to convert each of the functions within the module.
- auto *ctx = fns.front()->getContext();
- for (auto *func : fns) {
+ for (auto func : fns) {
// Convert the function type using the type converter.
auto conversion =
- converter.convertSignature(func->getType(), func->getAllArgAttrs());
+ converter.convertSignature(func.getType(), func.getAllArgAttrs());
if (!conversion)
return failure();
// Update the function signature.
- func->setType(conversion->getConvertedType(ctx));
- func->setAllArgAttrs(conversion->getConvertedArgAttrs());
+ func.setType(conversion->getConvertedType(ctx));
+ func.setAllArgAttrs(conversion->getConvertedArgAttrs());
// Convert the body of this function.
if (failed(funcConverter.convertFunction(func, &*conversion)))
@@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns(
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
-mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
+mlir::applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(fn.getContext(), target, patterns);
- return converter.convertFunction(&fn, /*signatureConversion=*/nullptr);
+ return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
}
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 5a926ceaa92..a3aa092b0ec 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block &block) {
auto *op = block.getContainingOp();
- return op ? op->emitRemark() : block.getFunction()->emitRemark();
+ return op ? op->emitRemark() : block.getFunction().emitRemark();
}
/// Creates a buffer in the faster memory space for the specified region;
@@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
OpBuilder &b = region.isWrite() ? epilogue : prologue;
// Builder to create constants at the top level.
- auto *func = block->getFunction();
- OpBuilder top(func->getBody());
+ auto func = block->getFunction();
+ OpBuilder top(func.getBody());
auto loc = region.loc;
auto *memref = region.memref;
@@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
if (auto *op = block->getContainingOp())
op->emitError(str);
else
- block->getFunction()->emitError(str);
+ block->getFunction().emitError(str);
}
return totalDmaBuffersSizeInBytes;
}
void DmaGeneration::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
OpBuilder topBuilder(f.getBody());
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 8d2e75b2dca..77b944f3e01 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -257,7 +257,7 @@ public:
// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
- bool init(Function &f);
+ bool init(Function f);
// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
@@ -637,7 +637,7 @@ public:
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
-bool MemRefDependenceGraph::init(Function &f) {
+bool MemRefDependenceGraph::init(Function f) {
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
@@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getFunction()->getBody());
+ OpBuilder top(forInst->getFunction().getBody());
// Create new memref type based on slice bounds.
auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
@@ -1750,9 +1750,9 @@ public:
};
// Search for siblings which load the same memref function argument.
- auto *fn = dstNode->op->getFunction();
- for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
- for (auto *user : fn->getArgument(i)->getUsers()) {
+ auto fn = dstNode->op->getFunction();
+ for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+ for (auto *user : fn.getArgument(i)->getUsers()) {
if (auto loadOp = dyn_cast<LoadOp>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index c1be6e8f6b1..2744e5ca05c 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
-static void getTileableBands(Function &f,
+static void getTileableBands(Function f,
std::vector<SmallVector<AffineForOp, 6>> *bands) {
// Get maximal perfect nest of 'affine.for' insts starting from root
// (inclusive).
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 05953926376..6f13f623fe8 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() {
// Store innermost loops as we walk.
std::vector<AffineForOp> loops;
- void walkPostOrder(Function *f) {
- for (auto &b : *f)
+ void walkPostOrder(Function f) {
+ for (auto &b : f)
walkPostOrder(b.begin(), b.end());
}
@@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() {
? clUnrollNumRepetitions
: 1;
// If the call back is provided, we will recurse until no loops are found.
- Function &func = getFunction();
+ Function func = getFunction();
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
InnermostLoopGatherer ilg;
- ilg.walkPostOrder(&func);
+ ilg.walkPostOrder(func);
auto &loops = ilg.loops;
if (loops.empty())
break;
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index 77a23b156b0..df30e270fe6 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -726,7 +726,7 @@ public:
} // end namespace
-LogicalResult mlir::lowerAffineConstructs(Function &function) {
+LogicalResult mlir::lowerAffineConstructs(Function function) {
OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index cd92198ac04..f59f1006ec5 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
+ LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state,
/// because we currently disallow vectorization of defs that come from another
/// scope.
/// TODO(ntv): please document return value.
-static bool materialize(Function *f, const SetVector<Operation *> &terminators,
+static bool materialize(Function f, const SetVector<Operation *> &terminators,
MaterializationState *state) {
DenseSet<Operation *> seen;
DominanceInfo domInfo(f);
@@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators,
return true;
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
}
return false;
}
@@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() {
NestedPatternContext mlContext;
// TODO(ntv): Check to see if this supports arbitrary top-level code.
- Function *f = &getFunction();
- if (f->getBlocks().size() != 1)
+ Function f = getFunction();
+ if (f.getBlocks().size() != 1)
return;
using matcher::Op;
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
MaterializationState state(hwVectorSize);
// Get the hardware vector type.
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index c5676afaf63..1208e2fdd15 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
void MemRefDataFlowOpt::runOnFunction() {
// Only supports single block functions at the moment.
- Function &f = getFunction();
+ Function f = getFunction();
if (f.getBlocks().size() != 1) {
markAllAnalysesPreserved();
return;
diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp
index f97f549c93e..c7c3621781a 100644
--- a/mlir/lib/Transforms/StripDebugInfo.cpp
+++ b/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
} // end anonymous namespace
void StripDebugInfo::runOnFunction() {
- Function &func = getFunction();
+ Function func = getFunction();
auto unknownLoc = UnknownLoc::get(&getContext());
// Strip the debug info from the function and its operations.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 47ca378f324..e185f702d27 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -44,7 +44,7 @@ namespace {
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
- explicit GreedyPatternRewriteDriver(Function &fn,
+ explicit GreedyPatternRewriteDriver(Function fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn.getBody()), matcher(std::move(patterns)) {
worklist.reserve(64);
@@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
-bool mlir::applyPatternsGreedily(Function &fn,
+bool mlir::applyPatternsGreedily(Function fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
bool converged = driver.simplifyFunction(maxPatternMatchIterations);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 728123f71a5..4ddf93c2232 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
- OpBuilder topBuilder(op->getFunction()->getBody());
+ OpBuilder topBuilder(op->getFunction().getBody());
auto constOp = topBuilder.create<ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv->replaceAllUsesWith(constOp);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 39a05d8c300..3fca26bef19 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
void Vectorize::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
if (!fastestVaryingPattern.empty() &&
fastestVaryingPattern.size() != vectorSizes.size()) {
f.emitRemark("Fastest varying pattern specified with different size than "
@@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() {
unsigned patternDepth = pat.getDepth();
SmallVector<NestedMatch, 8> matches;
- pat.match(&f, &matches);
+ pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {
diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp
index 1f2ab69409e..3c1a1b3b481 100644
--- a/mlir/lib/Transforms/ViewFunctionGraph.cpp
+++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp
@@ -53,13 +53,13 @@ std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) {
} // end namespace llvm
-void mlir::viewGraph(Function &function, const llvm::Twine &name,
+void mlir::viewGraph(Function function, const llvm::Twine &name,
bool shortNames, const llvm::Twine &title,
llvm::GraphProgram::Name program) {
llvm::ViewGraph(&function, name, shortNames, title, program);
}
-llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
+llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames, const llvm::Twine &title) {
return llvm::WriteGraph(os, &function, shortNames, title);
}
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 834e7c98228..a88312dba9b 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -43,13 +43,12 @@ static MLIRContext &globalContext() {
return context;
}
-static std::unique_ptr<Function> makeFunction(StringRef name,
- ArrayRef<Type> results = {},
- ArrayRef<Type> args = {}) {
+static Function makeFunction(StringRef name, ArrayRef<Type> results = {},
+ ArrayRef<Type> args = {}) {
auto &ctx = globalContext();
- auto function = llvm::make_unique<Function>(
- UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx));
- function->addEntryBlock();
+ auto function = Function::create(UnknownLoc::get(&ctx), name,
+ FunctionType::get(args, results, &ctx));
+ function.addEntryBlock();
return function;
}
@@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) {
auto f =
makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
- ub(f->getArgument(1));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
+ ub(f.getArgument(1));
ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
ValueHandle i7(constant_int(7, 32));
@@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) {
// CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32
// CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_dynamic_for) {
@@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) {
auto f = makeFunction("builder_dynamic_for", {},
{indexType, indexType, indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
- c(f->getArgument(2)), d(f->getArgument(3));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
+ c(f.getArgument(2)), d(f.getArgument(3));
LoopBuilder(&i, a - b, c + d, 2)();
// clang-format off
@@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) {
// CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3]
// CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 {
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_max_min_for) {
@@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) {
auto f = makeFunction("builder_max_min_for", {},
{indexType, indexType, indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
- ub1(f->getArgument(2)), ub2(f->getArgument(3));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
+ ub1(f.getArgument(2)), ub2(f.getArgument(3));
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
ret();
@@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) {
// CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) {
// CHECK: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_blocks) {
@@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks");
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
arg4(c1.getType()), r(c1.getType());
- BlockHandle b1, b2, functionBlock(&f->front());
+ BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1, &arg2})(
// b2 has not yet been constructed, need to come back later.
// This is a byproduct of non-structured control-flow.
@@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_blocks_eager) {
@@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks_eager");
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_cond_branch) {
@@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) {
auto f = makeFunction("builder_cond_branch", {},
{IntegerType::get(1, &globalContext())});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle funcArg(f->getArgument(0));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
- BlockHandle b1, b2, functionBlock(&f->front());
+ BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1})([&] { ret(); });
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
// Get back to entry block and add a conditional branch
@@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_cond_branch_eager) {
@@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) {
auto f = makeFunction("builder_cond_branch_eager", {},
{IntegerType::get(1, &globalContext())});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle funcArg(f->getArgument(0));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
@@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_helpers) {
@@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) {
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)),
+ vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
int64_t step0, step1, step2;
std::tie(lb0, ub0, step0) = vA.range(0);
@@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) {
// CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32
// CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(custom_ops) {
@@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
@@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) {
// clang-format off
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
OperationHandle ih0, ih2;
- IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
+ IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1));
IndexHandle ten(index_t(10)), twenty(index_t(20));
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
@@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) {
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(insertion_in_block) {
@@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
BlockHandle b1;
// clang-format off
ValueHandle::create<ConstantIntOp>(0, 32);
@@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) {
// CHECK: ^bb1: // no predecessors
// CHECK: {{.*}} = constant 1 : i32
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(select_op) {
@@ -438,12 +447,12 @@ TEST_FUNC(select_op) {
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle zero = constant_index(0), one = constant_index(1);
- MemRefView vA(f->getArgument(0));
- IndexedValue A(f->getArgument(0));
+ MemRefView vA(f.getArgument(0));
+ IndexedValue A(f.getArgument(0));
IndexHandle i, j;
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
// This test exercises IndexedValue::operator Value*.
@@ -461,7 +470,8 @@ TEST_FUNC(select_op) {
// CHECK-DAG: {{.*}} = load
// CHECK-NEXT: {{.*}} = select
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
@@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) {
MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
// clang-format off
@@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) {
// CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32
// CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref<?x?x?xf32>
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
// Inject an EDSC-constructed computation to exercise 2-d vectorization.
@@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) {
auto owningF =
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
- mlir::Function *f = owningF.release();
+ mlir::Function f = owningF;
mlir::Module module(&globalContext());
- module.getFunctions().push_back(f);
+ module.push_back(f);
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2));
// clang-format off
@@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) {
pm.addPass(mlir::createCanonicalizerPass());
SmallVector<int64_t, 2> vectorSizes{4, 4};
pm.addPass(mlir::createVectorizePass(vectorSizes));
- auto result = pm.run(f->getModule());
+ auto result = pm.run(f.getModule());
if (succeeded(result))
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
int main() {
diff --git a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
index 4767e3367be..7bfb5564064 100644
--- a/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
@@ -97,12 +97,12 @@ struct VectorizerTestPass : public FunctionPass<VectorizerTestPass> {
} // end anonymous namespace
void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
- auto *f = &getFunction();
+ auto f = getFunction();
using matcher::Op;
SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
clTestVectorShapeRatio.end());
auto subVectorType =
- VectorType::get(shape, FloatType::getF32(f->getContext()));
+ VectorType::get(shape, FloatType::getF32(f.getContext()));
// Only filter operations that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [&](Operation &op) {
@@ -148,7 +148,7 @@ static NestedPattern patternTestSlicingOps() {
}
void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
- auto *f = &getFunction();
+ auto f = getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
@@ -163,7 +163,7 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
}
void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
- auto *f = &getFunction();
+ auto f = getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
@@ -177,7 +177,7 @@ void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
}
void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) {
- auto *f = &getFunction();
+ auto f = getFunction();
SmallVector<NestedMatch, 8> matches;
patternTestSlicingOps().match(f, &matches);
@@ -195,7 +195,7 @@ static bool customOpWithAffineMapAttribute(Operation &op) {
}
void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
- auto *f = &getFunction();
+ auto f = getFunction();
using matcher::Op;
auto pattern = Op(customOpWithAffineMapAttribute);
@@ -227,7 +227,7 @@ static bool singleResultAffineApplyOpWithoutUses(Operation &op) {
void VectorizerTestPass::testNormalizeMaps() {
using matcher::Op;
- auto *f = &getFunction();
+ auto f = getFunction();
// Save matched AffineApplyOp that all need to be erased in the end.
auto pattern = Op(affineApplyOp);
@@ -256,7 +256,7 @@ void VectorizerTestPass::runOnFunction() {
NestedPatternContext mlContext;
// Only support single block functions at this point.
- Function &f = getFunction();
+ Function f = getFunction();
if (f.getBlocks().size() != 1)
return;
diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
index 54a9c6ce95c..1ac6c402630 100644
--- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
+++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
@@ -163,8 +163,8 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
static Error compileAndExecuteFunctionWithMemRefs(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
- Function *mainFunction = module->getNamedFunction(entryPoint);
- if (!mainFunction || mainFunction->getBlocks().empty()) {
+ Function mainFunction = module->getNamedFunction(entryPoint);
+ if (!mainFunction || mainFunction.getBlocks().empty()) {
return make_string_error("entry point not found");
}
@@ -172,9 +172,9 @@ static Error compileAndExecuteFunctionWithMemRefs(
// pretty print the results, because the function itself will be rewritten
// to use the LLVM dialect.
SmallVector<Type, 8> argTypes =
- llvm::to_vector<8>(mainFunction->getType().getInputs());
+ llvm::to_vector<8>(mainFunction.getType().getInputs());
SmallVector<Type, 8> resTypes =
- llvm::to_vector<8>(mainFunction->getType().getResults());
+ llvm::to_vector<8>(mainFunction.getType().getResults());
float init = std::stof(initValue.getValue());
@@ -206,18 +206,18 @@ static Error compileAndExecuteFunctionWithMemRefs(
static Error compileAndExecuteSingleFloatReturnFunction(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
- Function *mainFunction = module->getNamedFunction(entryPoint);
- if (!mainFunction || mainFunction->isExternal()) {
+ Function mainFunction = module->getNamedFunction(entryPoint);
+ if (!mainFunction || mainFunction.isExternal()) {
return make_string_error("entry point not found");
}
- if (!mainFunction->getType().getInputs().empty())
+ if (!mainFunction.getType().getInputs().empty())
return make_string_error("function inputs not supported");
- if (mainFunction->getType().getResults().size() != 1)
+ if (mainFunction.getType().getResults().size() != 1)
return make_string_error("only single f32 function result supported");
- auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
+ auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
if (!t)
return make_string_error("only single llvm.f32 function result supported");
auto *llvmTy = t.getUnderlyingType();
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index 38a059b3ba5..d2a82374124 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -25,11 +25,11 @@ using namespace mlir::detail;
namespace {
/// Minimal class definitions for two analyses.
struct MyAnalysis {
- MyAnalysis(Function *) {}
+ MyAnalysis(Function) {}
MyAnalysis(Module *) {}
};
struct OtherAnalysis {
- OtherAnalysis(Function *) {}
+ OtherAnalysis(Function) {}
OtherAnalysis(Module *) {}
};
@@ -59,10 +59,10 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
// Create a function and a module.
std::unique_ptr<Module> module(new Module(&context));
- Function *func1 =
- new Function(builder.getUnknownLoc(), "foo",
- builder.getFunctionType(llvm::None, llvm::None));
- module->getFunctions().push_back(func1);
+ Function func1 =
+ Function::create(builder.getUnknownLoc(), "foo",
+ builder.getFunctionType(llvm::None, llvm::None));
+ module->push_back(func1);
// Test fine grain invalidation of the function analysis manager.
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
@@ -87,10 +87,10 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
// Create a function and a module.
std::unique_ptr<Module> module(new Module(&context));
- Function *func1 =
- new Function(builder.getUnknownLoc(), "foo",
- builder.getFunctionType(llvm::None, llvm::None));
- module->getFunctions().push_back(func1);
+ Function func1 =
+ Function::create(builder.getUnknownLoc(), "foo",
+ builder.getFunctionType(llvm::None, llvm::None));
+ module->push_back(func1);
// Test fine grain invalidation of a function analysis from within a module
// analysis manager.
OpenPOWER on IntegriCloud