summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-07-01 10:29:09 -0700
committerjpienaar <jpienaar@google.com>2019-07-01 11:39:00 -0700
commit54cd6a7e97a226738e2c85b86559918dd9e3cd5d (patch)
treeaffa803347d6695be575137d1ad55a055a8021e3 /mlir/lib
parent84bd67fc4fd116e80f7a66bfadfe9a7fd6fd5e82 (diff)
downloadbcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.tar.gz
bcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.zip
NFC: Refactor Function to be value typed.
Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed). PiperOrigin-RevId: 255983022
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp2
-rw-r--r--mlir/lib/Analysis/Dominance.cpp9
-rw-r--r--mlir/lib/Analysis/OpStats.cpp2
-rw-r--r--mlir/lib/Analysis/TestParallelismDetection.cpp2
-rw-r--r--mlir/lib/Analysis/Verifier.cpp14
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp106
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp26
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp18
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp4
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp2
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/MemRefUtils.cpp10
-rw-r--r--mlir/lib/GPU/IR/GPUDialect.cpp18
-rw-r--r--mlir/lib/GPU/Transforms/KernelOutlining.cpp28
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp50
-rw-r--r--mlir/lib/IR/Attributes.cpp5
-rw-r--r--mlir/lib/IR/Block.cpp2
-rw-r--r--mlir/lib/IR/Builders.cpp4
-rw-r--r--mlir/lib/IR/Dialect.cpp15
-rw-r--r--mlir/lib/IR/Function.cpp59
-rw-r--r--mlir/lib/IR/Operation.cpp9
-rw-r--r--mlir/lib/IR/Region.cpp10
-rw-r--r--mlir/lib/IR/SymbolTable.cpp22
-rw-r--r--mlir/lib/IR/Value.cpp8
-rw-r--r--mlir/lib/LLVMIR/IR/LLVMDialect.cpp4
-rw-r--r--mlir/lib/Linalg/Transforms/Fusion.cpp2
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp69
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLoops.cpp3
-rw-r--r--mlir/lib/Linalg/Transforms/Tiling.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp28
-rw-r--r--mlir/lib/Pass/IRPrinting.cpp12
-rw-r--r--mlir/lib/Pass/Pass.cpp23
-rw-r--r--mlir/lib/Pass/PassDetail.h2
-rw-r--r--mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp2
-rw-r--r--mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp2
-rw-r--r--mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp2
-rw-r--r--mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp6
-rw-r--r--mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp2
-rw-r--r--mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp2
-rw-r--r--mlir/lib/StandardOps/Ops.cpp12
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp31
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp4
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp2
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp54
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp10
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp12
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp2
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp8
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp2
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp12
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp2
-rw-r--r--mlir/lib/Transforms/StripDebugInfo.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp4
-rw-r--r--mlir/lib/Transforms/ViewFunctionGraph.cpp4
57 files changed, 378 insertions, 383 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 016ef43a84a..d7650dcb127 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
if (inserted) {
reorderedDims.push_back(v);
}
- return getAffineDimExpr(iterPos->second, v->getFunction()->getContext())
+ return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
.cast<AffineDimExpr>();
}
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index 954a01b4843..b4cdeb7d886 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase<Block>;
/// Recalculate the dominance info.
template <bool IsPostDom>
-void DominanceInfoBase<IsPostDom>::recalculate(Function *function) {
+void DominanceInfoBase<IsPostDom>::recalculate(Function function) {
dominanceInfos.clear();
// Build the top level function dominance.
auto functionDominance = llvm::make_unique<base>();
- functionDominance->recalculate(function->getBody());
- dominanceInfos.try_emplace(&function->getBody(),
- std::move(functionDominance));
+ functionDominance->recalculate(function.getBody());
+ dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance));
/// Build the dominance for each of the operation regions.
- function->walk([&](Operation *op) {
+ function.walk([&](Operation *op) {
for (auto &region : op->getRegions()) {
// Don't compute dominance if the region is empty.
if (region.empty())
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index 5177afcee67..75a2fc1a5dc 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() {
opCount.clear();
// Compute the operation statistics for each function in the module.
- for (auto &fn : getModule())
+ for (auto fn : getModule())
fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
printSummary();
}
diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp
index cbda6d40224..473d253cfa2 100644
--- a/mlir/lib/Analysis/TestParallelismDetection.cpp
+++ b/mlir/lib/Analysis/TestParallelismDetection.cpp
@@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() {
// Walks the function and emits a note for all 'affine.for' ops detected as
// parallel.
void TestParallelismDetection::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
OpBuilder b(f.getBody());
f.walk<AffineForOp>([&](AffineForOp forOp) {
if (isLoopParallel(forOp))
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 1330fe0fb94..0d0525145ef 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -53,7 +53,7 @@ public:
: ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
/// Verify the body of the given function.
- LogicalResult verify(Function &fn);
+ LogicalResult verify(Function fn);
/// Verify the given operation.
LogicalResult verify(Operation &op);
@@ -104,7 +104,7 @@ private:
} // end anonymous namespace
/// Verify the body of the given function.
-LogicalResult OperationVerifier::verify(Function &fn) {
+LogicalResult OperationVerifier::verify(Function fn) {
// Verify the body first.
if (failed(verifyRegion(fn.getBody())))
return failure();
@@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) {
// check. We do this as a second pass since malformed CFG's can cause
// dominator analysis constructure to crash and we want the verifier to be
// resilient to malformed code.
- DominanceInfo theDomInfo(&fn);
+ DominanceInfo theDomInfo(fn);
domInfo = &theDomInfo;
if (failed(verifyDominance(fn.getBody())))
return failure();
@@ -313,7 +313,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
- if (failed(dialect->verifyFunctionAttribute(this, attr)))
+ if (failed(dialect->verifyFunctionAttribute(*this, attr)))
return failure();
}
@@ -331,7 +331,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
- if (failed(dialect->verifyFunctionArgAttribute(this, i, attr)))
+ if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr)))
return failure();
}
}
@@ -369,7 +369,7 @@ LogicalResult Operation::verify() {
LogicalResult Module::verify() {
// Check that all functions are uniquely named.
llvm::StringMap<Location> nameToOrigLoc;
- for (auto &fn : *this) {
+ for (auto fn : *this) {
auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
if (!it.second)
return fn.emitError()
@@ -379,7 +379,7 @@ LogicalResult Module::verify() {
}
// Check that each function is correct.
- for (auto &fn : *this)
+ for (auto fn : *this)
if (failed(fn.verify()))
return failure();
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
index 9d7aeeb6321..022d8c70cc6 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
@@ -64,8 +64,8 @@ public:
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
- for (auto &function : getModule()) {
- if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
+ for (auto function : getModule()) {
+ if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
continue;
}
if (failed(translateGpuKernelToCubinAnnotation(function)))
@@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
std::unique_ptr<Module> module(builder.createModule());
// TODO(herhut): Also handle called functions.
- module->getFunctions().push_back(function.clone());
+ module->push_back(function.clone());
auto llvmModule = translateModuleToNVVMIR(*module);
auto cubin = convertModuleToCubin(*llvmModule, function);
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index bd96f396b22..f9d5899456a 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -118,7 +118,7 @@ private:
void declareCudaFunctions(Location loc);
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
- Value *generateKernelNameConstant(Function *kernelFunction, Location &loc,
+ Value *generateKernelNameConstant(Function kernelFunction, Location &loc,
OpBuilder &builder);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
@@ -130,7 +130,7 @@ public:
// Cache the used LLVM types.
initializeCachedTypes();
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchFuncOp>(
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
}
@@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
Module &module = getModule();
Builder builder(&module);
if (!module.getNamedFunction(cuModuleLoadName)) {
- module.getFunctions().push_back(
- new Function(loc, cuModuleLoadName,
- builder.getFunctionType(
- {
- getPointerPointerType(), /* CUmodule *module */
- getPointerType() /* void *cubin */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuModuleLoadName,
+ builder.getFunctionType(
+ {
+ getPointerPointerType(), /* CUmodule *module */
+ getPointerType() /* void *cubin */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction.
- module.getFunctions().push_back(
- new Function(loc, cuModuleGetFunctionName,
- builder.getFunctionType(
- {
- getPointerPointerType(), /* void **function */
- getPointerType(), /* void *module */
- getPointerType() /* char *name */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuModuleGetFunctionName,
+ builder.getFunctionType(
+ {
+ getPointerPointerType(), /* void **function */
+ getPointerType(), /* void *module */
+ getPointerType() /* char *name */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
// CUstream.
- module.getFunctions().push_back(
- new Function(loc, cuLaunchKernelName,
- builder.getFunctionType(
- {
- getPointerType(), /* void* f */
- getIntPtrType(), /* intptr_t gridXDim */
- getIntPtrType(), /* intptr_t gridyDim */
- getIntPtrType(), /* intptr_t gridZDim */
- getIntPtrType(), /* intptr_t blockXDim */
- getIntPtrType(), /* intptr_t blockYDim */
- getIntPtrType(), /* intptr_t blockZDim */
- getInt32Type(), /* unsigned int sharedMemBytes */
- getPointerType(), /* void *hstream */
- getPointerPointerType(), /* void **kernelParams */
- getPointerPointerType() /* void **extra */
- },
- getCUResultType())));
+ module.push_back(Function::create(
+ loc, cuLaunchKernelName,
+ builder.getFunctionType(
+ {
+ getPointerType(), /* void* f */
+ getIntPtrType(), /* intptr_t gridXDim */
+ getIntPtrType(), /* intptr_t gridyDim */
+ getIntPtrType(), /* intptr_t gridZDim */
+ getIntPtrType(), /* intptr_t blockXDim */
+ getIntPtrType(), /* intptr_t blockYDim */
+ getIntPtrType(), /* intptr_t blockZDim */
+ getInt32Type(), /* unsigned int sharedMemBytes */
+ getPointerType(), /* void *hstream */
+ getPointerPointerType(), /* void **kernelParams */
+ getPointerPointerType() /* void **extra */
+ },
+ getCUResultType())));
}
if (!module.getNamedFunction(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream.
- module.getFunctions().push_back(new Function(
+ module.push_back(Function::create(
loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */)));
}
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
- module.getFunctions().push_back(
- new Function(loc, cuStreamSynchronizeName,
- builder.getFunctionType(
- {
- getPointerType() /* CUstream stream */
- },
- getCUResultType())));
+ module.push_back(
+ Function::create(loc, cuStreamSynchronizeName,
+ builder.getFunctionType(
+ {
+ getPointerType() /* CUstream stream */
+ },
+ getCUResultType())));
}
}
@@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
// %0[n] = constant name[n]
// %0[n+1] = 0
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
- Function *kernelFunction, Location &loc, OpBuilder &builder) {
+ Function kernelFunction, Location &loc, OpBuilder &builder) {
// TODO(herhut): Make this a constant once this is supported.
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction->getName().size() + 1));
+ builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
auto kernelName =
builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
- for (auto byte : llvm::enumerate(kernelFunction->getName())) {
+ for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
@@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
// Add trailing zero to terminate string.
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction->getName().size()));
+ builder.getI32IntegerAttr(kernelFunction.getName().size()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
ArrayRef<Value *>{index});
auto value = builder.create<LLVM::ConstantOp>(
@@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
auto cubinGetter =
- kernelFunction->getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
+ kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
if (!cubinGetter) {
- kernelFunction->emitError("Missing ")
+ kernelFunction.emitError("Missing ")
<< kCubinGetterAnnotation << " attribute.";
return signalPassFailure();
}
@@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
- Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
+ Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)});
@@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc);
- Function *cuModuleGetFunction =
+ Function cuModuleGetFunction =
getModule().getNamedFunction(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
// Grab the global stream needed for execution.
- Function *cuGetStreamHelper =
+ Function cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
index c1d4af380ce..97790a5afce 100644
--- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
@@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc";
class GpuGenerateCubinAccessorsPass
: public ModulePass<GpuGenerateCubinAccessorsPass> {
private:
- Function *getMallocHelper(Location loc, Builder &builder) {
- Function *result = getModule().getNamedFunction(kMallocHelperName);
+ Function getMallocHelper(Location loc, Builder &builder) {
+ Function result = getModule().getNamedFunction(kMallocHelperName);
if (!result) {
- result = new Function(
+ result = Function::create(
loc, kMallocHelperName,
builder.getFunctionType(
ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
- getModule().getFunctions().push_back(result);
+ getModule().push_back(result);
}
return result;
}
@@ -70,18 +70,18 @@ private:
// data from blob. As there are currently no global constants, this uses a
// sequence of store operations.
// TODO(herhut): Use global constants instead.
- Function *generateCubinAccessor(Builder &builder, Function &orig,
- StringAttr blob) {
+ Function generateCubinAccessor(Builder &builder, Function &orig,
+ StringAttr blob) {
Location loc = orig.getLoc();
SmallString<128> nameBuffer(orig.getName());
nameBuffer.append(kCubinGetterSuffix);
// Generate a function that returns void*.
- Function *result = new Function(
+ Function result = Function::create(
loc, mlir::Identifier::get(nameBuffer, &getContext()),
builder.getFunctionType(ArrayRef<Type>{},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
// Insert a body block that just returns the constant.
- OpBuilder ob(result->getBody());
+ OpBuilder ob(result.getBody());
ob.createBlock();
auto sizeConstant = ob.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
@@ -115,18 +115,18 @@ public:
void runOnModule() override {
llvmDialect =
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- Builder builder(getModule().getContext());
+ auto &module = getModule();
+ Builder builder(&getContext());
- auto &functions = getModule().getFunctions();
+ auto functions = module.getFunctions();
for (auto it = functions.begin(); it != functions.end();) {
// Move iterator to after the current function so that potential insertion
// of the accessor is after the kernel with cubin iself.
- Function &orig = *it++;
+ Function orig = *it++;
StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation);
if (!cubinBlob)
continue;
- it =
- functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
+ module.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
}
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 872707842d7..e849f6fd023 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
createIndexConstant(rewriter, op->getLoc(), elementSize)});
// Insert the `malloc` declaration if it is not already present.
- Function *mallocFunc =
- op->getFunction()->getModule()->getNamedFunction("malloc");
+ Function mallocFunc =
+ op->getFunction().getModule()->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
- mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
- op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+ mallocFunc =
+ Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+ op->getFunction().getModule()->push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
@@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
- Function *freeFunc =
- op->getFunction()->getModule()->getNamedFunction("free");
+ Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
- freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
- op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+ freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
+ op->getFunction().getModule()->push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
@@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) {
}
void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
- for (auto &f : *m) {
+ for (auto f : *m) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index ff198217bb7..dafc8e711f5 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
- auto &fn = getFunction();
+ auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
@@ -386,7 +386,7 @@ static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
- auto &fn = getFunction();
+ auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 9dcc6df6bea..8469fa2ea70 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
applyPatternsGreedily(func, std::move(patterns));
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
index ea8095b791c..0c93146a232 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -95,7 +95,7 @@ public:
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp
index 51636037382..f13b743de0c 100644
--- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp
+++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp
@@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
}
llvm::Expected<SmallVector<void *, 8>>
-mlir::allocateMemRefArguments(Function *func, float initialValue) {
+mlir::allocateMemRefArguments(Function func, float initialValue) {
SmallVector<void *, 8> args;
- args.reserve(func->getNumArguments());
- for (const auto &arg : func->getArguments()) {
+ args.reserve(func.getNumArguments());
+ for (const auto &arg : func.getArguments()) {
auto descriptor =
allocMemRefDescriptor(arg->getType(),
/*allocateData=*/true, initialValue);
@@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) {
args.push_back(*descriptor);
}
- if (func->getType().getNumResults() > 1)
+ if (func.getType().getNumResults() > 1)
return make_string_error("functions with more than 1 result not supported");
- for (Type resType : func->getType().getResults()) {
+ for (Type resType : func.getType().getResults()) {
auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
if (!descriptor)
return descriptor.takeError();
diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp
index e39860bddda..5e8090b42b4 100644
--- a/mlir/lib/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/GPU/IR/GPUDialect.cpp
@@ -30,9 +30,9 @@ using namespace mlir::gpu;
StringRef GPUDialect::getDialectName() { return "gpu"; }
-bool GPUDialect::isKernel(Function *function) {
+bool GPUDialect::isKernel(Function function) {
UnitAttr isKernelAttr =
- function->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
+ function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
@@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState *result,
- Function *kernelFunc, Value *gridSizeX,
+ Function kernelFunc, Value *gridSizeX,
Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
Value *blockSizeY, Value *blockSizeZ,
ArrayRef<Value *> kernelOperands) {
@@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
}
void LaunchFuncOp::build(Builder *builder, OperationState *result,
- Function *kernelFunc, KernelDim3 gridSize,
+ Function kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() {
return emitOpError("attribute 'kernel' must be a function");
}
- auto *module = getOperation()->getFunction()->getModule();
- Function *kernelFunc = module->getNamedFunction(kernel());
+ auto *module = getOperation()->getFunction().getModule();
+ Function kernelFunc = module->getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
- if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
+ if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName())) {
return emitError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
}
- unsigned numKernelFuncArgs = kernelFunc->getNumArguments();
+ unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
if (getNumKernelOperands() != numKernelFuncArgs) {
return emitOpError("got ")
<< getNumKernelOperands() << " kernel operands but expected "
<< numKernelFuncArgs;
}
- auto functionType = kernelFunc->getType();
+ auto functionType = kernelFunc.getType();
for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
return emitOpError("type of function argument ")
diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp
index 46363f06f72..f93febcf5da 100644
--- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp
@@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc,
// Add operations generating block/thread ids and gird/block dimensions at the
// beginning of `kernelFunc` and replace uses of the respective function args.
-static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
+static void injectGpuIndexOperations(Location loc, Function kernelFunc) {
OpBuilder OpBuilder(kernelFunc.getBody());
SmallVector<Value *, 12> indexOps;
createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
@@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
// Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.return` operations by `std.return` in the generated functions.
-static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
+static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
- Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
- Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
- outlinedFunc->getBody().takeBody(launchOp.getBody());
+ Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
+ Function outlinedFunc = Function::create(loc, kernelFuncName, type);
+ outlinedFunc.getBody().takeBody(launchOp.getBody());
Builder builder(launchOp.getContext());
- outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
- builder.getUnitAttr());
- injectGpuIndexOperations(loc, *outlinedFunc);
- outlinedFunc->walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
+ outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+ injectGpuIndexOperations(loc, outlinedFunc);
+ outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
OpBuilder replacer(op);
replacer.create<ReturnOp>(op.getLoc());
op.erase();
@@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
// `kernelFunc`.
static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
- Function &kernelFunc) {
+ Function kernelFunc) {
OpBuilder builder(launchOp);
SmallVector<Value *, 4> kernelOperandValues(
launchOp.getKernelOperandValues());
builder.create<gpu::LaunchFuncOp>(
- launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(),
+ launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
launchOp.erase();
}
@@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
ModuleManager moduleManager(&getModule());
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
- Function *outlinedFunc = outlineKernelFunc(op);
+ Function outlinedFunc = outlineKernelFunc(op);
moduleManager.insert(outlinedFunc);
- convertToLaunchFuncOp(op, *outlinedFunc);
+ convertToLaunchFuncOp(op, outlinedFunc);
});
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8e3d5788bb1..346d35af231 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) {
initializeSymbolAliases();
// Walk the module and visit each operation.
- for (auto &fn : *module) {
+ for (auto fn : *module) {
visitType(fn.getType());
for (auto attr : fn.getAttrs())
ModuleState::visitAttribute(attr.second);
@@ -342,7 +342,7 @@ public:
void printAttribute(Attribute attr, bool mayElideType = false);
void printType(Type type);
- void print(Function *fn);
+ void print(Function fn);
void printLocation(LocationAttr loc);
void printAffineMap(AffineMap map);
@@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) {
state.printTypeAliases(os);
// Print the module.
- for (auto &fn : *module)
- print(&fn);
+ for (auto fn : *module)
+ print(fn);
}
/// Print a floating point value in a way that the parser will be able to
@@ -1186,7 +1186,7 @@ namespace {
// CFG and ML functions.
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
- FunctionPrinter(Function *function, ModulePrinter &other);
+ FunctionPrinter(Function function, ModulePrinter &other);
// Prints the function as a whole.
void print();
@@ -1275,7 +1275,7 @@ protected:
void printValueID(Value *value, bool printResultNo = true) const;
private:
- Function *function;
+ Function function;
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
@@ -1305,10 +1305,10 @@ private:
};
} // end anonymous namespace
-FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other)
+FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other)
: ModulePrinter(other), function(function) {
- for (auto &block : *function)
+ for (auto &block : function)
numberValuesInBlock(block);
}
@@ -1419,17 +1419,17 @@ void FunctionPrinter::print() {
printFunctionSignature();
// Print out function attributes, if present.
- auto attrs = function->getAttrs();
+ auto attrs = function.getAttrs();
if (!attrs.empty()) {
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
// Print the trailing location.
- printTrailingLocation(function->getLoc());
+ printTrailingLocation(function.getLoc());
- if (!function->empty()) {
- printRegion(function->getBody(), /*printEntryBlockArgs=*/false,
+ if (!function.empty()) {
+ printRegion(function.getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
os << "\n";
}
@@ -1437,24 +1437,24 @@ void FunctionPrinter::print() {
}
void FunctionPrinter::printFunctionSignature() {
- os << "func @" << function->getName() << '(';
+ os << "func @" << function.getName() << '(';
- auto fnType = function->getType();
- bool isExternal = function->isExternal();
- for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+ auto fnType = function.getType();
+ bool isExternal = function.isExternal();
+ for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
// If this is an external function, don't print argument labels.
if (!isExternal) {
- printOperand(function->getArgument(i));
+ printOperand(function.getArgument(i));
os << ": ";
}
printType(fnType.getInput(i));
// Print the attributes for this argument.
- printOptionalAttrDict(function->getArgAttrs(i));
+ printOptionalAttrDict(function.getArgAttrs(i));
}
os << ')';
@@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term,
}
// Prints function with initialized module state.
-void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
+void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); }
//===----------------------------------------------------------------------===//
// print and dump methods
@@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) {
void Value::dump() { print(llvm::errs()); }
void Operation::print(raw_ostream &os) {
- auto *function = getFunction();
+ auto function = getFunction();
if (!function) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
- ModuleState state(function->getContext());
+ ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@@ -1754,13 +1754,13 @@ void Operation::dump() {
}
void Block::print(raw_ostream &os) {
- auto *function = getFunction();
+ auto function = getFunction();
if (!function) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
- ModuleState state(function->getContext());
+ ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
- ModuleState state(getFunction()->getContext());
+ ModuleState state(getFunction().getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(getFunction(), modulePrinter).printBlockName(this);
}
void Function::print(raw_ostream &os) {
ModuleState state(getContext());
- ModulePrinter(os, state).print(this);
+ ModulePrinter(os, state).print(*this);
}
void Function::dump() { print(llvm::errs()); }
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 01f9a060bd9..9cbba0fe429 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// FunctionAttr
//===----------------------------------------------------------------------===//
-FunctionAttr FunctionAttr::get(Function *value) {
- assert(value && "Cannot get FunctionAttr for a null function");
- return get(value->getName(), value->getContext());
-}
-
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::Function, value,
NoneType::get(ctx));
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index e7616f6d7d0..134f6e468a0 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -50,7 +50,7 @@ Operation *Block::getContainingOp() {
return getParent() ? getParent()->getContainingOp() : nullptr;
}
-Function *Block::getFunction() {
+Function Block::getFunction() {
Block *block = this;
while (auto *op = block->getContainingOp()) {
block = op->getBlock();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9b30205abdb..89df64260d3 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
-FunctionAttr Builder::getFunctionAttr(Function *value) {
- return FunctionAttr::get(value);
+FunctionAttr Builder::getFunctionAttr(Function value) {
+ return getFunctionAttr(value.getName());
}
FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 4547452eb55..e38b95ff0f7 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
@@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context)
Dialect::~Dialect() {}
+/// Verify an attribute from this dialect on the given function. Returns
+/// failure if the verification failed, success otherwise.
+LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) {
+ return success();
+}
+
+/// Verify an attribute from this dialect on the argument at 'argIndex' for
+/// the given function. Returns failure if the verification failed, success
+/// otherwise.
+LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex,
+ NamedAttribute) {
+ return success();
+}
+
/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 7d17ed1d705..f8835f02c26 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -27,45 +27,50 @@
#include "llvm/ADT/Twine.h"
using namespace mlir;
+using namespace mlir::detail;
-Function::Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs)
+FunctionStorage::FunctionStorage(Location location, StringRef name,
+ FunctionType type,
+ ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
-Function::Function(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs,
- ArrayRef<NamedAttributeList> argAttrs)
+FunctionStorage::FunctionStorage(Location location, StringRef name,
+ FunctionType type,
+ ArrayRef<NamedAttribute> attrs,
+ ArrayRef<NamedAttributeList> argAttrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
MLIRContext *Function::getContext() { return getType().getContext(); }
-Module *llvm::ilist_traits<Function>::getContainingModule() {
+Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
- iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
+ iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
}
/// This is a trait method invoked when a Function is added to a Module. We
/// keep the module pointer and module symbol table up to date.
-void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
- assert(!function->getModule() && "already in a module!");
+void llvm::ilist_traits<FunctionStorage>::addNodeToList(
+ FunctionStorage *function) {
+ assert(!function->module && "already in a module!");
function->module = getContainingModule();
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
-void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
+void llvm::ilist_traits<FunctionStorage>::removeNodeFromList(
+ FunctionStorage *function) {
assert(function->module && "not already in a module!");
function->module = nullptr;
}
/// This is a trait method invoked when an operation is moved from one block
/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<Function>::transferNodesFromList(
- ilist_traits<Function> &otherList, function_iterator first,
+void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
+ ilist_traits<FunctionStorage> &otherList, function_iterator first,
function_iterator last) {
// If we are transferring functions within the same module, the Module
// pointer doesn't need to be updated.
@@ -82,8 +87,10 @@ void llvm::ilist_traits<Function>::transferNodesFromList(
/// Unlink this function from its Module and delete it.
void Function::erase() {
- assert(getModule() && "Function has no parent");
- getModule()->getFunctions().erase(this);
+ if (auto *module = getModule())
+ getModule()->functions.erase(impl);
+ else
+ delete impl;
}
/// Emit an error about fatal conditions with this function, reporting up to
@@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) {
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
-void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
+void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<Identifier, Attribute> newAttrs;
- for (auto &attr : dest->getAttrs())
+ for (auto &attr : dest.getAttrs())
newAttrs.insert(attr);
for (auto &attr : getAttrs()) {
auto insertPair = newAttrs.insert(attr);
@@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
assert((insertPair.second || insertPair.first->second == attr.second) &&
"the two functions have incompatible attributes");
}
- dest->setAttrs(newAttrs.takeVector());
+ dest.setAttrs(newAttrs.takeVector());
// Clone the body.
- body.cloneInto(&dest->body, mapper);
+ impl->body.cloneInto(&dest.impl->body, mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
@@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
-Function *Function::clone(BlockAndValueMapping &mapper) {
- FunctionType newType = type;
+Function Function::clone(BlockAndValueMapping &mapper) {
+ FunctionType newType = impl->type;
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
@@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
- inputTypes.push_back(type.getInput(i));
- newType = FunctionType::get(inputTypes, type.getResults(), getContext());
+ inputTypes.push_back(newType.getInput(i));
+ newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
}
// Create the new function.
- Function *newFunc = new Function(getLoc(), getName(), newType);
+ Function newFunc = Function::create(getLoc(), getName(), newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
- newFunc->setArgAttrs(destI++, getArgAttrs(i));
+ newFunc.setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
-Function *Function::clone() {
+Function Function::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}
@@ -178,7 +185,7 @@ void Function::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
push_back(entry);
- entry->addArguments(type.getInputs());
+ entry->addArguments(impl->type.getInputs());
}
void Function::walk(const std::function<void(Operation *)> &callback) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 83171f12d1d..f953cd27a56 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -281,7 +281,7 @@ Operation *Operation::getParentOp() {
return block ? block->getContainingOp() : nullptr;
}
-Function *Operation::getFunction() {
+Function Operation::getFunction() {
return block ? block->getFunction() : nullptr;
}
@@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands,
}
static LogicalResult verifyTerminatorSuccessors(Operation *op) {
+ auto *parent = op->getContainingRegion();
+
// Verify that the operands lines up with the BB arguments in the successor.
- Function *fn = op->getFunction();
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
auto *succ = op->getSuccessor(i);
- if (succ->getFunction() != fn)
- return op->emitError("reference to block defined in another function");
+ if (succ->getParent() != parent)
+ return op->emitError("reference to block defined in another region");
if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
return failure();
}
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 992d9112beb..74c71b7aeac 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -21,7 +21,7 @@
#include "mlir/IR/Operation.h"
using namespace mlir;
-Region::Region(Function *container) : container(container) {}
+Region::Region(Function container) : container(container.impl) {}
Region::Region(Operation *container) : container(container) {}
@@ -38,7 +38,7 @@ MLIRContext *Region::getContext() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getContext();
- return getContainingFunction()->getContext();
+ return getContainingFunction().getContext();
}
/// Return a location for this region. This is the location attached to the
@@ -47,7 +47,7 @@ Location Region::getLoc() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getLoc();
- return getContainingFunction()->getLoc();
+ return getContainingFunction().getLoc();
}
Region *Region::getContainingRegion() {
@@ -60,8 +60,8 @@ Operation *Region::getContainingOp() {
return container.dyn_cast<Operation *>();
}
-Function *Region::getContainingFunction() {
- return container.dyn_cast<Function *>();
+Function Region::getContainingFunction() {
+ return container.dyn_cast<detail::FunctionStorage *>();
}
bool Region::isProperAncestor(Region *other) {
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index a0819a78fc1..dafbd48f513 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -22,8 +22,8 @@ using namespace mlir;
/// Build a symbol table with the symbols within the given module.
SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
- for (auto &func : *module) {
- auto inserted = symbolTable.insert({func.getName(), &func});
+ for (auto func : *module) {
+ auto inserted = symbolTable.insert({func.getName(), func});
(void)inserted;
assert(inserted.second &&
"expected module to contain uniquely named functions");
@@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
-Function *SymbolTable::lookup(StringRef name) const {
+Function SymbolTable::lookup(StringRef name) const {
return lookup(Identifier::get(name, context));
}
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
-Function *SymbolTable::lookup(Identifier name) const {
+Function SymbolTable::lookup(Identifier name) const {
return symbolTable.lookup(name);
}
/// Erase the given symbol from the table.
-void SymbolTable::erase(Function *symbol) {
- auto it = symbolTable.find(symbol->getName());
+void SymbolTable::erase(Function symbol) {
+ auto it = symbolTable.find(symbol.getName());
if (it != symbolTable.end() && it->second == symbol)
symbolTable.erase(it);
}
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
-void SymbolTable::insert(Function *symbol) {
+void SymbolTable::insert(Function symbol) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
- if (symbolTable.insert({symbol->getName(), symbol}).second)
+ if (symbolTable.insert({symbol.getName(), symbol}).second)
return;
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(symbol->getName());
+ SmallString<128> nameBuffer(symbol.getName());
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used. We use a
@@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
- symbol->setName(Identifier::get(nameBuffer, context));
- } while (!symbolTable.insert({symbol->getName(), symbol}).second);
+ symbol.setName(Identifier::get(nameBuffer, context));
+ } while (!symbolTable.insert({symbol.getName(), symbol}).second);
}
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 073c3b369c6..65a98f7ee59 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() {
}
/// Return the function that this Value is defined in.
-Function *Value::getFunction() {
+Function Value::getFunction() {
switch (getKind()) {
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
@@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() {
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
-Function *BlockArgument::getFunction() {
+Function BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->getFunction();
return nullptr;
@@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() {
/// Returns if the current argument is a function argument.
bool BlockArgument::isFunctionArgument() {
- auto *containingFn = getFunction();
- return containingFn && &containingFn->front() == getOwner();
+ auto containingFn = getFunction();
+ return containingFn && &containingFn.front() == getOwner();
}
diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
index 0d3a5ca2756..0dbf63a3ce7 100644
--- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp
@@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const {
}
/// Verify LLVMIR function argument attributes.
-LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
+LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func,
unsigned argIdx,
NamedAttribute argAttr) {
// Check that llvm.noalias is a boolean attribute.
if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
- return func->emitError()
+ return func.emitError()
<< "llvm.noalias argument attribute of non boolean type";
return success();
}
diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp
index 7ddb7b0c19f..5761cc637b7 100644
--- a/mlir/lib/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Linalg/Transforms/Fusion.cpp
@@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
return true;
}
-static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
+static void fuseLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
DenseSet<Operation *> eraseSet;
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index a8099aaff99..5fe4f07613a 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -170,12 +170,13 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *mallocFunc = module->getNamedFunction("malloc");
+ auto *module = op->getFunction().getModule();
+ Function mallocFunc = module->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
- mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
- module->getFunctions().push_back(mallocFunc);
+ mallocFunc =
+ Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+ module->push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
@@ -230,12 +231,12 @@ public:
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *freeFunc = module->getNamedFunction("free");
+ auto *module = op->getFunction().getModule();
+ Function freeFunc = module->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
- freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
- module->getFunctions().push_back(freeFunc);
+ freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
+ module->push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
@@ -572,37 +573,37 @@ public:
// Create a function definition which takes as argument pointers to the input
// types and returns pointers to the output types.
-static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
- auto implFnName = (libFn->getName().str() + "_impl");
- auto module = libFn->getModule();
- if (auto *f = module->getNamedFunction(implFnName)) {
+static Function getLLVMLibraryCallImplDefinition(Function libFn) {
+ auto implFnName = (libFn.getName().str() + "_impl");
+ auto module = libFn.getModule();
+ if (auto f = module->getNamedFunction(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
- for (auto t : libFn->getType().getInputs()) {
+ for (auto t : libFn.getType().getInputs()) {
assert(t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
}
- auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
+ auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
// Insert the implementation function definition.
- auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
- module->getFunctions().push_back(implFnDefn);
+ auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
+ module->push_back(implFnDefn);
return implFnDefn;
}
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
-static Function *getLLVMLibraryCallDeclaration(Operation *op,
- LLVMTypeConverter &lowering,
- PatternRewriter &rewriter) {
+static Function getLLVMLibraryCallDeclaration(Operation *op,
+ LLVMTypeConverter &lowering,
+ PatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
- auto module = op->getFunction()->getModule();
- if (auto *f = module->getNamedFunction(fnName)) {
+ auto module = op->getFunction().getModule();
+ if (auto f = module->getNamedFunction(fnName)) {
return f;
}
@@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op,
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
- auto libFn = new Function(op->getLoc(), fnName, libFnType);
- module->getFunctions().push_back(libFn);
+ auto libFn = Function::create(op->getLoc(), fnName, libFnType);
+ module->push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
-static void getLLVMLibraryCallDefinition(Function *fn,
+static void getLLVMLibraryCallDefinition(Function fn,
LLVMTypeConverter &lowering) {
// Generate the implementation function definition.
auto implFn = getLLVMLibraryCallImplDefinition(fn);
// Generate the function body.
- fn->addEntryBlock();
+ fn.addEntryBlock();
- OpBuilder builder(fn->getBody());
- edsc::ScopedContext scope(builder, fn->getLoc());
+ OpBuilder builder(fn.getBody());
+ edsc::ScopedContext scope(builder, fn.getLoc());
SmallVector<Value *, 4> implFnArgs;
// Create a constant 1.
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
- IntegerAttr::get(IndexType::get(fn->getContext()), 1));
- for (auto arg : fn->getArguments()) {
+ IntegerAttr::get(IndexType::get(fn.getContext()), 1));
+ for (auto arg : fn.getArguments()) {
// Allocate a stack for storing the argument value. The stack is passed to
// the implementation function.
auto alloca =
@@ -665,17 +666,17 @@ public:
return convertLinalgType(t, *this);
}
- void addLibraryFnDeclaration(Function *fn) {
+ void addLibraryFnDeclaration(Function fn) {
libraryFnDeclarations.push_back(fn);
}
- ArrayRef<Function *> getLibraryFnDeclarations() {
+ ArrayRef<Function> getLibraryFnDeclarations() {
return libraryFnDeclarations;
}
private:
/// List of library functions declarations needed during dialect conversion
- SmallVector<Function *, 2> libraryFnDeclarations;
+ SmallVector<Function, 2> libraryFnDeclarations;
};
} // end anonymous namespace
@@ -692,7 +693,7 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
- auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+ auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
@@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) {
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
- for (auto &f : module.getFunctions()) {
+ for (auto f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))
diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
index d31ba5bf22d..2e616c35f1d 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp
@@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
} // namespace
void LowerLinalgToLoopsPass::runOnFunction() {
- auto &f = getFunction();
OperationFolder state;
- f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
+ getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) {
emitLinalgOpAsLoops(linalgOp, state);
linalgOp.getOperation()->erase();
});
diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp
index c63e1cf197d..2f752b2b637 100644
--- a/mlir/lib/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Linalg/Transforms/Tiling.cpp
@@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
return tileLinalgOp(op, tileSizeValues, state);
}
-static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
+static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 44f05963727..4af2f093daf 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -254,7 +254,7 @@ public:
/// trailing-location ::= location?
///
template <typename Owner>
- ParseResult parseOptionalTrailingLocation(Owner *owner) {
+ ParseResult parseOptionalTrailingLocation(Owner &owner) {
// If there is a 'loc' we parse a trailing location.
if (!getToken().is(Token::kw_loc))
return success();
@@ -263,7 +263,7 @@ public:
LocationAttr directLoc;
if (parseLocation(directLoc))
return failure();
- owner->setLoc(directLoc);
+ owner.setLoc(directLoc);
return success();
}
@@ -2472,8 +2472,8 @@ namespace {
/// operations.
class OperationParser : public Parser {
public:
- OperationParser(ParserState &state, Function *function)
- : Parser(state), function(function), opBuilder(function->getBody()) {}
+ OperationParser(ParserState &state, Function function)
+ : Parser(state), function(function), opBuilder(function.getBody()) {}
~OperationParser();
@@ -2588,7 +2588,7 @@ public:
Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
private:
- Function *function;
+ Function function;
/// Returns the info for a block at the current scope for the given name.
std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
@@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() {
for (auto entry : forwardRefInCurrentScope) {
errors.push_back({entry.second.getPointer(), entry.first});
// Add this block to the top-level region to allow for automatic cleanup.
- function->push_back(entry.first);
+ function.push_back(entry.first);
}
llvm::array_pod_sort(errors.begin(), errors.end());
@@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() {
}
// Try to parse the optional trailing location.
- if (parseOptionalTrailingLocation(op))
+ if (parseOptionalTrailingLocation(*op))
return failure();
return success();
@@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) {
}
// Okay, the function signature was parsed correctly, create the function now.
- auto *function =
- new Function(getEncodedSourceLocation(loc), name, type, attrs);
- module->getFunctions().push_back(function);
+ auto function =
+ Function::create(getEncodedSourceLocation(loc), name, type, attrs);
+ module->push_back(function);
// Parse an optional trailing location.
if (parseOptionalTrailingLocation(function))
return failure();
// Add the attributes to the function arguments.
- for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
- function->setArgAttrs(i, argAttrs[i]);
+ for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i)
+ function.setArgAttrs(i, argAttrs[i]);
// External functions have no body.
if (getToken().isNot(Token::l_brace))
@@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) {
// Parse the function body.
auto parser = OperationParser(getState(), function);
- if (parser.parseRegion(function->getBody(), entryArgs))
+ if (parser.parseRegion(function.getBody(), entryArgs))
return failure();
// Verify that a valid function body was parsed.
- if (function->empty())
+ if (function.empty())
return emitError(braceLoc, "function must have a body");
return parser.finalize(braceLoc);
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 868d492e094..057f2655207 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -61,12 +61,12 @@ private:
static void printIR(const llvm::Any &ir, bool printModuleScope,
raw_ostream &out) {
// Check for printing at module scope.
- if (printModuleScope && llvm::any_isa<Function *>(ir)) {
- Function *function = llvm::any_cast<Function *>(ir);
+ if (printModuleScope && llvm::any_isa<Function>(ir)) {
+ Function function = llvm::any_cast<Function>(ir);
// Print the function name and a newline before the Module.
- out << " (function: " << function->getName() << ")\n";
- function->getModule()->print(out);
+ out << " (function: " << function.getName() << ")\n";
+ function.getModule()->print(out);
return;
}
@@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
out << "\n";
// Print the given function.
- if (llvm::any_isa<Function *>(ir)) {
- llvm::any_cast<Function *>(ir)->print(out);
+ if (llvm::any_isa<Function>(ir)) {
+ llvm::any_cast<Function>(ir).print(out);
return;
}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 2f605b6690b..27ec74c23c2 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -46,8 +46,7 @@ static llvm::cl::opt<bool>
void Pass::anchor() {}
/// Forwarding function to execute this pass.
-LogicalResult FunctionPassBase::run(Function *fn,
- FunctionAnalysisManager &fam) {
+LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
// Initialize the pass state.
passState.emplace(fn, fam);
@@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
}
/// Run all of the passes in this manager over the current function.
-LogicalResult detail::FunctionPassExecutor::run(Function *function,
+LogicalResult detail::FunctionPassExecutor::run(Function function,
FunctionAnalysisManager &fam) {
// Run each of the held passes.
for (auto &pass : passes)
@@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module,
/// Utility to run the given function and analysis manager on a provided
/// function pass executor.
static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
- Function *func,
+ Function func,
FunctionAnalysisManager &fam) {
// Run the function pipeline over the provided function.
auto result = fpe.run(func, fam);
@@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
/// module.
void ModuleToFunctionPassAdaptor::runOnModule() {
ModuleAnalysisManager &mam = getAnalysisManager();
- for (auto &func : getModule()) {
+ for (auto func : getModule()) {
// Skip external functions.
if (func.isExternal())
continue;
// Run the held function pipeline over the current function.
- auto fam = mam.slice(&func);
- if (failed(runFunctionPipeline(fpe, &func, fam)))
+ auto fam = mam.slice(func);
+ if (failed(runFunctionPipeline(fpe, func, fam)))
return signalPassFailure();
// Clear out any computed function analyses. These analyses won't be used
@@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
// Run a prepass over the module to collect the functions to execute a over.
// This ensures that an analysis manager exists for each function, as well as
// providing a queue of functions to execute over.
- std::vector<std::pair<Function *, FunctionAnalysisManager>> funcAMPairs;
- for (auto &func : getModule())
+ std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs;
+ for (auto func : getModule())
if (!func.isExternal())
- funcAMPairs.emplace_back(&func, mam.slice(&func));
+ funcAMPairs.emplace_back(func, mam.slice(func));
// A parallel diagnostic handler that provides deterministic diagnostic
// ordering.
@@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
}
/// Create an analysis slice for the given child function.
-FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) {
- assert(func->getModule() == moduleAnalyses.getIRUnit() &&
+FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) {
+ assert(func.getModule() == moduleAnalyses.getIRUnit() &&
"function has a different parent module");
auto it = functionAnalyses.find(func);
if (it == functionAnalyses.end()) {
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 46addfb8e9c..d2563fb62cd 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -48,7 +48,7 @@ public:
FunctionPassExecutor(const FunctionPassExecutor &rhs);
/// Run the executor on the given function.
- LogicalResult run(Function *function, FunctionAnalysisManager &fam);
+ LogicalResult run(Function function, FunctionAnalysisManager &fam);
/// Add a pass to the current executor. This takes ownership over the provided
/// pass pointer.
diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
index 375a64d8f2d..3f26bf075af 100644
--- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
@@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() {
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
- auto &func = getFunction();
+ auto func = getFunction();
// Insert stats for each argument.
for (auto *arg : func.getArguments()) {
diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
index dec4ea90db8..169fec3b39a 100644
--- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
@@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() {
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
CAGSlice cag(solverContext);
- for (auto &f : getModule()) {
+ for (auto f : getModule()) {
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
}
config.finalizeAnchors(cag);
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
index ed3b0956a16..6b376db8516 100644
--- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -58,7 +58,7 @@ public:
void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
index 3add211fdd5..543b7300af0 100644
--- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
+++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
@@ -36,11 +36,11 @@ using namespace mlir;
// block. The created block will be terminated by `std.return`.
Block *createOneBlockFunction(Builder builder, Module *module) {
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
- auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType);
- module->getFunctions().push_back(fn);
+ auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
+ module->push_back(fn);
auto *block = new Block();
- fn->push_back(block);
+ fn.push_back(block);
OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName());
ReturnOp::build(&builder, &state);
diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
index ebdcaf73717..33572d5adbe 100644
--- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
+++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
@@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
// to take in the SPIR-V ModuleOp directly after module and function are
// migrated to be general ops.
- for (auto &fn : *module) {
+ for (auto fn : *module) {
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
if (done) {
spirvModule.emitError("found more than one 'spv.module' op");
diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
index 1a8d79c1790..1ce2b69f055 100644
--- a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
+++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
@@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass
void StdOpsToSPIRVConversionPass::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
populateWithGenerated(func.getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index 6d5073f1c37..9fc216eef25 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
- auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
- auto fnType = fn->getType();
+ auto fnType = fn.getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
@@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
- auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
// Check that the referenced function has the correct type.
- if (fn->getType() != type)
+ if (fn.getType() != type)
return op.emitOpError("reference to function with mismatched type");
return success();
@@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
}
static LogicalResult verify(ReturnOp op) {
- auto *function = op.getOperation()->getFunction();
+ auto function = op.getOperation()->getFunction();
// The operand number and types must match the function signature.
- const auto &results = function->getType().getResults();
+ const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError("has ")
<< op.getNumOperands()
diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
index 74ade942fc7..1e8409246ef 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -69,7 +69,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
- for (Function &func : m) {
+ for (Function func : m) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;
@@ -89,20 +89,21 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
return llvmModule;
}
-static TranslateFromMLIRRegistration registration(
- "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) {
- if (!module)
- return true;
+static TranslateFromMLIRRegistration
+ registration("mlir-to-nvvmir",
+ [](Module *module, llvm::StringRef outputFilename) {
+ if (!module)
+ return true;
- auto llvmModule = mlir::translateModuleToNVVMIR(*module);
- if (!llvmModule)
- return true;
+ auto llvmModule = mlir::translateModuleToNVVMIR(*module);
+ if (!llvmModule)
+ return true;
- auto file = openOutputFile(outputFilename);
- if (!file)
- return true;
+ auto file = openOutputFile(outputFilename);
+ if (!file)
+ return true;
- llvmModule->print(file->os(), nullptr);
- file->keep();
- return false;
- });
+ llvmModule->print(file->os(), nullptr);
+ file->keep();
+ return false;
+ });
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ef286cb64fd..4a68ac71ee0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
bool ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
- for (Function &function : mlirModule) {
+ for (Function function : mlirModule) {
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
@@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() {
}
// Convert functions.
- for (Function &function : mlirModule) {
+ for (Function function : mlirModule) {
// Ignore external functions.
if (function.isExternal())
continue;
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 8a2002ce368..394b3ef8db5 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> {
void Canonicalizer::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index be60ada6a43..84f00b97e38 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -849,7 +849,7 @@ struct FunctionConverter {
/// error, success otherwise. If 'signatureConversion' is provided, the
/// arguments of the entry block are updated accordingly.
LogicalResult
- convertFunction(Function *f,
+ convertFunction(Function f,
TypeConverter::SignatureConversion *signatureConversion);
/// Converts the given region starting from the entry block and following the
@@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
}
LogicalResult FunctionConverter::convertFunction(
- Function *f, TypeConverter::SignatureConversion *signatureConversion) {
+ Function f, TypeConverter::SignatureConversion *signatureConversion) {
// If this is an external function, there is nothing else to do.
- if (f->isExternal())
+ if (f.isExternal())
return success();
- DialectConversionRewriter rewriter(f->getBody(), typeConverter);
+ DialectConversionRewriter rewriter(f.getBody(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
rewriter.argConverter.convertSignature(
- &f->getBody().front(), *signatureConversion, rewriter.mapping);
+ &f.getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
if (failed(
- convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) {
+ convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
// Reset any of the generated rewrites.
rewriter.discardRewrites();
return failure();
@@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const
// applyConversionPatterns
//===----------------------------------------------------------------------===//
-namespace {
-/// This class represents a function to be converted. It allows for converting
-/// the body of functions and the signature in two phases.
-struct ConvertedFunction {
- ConvertedFunction(Function *fn, FunctionType newType,
- ArrayRef<NamedAttributeList> newFunctionArgAttrs)
- : fn(fn), newType(newType),
- newFunctionArgAttrs(newFunctionArgAttrs.begin(),
- newFunctionArgAttrs.end()) {}
-
- /// The function to convert.
- Function *fn;
- /// The new type and argument attributes for the function.
- FunctionType newType;
- SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-};
-} // end anonymous namespace
-
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
@@ -1149,37 +1131,33 @@ LogicalResult
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
- std::vector<Function *> allFunctions;
- allFunctions.reserve(module.getFunctions().size());
- for (auto &func : module)
- allFunctions.push_back(&func);
+ SmallVector<Function, 32> allFunctions(module.getFunctions());
return applyConversionPatterns(allFunctions, target, converter,
std::move(patterns));
}
/// Convert the given functions with the provided conversion patterns.
LogicalResult mlir::applyConversionPatterns(
- ArrayRef<Function *> fns, ConversionTarget &target,
+ MutableArrayRef<Function> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns) {
if (fns.empty())
return success();
// Build the function converter.
- FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
- &converter);
+ auto *ctx = fns.front().getContext();
+ FunctionConverter funcConverter(ctx, target, patterns, &converter);
// Try to convert each of the functions within the module.
- auto *ctx = fns.front()->getContext();
- for (auto *func : fns) {
+ for (auto func : fns) {
// Convert the function type using the type converter.
auto conversion =
- converter.convertSignature(func->getType(), func->getAllArgAttrs());
+ converter.convertSignature(func.getType(), func.getAllArgAttrs());
if (!conversion)
return failure();
// Update the function signature.
- func->setType(conversion->getConvertedType(ctx));
- func->setAllArgAttrs(conversion->getConvertedArgAttrs());
+ func.setType(conversion->getConvertedType(ctx));
+ func.setAllArgAttrs(conversion->getConvertedArgAttrs());
// Convert the body of this function.
if (failed(funcConverter.convertFunction(func, &*conversion)))
@@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns(
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
-mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
+mlir::applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(fn.getContext(), target, patterns);
- return converter.convertFunction(&fn, /*signatureConversion=*/nullptr);
+ return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
}
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 5a926ceaa92..a3aa092b0ec 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block &block) {
auto *op = block.getContainingOp();
- return op ? op->emitRemark() : block.getFunction()->emitRemark();
+ return op ? op->emitRemark() : block.getFunction().emitRemark();
}
/// Creates a buffer in the faster memory space for the specified region;
@@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
OpBuilder &b = region.isWrite() ? epilogue : prologue;
// Builder to create constants at the top level.
- auto *func = block->getFunction();
- OpBuilder top(func->getBody());
+ auto func = block->getFunction();
+ OpBuilder top(func.getBody());
auto loc = region.loc;
auto *memref = region.memref;
@@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
if (auto *op = block->getContainingOp())
op->emitError(str);
else
- block->getFunction()->emitError(str);
+ block->getFunction().emitError(str);
}
return totalDmaBuffersSizeInBytes;
}
void DmaGeneration::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
OpBuilder topBuilder(f.getBody());
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 8d2e75b2dca..77b944f3e01 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -257,7 +257,7 @@ public:
// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
- bool init(Function &f);
+ bool init(Function f);
// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
@@ -637,7 +637,7 @@ public:
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
-bool MemRefDependenceGraph::init(Function &f) {
+bool MemRefDependenceGraph::init(Function f) {
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
@@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getFunction()->getBody());
+ OpBuilder top(forInst->getFunction().getBody());
// Create new memref type based on slice bounds.
auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
@@ -1750,9 +1750,9 @@ public:
};
// Search for siblings which load the same memref function argument.
- auto *fn = dstNode->op->getFunction();
- for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
- for (auto *user : fn->getArgument(i)->getUsers()) {
+ auto fn = dstNode->op->getFunction();
+ for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+ for (auto *user : fn.getArgument(i)->getUsers()) {
if (auto loadOp = dyn_cast<LoadOp>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index c1be6e8f6b1..2744e5ca05c 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
-static void getTileableBands(Function &f,
+static void getTileableBands(Function f,
std::vector<SmallVector<AffineForOp, 6>> *bands) {
// Get maximal perfect nest of 'affine.for' insts starting from root
// (inclusive).
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 05953926376..6f13f623fe8 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() {
// Store innermost loops as we walk.
std::vector<AffineForOp> loops;
- void walkPostOrder(Function *f) {
- for (auto &b : *f)
+ void walkPostOrder(Function f) {
+ for (auto &b : f)
walkPostOrder(b.begin(), b.end());
}
@@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() {
? clUnrollNumRepetitions
: 1;
// If the call back is provided, we will recurse until no loops are found.
- Function &func = getFunction();
+ Function func = getFunction();
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
InnermostLoopGatherer ilg;
- ilg.walkPostOrder(&func);
+ ilg.walkPostOrder(func);
auto &loops = ilg.loops;
if (loops.empty())
break;
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index 77a23b156b0..df30e270fe6 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -726,7 +726,7 @@ public:
} // end namespace
-LogicalResult mlir::lowerAffineConstructs(Function &function) {
+LogicalResult mlir::lowerAffineConstructs(Function function) {
OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index cd92198ac04..f59f1006ec5 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
+ LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state,
/// because we currently disallow vectorization of defs that come from another
/// scope.
/// TODO(ntv): please document return value.
-static bool materialize(Function *f, const SetVector<Operation *> &terminators,
+static bool materialize(Function f, const SetVector<Operation *> &terminators,
MaterializationState *state) {
DenseSet<Operation *> seen;
DominanceInfo domInfo(f);
@@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators,
return true;
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
}
return false;
}
@@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() {
NestedPatternContext mlContext;
// TODO(ntv): Check to see if this supports arbitrary top-level code.
- Function *f = &getFunction();
- if (f->getBlocks().size() != 1)
+ Function f = getFunction();
+ if (f.getBlocks().size() != 1)
return;
using matcher::Op;
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
MaterializationState state(hwVectorSize);
// Get the hardware vector type.
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index c5676afaf63..1208e2fdd15 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
void MemRefDataFlowOpt::runOnFunction() {
// Only supports single block functions at the moment.
- Function &f = getFunction();
+ Function f = getFunction();
if (f.getBlocks().size() != 1) {
markAllAnalysesPreserved();
return;
diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp
index f97f549c93e..c7c3621781a 100644
--- a/mlir/lib/Transforms/StripDebugInfo.cpp
+++ b/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
} // end anonymous namespace
void StripDebugInfo::runOnFunction() {
- Function &func = getFunction();
+ Function func = getFunction();
auto unknownLoc = UnknownLoc::get(&getContext());
// Strip the debug info from the function and its operations.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 47ca378f324..e185f702d27 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -44,7 +44,7 @@ namespace {
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
- explicit GreedyPatternRewriteDriver(Function &fn,
+ explicit GreedyPatternRewriteDriver(Function fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn.getBody()), matcher(std::move(patterns)) {
worklist.reserve(64);
@@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
-bool mlir::applyPatternsGreedily(Function &fn,
+bool mlir::applyPatternsGreedily(Function fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
bool converged = driver.simplifyFunction(maxPatternMatchIterations);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 728123f71a5..4ddf93c2232 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
- OpBuilder topBuilder(op->getFunction()->getBody());
+ OpBuilder topBuilder(op->getFunction().getBody());
auto constOp = topBuilder.create<ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv->replaceAllUsesWith(constOp);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 39a05d8c300..3fca26bef19 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
void Vectorize::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
if (!fastestVaryingPattern.empty() &&
fastestVaryingPattern.size() != vectorSizes.size()) {
f.emitRemark("Fastest varying pattern specified with different size than "
@@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() {
unsigned patternDepth = pat.getDepth();
SmallVector<NestedMatch, 8> matches;
- pat.match(&f, &matches);
+ pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {
diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp
index 1f2ab69409e..3c1a1b3b481 100644
--- a/mlir/lib/Transforms/ViewFunctionGraph.cpp
+++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp
@@ -53,13 +53,13 @@ std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) {
} // end namespace llvm
-void mlir::viewGraph(Function &function, const llvm::Twine &name,
+void mlir::viewGraph(Function function, const llvm::Twine &name,
bool shortNames, const llvm::Twine &title,
llvm::GraphProgram::Name program) {
llvm::ViewGraph(&function, name, shortNames, title, program);
}
-llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
+llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames, const llvm::Twine &title) {
return llvm::WriteGraph(os, &function, shortNames, title);
}
OpenPOWER on IntegriCloud