diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-10-04 17:15:30 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:23:19 -0700 |
| commit | 6cfdb756b165ebd32068506fc70d623f86bb80b3 (patch) | |
| tree | 2ac173f35c2b0b7913a243d0b5eed17c78c5a09b /mlir | |
| parent | b55b4076011419c8d8d8cac58c8fda7631067bb2 (diff) | |
| download | bcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.tar.gz bcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.zip | |
Introduce memref replacement/rewrite support: to replace an existing memref
with a new one (of a potentially different rank/shape) with an optional index
remapping.
- introduce Utils::replaceAllMemRefUsesWith
- use this for DMA double buffering
(This CL also adds a few temporary utilities / code that will be done away with
once:
1) abstract DMA op's are added
2) memref deferencing side-effect / trait is available on op's
3) b/117159533 is resolved (memref index computation slices).
PiperOrigin-RevId: 215831373
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/IR/Builders.h | 12 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/Utils.h | 49 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 12 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 235 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 165 | ||||
| -rw-r--r-- | mlir/test/Transforms/pipeline-data-transfer.mlir | 131 |
6 files changed, 518 insertions, 86 deletions
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index a37b4b2d138..334422f05c8 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -343,15 +343,21 @@ public: return MLFuncBuilder(forStmt, forStmt->end()); } - /// Get the current insertion point of the builder. + /// Returns the current insertion point of the builder. StmtBlock::iterator getInsertionPoint() const { return insertPoint; } - /// Get the current block of the builder. + /// Returns the current block of the builder. StmtBlock *getBlock() const { return block; } - /// Create an operation given the fields represented as an OperationState. + /// Creates an operation given the fields represented as an OperationState. OperationStmt *createOperation(const OperationState &state); + /// Creates an operation given the fields. + OperationStmt *createOperation(Location *location, Identifier name, + ArrayRef<MLValue *> operands, + ArrayRef<Type *> types, + ArrayRef<NamedAttribute> attrs); + /// Create operation of specific op type at the current insertion point. template <typename OpTy, typename... Args> OpPointer<OpTy> create(Location *location, Args... args) { diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h new file mode 100644 index 00000000000..64e9b72b2ad --- /dev/null +++ b/mlir/include/mlir/Transforms/Utils.h @@ -0,0 +1,49 @@ +//===- Utils.h - General transformation utilities ---------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This header file defines prototypes for various transformation utilities for +// memref's and non-loop IR structures. These are not passes by themselves but +// are used either by passes, optimization sequences, or in turn by other +// transformation utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_UTILS_H +#define MLIR_TRANSFORMS_UTILS_H + +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class AffineMap; +class MLValue; +class SSAValue; + +/// Replace all uses of oldMemRef with newMemRef while optionally remapping the +/// old memref's indices using the supplied affine map and adding any additional +/// indices. The new memref could be of a different shape or rank. Returns true +/// on success and false if the replacement is not possible (whenever a memref +/// is used as an operand in a non-deferencing scenario). +/// Additional indices are added at the start. +// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// extended to add additional indices at any position. +bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef, + llvm::ArrayRef<SSAValue *> extraIndices, + AffineMap *indexRemap = nullptr); +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_UTILS_H diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8ca0fe8318f..deac801f822 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -312,6 +312,18 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { return op; } +/// Create an operation given the fields. +OperationStmt *MLFuncBuilder::createOperation(Location *location, + Identifier name, + ArrayRef<MLValue *> operands, + ArrayRef<Type *> types, + ArrayRef<NamedAttribute> attrs) { + auto *op = OperationStmt::create(location, name, operands, types, attrs, + getContext()); + block->getStatements().insert(insertPoint, op); + return op; +} + ForStmt *MLFuncBuilder::createFor(Location *location, ArrayRef<MLValue *> lbOperands, AffineMap *lbMap, diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 96d706e98ac..5aaac1c6c29 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,10 +21,13 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/IR/MLFunction.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMap.h" using namespace mlir; @@ -43,27 +46,237 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// For testing purposes, this just runs on the first statement of the MLFunction -// if that statement is a for stmt, and shifts the second half of its body by -// one. +// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or +// op traits for it are added. TODO(b/117228571) +static bool isDmaStartStmt(const OperationStmt &stmt) { + return stmt.getName().strref().contains("dma.in.start") || + stmt.getName().strref().contains("dma.out.start"); +} + +// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are +// added. TODO(b/117228571) +static bool isDmaFinishStmt(const OperationStmt &stmt) { + return stmt.getName().strref().contains("dma.finish"); +} + +/// Given a DMA start operation, returns the operand position of either the +/// source or destination memref depending on the one that is at the higher +/// level of the memory hierarchy. +// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are +// added. TODO(b/117228571) +static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { + assert(isDmaStartStmt(dmaStartStmt)); + unsigned srcDmaPos = 0; + unsigned destDmaPos = + cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1; + + if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType()) + ->getMemorySpace() > + cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType()) + ->getMemorySpace()) + return srcDmaPos; + return destDmaPos; +} + +// Returns the position of the tag memref operand given a DMA statement. +// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are +// added. TODO(b/117228571) +unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { + assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt)); + if (isDmaStartStmt(dmaStmt)) { + // Second to last operand. + return dmaStmt.getNumOperands() - 2; + } + // First operand for a dma finish statement. + return 0; +} + +/// Doubles the buffer of the supplied memref. +static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { + MLFuncBuilder bInner(forStmt, forStmt->begin()); + bInner.setInsertionPoint(forStmt, forStmt->begin()); + + // Doubles the shape with a leading dimension extent of 2. + auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * { + // Add the leading dimension in the shape for the double buffer. + ArrayRef<int> shape = origMemRefType->getShape(); + SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); + shapeSizes.insert(shapeSizes.begin(), 2); + + auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type()); + return newMemRefType; + }; + + auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType())); + + // Create and place the alloc at the top level. + auto *func = forStmt->getFunction(); + MLFuncBuilder topBuilder(func, func->begin()); + auto *newMemRef = cast<MLValue>( + topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType) + ->getResult()); + + auto d0 = bInner.getDimExpr(0); + auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {}); + auto ivModTwoOp = + bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt); + if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0))) + return false; + // We don't need ivMod2Op any more - this is cloned by + // replaceAllMemRefUsesWith wherever the memref replacement happens. Once + // b/117159533 is addressed, we'll eventually only need to pass + // ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith. + cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock(); + return true; +} + +// For testing purposes, this just runs on the first for statement of an +// MLFunction at the top level. +// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when +// the other TODOs listed inside are dealt with. PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { if (f->empty()) return PassResult::Success; - auto *forStmt = dyn_cast<ForStmt>(&f->front()); + + ForStmt *forStmt = nullptr; + for (auto &stmt : *f) { + if ((forStmt = dyn_cast<ForStmt>(&stmt))) { + break; + } + } if (!forStmt) - return PassResult::Failure; + return PassResult::Success; unsigned numStmts = forStmt->getStatements().size(); + if (numStmts == 0) return PassResult::Success; - std::vector<uint64_t> delays(numStmts); - for (unsigned i = 0; i < numStmts; i++) - delays[i] = (i < numStmts / 2) ? 0 : 1; + SmallVector<OperationStmt *, 4> dmaStartStmts; + SmallVector<OperationStmt *, 4> dmaFinishStmts; + for (auto &stmt : *forStmt) { + auto *opStmt = dyn_cast<OperationStmt>(&stmt); + if (!opStmt) + continue; + if (isDmaStartStmt(*opStmt)) { + dmaStartStmts.push_back(opStmt); + } else if (isDmaFinishStmt(*opStmt)) { + dmaFinishStmts.push_back(opStmt); + } + } + + // TODO(bondhugula,andydavis): match tag memref's (requires memory-based + // subscript check utilities). Assume for now that start/finish are matched in + // the order they appear. + if (dmaStartStmts.size() != dmaFinishStmts.size()) + return PassResult::Failure; + + // Double the buffers for the higher memory space memref's. + // TODO(bondhugula): assuming we don't have multiple DMA starts for the same + // memref. + // TODO(bondhugula): check whether double-buffering is even necessary. + // TODO(bondhugula): make this work with different layouts: assuming here that + // the dimension we are adding here for the double buffering is the outermost + // dimension. + // Identify memref's to replace by scanning through all DMA start statements. + // A DMA start statement has two memref's - the one from the higher level of + // memory hierarchy is the one to double buffer. + for (auto *dmaStartStmt : dmaStartStmts) { + MLValue *oldMemRef = cast<MLValue>( + dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt))); + if (!doubleBuffer(oldMemRef, forStmt)) + return PassResult::Failure; + } + + // Double the buffers for tag memref's. + for (auto *dmaFinishStmt : dmaFinishStmts) { + MLValue *oldTagMemRef = cast<MLValue>( + dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); + if (!doubleBuffer(oldTagMemRef, forStmt)) + return PassResult::Failure; + } + + // Collect all compute ops. + std::vector<const Statement *> computeOps; + computeOps.reserve(forStmt->getStatements().size()); + // Store delay for statement for later lookup for AffineApplyOp's. + DenseMap<const Statement *, unsigned> opDelayMap; + for (const auto &stmt : *forStmt) { + auto *opStmt = dyn_cast<OperationStmt>(&stmt); + if (!opStmt) { + // All for and if stmt's are treated as pure compute operations. + // TODO(bondhugula): check whether such statements do not have any DMAs + // nested within. + opDelayMap[&stmt] = 1; + } else if (isDmaStartStmt(*opStmt)) { + // DMA starts are not shifted. + opDelayMap[&stmt] = 0; + } else if (isDmaFinishStmt(*opStmt)) { + // DMA finish op shifted by one. + opDelayMap[&stmt] = 1; + } else if (!opStmt->is<AffineApplyOp>()) { + // Compute op shifted by one. + opDelayMap[&stmt] = 1; + computeOps.push_back(&stmt); + } + // Shifts for affine apply op's determined later. + } + + // Get the ancestor of a 'stmt' that lies in forStmt's block. + auto getAncestorInForBlock = + [&](const Statement *stmt, const StmtBlock &block) -> const Statement * { + // Traverse up the statement hierarchy starting from the owner of operand to + // find the ancestor statement that resides in the block of 'forStmt'. + while (stmt != nullptr && stmt->getBlock() != &block) { + stmt = stmt->getParentStmt(); + } + return stmt; + }; + + // Determine delays for affine apply op's: look up delay from its consumer op. + // This code will be thrown away once we have a way to obtain indices through + // a composed affine_apply op. See TODO(b/117159533). Such a composed + // affine_apply will be used exclusively by a given memref deferencing op. + for (const auto &stmt : *forStmt) { + auto *opStmt = dyn_cast<OperationStmt>(&stmt); + // Skip statements that aren't affine apply ops. + if (!opStmt || !opStmt->is<AffineApplyOp>()) + continue; + // Traverse uses of each result of the affine apply op. + for (auto *res : opStmt->getResults()) { + for (auto &use : res->getUses()) { + auto *ancestorInForBlock = + getAncestorInForBlock(use.getOwner(), *forStmt); + assert(ancestorInForBlock && + "traversing parent should reach forStmt block"); + auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock); + if (!opCheck || opCheck->is<AffineApplyOp>()) + continue; + assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end()); + if (opDelayMap.find(&stmt) != opDelayMap.end()) { + // This is where we enforce all uses of this affine_apply to have + // the same shifts - so that we know what shift to use for the + // affine_apply to preserve semantics. + assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]); + } else { + // Obtain delay from its consumer. + opDelayMap[&stmt] = opDelayMap[ancestorInForBlock]; + } + } + } + } + + // Get delays stored in map. + std::vector<uint64_t> delays(forStmt->getStatements().size()); + unsigned s = 0; + for (const auto &stmt : *forStmt) { + delays[s++] = opDelayMap[&stmt]; + } - if (!checkDominancePreservationOnShift(*forStmt, delays)) + if (!checkDominancePreservationOnShift(*forStmt, delays)) { // Violates SSA dominance. return PassResult::Failure; + } if (stmtBodySkew(forStmt, delays)) return PassResult::Failure; diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp new file mode 100644 index 00000000000..d189ac26703 --- /dev/null +++ b/mlir/lib/Transforms/Utils.cpp @@ -0,0 +1,165 @@ +//===- Utils.cpp ---- Misc utilities for code and data transformation -----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements miscellaneous transformation routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Utils.h" + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardOps.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; + +/// Return true if this operation dereferences one or more memref's. +// Temporary utility: will be replaced when this is modeled through +// side-effects/op traits. TODO(b/117228571) +static bool isMemRefDereferencingOp(const Operation &op) { + if (op.is<LoadOp>() || op.is<StoreOp>() || + op.getName().strref().contains("dma.in.start") || + op.getName().strref().contains("dma.out.start") || + op.getName().strref().contains("dma.finish")) { + return true; + } + return false; +} + +/// Replaces all uses of oldMemRef with newMemRef while optionally remapping +/// old memref's indices to the new memref using the supplied affine map +/// and adding any additional indices. The new memref could be of a different +/// shape or rank, but of the same elemental type. Additional indices are added +/// at the start for now. +// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// extended to add additional indices at any position. +bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef, + ArrayRef<SSAValue *> extraIndices, + AffineMap *indexRemap) { + unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); + unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank(); + if (indexRemap) { + assert(indexRemap->getNumInputs() == oldMemRefRank); + assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(cast<MemRefType>(oldMemRef->getType())->getElementType() == + cast<MemRefType>(newMemRef->getType())->getElementType()); + + // Check if memref was used in a non-deferencing context. + for (const StmtOperand &use : oldMemRef->getUses()) { + auto *opStmt = cast<OperationStmt>(use.getOwner()); + // Failure: memref used in a non-deferencing op (potentially escapes); no + // replacement in these cases. + if (!isMemRefDereferencingOp(*opStmt)) + return false; + } + + // Walk all uses of old memref. Statement using the memref gets replaced. + for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { + StmtOperand &use = *(it++); + auto *opStmt = cast<OperationStmt>(use.getOwner()); + assert(isMemRefDereferencingOp(*opStmt) && + "memref deferencing op expected"); + + auto getMemRefOperandPos = [&]() -> unsigned { + unsigned i; + for (i = 0; i < opStmt->getNumOperands(); i++) { + if (opStmt->getOperand(i) == oldMemRef) + break; + } + assert(i < opStmt->getNumOperands() && "operand guaranteed to be found"); + return i; + }; + unsigned memRefOperandPos = getMemRefOperandPos(); + + // Construct the new operation statement using this memref. + SmallVector<MLValue *, 8> operands; + operands.reserve(opStmt->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + operands.insert(operands.end(), opStmt->operand_begin(), + opStmt->operand_begin() + memRefOperandPos); + operands.push_back(newMemRef); + + MLFuncBuilder builder(opStmt); + // Normally, we could just use extraIndices as operands, but we will + // clone it so that each op gets its own "private" index. See b/117159533. + for (auto *extraIndex : extraIndices) { + OperationStmt::OperandMapTy operandMap; + // TODO(mlir-team): An operation/SSA value should provide a method to + // return the position of an SSA result in its defining + // operation. + assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + // TODO: actually check if this is a result of an affine_apply op. + assert((cast<MLValue>(extraIndex)->isValidDim() || + cast<MLValue>(extraIndex)->isValidSymbol()) && + "invalid memory op index"); + auto *clonedExtraIndex = + cast<OperationStmt>( + builder.clone(*extraIndex->getDefiningStmt(), operandMap)) + ->getResult(0); + operands.push_back(cast<MLValue>(clonedExtraIndex)); + } + + // Construct new indices. The indices of a memref come right after it, i.e., + // at position memRefOperandPos + 1. + SmallVector<SSAValue *, 4> indices( + opStmt->operand_begin() + memRefOperandPos + 1, + opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); + if (indexRemap) { + auto remapOp = + builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices); + // Remapped indices. + for (auto *index : remapOp->getOperation()->getResults()) + operands.push_back(cast<MLValue>(index)); + } else { + // No remapping specified. + for (auto *index : indices) + operands.push_back(cast<MLValue>(index)); + } + + // Insert the remaining operands unmodified. + operands.insert(operands.end(), + opStmt->operand_begin() + memRefOperandPos + 1 + + oldMemRefRank, + opStmt->operand_end()); + + // Result types don't change. Both memref's are of the same elemental type. + SmallVector<Type *, 8> resultTypes; + resultTypes.reserve(opStmt->getNumResults()); + for (const auto *result : opStmt->getResults()) + resultTypes.push_back(result->getType()); + + // Create the new operation. + auto *repOp = + builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands, + resultTypes, opStmt->getAttrs()); + // Replace old memref's deferencing op's uses. + unsigned r = 0; + for (auto *res : opStmt->getResults()) { + res->replaceAllUsesWith(repOp->getResult(r++)); + } + opStmt->eraseFromBlock(); + } + return true; +} diff --git a/mlir/test/Transforms/pipeline-data-transfer.mlir b/mlir/test/Transforms/pipeline-data-transfer.mlir index a866e0474c1..7bef1e04c2c 100644 --- a/mlir/test/Transforms/pipeline-data-transfer.mlir +++ b/mlir/test/Transforms/pipeline-data-transfer.mlir @@ -1,79 +1,66 @@ // RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s +// CHECK: #map0 = (d0) -> (d0 mod 2) +// CHECK-NEXT: #map1 = (d0) -> (d0 - 1) +// CHECK-NEXT: mlfunc @loop_nest_dma() { +// CHECK-NEXT: %c8 = constant 8 : affineint +// CHECK-NEXT: %c0 = constant 0 : affineint +// CHECK-NEXT: %0 = alloc() : memref<2x1xf32> +// CHECK-NEXT: %1 = alloc() : memref<2x32xf32> +// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)> +// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1> +// CHECK-NEXT: %4 = alloc() : memref<1xf32> +// CHECK-NEXT: %c0_0 = constant 0 : affineint +// CHECK-NEXT: %c128 = constant 128 : affineint +// CHECK-NEXT: %5 = affine_apply #map0(%c0) +// CHECK-NEXT: %6 = affine_apply #map0(%c0) +// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> () +// CHECK-NEXT: for %i0 = 1 to 7 { +// CHECK-NEXT: %7 = affine_apply #map0(%i0) +// CHECK-NEXT: %8 = affine_apply #map0(%i0) +// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> () +// CHECK-NEXT: %9 = affine_apply #map1(%i0) +// CHECK-NEXT: %10 = affine_apply #map0(%9) +// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint +// CHECK-NEXT: %12 = affine_apply #map0(%9) +// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32> +// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32 +// CHECK-NEXT: %15 = affine_apply #map0(%9) +// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32> +// CHECK-NEXT: for %i1 = 0 to 127 { +// CHECK-NEXT: "do_more_compute"(%9, %i1) : (affineint, affineint) -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %16 = affine_apply #map1(%c8) +// CHECK-NEXT: %17 = affine_apply #map0(%16) +// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint +// CHECK-NEXT: %19 = affine_apply #map0(%16) +// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32> +// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32 +// CHECK-NEXT: %22 = affine_apply #map0(%16) +// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32> +// CHECK-NEXT: for %i2 = 0 to 127 { +// CHECK-NEXT: "do_more_compute"(%16, %i2) : (affineint, affineint) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +mlfunc @loop_nest_dma() { -// CHECK-LABEL: mlfunc @loop_nest_simple() { -// CHECK: %c8 = constant 8 : affineint -// CHECK-NEXT: %c0 = constant 0 : affineint -// CHECK-NEXT: %0 = "foo"(%c0) : (affineint) -> affineint -// CHECK-NEXT: for %i0 = 1 to 7 { -// CHECK-NEXT: %1 = "foo"(%i0) : (affineint) -> affineint -// CHECK-NEXT: %2 = affine_apply #map0(%i0) -// CHECK-NEXT: %3 = "bar"(%2) : (affineint) -> affineint -// CHECK-NEXT: } -// CHECK-NEXT: %4 = affine_apply #map0(%c8) -// CHECK-NEXT: %5 = "bar"(%4) : (affineint) -> affineint -// CHECK-NEXT: return -mlfunc @loop_nest_simple() { - for %i = 0 to 7 { - %y = "foo"(%i) : (affineint) -> affineint - %x = "bar"(%i) : (affineint) -> affineint - } - return -} + %A = alloc() : memref<256 x f32, (d0) -> (d0), 0> + %Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1> -// CHECK-LABEL: mlfunc @loop_nest_dma() { -// CHECK: %c8 = constant 8 : affineint -// CHECK-NEXT: %c0 = constant 0 : affineint -// CHECK-NEXT: %0 = affine_apply #map1(%c0) -// CHECK-NEXT: %1 = "dma.enqueue"(%0) : (affineint) -> affineint -// CHECK-NEXT: %2 = "dma.enqueue"(%0) : (affineint) -> affineint -// CHECK-NEXT: for %i0 = 1 to 7 { -// CHECK-NEXT: %3 = affine_apply #map1(%i0) -// CHECK-NEXT: %4 = "dma.enqueue"(%3) : (affineint) -> affineint -// CHECK-NEXT: %5 = "dma.enqueue"(%3) : (affineint) -> affineint -// CHECK-NEXT: %6 = affine_apply #map0(%i0) -// CHECK-NEXT: %7 = affine_apply #map1(%6) -// CHECK-NEXT: %8 = "dma.wait"(%7) : (affineint) -> affineint -// CHECK-NEXT: %9 = "compute1"(%7) : (affineint) -> affineint -// CHECK-NEXT: } -// CHECK-NEXT: %10 = affine_apply #map0(%c8) -// CHECK-NEXT: %11 = affine_apply #map1(%10) -// CHECK-NEXT: %12 = "dma.wait"(%11) : (affineint) -> affineint -// CHECK-NEXT: %13 = "compute1"(%11) : (affineint) -> affineint -// CHECK-NEXT: return -mlfunc @loop_nest_dma() { - for %i = 0 to 7 { - %pingpong = affine_apply (d0) -> (d0 mod 2) (%i) - "dma.enqueue"(%pingpong) : (affineint) -> affineint - "dma.enqueue"(%pingpong) : (affineint) -> affineint - %pongping = affine_apply (d0) -> (d0 mod 2) (%i) - "dma.wait"(%pongping) : (affineint) -> affineint - "compute1"(%pongping) : (affineint) -> affineint - } - return -} + %tag = alloc() : memref<1 x f32> + + %zero = constant 0 : affineint + %size = constant 128 : affineint -// CHECK-LABEL: mlfunc @loop_nest_bound_map(%arg0 : affineint) { -// CHECK: %0 = affine_apply #map2()[%arg0] -// CHECK-NEXT: %1 = "foo"(%0) : (affineint) -> affineint -// CHECK-NEXT: %2 = "bar"(%0) : (affineint) -> affineint -// CHECK-NEXT: for %i0 = #map3()[%arg0] to #map4()[%arg0] { -// CHECK-NEXT: %3 = "foo"(%i0) : (affineint) -> affineint -// CHECK-NEXT: %4 = "bar"(%i0) : (affineint) -> affineint -// CHECK-NEXT: %5 = affine_apply #map0(%i0) -// CHECK-NEXT: %6 = "foo_bar"(%5) : (affineint) -> affineint -// CHECK-NEXT: %7 = "bar_foo"(%5) : (affineint) -> affineint -// CHECK-NEXT: } -// CHECK-NEXT: %8 = affine_apply #map5()[%arg0] -// CHECK-NEXT: %9 = affine_apply #map0(%8) -// CHECK-NEXT: %10 = "foo_bar"(%9) : (affineint) -> affineint -// CHECK-NEXT: %11 = "bar_foo"(%9) : (affineint) -> affineint -// CHECK-NEXT: return -mlfunc @loop_nest_bound_map(%N : affineint) { - for %i = %N to ()[s0] -> (s0 + 7)()[%N] { - "foo"(%i) : (affineint) -> affineint - "bar"(%i) : (affineint) -> affineint - "foo_bar"(%i) : (affineint) -> (affineint) - "bar_foo"(%i) : (affineint) -> (affineint) + for %i = 0 to 7 { + "dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, affineint, memref<32 x f32, (d0)->(d0), 1>, affineint, affineint, memref<1 x f32>, affineint) -> () + "dma.finish"(%tag, %zero) : (memref<1 x f32>, affineint) -> affineint + %v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> + %r = "compute"(%v) : (f32) -> (f32) + store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1> + for %j = 0 to 127 { + "do_more_compute"(%i, %j) : (affineint, affineint) -> () + } } return } |

