summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/StandardOps.h121
-rw-r--r--mlir/lib/IR/Operation.cpp7
-rw-r--r--mlir/lib/IR/StandardOps.cpp135
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp47
-rw-r--r--mlir/lib/Transforms/Utils.cpp7
-rw-r--r--mlir/test/Transforms/pipeline-data-transfer.mlir38
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp3
7 files changed, 295 insertions, 63 deletions
diff --git a/mlir/include/mlir/IR/StandardOps.h b/mlir/include/mlir/IR/StandardOps.h
index 9aa6875f13a..018d75a0cbe 100644
--- a/mlir/include/mlir/IR/StandardOps.h
+++ b/mlir/include/mlir/IR/StandardOps.h
@@ -366,6 +366,127 @@ private:
explicit DimOp(const Operation *state) : Op(state) {}
};
+// DmaStartOp starts a non-blocking DMA operation that transfers data from a
+// source memref to a destination memref. The source and destination memref need
+// not be of the same dimensionality, but need to have the same elemental type.
+// The operands include the source and destination memref's each followed by its
+// indices, size of the data transfer in terms of the number of elements (of the
+// elemental type of the memref), and a tag memref with its indices. The tag
+// location is used by a DmaWaitOp to check for completion. The indices of the
+// source memref, destination memref, and the tag memref have the same
+// restrictions as any load/store in MLFunctions.
+//
+// For example, a DmaStartOp operation that transfers one 8x128xf32
+// (%size = 1024) chunk of data from memref '%src' in HBM (memory space 0)
+// at indices [%i, %j] to memref '%dst' in VMEM (memory space 2) at
+// indices [%k, %l], would be specified as follows:
+//
+// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
+// %idx = constant 0 : index
+// dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] :
+// memref<40 x 8 x vector<8x128xf32>, (d0) -> (d0), 0>,
+// memref<2 x 4 x vector<8x128xf32>, (d0) -> (d0), 2>,
+// memref<1 x i32>, (d0) -> (d0), 4>
+//
+// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
+class DmaStartOp
+ : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+ // Returns the source MemRefType for this DMA operation.
+ const SSAValue *getSrcMemRef() const { return getOperand(0); }
+ // Returns the rank (number of indices) of the source MemRefType.
+ unsigned getSrcMemRefRank() const {
+ return cast<MemRefType>(getSrcMemRef()->getType())->getRank();
+ }
+ // Returns the source memerf indices for this DMA operation.
+ llvm::iterator_range<Operation::const_operand_iterator>
+ getSrcIndices() const {
+ return {getOperation()->operand_begin() + 1,
+ getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
+ }
+
+ // Returns the destination MemRefType for this DMA operations.
+ const SSAValue *getDstMemRef() const {
+ return getOperand(1 + getSrcMemRefRank());
+ }
+ // Returns the rank (number of indices) of the destination MemRefType.
+ unsigned getDstMemRefRank() const {
+ return cast<MemRefType>(getDstMemRef()->getType())->getRank();
+ }
+ unsigned getSrcMemorySpace() const {
+ return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace();
+ }
+ unsigned getDstMemorySpace() const {
+ return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace();
+ }
+
+ // Returns the destination memref indices for this DMA operation.
+ llvm::iterator_range<Operation::const_operand_iterator>
+ getDstIndices() const {
+ return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
+ getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
+ getDstMemRefRank()};
+ }
+
+ // Returns the number of elements being transferred by this DMA operation.
+ const SSAValue *getNumElements() const {
+ return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
+ }
+
+ // Returns the Tag MemRef for this DMA operation.
+ const SSAValue *getTagMemRef() const {
+ return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+ }
+ // Returns the tag memref index for this DMA operation.
+ llvm::iterator_range<Operation::const_operand_iterator>
+ getTagIndices() const {
+ return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
+ getDstMemRefRank() + 1 + 1,
+ getOperation()->operand_end()};
+ }
+
+ static StringRef getOperationName() { return "dma_start"; }
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p) const;
+
+protected:
+ friend class ::mlir::Operation;
+ explicit DmaStartOp(const Operation *state) : Op(state) {}
+};
+
+// DmaWaitOp blocks until the completion of a DMA operation associated with the
+// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
+// with the same restrictions as any load/store index in MLFunctions. For
+// example:
+//
+// dma_start %src[%i, %j], %dst[%k, %l], %tag[%index] :
+// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
+// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>
+// memref<1 x i32>, (d0) -> (d0), 4>
+// ...
+// ...
+// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
+//
+class DmaWaitOp
+ : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
+public:
+ static StringRef getOperationName() { return "dma_wait"; }
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p) const;
+
+ // Returns the Tag MemRef associated with the DMA operation being waited on.
+ const SSAValue *getTagMemRef() const { return getOperand(0); }
+ // Returns the tag memref index for this DMA operation.
+ llvm::iterator_range<Operation::const_operand_iterator>
+ getTagIndices() const {
+ return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+ }
+
+protected:
+ friend class ::mlir::Operation;
+ explicit DmaWaitOp(const Operation *state) : Op(state) {}
+};
+
/// The "extract_element" op reads a tensor or vector and returns one element
/// from it specified by an index list. The output of extract is a new value
/// with the same type as the elements of the tensor or vector. The arity of
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 81a6f3f6d34..e6d50f06c5f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -239,8 +239,11 @@ bool OpTrait::impl::verifyOneOperand(const Operation *op) {
}
bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) {
- if (op->getNumOperands() != numOperands)
- return op->emitOpError("expected " + Twine(numOperands) + " operands");
+ if (op->getNumOperands() != numOperands) {
+ return op->emitOpError("expected " + Twine(numOperands) +
+ " operands, but found " +
+ Twine(op->getNumOperands()));
+ }
return false;
}
diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp
index 355d8507b1e..c7c32e5210c 100644
--- a/mlir/lib/IR/StandardOps.cpp
+++ b/mlir/lib/IR/StandardOps.cpp
@@ -612,6 +612,133 @@ Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
return nullptr;
}
+// ---------------------------------------------------------------------------
+// DmaStartOp
+// ---------------------------------------------------------------------------
+
+void DmaStartOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName() << ' ' << *getSrcMemRef() << '[';
+ p->printOperands(getSrcIndices());
+ *p << "], " << *getDstMemRef() << '[';
+ p->printOperands(getDstIndices());
+ *p << "], " << *getNumElements();
+ *p << ", " << *getTagMemRef() << '[';
+ p->printOperands(getTagIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getSrcMemRef()->getType();
+ *p << ", " << *getDstMemRef()->getType();
+ *p << ", " << *getTagMemRef()->getType();
+}
+
+// Parse DmaStartOp.
+// EX:
+// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
+// %tag[%index] :
+// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
+// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>,
+// memref<1 x i32, (d0) -> (d0), 4>
+//
+bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType srcMemRefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
+ OpAsmParser::OperandType dstMemRefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
+ OpAsmParser::OperandType numElementsInfo;
+ OpAsmParser::OperandType tagMemrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
+
+ SmallVector<Type *, 3> types;
+ auto *indexType = parser->getBuilder().getIndexType();
+
+ // Parse and resolve the following list of operands:
+ // *) source memref followed by its indices (in square brackets).
+ // *) destination memref followed by its indices (in square brackets).
+ // *) dma size in KiB.
+ if (parser->parseOperand(srcMemRefInfo) ||
+ parser->parseOperandList(srcIndexInfos, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
+ parser->parseOperandList(dstIndexInfos, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseComma() || parser->parseOperand(numElementsInfo) ||
+ parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
+ parser->parseOperandList(tagIndexInfos, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseColonTypeList(types))
+ return true;
+
+ if (types.size() != 3)
+ return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
+
+ if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
+ parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
+ parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
+ parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
+ // size should be an index.
+ parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
+ parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
+ // tag indices should be index.
+ parser->resolveOperands(tagIndexInfos, indexType, result->operands))
+ return true;
+
+ // Check that source/destination index list size matches associated rank.
+ if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
+ dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
+ return parser->emitError(parser->getNameLoc(),
+ "memref rank not equal to indices count");
+
+ if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
+ return parser->emitError(parser->getNameLoc(),
+ "tag memref rank not equal to indices count");
+
+ // These should be verified in verify(). TODO(b/116737205).
+ if (tagIndexInfos.size() != 1)
+ return parser->emitError(parser->getNameLoc(),
+ "only 1-d tag memref supported");
+
+ return false;
+}
+
+// ---------------------------------------------------------------------------
+// DmaWaitOp
+// ---------------------------------------------------------------------------
+// Parse DmaWaitOp.
+// Eg:
+// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
+//
+void DmaWaitOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName() << ' ';
+ // Print operands.
+ p->printOperand(getTagMemRef());
+ *p << '[';
+ p->printOperands(getTagIndices());
+ *p << ']';
+ *p << " : " << *getTagMemRef()->getType();
+}
+
+bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType tagMemrefInfo;
+ SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
+ Type *type;
+ auto *indexType = parser->getBuilder().getIndexType();
+
+ // Parse tag memref and index.
+ if (parser->parseOperand(tagMemrefInfo) ||
+ parser->parseOperandList(tagIndexInfos, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
+ parser->resolveOperands(tagIndexInfos, indexType, result->operands))
+ return true;
+
+ if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
+ return parser->emitError(parser->getNameLoc(),
+ "tag memref rank not equal to indices count");
+
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
@@ -796,7 +923,7 @@ bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
void ReturnOp::print(OpAsmPrinter *p) const {
*p << "return";
if (getNumOperands() > 0) {
- *p << " ";
+ *p << ' ';
p->printOperands(operand_begin(), operand_end());
*p << " : ";
interleave(operand_begin(), operand_end(),
@@ -982,8 +1109,8 @@ Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AddIOp, AffineApplyOp, AllocOp, CallOp,
- CallIndirectOp, ConstantOp, DeallocOp, DimOp,
- ExtractElementOp, LoadOp, MulFOp, MulIOp, ReturnOp,
- ShapeCastOp, StoreOp, SubFOp, SubIOp>(
+ CallIndirectOp, ConstantOp, DeallocOp, DimOp, DmaStartOp,
+ DmaWaitOp, ExtractElementOp, LoadOp, MulFOp, MulIOp,
+ ReturnOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>(
/*prefix=*/"");
}
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 1eeb9a9aa5c..0d025f5678f 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
-// op traits for it are added. TODO(b/117228571)
-static bool isDmaStartStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.in.start") ||
- stmt.getName().strref().contains("dma.out.start");
-}
-
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
-// added. TODO(b/117228571)
-static bool isDmaFinishStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.finish");
-}
-
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
- assert(isDmaStartStmt(dmaStartStmt));
+static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
unsigned srcDmaPos = 0;
- unsigned destDmaPos =
- cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
+ unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
- if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
- ->getMemorySpace() >
- cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
- ->getMemorySpace())
+ if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
return srcDmaPos;
return destDmaPos;
}
@@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
- assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
- if (isDmaStartStmt(dmaStmt)) {
+static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+ assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>());
+ if (dmaStmt.is<DmaStartOp>()) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
}
@@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
return 0;
}
-/// Doubles the buffer of the supplied memref.
+/// Doubles the buffer of the supplied memref while replacing all uses of the
+/// old memref. Returns false if such a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
@@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
return true;
}
-// For testing purposes, this just runs on the first for statement of an
+// For testing purposes, this just runs on the first 'for' statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
@@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
- if (isDmaStartStmt(*opStmt)) {
+ if (opStmt->is<DmaStartOp>()) {
dmaStartStmts.push_back(opStmt);
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
}
}
@@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
- MLValue *oldMemRef = cast<MLValue>(
- dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
+ MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
+ getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
if (!doubleBuffer(oldMemRef, forStmt))
return PassResult::Failure;
}
@@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// TODO(bondhugula): check whether such statements do not have any DMAs
// nested within.
opDelayMap[&stmt] = 1;
- } else if (isDmaStartStmt(*opStmt)) {
+ } else if (opStmt->is<DmaStartOp>()) {
// DMA starts are not shifted.
opDelayMap[&stmt] = 0;
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
// DMA finish op shifted by one.
opDelayMap[&stmt] = 1;
} else if (!opStmt->is<AffineApplyOp>()) {
diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp
index cc1c7973858..0262eb94bd7 100644
--- a/mlir/lib/Transforms/Utils.cpp
+++ b/mlir/lib/Transforms/Utils.cpp
@@ -33,12 +33,9 @@ using namespace mlir;
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
static bool isMemRefDereferencingOp(const Operation &op) {
- if (op.is<LoadOp>() || op.is<StoreOp>() ||
- op.getName().strref().contains("dma.in.start") ||
- op.getName().strref().contains("dma.out.start") ||
- op.getName().strref().contains("dma.finish")) {
+ if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() ||
+ op.is<DmaWaitOp>())
return true;
- }
return false;
}
diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir
index c061c090f93..13abd43a2dc 100644
--- a/mlir/test/Transforms/pipeline-data-transfer.mlir
+++ b/mlir/test/Transforms/pipeline-data-transfer.mlir
@@ -13,33 +13,33 @@
// CHECK-NEXT: %c128 = constant 128 : index
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
// CHECK-NEXT: %6 = affine_apply #map0(%c0)
-// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, index, memref<2x32xf32>, index, index, index, memref<2x1xf32>, index, index) -> ()
+// CHECK-NEXT: dma_start %2[%c0], %1[%5, %c0], %c128, %0[%6, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
// CHECK-NEXT: for %i0 = 1 to 7 {
// CHECK-NEXT: %7 = affine_apply #map0(%i0)
// CHECK-NEXT: %8 = affine_apply #map0(%i0)
-// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, index, memref<2x32xf32>, index, index, index, memref<2x1xf32>, index, index) -> ()
+// CHECK-NEXT: dma_start %2[%i0], %1[%7, %i0], %c128, %0[%8, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
// CHECK-NEXT: %9 = affine_apply #map1(%i0)
// CHECK-NEXT: %10 = affine_apply #map0(%9)
-// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, index, index) -> index
-// CHECK-NEXT: %12 = affine_apply #map0(%9)
-// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32>
-// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
-// CHECK-NEXT: %15 = affine_apply #map0(%9)
-// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32>
+// CHECK-NEXT: dma_wait %0[%10, %c0_0] : memref<2x1xf32>
+// CHECK-NEXT: %11 = affine_apply #map0(%9)
+// CHECK-NEXT: %12 = load %1[%11, %9] : memref<2x32xf32>
+// CHECK-NEXT: %13 = "compute"(%12) : (f32) -> f32
+// CHECK-NEXT: %14 = affine_apply #map0(%9)
+// CHECK-NEXT: store %13, %1[%14, %9] : memref<2x32xf32>
// CHECK-NEXT: for %i1 = 0 to 127 {
// CHECK-NEXT: "do_more_compute"(%9, %i1) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
-// CHECK-NEXT: %16 = affine_apply #map1(%c8)
-// CHECK-NEXT: %17 = affine_apply #map0(%16)
-// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, index, index) -> index
-// CHECK-NEXT: %19 = affine_apply #map0(%16)
-// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32>
-// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32
-// CHECK-NEXT: %22 = affine_apply #map0(%16)
-// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32>
+// CHECK-NEXT: %15 = affine_apply #map1(%c8)
+// CHECK-NEXT: %16 = affine_apply #map0(%15)
+// CHECK-NEXT: dma_wait %0[%16, %c0_0] : memref<2x1xf32>
+// CHECK-NEXT: %17 = affine_apply #map0(%15)
+// CHECK-NEXT: %18 = load %1[%17, %15] : memref<2x32xf32>
+// CHECK-NEXT: %19 = "compute"(%18) : (f32) -> f32
+// CHECK-NEXT: %20 = affine_apply #map0(%15)
+// CHECK-NEXT: store %19, %1[%20, %15] : memref<2x32xf32>
// CHECK-NEXT: for %i2 = 0 to 127 {
-// CHECK-NEXT: "do_more_compute"(%16, %i2) : (index, index) -> ()
+// CHECK-NEXT: "do_more_compute"(%15, %i2) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
mlfunc @loop_nest_dma() {
@@ -53,8 +53,8 @@ mlfunc @loop_nest_dma() {
%size = constant 128 : index
for %i = 0 to 7 {
- "dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, index, memref<32 x f32, (d0)->(d0), 1>, index, index, memref<1 x f32>, index) -> ()
- "dma.finish"(%tag, %zero) : (memref<1 x f32>, index) -> index
+ dma_start %A[%i], %Ah[%i], %size, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32>
+ dma_wait %tag[%zero] : memref<1 x f32>
%v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
%r = "compute"(%v) : (f32) -> (f32)
store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1b395ce8723..6299ec37e99 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -212,7 +212,8 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
return OptFailure;
// Verify that the result of the pass is still valid.
- module->verify();
+ if (module->verify())
+ return OptFailure;
}
// Print the output.
OpenPOWER on IntegriCloud