diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/Utils.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 469 |
1 files changed, 469 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp new file mode 100644 index 00000000000..a6629183dee --- /dev/null +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -0,0 +1,469 @@ +//===- Utils.cpp ---- Misc utilities for code and data transformation -----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous transformation routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Utils.h" + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Dominance.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/DenseMap.h" +using namespace mlir; + +/// Return true if this operation dereferences one or more memref's. +// Temporary utility: will be replaced when this is modeled through +// side-effects/op traits. TODO(b/117228571) +static bool isMemRefDereferencingOp(Operation &op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || + isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) + return true; + return false; +} + +/// Return the AffineMapAttr associated with memory 'op' on 'memref'. +static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { + return TypeSwitch<Operation *, NamedAttribute>(op) + .Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp, + AffineDmaWaitOp>( + [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); +} + +// Perform the replacement in `op`. +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + Operation *op, + ArrayRef<Value> extraIndices, + AffineMap indexRemap, + ArrayRef<Value> extraOperands, + ArrayRef<Value> symbolOperands) { + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); + (void)oldMemRefRank; // unused in opt mode + if (indexRemap) { + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbolic operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); + + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing context (potentially + // escapes); no replacement in these cases. + return failure(); + + SmallVector<unsigned, 2> usePositions; + for (const auto &opEntry : llvm::enumerate(op->getOperands())) { + if (opEntry.value() == oldMemRef) + usePositions.push_back(opEntry.index()); + } + + // If memref doesn't appear, nothing to do. + if (usePositions.empty()) + return success(); + + if (usePositions.size() > 1) { + // TODO(mlir-team): extend it for this case when needed (rare). + assert(false && "multiple dereferencing uses in a single op not supported"); + return failure(); + } + + unsigned memRefOperandPos = usePositions.front(); + + OpBuilder builder(op); + NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); + AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue(); + unsigned oldMapNumInputs = oldMap.getNumInputs(); + SmallVector<Value, 4> oldMapOperands( + op->operand_begin() + memRefOperandPos + 1, + op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + + // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. + SmallVector<Value, 4> oldMemRefOperands; + SmallVector<Value, 4> affineApplyOps; + oldMemRefOperands.reserve(oldMemRefRank); + if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { + for (auto resultExpr : oldMap.getResults()) { + auto singleResMap = AffineMap::get(oldMap.getNumDims(), + oldMap.getNumSymbols(), resultExpr); + auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, + oldMapOperands); + oldMemRefOperands.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); + } + + // Construct new indices as a remap of the old ones if a remapping has been + // provided. The indices of a memref come right after it, i.e., + // at position memRefOperandPos + 1. + SmallVector<Value, 4> remapOperands; + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); + remapOperands.append(extraOperands.begin(), extraOperands.end()); + remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); + + SmallVector<Value, 4> remapOutputs; + remapOutputs.reserve(oldMemRefRank); + + if (indexRemap && + indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { + // Remapped indices. + for (auto resultExpr : indexRemap.getResults()) { + auto singleResMap = AffineMap::get( + indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); + auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, + remapOperands); + remapOutputs.push_back(afOp); + affineApplyOps.push_back(afOp); + } + } else { + // No remapping specified. + remapOutputs.append(remapOperands.begin(), remapOperands.end()); + } + + SmallVector<Value, 4> newMapOperands; + newMapOperands.reserve(newMemRefRank); + + // Prepend 'extraIndices' in 'newMapOperands'. + for (auto extraIndex : extraIndices) { + assert(extraIndex->getDefiningOp()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && + "invalid memory op index"); + newMapOperands.push_back(extraIndex); + } + + // Append 'remapOutputs' to 'newMapOperands'. + newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); + + // Create new fully composed AffineMap for new op to be created. + assert(newMapOperands.size() == newMemRefRank); + auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); + // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. + fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); + newMap = simplifyAffineMap(newMap); + canonicalizeMapAndOperands(&newMap, &newMapOperands); + // Remove any affine.apply's that became dead as a result of composition. + for (auto value : affineApplyOps) + if (value->use_empty()) + value->getDefiningOp()->erase(); + + // Construct the new operation using this memref. + OperationState state(op->getLoc(), op->getName()); + state.setOperandListToResizable(op->hasResizableOperandsList()); + state.operands.reserve(op->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + state.operands.append(op->operand_begin(), + op->operand_begin() + memRefOperandPos); + // Insert the new memref value. + state.operands.push_back(newMemRef); + + // Insert the new memref map operands. + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + + // Insert the remaining operands unmodified. + state.operands.append(op->operand_begin() + memRefOperandPos + 1 + + oldMapNumInputs, + op->operand_end()); + + // Result types don't change. Both memref's are of the same elemental type. + state.types.reserve(op->getNumResults()); + for (auto result : op->getResults()) + state.types.push_back(result->getType()); + + // Add attribute for 'newMap', other Attributes do not change. + auto newMapAttr = AffineMapAttr::get(newMap); + for (auto namedAttr : op->getAttrs()) { + if (namedAttr.first == oldMapAttrPair.first) { + state.attributes.push_back({namedAttr.first, newMapAttr}); + } else { + state.attributes.push_back(namedAttr); + } + } + + // Create the new operation. + auto *repOp = builder.createOperation(state); + op->replaceAllUsesWith(repOp); + op->erase(); + + return success(); +} + +LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, + ArrayRef<Value> extraIndices, + AffineMap indexRemap, + ArrayRef<Value> extraOperands, + ArrayRef<Value> symbolOperands, + Operation *domInstFilter, + Operation *postDomInstFilter) { + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); + (void)oldMemRefRank; + if (indexRemap) { + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbol operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); + + std::unique_ptr<DominanceInfo> domInfo; + std::unique_ptr<PostDominanceInfo> postDomInfo; + if (domInstFilter) + domInfo = std::make_unique<DominanceInfo>( + domInstFilter->getParentOfType<FuncOp>()); + + if (postDomInstFilter) + postDomInfo = std::make_unique<PostDominanceInfo>( + postDomInstFilter->getParentOfType<FuncOp>()); + + // Walk all uses of old memref; collect ops to perform replacement. We use a + // DenseSet since an operation could potentially have multiple uses of a + // memref (although rare), and the replacement later is going to erase ops. + DenseSet<Operation *> opsToReplace; + for (auto *op : oldMemRef->getUsers()) { + // Skip this use if it's not dominated by domInstFilter. + if (domInstFilter && !domInfo->dominates(domInstFilter, op)) + continue; + + // Skip this use if it's not post-dominated by postDomInstFilter. + if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) + continue; + + // Skip dealloc's - no replacement is necessary, and a memref replacement + // at other uses doesn't hurt these dealloc's. + if (isa<DeallocOp>(op)) + continue; + + // Check if the memref was used in a non-dereferencing context. It is fine + // for the memref to be used in a non-dereferencing way outside of the + // region where this replacement is happening. + if (!isMemRefDereferencingOp(*op)) + // Failure: memref used in a non-dereferencing op (potentially escapes); + // no replacement in these cases. + return failure(); + + // We'll first collect and then replace --- since replacement erases the op + // that has the use, and that op could be postDomFilter or domFilter itself! + opsToReplace.insert(op); + } + + for (auto *op : opsToReplace) { + if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, + indexRemap, extraOperands, + symbolOperands))) + llvm_unreachable("memref replacement guaranteed to succeed here"); + } + + return success(); +} + +/// Given an operation, inserts one or more single result affine +/// apply operations, results of which are exclusively used by this operation +/// operation. The operands of these newly created affine apply ops are +/// guaranteed to be loop iterators or terminal symbols of a function. +/// +/// Before +/// +/// affine.for %i = 0 to #map(%N) +/// %idx = affine.apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// "compute"(%idx) +/// +/// After +/// +/// affine.for %i = 0 to #map(%N) +/// %idx = affine.apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) +/// "compute"(%idx_) +/// +/// This allows applying different transformations on send and compute (for eg. +/// different shifts/delays). +/// +/// Returns nullptr either if none of opInst's operands were the result of an +/// affine.apply and thus there was no affine computation slice to create, or if +/// all the affine.apply op's supplying operands to this opInst did not have any +/// uses besides this opInst; otherwise returns the list of affine.apply +/// operations created in output argument `sliceOps`. +void mlir::createAffineComputationSlice( + Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { + // Collect all operands that are results of affine apply ops. + SmallVector<Value, 4> subOperands; + subOperands.reserve(opInst->getNumOperands()); + for (auto operand : opInst->getOperands()) + if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp())) + subOperands.push_back(operand); + + // Gather sequence of AffineApplyOps reachable from 'subOperands'. + SmallVector<Operation *, 4> affineApplyOps; + getReachableAffineApplyOps(subOperands, affineApplyOps); + // Skip transforming if there are no affine maps to compose. + if (affineApplyOps.empty()) + return; + + // Check if all uses of the affine apply op's lie only in this op op, in + // which case there would be nothing to do. + bool localized = true; + for (auto *op : affineApplyOps) { + for (auto result : op->getResults()) { + for (auto *user : result->getUsers()) { + if (user != opInst) { + localized = false; + break; + } + } + } + } + if (localized) + return; + + OpBuilder builder(opInst); + SmallVector<Value, 4> composedOpOperands(subOperands); + auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); + fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); + + // Create an affine.apply for each of the map results. + sliceOps->reserve(composedMap.getNumResults()); + for (auto resultExpr : composedMap.getResults()) { + auto singleResMap = AffineMap::get(composedMap.getNumDims(), + composedMap.getNumSymbols(), resultExpr); + sliceOps->push_back(builder.create<AffineApplyOp>( + opInst->getLoc(), singleResMap, composedOpOperands)); + } + + // Construct the new operands that include the results from the composed + // affine apply op above instead of existing ones (subOperands). So, they + // differ from opInst's operands only for those operands in 'subOperands', for + // which they will be replaced by the corresponding one from 'sliceOps'. + SmallVector<Value, 4> newOperands(opInst->getOperands()); + for (unsigned i = 0, e = newOperands.size(); i < e; i++) { + // Replace the subOperands from among the new operands. + unsigned j, f; + for (j = 0, f = subOperands.size(); j < f; j++) { + if (newOperands[i] == subOperands[j]) + break; + } + if (j < subOperands.size()) { + newOperands[i] = (*sliceOps)[j]; + } + } + for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { + opInst->setOperand(idx, newOperands[idx]); + } +} + +// TODO: Currently works for static memrefs with a single layout map. +LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { + MemRefType memrefType = allocOp.getType(); + unsigned rank = memrefType.getRank(); + if (rank == 0) + return success(); + + auto layoutMaps = memrefType.getAffineMaps(); + OpBuilder b(allocOp); + if (layoutMaps.size() != 1) + return failure(); + + AffineMap layoutMap = layoutMaps.front(); + + // Nothing to do for identity layout maps. + if (layoutMap == b.getMultiDimIdentityMap(rank)) + return success(); + + // We don't do any checks for one-to-one'ness; we assume that it is + // one-to-one. + + // TODO: Only for static memref's for now. + if (memrefType.getNumDynamicDims() > 0) + return failure(); + + // We have a single map that is not an identity map. Create a new memref with + // the right shape and an identity layout map. + auto shape = memrefType.getShape(); + FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands()); + for (unsigned d = 0; d < rank; ++d) { + fac.addConstantLowerBound(d, 0); + fac.addConstantUpperBound(d, shape[d] - 1); + } + + // We compose this map with the original index (logical) space to derive the + // upper bounds for the new index space. + unsigned newRank = layoutMap.getNumResults(); + if (failed(fac.composeMatchingMap(layoutMap))) + // TODO: semi-affine maps. + return failure(); + + // Project out the old data dimensions. + fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); + SmallVector<int64_t, 4> newShape(newRank); + for (unsigned d = 0; d < newRank; ++d) { + // The lower bound for the shape is always zero. + auto ubConst = fac.getConstantUpperBound(d); + // For a static memref and an affine map with no symbols, this is always + // bounded. + assert(ubConst.hasValue() && "should always have an upper bound"); + if (ubConst.getValue() < 0) + // This is due to an invalid map that maps to a negative space. + return failure(); + newShape[d] = ubConst.getValue() + 1; + } + + auto oldMemRef = allocOp.getResult(); + SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands()); + + auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), + b.getMultiDimIdentityMap(newRank)); + auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType); + + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands))) { + // If it failed (due to escapes for example), bail out. + newAlloc.erase(); + return failure(); + } + // Replace any uses of the original alloc op and erase it. All remaining uses + // have to be dealloc's; RAMUW above would've failed otherwise. + assert(std::all_of(oldMemRef->user_begin(), oldMemRef->user_end(), + [](Operation *op) { return isa<DeallocOp>(op); })); + oldMemRef->replaceAllUsesWith(newAlloc); + allocOp.erase(); + return success(); +} |