summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-12-10 15:17:25 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:25:07 -0700
commitb9f53dc0bde323c6418bd4037e3bda93b6c33316 (patch)
treeb233f97427d494ea755235594bae53d16270dcee /mlir/lib
parentd59a95a05c4fdf15a5b676d852f6b790a931494e (diff)
downloadbcm5719-llvm-b9f53dc0bde323c6418bd4037e3bda93b6c33316.tar.gz
bcm5719-llvm-b9f53dc0bde323c6418bd4037e3bda93b6c33316.zip
Update/Fix LoopUtils::stmtBodySkew to handle loop step.
- loop step wasn't handled and there wasn't a TODO or an assertion; fix this. - rename 'delay' to shift for consistency/readability. - other readability changes. - remove duplicate attribute print for DmaStartOp; fix misplaced attribute print for DmaWaitOp - add build method for AddFOp (unrelated to this CL, but add it anyway) PiperOrigin-RevId: 224892958
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp10
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp32
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp120
3 files changed, 88 insertions, 74 deletions
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index 69a3370fafd..7304223d6b0 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -79,6 +79,13 @@ struct MemRefCastFolder : public RewritePattern {
// AddFOp
//===----------------------------------------------------------------------===//
+void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
+ SSAValue *rhs) {
+ assert(lhs->getType() == rhs->getType());
+ result->addOperands({lhs, rhs});
+ result->types.push_back(lhs->getType());
+}
+
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "addf takes two operands");
@@ -718,7 +725,6 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
*p << " : " << getSrcMemRef()->getType();
*p << ", " << getDstMemRef()->getType();
*p << ", " << getTagMemRef()->getType();
- p->printOptionalAttrDict(getAttrs());
}
// Parse DmaStartOp.
@@ -846,8 +852,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices());
*p << "], ";
p->printOperand(getNumElements());
- *p << " : " << getTagMemRef()->getType();
p->printOptionalAttrDict(getAttrs());
+ *p << " : " << getTagMemRef()->getType();
}
// Parse DmaWaitOp.
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index a3455de2039..b656af0d69d 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -318,15 +318,15 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
startWaitPairs.clear();
findMatchingStartFinishStmts(forStmt, startWaitPairs);
- // Store delay for statement for later lookup for AffineApplyOp's.
- DenseMap<const Statement *, unsigned> stmtDelayMap;
+ // Store shift for statement for later lookup for AffineApplyOp's.
+ DenseMap<const Statement *, unsigned> stmtShiftMap;
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
assert(dmaStartStmt->isa<DmaStartOp>());
- stmtDelayMap[dmaStartStmt] = 0;
+ stmtShiftMap[dmaStartStmt] = 0;
// Set shifts for DMA start stmt's affine operand computation slices to 0.
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
- stmtDelayMap[slice] = 0;
+ stmtShiftMap[slice] = 0;
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
@@ -334,39 +334,39 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) {
- stmtDelayMap[stmt] = 0;
+ stmtShiftMap[stmt] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (const auto &stmt : *forStmt) {
- if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
- stmtDelayMap[&stmt] = 1;
+ if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
+ stmtShiftMap[&stmt] = 1;
}
}
- // Get delays stored in map.
- std::vector<uint64_t> delays(forStmt->getStatements().size());
+ // Get shifts stored in map.
+ std::vector<uint64_t> shifts(forStmt->getStatements().size());
unsigned s = 0;
for (auto &stmt : *forStmt) {
- assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
- delays[s++] = stmtDelayMap[&stmt];
+ assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
+ shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(
- // Tagging statements with delays for debugging purposes.
+ // Tagging statements with shifts for debugging purposes.
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
MLFuncBuilder b(opStmt);
- opStmt->setAttr(b.getIdentifier("delay"),
- b.getIntegerAttr(delays[s - 1]));
+ opStmt->setAttr(b.getIdentifier("shift"),
+ b.getIntegerAttr(shifts[s - 1]));
});
}
- if (!isStmtwiseShiftValid(*forStmt, delays)) {
+ if (!isStmtwiseShiftValid(*forStmt, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
- if (stmtBodySkew(forStmt, delays)) {
+ if (stmtBodySkew(forStmt, shifts)) {
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
return success();
}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 55cb64b8ba9..791997e7ff1 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -137,45 +137,50 @@ void mlir::promoteSingleIterationLoops(MLFunction *f) {
fsw.walkPostOrder(f);
}
-/// Generates a for 'stmt' with the specified lower and upper bounds while
-/// generating the right IV remappings for the delayed statements. The
+/// Generates a 'for' stmt with the specified lower and upper bounds while
+/// generating the right IV remappings for the shifted statements. The
/// statement blocks that go into the loop are specified in stmtGroupQueue
/// starting from the specified offset, and in that order; the first element of
-/// the pair specifies the delay applied to that group of statements. Returns
+/// the pair specifies the shift applied to that group of statements; note that
+/// the shift is multiplied by the loop step before being applied. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
static ForStmt *
-generateLoop(AffineMap lb, AffineMap ub,
+generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
- auto *loopChunk =
- b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
+ assert(lbMap.getNumInputs() == lbOperands.size());
+ assert(ubMap.getNumInputs() == ubOperands.size());
+
+ auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, srcForStmt->getStep());
+
OperationStmt::OperandMapTy operandMap;
for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
it != e; ++it) {
- auto elt = *it;
- // All 'same delay' statements get added with the operands being remapped
- // (to results of cloned statements).
- // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
- // TODO(bondhugula): check if srcForStmt is actually used in elt.second
- // instead of just checking if it's used at all.
- if (!srcForStmt->use_empty() && elt.first != 0) {
+ uint64_t shift = it->first;
+ auto stmts = it->second;
+ // All 'same shift' statements get added with their operands being remapped
+ // to results of cloned statements, and their IV used remapped.
+ // Generate the remapping if the shift is not zero: remappedIV = newIV -
+ // shift.
+ if (!srcForStmt->use_empty() && shift != 0) {
auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
- auto *oldIV =
- b.create<AffineApplyOp>(
- srcForStmt->getLoc(),
- b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
- loopChunk)
- ->getResult(0);
- operandMap[srcForStmt] = cast<MLValue>(oldIV);
+ auto *ivRemap = b.create<AffineApplyOp>(
+ srcForStmt->getLoc(),
+ b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
+ srcForStmt->getStep() * shift)),
+ loopChunk)
+ ->getResult(0);
+ operandMap[srcForStmt] = cast<MLValue>(ivRemap);
} else {
- operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
+ operandMap[srcForStmt] = loopChunk;
}
- for (auto *stmt : elt.second) {
+ for (auto *stmt : stmts) {
loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
}
}
@@ -185,12 +190,13 @@ generateLoop(AffineMap lb, AffineMap ub,
}
/// Skew the statements in the body of a 'for' statement with the specified
-/// statement-wise delays. The delays are with respect to the original execution
-/// order. A delay of zero for each statement will lead to no change.
+/// statement-wise shifts. The shifts are with respect to the original execution
+/// order, and are multiplied by the loop 'step' before being applied. A shift
+/// of zero for each statement will lead to no change.
// The skewing of statements with respect to one another can be used for example
// to allow overlap of asynchronous operations (such as DMA communication) with
// computation, or just relative shifting of statements for better register
-// reuse, locality or parallelism. As such, the delays are typically expected to
+// reuse, locality or parallelism. As such, the shifts are typically expected to
// be at most of the order of the number of statements. This method should not
// be used as a substitute for loop distribution/fission.
// This method uses an algorithm// in time linear in the number of statements in
@@ -198,7 +204,7 @@ generateLoop(AffineMap lb, AffineMap ub,
// asserts preservation of SSA dominance. A check for that as well as that for
// memory-based depedence preservation check rests with the users of this
// method.
-UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
+UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
if (forStmt->getStatements().empty())
return UtilResult::Success;
@@ -214,30 +220,32 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
}
uint64_t tripCount = mayBeConstTripCount.getValue();
- assert(isStmtwiseShiftValid(*forStmt, delays) &&
+ assert(isStmtwiseShiftValid(*forStmt, shifts) &&
"shifts will lead to an invalid transformation\n");
+ int64_t step = forStmt->getStep();
+
unsigned numChildStmts = forStmt->getStatements().size();
- // Do a linear time (counting) sort for the delays.
- uint64_t maxDelay = 0;
+ // Do a linear time (counting) sort for the shifts.
+ uint64_t maxShift = 0;
for (unsigned i = 0; i < numChildStmts; i++) {
- maxDelay = std::max(maxDelay, delays[i]);
+ maxShift = std::max(maxShift, shifts[i]);
}
- // Such large delays are not the typical use case.
- if (maxDelay >= numChildStmts) {
- LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
+ // Such large shifts are not the typical use case.
+ if (maxShift >= numChildStmts) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";);
return UtilResult::Success;
}
- // An array of statement groups sorted by delay amount; each group has all
- // statements with the same delay in the order in which they appear in the
+ // An array of statement groups sorted by shift amount; each group has all
+ // statements with the same shift in the order in which they appear in the
// body of the 'for' stmt.
- std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
+ std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
unsigned pos = 0;
for (auto &stmt : *forStmt) {
- auto delay = delays[pos++];
- sortedStmtGroups[delay].push_back(&stmt);
+ auto shift = shifts[pos++];
+ sortedStmtGroups[shift].push_back(&stmt);
}
// Unless the shifts have a specific pattern (which actually would be the
@@ -248,45 +256,45 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
ForStmt *prologue = nullptr;
ForStmt *epilogue = nullptr;
- // Do a sweep over the sorted delays while storing open groups in a
+ // Do a sweep over the sorted shifts while storing open groups in a
// vector, and generating loop portions as necessary during the sweep. A block
- // of statements is paired with its delay.
+ // of statements is paired with its shift.
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
auto origLbMap = forStmt->getLowerBoundMap();
- uint64_t lbDelay = 0;
+ uint64_t lbShift = 0;
MLFuncBuilder b(forStmt);
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
- // If nothing is delayed by d, continue.
+ // If nothing is shifted by d, continue.
if (sortedStmtGroups[d].empty())
continue;
if (!stmtGroupQueue.empty()) {
assert(d >= 1 &&
"Queue expected to be empty when the first block is found");
// The interval for which the loop needs to be generated here is:
- // ( lbDelay, min(lbDelay + tripCount, d)) and the body of the
+ // [lbShift, min(lbShift + tripCount, d)) and the body of the
// loop needs to have all statements in stmtQueue in that order.
ForStmt *res;
- if (lbDelay + tripCount < d) {
- res =
- generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
- b.getShiftedAffineMap(origLbMap, lbDelay + tripCount),
- stmtGroupQueue, 0, forStmt, &b);
+ if (lbShift + tripCount * step < d * step) {
+ res = generateLoop(
+ b.getShiftedAffineMap(origLbMap, lbShift),
+ b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
+ stmtGroupQueue, 0, forStmt, &b);
// Entire loop for the queued stmt groups generated, empty it.
stmtGroupQueue.clear();
- lbDelay += tripCount;
+ lbShift += tripCount * step;
} else {
- res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+ res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue,
0, forStmt, &b);
- lbDelay = d;
+ lbShift = d * step;
}
if (!prologue && res)
prologue = res;
epilogue = res;
} else {
// Start of first interval.
- lbDelay = d;
+ lbShift = d * step;
}
// Augment the list of statements that get into the current open interval.
stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
@@ -295,11 +303,11 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
// Those statements groups left in the queue now need to be processed (FIFO)
// and their loops completed.
for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
- uint64_t ubDelay = stmtGroupQueue[i].first + tripCount;
- epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
- b.getShiftedAffineMap(origLbMap, ubDelay),
+ uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step;
+ epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+ b.getShiftedAffineMap(origLbMap, ubShift),
stmtGroupQueue, i, forStmt, &b);
- lbDelay = ubDelay;
+ lbShift = ubShift;
if (!prologue)
prologue = epilogue;
}
OpenPOWER on IntegriCloud