diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-18 09:28:48 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-18 09:29:20 -0800 |
| commit | 4562e389a43caa2e30ebf277c12743edafe6a0ac (patch) | |
| tree | 1901855666adba9be9576e864877fe191e197085 /mlir/lib | |
| parent | 24ab8362f2099ed42f2e05f09fb9323ad0c5ab27 (diff) | |
| download | bcm5719-llvm-4562e389a43caa2e30ebf277c12743edafe6a0ac.tar.gz bcm5719-llvm-4562e389a43caa2e30ebf277c12743edafe6a0ac.zip | |
NFC: Remove unnecessary 'llvm::' prefix from uses of llvm symbols declared in `mlir` namespace.
Aside from being cleaner, this also makes the codebase more consistent.
PiperOrigin-RevId: 286206974
Diffstat (limited to 'mlir/lib')
74 files changed, 379 insertions, 413 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 9cf7fa897bf..97868a56524 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -619,7 +619,7 @@ static void computeDirectionVector( const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, unsigned loopDepth, FlatAffineConstraints *dependenceDomain, - llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) { + SmallVector<DependenceComponent, 2> *dependenceComponents) { // Find the number of common loops shared by src and dst accesses. SmallVector<AffineForOp, 4> commonLoops; unsigned numCommonLoops = @@ -772,8 +772,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { DependenceResult mlir::checkMemrefAccessDependence( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, - llvm::SmallVector<DependenceComponent, 2> *dependenceComponents, - bool allowRAR) { + SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) { LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: " << Twine(loopDepth) << " between:\n";); LLVM_DEBUG(srcAccess.opInst->dump();); @@ -865,7 +864,7 @@ DependenceResult mlir::checkMemrefAccessDependence( /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth]. void mlir::getDependenceComponents( AffineForOp forOp, unsigned maxLoopDepth, - std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) { + std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) { // Collect all load and store ops in loop nest rooted at 'forOp'. SmallVector<Operation *, 8> loadAndStoreOpInsts; forOp.getOperation()->walk([&](Operation *opInst) { @@ -883,7 +882,7 @@ void mlir::getDependenceComponents( MemRefAccess dstAccess(dstOpInst); FlatAffineConstraints dependenceConstraints; - llvm::SmallVector<DependenceComponent, 2> depComps; + SmallVector<DependenceComponent, 2> depComps; // TODO(andydavis,bondhugula) Explore whether it would be profitable // to pre-compute and store deps instead of repeatedly checking. DependenceResult result = checkMemrefAccessDependence( diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 7f6da8eb418..d678355880e 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" @@ -34,7 +35,6 @@ using namespace mlir; using llvm::SmallDenseMap; using llvm::SmallDenseSet; -using llvm::SmallPtrSet; namespace { @@ -73,10 +73,11 @@ private: // Flattens the expressions in map. Returns failure if 'expr' was unable to be // flattened (i.e., semi-affine expressions not handled yet). -static LogicalResult getFlattenedAffineExprs( - ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, - std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, - FlatAffineConstraints *localVarCst) { +static LogicalResult +getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, + unsigned numSymbols, + std::vector<SmallVector<int64_t, 8>> *flattenedExprs, + FlatAffineConstraints *localVarCst) { if (exprs.empty()) { localVarCst->reset(numDims, numSymbols); return success(); @@ -109,7 +110,7 @@ static LogicalResult getFlattenedAffineExprs( LogicalResult mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, - llvm::SmallVectorImpl<int64_t> *flattenedExpr, + SmallVectorImpl<int64_t> *flattenedExpr, FlatAffineConstraints *localVarCst) { std::vector<SmallVector<int64_t, 8>> flattenedExprs; LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, @@ -121,7 +122,7 @@ mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, /// Flattens the expressions in map. Returns failure if 'expr' was unable to be /// flattened (i.e., semi-affine expressions not handled yet). LogicalResult mlir::getFlattenedAffineExprs( - AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, + AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, FlatAffineConstraints *localVarCst) { if (map.getNumResults() == 0) { localVarCst->reset(map.getNumDims(), map.getNumSymbols()); @@ -133,7 +134,7 @@ LogicalResult mlir::getFlattenedAffineExprs( } LogicalResult mlir::getFlattenedAffineExprs( - IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, + IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, FlatAffineConstraints *localVarCst) { if (set.getNumConstraints() == 0) { localVarCst->reset(set.getNumDims(), set.getNumSymbols()); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 1d88d09d269..a81116579ce 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -97,7 +97,7 @@ void mlir::buildTripCountMapAndOperands( // being an analysis utility, it shouldn't. Replace with a version that just // works with analysis structures (FlatAffineConstraints) and thus doesn't // update the IR. -llvm::Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) { +Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) { SmallVector<Value *, 4> operands; AffineMap map; buildTripCountMapAndOperands(forOp, &map, &operands); @@ -197,9 +197,9 @@ static bool isAccessIndexInvariant(Value *iv, Value *index) { return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); } -llvm::DenseSet<Value *> -mlir::getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices) { - llvm::DenseSet<Value *> res; +DenseSet<Value *> mlir::getInvariantAccesses(Value *iv, + ArrayRef<Value *> indices) { + DenseSet<Value *> res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto *val = indices[idx]; if (isAccessIndexInvariant(iv, val)) { diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index f01ec56ddb1..1c9f6211a84 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -27,7 +27,7 @@ using namespace mlir; namespace { struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { - explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {} + explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. void runOnModule() override; @@ -37,7 +37,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { private: llvm::StringMap<int64_t> opCount; - llvm::raw_ostream &os; + raw_ostream &os; }; } // namespace diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp index d0351e9bcf9..80a579d163f 100644 --- a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -94,7 +94,7 @@ static void checkDependences(ArrayRef<Operation *> loadsAndStores) { getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { FlatAffineConstraints dependenceConstraints; - llvm::SmallVector<DependenceComponent, 2> dependenceComponents; + SmallVector<DependenceComponent, 2> dependenceComponents; DependenceResult result = checkMemrefAccessDependence( srcAccess, dstAccess, d, &dependenceConstraints, &dependenceComponents); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index e06e88b92f1..23bfa303708 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -94,7 +94,7 @@ private: Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName); if (funcOp) - return llvm::cast<LLVMFuncOp>(*funcOp); + return cast<LLVMFuncOp>(*funcOp); mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 468973665f8..78fe15dff50 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -370,13 +370,13 @@ private: [&] { Value *shflValue = rewriter.create<LLVM::ExtractValueOp>( loc, type, shfl, rewriter.getIndexArrayAttr(0)); - return llvm::SmallVector<Value *, 1>{ + return SmallVector<Value *, 1>{ accumFactory(loc, value, shflValue, rewriter)}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); } - return llvm::SmallVector<Value *, 1>{value}; + return SmallVector<Value *, 1>{value}; }, // Generate a reduction over the entire warp. This is a specialization // of the above reduction with unconditional accumulation. @@ -394,7 +394,7 @@ private: /*return_value_and_is_valid=*/UnitAttr()); value = accumFactory(loc, value, shflValue, rewriter); } - return llvm::SmallVector<Value *, 1>{value}; + return SmallVector<Value *, 1>{value}; }); return rewriter.getInsertionBlock()->getArgument(0); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5d6a92fee92..5bb18458725 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1603,15 +1603,14 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { - rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::ArrayRef<Value *>(), - llvm::ArrayRef<Block *>(), - op->getAttrs()); + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, ArrayRef<Value *>(), ArrayRef<Block *>(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( - op, llvm::ArrayRef<Value *>(operands.front()), - llvm::ArrayRef<Block *>(), op->getAttrs()); + op, ArrayRef<Value *>(operands.front()), ArrayRef<Block *>(), + op->getAttrs()); return matchSuccess(); } @@ -1626,9 +1625,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { op->getLoc(), packedType, packed, operands[i], rewriter.getI64ArrayAttr(i)); } - rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed), - llvm::ArrayRef<Block *>(), - op->getAttrs()); + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, llvm::makeArrayRef(packed), ArrayRef<Block *>(), op->getAttrs()); return matchSuccess(); } }; @@ -1971,7 +1969,7 @@ static void ensureDistinctSuccessors(Block &bb) { auto *terminator = bb.getTerminator(); // Find repeated successors with arguments. - llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions; + llvm::SmallDenseMap<Block *, SmallVector<int, 4>> successorPositions; for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { Block *successor = terminator->getSuccessor(i); // Blocks with no arguments are safe even if they appear multiple times diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp index d4f362d685d..721e7092cfc 100644 --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -155,16 +155,16 @@ void coalesceCopy(TransferOpTy transfer, /// Emits remote memory accesses that are clipped to the boundaries of the /// MemRef. template <typename TransferOpTy> -llvm::SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer, - edsc::MemRefView &view, - ArrayRef<edsc::IndexHandle> ivs) { +SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer, + edsc::MemRefView &view, + ArrayRef<edsc::IndexHandle> ivs) { using namespace mlir::edsc; using namespace edsc::op; using edsc::intrinsics::select; IndexHandle zero(index_t(0)), one(index_t(1)); - llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices()); - llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs( + SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices()); + SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs( memRefAccess.size(), edsc::IndexHandle()); // Indices accessing to remote memory are clipped and their expressions are diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index e58f6f8d6ed..8c8c67d1595 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -616,9 +616,8 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. template <class MapOrSet> -static void -canonicalizePromotedSymbols(MapOrSet *mapOrSet, - llvm::SmallVectorImpl<Value *> *operands) { +static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, + SmallVectorImpl<Value *> *operands) { if (!mapOrSet || operands->empty()) return; @@ -662,7 +661,7 @@ canonicalizePromotedSymbols(MapOrSet *mapOrSet, template <class MapOrSet> static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, - llvm::SmallVectorImpl<Value *> *operands) { + SmallVectorImpl<Value *> *operands) { static_assert(std::is_same<MapOrSet, AffineMap>::value || std::is_same<MapOrSet, IntegerSet>::value, "Argument must be either of AffineMap or IntegerSet type"); @@ -738,13 +737,13 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, *operands = resultOperands; } -void mlir::canonicalizeMapAndOperands( - AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) { +void mlir::canonicalizeMapAndOperands(AffineMap *map, + SmallVectorImpl<Value *> *operands) { canonicalizeMapOrSetAndOperands<AffineMap>(map, operands); } -void mlir::canonicalizeSetAndOperands( - IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) { +void mlir::canonicalizeSetAndOperands(IntegerSet *set, + SmallVectorImpl<Value *> *operands) { canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index f0eeba0891a..955e2ecc88c 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -35,7 +35,7 @@ inline quant::UniformQuantizedType getUniformElementType(Type t) { } inline bool hasStorageBitWidth(quant::QuantizedType t, - llvm::ArrayRef<unsigned> checkWidths) { + ArrayRef<unsigned> checkWidths) { unsigned w = t.getStorageType().getIntOrFloatBitWidth(); for (unsigned checkWidth : checkWidths) { if (w == checkWidth) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 1c20be6a453..7324b96a7e1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -237,7 +237,7 @@ KernelDim3 LaunchOp::getBlockSizeOperandValues() { return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; } -llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { +iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { auto args = body().getBlocks().front().getArguments(); return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 416a37b3270..0a6a5915633 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -69,7 +69,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, gpu::LaunchFuncOp launch) { OpBuilder kernelBuilder(kernelFunc.getBody()); auto &firstBlock = kernelFunc.getBody().front(); - llvm::SmallVector<Value *, 8> newLaunchArgs; + SmallVector<Value *, 8> newLaunchArgs; BlockAndValueMapping map; for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) { map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i)); @@ -195,7 +195,7 @@ private: SymbolTable symbolTable(kernelModule); symbolTable.insert(kernelFunc); - llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc}; + SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc}; while (!symbolDefWorklist.empty()) { if (Optional<SymbolTable::UseRange> symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 9ac564599db..abbc4e0ae45 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1227,7 +1227,7 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser, auto *body = result.addRegion(); return parser.parseOptionalRegion( - *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes); + *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); } // Print the LLVMFuncOp. Collects argument and result types and passes them to @@ -1499,7 +1499,7 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { /// Get an LLVMType with an llvm type that may cause changes to the underlying /// llvm context when constructed. LLVMType LLVMType::getLocked(LLVMDialect *dialect, - llvm::function_ref<llvm::Type *()> typeBuilder) { + function_ref<llvm::Type *()> typeBuilder) { // Lock access to the llvm context and build the type. llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex); return get(dialect->getContext(), typeBuilder()); diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 77e3a1e392f..ba96186da38 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -44,7 +44,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices, Operation *mlir::edsc::makeLinalgGenericOp( ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs, ArrayRef<StructuredIndexed> outputs, - llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder, + function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder, ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6adfeb592ef..0fd29cdc6e0 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -632,7 +632,7 @@ namespace linalg { } // namespace linalg } // namespace mlir -static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap, +static AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank, MLIRContext *context) { if (maybeMap) return maybeMap.getValue(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 74000212373..f4364928af8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -100,7 +100,7 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value *consumedView, - llvm::function_ref<bool(Operation *)> isaOpType) { + function_ref<bool(Operation *)> isaOpType) { LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp); if (!consumer) return false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 435aa7245ba..4d8a24cb6cb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -315,7 +315,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, return res; } -llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( +Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes, ArrayRef<unsigned> permutation, OperationFolder *folder) { // 1. Enforce the convention that "tiling by zero" skips tiling a particular @@ -389,7 +389,7 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( return TiledLinalgOp{res, loops}; } -llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( +Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, ArrayRef<unsigned> permutation, OperationFolder *folder) { if (tileSizes.empty()) diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp index ec3c7f3433a..510e13e8028 100644 --- a/mlir/lib/Dialect/SDBM/SDBM.cpp +++ b/mlir/lib/Dialect/SDBM/SDBM.cpp @@ -88,11 +88,11 @@ namespace { struct SDBMBuilderResult { // Positions in the matrix of the variables taken with the "+" sign in the // difference expression, 0 if it is a constant rather than a variable. - llvm::SmallVector<unsigned, 2> positivePos; + SmallVector<unsigned, 2> positivePos; // Positions in the matrix of the variables taken with the "-" sign in the // difference expression, 0 if it is a constant rather than a variable. - llvm::SmallVector<unsigned, 2> negativePos; + SmallVector<unsigned, 2> negativePos; // Constant value in the difference expression. int64_t value = 0; @@ -184,13 +184,12 @@ public: return lhs; } - SDBMBuilder(llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> - &pointExprToStripe, - llvm::function_ref<unsigned(SDBMInputExpr)> callback) + SDBMBuilder(DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe, + function_ref<unsigned(SDBMInputExpr)> callback) : pointExprToStripe(pointExprToStripe), linearPosition(callback) {} - llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> &pointExprToStripe; - llvm::function_ref<unsigned(SDBMInputExpr)> linearPosition; + DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe; + function_ref<unsigned(SDBMInputExpr)> linearPosition; }; } // namespace @@ -239,7 +238,7 @@ SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) { // expression. Keep track of those in pointExprToStripe. // There may also be multiple stripe expressions equal to the same variable. // Introduce a temporary variable for each of those. - llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> pointExprToStripe; + DenseMap<SDBMExpr, SmallVector<unsigned, 2>> pointExprToStripe; unsigned numTemporaries = 0; auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe, @@ -512,7 +511,7 @@ void SDBM::getSDBMExpressions(SDBMDialect *dialect, } } -void SDBM::print(llvm::raw_ostream &os) { +void SDBM::print(raw_ostream &os) { unsigned numVariables = getNumVariables(); // Helper function that prints the name of the variable given its linearized diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 8f6b59d8e45..8cdd9c8566e 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -89,7 +89,7 @@ public: : subExprs(exprs.begin(), exprs.end()) {} AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b) : subExprs({a, b}) {} - llvm::SmallVector<AffineExprMatcher, 0> subExprs; + SmallVector<AffineExprMatcher, 0> subExprs; AffineExpr matched; }; } // namespace @@ -311,7 +311,7 @@ AffineExpr SDBMExpr::getAsAffineExpr() const { // LHS if the constant becomes zero. Otherwise, construct a sum expression. template <typename Result> Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated, - llvm::function_ref<Result(SDBMDirectExpr)> builder) { + function_ref<Result(SDBMDirectExpr)> builder) { SDBMDialect *dialect = expr.getDialect(); if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) { if (negated) diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp index e2d5332777d..5db478d388b 100644 --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -33,10 +33,9 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, return structType; } - llvm::SmallVector<Type, 4> memberTypes; - llvm::SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo; - llvm::SmallVector<spirv::StructType::MemberDecorationInfo, 4> - memberDecorations; + SmallVector<Type, 4> memberTypes; + SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo; + SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; VulkanLayoutUtils::Size structMemberOffset = 0; VulkanLayoutUtils::Size maxMemberAlignment = 1; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index c99e7ca8b20..def8ee810fe 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -149,7 +149,7 @@ Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect, DialectAsmParser &parser); static bool isValidSPIRVIntType(IntegerType type) { - return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}), + return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}), type.getWidth()); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 140470b8df8..0df4525bac6 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -80,7 +80,7 @@ static LogicalResult extractValueFromConstOp(Operation *op, template <typename Ty> static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, - llvm::function_ref<StringRef(Ty)> stringifyFn) { + function_ref<StringRef(Ty)> stringifyFn) { if (enumValues.empty()) { return nullptr; } @@ -399,7 +399,7 @@ static unsigned getBitWidth(Type type) { /// emits errors with the given loc on failure. static Type getElementType(Type type, ArrayRef<int32_t> indices, - llvm::function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { + function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { if (indices.empty()) { emitErrorFn("expected at least one index for spv.CompositeExtract"); return nullptr; @@ -423,7 +423,7 @@ getElementType(Type type, ArrayRef<int32_t> indices, static Type getElementType(Type type, Attribute indices, - llvm::function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { + function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>(); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); @@ -2317,7 +2317,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { auto &op = *moduleOp.getOperation(); auto *dialect = op.getDialect(); auto &body = op.getRegion(0).front(); - llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> + DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> entryPoints; SymbolTable table(moduleOp); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index e60805aca1b..df9cb47a562 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -2366,7 +2366,7 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { auto functionName = getFunctionSymbol(functionID); - llvm::SmallVector<Value *, 4> arguments; + SmallVector<Value *, 4> arguments; for (auto operand : llvm::drop_begin(operands, 3)) { auto *value = getValue(operand); if (!value) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 7db7111e086..4baac53b89f 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -69,7 +69,7 @@ static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, /// serialization of the merge block and the continue block, if exists, until /// after all other blocks have been processed. static LogicalResult visitInPrettyBlockOrder( - Block *headerBlock, llvm::function_ref<LogicalResult(Block *)> blockHandler, + Block *headerBlock, function_ref<LogicalResult(Block *)> blockHandler, bool skipHeader = false, ArrayRef<Block *> skipBlocks = {}) { llvm::df_iterator_default_set<Block *, 4> doneBlocks; doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); @@ -301,7 +301,7 @@ private: /// instruction if this is a SPIR-V selection/loop header block. LogicalResult processBlock(Block *block, bool omitLabel = false, - llvm::function_ref<void()> actionBeforeTerminator = nullptr); + function_ref<void()> actionBeforeTerminator = nullptr); /// Emits OpPhi instructions for the given block if it has block arguments. LogicalResult emitPhiForBlockArguments(Block *block); @@ -457,7 +457,7 @@ private: /// placed inside `functions`) here. And then after emitting all blocks, we /// replace the dummy <id> 0 with the real result <id> by overwriting /// `functions[offset]`. - DenseMap<Value *, llvm::SmallVector<size_t, 1>> deferredPhiValues; + DenseMap<Value *, SmallVector<size_t, 1>> deferredPhiValues; }; } // namespace @@ -1341,7 +1341,7 @@ uint32_t Serializer::getOrCreateBlockID(Block *block) { LogicalResult Serializer::processBlock(Block *block, bool omitLabel, - llvm::function_ref<void()> actionBeforeTerminator) { + function_ref<void()> actionBeforeTerminator) { LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); LLVM_DEBUG(block->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); @@ -1773,7 +1773,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { auto funcName = op.callee(); uint32_t resTypeID = 0; - llvm::SmallVector<Type, 1> resultTypes(op.getResultTypes()); + SmallVector<Type, 1> resultTypes(op.getResultTypes()); if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), resTypeID))) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp index 655f559b765..e9b4f23cca4 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -80,7 +80,7 @@ static TranslateToMLIRRegistration fromBinary( // Serialization registration //===----------------------------------------------------------------------===// -LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) { +LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { if (!module) return failure(); @@ -105,7 +105,7 @@ LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) { } static TranslateFromMLIRRegistration - toBinary("serialize-spirv", [](ModuleOp module, llvm::raw_ostream &output) { + toBinary("serialize-spirv", [](ModuleOp module, raw_ostream &output) { return serializeModule(module, output); }); @@ -113,8 +113,8 @@ static TranslateFromMLIRRegistration // Round-trip registration //===----------------------------------------------------------------------===// -LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, - llvm::raw_ostream &output, MLIRContext *context) { +LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { // Parse an MLIR module from the source manager. auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!srcModule) @@ -147,9 +147,8 @@ LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, return mlir::success(); } -static TranslateRegistration - roundtrip("test-spirv-roundtrip", - [](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, - MLIRContext *context) { - return roundTripModule(sourceMgr, output, context); - }); +static TranslateRegistration roundtrip( + "test-spirv-roundtrip", + [](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { + return roundTripModule(sourceMgr, output, context); + }); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 3189e42d061..b2b3ba5f509 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2297,7 +2297,7 @@ static void print(OpAsmPrinter &p, ViewOp op) { Value *ViewOp::getDynamicOffset() { int64_t offset; - llvm::SmallVector<int64_t, 4> strides; + SmallVector<int64_t, 4> strides; auto result = succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); assert(result); @@ -2341,7 +2341,7 @@ static LogicalResult verify(ViewOp op) { // Verify that the result memref type has a strided layout map. int64_t offset; - llvm::SmallVector<int64_t, 4> strides; + SmallVector<int64_t, 4> strides; if (failed(getStridesAndOffset(viewType, strides, offset))) return op.emitError("result type ") << viewType << " is not strided"; @@ -2383,7 +2383,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { // Get offset from old memref view type 'memRefType'. int64_t oldOffset; - llvm::SmallVector<int64_t, 4> oldStrides; + SmallVector<int64_t, 4> oldStrides; if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return matchFailure(); @@ -2585,13 +2585,13 @@ static LogicalResult verify(SubViewOp op) { // Verify that the base memref type has a strided layout map. int64_t baseOffset; - llvm::SmallVector<int64_t, 4> baseStrides; + SmallVector<int64_t, 4> baseStrides; if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) return op.emitError("base type ") << subViewType << " is not strided"; // Verify that the result memref type has a strided layout map. int64_t subViewOffset; - llvm::SmallVector<int64_t, 4> subViewStrides; + SmallVector<int64_t, 4> subViewStrides; if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) return op.emitError("result type ") << subViewType << " is not strided"; @@ -2677,8 +2677,7 @@ static LogicalResult verify(SubViewOp op) { return success(); } -llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, - SubViewOp::Range &range) { +raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { return os << "range " << *range.offset << ":" << *range.size << ":" << *range.stride; } @@ -2734,7 +2733,7 @@ static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) { return false; // Get offset and strides. int64_t offset; - llvm::SmallVector<int64_t, 4> strides; + SmallVector<int64_t, 4> strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) return false; // Return 'false' if any of offset or strides is dynamic. diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 9945b6ae4c2..0ac07c2c4f5 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -112,8 +112,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // Returns the type kind if the given type is a vector or ranked tensor type. // Returns llvm::None otherwise. - auto getCompositeTypeKind = - [](Type type) -> llvm::Optional<StandardTypes::Kind> { + auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> { if (type.isa<VectorType>() || type.isa<RankedTensorType>()) return static_cast<StandardTypes::Kind>(type.getKind()); return llvm::None; @@ -122,7 +121,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // Make sure the composite type, if has, is consistent. auto compositeKind1 = getCompositeTypeKind(type1); auto compositeKind2 = getCompositeTypeKind(type2); - llvm::Optional<StandardTypes::Kind> resultCompositeKind; + Optional<StandardTypes::Kind> resultCompositeKind; if (compositeKind1 && compositeKind2) { // Disallow mixing vector and tensor. diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index c4d3e9d993d..64cacb28720 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -513,11 +513,11 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( // Generates slices of 'vectorType' according to 'sizes' and 'strides, and // calls 'fn' with linear index and indices for each slice. -static void generateTransferOpSlices( - VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes, - ArrayRef<int64_t> strides, ArrayRef<Value *> indices, - PatternRewriter &rewriter, - llvm::function_ref<void(unsigned, ArrayRef<Value *>)> fn) { +static void +generateTransferOpSlices(VectorType vectorType, TupleType tupleType, + ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, + ArrayRef<Value *> indices, PatternRewriter &rewriter, + function_ref<void(unsigned, ArrayRef<Value *>)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 9d7ca8ca99b..2956066a035 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -142,21 +142,21 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) { return res; } -static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, - ArrayRef<ValueHandle> ubs, - int64_t step) { +static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, + ArrayRef<ValueHandle> ubs, + int64_t step) { if (lbs.size() != 1 || ubs.size() != 1) - return llvm::Optional<ValueHandle>(); + return Optional<ValueHandle>(); auto *lbDef = lbs.front().getValue()->getDefiningOp(); auto *ubDef = ubs.front().getValue()->getDefiningOp(); if (!lbDef || !ubDef) - return llvm::Optional<ValueHandle>(); + return Optional<ValueHandle>(); auto lbConst = dyn_cast<ConstantIndexOp>(lbDef); auto ubConst = dyn_cast<ConstantIndexOp>(ubDef); if (!lbConst || !ubConst) - return llvm::Optional<ValueHandle>(); + return Optional<ValueHandle>(); return ValueHandle::create<AffineForOp>(lbConst.getValue(), ubConst.getValue(), step); @@ -194,7 +194,7 @@ mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle, return result; } -void mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) { +void mlir::edsc::LoopBuilder::operator()(function_ref<void(void)> fun) { // Call to `exit` must be explicit and asymmetric (cannot happen in the // destructor) because of ordering wrt comma operator. /// The particular use case concerns nested blocks: @@ -236,7 +236,7 @@ mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( } void mlir::edsc::AffineLoopNestBuilder::operator()( - llvm::function_ref<void(void)> fun) { + function_ref<void(void)> fun) { if (fun) fun(); // Iterate on the calling operator() on all the loops in the nest. @@ -281,7 +281,7 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args) { assert(!*bh && "BlockHandle already captures a block, use " "the explicit BockBuilder(bh, Append())({}) syntax instead."); - llvm::SmallVector<Type, 8> types; + SmallVector<Type, 8> types; for (auto *a : args) { assert(!a->hasValue() && "Expected delayed ValueHandle that has not yet captured."); @@ -296,7 +296,7 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, /// Only serves as an ordering point between entering nested block and creating /// stmts. -void mlir::edsc::BlockBuilder::operator()(llvm::function_ref<void(void)> fun) { +void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) { // Call to `exit` must be explicit and asymmetric (cannot happen in the // destructor) because of ordering wrt comma operator. if (fun) @@ -328,7 +328,7 @@ categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, static ValueHandle createBinaryIndexHandle( ValueHandle lhs, ValueHandle rhs, - llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { + function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { MLIRContext *context = ScopedContext::getContext(); unsigned numDims = 0, numSymbols = 0; AffineExpr d0, d1; @@ -352,7 +352,7 @@ static ValueHandle createBinaryIndexHandle( template <typename IOp, typename FOp> static ValueHandle createBinaryHandle( ValueHandle lhs, ValueHandle rhs, - llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { + function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { auto thisType = lhs.getValue()->getType(); auto thatType = rhs.getValue()->getType(); assert(thisType == thatType && "cannot mix types in operators"); diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index b88a1fdf4ef..46199c29c14 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -34,7 +34,7 @@ using namespace mlir; mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, int64_list_t sizes) { auto t = mlir::MemRefType::get( - llvm::ArrayRef<int64_t>(sizes.values, sizes.n), + ArrayRef<int64_t>(sizes.values, sizes.n), mlir::Type::getFromOpaquePointer(elemType), {mlir::AffineMap::getMultiDimIdentityMap( sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))}, @@ -44,7 +44,7 @@ mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, mlir_type_list_t outputs) { - llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n); + SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n); for (unsigned i = 0; i < inputs.n; ++i) { ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]); } diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 2913c436ad5..bbee80ac4e9 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -64,7 +64,7 @@ using llvm::orc::ThreadSafeModule; using llvm::orc::TMOwningSimpleCompiler; // Wrap a string into an llvm::StringError. -static inline Error make_string_error(const llvm::Twine &message) { +static inline Error make_string_error(const Twine &message) { return llvm::make_error<StringError>(message.str(), llvm::inconvertibleErrorCode()); } @@ -89,7 +89,7 @@ std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); } -void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) { +void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) { // Set up the output file. std::string errorMessage; auto file = openOutputFile(outputFilename, &errorMessage); @@ -105,7 +105,7 @@ void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) { file->keep(); } -void ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) { +void ExecutionEngine::dumpToObjectFile(StringRef filename) { cache->dumpToObjectFile(filename); } @@ -136,7 +136,7 @@ static std::string makePackedFunctionName(StringRef name) { void packFunctionArguments(Module *module) { auto &ctx = module->getContext(); llvm::IRBuilder<> builder(ctx); - llvm::DenseSet<llvm::Function *> interfaceFunctions; + DenseSet<llvm::Function *> interfaceFunctions; for (auto &func : module->getFunctionList()) { if (func.isDeclaration()) { continue; @@ -152,8 +152,7 @@ void packFunctionArguments(Module *module) { /*isVarArg=*/false); auto newName = makePackedFunctionName(func.getName()); auto funcCst = module->getOrInsertFunction(newName, newType); - llvm::Function *interfaceFunc = - llvm::cast<llvm::Function>(funcCst.getCallee()); + llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee()); interfaceFunctions.insert(interfaceFunc); // Extract the arguments from the type-erased argument list and cast them to @@ -162,11 +161,11 @@ void packFunctionArguments(Module *module) { bb->insertInto(interfaceFunc); builder.SetInsertPoint(bb); llvm::Value *argList = interfaceFunc->arg_begin(); - llvm::SmallVector<llvm::Value *, 8> args; + SmallVector<llvm::Value *, 8> args; args.reserve(llvm::size(func.args())); for (auto &indexedArg : llvm::enumerate(func.args())) { llvm::Value *argIndex = llvm::Constant::getIntegerValue( - builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); + builder.getInt64Ty(), APInt(64, indexedArg.index())); llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); argPtr = builder.CreateBitCast( @@ -181,7 +180,7 @@ void packFunctionArguments(Module *module) { // Assuming the result is one value, potentially of type `void`. if (!result->getType()->isVoidTy()) { llvm::Value *retIndex = llvm::Constant::getIntegerValue( - builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); + builder.getInt64Ty(), APInt(64, llvm::size(func.args()))); llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); @@ -220,7 +219,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create( llvm::raw_svector_ostream os(buffer); WriteBitcodeToFile(*llvmModule, os); } - llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()), + llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()), "cloned module buffer"); auto expectedModule = parseBitcodeFile(bufferRef, *ctx); if (!expectedModule) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 19599a8a62e..009c1a1485c 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -866,9 +866,10 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, // Flattens the expressions in map. Returns true on success or false // if 'expr' was unable to be flattened (i.e., semi-affine expressions not // handled yet). -static bool getFlattenedAffineExprs( - ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, - std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { +static bool +getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, + unsigned numSymbols, + std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { if (exprs.empty()) { return true; } @@ -894,9 +895,9 @@ static bool getFlattenedAffineExprs( // Flattens 'expr' into 'flattenedExpr'. Returns true on success or false // if 'expr' was unable to be flattened (semi-affine expressions not handled // yet). -bool mlir::getFlattenedAffineExpr( - AffineExpr expr, unsigned numDims, unsigned numSymbols, - llvm::SmallVectorImpl<int64_t> *flattenedExpr) { +bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + SmallVectorImpl<int64_t> *flattenedExpr) { std::vector<SmallVector<int64_t, 8>> flattenedExprs; bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs); @@ -908,7 +909,7 @@ bool mlir::getFlattenedAffineExpr( /// if 'expr' was unable to be flattened (i.e., semi-affine expressions not /// handled yet). bool mlir::getFlattenedAffineExprs( - AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { + AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { if (map.getNumResults() == 0) { return true; } @@ -917,8 +918,7 @@ bool mlir::getFlattenedAffineExprs( } bool mlir::getFlattenedAffineExprs( - IntegerSet set, - std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { + IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) { if (set.getNumConstraints() == 0) { return true; } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 98357b1348b..6cfef363985 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -48,7 +48,7 @@ public: } private: - llvm::Optional<int64_t> constantFoldImpl(AffineExpr expr) { + Optional<int64_t> constantFoldImpl(AffineExpr expr) { switch (expr.getKind()) { case AffineExprKind::Add: return constantFoldBinExpr( @@ -83,8 +83,8 @@ private: } // TODO: Change these to operate on APInts too. - llvm::Optional<int64_t> constantFoldBinExpr(AffineExpr expr, - int64_t (*op)(int64_t, int64_t)) { + Optional<int64_t> constantFoldBinExpr(AffineExpr expr, + int64_t (*op)(int64_t, int64_t)) { auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) @@ -324,7 +324,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { for (auto m : maps) numResults += m ? m.getNumResults() : 0; unsigned numDims = 0; - llvm::SmallVector<AffineExpr, 8> results; + SmallVector<AffineExpr, 8> results; results.reserve(numResults); for (auto m : maps) { if (!m) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 0ea447ed324..e1903d560b1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -437,9 +437,9 @@ public: void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); - void printAffineExpr( - AffineExpr expr, - llvm::function_ref<void(unsigned, bool)> printValueName = nullptr); + void + printAffineExpr(AffineExpr expr, + function_ref<void(unsigned, bool)> printValueName = nullptr); void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); @@ -463,7 +463,7 @@ protected: }; void printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, - llvm::function_ref<void(unsigned, bool)> printValueName = nullptr); + function_ref<void(unsigned, bool)> printValueName = nullptr); /// The output stream for the printer. raw_ostream &os; @@ -1175,13 +1175,13 @@ void ModulePrinter::printDialectType(Type type) { //===----------------------------------------------------------------------===// void ModulePrinter::printAffineExpr( - AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) { + AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } void ModulePrinter::printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, - llvm::function_ref<void(unsigned, bool)> printValueName) { + function_ref<void(unsigned, bool)> printValueName) { const char *binopSpelling = nullptr; switch (expr.getKind()) { case AffineExprKind::SymbolId: { diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index b546643837b..bb35a63bf5d 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -405,9 +405,9 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { }); } -ElementsAttr ElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { +ElementsAttr +ElementsAttr::mapValues(Type newElementType, + function_ref<APInt(const APInt &)> mapping) const { switch (getKind()) { case StandardAttributes::DenseElements: return cast<DenseElementsAttr>().mapValues(newElementType, mapping); @@ -416,9 +416,9 @@ ElementsAttr ElementsAttr::mapValues( } } -ElementsAttr ElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { +ElementsAttr +ElementsAttr::mapValues(Type newElementType, + function_ref<APInt(const APFloat &)> mapping) const { switch (getKind()) { case StandardAttributes::DenseElements: return cast<DenseElementsAttr>().mapValues(newElementType, mapping); @@ -798,15 +798,14 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { return getRaw(newType, getRawData(), isSplat()); } -DenseElementsAttr DenseElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { +DenseElementsAttr +DenseElementsAttr::mapValues(Type newElementType, + function_ref<APInt(const APInt &)> mapping) const { return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); } DenseElementsAttr DenseElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { + Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); } @@ -855,8 +854,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, } DenseElementsAttr DenseFPElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { + Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { llvm::SmallVector<char, 8> elementData; auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); @@ -875,8 +873,7 @@ bool DenseFPElementsAttr::classof(Attribute attr) { //===----------------------------------------------------------------------===// DenseElementsAttr DenseIntElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { + Type newElementType, function_ref<APInt(const APInt &)> mapping) const { llvm::SmallVector<char, 8> elementData; auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 63e85802b73..4dac32ae0c0 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -159,7 +159,7 @@ BlockArgument *Block::addArgument(Type type) { /// Add one argument to the argument list for each type specified in the list. auto Block::addArguments(ArrayRef<Type> types) - -> llvm::iterator_range<args_iterator> { + -> iterator_range<args_iterator> { arguments.reserve(arguments.size() + types.size()); auto initialSize = arguments.size(); for (auto type : types) { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 70a802cd856..59e16a48865 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -104,7 +104,7 @@ void DiagnosticArgument::print(raw_ostream &os) const { static StringRef twineToStrRef(const Twine &val, std::vector<std::unique_ptr<char[]>> &strings) { // Allocate memory to hold this string. - llvm::SmallString<64> data; + SmallString<64> data; auto strRef = val.toStringRef(data); strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()])); memcpy(&strings.back()[0], strRef.data(), strRef.size()); @@ -157,7 +157,7 @@ std::string Diagnostic::str() const { /// Attaches a note to this diagnostic. A new location may be optionally /// provided, if not, then the location defaults to the one specified for this /// diagnostic. Notes may not be attached to other notes. -Diagnostic &Diagnostic::attachNote(llvm::Optional<Location> noteLoc) { +Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) { // We don't allow attaching notes to notes. assert(severity != DiagnosticSeverity::Note && "cannot attach a note to a note"); @@ -285,9 +285,8 @@ void DiagnosticEngine::emit(Diagnostic diag) { /// Helper function used to emit a diagnostic with an optionally empty twine /// message. If the message is empty, then it is not inserted into the /// diagnostic. -static InFlightDiagnostic emitDiag(Location location, - DiagnosticSeverity severity, - const llvm::Twine &message) { +static InFlightDiagnostic +emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) { auto &diagEngine = location->getContext()->getDiagEngine(); auto diag = diagEngine.emit(location, severity); if (!message.isTriviallyEmpty()) @@ -374,7 +373,7 @@ struct SourceMgrDiagnosticHandlerImpl { } // end namespace mlir /// Return a processable FileLineColLoc from the given location. -static llvm::Optional<FileLineColLoc> getFileLineColLoc(Location loc) { +static Optional<FileLineColLoc> getFileLineColLoc(Location loc) { switch (loc->getKind()) { case StandardAttributes::NameLocation: return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc()); @@ -405,7 +404,7 @@ static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) { SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, - llvm::raw_ostream &os) + raw_ostream &os) : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os), impl(new SourceMgrDiagnosticHandlerImpl()) { setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); }); @@ -556,8 +555,7 @@ struct SourceMgrDiagnosticVerifierHandlerImpl { SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {} /// Returns the expected diagnostics for the given source file. - llvm::Optional<MutableArrayRef<ExpectedDiag>> - getExpectedDiags(StringRef bufName); + Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName); /// Computes the expected diagnostics for the given source buffer. MutableArrayRef<ExpectedDiag> @@ -592,7 +590,7 @@ static StringRef getDiagKindStr(DiagnosticSeverity kind) { } /// Returns the expected diagnostics for the given source file. -llvm::Optional<MutableArrayRef<ExpectedDiag>> +Optional<MutableArrayRef<ExpectedDiag>> SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) { auto expectedDiags = expectedDiagsPerFile.find(bufName); if (expectedDiags != expectedDiagsPerFile.end()) @@ -681,7 +679,7 @@ SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags( } SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( - llvm::SourceMgr &srcMgr, MLIRContext *ctx, llvm::raw_ostream &out) + llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out) : SourceMgrDiagnosticHandler(srcMgr, ctx, out), impl(new SourceMgrDiagnosticVerifierHandlerImpl()) { // Compute the expected diagnostics for each of the current files in the diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index e5e854260f3..b51c77f34c2 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -40,10 +40,10 @@ FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, OperationState state(location, "func"); Builder builder(location->getContext()); FuncOp::build(&builder, state, name, type, attrs); - return llvm::cast<FuncOp>(Operation::create(state)); + return cast<FuncOp>(Operation::create(state)); } FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - llvm::iterator_range<dialect_attr_iterator> attrs) { + iterator_range<dialect_attr_iterator> attrs) { SmallVector<NamedAttribute, 8> attrRef(attrs); return create(location, name, type, llvm::makeArrayRef(attrRef)); } @@ -204,7 +204,7 @@ FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { } // Create the new function. - FuncOp newFunc = llvm::cast<FuncOp>(getOperation()->cloneWithoutRegions()); + FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); newFunc.setType(newType); /// Set the argument attributes for arguments that aren't being replaced. diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp index 66c0d8af6d3..9cec216468d 100644 --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -213,7 +213,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, // Parse the optional function body. auto *body = result.addRegion(); return parser.parseOptionalRegion( - *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes); + *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); } // Print a function result list. diff --git a/mlir/lib/IR/IntegerSet.cpp b/mlir/lib/IR/IntegerSet.cpp index e5715877649..ce50fa7cc5b 100644 --- a/mlir/lib/IR/IntegerSet.cpp +++ b/mlir/lib/IR/IntegerSet.cpp @@ -73,8 +73,7 @@ MLIRContext *IntegerSet::getContext() const { /// Walk all of the AffineExpr's in this set. Each node in an expression /// tree is visited in postorder. -void IntegerSet::walkExprs( - llvm::function_ref<void(AffineExpr)> callback) const { +void IntegerSet::walkExprs(function_ref<void(AffineExpr)> callback) const { for (auto expr : getConstraints()) expr.walk(callback); } diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp index 79e04521e9c..c52a55b20fe 100644 --- a/mlir/lib/IR/Module.cpp +++ b/mlir/lib/IR/Module.cpp @@ -38,7 +38,7 @@ ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) { OperationState state(loc, "module"); Builder builder(loc->getContext()); ModuleOp::build(&builder, state, name); - return llvm::cast<ModuleOp>(Operation::create(state)); + return cast<ModuleOp>(Operation::create(state)); } ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index fd747a98a40..9df10791046 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1174,7 +1174,7 @@ Value *impl::foldCastOp(Operation *op) { /// terminator operation to insert. void impl::ensureRegionTerminator( Region ®ion, Location loc, - llvm::function_ref<Operation *()> buildTerminatorOp) { + function_ref<Operation *()> buildTerminatorOp) { if (region.empty()) region.push_back(new Block); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index c588e567bc3..6cec021b6a1 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -129,7 +129,7 @@ void Region::dropAllReferences() { /// is used to point to the operation containing the region, the actual error is /// reported at the operation with an offending use. static bool isIsolatedAbove(Region ®ion, Region &limit, - llvm::Optional<Location> noteLoc) { + Optional<Location> noteLoc) { assert(limit.isAncestor(®ion) && "expected isolation limit to be an ancestor of the given region"); @@ -174,7 +174,7 @@ static bool isIsolatedAbove(Region ®ion, Region &limit, return true; } -bool Region::isIsolatedFromAbove(llvm::Optional<Location> noteLoc) { +bool Region::isIsolatedFromAbove(Optional<Location> noteLoc) { return isIsolatedAbove(*this, *this, noteLoc); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 8a47c5b0b41..7c494e219e8 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -375,7 +375,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, // Drop identity maps from the composition. // This may lead to the composition becoming empty, which is interpreted as an // implicit identity. - llvm::SmallVector<AffineMap, 2> cleanedAffineMapComposition; + SmallVector<AffineMap, 2> cleanedAffineMapComposition; for (const auto &map : affineMapComposition) { if (map.isIdentity()) continue; @@ -417,7 +417,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const { } LogicalResult UnrankedMemRefType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Type elementType, + Optional<Location> loc, MLIRContext *context, Type elementType, unsigned memorySpace) { // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) @@ -647,8 +647,9 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) { } /// Verify the construction of an integer type. -LogicalResult ComplexType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Type elementType) { +LogicalResult ComplexType::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *context, + Type elementType) { if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) return emitOptionalError(loc, "invalid element type for complex"); return success(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index ddc8d0191f5..1a02745e90c 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -349,7 +349,7 @@ public: /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. ParseResult parseAffineMapOfSSAIds(AffineMap &map, - llvm::function_ref<ParseResult(bool)> parseElement); + function_ref<ParseResult(bool)> parseElement); private: /// The Parser is subclassed and reinstantiated. Do not add additional @@ -832,7 +832,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, /// parsing failed, nullptr is returned. The number of bytes read from the input /// string is returned in 'numRead'. template <typename T, typename ParserFn> -static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, +static T parseSymbol(StringRef inputStr, MLIRContext *context, SymbolState &symbolState, ParserFn &&parserFn, size_t *numRead = nullptr) { SourceMgr sourceMgr; @@ -1866,7 +1866,7 @@ private: /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] /// parseList([[1, 2], 3]) -> Failure /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure - ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims); + ParseResult parseList(SmallVectorImpl<int64_t> &dims); Parser &p; @@ -1877,7 +1877,7 @@ private: std::vector<std::pair<bool, Token>> storage; /// A flag that indicates the type of elements that have been parsed. - llvm::Optional<ElementKind> knownEltKind; + Optional<ElementKind> knownEltKind; }; } // namespace @@ -2032,13 +2032,11 @@ ParseResult TensorLiteralParser::parseElement() { /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] /// parseList([[1, 2], 3]) -> Failure /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure -ParseResult -TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) { +ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { p.consumeToken(Token::l_square); - auto checkDims = - [&](const llvm::SmallVectorImpl<int64_t> &prevDims, - const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult { + auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, + const SmallVectorImpl<int64_t> &newDims) -> ParseResult { if (prevDims == newDims) return success(); return p.emitError("tensor literal is invalid; ranks are not consistent " @@ -2046,10 +2044,10 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) { }; bool first = true; - llvm::SmallVector<int64_t, 4> newDims; + SmallVector<int64_t, 4> newDims; unsigned size = 0; auto parseCommaSeparatedList = [&]() -> ParseResult { - llvm::SmallVector<int64_t, 4> thisDims; + SmallVector<int64_t, 4> thisDims; if (p.getToken().getKind() == Token::l_square) { if (parseList(thisDims)) return failure(); @@ -2275,7 +2273,7 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) { return failure(); } - llvm::SmallVector<Location, 4> locations; + SmallVector<Location, 4> locations; auto parseElt = [&] { LocationAttr newLoc; if (parseLocationInstance(newLoc)) @@ -2411,7 +2409,7 @@ namespace { class AffineParser : public Parser { public: AffineParser(ParserState &state, bool allowParsingSSAIds = false, - llvm::function_ref<ParseResult(bool)> parseElement = nullptr) + function_ref<ParseResult(bool)> parseElement = nullptr) : Parser(state), allowParsingSSAIds(allowParsingSSAIds), parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} @@ -2454,7 +2452,7 @@ private: private: bool allowParsingSSAIds; - llvm::function_ref<ParseResult(bool)> parseElement; + function_ref<ParseResult(bool)> parseElement; unsigned numDimOperands; unsigned numSymbolOperands; SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; @@ -3048,8 +3046,9 @@ ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to /// parse SSA value uses encountered while parsing affine expressions. -ParseResult Parser::parseAffineMapOfSSAIds( - AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) { +ParseResult +Parser::parseAffineMapOfSSAIds(AffineMap &map, + function_ref<ParseResult(bool)> parseElement) { return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) .parseAffineMapOfSSAIds(map); } @@ -3113,7 +3112,7 @@ public: /// Return the location of the value identified by its name and number if it /// has been already reference. - llvm::Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { + Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { auto &values = isolatedNameScopes.back().values; if (!values.count(name) || number >= values[name].size()) return {}; @@ -4781,8 +4780,8 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr, /// parsing failed, nullptr is returned. The number of bytes read from the input /// string is returned in 'numRead'. template <typename T, typename ParserFn> -static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, - size_t &numRead, ParserFn &&parserFn) { +static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, + ParserFn &&parserFn) { SymbolState aliasState; return parseSymbol<T>( inputStr, context, aliasState, @@ -4795,35 +4794,33 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, &numRead); } -Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) { +Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { size_t numRead = 0; return parseAttribute(attrStr, context, numRead); } -Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type) { +Attribute mlir::parseAttribute(StringRef attrStr, Type type) { size_t numRead = 0; return parseAttribute(attrStr, type, numRead); } -Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context, +Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, size_t &numRead) { return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { return parser.parseAttribute(); }); } -Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type, - size_t &numRead) { +Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { return parseSymbol<Attribute>( attrStr, type.getContext(), numRead, [type](Parser &parser) { return parser.parseAttribute(type); }); } -Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) { +Type mlir::parseType(StringRef typeStr, MLIRContext *context) { size_t numRead = 0; return parseType(typeStr, context, numRead); } -Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context, - size_t &numRead) { +Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { return parseSymbol<Type>(typeStr, context, numRead, [](Parser &parser) { return parser.parseType(); }); } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index cb5194acf21..f893c7babf9 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -725,7 +725,7 @@ void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) { } /// See PassInstrumentation::runBeforeAnalysis for details. -void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, +void PassInstrumentor::runBeforeAnalysis(StringRef name, AnalysisID *id, Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : impl->instrumentations) @@ -733,7 +733,7 @@ void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, } /// See PassInstrumentation::runAfterAnalysis for details. -void PassInstrumentor::runAfterAnalysis(llvm::StringRef name, AnalysisID *id, +void PassInstrumentor::runAfterAnalysis(StringRef name, AnalysisID *id, Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index 932bf98f61e..c29e0d08869 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -105,7 +105,7 @@ struct PassManagerOptions { }; } // end anonymous namespace -static llvm::ManagedStatic<llvm::Optional<PassManagerOptions>> options; +static llvm::ManagedStatic<Optional<PassManagerOptions>> options; /// Add an IR printing instrumentation if enabled by any 'print-ir' flags. void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) { diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 397fef3ef5d..1a321d666c4 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -27,8 +27,7 @@ using namespace mlir; using namespace detail; /// Static mapping of all of the registered passes. -static llvm::ManagedStatic<llvm::DenseMap<const PassID *, PassInfo>> - passRegistry; +static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry; /// Static mapping of all of the registered pass pipelines. static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> @@ -138,7 +137,7 @@ private: /// A functor used to emit errors found during pipeline handling. The first /// parameter corresponds to the raw location within the pipeline string. This /// should always return failure. - using ErrorHandlerT = function_ref<LogicalResult(const char *, llvm::Twine)>; + using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; /// A struct to capture parsed pass pipeline names. /// @@ -189,7 +188,7 @@ LogicalResult TextualPipeline::initialize(StringRef text, pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer( text, "MLIR Textual PassPipeline Parser"), llvm::SMLoc()); - auto errorHandler = [&](const char *rawLoc, llvm::Twine msg) { + auto errorHandler = [&](const char *rawLoc, Twine msg) { pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc), llvm::SourceMgr::DK_Error, msg); return failure(); @@ -401,7 +400,7 @@ namespace { /// The name for the command line option used for parsing the textual pass /// pipeline. -static constexpr llvm::StringLiteral passPipelineArg = "pass-pipeline"; +static constexpr StringLiteral passPipelineArg = "pass-pipeline"; /// Adds command line option for each registered pass or pass pipeline, as well /// as textual pass pipelines. diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp index 3c46b0bf255..530697421ef 100644 --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -23,7 +23,7 @@ using namespace mlir; using namespace mlir::detail; -constexpr llvm::StringLiteral kPassStatsDescription = +constexpr StringLiteral kPassStatsDescription = "... Pass statistics report ..."; namespace { diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index dd193a4d9a9..113b65a09b5 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -29,7 +29,7 @@ using namespace mlir; using namespace mlir::detail; -constexpr llvm::StringLiteral kPassTimingDescription = +constexpr StringLiteral kPassTimingDescription = "... Pass execution timing report ..."; namespace { @@ -182,11 +182,10 @@ struct PassTiming : public PassInstrumentation { void runAfterPassFailed(Pass *pass, Operation *op) override { runAfterPass(pass, op); } - void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, - Operation *) override { + void runBeforeAnalysis(StringRef name, AnalysisID *id, Operation *) override { startAnalysisTimer(name, id); } - void runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) override; + void runAfterAnalysis(StringRef, AnalysisID *, Operation *) override; /// Print and clear the timing results. void print(); @@ -195,7 +194,7 @@ struct PassTiming : public PassInstrumentation { void startPassTimer(Pass *pass); /// Start a new timer for the given analysis. - void startAnalysisTimer(llvm::StringRef name, AnalysisID *id); + void startAnalysisTimer(StringRef name, AnalysisID *id); /// Pop the last active timer for the current thread. Timer *popLastActiveTimer() { @@ -301,7 +300,7 @@ void PassTiming::startPassTimer(Pass *pass) { } /// Start a new timer for the given analysis. -void PassTiming::startAnalysisTimer(llvm::StringRef name, AnalysisID *id) { +void PassTiming::startAnalysisTimer(StringRef name, AnalysisID *id) { Timer *timer = getTimer(id, TimerKind::PassOrAnalysis, [name] { return "(A) " + name.str(); }); timer->start(); @@ -330,12 +329,12 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) { } /// Stop a timer. -void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) { +void PassTiming::runAfterAnalysis(StringRef, AnalysisID *, Operation *) { popLastActiveTimer()->stop(); } /// Utility to print the timer heading information. -static void printTimerHeader(llvm::raw_ostream &os, TimeRecord total) { +static void printTimerHeader(raw_ostream &os, TimeRecord total) { os << "===" << std::string(73, '-') << "===\n"; // Figure out how many spaces to description name. unsigned padding = (80 - kPassTimingDescription.size()) / 2; diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp index cfed2a2647c..d38c76255f0 100644 --- a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -157,15 +157,15 @@ Type CAGAnchorNode::getTransformedType() { getOriginalType()); } -void CAGNode::printLabel(llvm::raw_ostream &os) const { +void CAGNode::printLabel(raw_ostream &os) const { os << "Node<" << static_cast<const void *>(this) << ">"; } -void CAGAnchorNode::printLabel(llvm::raw_ostream &os) const { +void CAGAnchorNode::printLabel(raw_ostream &os) const { getUniformMetadata().printSummary(os); } -void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const { +void CAGOperandAnchor::printLabel(raw_ostream &os) const { os << "Operand<"; op->getName().print(os); os << "," << operandIdx; @@ -173,7 +173,7 @@ void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const { CAGAnchorNode::printLabel(os); } -void CAGResultAnchor::printLabel(llvm::raw_ostream &os) const { +void CAGResultAnchor::printLabel(raw_ostream &os) const { os << "Result<"; getOp()->getName().print(os); os << ">"; diff --git a/mlir/lib/Quantizer/Support/Metadata.cpp b/mlir/lib/Quantizer/Support/Metadata.cpp index 3661f52b52f..89478c4209d 100644 --- a/mlir/lib/Quantizer/Support/Metadata.cpp +++ b/mlir/lib/Quantizer/Support/Metadata.cpp @@ -24,7 +24,7 @@ using namespace mlir; using namespace mlir::quantizer; -void CAGUniformMetadata::printSummary(llvm::raw_ostream &os) const { +void CAGUniformMetadata::printSummary(raw_ostream &os) const { if (requiredRange.hasValue()) { os << "\n[" << requiredRange.getValue().first << "," << requiredRange.getValue().second << "]"; diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index 788c2f67e27..d155875cfe3 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -28,11 +28,12 @@ using namespace mlir::quantizer; // AttributeTensorStatistics implementation //===----------------------------------------------------------------------===// -static void -collectElementsStatisticsDim(ElementsAttr attr, unsigned numElements, - ArrayRef<int64_t> shape, - llvm::SmallVectorImpl<uint64_t> &indices, - uint64_t dim, TensorAxisStatistics &statistics) { +static void collectElementsStatisticsDim(ElementsAttr attr, + unsigned numElements, + ArrayRef<int64_t> shape, + SmallVectorImpl<uint64_t> &indices, + uint64_t dim, + TensorAxisStatistics &statistics) { // Recursive terminating condition. if (dim >= shape.size()) return; @@ -71,7 +72,7 @@ static bool getElementsStatistics(ElementsAttr attr, if (!elementTy.isa<FloatType>()) return false; - llvm::SmallVector<uint64_t, 4> indices; + SmallVector<uint64_t, 4> indices; indices.resize(sType.getRank()); ArrayRef<int64_t> shape = sType.getShape(); @@ -97,8 +98,7 @@ bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const { namespace mlir { namespace quantizer { -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const TensorAxisStatistics &stats) { +raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats) { os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean << ", variance=" << stats.variance << "]"; diff --git a/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/mlir/lib/Quantizer/Support/UniformConstraints.cpp index c43ecdfb5c2..1a800dad4ac 100644 --- a/mlir/lib/Quantizer/Support/UniformConstraints.cpp +++ b/mlir/lib/Quantizer/Support/UniformConstraints.cpp @@ -118,7 +118,7 @@ public: } private: - void printLabel(llvm::raw_ostream &os) const override { + void printLabel(raw_ostream &os) const override { os << "PropagateExplicitScale"; } void propagate(SolverContext &solverContext, @@ -127,7 +127,7 @@ private: // Get scale/zp from all parents. for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { - auto parentAnchor = llvm::cast<CAGAnchorNode>(*it); + auto parentAnchor = cast<CAGAnchorNode>(*it); auto selectedType = parentAnchor->getUniformMetadata().selectedType; if (auto uqType = selectedType.dyn_cast_or_null<UniformQuantizedType>()) { scaleZp.assertValue( @@ -139,7 +139,7 @@ private: // Propagate to children. if (scaleZp.hasValue()) { for (auto it = begin(), e = end(); it != e; ++it) { - auto childAnchor = llvm::cast<CAGAnchorNode>(*it); + auto childAnchor = cast<CAGAnchorNode>(*it); if (modified(childAnchor->getUniformMetadata() .explicitScaleZeroPoint.mergeFrom(scaleZp))) { childAnchor->markDirty(); @@ -163,9 +163,7 @@ public: } private: - void printLabel(llvm::raw_ostream &os) const override { - os << "SolveUniform"; - } + void printLabel(raw_ostream &os) const override { os << "SolveUniform"; } void propagate(SolverContext &solverContext, const TargetConfiguration &config) override { @@ -176,7 +174,7 @@ private: ClusteredFacts clusteredFacts; Type originalElementType; for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { - auto parentAnchor = llvm::cast<CAGAnchorNode>(*it); + auto parentAnchor = cast<CAGAnchorNode>(*it); auto metadata = parentAnchor->getUniformMetadata(); // TODO: Possibly use a location that fuses all involved parents. fusedLoc = parentAnchor->getOp()->getLoc(); @@ -226,7 +224,7 @@ private: // Apply it to all parents. for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { - auto parentAnchor = llvm::cast<CAGAnchorNode>(*it); + auto parentAnchor = cast<CAGAnchorNode>(*it); auto &metadata = parentAnchor->getUniformMetadata(); if (metadata.selectedType != selectedType) { metadata.selectedType = selectedType; diff --git a/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/mlir/lib/Quantizer/Support/UniformSolvers.cpp index 341df5bf888..bd2fe686ee1 100644 --- a/mlir/lib/Quantizer/Support/UniformSolvers.cpp +++ b/mlir/lib/Quantizer/Support/UniformSolvers.cpp @@ -16,9 +16,8 @@ // ============================================================================= #include "mlir/Quantizer/Support/UniformSolvers.h" - +#include "mlir/Support/LLVM.h" #include "llvm/Support/raw_ostream.h" - #include <cmath> using namespace mlir; @@ -131,14 +130,13 @@ double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const { namespace mlir { namespace quantizer { -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const UniformStorageParams &p) { +raw_ostream &operator<<(raw_ostream &os, const UniformStorageParams &p) { os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}"; return os; } -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const UniformParamsFromMinMaxSolver &s) { +raw_ostream &operator<<(raw_ostream &os, + const UniformParamsFromMinMaxSolver &s) { os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){"; os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> "; if (!s.isSatisfied()) { diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index 7c449e32c4c..511df0a463f 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -162,7 +162,7 @@ void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, // operands). // Apply result types. for (auto *node : cag) { - auto anchorNode = llvm::dyn_cast<CAGResultAnchor>(node); + auto anchorNode = dyn_cast<CAGResultAnchor>(node); if (!anchorNode) continue; if (Type newType = anchorNode->getTransformedType()) @@ -171,7 +171,7 @@ void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, // Apply operand types. for (auto *node : cag) { - auto anchorNode = llvm::dyn_cast<CAGOperandAnchor>(node); + auto anchorNode = dyn_cast<CAGOperandAnchor>(node); if (!anchorNode) continue; if (Type newType = anchorNode->getTransformedType()) diff --git a/mlir/lib/Support/JitRunner.cpp b/mlir/lib/Support/JitRunner.cpp index 8914681cdd9..dcd23437401 100644 --- a/mlir/lib/Support/JitRunner.cpp +++ b/mlir/lib/Support/JitRunner.cpp @@ -122,14 +122,14 @@ static void initializeLLVM() { llvm::InitializeNativeTargetAsmPrinter(); } -static inline Error make_string_error(const llvm::Twine &message) { +static inline Error make_string_error(const Twine &message) { return llvm::make_error<llvm::StringError>(message.str(), llvm::inconvertibleErrorCode()); } -static llvm::Optional<unsigned> getCommandLineOptLevel() { - llvm::Optional<unsigned> optLevel; - llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ +static Optional<unsigned> getCommandLineOptLevel() { + Optional<unsigned> optLevel; + SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ optO0, optO1, optO2, optO3}; // Determine if there is an optimization flag present. @@ -217,7 +217,7 @@ static Error compileAndExecuteSingleFloatReturnFunction( // the MLIR module to the ExecutionEngine. int mlir::JitRunnerMain( int argc, char **argv, - llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { + function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { llvm::InitLLVM y(argc, argv); initializeLLVM(); @@ -225,8 +225,8 @@ int mlir::JitRunnerMain( llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); - llvm::Optional<unsigned> optLevel = getCommandLineOptLevel(); - llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ + Optional<unsigned> optLevel = getCommandLineOptLevel(); + SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ optO0, optO1, optO2, optO3}; unsigned optCLIPosition = 0; // Determine if there is an optimization flag present, and its CLI position @@ -240,7 +240,7 @@ int mlir::JitRunnerMain( } // Generate vector of pass information, plus the index at which we should // insert any optimization passes in that vector (optPosition). - llvm::SmallVector<const llvm::PassInfo *, 4> passes; + SmallVector<const llvm::PassInfo *, 4> passes; unsigned optPosition = 0; for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) { passes.push_back(llvmPasses[i]); diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index 3672c0f3759..cae4dce143f 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -39,7 +39,7 @@ struct StorageUniquerImpl { unsigned hashValue; /// An equality function for comparing with an existing storage instance. - llvm::function_ref<bool(const BaseStorage *)> isEqual; + function_ref<bool(const BaseStorage *)> isEqual; }; /// A utility wrapper object representing a hashed storage object. This class @@ -52,8 +52,8 @@ struct StorageUniquerImpl { /// Get or create an instance of a complex derived type. BaseStorage * getOrCreate(unsigned kind, unsigned hashValue, - llvm::function_ref<bool(const BaseStorage *)> isEqual, - llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { + function_ref<bool(const BaseStorage *)> isEqual, + function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { LookupKey lookupKey{kind, hashValue, isEqual}; // Check for an existing instance in read-only mode. @@ -83,7 +83,7 @@ struct StorageUniquerImpl { /// Get or create an instance of a simple derived type. BaseStorage * getOrCreate(unsigned kind, - llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { + function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader<true> typeLock(mutex); @@ -107,8 +107,8 @@ struct StorageUniquerImpl { /// Erase an instance of a complex derived type. void erase(unsigned kind, unsigned hashValue, - llvm::function_ref<bool(const BaseStorage *)> isEqual, - llvm::function_ref<void(BaseStorage *)> cleanupFn) { + function_ref<bool(const BaseStorage *)> isEqual, + function_ref<void(BaseStorage *)> cleanupFn) { LookupKey lookupKey{kind, hashValue, isEqual}; // Acquire a writer-lock so that we can safely erase the type instance. @@ -127,9 +127,9 @@ struct StorageUniquerImpl { //===--------------------------------------------------------------------===// /// Utility to create and initialize a storage instance. - BaseStorage *initializeStorage( - unsigned kind, - llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { + BaseStorage * + initializeStorage(unsigned kind, + function_ref<BaseStorage *(StorageAllocator &)> ctorFn) { BaseStorage *storage = ctorFn(allocator); storage->kind = kind; return storage; @@ -162,11 +162,11 @@ struct StorageUniquerImpl { }; // Unique types with specific hashing or storage constraints. - using StorageTypeSet = llvm::DenseSet<HashedStorage, StorageKeyInfo>; + using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>; StorageTypeSet storageTypes; // Unique types with just the kind. - llvm::DenseMap<unsigned, BaseStorage *> simpleTypes; + DenseMap<unsigned, BaseStorage *> simpleTypes; // Allocator to use when constructing derived type instances. StorageUniquer::StorageAllocator allocator; @@ -184,7 +184,7 @@ StorageUniquer::~StorageUniquer() {} /// complex storage. auto StorageUniquer::getImpl( unsigned kind, unsigned hashValue, - llvm::function_ref<bool(const BaseStorage *)> isEqual, + function_ref<bool(const BaseStorage *)> isEqual, std::function<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * { return impl->getOrCreate(kind, hashValue, isEqual, ctorFn); } @@ -199,9 +199,8 @@ auto StorageUniquer::getImpl( /// Implementation for erasing an instance of a derived type with complex /// storage. -void StorageUniquer::eraseImpl( - unsigned kind, unsigned hashValue, - llvm::function_ref<bool(const BaseStorage *)> isEqual, - std::function<void(BaseStorage *)> cleanupFn) { +void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue, + function_ref<bool(const BaseStorage *)> isEqual, + std::function<void(BaseStorage *)> cleanupFn) { impl->erase(kind, hashValue, isEqual, cleanupFn); } diff --git a/mlir/lib/Support/TranslateClParser.cpp b/mlir/lib/Support/TranslateClParser.cpp index dae0437813f..115c0c03f50 100644 --- a/mlir/lib/Support/TranslateClParser.cpp +++ b/mlir/lib/Support/TranslateClParser.cpp @@ -35,9 +35,9 @@ using namespace mlir; // Storage for the translation function wrappers that survive the parser. -static llvm::SmallVector<TranslateFunction, 16> wrapperStorage; +static SmallVector<TranslateFunction, 16> wrapperStorage; -static LogicalResult printMLIROutput(ModuleOp module, llvm::raw_ostream &os) { +static LogicalResult printMLIROutput(ModuleOp module, raw_ostream &os) { if (failed(verify(module))) return failure(); module.print(os); @@ -57,7 +57,7 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) for (const auto &kv : toMLIRRegistry) { TranslateSourceMgrToMLIRFunction function = kv.second; TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr, - llvm::raw_ostream &output, + raw_ostream &output, MLIRContext *context) { OwningModuleRef module = function(sourceMgr, context); if (!module) @@ -72,7 +72,7 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) for (const auto &kv : fromMLIRRegistry) { TranslateFromMLIRFunction function = kv.second; TranslateFunction wrapper = [function](llvm::SourceMgr &sourceMgr, - llvm::raw_ostream &output, + raw_ostream &output, MLIRContext *context) { auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!module) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index b3bea0f036b..e69dce7b59b 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -34,8 +34,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(ModuleOp m) { return LLVM::ModuleTranslation::translateModule<>(m); } -static TranslateFromMLIRRegistration registration( - "mlir-to-llvmir", [](ModuleOp module, llvm::raw_ostream &output) { +static TranslateFromMLIRRegistration + registration("mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 166ec899776..83c486979d6 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -105,12 +105,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) { } static TranslateFromMLIRRegistration - registration("mlir-to-nvvmir", - [](ModuleOp module, llvm::raw_ostream &output) { - auto llvmModule = mlir::translateModuleToNVVMIR(module); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); - }); + registration("mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) { + auto llvmModule = mlir::translateModuleToNVVMIR(module); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }); diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index 31ba4a27ca0..c06e1cadbc4 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -60,11 +60,11 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, llvm::Type::getInt64Ty(module->getContext()), // return type. llvm::Type::getInt32Ty(module->getContext()), // parameter type. false); // no variadic arguments. - llvm::Function *fn = llvm::dyn_cast<llvm::Function>( + llvm::Function *fn = dyn_cast<llvm::Function>( module->getOrInsertFunction(fn_name, function_type).getCallee()); llvm::Value *fn_op0 = llvm::ConstantInt::get( llvm::Type::getInt32Ty(module->getContext()), parameter); - return builder.CreateCall(fn, llvm::ArrayRef<llvm::Value *>(fn_op0)); + return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fn_op0)); } class ModuleTranslation : public LLVM::ModuleTranslation { @@ -111,12 +111,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) { } static TranslateFromMLIRRegistration - registration("mlir-to-rocdlir", - [](ModuleOp module, llvm::raw_ostream &output) { - auto llvmModule = mlir::translateModuleToROCDLIR(module); - if (!llvmModule) - return failure(); - - llvmModule->print(output, nullptr); - return success(); - }); + registration("mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) { + auto llvmModule = mlir::translateModuleToROCDLIR(module); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 7a7964d71d3..086c3a831fc 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -473,7 +473,7 @@ LogicalResult ModuleTranslation::convertFunctions() { for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( function.getName(), - llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType())); + cast<llvm::FunctionType>(function.getType().getUnderlyingType())); assert(isa<llvm::Function>(llvmFuncCst.getCallee())); functionMapping[function.getName()] = cast<llvm::Function>(llvmFuncCst.getCallee()); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 4b4575a5e50..37c918fe9be 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -35,7 +35,7 @@ using namespace mlir::detail; /// If 'target' is nonnull, operations that are recursively legal have their /// regions pre-filtered to avoid considering them for legalization. static LogicalResult -computeConversionSet(llvm::iterator_range<Region::iterator> region, +computeConversionSet(iterator_range<Region::iterator> region, Location regionLoc, std::vector<Operation *> &toConvert, ConversionTarget *target = nullptr) { if (llvm::empty(region)) @@ -537,9 +537,8 @@ struct ConversionPatternRewriterImpl { Region::iterator before); /// Notifies that the blocks of a region were cloned into another. - void - notifyRegionWasClonedBefore(llvm::iterator_range<Region::iterator> &blocks, - Location origRegionLoc); + void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks, + Location origRegionLoc); /// Remap the given operands to those with potentially different types. void remapValues(Operation::operand_range operands, @@ -742,7 +741,7 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( } void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( - llvm::iterator_range<Region::iterator> &blocks, Location origRegionLoc) { + iterator_range<Region::iterator> &blocks, Location origRegionLoc) { for (Block &block : blocks) blockActions.push_back(BlockAction::getCreate(&block)); @@ -986,7 +985,7 @@ private: void computeLegalizationGraphBenefit(); /// The current set of patterns that have been applied. - llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns; + SmallPtrSet<RewritePattern *, 8> appliedPatterns; /// The set of legality information for operations transitively supported by /// the target. @@ -1572,7 +1571,7 @@ void mlir::populateFuncOpTypeConversionPattern( /// 'convertSignatureArg' for each argument. This function should return a valid /// conversion for the signature on success, None otherwise. auto TypeConverter::convertBlockSignature(Block *block) - -> llvm::Optional<SignatureConversion> { + -> Optional<SignatureConversion> { SignatureConversion conversion(block->getNumArguments()); for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) if (failed(convertSignatureArg(i, block->getArgument(i)->getType(), diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index dbb5381ed70..9948a429616 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -79,11 +79,11 @@ struct ResolvedCall { /// Collect all of the callable operations within the given range of blocks. If /// `traverseNestedCGNodes` is true, this will also collect call operations /// inside of nested callgraph nodes. -static void collectCallOps(llvm::iterator_range<Region::iterator> blocks, +static void collectCallOps(iterator_range<Region::iterator> blocks, CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, bool traverseNestedCGNodes) { SmallVector<Block *, 8> worklist; - auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) { + auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { for (Block &block : blocks) worklist.push_back(&block); }; @@ -120,8 +120,8 @@ struct Inliner : public InlinerInterface { /// Process a set of blocks that have been inlined. This callback is invoked /// *before* inlined terminator operations have been processed. - void processInlinedBlocks( - llvm::iterator_range<Region::iterator> inlinedBlocks) final { + void + processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); } diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 738524aa6ec..4932494a04b 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -50,7 +50,7 @@ public: // - the op has no side-effects. If sideEffecting is Never, sideeffects of this // op and its nested ops are ignored. static bool canBeHoisted(Operation *op, - llvm::function_ref<bool(Value *)> definedOutside, + function_ref<bool(Value *)> definedOutside, SideEffecting sideEffecting, SideEffectsInterface &interface) { // Check that dependencies are defined outside of loop. diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index 5faca1296a8..d4b7caae527 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -82,9 +82,8 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, //===----------------------------------------------------------------------===// LogicalResult OperationFolder::tryToFold( - Operation *op, - llvm::function_ref<void(Operation *)> processGeneratedConstants, - llvm::function_ref<void(Operation *)> preReplaceAction) { + Operation *op, function_ref<void(Operation *)> processGeneratedConstants, + function_ref<void(Operation *)> preReplaceAction) { // If this is a unique'd constant, return failure as we know that it has // already been folded. if (referencedDialects.count(op)) @@ -140,7 +139,7 @@ void OperationFolder::notifyRemoval(Operation *op) { /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( Operation *op, SmallVectorImpl<Value *> &results, - llvm::function_ref<void(Operation *)> processGeneratedConstants) { + function_ref<void(Operation *)> processGeneratedConstants) { SmallVector<Attribute, 8> operandConstants; SmallVector<OpFoldResult, 8> foldResults; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index fd08c53b0dc..e8e6ae03338 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -35,7 +35,7 @@ using namespace mlir; /// Remap locations from the inlined blocks with CallSiteLoc locations with the /// provided caller location. static void -remapInlinedLocations(llvm::iterator_range<Region::iterator> inlinedBlocks, +remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, Location callerLoc) { DenseMap<Location, Location> mappedLocations; auto remapOpLoc = [&](Operation *op) { @@ -50,9 +50,8 @@ remapInlinedLocations(llvm::iterator_range<Region::iterator> inlinedBlocks, block.walk(remapOpLoc); } -static void -remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks, - BlockAndValueMapping &mapper) { +static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, + BlockAndValueMapping &mapper) { auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) if (auto *mappedOp = mapper.lookupOrNull(operand.get())) @@ -133,7 +132,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ArrayRef<Value *> resultsToReplace, - llvm::Optional<Location> inlineLoc, + Optional<Location> inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. if (src->empty()) @@ -226,7 +225,7 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, ArrayRef<Value *> inlinedOperands, ArrayRef<Value *> resultsToReplace, - llvm::Optional<Location> inlineLoc, + Optional<Location> inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. if (src->empty()) diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 50248b01359..419df8d2705 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -532,11 +532,11 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { // desired loop interchange would violate dependences by making the // dependence component lexicographically negative. static bool checkLoopInterchangeDependences( - const std::vector<llvm::SmallVector<DependenceComponent, 2>> &depCompsVec, + const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec, ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) { // Invert permutation map. unsigned maxLoopDepth = loops.size(); - llvm::SmallVector<unsigned, 4> loopPermMapInv; + SmallVector<unsigned, 4> loopPermMapInv; loopPermMapInv.resize(maxLoopDepth); for (unsigned i = 0; i < maxLoopDepth; ++i) loopPermMapInv[loopPermMap[i]] = i; @@ -547,7 +547,7 @@ static bool checkLoopInterchangeDependences( // Example 1: [-1, 1][0, 0] // Example 2: [0, 0][-1, 1] for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { - const llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; + const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; assert(depComps.size() >= maxLoopDepth); // Check if the first non-zero dependence component is positive. // This iterates through loops in the desired order. @@ -572,7 +572,7 @@ bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops, // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. assert(loopPermMap.size() == loops.size()); unsigned maxLoopDepth = loops.size(); - std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec; + std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap); } @@ -608,13 +608,13 @@ AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) { // Gather dependence components for dependences between all ops in loop nest // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. unsigned maxLoopDepth = loops.size(); - std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec; + std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec); // Mark loops as either parallel or sequential. - llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true); + SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true); for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) { - llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; + SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i]; assert(depComps.size() >= maxLoopDepth); for (unsigned j = 0; j < maxLoopDepth; ++j) { DependenceComponent &depComp = depComps[j]; @@ -632,7 +632,7 @@ AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) { // Compute permutation of loops that sinks sequential loops (and thus raises // parallel loops) while preserving relative order. - llvm::SmallVector<unsigned, 4> loopPermMap(maxLoopDepth); + SmallVector<unsigned, 4> loopPermMap(maxLoopDepth); unsigned nextSequentialLoop = numParallelLoops; unsigned nextParallelLoop = 0; for (unsigned i = 0; i < maxLoopDepth; ++i) { diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index ba77ceacf28..b91b189b381 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -36,14 +36,13 @@ void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement, } void mlir::visitUsedValuesDefinedAbove( - Region ®ion, Region &limit, - llvm::function_ref<void(OpOperand *)> callback) { + Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { assert(limit.isAncestor(®ion) && "expected isolation limit to be an ancestor of the given region"); // Collect proper ancestors of `limit` upfront to avoid traversing the region // tree for every value. - llvm::SmallPtrSet<Region *, 4> properAncestors; + SmallPtrSet<Region *, 4> properAncestors; for (auto *reg = limit.getParentRegion(); reg != nullptr; reg = reg->getParentRegion()) { properAncestors.insert(reg); @@ -58,8 +57,7 @@ void mlir::visitUsedValuesDefinedAbove( } void mlir::visitUsedValuesDefinedAbove( - llvm::MutableArrayRef<Region> regions, - llvm::function_ref<void(OpOperand *)> callback) { + MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { for (Region ®ion : regions) visitUsedValuesDefinedAbove(region, region, callback); } @@ -71,7 +69,7 @@ void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, }); } -void mlir::getUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions, +void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, llvm::SetVector<Value *> &values) { for (Region ®ion : regions) getUsedValuesDefinedAbove(region, region, values); @@ -352,7 +350,7 @@ static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. -LogicalResult mlir::simplifyRegions(llvm::MutableArrayRef<Region> regions) { +LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) { LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 036e53435ae..e3212d54e42 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -557,7 +557,7 @@ static llvm::cl::list<int> clFastestVaryingPattern( /// Forward declaration. static FilterFunctionType -isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> ¶llelLoops, +isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, int fastestVaryingMemRefDimension); /// Creates a vectorization pattern from the command line arguments. @@ -565,7 +565,7 @@ isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> ¶llelLoops, /// If the command line argument requests a pattern of higher order, returns an /// empty pattern list which will conservatively result in no vectorization. static std::vector<NestedPattern> -makePatterns(const llvm::DenseSet<Operation *> ¶llelLoops, int vectorRank, +makePatterns(const DenseSet<Operation *> ¶llelLoops, int vectorRank, ArrayRef<int64_t> fastestVaryingPattern) { using matcher::For; int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; @@ -842,8 +842,8 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, map(makePtrDynCaster<Value>(), indices), AffineMapAttr::get(permutationMap), // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 - state->folder->create<ConstantFloatOp>( - b, opInst->getLoc(), llvm::APFloat(0.0f), b.getF32Type())); + state->folder->create<ConstantFloatOp>(b, opInst->getLoc(), + APFloat(0.0f), b.getF32Type())); state->registerReplacement(opInst, transfer.getOperation()); } else { state->registerTerminal(opInst); @@ -889,7 +889,7 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, /// loop whose underlying load/store accesses are either invariant or all // varying along the `fastestVaryingMemRefDimension`. static FilterFunctionType -isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> ¶llelLoops, +isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, int fastestVaryingMemRefDimension) { return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { auto loop = cast<AffineForOp>(forOp); @@ -1255,7 +1255,7 @@ void Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - llvm::DenseSet<Operation *> parallelLoops; + DenseSet<Operation *> parallelLoops; f.walk([¶llelLoops](AffineForOp loop) { if (isLoopParallel(loop)) parallelLoops.insert(loop); @@ -1293,7 +1293,7 @@ void Vectorize::runOnFunction() { } std::unique_ptr<OpPassBase<FuncOp>> -mlir::createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize) { +mlir::createVectorizePass(ArrayRef<int64_t> virtualVectorSize) { return std::make_unique<Vectorize>(virtualVectorSize); } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 503a82bf82b..591562d0245 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -28,15 +28,17 @@ static llvm::cl::opt<int> elideIfLarger( llvm::cl::desc("Upper limit to emit elements attribute rather than elide"), llvm::cl::init(16)); +using namespace mlir; + namespace llvm { // Specialize GraphTraits to treat Block as a graph of Operations as nodes and // uses as edges. -template <> struct GraphTraits<mlir::Block *> { - using GraphType = mlir::Block *; - using NodeRef = mlir::Operation *; +template <> struct GraphTraits<Block *> { + using GraphType = Block *; + using NodeRef = Operation *; - using ChildIteratorType = mlir::UseIterator; + using ChildIteratorType = UseIterator; static ChildIteratorType child_begin(NodeRef n) { return ChildIteratorType(n); } @@ -46,49 +48,46 @@ template <> struct GraphTraits<mlir::Block *> { // Operation's destructor is private so use Operation* instead and use // mapped iterator. - static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; } - using nodes_iterator = - mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>; - static nodes_iterator nodes_begin(mlir::Block *b) { + static Operation *AddressOf(Operation &op) { return &op; } + using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>; + static nodes_iterator nodes_begin(Block *b) { return nodes_iterator(b->begin(), &AddressOf); } - static nodes_iterator nodes_end(mlir::Block *b) { + static nodes_iterator nodes_end(Block *b) { return nodes_iterator(b->end(), &AddressOf); } }; // Specialize DOTGraphTraits to produce more readable output. -template <> -struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits { +template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(mlir::Operation *op, mlir::Block *); + static std::string getNodeLabel(Operation *op, Block *); }; -std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op, - mlir::Block *b) { +std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) { // Reuse the print output for the node labels. std::string ostr; raw_string_ostream os(ostr); os << op->getName() << "\n"; - if (!op->getLoc().isa<mlir::UnknownLoc>()) { + if (!op->getLoc().isa<UnknownLoc>()) { os << op->getLoc() << "\n"; } // Print resultant types - mlir::interleaveComma(op->getResultTypes(), os); + interleaveComma(op->getResultTypes(), os); os << "\n"; for (auto attr : op->getAttrs()) { os << '\n' << attr.first << ": "; // Always emit splat attributes. - if (attr.second.isa<mlir::SplatElementsAttr>()) { + if (attr.second.isa<SplatElementsAttr>()) { attr.second.print(os); continue; } // Elide "big" elements attributes. - auto elements = attr.second.dyn_cast<mlir::ElementsAttr>(); + auto elements = attr.second.dyn_cast<ElementsAttr>(); if (elements && elements.getNumElements() > elideIfLarger) { os << std::string(elements.getType().getRank(), '[') << "..." << std::string(elements.getType().getRank(), ']') << " : " @@ -96,7 +95,7 @@ std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op, continue; } - auto array = attr.second.dyn_cast<mlir::ArrayAttr>(); + auto array = attr.second.dyn_cast<ArrayAttr>(); if (array && static_cast<int64_t>(array.size()) > elideIfLarger) { os << "[...]"; continue; @@ -114,14 +113,14 @@ namespace { // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { - explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(), - bool short_names = false, const llvm::Twine &title = "") +struct PrintOpPass : public ModulePass<PrintOpPass> { + explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false, + const Twine &title = "") : os(os), title(title.str()), short_names(short_names) {} - std::string getOpName(mlir::Operation &op) { - auto symbolAttr = op.getAttrOfType<mlir::StringAttr>( - mlir::SymbolTable::getSymbolAttrName()); + std::string getOpName(Operation &op) { + auto symbolAttr = + op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); if (symbolAttr) return symbolAttr.getValue(); ++unnamedOpCtr; @@ -129,22 +128,22 @@ struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { } // Print all the ops in a module. - void processModule(mlir::ModuleOp module) { - for (mlir::Operation &op : module) { + void processModule(ModuleOp module) { + for (Operation &op : module) { // Modules may actually be nested, recurse on nesting. - if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) { + if (auto nestedModule = dyn_cast<ModuleOp>(op)) { processModule(nestedModule); continue; } auto opName = getOpName(op); - for (mlir::Region ®ion : op.getRegions()) { + for (Region ®ion : op.getRegions()) { for (auto indexed_block : llvm::enumerate(region)) { // Suffix block number if there are more than 1 block. auto blockName = region.getBlocks().size() == 1 ? "" : ("__" + llvm::utostr(indexed_block.index())); llvm::WriteGraph(os, &indexed_block.value(), short_names, - llvm::Twine(title) + opName + blockName); + Twine(title) + opName + blockName); } } } @@ -153,29 +152,28 @@ struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { void runOnModule() override { processModule(getModule()); } private: - llvm::raw_ostream &os; + raw_ostream &os; std::string title; int unnamedOpCtr = 0; bool short_names; }; } // namespace -void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name, - bool shortNames, const llvm::Twine &title, - llvm::GraphProgram::Name program) { +void mlir::viewGraph(Block &block, const Twine &name, bool shortNames, + const Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&block, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block, - bool shortNames, const llvm::Twine &title) { +raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, + const Twine &title) { return llvm::WriteGraph(os, &block, shortNames, title); } -std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> -mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames, - const llvm::Twine &title) { +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, + const Twine &title) { return std::make_unique<PrintOpPass>(os, shortNames, title); } -static mlir::PassRegistration<PrintOpPass> pass("print-op-graph", - "Print op graph per region"); +static PassRegistration<PrintOpPass> pass("print-op-graph", + "Print op graph per region"); diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp index 57c2f31e6a4..db55415d62e 100644 --- a/mlir/lib/Transforms/ViewRegionGraph.cpp +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -53,41 +53,40 @@ std::string DOTGraphTraits<Region *>::getNodeLabel(Block *Block, Region *) { } // end namespace llvm -void mlir::viewGraph(Region ®ion, const llvm::Twine &name, bool shortNames, - const llvm::Twine &title, - llvm::GraphProgram::Name program) { +void mlir::viewGraph(Region ®ion, const Twine &name, bool shortNames, + const Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(®ion, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Region ®ion, - bool shortNames, const llvm::Twine &title) { +raw_ostream &mlir::writeGraph(raw_ostream &os, Region ®ion, bool shortNames, + const Twine &title) { return llvm::WriteGraph(os, ®ion, shortNames, title); } -void mlir::Region::viewGraph(const llvm::Twine ®ionName) { +void mlir::Region::viewGraph(const Twine ®ionName) { ::mlir::viewGraph(*this, regionName); } void mlir::Region::viewGraph() { viewGraph("region"); } namespace { struct PrintCFGPass : public FunctionPass<PrintCFGPass> { - PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, - const llvm::Twine &title = "") + PrintCFGPass(raw_ostream &os = llvm::errs(), bool shortNames = false, + const Twine &title = "") : os(os), shortNames(shortNames), title(title.str()) {} void runOnFunction() override { mlir::writeGraph(os, getFunction().getBody(), shortNames, title); } private: - llvm::raw_ostream &os; + raw_ostream &os; bool shortNames; std::string title; }; } // namespace std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> -mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, bool shortNames, - const llvm::Twine &title) { +mlir::createPrintCFGGraphPass(raw_ostream &os, bool shortNames, + const Twine &title) { return std::make_unique<PrintCFGPass>(os, shortNames, title); } |

