diff options
42 files changed, 545 insertions, 572 deletions
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 5c562e390af..3469133a8ba 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -33,8 +33,8 @@ list of [Functions](#functions), and there are two types of function definitions, a "[CFG Function](#cfg-functions)" and an "[ML Function](#ml-functions)". Both kinds of functions are represented as a composition of [operations](#operations), but represent control flow in -different ways: A CFG Function control flow using a CFG of -[BasicBlocks](#basic-blocks), which contain instructions and end with +different ways: A CFG Function control flow using a CFG of [Blocks](#blocks), +which contain instructions and end with [control flow terminator statements](#terminator-instructions) (like branches). ML Functions represents control flow with a nest of affine loops and if conditions, and are said to contain statements. Both types of functions can call @@ -65,7 +65,7 @@ Here's an example of an MLIR module: // result using a TensorFlow op. The dimensions of A and B are partially // known. The shapes are assumed to match. cfgfunc @mul(tensor<100x?xf32>, tensor<?x50xf32>) -> (tensor<100x50xf32>) { -// Basic block bb0. %A and %B come from function arguments. +// Block bb0. %A and %B come from function arguments. bb0(%A: tensor<100x?xf32>, %B: tensor<?x50xf32>): // Compute the inner dimension of %A using the dim operation. %n = dim %A, 1 : tensor<100x?xf32> @@ -606,9 +606,8 @@ function-type ::= type-list-parens `->` type-list MLIR supports first-class functions: the [`constant` operation](#'constant'-operation) produces the address of a function as an SSA value. This SSA value may be passed to and returned from functions, -merged across control flow boundaries with -[basic block arguments](#basic-blocks), and called with the -[`call_indirect` instruction](#'call_indirect'-operation). +merged across control flow boundaries with [block arguments](#blocks), and +called with the [`call_indirect` instruction](#'call_indirect'-operation). Function types are also used to indicate the arguments and results of [operations](#operations). @@ -916,7 +915,7 @@ Syntax: ``` {.ebnf} cfg-func ::= `cfgfunc` function-signature - (`attributes` attribute-dict)? `{` basic-block+ `}` + (`attributes` attribute-dict)? `{` block+ `}` ``` A simple CFG function that returns its argument twice looks like this: @@ -935,14 +934,14 @@ TensorFlow dataflow graph, where the instructions are TensorFlow "ops" producing values of Tensor type. It can also represent scalar math, and can be used as a way to lower [ML Functions](#ml-functions) before late code generation. -#### Basic Blocks {#basic-blocks} +#### Blocks {#blocks} Syntax: ``` {.ebnf} -basic-block ::= bb-label operation* terminator-stmt -bb-label ::= bb-id bb-arg-list? `:` -bb-id ::= bare-id +block ::= bb-label operation* terminator-stmt +bb-label ::= bb-id bb-arg-list? `:` +bb-id ::= bare-id ssa-id-and-type ::= ssa-id `:` type // Non-empty list of names and types. @@ -954,14 +953,14 @@ bb-arg-list ::= `(` ssa-id-and-type-list? `)` A [basic block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list of operation instructions without control flow (calls are not considered control flow for this purpose) that are executed from top to bottom. The last -instruction in a basic block is a -[terminator instruction](#terminator-instructions), which ends the block. +instruction in a block is a [terminator instruction](#terminator-instructions), +which ends the block. -Basic blocks in MLIR take a list of arguments, which represent SSA PHI nodes in -a functional notation. The arguments are defined by the block, and values are -provided for these basic block arguments by branches that go to the block. +Blocks in MLIR take a list of arguments, which represent SSA PHI nodes in a +functional notation. The arguments are defined by the block, and values are +provided for these block arguments by branches that go to the block. -Here is a simple example function showing branches, returns, and basic block +Here is a simple example function showing branches, returns, and block arguments: ```mlir {.mlir} @@ -987,13 +986,13 @@ bb4(%d : i64, %e : i64): } ``` -**Context:** The "basic block argument" representation eliminates a number of -special cases from the IR compared to traditional "PHI nodes are instructions" -SSA IRs (like LLVM). For example, the +**Context:** The "block argument" representation eliminates a number of special +cases from the IR compared to traditional "PHI nodes are instructions" SSA IRs +(like LLVM). For example, the [parallel copy semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) of SSA is immediately apparent, and function arguments are no longer a special case: they become arguments to the entry block -[[more rationale](Rationale.md#basic-block-arguments-vs-phi-nodes)]. +[[more rationale](Rationale.md#block-arguments-vs-phi-nodes)]. Control flow within a CFG function is implemented with unconditional branches, conditional branches, and a return statement. @@ -1014,9 +1013,9 @@ terminator-stmt ::= `br` bb-id branch-use-list? branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` ``` -The `br` terminator statement represents an unconditional jump to a target basic +The `br` terminator statement represents an unconditional jump to a target block. The count and types of operands to the branch must align with the -arguments in the target basic block. +arguments in the target block. The MLIR branch instruction is not allowed to target the entry block for a function. @@ -1040,7 +1039,7 @@ for a function. The two destinations of the conditional branch instruction are allowed to be the same. The following example illustrates a CFG function with a conditional branch -instruction that targets the same basic block: +instruction that targets the same block: ```mlir {.mlir} cfgfunc @select(%a : i32, %b :i32, %flag : i1) -> i32 { @@ -1318,8 +1317,8 @@ operation ::= ssa-id `=` `call_indirect` ssa-use The `call_indirect` operation represents an indirect call to a value of function type. Functions are first class types in MLIR, and may be passed as arguments -and merged together with basic block arguments. The operands and result types of -the call must match the specified function type. +and merged together with block arguments. The operands and result types of the +call must match the specified function type. Function values can be created with the [`constant` operation](#'constant'-operation). diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 791fe31fce9..17cbd1d15c1 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -171,15 +171,14 @@ type - memref<8x%Nxf32>. We went for the current approach in MLIR because it simplifies the design --- types remain immutable when the values of symbols change. -### Basic Block Arguments vs PHI nodes {#basic-block-arguments-vs-phi-nodes} +### Block Arguments vs PHI nodes {#block-arguments-vs-phi-nodes} -MLIR CFG Functions represent SSA using -"[basic block arguments](LangRef.md#basic-blocks)" rather than -[PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in LLVM. This -choice is representationally identical (the same constructs can be represented -in either form) but basic block arguments have several advantages: +MLIR CFG Functions represent SSA using "[block arguments](LangRef.md#blocks)" +rather than [PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in +LLVM. This choice is representationally identical (the same constructs can be +represented in either form) but block arguments have several advantages: -1. LLVM PHI nodes always have to be kept at the top of a basic block, and +1. LLVM PHI nodes always have to be kept at the top of a block, and transformations frequently have to manually skip over them. This is defined away with BB arguments. 1. LLVM has a separate function Argument node. This is defined away with BB @@ -202,7 +201,7 @@ in either form) but basic block arguments have several advantages: but SIL uses it extensively, e.g. in the [switch_enum instruction](https://github.com/apple/swift/blob/master/docs/SIL.rst#switch-enum). -For more context, basic block arguments were previously used in the Swift +For more context, block arguments were previously used in the Swift [SIL Intermediate Representation](https://github.com/apple/swift/blob/master/docs/SIL.rst), and described in [a talk on YouTube](https://www.youtube.com/watch?v=Ntj8ab-5cvE). The section of @@ -474,7 +473,7 @@ for (i=0; i <N; i++) // non-affine loop bound for k loop for (k=0; k<pow(2,j); k++) for (l=0; l<N; l++) { - // basic block loop body + // block loop body ... } ``` diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 6d3853e396d..172ddc743b4 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -21,15 +21,15 @@ #include "mlir/IR/FunctionGraphTraits.h" #include "llvm/Support/GenericDomTree.h" -extern template class llvm::DominatorTreeBase<mlir::BasicBlock, false>; -extern template class llvm::DominatorTreeBase<mlir::BasicBlock, true>; -extern template class llvm::DomTreeNodeBase<mlir::BasicBlock>; +extern template class llvm::DominatorTreeBase<mlir::Block, false>; +extern template class llvm::DominatorTreeBase<mlir::Block, true>; +extern template class llvm::DomTreeNodeBase<mlir::Block>; namespace llvm { namespace DomTreeBuilder { -using MLIRDomTree = llvm::DomTreeBase<mlir::BasicBlock>; -using MLIRPostDomTree = llvm::PostDomTreeBase<mlir::BasicBlock>; +using MLIRDomTree = llvm::DomTreeBase<mlir::Block>; +using MLIRPostDomTree = llvm::PostDomTreeBase<mlir::Block>; // extern template void Calculate<MLIRDomTree>(MLIRDomTree &DT); // extern template void Calculate<MLIRPostDomTree>(MLIRPostDomTree &DT); @@ -38,9 +38,9 @@ using MLIRPostDomTree = llvm::PostDomTreeBase<mlir::BasicBlock>; } // namespace llvm namespace mlir { -using DominatorTreeBase = llvm::DominatorTreeBase<BasicBlock, false>; -using PostDominatorTreeBase = llvm::DominatorTreeBase<BasicBlock, true>; -using DominanceInfoNode = llvm::DomTreeNodeBase<BasicBlock>; +using DominatorTreeBase = llvm::DominatorTreeBase<Block, false>; +using PostDominatorTreeBase = llvm::DominatorTreeBase<Block, true>; +using DominanceInfoNode = llvm::DomTreeNodeBase<Block>; /// A class for computing basic dominance information. class DominanceInfo : public DominatorTreeBase { diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/Block.h index 916834dfbdc..985d0fdb075 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/Block.h @@ -1,4 +1,4 @@ -//===- StmtBlock.h ----------------------------------------------*- C++ -*-===// +//===- Block.h - MLIR Block and BlockList Classes ---------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -15,53 +15,53 @@ // limitations under the License. // ============================================================================= // -// This file defines StmtBlock and *Stmt classes that extend Statement. +// This file defines Block and BlockList classes. // //===----------------------------------------------------------------------===// -#ifndef MLIR_IR_STMTBLOCK_H -#define MLIR_IR_STMTBLOCK_H +#ifndef MLIR_IR_BLOCK_H +#define MLIR_IR_BLOCK_H #include "mlir/IR/Statement.h" #include "llvm/ADT/PointerUnion.h" namespace mlir { class IfStmt; -class StmtBlockList; +class BlockList; template <typename BlockType> class PredecessorIterator; template <typename BlockType> class SuccessorIterator; -/// Blocks represents an ordered list of Instructions. -class StmtBlock - : public IRObjectWithUseList, - public llvm::ilist_node_with_parent<StmtBlock, StmtBlockList> { +/// `Block` represents an ordered list of `Instruction`s. +class Block : public IRObjectWithUseList, + public llvm::ilist_node_with_parent<Block, BlockList> { public: - explicit StmtBlock() {} - ~StmtBlock(); + explicit Block() {} + ~Block(); void clear() { - // Clear statements in the reverse order so that uses are destroyed + // Clear instructions in the reverse order so that uses are destroyed // before their defs. while (!empty()) - statements.pop_back(); + instructions.pop_back(); } - StmtBlockList *getParent() const { return parent; } + /// Blocks are maintained in a list by BlockList type. + BlockList *getParent() const { return parent; } - /// Returns the closest surrounding statement that contains this block or - /// nullptr if this is a top-level statement block. - Statement *getContainingStmt(); + /// Returns the closest surrounding instruction that contains this block or + /// nullptr if this is a top-level block. + Instruction *getContainingInst(); - const Statement *getContainingStmt() const { - return const_cast<StmtBlock *>(this)->getContainingStmt(); + const Instruction *getContainingInst() const { + return const_cast<Block *>(this)->getContainingInst(); } - /// Returns the function that this statement block is part of. The function - /// is determined by traversing the chain of parent statements. + /// Returns the function that this block is part of, even if the block is + /// nested under an IfStmt or ForStmt. Function *getFunction(); const Function *getFunction() const { - return const_cast<StmtBlock *>(this)->getFunction(); + return const_cast<Block *>(this)->getFunction(); } //===--------------------------------------------------------------------===// @@ -97,47 +97,46 @@ public: const BlockArgument *getArgument(unsigned i) const { return arguments[i]; } //===--------------------------------------------------------------------===// - // Statement list management + // Instruction list management //===--------------------------------------------------------------------===// - /// This is the list of statements in the block. - using StmtListType = llvm::iplist<Statement>; - StmtListType &getStatements() { return statements; } - const StmtListType &getStatements() const { return statements; } - - // Iteration over the statements in the block. - using iterator = StmtListType::iterator; - using const_iterator = StmtListType::const_iterator; - using reverse_iterator = StmtListType::reverse_iterator; - using const_reverse_iterator = StmtListType::const_reverse_iterator; - - iterator begin() { return statements.begin(); } - iterator end() { return statements.end(); } - const_iterator begin() const { return statements.begin(); } - const_iterator end() const { return statements.end(); } - reverse_iterator rbegin() { return statements.rbegin(); } - reverse_iterator rend() { return statements.rend(); } - const_reverse_iterator rbegin() const { return statements.rbegin(); } - const_reverse_iterator rend() const { return statements.rend(); } - - bool empty() const { return statements.empty(); } - void push_back(Statement *stmt) { statements.push_back(stmt); } - void push_front(Statement *stmt) { statements.push_front(stmt); } - - Statement &back() { return statements.back(); } - const Statement &back() const { - return const_cast<StmtBlock *>(this)->back(); - } - Statement &front() { return statements.front(); } - const Statement &front() const { - return const_cast<StmtBlock *>(this)->front(); + /// This is the list of instructions in the block. + using InstListType = llvm::iplist<Instruction>; + InstListType &getInstructions() { return instructions; } + const InstListType &getInstructions() const { return instructions; } + + // Iteration over the instructions in the block. + using iterator = InstListType::iterator; + using const_iterator = InstListType::const_iterator; + using reverse_iterator = InstListType::reverse_iterator; + using const_reverse_iterator = InstListType::const_reverse_iterator; + + iterator begin() { return instructions.begin(); } + iterator end() { return instructions.end(); } + const_iterator begin() const { return instructions.begin(); } + const_iterator end() const { return instructions.end(); } + reverse_iterator rbegin() { return instructions.rbegin(); } + reverse_iterator rend() { return instructions.rend(); } + const_reverse_iterator rbegin() const { return instructions.rbegin(); } + const_reverse_iterator rend() const { return instructions.rend(); } + + bool empty() const { return instructions.empty(); } + void push_back(Instruction *inst) { instructions.push_back(inst); } + void push_front(Instruction *inst) { instructions.push_front(inst); } + + Instruction &back() { return instructions.back(); } + const Instruction &back() const { return const_cast<Block *>(this)->back(); } + Instruction &front() { return instructions.front(); } + const Instruction &front() const { + return const_cast<Block *>(this)->front(); } - /// Returns the statement's position in this block or -1 if the statement is - /// not present. - int64_t findStmtPosInBlock(const Statement &stmt) const { + /// Returns the instructions's position in this block or -1 if the instruction + /// is not present. + /// TODO: This is needlessly inefficient, and should not be API on Block. + int64_t findInstPositionInBlock(const Instruction &stmt) const { int64_t j = 0; - for (const auto &s : statements) { + for (const auto &s : instructions) { if (&s == &stmt) return j; j++; @@ -145,12 +144,14 @@ public: return -1; } - /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the - /// ancestor statement of 'stmt' that lies in this block. Returns nullptr if + /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the + /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. - const Statement *findAncestorStmtInBlock(const Statement &stmt) const; - Statement *findAncestorStmtInBlock(Statement *stmt) { - return const_cast<Statement *>(findAncestorStmtInBlock(*stmt)); + /// TODO: This is very specific functionality that should live somewhere else. + const Instruction *findAncestorInstInBlock(const Instruction &inst) const; + /// TODO: This const overload is wrong. + Instruction *findAncestorInstInBlock(Instruction *inst) { + return const_cast<Instruction *>(findAncestorInstInBlock(*inst)); } //===--------------------------------------------------------------------===// @@ -162,7 +163,7 @@ public: OperationInst *getTerminator(); const OperationInst *getTerminator() const { - return const_cast<StmtBlock *>(this)->getTerminator(); + return const_cast<Block *>(this)->getTerminator(); } //===--------------------------------------------------------------------===// @@ -170,12 +171,12 @@ public: //===--------------------------------------------------------------------===// // Predecessor iteration. - using const_pred_iterator = PredecessorIterator<const StmtBlock>; + using const_pred_iterator = PredecessorIterator<const Block>; const_pred_iterator pred_begin() const; const_pred_iterator pred_end() const; llvm::iterator_range<const_pred_iterator> getPredecessors() const; - using pred_iterator = PredecessorIterator<StmtBlock>; + using pred_iterator = PredecessorIterator<Block>; pred_iterator pred_begin(); pred_iterator pred_end(); llvm::iterator_range<pred_iterator> getPredecessors(); @@ -189,26 +190,26 @@ public: /// Note that if a block has duplicate predecessors from a single block (e.g. /// if you have a conditional branch with the same block as the true/false /// destinations) is not considered to be a single predecessor. - StmtBlock *getSinglePredecessor(); + Block *getSinglePredecessor(); - const StmtBlock *getSinglePredecessor() const { - return const_cast<StmtBlock *>(this)->getSinglePredecessor(); + const Block *getSinglePredecessor() const { + return const_cast<Block *>(this)->getSinglePredecessor(); } // Indexed successor access. unsigned getNumSuccessors() const; - const StmtBlock *getSuccessor(unsigned i) const { - return const_cast<StmtBlock *>(this)->getSuccessor(i); + const Block *getSuccessor(unsigned i) const { + return const_cast<Block *>(this)->getSuccessor(i); } - StmtBlock *getSuccessor(unsigned i); + Block *getSuccessor(unsigned i); // Successor iteration. - using const_succ_iterator = SuccessorIterator<const StmtBlock>; + using const_succ_iterator = SuccessorIterator<const Block>; const_succ_iterator succ_begin() const; const_succ_iterator succ_end() const; llvm::iterator_range<const_succ_iterator> getSuccessors() const; - using succ_iterator = SuccessorIterator<StmtBlock>; + using succ_iterator = SuccessorIterator<Block>; succ_iterator succ_begin(); succ_iterator succ_end(); llvm::iterator_range<succ_iterator> getSuccessors(); @@ -226,18 +227,18 @@ public: /// Note that all instructions BEFORE the specified iterator stay as part of /// the original basic block, an unconditional branch is added to the original /// block (going to the new block), and the rest of the instructions in the - /// original block are moved to the new BB, including the old terminator. The - /// newly formed Block is returned. + /// original block are moved to the new block, including the old terminator. + /// The newly formed Block is returned. /// /// This function invalidates the specified iterator. - StmtBlock *splitBasicBlock(iterator splitBefore); - StmtBlock *splitBasicBlock(Instruction *splitBeforeInst) { - return splitBasicBlock(iterator(splitBeforeInst)); + Block *splitBlock(iterator splitBefore); + Block *splitBlock(Instruction *splitBeforeInst) { + return splitBlock(iterator(splitBeforeInst)); } - /// getSublistAccess() - Returns pointer to member of statement list - static StmtListType StmtBlock::*getSublistAccess(Statement *) { - return &StmtBlock::statements; + /// Returns pointer to member of instruction list. + static InstListType Block::*getSublistAccess(Instruction *) { + return &Block::instructions; } void print(raw_ostream &os) const; @@ -249,42 +250,41 @@ public: void printAsOperand(raw_ostream &os, bool printType = true); private: - /// This is the parent function/IfStmt/ForStmt that owns this block. - StmtBlockList *parent = nullptr; + /// This is the parent object that owns this block. + BlockList *parent = nullptr; - /// This is the list of statements in the block. - StmtListType statements; + /// This is the list of instructions in the block. + InstListType instructions; /// This is the list of arguments to the block. std::vector<BlockArgument *> arguments; - StmtBlock(const StmtBlock &) = delete; - void operator=(const StmtBlock &) = delete; + Block(const Block &) = delete; + void operator=(const Block &) = delete; - friend struct llvm::ilist_traits<StmtBlock>; + friend struct llvm::ilist_traits<Block>; }; } // end namespace mlir //===----------------------------------------------------------------------===// -// ilist_traits for StmtBlock +// ilist_traits for Block //===----------------------------------------------------------------------===// namespace llvm { template <> -struct ilist_traits<::mlir::StmtBlock> - : public ilist_alloc_traits<::mlir::StmtBlock> { - using StmtBlock = ::mlir::StmtBlock; - using block_iterator = simple_ilist<::mlir::StmtBlock>::iterator; - - void addNodeToList(StmtBlock *block); - void removeNodeFromList(StmtBlock *block); - void transferNodesFromList(ilist_traits<StmtBlock> &otherList, +struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> { + using Block = ::mlir::Block; + using block_iterator = simple_ilist<::mlir::Block>::iterator; + + void addNodeToList(Block *block); + void removeNodeFromList(Block *block); + void transferNodesFromList(ilist_traits<Block> &otherList, block_iterator first, block_iterator last); private: - mlir::StmtBlockList *getContainingBlockList(); + mlir::BlockList *getContainingBlockList(); }; } // end namespace llvm @@ -292,12 +292,12 @@ namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it /// is part of - a Function or IfStmt or ForStmt. -class StmtBlockList { +class BlockList { public: - explicit StmtBlockList(Function *container); - explicit StmtBlockList(Statement *container); + explicit BlockList(Function *container); + explicit BlockList(Instruction *container); - using BlockListType = llvm::iplist<StmtBlock>; + using BlockListType = llvm::iplist<Block>; BlockListType &getBlocks() { return blocks; } const BlockListType &getBlocks() const { return blocks; } @@ -317,50 +317,39 @@ public: const_reverse_iterator rend() const { return blocks.rend(); } bool empty() const { return blocks.empty(); } - void push_back(StmtBlock *block) { blocks.push_back(block); } - void push_front(StmtBlock *block) { blocks.push_front(block); } + void push_back(Block *block) { blocks.push_back(block); } + void push_front(Block *block) { blocks.push_front(block); } - StmtBlock &back() { return blocks.back(); } - const StmtBlock &back() const { - return const_cast<StmtBlockList *>(this)->back(); - } + Block &back() { return blocks.back(); } + const Block &back() const { return const_cast<BlockList *>(this)->back(); } - StmtBlock &front() { return blocks.front(); } - const StmtBlock &front() const { - return const_cast<StmtBlockList *>(this)->front(); - } + Block &front() { return blocks.front(); } + const Block &front() const { return const_cast<BlockList *>(this)->front(); } /// getSublistAccess() - Returns pointer to member of block list. - static BlockListType StmtBlockList::*getSublistAccess(StmtBlock *) { - return &StmtBlockList::blocks; + static BlockListType BlockList::*getSublistAccess(Block *) { + return &BlockList::blocks; } - /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is + /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is /// part of an IfStmt/ForStmt, then return it, otherwise return null. - Statement *getContainingStmt(); - const Statement *getContainingStmt() const { - return const_cast<StmtBlockList *>(this)->getContainingStmt(); + Instruction *getContainingInst(); + const Instruction *getContainingInst() const { + return const_cast<BlockList *>(this)->getContainingInst(); } - /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is + /// A BlockList is part of a Function or and IfStmt/ForStmt. If it is /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { - return const_cast<StmtBlockList *>(this)->getContainingFunction(); - } - - // TODO(clattner): This is only to help ML -> CFG migration, remove in the - // near future. This makes StmtBlockList work more like BasicBlock did. - Function *getFunction(); - const Function *getFunction() const { - return const_cast<StmtBlockList *>(this)->getFunction(); + return const_cast<BlockList *>(this)->getContainingFunction(); } private: BlockListType blocks; /// This is the object we are part of. - llvm::PointerUnion<Function *, Statement *> container; + llvm::PointerUnion<Function *, Instruction *> container; }; //===----------------------------------------------------------------------===// @@ -369,7 +358,7 @@ private: /// Implement a predecessor iterator as a forward iterator. This works by /// walking the use lists of the blocks. The entries on this list are the -/// StmtBlockOperands that are embedded into terminator instructions. From the +/// BlockOperands that are embedded into terminator instructions. From the /// operand, we can get the terminator that contains it, and it's parent block /// is the predecessor. template <typename BlockType> @@ -378,7 +367,7 @@ class PredecessorIterator std::forward_iterator_tag, BlockType *> { public: - PredecessorIterator(StmtBlockOperand *firstOperand) + PredecessorIterator(BlockOperand *firstOperand) : bbUseIterator(firstOperand) {} PredecessorIterator &operator=(const PredecessorIterator &rhs) { @@ -406,33 +395,32 @@ public: } private: - using BBUseIterator = ValueUseIterator<StmtBlockOperand, OperationInst>; + using BBUseIterator = ValueUseIterator<BlockOperand, OperationInst>; BBUseIterator bbUseIterator; }; -inline auto StmtBlock::pred_begin() const -> const_pred_iterator { - return const_pred_iterator((StmtBlockOperand *)getFirstUse()); +inline auto Block::pred_begin() const -> const_pred_iterator { + return const_pred_iterator((BlockOperand *)getFirstUse()); } -inline auto StmtBlock::pred_end() const -> const_pred_iterator { +inline auto Block::pred_end() const -> const_pred_iterator { return const_pred_iterator(nullptr); } -inline auto StmtBlock::getPredecessors() const +inline auto Block::getPredecessors() const -> llvm::iterator_range<const_pred_iterator> { return {pred_begin(), pred_end()}; } -inline auto StmtBlock::pred_begin() -> pred_iterator { - return pred_iterator((StmtBlockOperand *)getFirstUse()); +inline auto Block::pred_begin() -> pred_iterator { + return pred_iterator((BlockOperand *)getFirstUse()); } -inline auto StmtBlock::pred_end() -> pred_iterator { +inline auto Block::pred_end() -> pred_iterator { return pred_iterator(nullptr); } -inline auto StmtBlock::getPredecessors() - -> llvm::iterator_range<pred_iterator> { +inline auto Block::getPredecessors() -> llvm::iterator_range<pred_iterator> { return {pred_begin(), pred_end()}; } @@ -440,7 +428,7 @@ inline auto StmtBlock::getPredecessors() // Successors //===----------------------------------------------------------------------===// -/// This template implments the successor iterators for StmtBlock. +/// This template implements the successor iterators for Block. template <typename BlockType> class SuccessorIterator final : public IndexedAccessorIterator<SuccessorIterator<BlockType>, BlockType, @@ -468,30 +456,31 @@ public: unsigned getSuccessorIndex() const { return this->index; } }; -inline auto StmtBlock::succ_begin() const -> const_succ_iterator { +inline auto Block::succ_begin() const -> const_succ_iterator { return const_succ_iterator(this, 0); } -inline auto StmtBlock::succ_end() const -> const_succ_iterator { +inline auto Block::succ_end() const -> const_succ_iterator { return const_succ_iterator(this, getNumSuccessors()); } -inline auto StmtBlock::getSuccessors() const +inline auto Block::getSuccessors() const -> llvm::iterator_range<const_succ_iterator> { return {succ_begin(), succ_end()}; } -inline auto StmtBlock::succ_begin() -> succ_iterator { +inline auto Block::succ_begin() -> succ_iterator { return succ_iterator(this, 0); } -inline auto StmtBlock::succ_end() -> succ_iterator { +inline auto Block::succ_end() -> succ_iterator { return succ_iterator(this, getNumSuccessors()); } -inline auto StmtBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> { +inline auto Block::getSuccessors() -> llvm::iterator_range<succ_iterator> { return {succ_begin(), succ_end()}; } } // end namespace mlir -#endif // MLIR_IR_STMTBLOCK_H + +#endif // MLIR_IR_BLOCK_H diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1ad533b0983..5c1331e880d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -178,11 +178,11 @@ public: setInsertionPoint(stmt); } - FuncBuilder(StmtBlock *block) : FuncBuilder(block->getFunction()) { + FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) { setInsertionPoint(block, block->end()); } - FuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint) + FuncBuilder(Block *block, Block::iterator insertPoint) : FuncBuilder(block->getFunction()) { setInsertionPoint(block, insertPoint); } @@ -195,11 +195,11 @@ public: /// current insertion point a builder refers to is being removed. void clearInsertionPoint() { this->block = nullptr; - insertPoint = StmtBlock::iterator(); + insertPoint = Block::iterator(); } /// Set the insertion point to the specified location. - void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) { + void setInsertionPoint(Block *block, Block::iterator insertPoint) { // TODO: check that insertPoint is in this rather than some other block. this->block = block; this->insertPoint = insertPoint; @@ -208,31 +208,31 @@ public: /// Sets the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. void setInsertionPoint(Statement *stmt) { - setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt)); + setInsertionPoint(stmt->getBlock(), Block::iterator(stmt)); } /// Sets the insertion point to the start of the specified block. - void setInsertionPointToStart(StmtBlock *block) { + void setInsertionPointToStart(Block *block) { setInsertionPoint(block, block->begin()); } /// Sets the insertion point to the end of the specified block. - void setInsertionPointToEnd(StmtBlock *block) { + void setInsertionPointToEnd(Block *block) { setInsertionPoint(block, block->end()); } /// Return the block the current insertion point belongs to. Note that the /// the insertion point is not necessarily the end of the block. - BasicBlock *getInsertionBlock() const { return block; } + Block *getInsertionBlock() const { return block; } /// Returns the current insertion point of the builder. - StmtBlock::iterator getInsertionPoint() const { return insertPoint; } + Block::iterator getInsertionPoint() const { return insertPoint; } /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. - StmtBlock *createBlock(StmtBlock *insertBefore = nullptr); + Block *createBlock(Block *insertBefore = nullptr); /// Returns a builder for the body of a for Stmt. static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) { @@ -240,7 +240,7 @@ public: } /// Returns the current block of the builder. - StmtBlock *getBlock() const { return block; } + Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. OperationInst *createOperation(const OperationState &state); @@ -286,7 +286,7 @@ public: Statement *clone(const Statement &stmt, OperationInst::OperandMapTy &operandMapping) { Statement *cloneStmt = stmt.clone(operandMapping, getContext()); - block->getStatements().insert(insertPoint, cloneStmt); + block->getInstructions().insert(insertPoint, cloneStmt); return cloneStmt; } @@ -305,8 +305,8 @@ public: private: Function *function; - StmtBlock *block = nullptr; - StmtBlock::iterator insertPoint; + Block *block = nullptr; + Block::iterator insertPoint; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 3ccfe4f9f2d..e608a704f99 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -99,7 +99,7 @@ class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands, public: static StringRef getOperationName() { return "br"; } - static void build(Builder *builder, OperationState *result, BasicBlock *dest, + static void build(Builder *builder, OperationState *result, Block *dest, ArrayRef<Value *> operands = {}); // Hooks to customize behavior of this op. @@ -108,11 +108,11 @@ public: bool verify() const; /// Return the block this branch jumps to. - BasicBlock *getDest(); - const BasicBlock *getDest() const { + Block *getDest(); + const Block *getDest() const { return const_cast<BranchOp *>(this)->getDest(); } - void setDest(BasicBlock *block); + void setDest(Block *block); /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); @@ -147,8 +147,8 @@ public: static StringRef getOperationName() { return "cond_br"; } static void build(Builder *builder, OperationState *result, Value *condition, - BasicBlock *trueDest, ArrayRef<Value *> trueOperands, - BasicBlock *falseDest, ArrayRef<Value *> falseOperands); + Block *trueDest, ArrayRef<Value *> trueOperands, + Block *falseDest, ArrayRef<Value *> falseOperands); // Hooks to customize behavior of this op. static bool parse(OpAsmParser *parser, OperationState *result); @@ -160,14 +160,14 @@ public: const Value *getCondition() const { return getOperand(0); } /// Return the destination if the condition is true. - BasicBlock *getTrueDest(); - const BasicBlock *getTrueDest() const { + Block *getTrueDest(); + const Block *getTrueDest() const { return const_cast<CondBranchOp *>(this)->getTrueDest(); } /// Return the destination if the condition is false. - BasicBlock *getFalseDest(); - const BasicBlock *getFalseDest() const { + Block *getFalseDest(); + const Block *getFalseDest() const { return const_cast<CondBranchOp *>(this)->getFalseDest(); } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 5b52a5de7e7..b79b64b68b5 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -25,9 +25,9 @@ #define MLIR_IR_FUNCTION_H #include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" -#include "mlir/IR/StmtBlock.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ilist.h" @@ -38,7 +38,6 @@ class FunctionType; class MLIRContext; class Module; template <typename ObjectType, typename ElementType> class ArgumentIterator; -using BasicBlock = StmtBlock; /// NamedAttribute is used for function attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -82,11 +81,11 @@ public: // Body Handling //===--------------------------------------------------------------------===// - StmtBlockList &getBlockList() { return blocks; } - const StmtBlockList &getBlockList() const { return blocks; } + BlockList &getBlockList() { return blocks; } + const BlockList &getBlockList() const { return blocks; } /// This is the list of blocks in the function. - using BlockListType = llvm::iplist<BasicBlock>; + using BlockListType = llvm::iplist<Block>; BlockListType &getBlocks() { return blocks.getBlocks(); } const BlockListType &getBlocks() const { return blocks.getBlocks(); } @@ -106,29 +105,25 @@ public: const_reverse_iterator rend() const { return blocks.rend(); } bool empty() const { return blocks.empty(); } - void push_back(BasicBlock *block) { blocks.push_back(block); } - void push_front(BasicBlock *block) { blocks.push_front(block); } + void push_back(Block *block) { blocks.push_back(block); } + void push_front(Block *block) { blocks.push_front(block); } - BasicBlock &back() { return blocks.back(); } - const BasicBlock &back() const { - return const_cast<Function *>(this)->back(); - } + Block &back() { return blocks.back(); } + const Block &back() const { return const_cast<Function *>(this)->back(); } - BasicBlock &front() { return blocks.front(); } - const BasicBlock &front() const { - return const_cast<Function *>(this)->front(); - } + Block &front() { return blocks.front(); } + const Block &front() const { return const_cast<Function *>(this)->front(); } /// Return the 'return' statement of this Function. const OperationInst *getReturnStmt() const; OperationInst *getReturnStmt(); // These should only be used on MLFunctions. - StmtBlock *getBody() { + Block *getBody() { assert(isML()); return &blocks.front(); } - const StmtBlock *getBody() const { + const Block *getBody() const { return const_cast<Function *>(this)->getBody(); } @@ -218,7 +213,7 @@ private: AttributeListStorage *attrs; /// The contents of the body. - StmtBlockList blocks; + BlockList blocks; void operator=(const Function &) = delete; friend struct llvm::ilist_traits<Function>; diff --git a/mlir/include/mlir/IR/FunctionGraphTraits.h b/mlir/include/mlir/IR/FunctionGraphTraits.h index 54305c90d25..6ba50e7ca9e 100644 --- a/mlir/include/mlir/IR/FunctionGraphTraits.h +++ b/mlir/include/mlir/IR/FunctionGraphTraits.h @@ -28,9 +28,9 @@ #include "llvm/ADT/GraphTraits.h" namespace llvm { -template <> struct GraphTraits<mlir::BasicBlock *> { - using ChildIteratorType = mlir::BasicBlock::succ_iterator; - using Node = mlir::BasicBlock; +template <> struct GraphTraits<mlir::Block *> { + using ChildIteratorType = mlir::Block::succ_iterator; + using Node = mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(NodeRef bb) { return bb; } @@ -41,9 +41,9 @@ template <> struct GraphTraits<mlir::BasicBlock *> { static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } }; -template <> struct GraphTraits<const mlir::BasicBlock *> { - using ChildIteratorType = mlir::BasicBlock::const_succ_iterator; - using Node = const mlir::BasicBlock; +template <> struct GraphTraits<const mlir::Block *> { + using ChildIteratorType = mlir::Block::const_succ_iterator; + using Node = const mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(NodeRef bb) { return bb; } @@ -54,9 +54,9 @@ template <> struct GraphTraits<const mlir::BasicBlock *> { static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } }; -template <> struct GraphTraits<Inverse<mlir::BasicBlock *>> { - using ChildIteratorType = mlir::BasicBlock::pred_iterator; - using Node = mlir::BasicBlock; +template <> struct GraphTraits<Inverse<mlir::Block *>> { + using ChildIteratorType = mlir::Block::pred_iterator; + using Node = mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) { return inverseGraph.Graph; @@ -69,9 +69,9 @@ template <> struct GraphTraits<Inverse<mlir::BasicBlock *>> { } }; -template <> struct GraphTraits<Inverse<const mlir::BasicBlock *>> { - using ChildIteratorType = mlir::BasicBlock::const_pred_iterator; - using Node = const mlir::BasicBlock; +template <> struct GraphTraits<Inverse<const mlir::Block *>> { + using ChildIteratorType = mlir::Block::const_pred_iterator; + using Node = const mlir::Block; using NodeRef = Node *; static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) { @@ -86,9 +86,9 @@ template <> struct GraphTraits<Inverse<const mlir::BasicBlock *>> { }; template <> -struct GraphTraits<mlir::Function *> : public GraphTraits<mlir::BasicBlock *> { +struct GraphTraits<mlir::Function *> : public GraphTraits<mlir::Block *> { using GraphType = mlir::Function *; - using NodeRef = mlir::BasicBlock *; + using NodeRef = mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -103,9 +103,9 @@ struct GraphTraits<mlir::Function *> : public GraphTraits<mlir::BasicBlock *> { template <> struct GraphTraits<const mlir::Function *> - : public GraphTraits<const mlir::BasicBlock *> { + : public GraphTraits<const mlir::Block *> { using GraphType = const mlir::Function *; - using NodeRef = const mlir::BasicBlock *; + using NodeRef = const mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -120,7 +120,7 @@ struct GraphTraits<const mlir::Function *> template <> struct GraphTraits<Inverse<mlir::Function *>> - : public GraphTraits<Inverse<mlir::BasicBlock *>> { + : public GraphTraits<Inverse<mlir::Block *>> { using GraphType = Inverse<mlir::Function *>; using NodeRef = NodeRef; @@ -137,7 +137,7 @@ struct GraphTraits<Inverse<mlir::Function *>> template <> struct GraphTraits<Inverse<const mlir::Function *>> - : public GraphTraits<Inverse<const mlir::BasicBlock *>> { + : public GraphTraits<Inverse<const mlir::Block *>> { using GraphType = Inverse<const mlir::Function *>; using NodeRef = NodeRef; @@ -153,10 +153,9 @@ struct GraphTraits<Inverse<const mlir::Function *>> }; template <> -struct GraphTraits<mlir::StmtBlockList *> - : public GraphTraits<mlir::BasicBlock *> { - using GraphType = mlir::StmtBlockList *; - using NodeRef = mlir::BasicBlock *; +struct GraphTraits<mlir::BlockList *> : public GraphTraits<mlir::Block *> { + using GraphType = mlir::BlockList *; + using NodeRef = mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -170,10 +169,10 @@ struct GraphTraits<mlir::StmtBlockList *> }; template <> -struct GraphTraits<const mlir::StmtBlockList *> - : public GraphTraits<const mlir::BasicBlock *> { - using GraphType = const mlir::StmtBlockList *; - using NodeRef = const mlir::BasicBlock *; +struct GraphTraits<const mlir::BlockList *> + : public GraphTraits<const mlir::Block *> { + using GraphType = const mlir::BlockList *; + using NodeRef = const mlir::Block *; static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } @@ -187,9 +186,9 @@ struct GraphTraits<const mlir::StmtBlockList *> }; template <> -struct GraphTraits<Inverse<mlir::StmtBlockList *>> - : public GraphTraits<Inverse<mlir::BasicBlock *>> { - using GraphType = Inverse<mlir::StmtBlockList *>; +struct GraphTraits<Inverse<mlir::BlockList *>> + : public GraphTraits<Inverse<mlir::Block *>> { + using GraphType = Inverse<mlir::BlockList *>; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } @@ -204,9 +203,9 @@ struct GraphTraits<Inverse<mlir::StmtBlockList *>> }; template <> -struct GraphTraits<Inverse<const mlir::StmtBlockList *>> - : public GraphTraits<Inverse<const mlir::BasicBlock *>> { - using GraphType = Inverse<const mlir::StmtBlockList *>; +struct GraphTraits<Inverse<const mlir::BlockList *>> + : public GraphTraits<Inverse<const mlir::Block *>> { + using GraphType = Inverse<const mlir::BlockList *>; using NodeRef = NodeRef; static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 4e840409a27..e1b90b6e39e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -752,14 +752,14 @@ public: return this->getInstruction()->getNumSuccessorOperands(index); } - const BasicBlock *getSuccessor(unsigned index) const { + const Block *getSuccessor(unsigned index) const { return this->getInstruction()->getSuccessor(index); } - BasicBlock *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) { return this->getInstruction()->getSuccessor(index); } - void setSuccessor(BasicBlock *block, unsigned index) { + void setSuccessor(Block *block, unsigned index) { return this->getInstruction()->setSuccessor(block, index); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 9ebc55b2ae8..587eabdee96 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -264,7 +264,7 @@ public: virtual bool parseOperand(OperandType &result) = 0; /// Parse a single operation successor and it's operand list. - virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + virtual bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value *> &operands) = 0; /// These are the supported delimiters around operand lists, used by diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 2bc75a2a40d..15c882b90f7 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -31,6 +31,7 @@ #include <memory> namespace mlir { +class Block; class Dialect; class OperationInst; class OperationState; @@ -39,10 +40,8 @@ class OpAsmParserResult; class OpAsmPrinter; class Pattern; class RewritePattern; -class StmtBlock; class Type; class Value; -using BasicBlock = StmtBlock; /// This is a vector that owns the patterns inside of it. using OwningPatternList = std::vector<std::unique_ptr<Pattern>>; @@ -209,7 +208,7 @@ struct OperationState { SmallVector<Type, 4> types; SmallVector<NamedAttribute, 4> attributes; /// Successors of this operation and their respective operands. - SmallVector<StmtBlock *, 1> successors; + SmallVector<Block *, 1> successors; public: OperationState(MLIRContext *context, Location location, StringRef name) @@ -221,7 +220,7 @@ public: OperationState(MLIRContext *context, Location location, StringRef name, ArrayRef<Value *> operands, ArrayRef<Type> types, ArrayRef<NamedAttribute> attributes, - ArrayRef<StmtBlock *> successors = {}) + ArrayRef<Block *> successors = {}) : context(context), location(location), name(name, context), operands(operands.begin(), operands.end()), types(types.begin(), types.end()), @@ -248,7 +247,7 @@ public: attributes.push_back({name, attr}); } - void addSuccessor(StmtBlock *successor, ArrayRef<Value *> succOperands) { + void addSuccessor(Block *successor, ArrayRef<Value *> succOperands) { successors.push_back(successor); // Insert a sentinal operand to mark a barrier between successor operands. operands.push_back(nullptr); diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index 48135514dcf..9ca5530f33c 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -28,13 +28,13 @@ #include "llvm/ADT/ilist_node.h" namespace mlir { +class Block; class Location; -class StmtBlock; class ForStmt; class MLIRContext; -/// The operand of a Terminator contains a StmtBlock. -using StmtBlockOperand = IROperandImpl<StmtBlock, OperationInst>; +/// Terminator operations can have Block operands to represent successors. +using BlockOperand = IROperandImpl<Block, OperationInst>; } // namespace mlir @@ -55,7 +55,7 @@ template <> struct ilist_traits<::mlir::Statement> { stmt_iterator first, stmt_iterator last); private: - mlir::StmtBlock *getContainingBlock(); + mlir::Block *getContainingBlock(); }; } // end namespace llvm @@ -66,9 +66,9 @@ template <typename ObjectType, typename ElementType> class OperandIterator; /// Statement is a basic unit of execution within an ML function. /// Statements can be nested within for and if statements effectively /// forming a tree. Child statements are organized into statement blocks -/// represented by a 'StmtBlock' class. +/// represented by a 'Block' class. class Statement : public IROperandOwner, - public llvm::ilist_node_with_parent<Statement, StmtBlock> { + public llvm::ilist_node_with_parent<Statement, Block> { public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, @@ -95,7 +95,7 @@ public: Statement *clone(MLIRContext *context) const; /// Returns the statement block that contains this statement. - StmtBlock *getBlock() const { return block; } + Block *getBlock() const { return block; } /// Returns the closest surrounding statement that contains this statement /// or nullptr if this is a top-level statement. @@ -121,7 +121,7 @@ public: /// Unlink this operation instruction from its current basic block and insert /// it right before `iterator` in the specified basic block. - void moveBefore(StmtBlock *block, llvm::iplist<Statement>::iterator iterator); + void moveBefore(Block *block, llvm::iplist<Statement>::iterator iterator); // Returns whether the Statement is a terminator. bool isTerminator() const; @@ -198,7 +198,7 @@ protected: private: /// The statement block that containts this statement. - StmtBlock *block = nullptr; + Block *block = nullptr; // allow ilist_traits access to 'block' field. friend struct llvm::ilist_traits<Statement>; diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index d04ebd776b9..aa4157714a7 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -23,10 +23,10 @@ #define MLIR_IR_STATEMENTS_H #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Block.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Statement.h" -#include "mlir/IR/StmtBlock.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/TrailingObjects.h" @@ -46,14 +46,14 @@ class Function; /// class OperationInst final : public Statement, - private llvm::TrailingObjects<OperationInst, InstResult, StmtBlockOperand, + private llvm::TrailingObjects<OperationInst, InstResult, BlockOperand, unsigned, InstOperand> { public: /// Create a new OperationInst with the specific fields. static OperationInst * create(Location location, OperationName name, ArrayRef<Value *> operands, ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, - ArrayRef<StmtBlock *> successors, MLIRContext *context); + ArrayRef<Block *> successors, MLIRContext *context); /// Return the context this operation is associated with. MLIRContext *getContext() const; @@ -229,11 +229,11 @@ public: // Terminators //===--------------------------------------------------------------------===// - MutableArrayRef<StmtBlockOperand> getBlockOperands() { + MutableArrayRef<BlockOperand> getBlockOperands() { assert(isTerminator() && "Only terminators have a block operands list"); - return {getTrailingObjects<StmtBlockOperand>(), numSuccs}; + return {getTrailingObjects<BlockOperand>(), numSuccs}; } - ArrayRef<StmtBlockOperand> getBlockOperands() const { + ArrayRef<BlockOperand> getBlockOperands() const { return const_cast<OperationInst *>(this)->getBlockOperands(); } @@ -248,14 +248,14 @@ public: return getTrailingObjects<unsigned>()[index]; } - StmtBlock *getSuccessor(unsigned index) { + Block *getSuccessor(unsigned index) { assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } - const StmtBlock *getSuccessor(unsigned index) const { + const Block *getSuccessor(unsigned index) const { return const_cast<OperationInst *>(this)->getSuccessor(index); } - void setSuccessor(BasicBlock *block, unsigned index); + void setSuccessor(Block *block, unsigned index); /// Erase a specific operand from the operand list of the successor at /// 'index'. @@ -404,7 +404,7 @@ private: void eraseOperand(unsigned index); // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects<OperationInst, InstResult, StmtBlockOperand, + friend llvm::TrailingObjects<OperationInst, InstResult, BlockOperand, unsigned, InstOperand>; size_t numTrailingObjects(OverloadToken<InstOperand>) const { return numOperands; @@ -412,7 +412,7 @@ private: size_t numTrailingObjects(OverloadToken<InstResult>) const { return numResults; } - size_t numTrailingObjects(OverloadToken<StmtBlockOperand>) const { + size_t numTrailingObjects(OverloadToken<BlockOperand>) const { return numSuccs; } size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; } @@ -515,7 +515,7 @@ public: AffineMap ubMap, int64_t step); ~ForStmt() { - // Explicitly erase statements instead of relying of 'StmtBlock' destructor + // Explicitly erase statements instead of relying of 'Block' destructor // since child statements need to be destroyed before the Value that this // for stmt represents is destroyed. Affine maps are immortal objects and // don't need to be deleted. @@ -534,10 +534,10 @@ public: using const_operand_range = llvm::iterator_range<const_operand_iterator>; /// Get the body of the ForStmt. - StmtBlock *getBody() { return &body.front(); } + Block *getBody() { return &body.front(); } /// Get the body of the ForStmt. - const StmtBlock *getBody() const { return &body.front(); } + const Block *getBody() const { return &body.front(); } //===--------------------------------------------------------------------===// // Bounds and step @@ -664,8 +664,8 @@ public: } private: - // The StmtBlock for the body. - StmtBlockList body; + // The Block for the body. + BlockList body; // Affine map for the lower bound. AffineMap lbMap; @@ -746,18 +746,18 @@ public: // Then, else, condition. //===--------------------------------------------------------------------===// - StmtBlock *getThen() { return &thenClause.front(); } - const StmtBlock *getThen() const { return &thenClause.front(); } - StmtBlock *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const StmtBlock *getElse() const { + Block *getThen() { return &thenClause.front(); } + const Block *getThen() const { return &thenClause.front(); } + Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } + const Block *getElse() const { return elseClause ? &elseClause->front() : nullptr; } bool hasElse() const { return elseClause != nullptr; } - StmtBlock *createElse() { + Block *createElse() { assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new StmtBlockList(this); - elseClause->push_back(new StmtBlock()); + elseClause = new BlockList(this); + elseClause->push_back(new Block()); return &elseClause->front(); } @@ -823,9 +823,9 @@ public: private: // it is always present. - StmtBlockList thenClause; + BlockList thenClause; // 'else' clause of the if statement. 'nullptr' if there is no else clause. - StmtBlockList *elseClause; + BlockList *elseClause; // The integer set capturing the conditional guard. IntegerSet set; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 75184dc7a3f..2213fe79852 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -27,10 +27,10 @@ #include "mlir/Support/LLVM.h" namespace mlir { +class Block; class Function; class OperationInst; class Statement; -class StmtBlock; class Value; using Instruction = Statement; @@ -136,18 +136,18 @@ public: return const_cast<BlockArgument *>(this)->getFunction(); } - StmtBlock *getOwner() { return owner; } - const StmtBlock *getOwner() const { return owner; } + Block *getOwner() { return owner; } + const Block *getOwner() const { return owner; } private: - friend class StmtBlock; // For access to private constructor. - BlockArgument(Type type, StmtBlock *owner) + friend class Block; // For access to private constructor. + BlockArgument(Type type, Block *owner) : Value(Value::Kind::BlockArgument, type), owner(owner) {} /// The owner of this operand. /// TODO: can encode this more efficiently to avoid the space hit of this /// through bitpacking shenanigans. - StmtBlock *const owner; + Block *const owner; }; /// This is a value defined by a result of an operation instruction. diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h index d214a96f335..2694433d5a0 100644 --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -66,7 +66,7 @@ bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor); bool promoteIfSingleIteration(ForStmt *forStmt); /// Promotes all single iteration ForStmt's in the Function, i.e., moves -/// their body into the containing StmtBlock. +/// their body into the containing Block. void promoteSingleIterationLoops(Function *f); /// Returns the lower bound of the cleanup loop when unrolling a loop diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index e28c2e87651..12af803fdad 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -878,15 +878,15 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, return numCommonLoops; } -// Returns StmtBlock common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. -static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - const FlatAffineConstraints &srcDomain, - unsigned numCommonLoops) { +// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. +static Block *getCommonBlock(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, + unsigned numCommonLoops) { if (numCommonLoops == 0) { auto *block = srcAccess.opStmt->getBlock(); - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); + while (block->getContainingInst()) { + block = block->getContainingInst()->getBlock(); } return block; } @@ -906,14 +906,14 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { - // Get StmtBlock common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. + // Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. auto *commonBlock = - getCommonStmtBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); + getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); // Check the dominance relationship between the respective ancestors of the - // src and dst in the StmtBlock of the innermost among the common loops. - auto *srcStmt = commonBlock->findAncestorStmtInBlock(*srcAccess.opStmt); + // src and dst in the Block of the innermost among the common loops. + auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt); assert(srcStmt != nullptr); - auto *dstStmt = commonBlock->findAncestorStmtInBlock(*dstAccess.opStmt); + auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt); assert(dstStmt != nullptr); return mlir::properlyDominates(*srcStmt, *dstStmt); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 0ebbec9c025..0c8db07dbb4 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -25,9 +25,9 @@ #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; -template class llvm::DominatorTreeBase<BasicBlock, false>; -template class llvm::DominatorTreeBase<BasicBlock, true>; -template class llvm::DomTreeNodeBase<BasicBlock>; +template class llvm::DominatorTreeBase<Block, false>; +template class llvm::DominatorTreeBase<Block, true>; +template class llvm::DomTreeNodeBase<Block>; /// Compute the immediate-dominators map. DominanceInfo::DominanceInfo(Function *function) : DominatorTreeBase() { @@ -57,8 +57,8 @@ bool DominanceInfo::properlyDominates(const Instruction *a, return true; // Otherwise, do a linear scan to determine whether B comes after A. - auto aIter = BasicBlock::const_iterator(a); - auto bIter = BasicBlock::const_iterator(b); + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); auto fIter = aBlock->begin(); while (bIter != fIter) { --bIter; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index caeeccb677f..dd14f38df55 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -309,7 +309,7 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) { bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, ArrayRef<uint64_t> shifts) { auto *forBody = forStmt.getBody(); - assert(shifts.size() == forBody->getStatements().size()); + assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &stmt : *forBody) { // A for or if stmt does not produce any def/results (that are used @@ -323,8 +323,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, // This is a naive way. If performance becomes an issue, a map can // be used to store 'shifts' - to look up the shift for a statement in // constant time. - if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)]) + if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)]) return false; } } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e17c27ac941..f6191418f54 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,8 +44,8 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { if (a.getBlock() == b.getBlock()) { // Do a linear scan to determine whether b comes after a. - auto aIter = StmtBlock::const_iterator(a); - auto bIter = StmtBlock::const_iterator(b); + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); auto aBlockStart = a.getBlock()->begin(); while (bIter != aBlockStart) { --bIter; @@ -56,7 +56,7 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { } // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b)) + if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b)) // a and bAncestor are in the same block; check if the former dominates it. return dominates(a, *bAncestor); @@ -333,26 +333,26 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer<LoadOp> loadOp, template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp, bool emitError); -// Returns in 'positions' the StmtBlock positions of 'stmt' in each ancestor -// StmtBlock from the StmtBlock containing statement, stopping at 'limitBlock'. -static void findStmtPosition(const Statement *stmt, StmtBlock *limitBlock, +// Returns in 'positions' the Block positions of 'stmt' in each ancestor +// Block from the Block containing statement, stopping at 'limitBlock'. +static void findStmtPosition(const Statement *stmt, Block *limitBlock, SmallVectorImpl<unsigned> *positions) { - StmtBlock *block = stmt->getBlock(); + Block *block = stmt->getBlock(); while (block != limitBlock) { - int stmtPosInBlock = block->findStmtPosInBlock(*stmt); + int stmtPosInBlock = block->findInstPositionInBlock(*stmt); assert(stmtPosInBlock >= 0); positions->push_back(stmtPosInBlock); - stmt = block->getContainingStmt(); + stmt = block->getContainingInst(); block = stmt->getBlock(); } std::reverse(positions->begin(), positions->end()); } -// Returns the Statement in a possibly nested set of StmtBlocks, where the +// Returns the Statement in a possibly nested set of Blocks, where the // position of the statement is represented by 'positions', which has a -// StmtBlock position for each level of nesting. +// Block position for each level of nesting. static Statement *getStmtAtPosition(ArrayRef<unsigned> positions, - unsigned level, StmtBlock *block) { + unsigned level, Block *block) { unsigned i = 0; for (auto &stmt : *block) { if (i != positions[level]) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 43c29dbb6ac..4cad531ecaa 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -59,7 +59,7 @@ public: return fn.emitError(message); } - bool failure(const Twine &message, const BasicBlock &bb) { + bool failure(const Twine &message, const Block &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) if (auto *op = dyn_cast<OperationInst>(&bb.front())) @@ -153,7 +153,7 @@ struct CFGFuncVerifier : public Verifier { : Verifier(fn), fn(fn), domInfo(const_cast<Function *>(&fn)) {} bool verify(); - bool verifyBlock(const BasicBlock &block); + bool verifyBlock(const Block &block); bool verifyInstOperands(const Instruction &inst); }; } // end anonymous namespace @@ -214,7 +214,7 @@ bool CFGFuncVerifier::verifyInstOperands(const Instruction &inst) { return false; } -bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) { +bool CFGFuncVerifier::verifyBlock(const Block &block) { if (!block.getTerminator()) return failure("basic block with no terminator", block); @@ -287,12 +287,12 @@ bool MLFuncVerifier::verifyDominance() { // This recursive function walks the statement list pushing scopes onto the // stack as it goes, and popping them to remove them from the table. - std::function<bool(const StmtBlock &block)> walkBlock; - walkBlock = [&](const StmtBlock &block) -> bool { + std::function<bool(const Block &block)> walkBlock; + walkBlock = [&](const Block &block) -> bool { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt())) + if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingInst())) liveValues.insert(forStmt, true); for (auto &stmt : block) { @@ -340,10 +340,10 @@ bool MLFuncVerifier::verifyDominance() { bool MLFuncVerifier::verifyReturn() { // TODO: fold return verification in the pass that verifies all statements. const char missingReturnMsg[] = "ML function must end with return statement"; - if (fn.getBody()->getStatements().empty()) + if (fn.getBody()->getInstructions().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getBody()->getStatements().back(); + const auto &stmt = fn.getBody()->getInstructions().back(); if (const auto *op = dyn_cast<OperationInst>(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2ff7220f8ee..daaaee7010c 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -180,7 +180,7 @@ void ModuleState::visitExtFunction(const Function *fn) { void ModuleState::visitCFGFunction(const Function *fn) { visitType(fn->getType()); for (auto &block : *fn) { - for (auto &op : block.getStatements()) { + for (auto &op : block.getInstructions()) { if (auto *opInst = dyn_cast<OperationInst>(&op)) visitOperation(opInst); else { @@ -914,7 +914,7 @@ public: void print(const OperationInst *inst); void print(const ForStmt *stmt); void print(const IfStmt *stmt); - void print(const StmtBlock *block); + void print(const Block *block); void printOperation(const OperationInst *op); void printDefaultOp(const OperationInst *op); @@ -944,11 +944,11 @@ public: enum { nameSentinel = ~0U }; - void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); } + void printBlockName(const Block *block) { os << "bb" << getBlockID(block); } - unsigned getBBID(const BasicBlock *block) { - auto it = basicBlockIDs.find(block); - assert(it != basicBlockIDs.end() && "Block not in this function?"); + unsigned getBlockID(const Block *block) { + auto it = blockIDs.find(block); + assert(it != blockIDs.end() && "Block not in this function?"); return it->second; } @@ -964,7 +964,7 @@ public: protected: void numberValueID(const Value *value); - void numberValuesInBlock(const StmtBlock &block); + void numberValuesInBlock(const Block &block); void printValueID(const Value *value, bool printResultNo = true) const; private: @@ -976,7 +976,7 @@ private: DenseMap<const Value *, StringRef> valueNames; /// This is the block ID for each block in the current function. - DenseMap<const BasicBlock *, unsigned> basicBlockIDs; + DenseMap<const Block *, unsigned> blockIDs; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -1007,10 +1007,10 @@ FunctionPrinter::FunctionPrinter(const Function *function, } /// Number all of the SSA values in the specified block list. -void FunctionPrinter::numberValuesInBlock(const StmtBlock &block) { +void FunctionPrinter::numberValuesInBlock(const Block &block) { // Each block gets a unique ID, and all of the instructions within it get // numbered as well. - basicBlockIDs[&block] = nextBlockID++; + blockIDs[&block] = nextBlockID++; for (auto *arg : block.getArguments()) numberValueID(arg); @@ -1154,6 +1154,7 @@ void FunctionPrinter::printMLFunctionSignature() { os << " : "; printType(arg->getType()); } + os << ')'; printFunctionResultType(type); } @@ -1174,11 +1175,11 @@ void FunctionPrinter::printOtherFunctionSignature() { printFunctionResultType(type); } -void FunctionPrinter::print(const StmtBlock *block) { +void FunctionPrinter::print(const Block *block) { // Print the block label and argument list, unless we are in an ML function. if (!block->getFunction()->isML()) { os.indent(currentIndent); - printBBName(block); + printBlockName(block); // Print the argument list if non-empty. if (!block->args_empty()) { @@ -1201,13 +1202,13 @@ void FunctionPrinter::print(const StmtBlock *block) { os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { os << "\t// pred: "; - printBBName(pred); + printBlockName(pred); } else { // We want to print the predecessors in increasing numeric order, not in // whatever order the use-list is in, so gather and sort them. SmallVector<unsigned, 4> predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back(getBBID(pred)); + predIDs.push_back(getBlockID(pred)); llvm::array_pod_sort(predIDs.begin(), predIDs.end()); os << "\t// " << predIDs.size() << " preds: "; @@ -1218,7 +1219,8 @@ void FunctionPrinter::print(const StmtBlock *block) { } currentIndent += indentWidth; - for (auto &stmt : block->getStatements()) { + + for (auto &stmt : block->getInstructions()) { print(&stmt); os << '\n'; } @@ -1358,10 +1360,9 @@ void FunctionPrinter::printDefaultOp(const OperationInst *op) { void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, unsigned index) { - printBBName(term->getSuccessor(index)); + printBlockName(term->getSuccessor(index)); auto succOperands = term->getSuccessorOperands(index); - if (succOperands.begin() == succOperands.end()) return; @@ -1516,7 +1517,7 @@ void Instruction::dump() const { llvm::errs() << "\n"; } -void BasicBlock::print(raw_ostream &os) const { +void Block::print(raw_ostream &os) const { auto *function = getFunction(); if (!function) { os << "<<UNLINKED BLOCK>>\n"; @@ -1528,17 +1529,17 @@ void BasicBlock::print(raw_ostream &os) const { FunctionPrinter(function, modulePrinter).print(this); } -void BasicBlock::dump() const { print(llvm::errs()); } +void Block::dump() const { print(llvm::errs()); } /// Print out the name of the basic block without printing its body. -void StmtBlock::printAsOperand(raw_ostream &os, bool printType) { +void Block::printAsOperand(raw_ostream &os, bool printType) { if (!getFunction()) { os << "<<UNLINKED BLOCK>>\n"; return; } ModuleState state(getFunction()->getContext()); ModulePrinter modulePrinter(os, state); - FunctionPrinter(getFunction(), modulePrinter).printBBName(this); + FunctionPrinter(getFunction(), modulePrinter).printBlockName(this); } void Function::print(raw_ostream &os) const { diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/Block.cpp index b551b1121a7..c7e84194c35 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/Block.cpp @@ -1,4 +1,4 @@ -//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===// +//===- Block.cpp - MLIR Block and BlockList Classes -----------------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,12 +15,12 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/StmtBlock.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" using namespace mlir; -StmtBlock::~StmtBlock() { +Block::~Block() { clear(); llvm::DeleteContainerPointers(arguments); @@ -28,13 +28,13 @@ StmtBlock::~StmtBlock() { /// Returns the closest surrounding statement that contains this block or /// nullptr if this is a top-level statement block. -Statement *StmtBlock::getContainingStmt() { - return parent ? parent->getContainingStmt() : nullptr; +Statement *Block::getContainingInst() { + return parent ? parent->getContainingInst() : nullptr; } -Function *StmtBlock::getFunction() { - StmtBlock *block = this; - while (auto *stmt = block->getContainingStmt()) { +Function *Block::getFunction() { + Block *block = this; + while (auto *stmt = block->getContainingInst()) { block = stmt->getBlock(); if (!block) return nullptr; @@ -44,34 +44,34 @@ Function *StmtBlock::getFunction() { return nullptr; } -/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor -/// statement of 'stmt' that lies in this block. Returns nullptr if the latter -/// fails. -const Statement * -StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const { +/// Returns 'inst' if 'inst' lies in this block, or otherwise finds the +/// ancestor instruction of 'inst' that lies in this block. Returns nullptr if +/// the latter fails. +const Instruction * +Block::findAncestorInstInBlock(const Instruction &inst) const { // Traverse up the statement hierarchy starting from the owner of operand to // find the ancestor statement that resides in the block of 'forStmt'. - const auto *currStmt = &stmt; - while (currStmt->getBlock() != this) { - currStmt = currStmt->getParentStmt(); - if (!currStmt) + const auto *currInst = &inst; + while (currInst->getBlock() != this) { + currInst = currInst->getParentStmt(); + if (!currInst) return nullptr; } - return currStmt; + return currInst; } //===----------------------------------------------------------------------===// // Argument list management. //===----------------------------------------------------------------------===// -BlockArgument *StmtBlock::addArgument(Type type) { +BlockArgument *Block::addArgument(Type type) { auto *arg = new BlockArgument(type, this); arguments.push_back(arg); return arg; } /// Add one argument to the argument list for each type specified in the list. -auto StmtBlock::addArguments(ArrayRef<Type> types) +auto Block::addArguments(ArrayRef<Type> types) -> llvm::iterator_range<args_iterator> { arguments.reserve(arguments.size() + types.size()); auto initialSize = arguments.size(); @@ -81,7 +81,7 @@ auto StmtBlock::addArguments(ArrayRef<Type> types) return {arguments.data() + initialSize, arguments.data() + arguments.size()}; } -void StmtBlock::eraseArgument(unsigned index) { +void Block::eraseArgument(unsigned index) { assert(index < arguments.size()); // Delete the argument. @@ -100,12 +100,12 @@ void StmtBlock::eraseArgument(unsigned index) { // Terminator management //===----------------------------------------------------------------------===// -OperationInst *StmtBlock::getTerminator() { +OperationInst *Block::getTerminator() { if (empty()) return nullptr; // Check if the last instruction is a terminator. - auto &backInst = statements.back(); + auto &backInst = back(); auto *opStmt = dyn_cast<OperationInst>(&backInst); if (!opStmt || !opStmt->isTerminator()) return nullptr; @@ -113,14 +113,14 @@ OperationInst *StmtBlock::getTerminator() { } /// Return true if this block has no predecessors. -bool StmtBlock::hasNoPredecessors() const { return pred_begin() == pred_end(); } +bool Block::hasNoPredecessors() const { return pred_begin() == pred_end(); } // Indexed successor access. -unsigned StmtBlock::getNumSuccessors() const { +unsigned Block::getNumSuccessors() const { return getTerminator()->getNumSuccessors(); } -StmtBlock *StmtBlock::getSuccessor(unsigned i) { +Block *Block::getSuccessor(unsigned i) { return getTerminator()->getSuccessor(i); } @@ -130,7 +130,7 @@ StmtBlock *StmtBlock::getSuccessor(unsigned i) { /// Note that multiple edges from a single block (e.g. if you have a cond /// branch with the same block as the true/false destinations) is not /// considered to be a single predecessor. -StmtBlock *StmtBlock::getSinglePredecessor() { +Block *Block::getSinglePredecessor() { auto it = pred_begin(); if (it == pred_end()) return nullptr; @@ -143,9 +143,9 @@ StmtBlock *StmtBlock::getSinglePredecessor() { // Other //===----------------------------------------------------------------------===// -/// Unlink this BasicBlock from its Function and delete it. -void BasicBlock::eraseFromFunction() { - assert(getFunction() && "BasicBlock has no parent"); +/// Unlink this Block from its Function and delete it. +void Block::eraseFromFunction() { + assert(getFunction() && "Block has no parent"); getFunction()->getBlocks().erase(this); } @@ -156,21 +156,21 @@ void BasicBlock::eraseFromFunction() { /// the original basic block, an unconditional branch is added to the original /// block (going to the new block), and the rest of the instructions in the /// original block are moved to the new BB, including the old terminator. The -/// newly formed BasicBlock is returned. +/// newly formed Block is returned. /// /// This function invalidates the specified iterator. -BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) { +Block *Block::splitBlock(iterator splitBefore) { // Start by creating a new basic block, and insert it immediate after this // one in the containing function. - auto newBB = new BasicBlock(); + auto newBB = new Block(); getFunction()->getBlocks().insert(++Function::iterator(this), newBB); auto branchLoc = splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc(); // Move all of the operations from the split point to the end of the function // into the new block. - newBB->getStatements().splice(newBB->end(), getStatements(), splitBefore, - end()); + newBB->getInstructions().splice(newBB->end(), getInstructions(), splitBefore, + end()); // Create an unconditional branch to the new block, and move our terminator // to the new block. @@ -179,58 +179,54 @@ BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) { } //===----------------------------------------------------------------------===// -// StmtBlockList +// BlockList //===----------------------------------------------------------------------===// -StmtBlockList::StmtBlockList(Function *container) : container(container) {} +BlockList::BlockList(Function *container) : container(container) {} -StmtBlockList::StmtBlockList(Statement *container) : container(container) {} +BlockList::BlockList(Statement *container) : container(container) {} -Function *StmtBlockList::getFunction() { return getContainingFunction(); } - -Statement *StmtBlockList::getContainingStmt() { +Statement *BlockList::getContainingInst() { return container.dyn_cast<Statement *>(); } -Function *StmtBlockList::getContainingFunction() { +Function *BlockList::getContainingFunction() { return container.dyn_cast<Function *>(); } -StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() { - size_t Offset(size_t( - &((StmtBlockList *)nullptr->*StmtBlockList::getSublistAccess(nullptr)))); - iplist<StmtBlock> *Anchor(static_cast<iplist<StmtBlock> *>(this)); - return reinterpret_cast<StmtBlockList *>(reinterpret_cast<char *>(Anchor) - - Offset); +BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() { + size_t Offset( + size_t(&((BlockList *)nullptr->*BlockList::getSublistAccess(nullptr)))); + iplist<Block> *Anchor(static_cast<iplist<Block> *>(this)); + return reinterpret_cast<BlockList *>(reinterpret_cast<char *>(Anchor) - + Offset); } /// This is a trait method invoked when a basic block is added to a function. /// We keep the function pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::addNodeToList(StmtBlock *block) { +void llvm::ilist_traits<::mlir::Block>::addNodeToList(Block *block) { assert(!block->parent && "already in a function!"); block->parent = getContainingBlockList(); } /// This is a trait method invoked when an instruction is removed from a /// function. We keep the function pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::removeNodeFromList( - StmtBlock *block) { +void llvm::ilist_traits<::mlir::Block>::removeNodeFromList(Block *block) { assert(block->parent && "not already in a function!"); block->parent = nullptr; } /// This is a trait method invoked when an instruction is moved from one block /// to another. We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::StmtBlock>::transferNodesFromList( - ilist_traits<StmtBlock> &otherList, block_iterator first, - block_iterator last) { +void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( + ilist_traits<Block> &otherList, block_iterator first, block_iterator last) { // If we are transferring instructions within the same function, the parent // pointer doesn't need to be updated. auto *curParent = getContainingBlockList(); if (curParent == otherList.getContainingBlockList()) return; - // Update the 'parent' member of each StmtBlock. + // Update the 'parent' member of each Block. for (; first != last; ++first) first->parent = curParent; } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 81a3b7c2950..a9eb6fe8c8a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -275,8 +275,8 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { /// 'insertBefore' basic block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. -StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) { - StmtBlock *b = new StmtBlock(); +Block *FuncBuilder::createBlock(Block *insertBefore) { + Block *b = new Block(); // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. @@ -294,7 +294,7 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { auto *op = OperationInst::create(state.location, state.name, state.operands, state.types, state.attributes, state.successors, context); - block->getStatements().insert(insertPoint, op); + block->getInstructions().insert(insertPoint, op); return op; } @@ -303,7 +303,7 @@ ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands, AffineMap ubMap, int64_t step) { auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getStatements().insert(insertPoint, stmt); + block->getInstructions().insert(insertPoint, stmt); return stmt; } @@ -317,6 +317,6 @@ ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands, IntegerSet set) { auto *stmt = IfStmt::create(location, operands, set); - block->getStatements().insert(insertPoint, stmt); + block->getInstructions().insert(insertPoint, stmt); return stmt; } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 51596a9f09e..a0264fc11b0 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -167,13 +167,13 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants, // BranchOp //===----------------------------------------------------------------------===// -void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest, +void BranchOp::build(Builder *builder, OperationState *result, Block *dest, ArrayRef<Value *> operands) { result->addSuccessor(dest, operands); } bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { - BasicBlock *dest; + Block *dest; SmallVector<Value *, 4> destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) return true; @@ -193,9 +193,9 @@ bool BranchOp::verify() const { return false; } -BasicBlock *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } +Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); } -void BranchOp::setDest(BasicBlock *block) { +void BranchOp::setDest(Block *block) { return getInstruction()->setSuccessor(block, 0); } @@ -208,8 +208,8 @@ void BranchOp::eraseOperand(unsigned index) { //===----------------------------------------------------------------------===// void CondBranchOp::build(Builder *builder, OperationState *result, - Value *condition, BasicBlock *trueDest, - ArrayRef<Value *> trueOperands, BasicBlock *falseDest, + Value *condition, Block *trueDest, + ArrayRef<Value *> trueOperands, Block *falseDest, ArrayRef<Value *> falseOperands) { result->addOperands(condition); result->addSuccessor(trueDest, trueOperands); @@ -218,7 +218,7 @@ void CondBranchOp::build(Builder *builder, OperationState *result, bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector<Value *, 4> destOperands; - BasicBlock *dest; + Block *dest; OpAsmParser::OperandType condInfo; // Parse the condition. @@ -263,11 +263,11 @@ bool CondBranchOp::verify() const { return false; } -BasicBlock *CondBranchOp::getTrueDest() { +Block *CondBranchOp::getTrueDest() { return getInstruction()->getSuccessor(trueIndex); } -BasicBlock *CondBranchOp::getFalseDest() { +Block *CondBranchOp::getFalseDest() { return getInstruction()->getSuccessor(falseIndex); } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 0e777c65f23..cbe84e10247 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -37,7 +37,7 @@ Function::Function(Kind kind, Location location, StringRef name, // TODO(clattner): Unify this behavior. if (kind == Kind::MLFunc) { // The body of an ML Function always has one block. - auto *entry = new StmtBlock(); + auto *entry = new Block(); blocks.push_back(entry); // Initialize the arguments. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 23e54b3638e..ccd7d65f7c8 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -245,7 +245,7 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { static bool verifyBBArguments( llvm::iterator_range<OperationInst::const_operand_iterator> operands, - const BasicBlock *destBB, const OperationInst *op) { + const Block *destBB, const OperationInst *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -277,11 +277,11 @@ static bool verifyTerminatorSuccessors(const OperationInst *op) { bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { // Verify that the operation is at the end of the respective parent block. if (op->getFunction()->isML()) { - StmtBlock *block = op->getBlock(); - if (!block || block->getContainingStmt() || &block->back() != op) + Block *block = op->getBlock(); + if (!block || block->getContainingInst() || &block->back() != op) return op->emitOpError("must be the last statement in the ML function"); } else { - const BasicBlock *block = op->getBlock(); + const Block *block = op->getBlock(); if (!block || &block->back() != op) return op->emitOpError( "must be the last instruction in the parent basic block."); diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 96b44600460..6bd9944bb65 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -49,7 +49,7 @@ template <> unsigned InstOperand::getOperandNumber() const { } /// Return which operand this is in the operand list. -template <> unsigned StmtBlockOperand::getOperandNumber() const { +template <> unsigned BlockOperand::getOperandNumber() const { return this - &getOwner()->getBlockOperands()[0]; } @@ -79,7 +79,7 @@ void Statement::destroy() { } Statement *Statement::getParentStmt() const { - return block ? block->getContainingStmt() : nullptr; + return block ? block->getContainingInst() : nullptr; } Function *Statement::getFunction() const { @@ -191,12 +191,10 @@ void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) { stmt->destroy(); } -StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { - size_t Offset( - size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr)))); +Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { + size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this)); - return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) - - Offset); + return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); } /// This is a trait method invoked when a statement is added to a block. We @@ -221,7 +219,7 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( stmt_iterator last) { // If we are transferring statements within the same block, the block // pointer doesn't need to be updated. - StmtBlock *curParent = getContainingBlock(); + Block *curParent = getContainingBlock(); if (curParent == otherList.getContainingBlock()) return; @@ -230,11 +228,11 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( first->block = curParent; } -/// Remove this statement (and its descendants) from its StmtBlock and delete +/// Remove this statement (and its descendants) from its Block and delete /// all of them. void Statement::erase() { assert(getBlock() && "Statement has no block"); - getBlock()->getStatements().erase(this); + getBlock()->getInstructions().erase(this); } /// Unlink this statement from its current block and insert it right before @@ -246,10 +244,10 @@ void Statement::moveBefore(Statement *existingStmt) { /// Unlink this operation instruction from its current basic block and insert /// it right before `iterator` in the specified basic block. -void Statement::moveBefore(StmtBlock *block, +void Statement::moveBefore(Block *block, llvm::iplist<Statement>::iterator iterator) { - block->getStatements().splice(iterator, getBlock()->getStatements(), - getIterator()); + block->getInstructions().splice(iterator, getBlock()->getInstructions(), + getIterator()); } /// This drops all operand uses from this instruction, which is an essential @@ -273,7 +271,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, ArrayRef<Value *> operands, ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, - ArrayRef<StmtBlock *> successors, + ArrayRef<Block *> successors, MLIRContext *context) { unsigned numSuccessors = successors.size(); @@ -282,7 +280,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, unsigned numOperands = operands.size() - numSuccessors; auto byteSize = - totalSizeToAlloc<InstResult, StmtBlockOperand, unsigned, InstOperand>( + totalSizeToAlloc<InstResult, BlockOperand, unsigned, InstOperand>( resultTypes.size(), numSuccessors, numSuccessors, numOperands); void *rawMem = malloc(byteSize); @@ -340,7 +338,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, } new (&instBlockOperands[currentSuccNum]) - StmtBlockOperand(stmt, successors[currentSuccNum]); + BlockOperand(stmt, successors[currentSuccNum]); *succOperandCountIt = 0; ++currentSuccNum; continue; @@ -382,7 +380,7 @@ OperationInst::~OperationInst() { // Explicitly run the destructors for the successors. if (isTerminator()) for (auto &successor : getBlockOperands()) - successor.~StmtBlockOperand(); + successor.~BlockOperand(); } /// Return true if there are no users of any results of this operation. @@ -420,7 +418,7 @@ MLIRContext *OperationInst::getContext() const { bool OperationInst::isReturn() const { return isa<ReturnOp>(); } -void OperationInst::setSuccessor(BasicBlock *block, unsigned index) { +void OperationInst::setSuccessor(Block *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); } @@ -559,7 +557,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, body(this), lbMap(lbMap), ubMap(ubMap), step(step) { // The body of a for stmt always has one block. - body.push_back(new StmtBlock()); + body.push_back(new Block()); operands.reserve(numOperands); } @@ -679,7 +677,7 @@ IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set) operands.reserve(numOperands); // The then of an 'if' stmt always has one block. - thenClause.push_back(new StmtBlock()); + thenClause.push_back(new Block()); } IfStmt::~IfStmt() { @@ -736,7 +734,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, }; SmallVector<Value *, 8> operands; - SmallVector<StmtBlock *, 2> successors; + SmallVector<Block *, 2> successors; if (auto *opStmt = dyn_cast<OperationInst>(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); @@ -758,8 +756,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, successors.reserve(opStmt->getNumSuccessors()); for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e; ++succ) { - successors.push_back( - const_cast<StmtBlock *>(opStmt->getSuccessor(succ))); + successors.push_back(const_cast<Block *>(opStmt->getSuccessor(succ))); // Add sentinel to delineate successor operands. operands.push_back(nullptr); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 9b67ef8b150..6cc1aba72b3 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1921,7 +1921,7 @@ public: parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. - virtual bool parseSuccessorAndUseList(BasicBlock *&dest, + virtual bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value *> &operands) = 0; protected: @@ -2398,7 +2398,7 @@ public: return false; } - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value *> &operands) override { // Defer successor parsing to the function parsers. return parser.parseSuccessorAndUseList(dest, operands); @@ -2570,13 +2570,13 @@ public: ParseResult parseFunctionBody(); - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value *> &operands); private: Function *function; - llvm::StringMap<std::pair<BasicBlock *, SMLoc>> blocksByName; - DenseMap<BasicBlock *, SMLoc> forwardRef; + llvm::StringMap<std::pair<Block *, SMLoc>> blocksByName; + DenseMap<Block *, SMLoc> forwardRef; /// This builder intentionally shadows the builder in the base class, with a /// more specific builder type. @@ -2585,10 +2585,10 @@ private: /// Get the basic block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. - BasicBlock *getBlockNamed(StringRef name, SMLoc loc) { + Block *getBlockNamed(StringRef name, SMLoc loc) { auto &blockAndLoc = blocksByName[name]; if (!blockAndLoc.first) { - blockAndLoc.first = new BasicBlock(); + blockAndLoc.first = new Block(); forwardRef[blockAndLoc.first] = loc; function->push_back(blockAndLoc.first); blockAndLoc.second = loc; @@ -2597,9 +2597,9 @@ private: return blockAndLoc.first; } - // Define the basic block with the specified name. Returns the BasicBlock* or + // Define the basic block with the specified name. Returns the Block* or // nullptr in the case of redefinition. - BasicBlock *defineBlockNamed(StringRef name, SMLoc loc) { + Block *defineBlockNamed(StringRef name, SMLoc loc) { auto &blockAndLoc = blocksByName[name]; if (!blockAndLoc.first) { blockAndLoc.first = builder.createBlock(); @@ -2621,10 +2621,10 @@ private: } ParseResult - parseOptionalBasicBlockArgList(SmallVectorImpl<BlockArgument *> &results, - BasicBlock *owner); + parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, + Block *owner); - ParseResult parseBasicBlock(); + ParseResult parseBlock(); }; } // end anonymous namespace @@ -2634,7 +2634,7 @@ private: /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// bool CFGFunctionParser::parseSuccessorAndUseList( - BasicBlock *&dest, SmallVectorImpl<Value *> &operands) { + Block *&dest, SmallVectorImpl<Value *> &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::bare_identifier)) return emitError("expected basic block name"); @@ -2656,8 +2656,8 @@ bool CFGFunctionParser::parseSuccessorAndUseList( /// /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* /// -ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList( - SmallVectorImpl<BlockArgument *> &results, BasicBlock *owner) { +ParseResult CFGFunctionParser::parseOptionalBlockArgList( + SmallVectorImpl<BlockArgument *> &results, Block *owner) { if (getToken().is(Token::r_brace)) return ParseSuccess; @@ -2684,12 +2684,12 @@ ParseResult CFGFunctionParser::parseFunctionBody() { // Parse the list of blocks. while (!consumeIf(Token::r_brace)) - if (parseBasicBlock()) + if (parseBlock()) return ParseFailure; // Verify that all referenced blocks were defined. if (!forwardRef.empty()) { - SmallVector<std::pair<const char *, BasicBlock *>, 4> errors; + SmallVector<std::pair<const char *, Block *>, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. for (auto entry : forwardRef) errors.push_back({entry.second.getPointer(), entry.first}); @@ -2721,7 +2721,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() { /// bb-id ::= bare-id /// bb-arg-list ::= `(` ssa-id-and-type-list? `)` /// -ParseResult CFGFunctionParser::parseBasicBlock() { +ParseResult CFGFunctionParser::parseBlock() { SMLoc nameLoc = getToken().getLoc(); auto name = getTokenSpelling(); if (parseToken(Token::bare_identifier, "expected basic block name")) @@ -2736,7 +2736,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() { // If an argument list is present, parse it. if (consumeIf(Token::l_paren)) { SmallVector<BlockArgument *, 8> bbArgs; - if (parseOptionalBasicBlockArgList(bbArgs, block) || + if (parseOptionalBlockArgList(bbArgs, block) || parseToken(Token::r_paren, "expected ')' to end argument list")) return ParseFailure; } @@ -2794,11 +2794,11 @@ private: ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); - ParseResult parseElseClause(StmtBlock *elseClause); - ParseResult parseStatements(StmtBlock *block); - ParseResult parseStmtBlock(StmtBlock *block); + ParseResult parseElseClause(Block *elseClause); + ParseResult parseStatements(Block *block); + ParseResult parseBlock(Block *block); - bool parseSuccessorAndUseList(BasicBlock *&dest, + bool parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value *> &operands) { assert(false && "MLFunctions do not have terminators with successors."); return true; @@ -2810,7 +2810,7 @@ ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); // Parse statements in this function. - if (parseStmtBlock(function->getBody())) + if (parseBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); @@ -2874,7 +2874,7 @@ ParseResult MLFunctionParser::parseForStmt() { // If parsing of the for statement body fails, // MLIR contains for statement with those nested statements that have been // successfully parsed. - if (parseStmtBlock(forStmt->getBody())) + if (parseBlock(forStmt->getBody())) return ParseFailure; // Reset insertion point to the current block. @@ -3118,12 +3118,12 @@ ParseResult MLFunctionParser::parseIfStmt() { IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), operands, set); - StmtBlock *thenClause = ifStmt->getThen(); + Block *thenClause = ifStmt->getThen(); // When parsing of an if statement body fails, the IR contains // the if statement with the portion of the body that has been // successfully parsed. - if (parseStmtBlock(thenClause)) + if (parseBlock(thenClause)) return ParseFailure; if (consumeIf(Token::kw_else)) { @@ -3138,19 +3138,19 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseSuccess; } -ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { +ParseResult MLFunctionParser::parseElseClause(Block *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); return parseIfStmt(); } - return parseStmtBlock(elseClause); + return parseBlock(elseClause); } /// /// Parse a list of statements ending with `return` or `}` /// -ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { +ParseResult MLFunctionParser::parseStatements(Block *block) { auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; @@ -3188,7 +3188,7 @@ ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { /// /// Parse `{` ml-stmt* `}` /// -ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) { +ParseResult MLFunctionParser::parseBlock(Block *block) { if (parseToken(Token::l_brace, "expected '{' before statement list") || parseStatements(block) || parseToken(Token::r_brace, "expected '}' after statement list")) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index e9942ff824b..0f130e19e26 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -52,7 +52,7 @@ public: bool runOnModule(Module &m, llvm::Module &llvmModule); private: - bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false); + bool convertBlock(const Block &bb, bool ignoreArguments = false); bool convertCFGFunction(const Function &cfgFunc, llvm::Function &llvmFunc); bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule); bool convertInstruction(const OperationInst &inst); @@ -142,7 +142,7 @@ private: llvm::DenseMap<const Function *, llvm::Function *> functionMapping; llvm::DenseMap<const Value *, llvm::Value *> valueMapping; - llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping; + llvm::DenseMap<const Block *, llvm::BasicBlock *> blockMapping; llvm::LLVMContext &llvmContext; llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder; llvm::IntegerType *indexType; @@ -742,8 +742,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) { return inst.emitError("unsupported operation"); } -bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, - bool ignoreArguments) { +bool ModuleLowerer::convertBlock(const Block &bb, bool ignoreArguments) { builder.SetInsertPoint(blockMapping[&bb]); // Before traversing instructions, make block arguments available through @@ -780,8 +779,7 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, // Get the SSA value passed to the current block from the terminator instruction // of its predecessor. -static const Value *getPHISourceValue(const BasicBlock *current, - const BasicBlock *pred, +static const Value *getPHISourceValue(const Block *current, const Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); if (terminator.isa<BranchOp>()) { @@ -804,7 +802,7 @@ void ModuleLowerer::connectPHINodes(const Function &cfgFunc) { // to the arguments of the LLVM function. for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit; ++it) { - const BasicBlock *bb = &*it; + const Block *bb = &*it; llvm::BasicBlock *llvmBB = blockMapping[bb]; auto phis = llvmBB->phis(); auto numArguments = bb->getNumArguments(); @@ -837,7 +835,7 @@ bool ModuleLowerer::convertCFGFunction(const Function &cfgFunc, // Then, convert blocks one by one. for (auto indexedBB : llvm::enumerate(cfgFunc)) { const auto &bb = indexedBB.value(); - if (convertBasicBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0)) + if (convertBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0)) return true; } @@ -872,7 +870,7 @@ bool ModuleLowerer::convertFunctions(const Module &mlirModule, // arguments of the first block are those of the function. assert(!functionPtr->getBlocks().empty() && "expected at least one basic block in a Function"); - const BasicBlock &firstBlock = *functionPtr->begin(); + const Block &firstBlock = *functionPtr->begin(); for (auto arg : llvm::enumerate(llvmFunc->args())) { valueMapping[firstBlock.getArgument(arg.index())] = &arg.value(); } diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 04f7cfdc3e9..a5b45ba4098 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -182,7 +182,7 @@ struct CFGCSE : public CSEImpl { // Check to see if we need to process this node. if (!currentNode->processed) { currentNode->processed = true; - simplifyBasicBlock(currentNode->node->getBlock()); + simplifyBlock(currentNode->node->getBlock()); // Otherwise, check to see if we need to process a child node. } else if (currentNode->childIterator != currentNode->node->end()) { auto *childNode = *(currentNode->childIterator++); @@ -199,7 +199,7 @@ struct CFGCSE : public CSEImpl { eraseDeadOperations(); } - void simplifyBasicBlock(BasicBlock *bb) { + void simplifyBlock(Block *bb) { for (auto &i : *bb) if (auto *opInst = dyn_cast<OperationInst>(&i)) simplifyOperation(opInst); diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 8c69fa61578..c97b83f8485 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -45,8 +45,8 @@ struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> { std::vector<OperationInst *> affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} - using StmtListType = llvm::iplist<Statement>; - void walk(StmtListType::iterator Start, StmtListType::iterator End); + using InstListType = llvm::iplist<Statement>; + void walk(InstListType::iterator Start, InstListType::iterator End); void visitOperationInst(OperationInst *stmt); PassResult runOnMLFunction(Function *f) override; using StmtWalker<ComposeAffineMaps>::walk; @@ -62,8 +62,8 @@ FunctionPass *mlir::createComposeAffineMapsPass() { return new ComposeAffineMaps(); } -void ComposeAffineMaps::walk(StmtListType::iterator Start, - StmtListType::iterator End) { +void ComposeAffineMaps::walk(InstListType::iterator Start, + InstListType::iterator End) { while (Start != End) { walk(&(*Start)); // Increment iterator after walk as visit function can mutate stmt list diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 270a25dd339..821f35ca539 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -50,7 +50,7 @@ public: private: Value *getConstantIndexValue(int64_t value); - void visitStmtBlock(StmtBlock *stmtBlock); + void visitBlock(Block *Block); Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range<OperationInst::result_iterator> values); @@ -117,8 +117,8 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) { } // Visit all statements in the given statement block. -void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { - for (auto &stmt : *stmtBlock) +void FunctionConverter::visitBlock(Block *Block) { + for (auto &stmt : *Block) this->visit(&stmt); } @@ -214,13 +214,13 @@ Value *FunctionConverter::buildMinMaxReductionSeq( void FunctionConverter::visitForStmt(ForStmt *forStmt) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). - BasicBlock *loopInsertionPoint = builder.getInsertionBlock(); + Block *loopInsertionPoint = builder.getInsertionBlock(); // Create blocks so that they appear in more human-readable order in the // output. - BasicBlock *loopInitBlock = builder.createBlock(); - BasicBlock *loopConditionBlock = builder.createBlock(); - BasicBlock *loopBodyFirstBlock = builder.createBlock(); + Block *loopInitBlock = builder.createBlock(); + Block *loopConditionBlock = builder.createBlock(); + Block *loopBodyFirstBlock = builder.createBlock(); // At the loop insertion location, branch immediately to the loop init block. builder.setInsertionPointToEnd(loopInsertionPoint); @@ -238,7 +238,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); - visitStmtBlock(forStmt->getBody()); + visitBlock(forStmt->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit @@ -254,7 +254,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { nextIvValue); // Create post-loop block here so that it appears after all loop body blocks. - BasicBlock *postLoopBlock = builder.createBlock(); + Block *postLoopBlock = builder.createBlock(); builder.setInsertionPointToEnd(loopInitBlock); // Compute loop bounds using affine_apply after remapping its operands. @@ -378,15 +378,15 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. - BasicBlock *ifInsertionBlock = builder.getInsertionBlock(); - SmallVector<BasicBlock *, 4> ifConditionExtraBlocks; + Block *ifInsertionBlock = builder.getInsertionBlock(); + SmallVector<Block *, 4> ifConditionExtraBlocks; unsigned numConstraints = integerSet.getNumConstraints(); ifConditionExtraBlocks.reserve(numConstraints - 1); for (unsigned i = 0, e = numConstraints - 1; i < e; ++i) { ifConditionExtraBlocks.push_back(builder.createBlock()); } - BasicBlock *thenBlock = builder.createBlock(); - BasicBlock *elseBlock = builder.createBlock(); + Block *thenBlock = builder.createBlock(); + Block *elseBlock = builder.createBlock(); builder.setInsertionPointToEnd(ifInsertionBlock); // Implement short-circuit logic. For each affine expression in the 'if' @@ -405,7 +405,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { ifConditionExtraBlocks)) { AffineExpr constraintExpr = std::get<0>(tuple); bool isEquality = std::get<1>(tuple); - BasicBlock *nextBlock = std::get<2>(tuple); + Block *nextBlock = std::get<2>(tuple); // Build and apply an affine map. auto affineMap = @@ -429,19 +429,19 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); - visitStmtBlock(ifStmt->getThen()); - BasicBlock *lastThenBlock = builder.getInsertionBlock(); + visitBlock(ifStmt->getThen()); + Block *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); if (ifStmt->hasElse()) - visitStmtBlock(ifStmt->getElse()); - BasicBlock *lastElseBlock = builder.getInsertionBlock(); + visitBlock(ifStmt->getElse()); + Block *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the // 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions // to the continuation block. - BasicBlock *continuationBlock = builder.createBlock(); + Block *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 925c50abfec..69344819ed8 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -176,7 +176,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, FuncBuilder prologue(forStmt); // DMAs for write regions are going to be inserted just after the for loop. FuncBuilder epilogue(forStmt->getBlock(), - std::next(StmtBlock::iterator(forStmt))); + std::next(Block::iterator(forStmt))); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. @@ -382,7 +382,7 @@ static unsigned getNestingDepth(const Statement &stmt) { return depth; } -// TODO(bondhugula): make this run on a StmtBlock instead of a 'for' stmt. +// TODO(bondhugula): make this run on a Block instead of a 'for' stmt. void DmaGeneration::runOnForStmt(ForStmt *forStmt) { // For now (for testing purposes), we'll run this on the outermost among 'for' // stmt's with unit stride, i.e., right at the top of the tile if tiling has diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 2ddd613d6af..d31337437ad 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -343,7 +343,7 @@ public: // Intializes the data dependence graph by walking statements in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. -// TODO(andydavis) Add support for taking a StmtBlock arg to construct the +// TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index d6c1eed3a0c..109953f2296 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -58,8 +58,9 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } // Move the loop body of ForStmt 'src' from 'src' into the specified location in // destination's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest, - StmtBlock::iterator loc) { - dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements()); + Block::iterator loc) { + dest->getBody()->getInstructions().splice(loc, + src->getBody()->getInstructions()); } // Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. @@ -164,8 +165,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto *pointLoop = b.createFor(loc, 0, 0); - pointLoop->getBody()->getStatements().splice( - pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + pointLoop->getBody()->getInstructions().splice( + pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; topLoop = pointLoop; @@ -178,9 +179,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, FuncBuilder b(topLoop); // Loop bounds will be set later. auto *tileSpaceLoop = b.createFor(loc, 0, 0); - tileSpaceLoop->getBody()->getStatements().splice( - tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), - topLoop); + tileSpaceLoop->getBody()->getInstructions().splice( + tileSpaceLoop->getBody()->begin(), + topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; topLoop = tileSpaceLoop; } @@ -222,7 +223,7 @@ static void getTileableBands(Function *f, ForStmt *currStmt = root; do { band.push_back(currStmt); - } while (currStmt->getBody()->getStatements().size() == 1 && + } while (currStmt->getBody()->getInstructions().size() == 1 && (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin()))); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index c3651e53593..15ea0f841cc 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -91,9 +91,9 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { std::vector<ForStmt *> loops; // This method specialized to encode custom return logic. - using StmtListType = llvm::iplist<Statement>; - bool walkPostOrder(StmtListType::iterator Start, - StmtListType::iterator End) { + using InstListType = llvm::iplist<Statement>; + bool walkPostOrder(InstListType::iterator Start, + InstListType::iterator End) { bool hasInnerLoops = false; // We need to walk all elements since all innermost loops need to be // gathered as opposed to determining whether this list has any inner diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7ed9be19644..60e8d154f98 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -130,13 +130,13 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // tree). class JamBlockGatherer : public StmtWalker<JamBlockGatherer> { public: - using StmtListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Statement>; // Store iterators to the first and last stmt of each sub-block found. - std::vector<std::pair<StmtBlock::iterator, StmtBlock::iterator>> subBlocks; + std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; // This is a linear time walk. - void walk(StmtListType::iterator Start, StmtListType::iterator End) { + void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; while (it != End && !isa<ForStmt>(it)) @@ -194,7 +194,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { DenseMap<const Value *, Value *> operandMap; // Insert the cleanup loop right after 'forStmt'. FuncBuilder builder(forStmt->getBlock(), - std::next(StmtBlock::iterator(forStmt))); + std::next(Block::iterator(forStmt))); auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); cleanupForStmt->setLowerBoundMap( getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder)); diff --git a/mlir/lib/Transforms/LowerAffineApply.cpp b/mlir/lib/Transforms/LowerAffineApply.cpp index 52146fdb5b7..747733de41e 100644 --- a/mlir/lib/Transforms/LowerAffineApply.cpp +++ b/mlir/lib/Transforms/LowerAffineApply.cpp @@ -52,7 +52,7 @@ PassResult LowerAffineApply::runOnMLFunction(Function *f) { } PassResult LowerAffineApply::runOnCFGFunction(Function *f) { - for (BasicBlock &bb : *f) { + for (Block &bb : *f) { // Handle iterators with care because we erase in the same loop. // In particular, step to the next element before erasing the current one. for (auto it = bb.begin(); it != bb.end();) { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 4d24191dcb2..51577009abb 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // Get the ML function builder. // We need access to the Function builder stored internally in the // MLFunctionLoweringRewriter general rewriting API does not provide - // ML-specific functions (ForStmt and StmtBlock manipulation). While we could + // ML-specific functions (ForStmt and Block manipulation). While we could // forward them or define a whole rewriting chain based on MLFunctionBuilder // instead of Builer, the code for it would be duplicate boilerplate. As we // go towards unifying ML and CFG functions, this separation will disappear. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a0964a67fa6..c8a6ced4ed1 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -345,7 +345,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } // Get shifts stored in map. - std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size()); + std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size()); unsigned s = 0; for (auto &stmt : *forStmt->getBody()) { assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 7def4fe2f09..03b4bb29e19 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -108,7 +108,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { } else { const AffineBound lb = forStmt->getLowerBound(); SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); + FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt)); auto affineApplyOp = builder.create<AffineApplyOp>( forStmt->getLoc(), lb.getMap(), lbOperands); forStmt->replaceAllUsesWith(affineApplyOp->getResult(0)); @@ -116,14 +116,14 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { } // Move the loop body statements to the loop's containing block. auto *block = forStmt->getBlock(); - block->getStatements().splice(StmtBlock::iterator(forStmt), - forStmt->getBody()->getStatements()); + block->getInstructions().splice(Block::iterator(forStmt), + forStmt->getBody()->getInstructions()); forStmt->erase(); return true; } /// Promotes all single iteration for stmt's in the Function, i.e., moves -/// their body into the containing StmtBlock. +/// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> { @@ -223,7 +223,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, int64_t step = forStmt->getStep(); - unsigned numChildStmts = forStmt->getBody()->getStatements().size(); + unsigned numChildStmts = forStmt->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -379,7 +379,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { DenseMap<const Value *, Value *> operandMap; - FuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); + FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt)); auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); assert(clLbMap && @@ -408,7 +408,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Keep a pointer to the last statement in the original block so that we know // what to clone (since we are doing this in-place). - StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 2ce8af3613a..50a3cf5a595 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -28,16 +28,16 @@ template <> struct llvm::DOTGraphTraits<const Function *> : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(const BasicBlock *basicBlock, - const Function *); + static std::string getNodeLabel(const Block *Block, const Function *); }; -std::string llvm::DOTGraphTraits<const Function *>::getNodeLabel( - const BasicBlock *basicBlock, const Function *) { +std::string +llvm::DOTGraphTraits<const Function *>::getNodeLabel(const Block *Block, + const Function *) { // Reuse the print output for the node labels. std::string outStreamStr; raw_string_ostream os(outStreamStr); - basicBlock->print(os); + Block->print(os); std::string &outStr = os.str(); if (outStr[0] == '\n') |

