diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerAffine.cpp | 59 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Transforms/SimplifyAffineStructures.cpp | 37 |
6 files changed, 79 insertions, 67 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b79..c2e1636626d 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,6 +188,16 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast<ForInst>(i).getBody()); break; } + case Instruction::Kind::If: { + auto &ifInst = cast<IfInst>(i); + if (auto *elseBlock = ifInst.getElse()) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(elseBlock); + } + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(ifInst.getThen()); + break; + } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index eebbbe9daa7..cee0a08a63c 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,7 +19,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -100,16 +99,16 @@ public: SmallVector<ForInst *, 4> forInsts; SmallVector<OperationInst *, 4> loadOpInsts; SmallVector<OperationInst *, 4> storeOpInsts; - bool hasNonForRegion = false; + bool hasIfInst = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + void visitIfInst(IfInst *ifInst) { hasIfInst = true; } + void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa<LoadOp>()) + if (opInst->isa<LoadOp>()) loadOpInsts.push_back(opInst); - else if (opInst->isa<StoreOp>()) + if (opInst->isa<StoreOp>()) storeOpInsts.push_back(opInst); } }; @@ -411,8 +410,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if a non 'for' region was found (not currently supported). - if (collector.hasNonForRegion) + // Return false if IfInsts are found (not currently supported). + if (collector.hasIfInst) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -435,18 +434,19 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } + if (auto storeOp = opInst->dyn_cast<StoreOp>()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; } } + // Return false if IfInsts are found (not currently supported). + if (isa<IfInst>(&inst)) + return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d..39ef758833b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,6 +119,15 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } + bool walkIfInstPostOrder(IfInst *ifInst) { + bool hasInnerLoops = + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) + hasInnerLoops |= + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); + return hasInnerLoops; + } + bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f519..ab37ff63bad 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -247,7 +246,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerAffineIf(AffineIfOp *ifOp); + bool lowerIfInst(IfInst *ifInst); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -410,7 +409,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | <code before the AffineIfOp> | +// | <code before the IfInst> | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -454,11 +453,10 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | <code after the AffineIfOp> | +// | <code after the IfInst> | // +--------------------------------+ // -bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { - auto *ifInst = ifOp->getInstruction(); +bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -468,38 +466,22 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the AffineIfOp and add a + // continue blocks. Move the instructions over from the IfInst and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - // If the 'then' block is not empty, then splice the instructions. - auto &oldThenBlocks = ifOp->getThenBlocks(); - if (!oldThenBlocks.empty()) { - // We currently only handle one 'then' block. - if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) - return true; - - Block *oldThen = &oldThenBlocks.front(); - - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); - } - + auto *oldThen = ifInst->getThen(); + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); FuncBuilder builder(thenBlock); builder.create<BranchOp>(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - auto &oldElseBlocks = ifOp->getElseBlocks(); - if (!oldElseBlocks.empty()) { - // We currently only handle one 'else' block. - if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) - return true; - - auto *oldElse = &oldElseBlocks.front(); + if (auto *oldElse = ifInst->getElse()) { elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -511,7 +493,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifOp->getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -611,30 +593,29 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector<Instruction *, 8> instsToRewrite; - // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. - // We do this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the If and For instructions as well as AffineApplyOps. We do + // this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<ForInst>(inst)) + if (isa<IfInst>(inst) || isa<ForInst>(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast<OperationInst>(inst); - if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) + if (op && op->isa<AffineApplyOp>()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast<ForInst>(inst)) { + if (auto *ifInst = dyn_cast<IfInst>(inst)) { + if (lowerIfInst(ifInst)) + return failure(); + } else if (auto *forInst = dyn_cast<ForInst>(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast<OperationInst>(inst); - if (auto ifOp = op->dyn_cast<AffineIfOp>()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + if (lowerAffineApply(op->cast<AffineApplyOp>())) return failure(); - } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c..09d961f85cd 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -560,6 +559,9 @@ static bool instantiateMaterialization(Instruction *inst, if (isa<ForInst>(inst)) return inst->emitError("NYI path ForInst"); + if (isa<IfInst>(inst)) + return inst->emitError("NYI path IfInst"); + // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast<OperationInst>(inst); @@ -568,9 +570,6 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa<AffineApplyOp>()) { return false; } - if (opInst->getNumBlockLists() != 0) - return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c700..bd39e47786a 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,6 +28,7 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; +using llvm::report_fatal_error; namespace { @@ -41,6 +42,9 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); + static char passID; }; @@ -62,19 +66,28 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkOps([&](OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) { - auto simplified = simplifyIntegerSet(setAttr.getValue()); - opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); - } +void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { + auto set = ifInst->getCondition().getIntegerSet(); + ifInst->setIntegerSet(simplifyIntegerSet(set)); +} + +void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); } + } +} + +PassResult SimplifyAffineStructures::runOnFunction(Function *f) { + f->walkInsts([&](Instruction *inst) { + if (auto *opInst = dyn_cast<OperationInst>(inst)) + visitOperationInst(opInst); + if (auto *ifInst = dyn_cast<IfInst>(inst)) + visitIfInst(ifInst); }); return success(); |

