summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-10-04 17:15:30 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:23:19 -0700
commit6cfdb756b165ebd32068506fc70d623f86bb80b3 (patch)
tree2ac173f35c2b0b7913a243d0b5eed17c78c5a09b /mlir
parentb55b4076011419c8d8d8cac58c8fda7631067bb2 (diff)
downloadbcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.tar.gz
bcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.zip
Introduce memref replacement/rewrite support: to replace an existing memref
with a new one (of a potentially different rank/shape) with an optional index remapping. - introduce Utils::replaceAllMemRefUsesWith - use this for DMA double buffering (This CL also adds a few temporary utilities / code that will be done away with once: 1) abstract DMA op's are added 2) memref deferencing side-effect / trait is available on op's 3) b/117159533 is resolved (memref index computation slices). PiperOrigin-RevId: 215831373
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/Builders.h12
-rw-r--r--mlir/include/mlir/Transforms/Utils.h49
-rw-r--r--mlir/lib/IR/Builders.cpp12
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp235
-rw-r--r--mlir/lib/Transforms/Utils.cpp165
-rw-r--r--mlir/test/Transforms/pipeline-data-transfer.mlir131
6 files changed, 518 insertions, 86 deletions
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index a37b4b2d138..334422f05c8 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -343,15 +343,21 @@ public:
return MLFuncBuilder(forStmt, forStmt->end());
}
- /// Get the current insertion point of the builder.
+ /// Returns the current insertion point of the builder.
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
- /// Get the current block of the builder.
+ /// Returns the current block of the builder.
StmtBlock *getBlock() const { return block; }
- /// Create an operation given the fields represented as an OperationState.
+ /// Creates an operation given the fields represented as an OperationState.
OperationStmt *createOperation(const OperationState &state);
+ /// Creates an operation given the fields.
+ OperationStmt *createOperation(Location *location, Identifier name,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<Type *> types,
+ ArrayRef<NamedAttribute> attrs);
+
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location *location, Args... args) {
diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h
new file mode 100644
index 00000000000..64e9b72b2ad
--- /dev/null
+++ b/mlir/include/mlir/Transforms/Utils.h
@@ -0,0 +1,49 @@
+//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes for various transformation utilities for
+// memref's and non-loop IR structures. These are not passes by themselves but
+// are used either by passes, optimization sequences, or in turn by other
+// transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_UTILS_H
+#define MLIR_TRANSFORMS_UTILS_H
+
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+
+class AffineMap;
+class MLValue;
+class SSAValue;
+
+/// Replace all uses of oldMemRef with newMemRef while optionally remapping the
+/// old memref's indices using the supplied affine map and adding any additional
+/// indices. The new memref could be of a different shape or rank. Returns true
+/// on success and false if the replacement is not possible (whenever a memref
+/// is used as an operand in a non-deferencing scenario).
+/// Additional indices are added at the start.
+// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
+// extended to add additional indices at any position.
+bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
+ llvm::ArrayRef<SSAValue *> extraIndices,
+ AffineMap *indexRemap = nullptr);
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_UTILS_H
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 8ca0fe8318f..deac801f822 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -312,6 +312,18 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
return op;
}
+/// Create an operation given the fields.
+OperationStmt *MLFuncBuilder::createOperation(Location *location,
+ Identifier name,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<Type *> types,
+ ArrayRef<NamedAttribute> attrs) {
+ auto *op = OperationStmt::create(location, name, operands, types, attrs,
+ getContext());
+ block->getStatements().insert(insertPoint, op);
+ return op;
+}
+
ForStmt *MLFuncBuilder::createFor(Location *location,
ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap,
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 96d706e98ac..5aaac1c6c29 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -21,10 +21,13 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@@ -43,27 +46,237 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// For testing purposes, this just runs on the first statement of the MLFunction
-// if that statement is a for stmt, and shifts the second half of its body by
-// one.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
+// op traits for it are added. TODO(b/117228571)
+static bool isDmaStartStmt(const OperationStmt &stmt) {
+ return stmt.getName().strref().contains("dma.in.start") ||
+ stmt.getName().strref().contains("dma.out.start");
+}
+
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+static bool isDmaFinishStmt(const OperationStmt &stmt) {
+ return stmt.getName().strref().contains("dma.finish");
+}
+
+/// Given a DMA start operation, returns the operand position of either the
+/// source or destination memref depending on the one that is at the higher
+/// level of the memory hierarchy.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
+ assert(isDmaStartStmt(dmaStartStmt));
+ unsigned srcDmaPos = 0;
+ unsigned destDmaPos =
+ cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
+
+ if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
+ ->getMemorySpace() >
+ cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
+ ->getMemorySpace())
+ return srcDmaPos;
+ return destDmaPos;
+}
+
+// Returns the position of the tag memref operand given a DMA statement.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+ assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
+ if (isDmaStartStmt(dmaStmt)) {
+ // Second to last operand.
+ return dmaStmt.getNumOperands() - 2;
+ }
+ // First operand for a dma finish statement.
+ return 0;
+}
+
+/// Doubles the buffer of the supplied memref.
+static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
+ MLFuncBuilder bInner(forStmt, forStmt->begin());
+ bInner.setInsertionPoint(forStmt, forStmt->begin());
+
+ // Doubles the shape with a leading dimension extent of 2.
+ auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
+ // Add the leading dimension in the shape for the double buffer.
+ ArrayRef<int> shape = origMemRefType->getShape();
+ SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
+ shapeSizes.insert(shapeSizes.begin(), 2);
+
+ auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
+ return newMemRefType;
+ };
+
+ auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
+
+ // Create and place the alloc at the top level.
+ auto *func = forStmt->getFunction();
+ MLFuncBuilder topBuilder(func, func->begin());
+ auto *newMemRef = cast<MLValue>(
+ topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
+ ->getResult());
+
+ auto d0 = bInner.getDimExpr(0);
+ auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
+ auto ivModTwoOp =
+ bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
+ if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
+ return false;
+ // We don't need ivMod2Op any more - this is cloned by
+ // replaceAllMemRefUsesWith wherever the memref replacement happens. Once
+ // b/117159533 is addressed, we'll eventually only need to pass
+ // ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith.
+ cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
+ return true;
+}
+
+// For testing purposes, this just runs on the first for statement of an
+// MLFunction at the top level.
+// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
+// the other TODOs listed inside are dealt with.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
- auto *forStmt = dyn_cast<ForStmt>(&f->front());
+
+ ForStmt *forStmt = nullptr;
+ for (auto &stmt : *f) {
+ if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
+ break;
+ }
+ }
if (!forStmt)
- return PassResult::Failure;
+ return PassResult::Success;
unsigned numStmts = forStmt->getStatements().size();
+
if (numStmts == 0)
return PassResult::Success;
- std::vector<uint64_t> delays(numStmts);
- for (unsigned i = 0; i < numStmts; i++)
- delays[i] = (i < numStmts / 2) ? 0 : 1;
+ SmallVector<OperationStmt *, 4> dmaStartStmts;
+ SmallVector<OperationStmt *, 4> dmaFinishStmts;
+ for (auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ if (!opStmt)
+ continue;
+ if (isDmaStartStmt(*opStmt)) {
+ dmaStartStmts.push_back(opStmt);
+ } else if (isDmaFinishStmt(*opStmt)) {
+ dmaFinishStmts.push_back(opStmt);
+ }
+ }
+
+ // TODO(bondhugula,andydavis): match tag memref's (requires memory-based
+ // subscript check utilities). Assume for now that start/finish are matched in
+ // the order they appear.
+ if (dmaStartStmts.size() != dmaFinishStmts.size())
+ return PassResult::Failure;
+
+ // Double the buffers for the higher memory space memref's.
+ // TODO(bondhugula): assuming we don't have multiple DMA starts for the same
+ // memref.
+ // TODO(bondhugula): check whether double-buffering is even necessary.
+ // TODO(bondhugula): make this work with different layouts: assuming here that
+ // the dimension we are adding here for the double buffering is the outermost
+ // dimension.
+ // Identify memref's to replace by scanning through all DMA start statements.
+ // A DMA start statement has two memref's - the one from the higher level of
+ // memory hierarchy is the one to double buffer.
+ for (auto *dmaStartStmt : dmaStartStmts) {
+ MLValue *oldMemRef = cast<MLValue>(
+ dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
+ if (!doubleBuffer(oldMemRef, forStmt))
+ return PassResult::Failure;
+ }
+
+ // Double the buffers for tag memref's.
+ for (auto *dmaFinishStmt : dmaFinishStmts) {
+ MLValue *oldTagMemRef = cast<MLValue>(
+ dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
+ if (!doubleBuffer(oldTagMemRef, forStmt))
+ return PassResult::Failure;
+ }
+
+ // Collect all compute ops.
+ std::vector<const Statement *> computeOps;
+ computeOps.reserve(forStmt->getStatements().size());
+ // Store delay for statement for later lookup for AffineApplyOp's.
+ DenseMap<const Statement *, unsigned> opDelayMap;
+ for (const auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ if (!opStmt) {
+ // All for and if stmt's are treated as pure compute operations.
+ // TODO(bondhugula): check whether such statements do not have any DMAs
+ // nested within.
+ opDelayMap[&stmt] = 1;
+ } else if (isDmaStartStmt(*opStmt)) {
+ // DMA starts are not shifted.
+ opDelayMap[&stmt] = 0;
+ } else if (isDmaFinishStmt(*opStmt)) {
+ // DMA finish op shifted by one.
+ opDelayMap[&stmt] = 1;
+ } else if (!opStmt->is<AffineApplyOp>()) {
+ // Compute op shifted by one.
+ opDelayMap[&stmt] = 1;
+ computeOps.push_back(&stmt);
+ }
+ // Shifts for affine apply op's determined later.
+ }
+
+ // Get the ancestor of a 'stmt' that lies in forStmt's block.
+ auto getAncestorInForBlock =
+ [&](const Statement *stmt, const StmtBlock &block) -> const Statement * {
+ // Traverse up the statement hierarchy starting from the owner of operand to
+ // find the ancestor statement that resides in the block of 'forStmt'.
+ while (stmt != nullptr && stmt->getBlock() != &block) {
+ stmt = stmt->getParentStmt();
+ }
+ return stmt;
+ };
+
+ // Determine delays for affine apply op's: look up delay from its consumer op.
+ // This code will be thrown away once we have a way to obtain indices through
+ // a composed affine_apply op. See TODO(b/117159533). Such a composed
+ // affine_apply will be used exclusively by a given memref deferencing op.
+ for (const auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ // Skip statements that aren't affine apply ops.
+ if (!opStmt || !opStmt->is<AffineApplyOp>())
+ continue;
+ // Traverse uses of each result of the affine apply op.
+ for (auto *res : opStmt->getResults()) {
+ for (auto &use : res->getUses()) {
+ auto *ancestorInForBlock =
+ getAncestorInForBlock(use.getOwner(), *forStmt);
+ assert(ancestorInForBlock &&
+ "traversing parent should reach forStmt block");
+ auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock);
+ if (!opCheck || opCheck->is<AffineApplyOp>())
+ continue;
+ assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end());
+ if (opDelayMap.find(&stmt) != opDelayMap.end()) {
+ // This is where we enforce all uses of this affine_apply to have
+ // the same shifts - so that we know what shift to use for the
+ // affine_apply to preserve semantics.
+ assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]);
+ } else {
+ // Obtain delay from its consumer.
+ opDelayMap[&stmt] = opDelayMap[ancestorInForBlock];
+ }
+ }
+ }
+ }
+
+ // Get delays stored in map.
+ std::vector<uint64_t> delays(forStmt->getStatements().size());
+ unsigned s = 0;
+ for (const auto &stmt : *forStmt) {
+ delays[s++] = opDelayMap[&stmt];
+ }
- if (!checkDominancePreservationOnShift(*forStmt, delays))
+ if (!checkDominancePreservationOnShift(*forStmt, delays)) {
// Violates SSA dominance.
return PassResult::Failure;
+ }
if (stmtBodySkew(forStmt, delays))
return PassResult::Failure;
diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp
new file mode 100644
index 00000000000..d189ac26703
--- /dev/null
+++ b/mlir/lib/Transforms/Utils.cpp
@@ -0,0 +1,165 @@
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardOps.h"
+#include "llvm/ADT/DenseMap.h"
+
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(const Operation &op) {
+ if (op.is<LoadOp>() || op.is<StoreOp>() ||
+ op.getName().strref().contains("dma.in.start") ||
+ op.getName().strref().contains("dma.out.start") ||
+ op.getName().strref().contains("dma.finish")) {
+ return true;
+ }
+ return false;
+}
+
+/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
+/// old memref's indices to the new memref using the supplied affine map
+/// and adding any additional indices. The new memref could be of a different
+/// shape or rank, but of the same elemental type. Additional indices are added
+/// at the start for now.
+// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
+// extended to add additional indices at any position.
+bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
+ ArrayRef<SSAValue *> extraIndices,
+ AffineMap *indexRemap) {
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+ if (indexRemap) {
+ assert(indexRemap->getNumInputs() == oldMemRefRank);
+ assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
+ cast<MemRefType>(newMemRef->getType())->getElementType());
+
+ // Check if memref was used in a non-deferencing context.
+ for (const StmtOperand &use : oldMemRef->getUses()) {
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ // Failure: memref used in a non-deferencing op (potentially escapes); no
+ // replacement in these cases.
+ if (!isMemRefDereferencingOp(*opStmt))
+ return false;
+ }
+
+ // Walk all uses of old memref. Statement using the memref gets replaced.
+ for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
+ StmtOperand &use = *(it++);
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ assert(isMemRefDereferencingOp(*opStmt) &&
+ "memref deferencing op expected");
+
+ auto getMemRefOperandPos = [&]() -> unsigned {
+ unsigned i;
+ for (i = 0; i < opStmt->getNumOperands(); i++) {
+ if (opStmt->getOperand(i) == oldMemRef)
+ break;
+ }
+ assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
+ return i;
+ };
+ unsigned memRefOperandPos = getMemRefOperandPos();
+
+ // Construct the new operation statement using this memref.
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(opStmt->getNumOperands() + extraIndices.size());
+ // Insert the non-memref operands.
+ operands.insert(operands.end(), opStmt->operand_begin(),
+ opStmt->operand_begin() + memRefOperandPos);
+ operands.push_back(newMemRef);
+
+ MLFuncBuilder builder(opStmt);
+ // Normally, we could just use extraIndices as operands, but we will
+ // clone it so that each op gets its own "private" index. See b/117159533.
+ for (auto *extraIndex : extraIndices) {
+ OperationStmt::OperandMapTy operandMap;
+ // TODO(mlir-team): An operation/SSA value should provide a method to
+ // return the position of an SSA result in its defining
+ // operation.
+ assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
+ "single result op's expected to generate these indices");
+ // TODO: actually check if this is a result of an affine_apply op.
+ assert((cast<MLValue>(extraIndex)->isValidDim() ||
+ cast<MLValue>(extraIndex)->isValidSymbol()) &&
+ "invalid memory op index");
+ auto *clonedExtraIndex =
+ cast<OperationStmt>(
+ builder.clone(*extraIndex->getDefiningStmt(), operandMap))
+ ->getResult(0);
+ operands.push_back(cast<MLValue>(clonedExtraIndex));
+ }
+
+ // Construct new indices. The indices of a memref come right after it, i.e.,
+ // at position memRefOperandPos + 1.
+ SmallVector<SSAValue *, 4> indices(
+ opStmt->operand_begin() + memRefOperandPos + 1,
+ opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
+ if (indexRemap) {
+ auto remapOp =
+ builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
+ // Remapped indices.
+ for (auto *index : remapOp->getOperation()->getResults())
+ operands.push_back(cast<MLValue>(index));
+ } else {
+ // No remapping specified.
+ for (auto *index : indices)
+ operands.push_back(cast<MLValue>(index));
+ }
+
+ // Insert the remaining operands unmodified.
+ operands.insert(operands.end(),
+ opStmt->operand_begin() + memRefOperandPos + 1 +
+ oldMemRefRank,
+ opStmt->operand_end());
+
+ // Result types don't change. Both memref's are of the same elemental type.
+ SmallVector<Type *, 8> resultTypes;
+ resultTypes.reserve(opStmt->getNumResults());
+ for (const auto *result : opStmt->getResults())
+ resultTypes.push_back(result->getType());
+
+ // Create the new operation.
+ auto *repOp =
+ builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs());
+ // Replace old memref's deferencing op's uses.
+ unsigned r = 0;
+ for (auto *res : opStmt->getResults()) {
+ res->replaceAllUsesWith(repOp->getResult(r++));
+ }
+ opStmt->eraseFromBlock();
+ }
+ return true;
+}
diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir
index a866e0474c1..7bef1e04c2c 100644
--- a/mlir/test/Transforms/pipeline-data-transfer.mlir
+++ b/mlir/test/Transforms/pipeline-data-transfer.mlir
@@ -1,79 +1,66 @@
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
+// CHECK: #map0 = (d0) -> (d0 mod 2)
+// CHECK-NEXT: #map1 = (d0) -> (d0 - 1)
+// CHECK-NEXT: mlfunc @loop_nest_dma() {
+// CHECK-NEXT: %c8 = constant 8 : affineint
+// CHECK-NEXT: %c0 = constant 0 : affineint
+// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
+// CHECK-NEXT: %1 = alloc() : memref<2x32xf32>
+// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)>
+// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1>
+// CHECK-NEXT: %4 = alloc() : memref<1xf32>
+// CHECK-NEXT: %c0_0 = constant 0 : affineint
+// CHECK-NEXT: %c128 = constant 128 : affineint
+// CHECK-NEXT: %5 = affine_apply #map0(%c0)
+// CHECK-NEXT: %6 = affine_apply #map0(%c0)
+// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
+// CHECK-NEXT: for %i0 = 1 to 7 {
+// CHECK-NEXT: %7 = affine_apply #map0(%i0)
+// CHECK-NEXT: %8 = affine_apply #map0(%i0)
+// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
+// CHECK-NEXT: %9 = affine_apply #map1(%i0)
+// CHECK-NEXT: %10 = affine_apply #map0(%9)
+// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
+// CHECK-NEXT: %12 = affine_apply #map0(%9)
+// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32>
+// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
+// CHECK-NEXT: %15 = affine_apply #map0(%9)
+// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32>
+// CHECK-NEXT: for %i1 = 0 to 127 {
+// CHECK-NEXT: "do_more_compute"(%9, %i1) : (affineint, affineint) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: %16 = affine_apply #map1(%c8)
+// CHECK-NEXT: %17 = affine_apply #map0(%16)
+// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
+// CHECK-NEXT: %19 = affine_apply #map0(%16)
+// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32>
+// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32
+// CHECK-NEXT: %22 = affine_apply #map0(%16)
+// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32>
+// CHECK-NEXT: for %i2 = 0 to 127 {
+// CHECK-NEXT: "do_more_compute"(%16, %i2) : (affineint, affineint) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+mlfunc @loop_nest_dma() {
-// CHECK-LABEL: mlfunc @loop_nest_simple() {
-// CHECK: %c8 = constant 8 : affineint
-// CHECK-NEXT: %c0 = constant 0 : affineint
-// CHECK-NEXT: %0 = "foo"(%c0) : (affineint) -> affineint
-// CHECK-NEXT: for %i0 = 1 to 7 {
-// CHECK-NEXT: %1 = "foo"(%i0) : (affineint) -> affineint
-// CHECK-NEXT: %2 = affine_apply #map0(%i0)
-// CHECK-NEXT: %3 = "bar"(%2) : (affineint) -> affineint
-// CHECK-NEXT: }
-// CHECK-NEXT: %4 = affine_apply #map0(%c8)
-// CHECK-NEXT: %5 = "bar"(%4) : (affineint) -> affineint
-// CHECK-NEXT: return
-mlfunc @loop_nest_simple() {
- for %i = 0 to 7 {
- %y = "foo"(%i) : (affineint) -> affineint
- %x = "bar"(%i) : (affineint) -> affineint
- }
- return
-}
+ %A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
+ %Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
-// CHECK-LABEL: mlfunc @loop_nest_dma() {
-// CHECK: %c8 = constant 8 : affineint
-// CHECK-NEXT: %c0 = constant 0 : affineint
-// CHECK-NEXT: %0 = affine_apply #map1(%c0)
-// CHECK-NEXT: %1 = "dma.enqueue"(%0) : (affineint) -> affineint
-// CHECK-NEXT: %2 = "dma.enqueue"(%0) : (affineint) -> affineint
-// CHECK-NEXT: for %i0 = 1 to 7 {
-// CHECK-NEXT: %3 = affine_apply #map1(%i0)
-// CHECK-NEXT: %4 = "dma.enqueue"(%3) : (affineint) -> affineint
-// CHECK-NEXT: %5 = "dma.enqueue"(%3) : (affineint) -> affineint
-// CHECK-NEXT: %6 = affine_apply #map0(%i0)
-// CHECK-NEXT: %7 = affine_apply #map1(%6)
-// CHECK-NEXT: %8 = "dma.wait"(%7) : (affineint) -> affineint
-// CHECK-NEXT: %9 = "compute1"(%7) : (affineint) -> affineint
-// CHECK-NEXT: }
-// CHECK-NEXT: %10 = affine_apply #map0(%c8)
-// CHECK-NEXT: %11 = affine_apply #map1(%10)
-// CHECK-NEXT: %12 = "dma.wait"(%11) : (affineint) -> affineint
-// CHECK-NEXT: %13 = "compute1"(%11) : (affineint) -> affineint
-// CHECK-NEXT: return
-mlfunc @loop_nest_dma() {
- for %i = 0 to 7 {
- %pingpong = affine_apply (d0) -> (d0 mod 2) (%i)
- "dma.enqueue"(%pingpong) : (affineint) -> affineint
- "dma.enqueue"(%pingpong) : (affineint) -> affineint
- %pongping = affine_apply (d0) -> (d0 mod 2) (%i)
- "dma.wait"(%pongping) : (affineint) -> affineint
- "compute1"(%pongping) : (affineint) -> affineint
- }
- return
-}
+ %tag = alloc() : memref<1 x f32>
+
+ %zero = constant 0 : affineint
+ %size = constant 128 : affineint
-// CHECK-LABEL: mlfunc @loop_nest_bound_map(%arg0 : affineint) {
-// CHECK: %0 = affine_apply #map2()[%arg0]
-// CHECK-NEXT: %1 = "foo"(%0) : (affineint) -> affineint
-// CHECK-NEXT: %2 = "bar"(%0) : (affineint) -> affineint
-// CHECK-NEXT: for %i0 = #map3()[%arg0] to #map4()[%arg0] {
-// CHECK-NEXT: %3 = "foo"(%i0) : (affineint) -> affineint
-// CHECK-NEXT: %4 = "bar"(%i0) : (affineint) -> affineint
-// CHECK-NEXT: %5 = affine_apply #map0(%i0)
-// CHECK-NEXT: %6 = "foo_bar"(%5) : (affineint) -> affineint
-// CHECK-NEXT: %7 = "bar_foo"(%5) : (affineint) -> affineint
-// CHECK-NEXT: }
-// CHECK-NEXT: %8 = affine_apply #map5()[%arg0]
-// CHECK-NEXT: %9 = affine_apply #map0(%8)
-// CHECK-NEXT: %10 = "foo_bar"(%9) : (affineint) -> affineint
-// CHECK-NEXT: %11 = "bar_foo"(%9) : (affineint) -> affineint
-// CHECK-NEXT: return
-mlfunc @loop_nest_bound_map(%N : affineint) {
- for %i = %N to ()[s0] -> (s0 + 7)()[%N] {
- "foo"(%i) : (affineint) -> affineint
- "bar"(%i) : (affineint) -> affineint
- "foo_bar"(%i) : (affineint) -> (affineint)
- "bar_foo"(%i) : (affineint) -> (affineint)
+ for %i = 0 to 7 {
+ "dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, affineint, memref<32 x f32, (d0)->(d0), 1>, affineint, affineint, memref<1 x f32>, affineint) -> ()
+ "dma.finish"(%tag, %zero) : (memref<1 x f32>, affineint) -> affineint
+ %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
+ %r = "compute"(%v) : (f32) -> (f32)
+ store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
+ for %j = 0 to 127 {
+ "do_more_compute"(%i, %j) : (affineint, affineint) -> ()
+ }
}
return
}
OpenPOWER on IntegriCloud