diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopInvariantCodeMotion.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index c4c1184fa82..48e97f44436 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -70,7 +70,7 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar, static bool isMemRefDereferencingOp(Operation &op) { // TODO(asabne): Support DMA Ops. - if (isa<LoadOp>(op) || isa<StoreOp>(op)) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { return true; } return false; @@ -94,23 +94,25 @@ bool isOpLoopInvariant(Operation &op, Value *indVar, // If the body of a predicated region has a for loop, we don't hoist the // 'affine.if'. return false; - } else if (isa<DmaStartOp>(op) || isa<DmaWaitOp>(op)) { + } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) { // TODO(asabne): Support DMA ops. return false; } else if (!isa<ConstantOp>(op)) { if (isMemRefDereferencingOp(op)) { - Value *memref = isa<LoadOp>(op) ? cast<LoadOp>(op).getMemRef() - : cast<StoreOp>(op).getMemRef(); + Value *memref = isa<AffineLoadOp>(op) + ? cast<AffineLoadOp>(op).getMemRef() + : cast<AffineStoreOp>(op).getMemRef(); for (auto *user : memref->getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. - if (isa<DmaStartOp>(op) || isa<DmaWaitOp>(op)) { + if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) { return false; } // If the memref used by the load/store is used in a store elsewhere in // the loop nest, we do not hoist. Similarly, if the memref used in a // load is also being stored too, we do not hoist the load. - if (isa<StoreOp>(user) || (isa<LoadOp>(user) && isa<StoreOp>(op))) { + if (isa<AffineStoreOp>(user) || + (isa<AffineLoadOp>(user) && isa<AffineStoreOp>(op))) { if (&op != user) { SmallVector<AffineForOp, 8> userIVs; getLoopIVs(*user, &userIVs); |

