//===- Utils.cpp ---- Misc utilities for code and data transformation -----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements miscellaneous transformation routines for non-loop IR // structures. // //===----------------------------------------------------------------------===// #include "mlir/Transforms/Utils.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(Operation &op) { if (isa(op) || isa(op) || isa(op) || isa(op)) return true; return false; } /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { if (auto loadOp = dyn_cast(op)) return loadOp.getAffineMapAttrForMemRef(memref); else if (auto storeOp = dyn_cast(op)) return storeOp.getAffineMapAttrForMemRef(memref); else if (auto dmaStart = dyn_cast(op)) return dmaStart.getAffineMapAttrForMemRef(memref); assert(isa(op)); return cast(op).getAffineMapAttrForMemRef(memref); } bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, Operation *domInstFilter, Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast().getRank(); (void)newMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); } // Assert same elemental type. assert(oldMemRef->getType().cast().getElementType() == newMemRef->getType().cast().getElementType()); std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) domInfo = std::make_unique( domInstFilter->getParentOfType()); if (postDomInstFilter) postDomInfo = std::make_unique( postDomInstFilter->getParentOfType()); // The ops where memref replacement succeeds are replaced with new ones. SmallVector opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) { // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) continue; // Skip this use if it's not post-dominated by postDomInstFilter. if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, opInst)) continue; // Skip dealloc's - no replacement is necessary, and a replacement doesn't // hurt dealloc's. if (isa(opInst)) continue; // Check if the memref was used in a non-deferencing context. It is fine for // the memref to be used in a non-deferencing way outside of the region // where this replacement is happening. if (!isMemRefDereferencingOp(*opInst)) // Failure: memref used in a non-deferencing op (potentially escapes); no // replacement in these cases. return false; auto getMemRefOperandPos = [&]() -> unsigned { unsigned i, e; for (i = 0, e = opInst->getNumOperands(); i < e; i++) { if (opInst->getOperand(i) == oldMemRef) break; } assert(i < opInst->getNumOperands() && "operand guaranteed to be found"); return i; }; OpBuilder builder(opInst); unsigned memRefOperandPos = getMemRefOperandPos(); NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(opInst, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( opInst->operand_begin() + memRefOperandPos + 1, opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); SmallVector affineApplyOps; // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. SmallVector oldMemRefOperands; oldMemRefOperands.reserve(oldMemRefRank); if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { auto singleResMap = builder.getAffineMap( oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); auto afOp = builder.create(opInst->getLoc(), singleResMap, oldMapOperands); oldMemRefOperands.push_back(afOp); affineApplyOps.push_back(afOp); } } else { oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector remapOperands; remapOperands.reserve(extraOperands.size() + oldMemRefRank); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); if (indexRemap && indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { // Remapped indices. for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = builder.getAffineMap( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); auto afOp = builder.create(opInst->getLoc(), singleResMap, remapOperands); remapOutputs.push_back(afOp); affineApplyOps.push_back(afOp); } } else { // No remapping specified. remapOutputs.append(remapOperands.begin(), remapOperands.end()); } SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. for (auto *extraIndex : extraIndices) { assert(extraIndex->getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && "invalid memory op index"); newMapOperands.push_back(extraIndex); } // Append 'remapOutputs' to 'newMapOperands'. newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); // Create new fully composed AffineMap for new op to be created. assert(newMapOperands.size() == newMemRefRank); auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. for (auto *value : affineApplyOps) if (value->use_empty()) value->getDefiningOp()->erase(); // Construct the new operation using this memref. OperationState state(opInst->getLoc(), opInst->getName()); state.setOperandListToResizable(opInst->hasResizableOperandsList()); state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. state.operands.append(opInst->operand_begin(), opInst->operand_begin() + memRefOperandPos); // Insert the new memref value. state.operands.push_back(newMemRef); // Insert the new memref map operands. state.operands.append(newMapOperands.begin(), newMapOperands.end()); // Insert the remaining operands unmodified. state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs, opInst->operand_end()); // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(opInst->getNumResults()); for (auto *result : opInst->getResults()) state.types.push_back(result->getType()); // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = builder.getAffineMapAttr(newMap); for (auto namedAttr : opInst->getAttrs()) { if (namedAttr.first == oldMapAttrPair.first) { state.attributes.push_back({namedAttr.first, newMapAttr}); } else { state.attributes.push_back(namedAttr); } } // Create the new operation. auto *repOp = builder.createOperation(state); opInst->replaceAllUsesWith(repOp); // Collect and erase at the end since one of these op's could be // domInstFilter or postDomInstFilter as well! opsToErase.push_back(opInst); } for (auto *opInst : opsToErase) opInst->erase(); return true; } /// Given an operation, inserts one or more single result affine /// apply operations, results of which are exclusively used by this operation /// operation. The operands of these newly created affine apply ops are /// guaranteed to be loop iterators or terminal symbols of a function. /// /// Before /// /// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// "compute"(%idx) /// /// After /// /// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) /// "compute"(%idx_) /// /// This allows applying different transformations on send and compute (for eg. /// different shifts/delays). /// /// Returns nullptr either if none of opInst's operands were the result of an /// affine.apply and thus there was no affine computation slice to create, or if /// all the affine.apply op's supplying operands to this opInst did not have any /// uses besides this opInst; otherwise returns the list of affine.apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( Operation *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); for (auto *operand : opInst->getOperands()) if (isa_and_nonnull(operand->getDefiningOp())) subOperands.push_back(operand); // Gather sequence of AffineApplyOps reachable from 'subOperands'. SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) return; // Check if all uses of the affine apply op's lie only in this op op, in // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { for (auto *result : op->getResults()) { for (auto *user : result->getUsers()) { if (user != opInst) { localized = false; break; } } } } if (localized) return; OpBuilder builder(opInst); SmallVector composedOpOperands(subOperands); auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); // Create an affine.apply for each of the map results. sliceOps->reserve(composedMap.getNumResults()); for (auto resultExpr : composedMap.getResults()) { auto singleResMap = builder.getAffineMap( composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); sliceOps->push_back(builder.create( opInst->getLoc(), singleResMap, composedOpOperands)); } // Construct the new operands that include the results from the composed // affine apply op above instead of existing ones (subOperands). So, they // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'sliceOps'. SmallVector newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; for (j = 0, f = subOperands.size(); j < f; j++) { if (newOperands[i] == subOperands[j]) break; } if (j < subOperands.size()) { newOperands[i] = (*sliceOps)[j]; } } for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { opInst->setOperand(idx, newOperands[idx]); } }