diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-27 14:35:10 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:40:06 -0700 |
| commit | 3f190312f8f7f09b5910bc77e80268402732ce6b (patch) | |
| tree | 1ac0989c889d04e1acb0370952ed3bf1f141e17d /mlir/lib | |
| parent | 776b035646d809d8b31662363e797f4d7f26c223 (diff) | |
| download | bcm5719-llvm-3f190312f8f7f09b5910bc77e80268402732ce6b.tar.gz bcm5719-llvm-3f190312f8f7f09b5910bc77e80268402732ce6b.zip | |
Merge SSAValue, CFGValue, and MLValue together into a single Value class, which
is the new base of the SSA value hierarchy. This CL also standardizes all the
nomenclature and comments to use 'Value' where appropriate. This also eliminates a large number of cast<MLValue>(x)'s, which is very soothing.
This is step 11/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 227064624
Diffstat (limited to 'mlir/lib')
33 files changed, 473 insertions, 554 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index bdc2c7ec286..04ef715d011 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -489,11 +489,11 @@ bool mlir::getFlattenedAffineExprs( // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( - ArrayRef<MLValue *> operands, + ArrayRef<Value *> operands, SmallVectorImpl<OperationStmt *> &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. - MLValue *value; + Value *value; // The operand index of 'value' to explore next during DFS traversal. unsigned operandIndex; }; @@ -557,8 +557,8 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { // setExprStride(ArrayRef<int64_t> expr, int64_t stride) bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts, FlatAffineConstraints *domain) { - SmallVector<MLValue *, 4> indices(forStmts.begin(), forStmts.end()); - // Reset while associated MLValues in 'indices' to the domain. + SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end()); + // Reset while associated Values in 'indices' to the domain. domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto *forStmt : forStmts) { // Add constraints from forStmt's bounds. @@ -583,10 +583,10 @@ static bool getStmtIndexSet(const Statement *stmt, return getIndexSet(loops, indexSet); } -// ValuePositionMap manages the mapping from MLValues which represent dimension +// ValuePositionMap manages the mapping from Values which represent dimension // and symbol identifiers from 'src' and 'dst' access functions to positions -// in new space where some MLValues are kept separate (using addSrc/DstValue) -// and some MLValues are merged (addSymbolValue). +// in new space where some Values are kept separate (using addSrc/DstValue) +// and some Values are merged (addSymbolValue). // Position lookups return the absolute position in the new space which // has the following format: // @@ -595,7 +595,7 @@ static bool getStmtIndexSet(const Statement *stmt, // Note: access function non-IV dimension identifiers (that have 'dimension' // positions in the access function position space) are assigned as symbols // in the output position space. Convienience access functions which lookup -// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle +// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle // the common case of resolving positions for all access function operands. // // TODO(andydavis) Generalize this: could take a template parameter for @@ -603,25 +603,25 @@ static bool getStmtIndexSet(const Statement *stmt, // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". class ValuePositionMap { public: - void addSrcValue(const MLValue *value) { + void addSrcValue(const Value *value) { if (addValueAt(value, &srcDimPosMap, numSrcDims)) ++numSrcDims; } - void addDstValue(const MLValue *value) { + void addDstValue(const Value *value) { if (addValueAt(value, &dstDimPosMap, numDstDims)) ++numDstDims; } - void addSymbolValue(const MLValue *value) { + void addSymbolValue(const Value *value) { if (addValueAt(value, &symbolPosMap, numSymbols)) ++numSymbols; } - unsigned getSrcDimOrSymPos(const MLValue *value) const { + unsigned getSrcDimOrSymPos(const Value *value) const { return getDimOrSymPos(value, srcDimPosMap, 0); } - unsigned getDstDimOrSymPos(const MLValue *value) const { + unsigned getDstDimOrSymPos(const Value *value) const { return getDimOrSymPos(value, dstDimPosMap, numSrcDims); } - unsigned getSymPos(const MLValue *value) const { + unsigned getSymPos(const Value *value) const { auto it = symbolPosMap.find(value); assert(it != symbolPosMap.end()); return numSrcDims + numDstDims + it->second; @@ -633,8 +633,7 @@ public: unsigned getNumSymbols() const { return numSymbols; } private: - bool addValueAt(const MLValue *value, - DenseMap<const MLValue *, unsigned> *posMap, + bool addValueAt(const Value *value, DenseMap<const Value *, unsigned> *posMap, unsigned position) { auto it = posMap->find(value); if (it == posMap->end()) { @@ -643,8 +642,8 @@ private: } return false; } - unsigned getDimOrSymPos(const MLValue *value, - const DenseMap<const MLValue *, unsigned> &dimPosMap, + unsigned getDimOrSymPos(const Value *value, + const DenseMap<const Value *, unsigned> &dimPosMap, unsigned dimPosOffset) const { auto it = dimPosMap.find(value); if (it != dimPosMap.end()) { @@ -658,25 +657,25 @@ private: unsigned numSrcDims = 0; unsigned numDstDims = 0; unsigned numSymbols = 0; - DenseMap<const MLValue *, unsigned> srcDimPosMap; - DenseMap<const MLValue *, unsigned> dstDimPosMap; - DenseMap<const MLValue *, unsigned> symbolPosMap; + DenseMap<const Value *, unsigned> srcDimPosMap; + DenseMap<const Value *, unsigned> dstDimPosMap; + DenseMap<const Value *, unsigned> symbolPosMap; }; -// Builds a map from MLValue to identifier position in a new merged identifier +// Builds a map from Value to identifier position in a new merged identifier // list, which is the result of merging dim/symbol lists from src/dst // iteration domains. The format of the new merged list is as follows: // // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers] // -// This method populates 'valuePosMap' with mappings from operand MLValues in +// This method populates 'valuePosMap' with mappings from operand Values in // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') // to the position of these values in the merged list. static void buildDimAndSymbolPositionMaps( const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) { - auto updateValuePosMap = [&](ArrayRef<MLValue *> values, bool isSrc) { + auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; if (!isa<ForStmt>(values[i])) @@ -688,7 +687,7 @@ static void buildDimAndSymbolPositionMaps( } }; - SmallVector<MLValue *, 4> srcValues, destValues; + SmallVector<Value *, 4> srcValues, destValues; srcDomain.getIdValues(&srcValues); dstDomain.getIdValues(&destValues); @@ -702,17 +701,10 @@ static void buildDimAndSymbolPositionMaps( updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); } -static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap, - const MLValue *value) { - auto it = posMap.find(value); - assert(it != posMap.end()); - return it->second; -} - // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into // 'dependenceDomain'. // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a -// srcDomain/dstDomain MLValue maps. +// srcDomain/dstDomain Value maps. static void addDomainConstraints(const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &dstDomain, const ValuePositionMap &valuePosMap, @@ -790,10 +782,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, unsigned numResults = srcMap.getNumResults(); unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); - ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands(); + ArrayRef<Value *> srcOperands = srcAccessMap.getOperands(); unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); - ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands(); + ArrayRef<Value *> dstOperands = dstAccessMap.getOperands(); std::vector<SmallVector<int64_t, 8>> srcFlatExprs; std::vector<SmallVector<int64_t, 8>> destFlatExprs; @@ -848,7 +840,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, } // Add equality constraints for any operands that are defined by constant ops. - auto addEqForConstOperands = [&](ArrayRef<const MLValue *> operands) { + auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (isa<ForStmt>(operands[i])) continue; @@ -1095,7 +1087,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // upper/lower loop bounds for each ForStmt in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map -// MLValues from access functions and iteration domains to their position +// Values from access functions and iteration domains to their position // in the merged constraint system built by this method. // // This method builds a constraint system with the following column format: @@ -1202,7 +1194,7 @@ bool mlir::checkMemrefAccessDependence( return false; } // Build dim and symbol position maps for each access from access operand - // MLValue to position in merged contstraint system. + // Value to position in merged contstraint system. ValuePositionMap valuePosMap; buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, dstAccessMap, &valuePosMap); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index c22c5ec95bc..bfdaceff7e7 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/Statements.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" @@ -238,23 +237,23 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, AffineValueMap::AffineValueMap(const AffineApplyOp &op) : map(op.getAffineMap()) { for (auto *operand : op.getOperands()) - operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand))); + operands.push_back(const_cast<Value *>(operand)); for (unsigned i = 0, e = op.getNumResults(); i < e; i++) - results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i)))); + results.push_back(const_cast<Value *>(op.getResult(i))); } -AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands) +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands) : map(map) { - for (MLValue *operand : operands) { + for (Value *operand : operands) { this->operands.push_back(operand); } } -void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) { +void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands) { this->operands.clear(); this->results.clear(); this->map.reset(map); - for (MLValue *operand : operands) { + for (Value *operand : operands) { this->operands.push_back(operand); } } @@ -275,7 +274,7 @@ void AffineValueMap::forwardSubstituteSingle(const AffineApplyOp &inputOp, // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. -static bool findIndex(MLValue *valueToMatch, ArrayRef<MLValue *> valuesToSearch, +static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch, unsigned indexStart, unsigned *indexOfMatch) { unsigned size = valuesToSearch.size(); for (unsigned i = indexStart; i < size; ++i) { @@ -324,8 +323,7 @@ void AffineValueMap::forwardSubstitute( for (unsigned j = 0; j < inputNumResults; ++j) { if (!inputResultsToSubstitute[j]) continue; - if (operands[i] == - cast<MLValue>(const_cast<SSAValue *>(inputOp.getResult(j)))) { + if (operands[i] == const_cast<Value *>(inputOp.getResult(j))) { currOperandToInputResult[i] = j; inputResultsUsed.insert(j); } @@ -365,7 +363,7 @@ void AffineValueMap::forwardSubstitute( } // Build new output operands list and map update. - SmallVector<MLValue *, 4> outputOperands; + SmallVector<Value *, 4> outputOperands; unsigned outputOperandPosition = 0; AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults()); @@ -385,8 +383,7 @@ void AffineValueMap::forwardSubstitute( if (inputPositionsUsed.count(i) == 0) continue; // Check if input operand has a dup in current operand list. - auto *inputOperand = - cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i))); + auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i)); unsigned outputIndex; if (findIndex(inputOperand, outputOperands, /*indexStart=*/0, &outputIndex)) { @@ -418,8 +415,7 @@ void AffineValueMap::forwardSubstitute( continue; unsigned inputSymbolPosition = i - inputNumDims; // Check if input operand has a dup in current operand list. - auto *inputOperand = - cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i))); + auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i)); // Find output operand index of 'inputOperand' dup. unsigned outputIndex; // Start at index 'outputNumDims' so that only symbol operands are searched. @@ -451,7 +447,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { /// This method uses the invariant that operands are always positionally aligned /// with the AffineDimExpr in the underlying AffineMap. -bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const { +bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { unsigned index; findIndex(value, operands, /*indexStart=*/0, &index); auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx); @@ -460,12 +456,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const { return expr.isFunctionOfDim(index); } -SSAValue *AffineValueMap::getOperand(unsigned i) const { - return static_cast<SSAValue *>(operands[i]); +Value *AffineValueMap::getOperand(unsigned i) const { + return static_cast<Value *>(operands[i]); } -ArrayRef<MLValue *> AffineValueMap::getOperands() const { - return ArrayRef<MLValue *>(operands); +ArrayRef<Value *> AffineValueMap::getOperands() const { + return ArrayRef<Value *>(operands); } AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } @@ -546,7 +542,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef<MLValue *> idArgs) { + ArrayRef<Value *> idArgs) { assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && "minimum 1 column"); numReservedCols = newNumReservedCols; @@ -570,7 +566,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, unsigned newNumLocals, - ArrayRef<MLValue *> idArgs) { + ArrayRef<Value *> idArgs) { reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, newNumSymbols, newNumLocals, idArgs); } @@ -597,17 +593,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) { addId(IdKind::Local, pos); } -void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) { +void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { addId(IdKind::Dimension, pos, id); } -void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) { +void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { addId(IdKind::Symbol, pos, id); } /// Adds a dimensional identifier. The added column is initialized to /// zero. -void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) { +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { if (kind == IdKind::Dimension) { assert(pos <= getNumDimIds()); } else if (kind == IdKind::Symbol) { @@ -755,7 +751,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { // Dims and symbols. for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { unsigned loc; - bool ret = findId(*cast<MLValue>(vMap->getOperand(i)), &loc); + bool ret = findId(*vMap->getOperand(i), &loc); assert(ret && "value map's id can't be found"); (void)ret; // We need to negate 'eq[r]' since the newly added dimension is going to @@ -1231,7 +1227,7 @@ void FlatAffineConstraints::addUpperBound(ArrayRef<int64_t> expr, } } -bool FlatAffineConstraints::findId(const MLValue &id, unsigned *pos) const { +bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { unsigned i = 0; for (const auto &mayBeId : ids) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) { @@ -1253,8 +1249,8 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { unsigned pos; // Pre-condition for this method. - if (!findId(*cast<MLValue>(&forStmt), &pos)) { - assert(0 && "MLValue not found"); + if (!findId(forStmt, &pos)) { + assert(0 && "Value not found"); return false; } @@ -1270,7 +1266,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { unsigned loc; if (!findId(*operand, &loc)) { if (operand->isValidSymbol()) { - addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand)); + addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. if (auto *opStmt = operand->getDefiningStmt()) { @@ -1279,7 +1275,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { } } } else { - addDimId(getNumDimIds(), const_cast<MLValue *>(operand)); + addDimId(getNumDimIds(), const_cast<Value *>(operand)); loc = getNumDimIds() - 1; } } @@ -1352,7 +1348,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { /// Sets the specified identifer to a constant value; asserts if the id is not /// found. -void FlatAffineConstraints::setIdToConstant(const MLValue &id, int64_t val) { +void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) { unsigned pos; if (!findId(id, &pos)) // This is a pre-condition for this method. @@ -1572,7 +1568,7 @@ void FlatAffineConstraints::print(raw_ostream &os) const { if (ids[i] == None) os << "None "; else - os << "MLValue "; + os << "Value "; } os << " const)\n"; for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { @@ -1779,7 +1775,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate( unsigned newNumDims = dimsSymbols.first; unsigned newNumSymbols = dimsSymbols.second; - SmallVector<Optional<MLValue *>, 8> newIds; + SmallVector<Optional<Value *>, 8> newIds; newIds.reserve(numIds - 1); newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos); newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end()); @@ -1942,7 +1938,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { normalizeConstraintsByGCD(); } -void FlatAffineConstraints::projectOut(MLValue *id) { +void FlatAffineConstraints::projectOut(Value *id) { unsigned pos; bool ret = findId(*id, &pos); assert(ret); diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index b3faaf3eae0..1a28eb138f4 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -70,7 +70,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a, } /// Return true if value A properly dominates instruction B. -bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) { +bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { if (auto *aInst = a->getDefiningInst()) return properlyDominates(aInst, b); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index f20b8bb19e5..7213ba5986a 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -124,14 +124,14 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { return tripCountExpr.getLargestKnownDivisor(); } -bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) { +bool mlir::isAccessInvariant(const Value &iv, const Value &index) { assert(isa<ForStmt>(iv) && "iv must be a ForStmt"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); SmallVector<OperationStmt *, 4> affineApplyOps; - getReachableAffineApplyOps({const_cast<MLValue *>(&index)}, affineApplyOps); + getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); if (affineApplyOps.empty()) { - // Pointer equality test because of MLValue pointer semantics. + // Pointer equality test because of Value pointer semantics. return &index != &iv; } @@ -155,13 +155,13 @@ bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) { } assert(idx < std::numeric_limits<unsigned>::max()); return !AffineValueMap(*composeOp) - .isFunctionOf(idx, &const_cast<MLValue &>(iv)); + .isFunctionOf(idx, &const_cast<Value &>(iv)); } -llvm::DenseSet<const MLValue *> -mlir::getInvariantAccesses(const MLValue &iv, - llvm::ArrayRef<const MLValue *> indices) { - llvm::DenseSet<const MLValue *> res; +llvm::DenseSet<const Value *> +mlir::getInvariantAccesses(const Value &iv, + llvm::ArrayRef<const Value *> indices) { + llvm::DenseSet<const Value *> res; for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { auto *val = indices[idx]; if (isAccessInvariant(iv, *val)) { @@ -191,7 +191,7 @@ mlir::getInvariantAccesses(const MLValue &iv, /// // TODO(ntv): check strides. template <typename LoadOrStoreOp> -static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp, +static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp, unsigned fastestVaryingDim) { static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value || std::is_same<LoadOrStoreOp, StoreOp>::value, @@ -220,7 +220,7 @@ static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp, if (fastestVaryingDim == (numIndices - 1) - d++) { continue; } - if (!isAccessInvariant(iv, cast<MLValue>(*index))) { + if (!isAccessInvariant(iv, *index)) { return false; } } @@ -316,7 +316,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, // outside). if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const MLValue *result = opStmt->getResult(i); + const Value *result = opStmt->getResult(i); for (const StmtOperand &use : result->getUses()) { // If an ancestor statement doesn't lie in the block of forStmt, there // is no shift to check. diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 2e3df2d61f4..7c57a66310a 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -70,7 +70,7 @@ static void addMemRefAccessIndices( MemRefType memrefType, MemRefAccess *access) { access->indices.reserve(memrefType.getRank()); for (auto *index : opIndices) { - access->indices.push_back(cast<MLValue>(const_cast<SSAValue *>(index))); + access->indices.push_back(const_cast<mlir::Value *>(index)); } } @@ -79,13 +79,13 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { access->opStmt = loadOrStoreOpStmt; if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { - access->memref = cast<MLValue>(loadOp->getMemRef()); + access->memref = loadOp->getMemRef(); addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), access); } else { assert(loadOrStoreOpStmt->isa<StoreOp>()); auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); - access->memref = cast<MLValue>(storeOp->getMemRef()); + access->memref = storeOp->getMemRef(); addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), access); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0c6cfea7ccd..7d397647bc9 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -150,21 +150,21 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, OpPointer<LoadOp> loadOp; OpPointer<StoreOp> storeOp; unsigned rank; - SmallVector<MLValue *, 4> indices; + SmallVector<Value *, 4> indices; if ((loadOp = opStmt->dyn_cast<LoadOp>())) { rank = loadOp->getMemRefType().getRank(); for (auto *index : loadOp->getIndices()) { - indices.push_back(cast<MLValue>(index)); + indices.push_back(index); } - region->memref = cast<MLValue>(loadOp->getMemRef()); + region->memref = loadOp->getMemRef(); region->setWrite(false); } else if ((storeOp = opStmt->dyn_cast<StoreOp>())) { rank = storeOp->getMemRefType().getRank(); for (auto *index : storeOp->getIndices()) { - indices.push_back(cast<MLValue>(index)); + indices.push_back(index); } - region->memref = cast<MLValue>(storeOp->getMemRef()); + region->memref = storeOp->getMemRef(); region->setWrite(true); } else { return false; @@ -201,7 +201,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, return false; } else { // Has to be a valid symbol. - auto *symbol = cast<MLValue>(accessValueMap.getOperand(i)); + auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. if (auto *opStmt = symbol->getDefiningStmt()) { @@ -405,7 +405,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // Solve for src IVs in terms of dst IVs, symbols and constants. SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null()); - std::vector<SmallVector<MLValue *, 2>> srcIvOperands(srcLoopNestSize); + std::vector<SmallVector<Value *, 2>> srcIvOperands(srcLoopNestSize); for (unsigned i = 0; i < srcLoopNestSize; ++i) { // Skip IVs which are greater than requested loop depth. if (i >= srcLoopDepth) { @@ -442,7 +442,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); } // TODO(andydavis) Add symbols from the access function. Ideally, we - // should be able to query the constaint system for the MLValue associated + // should be able to query the constaint system for the Value associated // with a symbol identifiers in 'nonZeroSymbolIds'. } @@ -454,7 +454,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // of the loop at 'dstLoopDepth' in 'dstLoopNest'. auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); - DenseMap<const MLValue *, MLValue *> operandMap; + DenseMap<const Value *, Value *> operandMap; auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap)); // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index bfef98d76da..ec19194f2fa 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -108,7 +108,7 @@ static AffineMap makePermutationMap( const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; - auto unwrappedIndices = map(makePtrDynCaster<SSAValue, MLValue>(), indices); + auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices); SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(), getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index a04cee7512d..e7abb899a11 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -277,7 +277,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> { /// Walk all of the code in this MLFunc and verify that the operands of any /// operations are properly dominated by their definitions. bool MLFuncVerifier::verifyDominance() { - using HashTable = llvm::ScopedHashTable<const SSAValue *, bool>; + using HashTable = llvm::ScopedHashTable<const Value *, bool>; HashTable liveValues; HashTable::ScopeTy topScope(liveValues); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4778564cb4d..c44ce4d4d6c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -38,7 +38,6 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" - using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -967,7 +966,7 @@ public: void printFunctionAttributes(const Function *func) { return ModulePrinter::printFunctionAttributes(func); } - void printOperand(const SSAValue *value) { printValueID(value); } + void printOperand(const Value *value) { printValueID(value); } void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs = {}) { @@ -977,7 +976,7 @@ public: enum { nameSentinel = ~0U }; protected: - void numberValueID(const SSAValue *value) { + void numberValueID(const Value *value) { assert(!valueIDs.count(value) && "Value numbered multiple times"); SmallString<32> specialNameBuffer; @@ -1004,7 +1003,7 @@ protected: if (specialNameBuffer.empty()) { switch (value->getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: // If this is an argument to the function, give it an 'arg' name. if (auto *block = cast<BlockArgument>(value)->getOwner()) if (auto *fn = block->getFunction()) @@ -1015,12 +1014,12 @@ protected: // Otherwise number it normally. valueIDs[value] = nextValueID++; return; - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: // This is an uninteresting result, give it a boring number and be // done with it. valueIDs[value] = nextValueID++; return; - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: specialName << 'i' << nextLoopID++; break; } @@ -1052,7 +1051,7 @@ protected: } } - void printValueID(const SSAValue *value, bool printResultNo = true) const { + void printValueID(const Value *value, bool printResultNo = true) const { int resultNo = -1; auto lookupValue = value; @@ -1093,8 +1092,8 @@ protected: private: /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. - DenseMap<const SSAValue *, unsigned> valueIDs; - DenseMap<const SSAValue *, StringRef> valueNames; + DenseMap<const Value *, unsigned> valueIDs; + DenseMap<const Value *, StringRef> valueNames; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -1135,7 +1134,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) { os << "\"("; interleaveComma(op->getOperands(), - [&](const SSAValue *value) { printValueID(value); }); + [&](const Value *value) { printValueID(value); }); os << ')'; auto attrs = op->getAttrs(); @@ -1144,16 +1143,15 @@ void FunctionPrinter::printDefaultOp(const Operation *op) { // Print the type signature of the operation. os << " : ("; interleaveComma(op->getOperands(), - [&](const SSAValue *value) { printType(value->getType()); }); + [&](const Value *value) { printType(value->getType()); }); os << ") -> "; if (op->getNumResults() == 1) { printType(op->getResult(0)->getType()); } else { os << '('; - interleaveComma(op->getResults(), [&](const SSAValue *result) { - printType(result->getType()); - }); + interleaveComma(op->getResults(), + [&](const Value *result) { printType(result->getType()); }); os << ')'; } } @@ -1297,11 +1295,10 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) { os << '('; interleaveComma(range, - [this](const SSAValue *operand) { printValueID(operand); }); + [this](const Value *operand) { printValueID(operand); }); os << " : "; - interleaveComma(range, [this](const SSAValue *operand) { - printType(operand->getType()); - }); + interleaveComma( + range, [this](const Value *operand) { printType(operand->getType()); }); os << ')'; } @@ -1576,20 +1573,20 @@ void IntegerSet::print(raw_ostream &os) const { ModulePrinter(os, state).printIntegerSet(*this); } -void SSAValue::print(raw_ostream &os) const { +void Value::print(raw_ostream &os) const { switch (getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: // TODO: Improve this. os << "<block argument>\n"; return; - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: return getDefiningStmt()->print(os); - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: return cast<ForStmt>(this)->print(os); } } -void SSAValue::dump() const { print(llvm::errs()); } +void Value::dump() const { print(llvm::errs()); } void Instruction::print(raw_ostream &os) const { auto *function = getFunction(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0732448fb87..0b88216f66f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -281,7 +281,7 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) { // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. if (insertBefore) - function->getBlocks().insert(CFGFunction::iterator(insertBefore), b); + function->getBlocks().insert(Function::iterator(insertBefore), b); else function->push_back(b); @@ -291,16 +291,9 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) { /// Create an operation given the fields represented as an OperationState. OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) { - SmallVector<CFGValue *, 8> operands; - operands.reserve(state.operands.size()); - // Allow null operands as they act as sentinal barriers between successor - // operand lists. - for (auto elt : state.operands) - operands.push_back(cast_or_null<CFGValue>(elt)); - - auto *op = - OperationInst::create(state.location, state.name, operands, state.types, - state.attributes, state.successors, context); + auto *op = OperationInst::create(state.location, state.name, state.operands, + state.types, state.attributes, + state.successors, context); block->getStatements().insert(insertPoint, op); return op; } @@ -311,23 +304,17 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) { /// Create an operation given the fields represented as an OperationState. OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { - SmallVector<MLValue *, 8> operands; - operands.reserve(state.operands.size()); - for (auto elt : state.operands) - operands.push_back(cast<MLValue>(elt)); - - auto *op = - OperationStmt::create(state.location, state.name, operands, state.types, - state.attributes, state.successors, context); + auto *op = OperationStmt::create(state.location, state.name, state.operands, + state.types, state.attributes, + state.successors, context); block->getStatements().insert(insertPoint, op); return op; } ForStmt *MLFuncBuilder::createFor(Location location, - ArrayRef<MLValue *> lbOperands, - AffineMap lbMap, - ArrayRef<MLValue *> ubOperands, - AffineMap ubMap, int64_t step) { + ArrayRef<Value *> lbOperands, AffineMap lbMap, + ArrayRef<Value *> ubOperands, AffineMap ubMap, + int64_t step) { auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); block->getStatements().insert(insertPoint, stmt); @@ -341,7 +328,7 @@ ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub, return createFor(location, {}, lbMap, {}, ubMap, step); } -IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<MLValue *> operands, +IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands, IntegerSet set) { auto *stmt = IfStmt::create(location, operands, set); block->getStatements().insert(insertPoint, stmt); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index cdf98ca4bee..50ab254dd76 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -20,8 +20,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/raw_ostream.h" @@ -54,7 +54,7 @@ void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, // dimension operands parsed. // Returns 'false' on success and 'true' on error. bool mlir::parseDimAndSymbolList(OpAsmParser *parser, - SmallVector<SSAValue *, 4> &operands, + SmallVector<Value *, 4> &operands, unsigned &numDims) { SmallVector<OpAsmParser::OperandType, 8> opInfos; if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) @@ -76,7 +76,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, //===----------------------------------------------------------------------===// void AffineApplyOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef<SSAValue *> operands) { + AffineMap map, ArrayRef<Value *> operands) { result->addOperands(operands); result->types.append(map.getNumResults(), builder->getIndexType()); result->addAttribute("map", builder->getAffineMapAttr(map)); @@ -133,24 +133,22 @@ bool AffineApplyOp::verify() const { } // The result of the affine apply operation can be used as a dimension id if it -// is a CFG value or if it is an MLValue, and all the operands are valid +// is a CFG value or if it is an Value, and all the operands are valid // dimension ids. bool AffineApplyOp::isValidDim() const { for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidDim()) - return false; + if (!op->isValidDim()) + return false; } return true; } // The result of the affine apply operation can be used as a symbol if it is -// a CFG value or if it is an MLValue, and all the operands are symbols. +// a CFG value or if it is an Value, and all the operands are symbols. bool AffineApplyOp::isValidSymbol() const { for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidSymbol()) - return false; + if (!op->isValidSymbol()) + return false; } return true; } @@ -170,13 +168,13 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants, //===----------------------------------------------------------------------===// void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest, - ArrayRef<SSAValue *> operands) { + ArrayRef<Value *> operands) { result->addSuccessor(dest, operands); } bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { BasicBlock *dest; - SmallVector<SSAValue *, 4> destOperands; + SmallVector<Value *, 4> destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) return true; result->addSuccessor(dest, destOperands); @@ -212,17 +210,16 @@ void BranchOp::eraseOperand(unsigned index) { //===----------------------------------------------------------------------===// void CondBranchOp::build(Builder *builder, OperationState *result, - SSAValue *condition, BasicBlock *trueDest, - ArrayRef<SSAValue *> trueOperands, - BasicBlock *falseDest, - ArrayRef<SSAValue *> falseOperands) { + Value *condition, BasicBlock *trueDest, + ArrayRef<Value *> trueOperands, BasicBlock *falseDest, + ArrayRef<Value *> falseOperands) { result->addOperands(condition); result->addSuccessor(trueDest, trueOperands); result->addSuccessor(falseDest, falseOperands); } bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector<SSAValue *, 4> destOperands; + SmallVector<Value *, 4> destOperands; BasicBlock *dest; OpAsmParser::OperandType condInfo; @@ -446,7 +443,7 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result, //===----------------------------------------------------------------------===// void ReturnOp::build(Builder *builder, OperationState *result, - ArrayRef<SSAValue *> results) { + ArrayRef<Value *> results) { result->addOperands(results); } @@ -465,9 +462,10 @@ void ReturnOp::print(OpAsmPrinter *p) const { *p << ' '; p->printOperands(operand_begin(), operand_end()); *p << " : "; - interleave(operand_begin(), operand_end(), - [&](const SSAValue *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); + interleave( + operand_begin(), operand_end(), + [&](const Value *e) { p->printType(e->getType()); }, + [&]() { *p << ", "; }); } } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 3a537d03e8f..6f22b854fbf 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -23,7 +23,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Statements.h" - using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -96,13 +95,13 @@ unsigned Operation::getNumOperands() const { return llvm::cast<OperationStmt>(this)->getNumOperands(); } -SSAValue *Operation::getOperand(unsigned idx) { +Value *Operation::getOperand(unsigned idx) { return llvm::cast<OperationStmt>(this)->getOperand(idx); } -void Operation::setOperand(unsigned idx, SSAValue *value) { +void Operation::setOperand(unsigned idx, Value *value) { auto *stmt = llvm::cast<OperationStmt>(this); - stmt->setOperand(idx, llvm::cast<MLValue>(value)); + stmt->setOperand(idx, value); } /// Return the number of results this operation has. @@ -111,7 +110,7 @@ unsigned Operation::getNumResults() const { } /// Return the indicated result. -SSAValue *Operation::getResult(unsigned idx) { +Value *Operation::getResult(unsigned idx) { return llvm::cast<OperationStmt>(this)->getResult(idx); } @@ -585,8 +584,8 @@ bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { // These functions are out-of-line implementations of the methods in BinaryOp, // which avoids them being template instantiated/duplicated. -void impl::buildBinaryOp(Builder *builder, OperationState *result, - SSAValue *lhs, SSAValue *rhs) { +void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { assert(lhs->getType() == rhs->getType()); result->addOperands({lhs, rhs}); result->types.push_back(lhs->getType()); @@ -613,8 +612,8 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { // CastOp implementation //===----------------------------------------------------------------------===// -void impl::buildCastOp(Builder *builder, OperationState *result, - SSAValue *source, Type destType) { +void impl::buildCastOp(Builder *builder, OperationState *result, Value *source, + Type destType) { result->addOperands(source); result->addTypes(destType); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index a8b6aa1e738..9e4d8bb180c 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -16,8 +16,8 @@ // ============================================================================= #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Statements.h" +#include "mlir/IR/Value.h" using namespace mlir; PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { @@ -77,8 +77,8 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues, - ArrayRef<SSAValue *> valuesToRemoveIfDead) { +void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues, + ArrayRef<Value *> valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -97,15 +97,14 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues, /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, - ArrayRef<SSAValue *> valuesToRemoveIfDead) { + Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); - SmallVector<SSAValue *, 8> newResults(newOp->getResults().begin(), - newOp->getResults().end()); + SmallVector<Value *, 8> newResults(newOp->getResults().begin(), + newOp->getResults().end()); return replaceOp(op, newResults, valuesToRemoveIfDead); } @@ -118,7 +117,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp( /// should remove if they are dead at this point. /// void PatternRewriter::updatedRootInPlace( - Operation *op, ArrayRef<SSAValue *> valuesToRemoveIfDead) { + Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index 9a26149ea1d..09825093fde 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -1,4 +1,4 @@ -//===- SSAValue.cpp - MLIR SSAValue Classes ------------===// +//===- SSAValue.cpp - MLIR ValueClasses ------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,15 +15,15 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Function.h" #include "mlir/IR/Statements.h" +#include "mlir/IR/Value.h" using namespace mlir; /// If this value is the result of an Instruction, return the instruction /// that defines it. -OperationInst *SSAValue::getDefiningInst() { +OperationInst *Value::getDefiningInst() { if (auto *result = dyn_cast<InstResult>(this)) return result->getOwner(); return nullptr; @@ -31,13 +31,13 @@ OperationInst *SSAValue::getDefiningInst() { /// If this value is the result of an OperationStmt, return the statement /// that defines it. -OperationStmt *SSAValue::getDefiningStmt() { +OperationStmt *Value::getDefiningStmt() { if (auto *result = dyn_cast<StmtResult>(this)) return result->getOwner(); return nullptr; } -Operation *SSAValue::getDefiningOperation() { +Operation *Value::getDefiningOperation() { if (auto *inst = getDefiningInst()) return inst; if (auto *stmt = getDefiningStmt()) @@ -45,14 +45,14 @@ Operation *SSAValue::getDefiningOperation() { return nullptr; } -/// Return the function that this SSAValue is defined in. -Function *SSAValue::getFunction() { +/// Return the function that this Valueis defined in. +Function *Value::getFunction() { switch (getKind()) { - case SSAValueKind::BlockArgument: + case Value::Kind::BlockArgument: return cast<BlockArgument>(this)->getFunction(); - case SSAValueKind::StmtResult: + case Value::Kind::StmtResult: return getDefiningStmt()->getFunction(); - case SSAValueKind::ForStmt: + case Value::Kind::ForStmt: return cast<ForStmt>(this)->getFunction(); } } @@ -90,15 +90,6 @@ MLIRContext *IROperandOwner::getContext() const { } //===----------------------------------------------------------------------===// -// MLValue implementation. -//===----------------------------------------------------------------------===// - -/// Return the function that this MLValue is defined in. -MLFunction *MLValue::getFunction() { - return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction()); -} - -//===----------------------------------------------------------------------===// // BlockArgument implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 2a47eb56a28..63c2b26425f 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -85,18 +85,16 @@ MLFunction *Statement::getFunction() const { return block ? block->getFunction() : nullptr; } -MLValue *Statement::getOperand(unsigned idx) { - return getStmtOperand(idx).get(); -} +Value *Statement::getOperand(unsigned idx) { return getStmtOperand(idx).get(); } -const MLValue *Statement::getOperand(unsigned idx) const { +const Value *Statement::getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } -// MLValue can be used as a dimension id if it is valid as a symbol, or +// Value can be used as a dimension id if it is valid as a symbol, or // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. -bool MLValue::isValidDim() const { +bool Value::isValidDim() const { if (auto *stmt = getDefiningStmt()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) @@ -111,10 +109,10 @@ bool MLValue::isValidDim() const { return true; } -// MLValue can be used as a symbol if it is a constant, or it is defined at +// Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. -bool MLValue::isValidSymbol() const { +bool Value::isValidSymbol() const { if (auto *stmt = getDefiningStmt()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) @@ -129,7 +127,7 @@ bool MLValue::isValidSymbol() const { return isa<BlockArgument>(this); } -void Statement::setOperand(unsigned idx, MLValue *value) { +void Statement::setOperand(unsigned idx, Value *value) { getStmtOperand(idx).set(value); } @@ -271,7 +269,7 @@ void Statement::dropAllReferences() { /// Create a new OperationStmt with the specific fields. OperationStmt *OperationStmt::create(Location location, OperationName name, - ArrayRef<MLValue *> operands, + ArrayRef<Value *> operands, ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, ArrayRef<StmtBlock *> successors, @@ -420,8 +418,8 @@ void OperationInst::eraseOperand(unsigned index) { // ForStmt //===----------------------------------------------------------------------===// -ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands, - AffineMap lbMap, ArrayRef<MLValue *> ubOperands, +ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands, + AffineMap lbMap, ArrayRef<Value *> ubOperands, AffineMap ubMap, int64_t step) { assert(lbOperands.size() == lbMap.getNumInputs() && "lower bound operand count does not match the affine map"); @@ -444,9 +442,9 @@ ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands, ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, AffineMap ubMap, int64_t step) - : Statement(Kind::For, location), - MLValue(MLValueKind::ForStmt, - Type::getIndex(lbMap.getResult(0).getContext())), + : Statement(Statement::Kind::For, location), + Value(Value::Kind::ForStmt, + Type::getIndex(lbMap.getResult(0).getContext())), body(this), lbMap(lbMap), ubMap(ubMap), step(step) { // The body of a for stmt always has one block. @@ -462,11 +460,11 @@ const AffineBound ForStmt::getUpperBound() const { return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); } -void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) { +void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands()); + SmallVector<Value *, 4> ubOperands(getUpperBoundOperands()); operands.clear(); operands.reserve(lbOperands.size() + ubMap.getNumInputs()); @@ -479,11 +477,11 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) { this->lbMap = map; } -void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) { +void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); - SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands()); + SmallVector<Value *, 4> lbOperands(getLowerBoundOperands()); operands.clear(); operands.reserve(lbOperands.size() + ubOperands.size()); @@ -553,7 +551,7 @@ bool ForStmt::matchingBoundOperandList() const { unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare MLValue *'s. + // Compare Value *'s. if (getOperand(i) != getOperand(numOperands + i)) return false; } @@ -581,7 +579,7 @@ IfStmt::~IfStmt() { // allocated through MLIRContext's bump pointer allocator. } -IfStmt *IfStmt::create(Location location, ArrayRef<MLValue *> operands, +IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands, IntegerSet set) { unsigned numOperands = operands.size(); assert(numOperands == set.getNumOperands() && @@ -617,16 +615,16 @@ MLIRContext *IfStmt::getContext() const { /// them alone if no entry is present). Replaces references to cloned /// sub-statements to the corresponding statement that is copied, and adds /// those mappings to the map. -Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, +Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, MLIRContext *context) const { // If the specified value is in operandMap, return the remapped value. // Otherwise return the value itself. - auto remapOperand = [&](const MLValue *value) -> MLValue * { + auto remapOperand = [&](const Value *value) -> Value * { auto it = operandMap.find(value); - return it != operandMap.end() ? it->second : const_cast<MLValue *>(value); + return it != operandMap.end() ? it->second : const_cast<Value *>(value); }; - SmallVector<MLValue *, 8> operands; + SmallVector<Value *, 8> operands; SmallVector<StmtBlock *, 2> successors; if (auto *opStmt = dyn_cast<OperationStmt>(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); @@ -683,10 +681,9 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, auto ubMap = forStmt->getUpperBoundMap(); auto *newFor = ForStmt::create( - getLoc(), - ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap, - ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap, - forStmt->getStep()); + getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), + ubMap, forStmt->getStep()); // Remember the induction variable mapping. operandMap[forStmt] = newFor; @@ -716,6 +713,6 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, } Statement *Statement::clone(MLIRContext *context) const { - DenseMap<const MLValue *, MLValue *> operandMap; + DenseMap<const Value *, Value *> operandMap; return clone(operandMap, context); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d58d687ee0c..9852b69e91b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -42,7 +42,6 @@ #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include <algorithm> - using namespace mlir; using llvm::MemoryBuffer; using llvm::SMLoc; @@ -1890,10 +1889,10 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type); + Value *resolveSSAUse(SSAUseInfo useInfo, Type type); /// Register a definition of a value with the symbol table. - ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); + ParseResult addDefinition(SSAUseInfo useInfo, Value *value); // SSA parsing productions. ParseResult parseSSAUse(SSAUseInfo &result); @@ -1903,9 +1902,9 @@ public: ResultType parseSSADefOrUseAndType( const std::function<ResultType(SSAUseInfo, Type)> &action); - SSAValue *parseSSAUseAndType() { - return parseSSADefOrUseAndType<SSAValue *>( - [&](SSAUseInfo useInfo, Type type) -> SSAValue * { + Value *parseSSAUseAndType() { + return parseSSADefOrUseAndType<Value *>( + [&](SSAUseInfo useInfo, Type type) -> Value * { return resolveSSAUse(useInfo, type); }); } @@ -1920,9 +1919,8 @@ public: Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. - virtual bool - parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl<SSAValue *> &operands) = 0; + virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + SmallVectorImpl<Value *> &operands) = 0; protected: FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {} @@ -1934,24 +1932,23 @@ private: Kind kind; /// This keeps track of all of the SSA values we are tracking, indexed by /// their name. This has one entry per result number. - llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values; + llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values; /// These are all of the placeholders we've made along with the location of /// their first reference, to allow checking for use of undefined values. - DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders; + DenseMap<Value *, SMLoc> forwardReferencePlaceholders; - SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type); + Value *createForwardReferencePlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. - bool isForwardReferencePlaceholder(SSAValue *value) { + bool isForwardReferencePlaceholder(Value *value) { return forwardReferencePlaceholders.count(value); } }; } // end anonymous namespace /// Create and remember a new placeholder for a forward reference. -SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, - Type type) { +Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) { // Forward references are always created as instructions, even in ML // functions, because we just need something with a def/use chain. // @@ -1969,7 +1966,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { +Value *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = values[useInfo.name]; // If we have already seen a value of this name, return it. @@ -2010,7 +2007,7 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { } /// Register a definition of a value with the symbol table. -ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) { +ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) { auto &entries = values[useInfo.name]; // Make sure there is a slot for this value. @@ -2046,7 +2043,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) { // Check for any forward references that are left. If we find any, error // out. if (!forwardReferencePlaceholders.empty()) { - SmallVector<std::pair<const char *, SSAValue *>, 4> errors; + SmallVector<std::pair<const char *, Value *>, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardReferencePlaceholders) errors.push_back({entry.second.getPointer(), entry.first}); @@ -2399,9 +2396,8 @@ public: return false; } - bool - parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl<SSAValue *> &operands) override { + bool parseSuccessorAndUseList(BasicBlock *&dest, + SmallVectorImpl<Value *> &operands) override { // Defer successor parsing to the function parsers. return parser.parseSuccessorAndUseList(dest, operands); } @@ -2493,7 +2489,7 @@ public: llvm::SMLoc getNameLoc() const override { return nameLoc; } bool resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl<SSAValue *> &result) override { + SmallVectorImpl<Value *> &result) override { FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; if (auto *value = parser.resolveSSAUse(operandInfo, type)) { @@ -2573,7 +2569,7 @@ public: ParseResult parseFunctionBody(); bool parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl<SSAValue *> &operands); + SmallVectorImpl<Value *> &operands); private: CFGFunction *function; @@ -2636,7 +2632,7 @@ private: /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// bool CFGFunctionParser::parseSuccessorAndUseList( - BasicBlock *&dest, SmallVectorImpl<SSAValue *> &operands) { + BasicBlock *&dest, SmallVectorImpl<Value *> &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::bare_identifier)) return emitError("expected basic block name"); @@ -2790,10 +2786,10 @@ private: ParseResult parseForStmt(); ParseResult parseIntConstant(int64_t &val); - ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, + ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands, unsigned numDims, unsigned numOperands, const char *affineStructName); - ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map, + ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); ParseResult parseElseClause(StmtBlock *elseClause); @@ -2801,7 +2797,7 @@ private: ParseResult parseStmtBlock(StmtBlock *block); bool parseSuccessorAndUseList(BasicBlock *&dest, - SmallVectorImpl<SSAValue *> &operands) { + SmallVectorImpl<Value *> &operands) { assert(false && "MLFunctions do not have terminators with successors."); return true; } @@ -2838,7 +2834,7 @@ ParseResult MLFunctionParser::parseForStmt() { return ParseFailure; // Parse lower bound. - SmallVector<MLValue *, 4> lbOperands; + SmallVector<Value *, 4> lbOperands; AffineMap lbMap; if (parseBound(lbOperands, lbMap, /*isLower*/ true)) return ParseFailure; @@ -2847,7 +2843,7 @@ ParseResult MLFunctionParser::parseForStmt() { return ParseFailure; // Parse upper bound. - SmallVector<MLValue *, 4> ubOperands; + SmallVector<Value *, 4> ubOperands; AffineMap ubMap; if (parseBound(ubOperands, ubMap, /*isLower*/ false)) return ParseFailure; @@ -2913,7 +2909,7 @@ ParseResult MLFunctionParser::parseIntConstant(int64_t &val) { /// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? /// ParseResult -MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, +MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands, unsigned numDims, unsigned numOperands, const char *affineStructName) { if (parseToken(Token::l_paren, "expected '('")) @@ -2942,18 +2938,17 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, // Resolve SSA uses. Type indexType = builder.getIndexType(); for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { - SSAValue *sval = resolveSSAUse(opInfo[i], indexType); + Value *sval = resolveSSAUse(opInfo[i], indexType); if (!sval) return ParseFailure; - auto *v = cast<MLValue>(sval); - if (i < numDims && !v->isValidDim()) + if (i < numDims && !sval->isValidDim()) return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + "' cannot be used as a dimension id"); - if (i >= numDims && !v->isValidSymbol()) + if (i >= numDims && !sval->isValidSymbol()) return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + "' cannot be used as a symbol"); - operands.push_back(v); + operands.push_back(sval); } return ParseSuccess; @@ -2965,7 +2960,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, /// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list /// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal /// -ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands, +ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower) { // 'min' / 'max' prefixes are syntactic sugar. Ignore them. if (isLower) @@ -3003,7 +2998,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands, // TODO: improve error message when SSA value is not an affine integer. // Currently it is 'use of value ... expects different type than prior uses' if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) - operands.push_back(cast<MLValue>(value)); + operands.push_back(value); else return ParseFailure; @@ -3113,7 +3108,7 @@ ParseResult MLFunctionParser::parseIfStmt() { if (!set) return ParseFailure; - SmallVector<MLValue *, 4> operands; + SmallVector<Value *, 4> operands; if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), "integer set")) return ParseFailure; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 9613c56daf0..7611c6e741b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -23,8 +23,8 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -78,8 +78,8 @@ struct MemRefCastFolder : public RewritePattern { // AddFOp //===----------------------------------------------------------------------===// -void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs, - SSAValue *rhs) { +void AddFOp::build(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { assert(lhs->getType() == rhs->getType()); result->addOperands({lhs, rhs}); result->types.push_back(lhs->getType()); @@ -146,7 +146,7 @@ void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void AllocOp::build(Builder *builder, OperationState *result, - MemRefType memrefType, ArrayRef<SSAValue *> operands) { + MemRefType memrefType, ArrayRef<Value *> operands) { result->addOperands(operands); result->types.push_back(memrefType); } @@ -247,8 +247,8 @@ struct SimplifyAllocConst : public RewritePattern { // and keep track of the resultant memref type to build. SmallVector<int, 4> newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector<SSAValue *, 4> newOperands; - SmallVector<SSAValue *, 4> droppedOperands; + SmallVector<Value *, 4> newOperands; + SmallVector<Value *, 4> droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -301,7 +301,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void CallOp::build(Builder *builder, OperationState *result, Function *callee, - ArrayRef<SSAValue *> operands) { + ArrayRef<Value *> operands) { result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); result->addTypes(callee->getType().getResults()); @@ -370,7 +370,7 @@ bool CallOp::verify() const { //===----------------------------------------------------------------------===// void CallIndirectOp::build(Builder *builder, OperationState *result, - SSAValue *callee, ArrayRef<SSAValue *> operands) { + Value *callee, ArrayRef<Value *> operands) { auto fnType = callee->getType().cast<FunctionType>(); result->operands.push_back(callee); result->addOperands(operands); @@ -507,7 +507,7 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { } void CmpIOp::build(Builder *build, OperationState *result, - CmpIPredicate predicate, SSAValue *lhs, SSAValue *rhs) { + CmpIPredicate predicate, Value *lhs, Value *rhs) { result->addOperands({lhs, rhs}); result->types.push_back(getI1SameShape(build, lhs->getType())); result->addAttribute(getPredicateAttrName(), @@ -580,8 +580,7 @@ bool CmpIOp::verify() const { // DeallocOp //===----------------------------------------------------------------------===// -void DeallocOp::build(Builder *builder, OperationState *result, - SSAValue *memref) { +void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) { result->addOperands(memref); } @@ -615,7 +614,7 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void DimOp::build(Builder *builder, OperationState *result, - SSAValue *memrefOrTensor, unsigned index) { + Value *memrefOrTensor, unsigned index) { result->addOperands(memrefOrTensor); auto type = builder->getIndexType(); result->addAttribute("index", builder->getIntegerAttr(type, index)); @@ -689,11 +688,11 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands, // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState *result, - SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices, - SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices, - SSAValue *numElements, SSAValue *tagMemRef, - ArrayRef<SSAValue *> tagIndices, SSAValue *stride, - SSAValue *elementsPerStride) { + Value *srcMemRef, ArrayRef<Value *> srcIndices, + Value *destMemRef, ArrayRef<Value *> destIndices, + Value *numElements, Value *tagMemRef, + ArrayRef<Value *> tagIndices, Value *stride, + Value *elementsPerStride) { result->addOperands(srcMemRef); result->addOperands(srcIndices); result->addOperands(destMemRef); @@ -836,8 +835,8 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // --------------------------------------------------------------------------- void DmaWaitOp::build(Builder *builder, OperationState *result, - SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices, - SSAValue *numElements) { + Value *tagMemRef, ArrayRef<Value *> tagIndices, + Value *numElements) { result->addOperands(tagMemRef); result->addOperands(tagIndices); result->addOperands(numElements); @@ -896,8 +895,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void ExtractElementOp::build(Builder *builder, OperationState *result, - SSAValue *aggregate, - ArrayRef<SSAValue *> indices) { + Value *aggregate, ArrayRef<Value *> indices) { auto aggregateType = aggregate->getType().cast<VectorOrTensorType>(); result->addOperands(aggregate); result->addOperands(indices); @@ -955,8 +953,8 @@ bool ExtractElementOp::verify() const { // LoadOp //===----------------------------------------------------------------------===// -void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, - ArrayRef<SSAValue *> indices) { +void LoadOp::build(Builder *builder, OperationState *result, Value *memref, + ArrayRef<Value *> indices) { auto memrefType = memref->getType().cast<MemRefType>(); result->addOperands(memref); result->addOperands(indices); @@ -1130,9 +1128,8 @@ void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // SelectOp //===----------------------------------------------------------------------===// -void SelectOp::build(Builder *builder, OperationState *result, - SSAValue *condition, SSAValue *trueValue, - SSAValue *falseValue) { +void SelectOp::build(Builder *builder, OperationState *result, Value *condition, + Value *trueValue, Value *falseValue) { result->addOperands({condition, trueValue, falseValue}); result->addTypes(trueValue->getType()); } @@ -1201,8 +1198,8 @@ Attribute SelectOp::constantFold(ArrayRef<Attribute> operands, //===----------------------------------------------------------------------===// void StoreOp::build(Builder *builder, OperationState *result, - SSAValue *valueToStore, SSAValue *memref, - ArrayRef<SSAValue *> indices) { + Value *valueToStore, Value *memref, + ArrayRef<Value *> indices) { result->addOperands(valueToStore); result->addOperands(memref); result->addOperands(indices); diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 3b9f5ed1b3a..02b4c4674ab 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -72,10 +72,10 @@ static bool verifyPermutationMap(AffineMap permutationMap, } void VectorTransferReadOp::build(Builder *builder, OperationState *result, - VectorType vectorType, SSAValue *srcMemRef, - ArrayRef<SSAValue *> srcIndices, + VectorType vectorType, Value *srcMemRef, + ArrayRef<Value *> srcIndices, AffineMap permutationMap, - Optional<SSAValue *> paddingValue) { + Optional<Value *> paddingValue) { result->addOperands(srcMemRef); result->addOperands(srcIndices); if (paddingValue) { @@ -100,21 +100,20 @@ VectorTransferReadOp::getIndices() const { return {begin, end}; } -Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() { +Optional<Value *> VectorTransferReadOp::getPaddingValue() { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { return None; } - return Optional<SSAValue *>( - getOperand(Offsets::FirstIndexOffset + memRefRank)); + return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank)); } -Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const { +Optional<const Value *> VectorTransferReadOp::getPaddingValue() const { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { return None; } - return Optional<const SSAValue *>( + return Optional<const Value *>( getOperand(Offsets::FirstIndexOffset + memRefRank)); } @@ -136,7 +135,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) const { // Construct the FunctionType and print it. llvm::SmallVector<Type, 8> inputs{getMemRefType()}; // Must have at least one actual index, see verify. - const SSAValue *firstIndex = *(getIndices().begin()); + const Value *firstIndex = *(getIndices().begin()); Type indexType = firstIndex->getType(); inputs.append(getMemRefType().getRank(), indexType); if (optionalPaddingValue) { @@ -295,8 +294,8 @@ bool VectorTransferReadOp::verify() const { // VectorTransferWriteOp //===----------------------------------------------------------------------===// void VectorTransferWriteOp::build(Builder *builder, OperationState *result, - SSAValue *srcVector, SSAValue *dstMemRef, - ArrayRef<SSAValue *> dstIndices, + Value *srcVector, Value *dstMemRef, + ArrayRef<Value *> dstIndices, AffineMap permutationMap) { result->addOperands({srcVector, dstMemRef}); result->addOperands(dstIndices); @@ -457,7 +456,7 @@ bool VectorTransferWriteOp::verify() const { // VectorTypeCastOp //===----------------------------------------------------------------------===// void VectorTypeCastOp::build(Builder *builder, OperationState *result, - SSAValue *srcVector, Type dstType) { + Value *srcVector, Type dstType) { result->addOperands(srcVector); result->addTypes(dstType); } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 5c325dbd95d..a4d474dc24a 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -111,7 +111,7 @@ private: /// descriptor and get the pointer to the element indexed by the linearized /// subscript. Return nullptr on errors. llvm::Value *emitMemRefElementAccess( - const SSAValue *memRef, const Operation &op, + const Value *memRef, const Operation &op, llvm::iterator_range<Operation::const_operand_iterator> opIndices); /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create @@ -136,12 +136,12 @@ private: /// Create a single LLVM value of struct type that includes the list of /// given MLIR values. The `values` list must contain at least 2 elements. - llvm::Value *packValues(ArrayRef<const SSAValue *> values); + llvm::Value *packValues(ArrayRef<const Value *> values); /// Extract a list of `num` LLVM values from a `value` of struct type. SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num); llvm::DenseMap<const Function *, llvm::Function *> functionMapping; - llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping; + llvm::DenseMap<const Value *, llvm::Value *> valueMapping; llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping; llvm::LLVMContext &llvmContext; llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder; @@ -316,7 +316,7 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { } llvm::Value *ModuleLowerer::emitMemRefElementAccess( - const SSAValue *memRef, const Operation &op, + const Value *memRef, const Operation &op, llvm::iterator_range<Operation::const_operand_iterator> opIndices) { auto type = memRef->getType().dyn_cast<MemRefType>(); assert(type && "expected memRef value to have a MemRef type"); @@ -340,7 +340,7 @@ llvm::Value *ModuleLowerer::emitMemRefElementAccess( // Obtain the list of access subscripts as values and linearize it given the // list of sizes. auto indices = functional::map( - [this](const SSAValue *value) { return valueMapping.lookup(value); }, + [this](const Value *value) { return valueMapping.lookup(value); }, opIndices); auto subscript = linearizeSubscripts(indices, sizes); @@ -460,11 +460,11 @@ llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) { } // Create an undef struct value and insert individual values into it. -llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) { +llvm::Value *ModuleLowerer::packValues(ArrayRef<const Value *> values) { assert(values.size() > 1 && "cannot pack less than 2 values"); auto types = - functional::map([](const SSAValue *v) { return v->getType(); }, values); + functional::map([](const Value *v) { return v->getType(); }, values); llvm::Type *packedType = getPackedResultType(types); llvm::Value *packed = llvm::UndefValue::get(packedType); @@ -641,7 +641,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { return false; } if (auto dimOp = inst.dyn_cast<DimOp>()) { - const SSAValue *container = dimOp->getOperand(); + const Value *container = dimOp->getOperand(); MemRefType type = container->getType().dyn_cast<MemRefType>(); if (!type) return dimOp->emitError("only memref types are supported"); @@ -672,7 +672,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { if (auto callOp = inst.dyn_cast<CallOp>()) { auto operands = functional::map( - [this](const SSAValue *value) { return valueMapping.lookup(value); }, + [this](const Value *value) { return valueMapping.lookup(value); }, callOp->getOperands()); auto numResults = callOp->getNumResults(); llvm::Value *result = @@ -779,10 +779,9 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const SSAValue *getPHISourceValue(const BasicBlock *current, - const BasicBlock *pred, - unsigned numArguments, - unsigned index) { +static const Value *getPHISourceValue(const BasicBlock *current, + const BasicBlock *pred, + unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa<BranchOp>()) { return terminator.getOperand(index); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index d4a50a05989..53e633f53cd 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -30,13 +30,12 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. - SmallVector<SSAValue *, 8> existingConstants; + SmallVector<Value *, 8> existingConstants; // Operation statements that were folded and that need to be erased. std::vector<OperationStmt *> opStmtsToErase; - using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>; + using ConstantFactoryType = std::function<Value *(Attribute, Type)>; - bool foldOperation(Operation *op, - SmallVectorImpl<SSAValue *> &existingConstants, + bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants, ConstantFactoryType constantFactory); void visitOperationStmt(OperationStmt *stmt); void visitForStmt(ForStmt *stmt); @@ -54,9 +53,8 @@ char ConstantFold::passID = 0; /// /// This returns false if the operation was successfully folded. bool ConstantFold::foldOperation(Operation *op, - SmallVectorImpl<SSAValue *> &existingConstants, + SmallVectorImpl<Value *> &existingConstants, ConstantFactoryType constantFactory) { - // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast<ConstantOp>()) { @@ -114,7 +112,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { if (!inst) continue; - auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> Value * { builder.setInsertionPoint(inst); return builder.create<ConstantOp>(inst->getLoc(), value, type); }; @@ -142,7 +140,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { // Override the walker's operation statement visit for constant folding. void ConstantFold::visitOperationStmt(OperationStmt *stmt) { - auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> Value * { MLFuncBuilder builder(stmt); return builder.create<ConstantOp>(stmt->getLoc(), value, type); }; diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 4423891a4bf..ab8ee28ba7c 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -50,28 +50,28 @@ public: void visitOperationStmt(OperationStmt *opStmt); private: - CFGValue *getConstantIndexValue(int64_t value); + Value *getConstantIndexValue(int64_t value); void visitStmtBlock(StmtBlock *stmtBlock); - CFGValue *buildMinMaxReductionSeq( + Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range<Operation::result_iterator> values); CFGFunction *cfgFunc; CFGFuncBuilder builder; - // Mapping between original MLValues and lowered CFGValues. - llvm::DenseMap<const MLValue *, CFGValue *> valueRemapping; + // Mapping between original Values and lowered Values. + llvm::DenseMap<const Value *, Value *> valueRemapping; }; } // end anonymous namespace -// Return a vector of OperationStmt's arguments as SSAValues. For each -// statement operands, represented as MLValue, lookup its CFGValue conterpart in +// Return a vector of OperationStmt's arguments as Values. For each +// statement operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. -static llvm::SmallVector<SSAValue *, 4> +static llvm::SmallVector<mlir::Value *, 4> operandsAs(Statement *opStmt, - const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) { - llvm::SmallVector<SSAValue *, 4> operands; - for (const MLValue *operand : opStmt->getOperands()) { + const llvm::DenseMap<const Value *, Value *> &valueRemapping) { + llvm::SmallVector<Value *, 4> operands; + for (const Value *operand : opStmt->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } @@ -81,8 +81,8 @@ operandsAs(Statement *opStmt, // Convert an operation statement into an operation instruction. // // The operation description (name, number and types of operands or results) -// remains the same but the values must be updated to be CFGValues. Update the -// mapping MLValue->CFGValue as the conversion is performed. The operation +// remains the same but the values must be updated to be Values. Update the +// mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { // Set up basic operation state (context, name, operands). @@ -90,11 +90,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { opStmt->getName()); state.addOperands(operandsAs(opStmt, valueRemapping)); - // Set up operation return types. The corresponding SSAValues will become + // Set up operation return types. The corresponding Values will become // available after the operation is created. - state.addTypes( - functional::map([](SSAValue *result) { return result->getType(); }, - opStmt->getResults())); + state.addTypes(functional::map( + [](Value *result) { return result->getType(); }, opStmt->getResults())); // Copy attributes. for (auto attr : opStmt->getAttrs()) { @@ -112,10 +111,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { } } -// Create a CFGValue for the given integer constant of index type. -CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) { +// Create a Value for the given integer constant of index type. +Value *FunctionConverter::getConstantIndexValue(int64_t value) { auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value); - return cast<CFGValue>(op->getResult()); + return op->getResult(); } // Visit all statements in the given statement block. @@ -135,18 +134,18 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. -CFGValue *FunctionConverter::buildMinMaxReductionSeq( +Value *FunctionConverter::buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range<Operation::result_iterator> values) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); - CFGValue *value = cast<CFGValue>(*valueIt++); + Value *value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt); auto selectOp = builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt); - value = cast<CFGValue>(selectOp->getResult()); + value = selectOp->getResult(); } return value; @@ -231,9 +230,9 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block // argument remapping available to the following instructions. ForStatement - // is-a MLValue corresponding to the loop induction variable. + // is-a Value corresponding to the loop induction variable. builder.setInsertionPoint(loopConditionBlock); - CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType()); + Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); valueRemapping.insert(std::make_pair(forStmt, iv)); // Recursively construct loop body region. @@ -251,7 +250,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv); - CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult(0)); + Value *nextIvValue = stepOp->getResult(0); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); @@ -260,20 +259,19 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPoint(loopInitBlock); // Compute loop bounds using affine_apply after remapping its operands. - auto remapOperands = [this](const SSAValue *value) -> SSAValue * { - const MLValue *mlValue = dyn_cast<MLValue>(value); - return valueRemapping.lookup(mlValue); + auto remapOperands = [this](const Value *value) -> Value * { + return valueRemapping.lookup(value); }; auto operands = functional::map(remapOperands, forStmt->getLowerBoundOperands()); auto lbAffineApply = builder.create<AffineApplyOp>( forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); - CFGValue *lowerBound = buildMinMaxReductionSeq( + Value *lowerBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); auto ubAffineApply = builder.create<AffineApplyOp>( forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); - CFGValue *upperBound = buildMinMaxReductionSeq( + Value *upperBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, lowerBound); @@ -281,10 +279,10 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPoint(loopConditionBlock); auto comparisonOp = builder.create<CmpIOp>( forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); - auto comparisonResult = cast<CFGValue>(comparisonOp->getResult()); + auto comparisonResult = comparisonOp->getResult(); builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult, - loopBodyFirstBlock, ArrayRef<SSAValue *>(), - postLoopBlock, ArrayRef<SSAValue *>()); + loopBodyFirstBlock, ArrayRef<Value *>(), + postLoopBlock, ArrayRef<Value *>()); // Finally, make sure building can continue by setting the post-loop block // (end of loop SESE region) as the insertion point. @@ -401,7 +399,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // If the test succeeds, jump to the next block testing testing the next // conjunct of the condition in the similar way. When all conjuncts have been // handled, jump to the 'then' block instead. - SSAValue *zeroConstant = getConstantIndexValue(0); + Value *zeroConstant = getConstantIndexValue(0); ifConditionExtraBlocks.push_back(thenBlock); for (auto tuple : llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(), @@ -416,16 +414,16 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create<AffineApplyOp>( ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); - SSAValue *affResult = affineApplyOp->getResult(0); + Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create<CmpIOp>( ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(), - nextBlock, /*trueArgs*/ ArrayRef<SSAValue *>(), + nextBlock, /*trueArgs*/ ArrayRef<Value *>(), elseBlock, - /*falseArgs*/ ArrayRef<SSAValue *>()); + /*falseArgs*/ ArrayRef<Value *>()); builder.setInsertionPoint(nextBlock); } ifConditionExtraBlocks.pop_back(); @@ -468,10 +466,10 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // of the current region. The SESE invariant allows us to easily handle nested // structures of arbitrary complexity. // -// During the conversion, we maintain a mapping between the MLValues present in -// the original function and their CFGValue images in the function under -// construction. When an MLValue is used, it gets replaced with the -// corresponding CFGValue that has been defined previously. The value flow +// During the conversion, we maintain a mapping between the Values present in +// the original function and their Value images in the function under +// construction. When an Value is used, it gets replaced with the +// corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { auto outerBlock = builder.createBlock(); @@ -482,8 +480,8 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { outerBlock->addArguments(mlFunc->getType().getInputs()); assert(mlFunc->getNumArguments() == outerBlock->getNumArguments()); for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) { - const MLValue *mlArgument = mlFunc->getArgument(i); - CFGValue *cfgArgument = outerBlock->getArgument(i); + const Value *mlArgument = mlFunc->getArgument(i); + Value *cfgArgument = outerBlock->getArgument(i); valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 62cf55e37d9..917cd3d0c13 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -76,7 +76,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { // Map from original memref's to the DMA buffers that their accesses are // replaced with. - DenseMap<SSAValue *, SSAValue *> fastBufferMap; + DenseMap<Value *, Value *> fastBufferMap; // Slow memory space associated with DMAs. const unsigned slowMemorySpace; @@ -195,11 +195,11 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // Indices to use for the DmaStart op. // Indices for the original memref being DMAed from/to. - SmallVector<SSAValue *, 4> memIndices; + SmallVector<Value *, 4> memIndices; // Indices for the faster buffer being DMAed into/from. - SmallVector<SSAValue *, 4> bufIndices; + SmallVector<Value *, 4> bufIndices; - SSAValue *zeroIndex = top.create<ConstantIndexOp>(loc, 0); + Value *zeroIndex = top.create<ConstantIndexOp>(loc, 0); unsigned rank = memRefType.getRank(); SmallVector<int, 4> fastBufferShape; @@ -226,10 +226,10 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // DMA generation is being done. const FlatAffineConstraints *cst = region.getConstraints(); auto ids = cst->getIds(); - SmallVector<SSAValue *, 8> outerIVs; + SmallVector<Value *, 8> outerIVs; for (unsigned i = rank, e = ids.size(); i < e; i++) { auto id = cst->getIds()[i]; - assert(id.hasValue() && "MLValue id expected"); + assert(id.hasValue() && "Value id expected"); outerIVs.push_back(id.getValue()); } @@ -253,15 +253,15 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, // Set DMA start location for this dimension in the lower memory space // memref. if (auto caf = offset.dyn_cast<AffineConstantExpr>()) { - memIndices.push_back(cast<MLValue>( - top.create<ConstantIndexOp>(loc, caf.getValue())->getResult())); + memIndices.push_back( + top.create<ConstantIndexOp>(loc, caf.getValue())->getResult()); } else { // The coordinate for the start location is just the lower bound along the // corresponding dimension on the memory region (stored in 'offset'). auto map = top.getAffineMap( cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {}); - memIndices.push_back(cast<MLValue>( - b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0))); + memIndices.push_back( + b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0)); } // The fast buffer is DMAed into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); @@ -272,7 +272,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, } // The faster memory space buffer. - SSAValue *fastMemRef; + Value *fastMemRef; // Check if a buffer was already created. // TODO(bondhugula): union across all memory op's per buffer. For now assuming @@ -321,8 +321,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, return false; } - SSAValue *stride = nullptr; - SSAValue *numEltPerStride = nullptr; + Value *stride = nullptr; + Value *numEltPerStride = nullptr; if (!strideInfos.empty()) { stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride); numEltPerStride = @@ -362,7 +362,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); // *Only* those uses within the body of 'forStmt' are replaced. - replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef), + replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, /*domStmtFilter=*/&*forStmt->getBody()->begin()); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index e3609496cc5..c86eec3d276 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -83,22 +83,22 @@ FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { - access->memref = cast<MLValue>(loadOp->getMemRef()); + access->memref = loadOp->getMemRef(); access->opStmt = loadOrStoreOpStmt; auto loadMemrefType = loadOp->getMemRefType(); access->indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { - access->indices.push_back(cast<MLValue>(index)); + access->indices.push_back(index); } } else { assert(loadOrStoreOpStmt->isa<StoreOp>()); auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); access->opStmt = loadOrStoreOpStmt; - access->memref = cast<MLValue>(storeOp->getMemRef()); + access->memref = storeOp->getMemRef(); auto storeMemrefType = storeOp->getMemRefType(); access->indices.reserve(storeMemrefType.getRank()); for (auto *index : storeOp->getIndices()) { - access->indices.push_back(cast<MLValue>(index)); + access->indices.push_back(index); } } } @@ -178,20 +178,20 @@ public: Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(MLValue *memref) { + unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpStmt : loads) { - if (memref == cast<MLValue>(loadOpStmt->cast<LoadOp>()->getMemRef())) + if (memref == loadOpStmt->cast<LoadOp>()->getMemRef()) ++loadOpCount; } return loadOpCount; } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(MLValue *memref) { + unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpStmt : stores) { - if (memref == cast<MLValue>(storeOpStmt->cast<StoreOp>()->getMemRef())) + if (memref == storeOpStmt->cast<StoreOp>()->getMemRef()) ++storeOpCount; } return storeOpCount; @@ -203,7 +203,7 @@ public: // The id of the node at the other end of the edge. unsigned id; // The memref on which this edge represents a dependence. - MLValue *memref; + Value *memref; }; // Map from node id to Node. @@ -227,13 +227,13 @@ public: } // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. - void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + void addEdge(unsigned srcId, unsigned dstId, Value *memref) { outEdges[srcId].push_back({dstId, memref}); inEdges[dstId].push_back({srcId, memref}); } // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. - void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) { + void removeEdge(unsigned srcId, unsigned dstId, Value *memref) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); // Remove 'srcId' from 'inEdges[dstId]'. @@ -253,7 +253,7 @@ public: } // Returns the input edge count for node 'id' and 'memref'. - unsigned getInEdgeCount(unsigned id, MLValue *memref) { + unsigned getInEdgeCount(unsigned id, Value *memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) @@ -263,7 +263,7 @@ public: } // Returns the output edge count for node 'id' and 'memref'. - unsigned getOutEdgeCount(unsigned id, MLValue *memref) { + unsigned getOutEdgeCount(unsigned id, Value *memref) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) @@ -347,7 +347,7 @@ public: // dependence graph at a different depth. bool MemRefDependenceGraph::init(MLFunction *f) { unsigned id = 0; - DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses; + DenseMap<Value *, SetVector<unsigned>> memrefAccesses; for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { // Create graph node 'id' to represent top-level 'forStmt' and record @@ -360,12 +360,12 @@ bool MemRefDependenceGraph::init(MLFunction *f) { Node node(id++, &stmt); for (auto *opStmt : collector.loadOpStmts) { node.loads.push_back(opStmt); - auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef()); + auto *memref = opStmt->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opStmt : collector.storeOpStmts) { node.stores.push_back(opStmt); - auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef()); + auto *memref = opStmt->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); } nodes.insert({node.id, node}); @@ -375,7 +375,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // Create graph node for top-level load op. Node node(id++, &stmt); node.loads.push_back(opStmt); - auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef()); + auto *memref = opStmt->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } @@ -383,7 +383,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { // Create graph node for top-level store op. Node node(id++, &stmt); node.stores.push_back(opStmt); - auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef()); + auto *memref = opStmt->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } @@ -477,8 +477,7 @@ public: SmallVector<OperationStmt *, 4> loads = dstNode->loads; while (!loads.empty()) { auto *dstLoadOpStmt = loads.pop_back_val(); - auto *memref = - cast<MLValue>(dstLoadOpStmt->cast<LoadOp>()->getMemRef()); + auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef(); // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. if (dstNode->getLoadOpCount(memref) != 1) continue; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index b5c12865790..5f49ed217a2 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -85,10 +85,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, for (unsigned i = 0; i < width; i++) { auto lbOperands = origLoops[i]->getLowerBoundOperands(); auto ubOperands = origLoops[i]->getUpperBoundOperands(); - SmallVector<MLValue *, 4> newLbOperands(lbOperands.begin(), - lbOperands.end()); - SmallVector<MLValue *, 4> newUbOperands(ubOperands.begin(), - ubOperands.end()); + SmallVector<Value *, 4> newLbOperands(lbOperands.begin(), lbOperands.end()); + SmallVector<Value *, 4> newUbOperands(ubOperands.begin(), ubOperands.end()); newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap()); newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap()); newLoops[i]->setStep(tileSizes[i]); @@ -112,8 +110,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, // Construct the upper bound map; the operands are the original operands // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. - SmallVector<MLValue *, 4> ubOperands( - origLoops[i]->getUpperBoundOperands()); + SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands()); ubOperands.push_back(newLoops[i]); auto origUbMap = origLoops[i]->getUpperBoundMap(); @@ -191,8 +188,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector<MLValue *, 6> origLoopIVs(band.begin(), band.end()); - SmallVector<Optional<MLValue *>, 6> ids(band.begin(), band.end()); + SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end()); + SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index ffff1c5b615..2a121529ed9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -191,7 +191,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - DenseMap<const MLValue *, MLValue *> operandMap; + DenseMap<const Value *, Value *> operandMap; // Insert the cleanup loop right after 'forStmt'. MLFuncBuilder builder(forStmt->getBlock(), std::next(StmtBlock::iterator(forStmt))); @@ -219,7 +219,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Unroll and jam (appends unrollJamFactor-1 additional copies). for (unsigned i = 1; i < unrollJamFactor; i++) { - DenseMap<const MLValue *, MLValue *> operandMapping; + DenseMap<const Value *, Value *> operandMapping; // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -230,7 +230,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { auto *ivUnroll = builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) ->getResult(0); - operandMapping[forStmt] = cast<MLValue>(ivUnroll); + operandMapping[forStmt] = ivUnroll; } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index fd07619a165..013b5080367 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -29,17 +29,14 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" @@ -62,26 +59,26 @@ using namespace mlir; #define DEBUG_TYPE "lower-vector-transfers" -/// Creates the SSAValue for the sum of `a` and `b` without building a +/// Creates the Value for the sum of `a` and `b` without building a /// full-fledged AffineMap for all indices. /// /// Prerequisites: /// `a` and `b` must be of IndexType. -static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) { +static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) { assert(v->getType().isa<IndexType>() && "v must be of IndexType"); assert(w->getType().isa<IndexType>() && "w must be of IndexType"); auto *context = b->getContext(); auto d0 = getAffineDimExpr(0, context); auto d1 = getAffineDimExpr(1, context); auto map = AffineMap::get(2, 0, {d0 + d1}, {}); - return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w}) + return b->create<AffineApplyOp>(loc, map, ArrayRef<mlir::Value *>{v, w}) ->getResult(0); } namespace { struct LowerVectorTransfersState : public MLFuncGlobalLoweringState { // Top of the function constant zero index. - SSAValue *zero; + Value *zero; }; } // namespace @@ -131,7 +128,8 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // case of GPUs. if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) { b.create<StoreOp>(vecView->getLoc(), transfer->getVector(), - vecView->getResult(), ArrayRef<SSAValue *>{state->zero}); + vecView->getResult(), + ArrayRef<mlir::Value *>{state->zero}); } // 3. Emit the loop-nest. @@ -140,7 +138,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // TODO(ntv): Handle broadcast / slice properly. auto permutationMap = transfer->getPermutationMap(); SetVector<ForStmt *> loops; - SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices()); + SmallVector<Value *, 8> accessIndices(transfer->getIndices()); for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) { auto composed = composeWithUnboundedMap( getAffineDimExpr(it.index(), b.getContext()), permutationMap); @@ -168,17 +166,16 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // b. write scalar to local. auto scalarLoad = b.create<LoadOp>(transfer->getLoc(), transfer->getMemRef(), accessIndices); - b.create<StoreOp>( - transfer->getLoc(), scalarLoad->getResult(), - tmpScalarAlloc->getResult(), - functional::map([](SSAValue *val) { return val; }, loops)); + b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(), + tmpScalarAlloc->getResult(), + functional::map([](Value *val) { return val; }, loops)); } else { // VectorTransferWriteOp. // a. read scalar from local; // b. write scalar to remote. auto scalarLoad = b.create<LoadOp>( transfer->getLoc(), tmpScalarAlloc->getResult(), - functional::map([](SSAValue *val) { return val; }, loops)); + functional::map([](Value *val) { return val; }, loops)); b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(), transfer->getMemRef(), accessIndices); } @@ -186,11 +183,11 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // 5. Read the vector from local storage in case of a vector_transfer_read. // TODO(ntv): This vector_load operation should be further lowered in the // case of GPUs. - llvm::SmallVector<SSAValue *, 1> newResults = {}; + llvm::SmallVector<Value *, 1> newResults = {}; if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) { b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation())); auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(), - ArrayRef<SSAValue *>{state->zero}) + ArrayRef<Value *>{state->zero}) ->getResult(); newResults.push_back(vector); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index faea9953d86..a12c563fe1a 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -32,9 +32,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -192,7 +190,7 @@ struct MaterializationState { VectorType superVectorType; VectorType hwVectorType; SmallVector<unsigned, 8> hwVectorInstance; - DenseMap<const MLValue *, MLValue *> *substitutionsMap; + DenseMap<const Value *, Value *> *substitutionsMap; }; struct MaterializeVectorsPass : public FunctionPass { @@ -250,9 +248,9 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex, static OperationStmt * instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, - DenseMap<const MLValue *, MLValue *> *substitutionsMap); + DenseMap<const Value *, Value *> *substitutionsMap); -/// Not all SSAValue belong to a program slice scoped within the immediately +/// Not all Values belong to a program slice scoped within the immediately /// enclosing loop. /// One simple example is constants defined outside the innermost loop scope. /// For such cases the substitutionsMap has no entry and we allow an additional @@ -261,17 +259,16 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, /// indices and will need to be extended in the future. /// /// If substitution fails, returns nullptr. -static MLValue * -substitute(SSAValue *v, VectorType hwVectorType, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - auto it = substitutionsMap->find(cast<MLValue>(v)); +static Value *substitute(Value *v, VectorType hwVectorType, + DenseMap<const Value *, Value *> *substitutionsMap) { + auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { auto *opStmt = cast<OperationStmt>(v->getDefiningOperation()); if (opStmt->isa<ConstantOp>()) { MLFuncBuilder b(opStmt); auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); - auto res = substitutionsMap->insert( - std::make_pair(cast<MLValue>(v), cast<MLValue>(inst->getResult(0)))); + auto res = + substitutionsMap->insert(std::make_pair(v, inst->getResult(0))); assert(res.second && "Insertion failed"); return res.first->second; } @@ -336,10 +333,10 @@ substitute(SSAValue *v, VectorType hwVectorType, /// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. -static SmallVector<SSAValue *, 8> +static SmallVector<mlir::Value *, 8> reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, - ArrayRef<SSAValue *> memrefIndices) { + ArrayRef<Value *> memrefIndices) { auto vectorShape = hwVectorType.getShape(); assert(hwVectorInstance.size() >= vectorShape.size()); @@ -380,7 +377,7 @@ reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, // TODO(ntv): support a concrete map and composition. auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(), affineMap, memrefIndices); - return SmallVector<SSAValue *, 8>{app->getResults()}; + return SmallVector<mlir::Value *, 8>{app->getResults()}; } /// Returns attributes with the following substitutions applied: @@ -402,21 +399,21 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { /// Creates an instantiated version of `opStmt`. /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no -/// affine reindexing. Just substitute their SSAValue* operands and be done. For -/// this case the actual instance is irrelevant. Just use the SSA values in +/// affine reindexing. Just substitute their Value operands and be done. For +/// this case the actual instance is irrelevant. Just use the values in /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. static OperationStmt * instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + DenseMap<const Value *, Value *> *substitutionsMap) { assert(!opStmt->isa<VectorTransferReadOp>() && "Should call the function specialized for VectorTransferReadOp"); assert(!opStmt->isa<VectorTransferWriteOp>() && "Should call the function specialized for VectorTransferWriteOp"); bool fail = false; auto operands = map( - [hwVectorType, substitutionsMap, &fail](SSAValue *v) -> SSAValue * { + [hwVectorType, substitutionsMap, &fail](Value *v) -> Value * { auto *res = fail ? nullptr : substitute(v, hwVectorType, substitutionsMap); fail |= !res; @@ -481,9 +478,9 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, static OperationStmt * instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - SmallVector<SSAValue *, 8> indices = - map(makePtrDynCaster<SSAValue>(), read->getIndices()); + DenseMap<const Value *, Value *> *substitutionsMap) { + SmallVector<Value *, 8> indices = + map(makePtrDynCaster<Value>(), read->getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b->create<VectorTransferReadOp>( @@ -501,9 +498,9 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, static OperationStmt * instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - SmallVector<SSAValue *, 8> indices = - map(makePtrDynCaster<SSAValue>(), write->getIndices()); + DenseMap<const Value *, Value *> *substitutionsMap) { + SmallVector<Value *, 8> indices = + map(makePtrDynCaster<Value>(), write->getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b->create<VectorTransferWriteOp>( @@ -555,8 +552,8 @@ static bool instantiateMaterialization(Statement *stmt, } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); - state->substitutionsMap->insert(std::make_pair( - cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0)))); + state->substitutionsMap->insert( + std::make_pair(read->getResult(), clone->getResult(0))); return false; } // The only op with 0 results reaching this point must, by construction, be @@ -571,8 +568,8 @@ static bool instantiateMaterialization(Statement *stmt, if (!clone) { return true; } - state->substitutionsMap->insert(std::make_pair( - cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); + state->substitutionsMap->insert( + std::make_pair(opStmt->getResult(0), clone->getResult(0))); return false; } @@ -610,7 +607,7 @@ static bool emitSlice(MaterializationState *state, // Fresh RAII instanceIndices and substitutionsMap. MaterializationState scopedState = *state; scopedState.hwVectorInstance = delinearize(idx, *ratio); - DenseMap<const MLValue *, MLValue *> substitutionMap; + DenseMap<const Value *, Value *> substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *stmt : *slice) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 13d3ea92307..de1952ca0f5 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -32,7 +32,6 @@ #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" - #define DEBUG_TYPE "pipeline-data-transfer" using namespace mlir; @@ -80,7 +79,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' statement modulo 2. Returns false if such /// a replacement cannot be performed. -static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { +static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { auto *forBody = forStmt->getBody(); MLFuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -103,7 +102,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // Put together alloc operands for the dynamic dimensions of the memref. MLFuncBuilder bOuter(forStmt); - SmallVector<SSAValue *, 4> allocOperands; + SmallVector<Value *, 4> allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -114,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // Create and place the alloc right before the 'for' statement. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. - SSAValue *newMemRef = + Value *newMemRef = bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. @@ -126,8 +125,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // replaceAllMemRefUsesWith will always succeed unless the forStmt body has // non-deferencing uses of the memref. - if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef), - ivModTwoOp->getResult(0), AffineMap::Null(), {}, + if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), + AffineMap::Null(), {}, &*forStmt->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); @@ -225,8 +224,7 @@ static void findMatchingStartFinishStmts( continue; // We only double buffer if the buffer is not live out of loop. - const MLValue *memref = - cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); + auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { @@ -280,8 +278,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // dimension. for (auto &pair : startWaitPairs) { auto *dmaStartStmt = pair.first; - MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( - dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos())); + Value *oldMemRef = dmaStartStmt->getOperand( + dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()); if (!doubleBuffer(oldMemRef, forStmt)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. @@ -302,8 +300,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { auto *dmaFinishStmt = pair.second; - MLValue *oldTagMemRef = cast<MLValue>( - dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); + Value *oldTagMemRef = + dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)); if (!doubleBuffer(oldTagMemRef, forStmt)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); @@ -332,7 +330,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. SmallVector<OperationStmt *, 4> affineApplyStmts; - SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands()); + SmallVector<Value *, 4> operands(dmaStartStmt->getOperands()); getReachableAffineApplyOps(operands, affineApplyStmts); for (const auto *stmt : affineApplyStmts) { stmtShiftMap[stmt] = 0; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 0af7e52b5b1..9d955fb6a81 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -217,7 +217,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // If we already have a canonicalized version of this constant, just // reuse it. Otherwise create a new one. - SSAValue *cstValue; + Value *cstValue; auto it = uniquedConstants.find({resultConstants[i], res->getType()}); if (it != uniquedConstants.end()) cstValue = it->second->getResult(0); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5a5617f3fb1..e8fc5e7ca14 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -31,7 +31,6 @@ #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" - #define DEBUG_TYPE "LoopUtils" using namespace mlir; @@ -108,8 +107,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { forStmt->replaceAllUsesWith(constOp); } else { const AffineBound lb = forStmt->getLowerBound(); - SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(), - lb.operand_end()); + SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end()); MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); auto affineApplyOp = builder.create<AffineApplyOp>( forStmt->getLoc(), lb.getMap(), lbOperands); @@ -149,8 +147,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> &stmtGroupQueue, unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) { - SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands()); - SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands()); + SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands()); + SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); @@ -176,7 +174,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, srcForStmt->getStep() * shift)), loopChunk) ->getResult(0); - operandMap[srcForStmt] = cast<MLValue>(ivRemap); + operandMap[srcForStmt] = ivRemap; } else { operandMap[srcForStmt] = loopChunk; } @@ -380,7 +378,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { - DenseMap<const MLValue *, MLValue *> operandMap; + DenseMap<const Value *, Value *> operandMap; MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); @@ -414,7 +412,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { - DenseMap<const MLValue *, MLValue *> operandMap; + DenseMap<const Value *, Value *> operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. @@ -425,7 +423,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { auto *ivUnroll = builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) ->getResult(0); - operandMap[forStmt] = cast<MLValue>(ivUnroll); + operandMap[forStmt] = ivUnroll; } // Clone the original body of 'forStmt'. diff --git a/mlir/lib/Transforms/Utils/LoweringUtils.cpp b/mlir/lib/Transforms/Utils/LoweringUtils.cpp index c8ac881dba7..8457ce4ce28 100644 --- a/mlir/lib/Transforms/Utils/LoweringUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoweringUtils.cpp @@ -32,17 +32,17 @@ using namespace mlir; namespace { // Visit affine expressions recursively and build the sequence of instructions -// that correspond to it. Visitation functions return an SSAValue of the +// that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander - : public AffineExprVisitor<AffineApplyExpander, SSAValue *> { + : public AffineExprVisitor<AffineApplyExpander, Value *> { public: // This internal clsas expects arguments to be non-null, checks must be // performed at the call site. AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op) : builder(*builder), applyOp(*op), loc(op->getLoc()) {} - template <typename OpTy> SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) { + template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) @@ -51,33 +51,33 @@ public: return op->getResult(); } - SSAValue *visitAddExpr(AffineBinaryOpExpr expr) { + Value *visitAddExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr<AddIOp>(expr); } - SSAValue *visitMulExpr(AffineBinaryOpExpr expr) { + Value *visitMulExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr<MulIOp>(expr); } // TODO(zinenko): implement when the standard operators are made available. - SSAValue *visitModExpr(AffineBinaryOpExpr) { + Value *visitModExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: mod"); return nullptr; } - SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) { + Value *visitFloorDivExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: floor_div"); return nullptr; } - SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) { + Value *visitCeilDivExpr(AffineBinaryOpExpr) { builder.getContext()->emitError(loc, "unsupported binary operator: ceil_div"); return nullptr; } - SSAValue *visitConstantExpr(AffineConstantExpr expr) { + Value *visitConstantExpr(AffineConstantExpr expr) { auto valueAttr = builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = @@ -85,7 +85,7 @@ public: return op->getResult(); } - SSAValue *visitDimExpr(AffineDimExpr expr) { + Value *visitDimExpr(AffineDimExpr expr) { assert(expr.getPosition() < applyOp.getNumOperands() && "affine dim position out of range"); // FIXME: this assumes a certain order of AffineApplyOp operands, the @@ -93,7 +93,7 @@ public: return applyOp.getOperand(expr.getPosition()); } - SSAValue *visitSymbolExpr(AffineSymbolExpr expr) { + Value *visitSymbolExpr(AffineSymbolExpr expr) { // FIXME: this assumes a certain order of AffineApplyOp operands, the // cleaner interface would be to separate them at the op level. assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() < @@ -114,8 +114,8 @@ private: // Given an affine expression `expr` extracted from `op`, build the sequence of // primitive instructions that correspond to the affine expression in the // `builder`. -static SSAValue *expandAffineExpr(FuncBuilder *builder, AffineExpr expr, - AffineApplyOp *op) { +static mlir::Value *expandAffineExpr(FuncBuilder *builder, AffineExpr expr, + AffineApplyOp *op) { auto expander = AffineApplyExpander(builder, op); return expander.visit(expr); } @@ -127,7 +127,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) { FuncBuilder builder(op->getOperation()); auto affineMap = op->getAffineMap(); for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) { - SSAValue *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); + Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); if (!expanded) return true; op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 2818e8c2e4f..624a8a758b5 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -31,7 +31,6 @@ #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" - using namespace mlir; /// Return true if this operation dereferences one or more memref's. @@ -61,13 +60,12 @@ static bool isMemRefDereferencingOp(const Operation &op) { // extra operands, note that 'indexRemap' would just be applied to the existing // indices (%i, %j). // -// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily // extended to add additional indices at any position. -bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, - MLValue *newMemRef, - ArrayRef<SSAValue *> extraIndices, +bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, + ArrayRef<Value *> extraIndices, AffineMap indexRemap, - ArrayRef<SSAValue *> extraOperands, + ArrayRef<Value *> extraOperands, const Statement *domStmtFilter) { unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; // unused in opt mode @@ -128,16 +126,15 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, // operation. assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && "single result op's expected to generate these indices"); - assert((cast<MLValue>(extraIndex)->isValidDim() || - cast<MLValue>(extraIndex)->isValidSymbol()) && + assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) && "invalid memory op index"); - state.operands.push_back(cast<MLValue>(extraIndex)); + state.operands.push_back(extraIndex); } // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. - SmallVector<SSAValue *, 4> remapOperands; + SmallVector<Value *, 4> remapOperands; remapOperands.reserve(oldMemRefRank + extraOperands.size()); remapOperands.insert(remapOperands.end(), extraOperands.begin(), extraOperands.end()); @@ -149,11 +146,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, remapOperands); // Remapped indices. for (auto *index : remapOp->getOperation()->getResults()) - state.operands.push_back(cast<MLValue>(index)); + state.operands.push_back(index); } else { // No remapping specified. for (auto *index : remapOperands) - state.operands.push_back(cast<MLValue>(index)); + state.operands.push_back(index); } // Insert the remaining operands unmodified. @@ -191,9 +188,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, // composed AffineApplyOp are returned in output parameter 'results'. OperationStmt * mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, - ArrayRef<MLValue *> operands, + ArrayRef<Value *> operands, ArrayRef<OperationStmt *> affineApplyOps, - SmallVectorImpl<SSAValue *> *results) { + SmallVectorImpl<Value *> *results) { // Create identity map with same number of dimensions as number of operands. auto map = builder->getMultiDimIdentityMap(operands.size()); // Initialize AffineValueMap with identity map. @@ -208,7 +205,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, // Compose affine maps from all ancestor AffineApplyOps. // Create new AffineApplyOp from 'valueMap'. unsigned numOperands = valueMap.getNumOperands(); - SmallVector<SSAValue *, 4> outOperands(numOperands); + SmallVector<Value *, 4> outOperands(numOperands); for (unsigned i = 0; i < numOperands; ++i) { outOperands[i] = valueMap.getOperand(i); } @@ -252,7 +249,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// otherwise. OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { // Collect all operands that are results of affine apply ops. - SmallVector<MLValue *, 4> subOperands; + SmallVector<Value *, 4> subOperands; subOperands.reserve(opStmt->getNumOperands()); for (auto *operand : opStmt->getOperands()) { auto *defStmt = operand->getDefiningStmt(); @@ -285,7 +282,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { return nullptr; FuncBuilder builder(opStmt); - SmallVector<SSAValue *, 4> results; + SmallVector<Value *, 4> results; auto *affineApplyStmt = createComposedAffineApplyOp( &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results); assert(results.size() == subOperands.size() && @@ -295,7 +292,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { // affine apply op above instead of existing ones (subOperands). So, they // differ from opStmt's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'results'. - SmallVector<MLValue *, 4> newOperands(opStmt->getOperands()); + SmallVector<Value *, 4> newOperands(opStmt->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -304,7 +301,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { break; } if (j < subOperands.size()) { - newOperands[i] = cast<MLValue>(results[j]); + newOperands[i] = results[j]; } } @@ -326,7 +323,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { // into any uses which are AffineApplyOps. for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; ++resultIndex) { - const MLValue *result = opStmt->getResult(resultIndex); + const Value *result = opStmt->getResult(resultIndex); for (auto it = result->use_begin(); it != result->use_end();) { StmtOperand &use = *(it++); auto *useStmt = use.getOwner(); @@ -347,7 +344,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { // Create new AffineApplyOp from 'valueMap'. unsigned numOperands = valueMap.getNumOperands(); - SmallVector<SSAValue *, 4> operands(numOperands); + SmallVector<Value *, 4> operands(numOperands); for (unsigned i = 0; i < numOperands; ++i) { operands[i] = valueMap.getOperand(i); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index aa80f47b826..9fe002c8fcb 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -27,8 +27,6 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" -#include "mlir/IR/MLValue.h" -#include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -740,8 +738,8 @@ struct VectorizationState { DenseSet<OperationStmt *> vectorizedSet; // Map of old scalar OperationStmt to new vectorized OperationStmt. DenseMap<OperationStmt *, OperationStmt *> vectorizationMap; - // Map of old scalar MLValue to new vectorized MLValue. - DenseMap<const MLValue *, MLValue *> replacementMap; + // Map of old scalar Value to new vectorized Value. + DenseMap<const Value *, Value *> replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -761,7 +759,7 @@ struct VectorizationState { void registerTerminator(OperationStmt *stmt); private: - void registerReplacement(const SSAValue *key, SSAValue *value); + void registerReplacement(const Value *key, Value *value); }; } // end namespace @@ -802,12 +800,9 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(const SSAValue *key, - SSAValue *value) { - assert(replacementMap.count(cast<MLValue>(key)) == 0 && - "replacement already registered"); - replacementMap.insert( - std::make_pair(cast<MLValue>(key), cast<MLValue>(value))); +void VectorizationState::registerReplacement(const Value *key, Value *value) { + assert(replacementMap.count(key) == 0 && "replacement already registered"); + replacementMap.insert(std::make_pair(key, value)); } ////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// @@ -825,7 +820,7 @@ void VectorizationState::registerReplacement(const SSAValue *key, /// Such special cases force us to delay the vectorization of the stores /// until the last step. Here we merely register the store operation. template <typename LoadOrStoreOpPointer> -static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, +static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { auto memRefType = memoryOp->getMemRef()->getType().template cast<MemRefType>(); @@ -850,8 +845,7 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, MLFuncBuilder b(opStmt); auto transfer = b.create<VectorTransferReadOp>( opStmt->getLoc(), vectorType, memoryOp->getMemRef(), - map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()), - permutationMap); + map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap); state->registerReplacement(opStmt, cast<OperationStmt>(transfer->getOperation())); } else { @@ -970,8 +964,8 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, - Type type) { +static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, + Type type) { if (!type || !type.isa<VectorType>() || !VectorType::isValidElementType(constant.getType())) { return nullptr; @@ -988,7 +982,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, {make_pair(Identifier::get("value", b.getContext()), attr)}); auto *splat = cast<OperationStmt>(b.createOperation(state)); - return cast<MLValue>(splat->getResult(0)); + return splat->getResult(0); } /// Returns a uniqu'ed VectorType. @@ -996,7 +990,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, /// vectorizedSet, just returns the type of `v`. /// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// and of elemental type the type of `v`. -static Type getVectorType(SSAValue *v, const VectorizationState &state) { +static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } @@ -1028,23 +1022,23 @@ static Type getVectorType(SSAValue *v, const VectorizationState &state) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt, - VectorizationState *state) { +static Value *vectorizeOperand(Value *operand, Statement *stmt, + VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); auto *definingStatement = cast<OperationStmt>(operand->getDefiningStmt()); // 1. If this value has already been vectorized this round, we are done. if (state->vectorizedSet.count(definingStatement) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); - return cast<MLValue>(operand); + return operand; } // 1.b. Delayed on-demand replacement of a use. // Note that we cannot just call replaceAllUsesWith because it may result // in ops with mixed types, for ops whose operands have not all yet // been vectorized. This would be invalid IR. - auto it = state->replacementMap.find(cast<MLValue>(operand)); + auto it = state->replacementMap.find(operand); if (it != state->replacementMap.end()) { - auto *res = cast<MLValue>(it->second); + auto *res = it->second; LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); LLVM_DEBUG(res->print(dbgs())); return res; @@ -1089,7 +1083,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, auto *memRef = store->getMemRef(); auto *value = store->getValueToStore(); auto *vectorValue = vectorizeOperand(value, opStmt, state); - auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices()); + auto indices = map(makePtrDynCaster<Value>(), store->getIndices()); MLFuncBuilder b(opStmt); auto permutationMap = makePermutationMap(opStmt, state->strategy->loopToVectorDim); @@ -1104,14 +1098,14 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, return res; } - auto types = map([state](SSAValue *v) { return getVectorType(v, *state); }, + auto types = map([state](Value *v) { return getVectorType(v, *state); }, opStmt->getResults()); - auto vectorizeOneOperand = [opStmt, state](SSAValue *op) -> SSAValue * { + auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * { return vectorizeOperand(op, opStmt, state); }; auto operands = map(vectorizeOneOperand, opStmt->getOperands()); // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(operands, [](SSAValue *op) { return op; }); + bool success = llvm::all_of(operands, [](Value *op) { return op; }); if (!success) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); return nullptr; @@ -1207,7 +1201,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, continue; } MLFuncBuilder builder(loop); // builder to insert in place of loop - DenseMap<const MLValue *, MLValue *> nomap; + DenseMap<const Value *, Value *> nomap; ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap)); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match @@ -1229,8 +1223,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, // Form the root operationsthat have been set in the replacementMap. // For now, these roots are the loads for which vector_transfer_read // operations have been inserted. - auto getDefiningOperation = [](const MLValue *val) { - return const_cast<MLValue *>(val)->getDefiningOperation(); + auto getDefiningOperation = [](const Value *val) { + return const_cast<Value *>(val)->getDefiningOperation(); }; using ReferenceTy = decltype(*(state.replacementMap.begin())); auto getKey = [](ReferenceTy it) { return it.first; }; |

