diff options
| -rw-r--r-- | mlir/include/mlir/IR/StandardOps.h | 121 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardOps.cpp | 135 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 47 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 7 | ||||
| -rw-r--r-- | mlir/test/Transforms/pipeline-data-transfer.mlir | 38 | ||||
| -rw-r--r-- | mlir/tools/mlir-opt/mlir-opt.cpp | 3 |
7 files changed, 295 insertions, 63 deletions
diff --git a/mlir/include/mlir/IR/StandardOps.h b/mlir/include/mlir/IR/StandardOps.h index 9aa6875f13a..018d75a0cbe 100644 --- a/mlir/include/mlir/IR/StandardOps.h +++ b/mlir/include/mlir/IR/StandardOps.h @@ -366,6 +366,127 @@ private: explicit DimOp(const Operation *state) : Op(state) {} }; +// DmaStartOp starts a non-blocking DMA operation that transfers data from a +// source memref to a destination memref. The source and destination memref need +// not be of the same dimensionality, but need to have the same elemental type. +// The operands include the source and destination memref's each followed by its +// indices, size of the data transfer in terms of the number of elements (of the +// elemental type of the memref), and a tag memref with its indices. The tag +// location is used by a DmaWaitOp to check for completion. The indices of the +// source memref, destination memref, and the tag memref have the same +// restrictions as any load/store in MLFunctions. +// +// For example, a DmaStartOp operation that transfers one 8x128xf32 +// (%size = 1024) chunk of data from memref '%src' in HBM (memory space 0) +// at indices [%i, %j] to memref '%dst' in VMEM (memory space 2) at +// indices [%k, %l], would be specified as follows: +// +// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> +// %idx = constant 0 : index +// dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : +// memref<40 x 8 x vector<8x128xf32>, (d0) -> (d0), 0>, +// memref<2 x 4 x vector<8x128xf32>, (d0) -> (d0), 2>, +// memref<1 x i32>, (d0) -> (d0), 4> +// +// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. +class DmaStartOp + : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { +public: + // Returns the source MemRefType for this DMA operation. + const SSAValue *getSrcMemRef() const { return getOperand(0); } + // Returns the rank (number of indices) of the source MemRefType. + unsigned getSrcMemRefRank() const { + return cast<MemRefType>(getSrcMemRef()->getType())->getRank(); + } + // Returns the source memerf indices for this DMA operation. + llvm::iterator_range<Operation::const_operand_iterator> + getSrcIndices() const { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; + } + + // Returns the destination MemRefType for this DMA operations. + const SSAValue *getDstMemRef() const { + return getOperand(1 + getSrcMemRefRank()); + } + // Returns the rank (number of indices) of the destination MemRefType. + unsigned getDstMemRefRank() const { + return cast<MemRefType>(getDstMemRef()->getType())->getRank(); + } + unsigned getSrcMemorySpace() const { + return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace(); + } + unsigned getDstMemorySpace() const { + return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace(); + } + + // Returns the destination memref indices for this DMA operation. + llvm::iterator_range<Operation::const_operand_iterator> + getDstIndices() const { + return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, + getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank()}; + } + + // Returns the number of elements being transferred by this DMA operation. + const SSAValue *getNumElements() const { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); + } + + // Returns the Tag MemRef for this DMA operation. + const SSAValue *getTagMemRef() const { + return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); + } + // Returns the tag memref index for this DMA operation. + llvm::iterator_range<Operation::const_operand_iterator> + getTagIndices() const { + return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + + getDstMemRefRank() + 1 + 1, + getOperation()->operand_end()}; + } + + static StringRef getOperationName() { return "dma_start"; } + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + +protected: + friend class ::mlir::Operation; + explicit DmaStartOp(const Operation *state) : Op(state) {} +}; + +// DmaWaitOp blocks until the completion of a DMA operation associated with the +// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index +// with the same restrictions as any load/store index in MLFunctions. For +// example: +// +// dma_start %src[%i, %j], %dst[%k, %l], %tag[%index] : +// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>, +// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2> +// memref<1 x i32>, (d0) -> (d0), 4> +// ... +// ... +// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4> +// +class DmaWaitOp + : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { +public: + static StringRef getOperationName() { return "dma_wait"; } + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + + // Returns the Tag MemRef associated with the DMA operation being waited on. + const SSAValue *getTagMemRef() const { return getOperand(0); } + // Returns the tag memref index for this DMA operation. + llvm::iterator_range<Operation::const_operand_iterator> + getTagIndices() const { + return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; + } + +protected: + friend class ::mlir::Operation; + explicit DmaWaitOp(const Operation *state) : Op(state) {} +}; + /// The "extract_element" op reads a tensor or vector and returns one element /// from it specified by an index list. The output of extract is a new value /// with the same type as the elements of the tensor or vector. The arity of diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 81a6f3f6d34..e6d50f06c5f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -239,8 +239,11 @@ bool OpTrait::impl::verifyOneOperand(const Operation *op) { } bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { - if (op->getNumOperands() != numOperands) - return op->emitOpError("expected " + Twine(numOperands) + " operands"); + if (op->getNumOperands() != numOperands) { + return op->emitOpError("expected " + Twine(numOperands) + + " operands, but found " + + Twine(op->getNumOperands())); + } return false; } diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp index 355d8507b1e..c7c32e5210c 100644 --- a/mlir/lib/IR/StandardOps.cpp +++ b/mlir/lib/IR/StandardOps.cpp @@ -612,6 +612,133 @@ Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands, return nullptr; } +// --------------------------------------------------------------------------- +// DmaStartOp +// --------------------------------------------------------------------------- + +void DmaStartOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' ' << *getSrcMemRef() << '['; + p->printOperands(getSrcIndices()); + *p << "], " << *getDstMemRef() << '['; + p->printOperands(getDstIndices()); + *p << "], " << *getNumElements(); + *p << ", " << *getTagMemRef() << '['; + p->printOperands(getTagIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << *getSrcMemRef()->getType(); + *p << ", " << *getDstMemRef()->getType(); + *p << ", " << *getTagMemRef()->getType(); +} + +// Parse DmaStartOp. +// EX: +// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, +// %tag[%index] : +// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>, +// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>, +// memref<1 x i32, (d0) -> (d0), 4> +// +bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; + OpAsmParser::OperandType dstMemRefInfo; + SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; + OpAsmParser::OperandType numElementsInfo; + OpAsmParser::OperandType tagMemrefInfo; + SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; + + SmallVector<Type *, 3> types; + auto *indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) source memref followed by its indices (in square brackets). + // *) destination memref followed by its indices (in square brackets). + // *) dma size in KiB. + if (parser->parseOperand(srcMemRefInfo) || + parser->parseOperandList(srcIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || + parser->parseOperandList(dstIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseComma() || parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseColonTypeList(types)) + return true; + + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcIndexInfos, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstIndexInfos, indexType, result->operands) || + // size should be an index. + parser->resolveOperand(numElementsInfo, indexType, result->operands) || + parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || + // tag indices should be index. + parser->resolveOperands(tagIndexInfos, indexType, result->operands)) + return true; + + // Check that source/destination index list size matches associated rank. + if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || + dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) + return parser->emitError(parser->getNameLoc(), + "memref rank not equal to indices count"); + + if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + // These should be verified in verify(). TODO(b/116737205). + if (tagIndexInfos.size() != 1) + return parser->emitError(parser->getNameLoc(), + "only 1-d tag memref supported"); + + return false; +} + +// --------------------------------------------------------------------------- +// DmaWaitOp +// --------------------------------------------------------------------------- +// Parse DmaWaitOp. +// Eg: +// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4> +// +void DmaWaitOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' '; + // Print operands. + p->printOperand(getTagMemRef()); + *p << '['; + p->printOperands(getTagIndices()); + *p << ']'; + *p << " : " << *getTagMemRef()->getType(); +} + +bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType tagMemrefInfo; + SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; + Type *type; + auto *indexType = parser->getBuilder().getIndexType(); + + // Parse tag memref and index. + if (parser->parseOperand(tagMemrefInfo) || + parser->parseOperandList(tagIndexInfos, -1, + OpAsmParser::Delimiter::Square) || + parser->parseColonType(type) || + parser->resolveOperand(tagMemrefInfo, type, result->operands) || + parser->resolveOperands(tagIndexInfos, indexType, result->operands)) + return true; + + if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) + return parser->emitError(parser->getNameLoc(), + "tag memref rank not equal to indices count"); + + return false; +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -796,7 +923,7 @@ bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { void ReturnOp::print(OpAsmPrinter *p) const { *p << "return"; if (getNumOperands() > 0) { - *p << " "; + *p << ' '; p->printOperands(operand_begin(), operand_end()); *p << " : "; interleave(operand_begin(), operand_end(), @@ -982,8 +1109,8 @@ Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands, /// Install the standard operations in the specified operation set. void mlir::registerStandardOperations(OperationSet &opSet) { opSet.addOperations<AddFOp, AddIOp, AffineApplyOp, AllocOp, CallOp, - CallIndirectOp, ConstantOp, DeallocOp, DimOp, - ExtractElementOp, LoadOp, MulFOp, MulIOp, ReturnOp, - ShapeCastOp, StoreOp, SubFOp, SubIOp>( + CallIndirectOp, ConstantOp, DeallocOp, DimOp, DmaStartOp, + DmaWaitOp, ExtractElementOp, LoadOp, MulFOp, MulIOp, + ReturnOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>( /*prefix=*/""); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 1eeb9a9aa5c..0d025f5678f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or -// op traits for it are added. TODO(b/117228571) -static bool isDmaStartStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.in.start") || - stmt.getName().strref().contains("dma.out.start"); -} - -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are -// added. TODO(b/117228571) -static bool isDmaFinishStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.finish"); -} - /// Given a DMA start operation, returns the operand position of either the /// source or destination memref depending on the one that is at the higher /// level of the memory hierarchy. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { - assert(isDmaStartStmt(dmaStartStmt)); +static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) { unsigned srcDmaPos = 0; - unsigned destDmaPos = - cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1; + unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1; - if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType()) - ->getMemorySpace() > - cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType()) - ->getMemorySpace()) + if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace()) return srcDmaPos; return destDmaPos; } @@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { - assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt)); - if (isDmaStartStmt(dmaStmt)) { +static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { + assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>()); + if (dmaStmt.is<DmaStartOp>()) { // Second to last operand. return dmaStmt.getNumOperands() - 2; } @@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { return 0; } -/// Doubles the buffer of the supplied memref. +/// Doubles the buffer of the supplied memref while replacing all uses of the +/// old memref. Returns false if such a replacement cannot be performed. static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); @@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { return true; } -// For testing purposes, this just runs on the first for statement of an +// For testing purposes, this just runs on the first 'for' statement of an // MLFunction at the top level. // TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when // the other TODOs listed inside are dealt with. @@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; - if (isDmaStartStmt(*opStmt)) { + if (opStmt->is<DmaStartOp>()) { dmaStartStmts.push_back(opStmt); - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { dmaFinishStmts.push_back(opStmt); } } @@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // A DMA start statement has two memref's - the one from the higher level of // memory hierarchy is the one to double buffer. for (auto *dmaStartStmt : dmaStartStmts) { - MLValue *oldMemRef = cast<MLValue>( - dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt))); + MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( + getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>()))); if (!doubleBuffer(oldMemRef, forStmt)) return PassResult::Failure; } @@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // TODO(bondhugula): check whether such statements do not have any DMAs // nested within. opDelayMap[&stmt] = 1; - } else if (isDmaStartStmt(*opStmt)) { + } else if (opStmt->is<DmaStartOp>()) { // DMA starts are not shifted. opDelayMap[&stmt] = 0; - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { // DMA finish op shifted by one. opDelayMap[&stmt] = 1; } else if (!opStmt->is<AffineApplyOp>()) { diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp index cc1c7973858..0262eb94bd7 100644 --- a/mlir/lib/Transforms/Utils.cpp +++ b/mlir/lib/Transforms/Utils.cpp @@ -33,12 +33,9 @@ using namespace mlir; // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(const Operation &op) { - if (op.is<LoadOp>() || op.is<StoreOp>() || - op.getName().strref().contains("dma.in.start") || - op.getName().strref().contains("dma.out.start") || - op.getName().strref().contains("dma.finish")) { + if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() || + op.is<DmaWaitOp>()) return true; - } return false; } diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index c061c090f93..13abd43a2dc 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -13,33 +13,33 @@ // CHECK-NEXT: %c128 = constant 128 : index // CHECK-NEXT: %5 = affine_apply #map0(%c0) // CHECK-NEXT: %6 = affine_apply #map0(%c0) -// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, index, memref<2x32xf32>, index, index, index, memref<2x1xf32>, index, index) -> () +// CHECK-NEXT: dma_start %2[%c0], %1[%5, %c0], %c128, %0[%6, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32> // CHECK-NEXT: for %i0 = 1 to 7 { // CHECK-NEXT: %7 = affine_apply #map0(%i0) // CHECK-NEXT: %8 = affine_apply #map0(%i0) -// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, index, memref<2x32xf32>, index, index, index, memref<2x1xf32>, index, index) -> () +// CHECK-NEXT: dma_start %2[%i0], %1[%7, %i0], %c128, %0[%8, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32> // CHECK-NEXT: %9 = affine_apply #map1(%i0) // CHECK-NEXT: %10 = affine_apply #map0(%9) -// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, index, index) -> index -// CHECK-NEXT: %12 = affine_apply #map0(%9) -// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32> -// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32 -// CHECK-NEXT: %15 = affine_apply #map0(%9) -// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32> +// CHECK-NEXT: dma_wait %0[%10, %c0_0] : memref<2x1xf32> +// CHECK-NEXT: %11 = affine_apply #map0(%9) +// CHECK-NEXT: %12 = load %1[%11, %9] : memref<2x32xf32> +// CHECK-NEXT: %13 = "compute"(%12) : (f32) -> f32 +// CHECK-NEXT: %14 = affine_apply #map0(%9) +// CHECK-NEXT: store %13, %1[%14, %9] : memref<2x32xf32> // CHECK-NEXT: for %i1 = 0 to 127 { // CHECK-NEXT: "do_more_compute"(%9, %i1) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: %16 = affine_apply #map1(%c8) -// CHECK-NEXT: %17 = affine_apply #map0(%16) -// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, index, index) -> index -// CHECK-NEXT: %19 = affine_apply #map0(%16) -// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32> -// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32 -// CHECK-NEXT: %22 = affine_apply #map0(%16) -// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32> +// CHECK-NEXT: %15 = affine_apply #map1(%c8) +// CHECK-NEXT: %16 = affine_apply #map0(%15) +// CHECK-NEXT: dma_wait %0[%16, %c0_0] : memref<2x1xf32> +// CHECK-NEXT: %17 = affine_apply #map0(%15) +// CHECK-NEXT: %18 = load %1[%17, %15] : memref<2x32xf32> +// CHECK-NEXT: %19 = "compute"(%18) : (f32) -> f32 +// CHECK-NEXT: %20 = affine_apply #map0(%15) +// CHECK-NEXT: store %19, %1[%20, %15] : memref<2x32xf32> // CHECK-NEXT: for %i2 = 0 to 127 { -// CHECK-NEXT: "do_more_compute"(%16, %i2) : (index, index) -> () +// CHECK-NEXT: "do_more_compute"(%15, %i2) : (index, index) -> () // CHECK-NEXT: } // CHECK-NEXT: return mlfunc @loop_nest_dma() { @@ -53,8 +53,8 @@ mlfunc @loop_nest_dma() { %size = constant 128 : index for %i = 0 to 7 { - "dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, index, memref<32 x f32, (d0)->(d0), 1>, index, index, memref<1 x f32>, index) -> () - "dma.finish"(%tag, %zero) : (memref<1 x f32>, index) -> index + dma_start %A[%i], %Ah[%i], %size, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32> + dma_wait %tag[%zero] : memref<1 x f32> %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> %r = "compute"(%v) : (f32) -> (f32) store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 1b395ce8723..6299ec37e99 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -212,7 +212,8 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) { return OptFailure; // Verify that the result of the pass is still valid. - module->verify(); + if (module->verify()) + return OptFailure; } // Print the output. |

