diff options
41 files changed, 995 insertions, 808 deletions
diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index 7f6d799ceb7..6820ee8ad3f 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -51,7 +51,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt); /// whether indices[dim] is independent of the value `input`. // For now we assume no layout map or identity layout map in the MemRef. // TODO(ntv): support more than identity layout map. -bool isAccessInvariant(const MLValue &input, MemRefType *memRefType, +bool isAccessInvariant(const MLValue &input, MemRefType memRefType, llvm::ArrayRef<MLValue *> indices, unsigned dim); /// Checks whether all the LoadOp and StoreOp matched have access indexing diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index b84d20fe021..7c3039742c8 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -250,9 +250,9 @@ public: TypeAttr() = default; /* implicit */ TypeAttr(Attribute::ImplType *ptr); - static TypeAttr get(Type *type, MLIRContext *context); + static TypeAttr get(Type type, MLIRContext *context); - Type *getValue() const; + Type getValue() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Type; } @@ -277,7 +277,7 @@ public: Function *getValue() const; - FunctionType *getType() const; + FunctionType getType() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Function; } @@ -294,7 +294,7 @@ public: ElementsAttr() = default; /* implicit */ ElementsAttr(Attribute::ImplType *ptr); - VectorOrTensorType *getType() const; + VectorOrTensorType getType() const; /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(Kind kind) { @@ -313,7 +313,7 @@ public: SplatElementsAttr() = default; /* implicit */ SplatElementsAttr(Attribute::ImplType *ptr); - static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt); + static SplatElementsAttr get(VectorOrTensorType type, Attribute elt); Attribute getValue() const; /// Method for support type inquiry through isa, cast and dyn_cast. @@ -335,12 +335,12 @@ public: /// width specified by the element type (note all float type are 64 bits). /// When the value is retrieved, the bits are read from the storage and extend /// to 64 bits if necessary. - static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data); + static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data); // TODO: Read the data from the attribute list and compress them // to a character array. Then call the above method to construct the // attribute. - static DenseElementsAttr get(VectorOrTensorType *type, + static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<Attribute> values); void getValues(SmallVectorImpl<Attribute> &values) const; @@ -410,7 +410,7 @@ public: OpaqueElementsAttr() = default; /* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr); - static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes); + static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes); StringRef getValue() const; @@ -440,7 +440,7 @@ public: SparseElementsAttr() = default; /* implicit */ SparseElementsAttr(Attribute::ImplType *ptr); - static SparseElementsAttr get(VectorOrTensorType *type, + static SparseElementsAttr get(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values); diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h index c55d09c1ca6..cfae6af542c 100644 --- a/mlir/include/mlir/IR/BasicBlock.h +++ b/mlir/include/mlir/IR/BasicBlock.h @@ -64,10 +64,10 @@ public: bool args_empty() const { return arguments.empty(); } /// Add one value to the operand list. - BBArgument *addArgument(Type *type); + BBArgument *addArgument(Type type); /// Add one argument to the argument list for each type specified in the list. - llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types); + llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types); unsigned getNumArguments() const { return arguments.size(); } BBArgument *getArgument(unsigned i) { return arguments[i]; } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 2e48008c651..46952e2b2a4 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -68,29 +68,28 @@ public: unsigned column); // Types. - FloatType *getBF16Type(); - FloatType *getF16Type(); - FloatType *getF32Type(); - FloatType *getF64Type(); - - OtherType *getIndexType(); - OtherType *getTFControlType(); - OtherType *getTFStringType(); - OtherType *getTFResourceType(); - OtherType *getTFVariantType(); - OtherType *getTFComplex64Type(); - OtherType *getTFComplex128Type(); - OtherType *getTFF32REFType(); - - IntegerType *getIntegerType(unsigned width); - FunctionType *getFunctionType(ArrayRef<Type *> inputs, - ArrayRef<Type *> results); - MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition = {}, - unsigned memorySpace = 0); - VectorType *getVectorType(ArrayRef<int> shape, Type *elementType); - RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType); - UnrankedTensorType *getTensorType(Type *elementType); + FloatType getBF16Type(); + FloatType getF16Type(); + FloatType getF32Type(); + FloatType getF64Type(); + + OtherType getIndexType(); + OtherType getTFControlType(); + OtherType getTFStringType(); + OtherType getTFResourceType(); + OtherType getTFVariantType(); + OtherType getTFComplex64Type(); + OtherType getTFComplex128Type(); + OtherType getTFF32REFType(); + + IntegerType getIntegerType(unsigned width); + FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results); + MemRefType getMemRefType(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition = {}, + unsigned memorySpace = 0); + VectorType getVectorType(ArrayRef<int> shape, Type elementType); + RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType); + UnrankedTensorType getTensorType(Type elementType); // Attributes. @@ -102,15 +101,15 @@ public: ArrayAttr getArrayAttr(ArrayRef<Attribute> value); AffineMapAttr getAffineMapAttr(AffineMap map); IntegerSetAttr getIntegerSetAttr(IntegerSet set); - TypeAttr getTypeAttr(Type *type); + TypeAttr getTypeAttr(Type type); FunctionAttr getFunctionAttr(const Function *value); - ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt); - ElementsAttr getDenseElementsAttr(VectorOrTensorType *type, + ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt); + ElementsAttr getDenseElementsAttr(VectorOrTensorType type, ArrayRef<char> data); - ElementsAttr getSparseElementsAttr(VectorOrTensorType *type, + ElementsAttr getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values); - ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes); + ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes); // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); @@ -366,7 +365,7 @@ public: /// Creates an operation given the fields. OperationStmt *createOperation(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> types, + ArrayRef<Type> types, ArrayRef<NamedAttribute> attrs); /// Create operation of specific op type at the current insertion point. diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 88d4d812ba0..5d810a91e4e 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -96,7 +96,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands, public: /// Builds a constant op with the specified attribute value and result type. static void build(Builder *builder, OperationState *result, Attribute value, - Type *type); + Type type); Attribute getValue() const { return getAttr("value"); } @@ -123,7 +123,7 @@ class ConstantFloatOp : public ConstantOp { public: /// Builds a constant float op producing a float of the specified type. static void build(Builder *builder, OperationState *result, - const APFloat &value, FloatType *type); + const APFloat &value, FloatType type); APFloat getValue() const { return getAttrOfType<FloatAttr>("value").getValue(); @@ -150,7 +150,7 @@ public: /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. static void build(Builder *builder, OperationState *result, int64_t value, - Type *type); + Type type); int64_t getValue() const { return getAttrOfType<IntegerAttr>("value").getValue(); diff --git a/mlir/include/mlir/IR/CFGFunction.h b/mlir/include/mlir/IR/CFGFunction.h index f3c1da37908..fb20a6b1ef7 100644 --- a/mlir/include/mlir/IR/CFGFunction.h +++ b/mlir/include/mlir/IR/CFGFunction.h @@ -27,7 +27,7 @@ namespace mlir { // blocks, each of which includes instructions. class CFGFunction : public Function { public: - CFGFunction(Location *location, StringRef name, FunctionType *type, + CFGFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs = {}); ~CFGFunction(); diff --git a/mlir/include/mlir/IR/CFGValue.h b/mlir/include/mlir/IR/CFGValue.h index 939073c2382..45b36c1bd05 100644 --- a/mlir/include/mlir/IR/CFGValue.h +++ b/mlir/include/mlir/IR/CFGValue.h @@ -66,7 +66,7 @@ public: } protected: - CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {} + CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {} }; /// Basic block arguments are CFG Values. @@ -87,7 +87,7 @@ public: private: friend class BasicBlock; // For access to private constructor. - BBArgument(Type *type, BasicBlock *owner) + BBArgument(Type type, BasicBlock *owner) : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {} /// The owner of this operand. @@ -99,7 +99,7 @@ private: /// Instruction results are CFG Values. class InstResult : public CFGValue { public: - InstResult(Type *type, OperationInst *owner) + InstResult(Type type, OperationInst *owner) : CFGValue(CFGValueKind::InstResult, type), owner(owner) {} static bool classof(const SSAValue *value) { diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index d42f52851a8..04acc59b36b 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ilist.h" @@ -55,7 +56,7 @@ public: Identifier getName() const { return nameAndKind.getPointer(); } /// Return the type of this function. - FunctionType *getType() const { return type; } + FunctionType getType() const { return type; } /// Returns all of the attributes on this function. ArrayRef<NamedAttribute> getAttrs() const; @@ -93,7 +94,7 @@ public: void emitNote(const Twine &message) const; protected: - Function(Kind kind, Location *location, StringRef name, FunctionType *type, + Function(Kind kind, Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs = {}); ~Function(); @@ -108,7 +109,7 @@ private: Location *location; /// The type of the function. - FunctionType *const type; + FunctionType type; /// This holds general named attributes for the function. AttributeListStorage *attrs; @@ -121,7 +122,7 @@ private: /// defined in some other module. class ExtFunction : public Function { public: - ExtFunction(Location *location, StringRef name, FunctionType *type, + ExtFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs = {}); /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index e74c5616a45..6d5a1ca8b73 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -202,7 +202,7 @@ public: /// Create a new OperationInst with the specified fields. static OperationInst *create(Location *location, OperationName name, ArrayRef<CFGValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context); diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h index 104692f538d..cf0deb9d33c 100644 --- a/mlir/include/mlir/IR/MLFunction.h +++ b/mlir/include/mlir/IR/MLFunction.h @@ -41,7 +41,7 @@ class MLFunction final public: /// Creates a new MLFunction with the specific type. static MLFunction *create(Location *location, StringRef name, - FunctionType *type, + FunctionType type, ArrayRef<NamedAttribute> attrs = {}); /// Destroys this statement and its subclass data. @@ -52,7 +52,7 @@ public: //===--------------------------------------------------------------------===// /// Returns number of arguments. - unsigned getNumArguments() const { return getType()->getInputs().size(); } + unsigned getNumArguments() const { return getType().getInputs().size(); } /// Gets argument. MLFuncArgument *getArgument(unsigned idx) { @@ -103,13 +103,13 @@ public: } private: - MLFunction(Location *location, StringRef name, FunctionType *type, + MLFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs = {}); // This stuff is used by the TrailingObjects template. friend llvm::TrailingObjects<MLFunction, MLFuncArgument>; size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const { - return getType()->getInputs().size(); + return getType().getInputs().size(); } // Internal functions to get argument list used by getArgument() methods. diff --git a/mlir/include/mlir/IR/MLValue.h b/mlir/include/mlir/IR/MLValue.h index 1961da13d5e..0c6c0b22696 100644 --- a/mlir/include/mlir/IR/MLValue.h +++ b/mlir/include/mlir/IR/MLValue.h @@ -73,7 +73,7 @@ public: } protected: - MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {} + MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {} }; /// This is the value defined by an argument of an ML function. @@ -93,7 +93,7 @@ public: private: friend class MLFunction; // For access to private constructor. - MLFuncArgument(Type *type, MLFunction *owner) + MLFuncArgument(Type type, MLFunction *owner) : MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {} /// The owner of this operand. @@ -105,7 +105,7 @@ private: /// This is a value defined by a result of an operation instruction. class StmtResult : public MLValue { public: - StmtResult(Type *type, OperationStmt *owner) + StmtResult(Type type, OperationStmt *owner) : MLValue(MLValueKind::StmtResult, type), owner(owner) {} static bool classof(const SSAValue *value) { diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 06013a72097..ad97dd2865c 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -71,13 +71,13 @@ struct constant_int_op_binder { bool match(Operation *op) { if (auto constOp = op->dyn_cast<ConstantOp>()) { - auto *type = constOp->getResult()->getType(); + auto type = constOp->getResult()->getType(); auto attr = constOp->getAttr("value"); - if (isa<IntegerType>(type)) { + if (type.isa<IntegerType>()) { return attr_value_binder<IntegerAttr>(bind_value).match(attr); } - if (isa<VectorOrTensorType>(type)) { + if (type.isa<VectorOrTensorType>()) { if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { return attr_value_binder<IntegerAttr>(bind_value) .match(splatAttr.getValue()); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 821beb29624..c2bc35728cb 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -493,7 +493,7 @@ public: return this->getOperation()->getResult(0); } - Type *getType() const { return getResult()->getType(); } + Type getType() const { return getResult()->getType(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns @@ -539,7 +539,7 @@ public: return this->getOperation()->getResult(i); } - Type *getType(unsigned i) const { return getResult(i)->getType(); } + Type getType(unsigned i) const { return getResult(i)->getType(); } static bool verifyTrait(const Operation *op) { return impl::verifyNResults(op, N); @@ -565,7 +565,7 @@ public: return this->getOperation()->getResult(i); } - Type *getType(unsigned i) const { return getResult(i)->getType(); } + Type getType(unsigned i) const { return getResult(i)->getType(); } static bool verifyTrait(const Operation *op) { return impl::verifyAtLeastNResults(op, N); @@ -803,7 +803,7 @@ protected: // which avoids them being template instantiated/duplicated. namespace impl { void buildCastOp(Builder *builder, OperationState *result, SSAValue *source, - Type *destType); + Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); void printCastOp(const Operation *op, OpAsmPrinter *p); } // namespace impl @@ -819,7 +819,7 @@ class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult, OpTrait::HasNoSideEffect, Traits...> { public: static void build(Builder *builder, OperationState *result, SSAValue *source, - Type *destType) { + Type destType) { impl::buildCastOp(builder, result, source, destType); } static bool parse(OpAsmParser *parser, OperationState *result) { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 09ec3f94136..ae8df556177 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -67,7 +67,7 @@ public: printOperand(*it); } } - virtual void printType(const Type *type) = 0; + virtual void printType(Type type) = 0; virtual void printFunctionReference(const Function *func) = 0; virtual void printAttribute(Attribute attr) = 0; virtual void printAffineMap(AffineMap map) = 0; @@ -95,8 +95,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) { return p; } -inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) { - p.printType(&type); +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { + p.printType(type); return p; } @@ -163,20 +163,20 @@ public: virtual bool parseComma() = 0; /// Parse a colon followed by a type. - virtual bool parseColonType(Type *&result) = 0; + virtual bool parseColonType(Type &result) = 0; /// Parse a type of a specific kind, e.g. a FunctionType. - template <typename TypeType> bool parseColonType(TypeType *&result) { + template <typename TypeType> bool parseColonType(TypeType &result) { llvm::SMLoc loc; getCurrentLocation(&loc); // Parse any kind of type. - Type *type; + Type type; if (parseColonType(type)) return true; // Check for the right kind of attribute. - result = dyn_cast<TypeType>(type); + result = type.dyn_cast<TypeType>(); if (!result) { emitError(loc, "invalid kind of type specified"); return true; @@ -186,15 +186,15 @@ public: } /// Parse a colon followed by a type list, which must have at least one type. - virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0; + virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0; /// Parse a keyword followed by a type. - virtual bool parseKeywordType(const char *keyword, Type *&result) = 0; + virtual bool parseKeywordType(const char *keyword, Type &result) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and /// chain through || operators. - bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) { + bool addTypeToList(Type type, SmallVectorImpl<Type> &result) { result.push_back(type); return false; } @@ -202,7 +202,7 @@ public: /// Add the specified types to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and /// chain through || operators. - bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) { + bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) { result.append(types.begin(), types.end()); return false; } @@ -288,13 +288,13 @@ public: /// Resolve an operand to an SSA value, emitting an error and returning true /// on failure. - virtual bool resolveOperand(const OperandType &operand, Type *type, + virtual bool resolveOperand(const OperandType &operand, Type type, SmallVectorImpl<SSAValue *> &result) = 0; /// Resolve a list of operands to SSA values, emitting an error and returning /// true on failure, or appending the results to the list on success. /// This method should be used when all operands have the same type. - virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type, + virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type, SmallVectorImpl<SSAValue *> &result) { for (auto elt : operands) if (resolveOperand(elt, type, result)) @@ -306,7 +306,7 @@ public: /// emitting an error and returning true on failure, or appending the results /// to the list on success. virtual bool resolveOperands(ArrayRef<OperandType> operands, - ArrayRef<Type *> types, llvm::SMLoc loc, + ArrayRef<Type> types, llvm::SMLoc loc, SmallVectorImpl<SSAValue *> &result) { if (operands.size() != types.size()) return emitError(loc, Twine(operands.size()) + @@ -321,7 +321,7 @@ public: } /// Resolve a parse function name and a type into a function reference. - virtual bool resolveFunctionName(StringRef name, FunctionType *type, + virtual bool resolveFunctionName(StringRef name, FunctionType type, llvm::SMLoc loc, Function *&result) = 0; /// Emit a diagnostic at the specified location and return true. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index ddf1aee68c5..6833e45f632 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -25,6 +25,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Types.h" #include "llvm/ADT/PointerUnion.h" #include <memory> @@ -191,7 +192,7 @@ struct OperationState { OperationName name; SmallVector<SSAValue *, 4> operands; /// Types of the results of this operation. - SmallVector<Type *, 4> types; + SmallVector<Type, 4> types; SmallVector<NamedAttribute, 4> attributes; public: @@ -202,7 +203,7 @@ public: : context(context), location(location), name(name) {} OperationState(MLIRContext *context, Location *location, StringRef name, - ArrayRef<SSAValue *> operands, ArrayRef<Type *> types, + ArrayRef<SSAValue *> operands, ArrayRef<Type> types, ArrayRef<NamedAttribute> attributes = {}) : context(context), location(location), name(name, context), operands(operands.begin(), operands.end()), @@ -213,7 +214,7 @@ public: operands.append(newOperands.begin(), newOperands.end()); } - void addTypes(ArrayRef<Type *> newTypes) { + void addTypes(ArrayRef<Type> newTypes) { types.append(newTypes.begin(), newTypes.end()); } diff --git a/mlir/include/mlir/IR/SSAValue.h b/mlir/include/mlir/IR/SSAValue.h index 93db6fa99ca..ab16c981eda 100644 --- a/mlir/include/mlir/IR/SSAValue.h +++ b/mlir/include/mlir/IR/SSAValue.h @@ -25,7 +25,6 @@ #include "mlir/IR/Types.h" #include "mlir/IR/UseDefLists.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/PointerIntPair.h" namespace mlir { class Function; @@ -51,7 +50,7 @@ public: SSAValueKind getKind() const { return typeAndKind.getInt(); } - Type *getType() const { return typeAndKind.getPointer(); } + Type getType() const { return typeAndKind.getPointer(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns @@ -93,9 +92,10 @@ public: void dump() const; protected: - SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {} + SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {} + private: - const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind; + const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind; }; inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) { @@ -127,7 +127,7 @@ public: inline use_range getUses() const; protected: - SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {} + SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {} }; // Utility functions for iterating through SSAValue uses. diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 7e7a49ffa15..8a5a9b5e192 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -44,7 +44,7 @@ public: /// Create a new OperationStmt with the specific fields. static OperationStmt *create(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context); @@ -329,7 +329,7 @@ public: //===--------------------------------------------------------------------===// /// Return the context this operation is associated with. - MLIRContext *getContext() const { return getType()->getContext(); } + MLIRContext *getContext() const { return getType().getContext(); } using Statement::dump; using Statement::print; diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 3d0afdf607d..493f607964b 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -20,6 +20,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMapInfo.h" namespace mlir { class AffineMap; @@ -28,6 +29,22 @@ class IntegerType; class FloatType; class OtherType; +namespace detail { + +class TypeStorage; +class IntegerTypeStorage; +class FloatTypeStorage; +struct OtherTypeStorage; +struct FunctionTypeStorage; +struct VectorOrTensorTypeStorage; +struct VectorTypeStorage; +struct TensorTypeStorage; +struct RankedTensorTypeStorage; +struct UnrankedTensorTypeStorage; +struct MemRefTypeStorage; + +} // namespace detail + /// Instances of the Type class are immutable, uniqued, immortal, and owned by /// MLIRContext. As such, they are passed around by raw non-const pointer. /// @@ -68,11 +85,34 @@ public: MemRef, }; + using ImplType = detail::TypeStorage; + + Type() : type(nullptr) {} + /* implicit */ Type(const ImplType *type) + : type(const_cast<ImplType *>(type)) {} + + Type(const Type &other) : type(other.type) {} + Type &operator=(Type other) { + type = other.type; + return *this; + } + + bool operator==(Type other) const { return type == other.type; } + bool operator!=(Type other) const { return !(*this == other); } + explicit operator bool() const { return type; } + + bool operator!() const { return type == nullptr; } + + template <typename U> bool isa() const; + template <typename U> U dyn_cast() const; + template <typename U> U dyn_cast_or_null() const; + template <typename U> U cast() const; + /// Return the classification for this type. - Kind getKind() const { return kind; } + Kind getKind() const; /// Return the LLVMContext in which this type was uniqued. - MLIRContext *getContext() const { return context; } + MLIRContext *getContext() const; // Convenience predicates. This is only for 'other' and floating point types, // derived types should use isa/dyn_cast. @@ -97,56 +137,42 @@ public: unsigned getBitWidth() const; // Convenience factories. - static IntegerType *getInteger(unsigned width, MLIRContext *ctx); - static FloatType *getBF16(MLIRContext *ctx); - static FloatType *getF16(MLIRContext *ctx); - static FloatType *getF32(MLIRContext *ctx); - static FloatType *getF64(MLIRContext *ctx); - static OtherType *getIndex(MLIRContext *ctx); - static OtherType *getTFControl(MLIRContext *ctx); - static OtherType *getTFString(MLIRContext *ctx); - static OtherType *getTFResource(MLIRContext *ctx); - static OtherType *getTFVariant(MLIRContext *ctx); - static OtherType *getTFComplex64(MLIRContext *ctx); - static OtherType *getTFComplex128(MLIRContext *ctx); - static OtherType *getTFF32REF(MLIRContext *ctx); + static IntegerType getInteger(unsigned width, MLIRContext *ctx); + static FloatType getBF16(MLIRContext *ctx); + static FloatType getF16(MLIRContext *ctx); + static FloatType getF32(MLIRContext *ctx); + static FloatType getF64(MLIRContext *ctx); + static OtherType getIndex(MLIRContext *ctx); + static OtherType getTFControl(MLIRContext *ctx); + static OtherType getTFString(MLIRContext *ctx); + static OtherType getTFResource(MLIRContext *ctx); + static OtherType getTFVariant(MLIRContext *ctx); + static OtherType getTFComplex64(MLIRContext *ctx); + static OtherType getTFComplex128(MLIRContext *ctx); + static OtherType getTFF32REF(MLIRContext *ctx); /// Print the current type. void print(raw_ostream &os) const; void dump() const; -protected: - explicit Type(Kind kind, MLIRContext *context) - : context(context), kind(kind), subclassData(0) {} - explicit Type(Kind kind, MLIRContext *context, unsigned subClassData) - : Type(kind, context) { - setSubclassData(subClassData); - } - - ~Type() {} + friend ::llvm::hash_code hash_value(Type arg); - unsigned getSubclassData() const { return subclassData; } + unsigned getSubclassData() const; + void setSubclassData(unsigned val); - void setSubclassData(unsigned val) { - subclassData = val; - // Ensure we don't have any accidental truncation. - assert(getSubclassData() == val && "Subclass data too large for field"); + /// Methods for supporting PointerLikeTypeTraits. + const void *getAsOpaquePointer() const { + return static_cast<const void *>(type); + } + static Type getFromOpaquePointer(const void *pointer) { + return Type((ImplType *)(pointer)); } -private: - Type(const Type&) = delete; - void operator=(const Type&) = delete; - /// This refers to the MLIRContext in which this type was uniqued. - MLIRContext *const context; - - /// Classification of the subclass, used for type checking. - Kind kind : 8; - - // Space for subclasses to store data. - unsigned subclassData : 24; +protected: + ImplType *type; }; -inline raw_ostream &operator<<(raw_ostream &os, const Type &type) { +inline raw_ostream &operator<<(raw_ostream &os, Type type) { type.print(os); return os; } @@ -154,148 +180,138 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) { /// Integer types can have arbitrary bitwidth up to a large fixed limit. class IntegerType : public Type { public: - static IntegerType *get(unsigned width, MLIRContext *context); + using ImplType = detail::IntegerTypeStorage; + IntegerType() = default; + /* implicit */ IntegerType(Type::ImplType *ptr); + + static IntegerType get(unsigned width, MLIRContext *context); /// Return the bitwidth of this integer type. - unsigned getWidth() const { - return width; - } + unsigned getWidth() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() == Kind::Integer; - } + static bool kindof(Kind kind) { return kind == Kind::Integer; } /// Integer representation maximal bitwidth. static constexpr unsigned kMaxWidth = 4096; -private: - unsigned width; - IntegerType(unsigned width, MLIRContext *context); - ~IntegerType() = delete; }; -inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) { +inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) { return IntegerType::get(width, ctx); } /// Return true if this is an integer type with the specified width. inline bool Type::isInteger(unsigned width) const { - if (auto *intTy = dyn_cast<IntegerType>(this)) - return intTy->getWidth() == width; + if (auto intTy = dyn_cast<IntegerType>()) + return intTy.getWidth() == width; return false; } class FloatType : public Type { public: - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE && - type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE; - } + using ImplType = detail::FloatTypeStorage; + FloatType() = default; + /* implicit */ FloatType(Type::ImplType *ptr); - static FloatType *get(Kind kind, MLIRContext *context); + static FloatType get(Kind kind, MLIRContext *context); -private: - FloatType(Kind kind, MLIRContext *context); - ~FloatType() = delete; + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(Kind kind) { + return kind >= Kind::FIRST_FLOATING_POINT_TYPE && + kind <= Kind::LAST_FLOATING_POINT_TYPE; + } }; -inline FloatType *Type::getBF16(MLIRContext *ctx) { +inline FloatType Type::getBF16(MLIRContext *ctx) { return FloatType::get(Kind::BF16, ctx); } -inline FloatType *Type::getF16(MLIRContext *ctx) { +inline FloatType Type::getF16(MLIRContext *ctx) { return FloatType::get(Kind::F16, ctx); } -inline FloatType *Type::getF32(MLIRContext *ctx) { +inline FloatType Type::getF32(MLIRContext *ctx) { return FloatType::get(Kind::F32, ctx); } -inline FloatType *Type::getF64(MLIRContext *ctx) { +inline FloatType Type::getF64(MLIRContext *ctx) { return FloatType::get(Kind::F64, ctx); } /// This is a type for the random collection of special base types. class OtherType : public Type { public: + using ImplType = detail::OtherTypeStorage; + OtherType() = default; + /* implicit */ OtherType(Type::ImplType *ptr); + + static OtherType get(Kind kind, MLIRContext *context); + /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() >= Kind::FIRST_OTHER_TYPE && - type->getKind() <= Kind::LAST_OTHER_TYPE; + static bool kindof(Kind kind) { + return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE; } - static OtherType *get(Kind kind, MLIRContext *context); - -private: - OtherType(Kind kind, MLIRContext *context); - ~OtherType() = delete; }; -inline OtherType *Type::getIndex(MLIRContext *ctx) { +inline OtherType Type::getIndex(MLIRContext *ctx) { return OtherType::get(Kind::Index, ctx); } -inline OtherType *Type::getTFControl(MLIRContext *ctx) { +inline OtherType Type::getTFControl(MLIRContext *ctx) { return OtherType::get(Kind::TFControl, ctx); } -inline OtherType *Type::getTFResource(MLIRContext *ctx) { +inline OtherType Type::getTFResource(MLIRContext *ctx) { return OtherType::get(Kind::TFResource, ctx); } -inline OtherType *Type::getTFString(MLIRContext *ctx) { +inline OtherType Type::getTFString(MLIRContext *ctx) { return OtherType::get(Kind::TFString, ctx); } -inline OtherType *Type::getTFVariant(MLIRContext *ctx) { +inline OtherType Type::getTFVariant(MLIRContext *ctx) { return OtherType::get(Kind::TFVariant, ctx); } -inline OtherType *Type::getTFComplex64(MLIRContext *ctx) { +inline OtherType Type::getTFComplex64(MLIRContext *ctx) { return OtherType::get(Kind::TFComplex64, ctx); } -inline OtherType *Type::getTFComplex128(MLIRContext *ctx) { +inline OtherType Type::getTFComplex128(MLIRContext *ctx) { return OtherType::get(Kind::TFComplex128, ctx); } -inline OtherType *Type::getTFF32REF(MLIRContext *ctx) { +inline OtherType Type::getTFF32REF(MLIRContext *ctx) { return OtherType::get(Kind::TFF32REF, ctx); } /// Function types map from a list of inputs to a list of results. class FunctionType : public Type { public: - static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results, - MLIRContext *context); + using ImplType = detail::FunctionTypeStorage; + FunctionType() = default; + /* implicit */ FunctionType(Type::ImplType *ptr); + + static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results, + MLIRContext *context); // Input types. unsigned getNumInputs() const { return getSubclassData(); } - Type *getInput(unsigned i) const { return getInputs()[i]; } + Type getInput(unsigned i) const { return getInputs()[i]; } - ArrayRef<Type*> getInputs() const { - return ArrayRef<Type *>(inputsAndResults, getNumInputs()); - } + ArrayRef<Type> getInputs() const; // Result types. - unsigned getNumResults() const { return numResults; } + unsigned getNumResults() const; - Type *getResult(unsigned i) const { return getResults()[i]; } + Type getResult(unsigned i) const { return getResults()[i]; } - ArrayRef<Type*> getResults() const { - return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults); - } + ArrayRef<Type> getResults() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() == Kind::Function; - } - -private: - unsigned numResults; - Type *const *inputsAndResults; - - FunctionType(Type *const *inputsAndResults, unsigned numInputs, - unsigned numResults, MLIRContext *context); - ~FunctionType() = delete; + static bool kindof(Kind kind) { return kind == Kind::Function; } }; /// This is a common base class between Vector, UnrankedTensor, and RankedTensor /// types, because many operations work on values of these aggregate types. class VectorOrTensorType : public Type { public: - Type *getElementType() const { return elementType; } + using ImplType = detail::VectorOrTensorTypeStorage; + VectorOrTensorType() = default; + /* implicit */ VectorOrTensorType(Type::ImplType *ptr); + + Type getElementType() const; /// If this is ranked tensor or vector type, return the number of elements. If /// it is an unranked tensor or vector, abort. @@ -319,56 +335,40 @@ public: int getDimSize(unsigned i) const; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() == Kind::Vector || - type->getKind() == Kind::RankedTensor || - type->getKind() == Kind::UnrankedTensor; + static bool kindof(Kind kind) { + return kind == Kind::Vector || kind == Kind::RankedTensor || + kind == Kind::UnrankedTensor; } - -public: - Type *elementType; - - VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType, - unsigned subClassData = 0); }; /// Vector types represent multi-dimensional SIMD vectors, and have a fixed /// known constant shape with one or more dimension. class VectorType : public VectorOrTensorType { public: - static VectorType *get(ArrayRef<int> shape, Type *elementType); - - ArrayRef<int> getShape() const { - return ArrayRef<int>(shapeElements, getSubclassData()); - } + using ImplType = detail::VectorTypeStorage; + VectorType() = default; + /* implicit */ VectorType(Type::ImplType *ptr); - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() == Kind::Vector; - } + static VectorType get(ArrayRef<int> shape, Type elementType); -private: - const int *shapeElements; - Type *elementType; + ArrayRef<int> getShape() const; - VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context); - ~VectorType() = delete; + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(Kind kind) { return kind == Kind::Vector; } }; /// Tensor types represent multi-dimensional arrays, and have two variants: /// RankedTensorType and UnrankedTensorType. class TensorType : public VectorOrTensorType { public: + using ImplType = detail::TensorTypeStorage; + TensorType() = default; + /* implicit */ TensorType(Type::ImplType *ptr); /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *type) { - return type->getKind() == Kind::RankedTensor || - type->getKind() == Kind::UnrankedTensor; + static bool kindof(Kind kind) { + return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor; } - -protected: - TensorType(Kind kind, Type *elementType, MLIRContext *context); - ~TensorType() {} }; /// Ranked tensor types represent multi-dimensional arrays that have a shape @@ -376,40 +376,30 @@ protected: /// integer or unknown (represented -1). class RankedTensorType : public TensorType { public: - static RankedTensorType *get(ArrayRef<int> shape, - Type *elementType); + using ImplType = detail::RankedTensorTypeStorage; + RankedTensorType() = default; + /* implicit */ RankedTensorType(Type::ImplType *ptr); - ArrayRef<int> getShape() const { - return ArrayRef<int>(shapeElements, getSubclassData()); - } - - static bool classof(const Type *type) { - return type->getKind() == Kind::RankedTensor; - } + static RankedTensorType get(ArrayRef<int> shape, Type elementType); -private: - const int *shapeElements; + ArrayRef<int> getShape() const; - RankedTensorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context); - ~RankedTensorType() = delete; + static bool kindof(Kind kind) { return kind == Kind::RankedTensor; } }; /// Unranked tensor types represent multi-dimensional arrays that have an /// unknown shape. class UnrankedTensorType : public TensorType { public: - static UnrankedTensorType *get(Type *elementType); + using ImplType = detail::UnrankedTensorTypeStorage; + UnrankedTensorType() = default; + /* implicit */ UnrankedTensorType(Type::ImplType *ptr); - ArrayRef<int> getShape() const { return ArrayRef<int>(); } + static UnrankedTensorType get(Type elementType); - static bool classof(const Type *type) { - return type->getKind() == Kind::UnrankedTensor; - } + ArrayRef<int> getShape() const { return ArrayRef<int>(); } -private: - UnrankedTensorType(Type *elementType, MLIRContext *context); - ~UnrankedTensorType() = delete; + static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; } }; /// MemRef types represent a region of memory that have a shape with a fixed @@ -418,62 +408,96 @@ private: /// affine map composition, represented as an array AffineMap pointers. class MemRefType : public Type { public: + using ImplType = detail::MemRefTypeStorage; + MemRefType() = default; + /* implicit */ MemRefType(Type::ImplType *ptr); + /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. - static MemRefType *get(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace); + static MemRefType get(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace); unsigned getRank() const { return getShape().size(); } /// Returns an array of memref shape dimension sizes. - ArrayRef<int> getShape() const { - return ArrayRef<int>(shapeElements, getSubclassData()); - } + ArrayRef<int> getShape() const; /// Return the size of the specified dimension, or -1 if unspecified. int getDimSize(unsigned i) const { return getShape()[i]; } /// Returns the elemental type for this memref shape. - Type *getElementType() const { return elementType; } + Type getElementType() const; /// Returns an array of affine map pointers representing the memref affine /// map composition. ArrayRef<AffineMap> getAffineMaps() const; /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const { return memorySpace; } + unsigned getMemorySpace() const; /// Returns the number of dimensions with dynamic size. unsigned getNumDynamicDims() const; - static bool classof(const Type *type) { - return type->getKind() == Kind::MemRef; - } - -private: - /// The type of each scalar element of the memref. - Type *elementType; - /// An array of integers which stores the shape dimension sizes. - const int *shapeElements; - /// The number of affine maps in the 'affineMapList' array. - const unsigned numAffineMaps; - /// List of affine maps in the memref's layout/index map composition. - AffineMap const *affineMapList; - /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; - - MemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapList, unsigned memorySpace, - MLIRContext *context); - ~MemRefType() = delete; + static bool kindof(Kind kind) { return kind == Kind::MemRef; } }; +// Make Type hashable. +inline ::llvm::hash_code hash_value(Type arg) { + return ::llvm::hash_value(arg.type); +} + +template <typename U> bool Type::isa() const { + assert(type && "isa<> used on a null type."); + return U::kindof(getKind()); +} +template <typename U> U Type::dyn_cast() const { + return isa<U>() ? U(type) : U(nullptr); +} +template <typename U> U Type::dyn_cast_or_null() const { + return (type && isa<U>()) ? U(type) : U(nullptr); +} +template <typename U> U Type::cast() const { + assert(isa<U>()); + return U(type); +} + /// Return true if the specified element type is ok in a tensor. -static bool isValidTensorElementType(Type *type) { - return isa<FloatType>(type) || isa<VectorType>(type) || - isa<IntegerType>(type) || isa<OtherType>(type); +static bool isValidTensorElementType(Type type) { + return type.isa<FloatType>() || type.isa<VectorType>() || + type.isa<IntegerType>() || type.isa<OtherType>(); } + } // end namespace mlir +namespace llvm { + +// Type hash just like pointers. +template <> struct DenseMapInfo<mlir::Type> { + static mlir::Type getEmptyKey() { + auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); + return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); + } + static mlir::Type getTombstoneKey() { + auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); + return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer)); + } + static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); } + static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } +}; + +/// We align TypeStorage by 8, so allow LLVM to steal the low bits. +template <> struct PointerLikeTypeTraits<mlir::Type> { +public: + static inline void *getAsVoidPointer(mlir::Type I) { + return const_cast<void *>(I.getAsOpaquePointer()); + } + static inline mlir::Type getFromVoidPointer(void *P) { + return mlir::Type::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // namespace llvm + #endif // MLIR_IR_TYPES_H diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index b733bad2658..c0fe4cfd17c 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -104,15 +104,15 @@ class AllocOp : public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> { public: /// The result of an alloc is always a MemRefType. - MemRefType *getType() const { - return cast<MemRefType>(getResult()->getType()); + MemRefType getType() const { + return getResult()->getType().cast<MemRefType>(); } static StringRef getOperationName() { return "alloc"; } // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, - MemRefType *memrefType, ArrayRef<SSAValue *> operands = {}); + MemRefType memrefType, ArrayRef<SSAValue *> operands = {}); bool verify() const; static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; @@ -276,7 +276,7 @@ public: const SSAValue *getSrcMemRef() const { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() const { - return cast<MemRefType>(getSrcMemRef()->getType())->getRank(); + return getSrcMemRef()->getType().cast<MemRefType>().getRank(); } // Returns the source memerf indices for this DMA operation. llvm::iterator_range<Operation::const_operand_iterator> @@ -291,13 +291,13 @@ public: } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() const { - return cast<MemRefType>(getDstMemRef()->getType())->getRank(); + return getDstMemRef()->getType().cast<MemRefType>().getRank(); } unsigned getSrcMemorySpace() const { - return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace(); + return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace(); } unsigned getDstMemorySpace() const { - return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace(); + return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace(); } // Returns the destination memref indices for this DMA operation. @@ -387,7 +387,7 @@ public: // Returns the rank (number of indices) of the tag memref. unsigned getTagMemRefRank() const { - return cast<MemRefType>(getTagMemRef()->getType())->getRank(); + return getTagMemRef()->getType().cast<MemRefType>().getRank(); } // Returns the number of elements transferred in the associated DMA operation. @@ -460,8 +460,8 @@ public: SSAValue *getMemRef() { return getOperand(0); } const SSAValue *getMemRef() const { return getOperand(0); } void setMemRef(SSAValue *value) { setOperand(0, value); } - MemRefType *getMemRefType() const { - return cast<MemRefType>(getMemRef()->getType()); + MemRefType getMemRefType() const { + return getMemRef()->getType().cast<MemRefType>(); } llvm::iterator_range<Operation::operand_iterator> getIndices() { @@ -508,8 +508,8 @@ public: static StringRef getOperationName() { return "memref_cast"; } /// The result of a memref_cast is always a memref. - MemRefType *getType() const { - return cast<MemRefType>(getResult()->getType()); + MemRefType getType() const { + return getResult()->getType().cast<MemRefType>(); } bool verify() const; @@ -583,8 +583,8 @@ public: SSAValue *getMemRef() { return getOperand(1); } const SSAValue *getMemRef() const { return getOperand(1); } void setMemRef(SSAValue *value) { setOperand(1, value); } - MemRefType *getMemRefType() const { - return cast<MemRefType>(getMemRef()->getType()); + MemRefType getMemRefType() const { + return getMemRef()->getType().cast<MemRefType>(); } llvm::iterator_range<Operation::operand_iterator> getIndices() { @@ -671,8 +671,8 @@ public: static StringRef getOperationName() { return "tensor_cast"; } /// The result of a tensor_cast is always a tensor. - TensorType *getType() const { - return cast<TensorType>(getResult()->getType()); + TensorType getType() const { + return getResult()->getType().cast<TensorType>(); } bool verify() const; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 1b3c24fd9f9..1904a636647 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { return tripCountExpr.getLargestKnownDivisor(); } -bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType, +bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType, ArrayRef<MLValue *> indices, unsigned dim) { - assert(indices.size() == memRefType->getRank()); + assert(indices.size() == memRefType.getRank()); assert(dim < indices.size()); - auto layoutMap = memRefType->getAffineMaps(); - assert(memRefType->getAffineMaps().size() <= 1); + auto layoutMap = memRefType.getAffineMaps(); + assert(memRefType.getAffineMaps().size() <= 1); // TODO(ntv): remove dependency on Builder once we support non-identity // layout map. - Builder b(memRefType->getContext()); + Builder b(memRefType.getContext()); assert(layoutMap.empty() || layoutMap[0] == b.getMultiDimIdentityMap(indices.size())); (void)layoutMap; @@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input, using namespace functional; auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); }, memoryOp->getIndices()); - auto *memRefType = memoryOp->getMemRefType(); + auto memRefType = memoryOp->getMemRefType(); for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) { if (fastestVaryingDim == (numIndices - 1) - d) { continue; @@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input, template <typename LoadOrStoreOpPointer> static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { - auto *memRefType = memoryOp->getMemRefType(); - return isa<VectorType>(memRefType->getElementType()); + auto memRefType = memoryOp->getMemRefType(); + return memRefType.getElementType().template isa<VectorType>(); } bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index bfbcb169cfe..0dd030d5b45 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() { // Verify that the argument list of the function and the arg list of the first // block line up. - auto fnInputTypes = fn.getType()->getInputs(); + auto fnInputTypes = fn.getType().getInputs(); if (fnInputTypes.size() != firstBB->getNumArguments()) return failure("first block of cfgfunc must have " + Twine(fnInputTypes.size()) + @@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands, bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) { // Verify that the return operands match the results of the function. - auto results = fn.getType()->getResults(); + auto results = fn.getType().getResults(); if (inst.getNumOperands() != results.size()) return failure("return has " + Twine(inst.getNumOperands()) + " operands, but enclosing function returns " + diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 454a28a6558..cb5e96f0086 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -122,7 +122,7 @@ private: void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); void visitOperationStmt(const OperationStmt *opStmt); - void visitType(const Type *type); + void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const Operation *op); @@ -135,16 +135,16 @@ private: } // end anonymous namespace // TODO Support visiting other types/instructions when implemented. -void ModuleState::visitType(const Type *type) { - if (auto *funcType = dyn_cast<FunctionType>(type)) { +void ModuleState::visitType(Type type) { + if (auto funcType = type.dyn_cast<FunctionType>()) { // Visit input and result types for functions. - for (auto *input : funcType->getInputs()) + for (auto input : funcType.getInputs()) visitType(input); - for (auto *result : funcType->getResults()) + for (auto result : funcType.getResults()) visitType(result); - } else if (auto *memref = dyn_cast<MemRefType>(type)) { + } else if (auto memref = type.dyn_cast<MemRefType>()) { // Visit affine maps in memref type. - for (auto map : memref->getAffineMaps()) { + for (auto map : memref.getAffineMaps()) { recordAffineMapReference(map); } } @@ -271,7 +271,7 @@ public: void print(const Module *module); void printFunctionReference(const Function *func); void printAttribute(Attribute attr); - void printType(const Type *type); + void printType(Type type); void print(const Function *fn); void print(const ExtFunction *fn); void print(const CFGFunction *fn); @@ -290,7 +290,7 @@ protected: void printFunctionAttributes(const Function *fn); void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs = {}); - void printFunctionResultType(const FunctionType *type); + void printFunctionResultType(FunctionType type); void printAffineMapId(int affineMapId) const; void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; @@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) { } void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { - auto *type = attr.getType(); - auto shape = type->getShape(); - auto rank = type->getRank(); + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); SmallVector<Attribute, 16> elements; attr.getValues(elements); @@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { os << ']'; } -void ModulePrinter::printType(const Type *type) { - switch (type->getKind()) { +void ModulePrinter::printType(Type type) { + switch (type.getKind()) { case Type::Kind::Index: os << "index"; return; @@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) { return; case Type::Kind::Integer: { - auto *integer = cast<IntegerType>(type); - os << 'i' << integer->getWidth(); + auto integer = type.cast<IntegerType>(); + os << 'i' << integer.getWidth(); return; } case Type::Kind::Function: { - auto *func = cast<FunctionType>(type); + auto func = type.cast<FunctionType>(); os << '('; - interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); + interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; - auto results = func->getResults(); + auto results = func.getResults(); if (results.size() == 1) - os << *results[0]; + os << results[0]; else { os << '('; - interleaveComma(results, [&](Type *type) { printType(type); }); + interleaveComma(results, [&](Type type) { printType(type); }); os << ')'; } return; } case Type::Kind::Vector: { - auto *v = cast<VectorType>(type); + auto v = type.cast<VectorType>(); os << "vector<"; - for (auto dim : v->getShape()) + for (auto dim : v.getShape()) os << dim << 'x'; - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::RankedTensor: { - auto *v = cast<RankedTensorType>(type); + auto v = type.cast<RankedTensorType>(); os << "tensor<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::UnrankedTensor: { - auto *v = cast<UnrankedTensorType>(type); + auto v = type.cast<UnrankedTensorType>(); os << "tensor<*x"; - printType(v->getElementType()); + printType(v.getElementType()); os << '>'; return; } case Type::Kind::MemRef: { - auto *v = cast<MemRefType>(type); + auto v = type.cast<MemRefType>(); os << "memref<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - printType(v->getElementType()); - for (auto map : v->getAffineMaps()) { + printType(v.getElementType()); + for (auto map : v.getAffineMaps()) { os << ", "; printAffineMapReference(map); } // Only print the memory space if it is the non-default one. - if (v->getMemorySpace()) - os << ", " << v->getMemorySpace(); + if (v.getMemorySpace()) + os << ", " << v.getMemorySpace(); os << '>'; return; } @@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { // Function printing //===----------------------------------------------------------------------===// -void ModulePrinter::printFunctionResultType(const FunctionType *type) { - switch (type->getResults().size()) { +void ModulePrinter::printFunctionResultType(FunctionType type) { + switch (type.getResults().size()) { case 0: break; case 1: os << " -> "; - printType(type->getResults()[0]); + printType(type.getResults()[0]); break; default: os << " -> ("; - interleaveComma(type->getResults(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getResults(), + [&](Type eltType) { printType(eltType); }); os << ')'; break; } @@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) { auto type = fn->getType(); os << "@" << fn->getName() << '('; - interleaveComma(type->getInputs(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); }); os << ')'; printFunctionResultType(type); @@ -937,7 +936,7 @@ public: // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } - void printType(const Type *type) { ModulePrinter::printType(type); } + void printType(Type type) { ModulePrinter::printType(type); } void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); } void printAffineMap(AffineMap map) { return ModulePrinter::printAffineMapReference(map); @@ -974,10 +973,10 @@ protected: if (auto *op = value->getDefiningOperation()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) { // i1 constants get special names. - if (intOp->getType()->isInteger(1)) { + if (intOp->getType().isInteger(1)) { specialName << (intOp->getValue() ? "true" : "false"); } else { - specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); + specialName << 'c' << intOp->getValue() << '_' << intOp->getType(); } } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { specialName << 'c' << intOp->getValue(); @@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); } void Type::print(raw_ostream &os) const { ModuleState state(getContext()); - ModulePrinter(os, state).printType(this); + ModulePrinter(os, state).printType(*this); } void Type::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index a0e9afb4fd3..63ad544fa48 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage { /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { - Type *value; + Type value; }; /// An attribute representing a reference to a function. @@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage { /// A base attribute representing a reference to a vector or tensor constant. struct ElementsAttributeStorage : public AttributeStorage { - VectorOrTensorType *type; + VectorOrTensorType type; }; /// An attribute representing a reference to a vector or tensor constant, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 34312b84a0b..58b5b90d43d 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const { TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -Type *TypeAttr::getValue() const { - return static_cast<ImplType *>(attr)->value; -} +Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} @@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } -FunctionType *FunctionAttr::getType() const { return getValue()->getType(); } +FunctionType FunctionAttr::getType() const { return getValue()->getType(); } ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -VectorOrTensorType *ElementsAttr::getType() const { +VectorOrTensorType ElementsAttr::getType() const { return static_cast<ImplType *>(attr)->type; } @@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos, void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth; - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); values.reserve(elementNum); if (bitsWidth == 64) { ArrayRef<int64_t> vs( @@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr) : DenseElementsAttr(ptr) {} void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()), getRawData().size() / 8}); values.reserve(elementNum); diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp index bb8ac75d91a..29a5ce12e4a 100644 --- a/mlir/lib/IR/BasicBlock.cpp +++ b/mlir/lib/IR/BasicBlock.cpp @@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() { // Argument list management. //===----------------------------------------------------------------------===// -BBArgument *BasicBlock::addArgument(Type *type) { +BBArgument *BasicBlock::addArgument(Type type) { auto *arg = new BBArgument(type, this); arguments.push_back(arg); return arg; } /// Add one argument to the argument list for each type specified in the list. -auto BasicBlock::addArguments(ArrayRef<Type *> types) +auto BasicBlock::addArguments(ArrayRef<Type> types) -> llvm::iterator_range<args_iterator> { arguments.reserve(arguments.size() + types.size()); auto initialSize = arguments.size(); - for (auto *type : types) { + for (auto type : types) { addArgument(type); } return {arguments.data() + initialSize, arguments.data() + arguments.size()}; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 22d749a6c8c..906b580d9af 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename, // Types. //===----------------------------------------------------------------------===// -FloatType *Builder::getBF16Type() { return Type::getBF16(context); } +FloatType Builder::getBF16Type() { return Type::getBF16(context); } -FloatType *Builder::getF16Type() { return Type::getF16(context); } +FloatType Builder::getF16Type() { return Type::getF16(context); } -FloatType *Builder::getF32Type() { return Type::getF32(context); } +FloatType Builder::getF32Type() { return Type::getF32(context); } -FloatType *Builder::getF64Type() { return Type::getF64(context); } +FloatType Builder::getF64Type() { return Type::getF64(context); } -OtherType *Builder::getIndexType() { return Type::getIndex(context); } +OtherType Builder::getIndexType() { return Type::getIndex(context); } -OtherType *Builder::getTFControlType() { return Type::getTFControl(context); } +OtherType Builder::getTFControlType() { return Type::getTFControl(context); } -OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } +OtherType Builder::getTFResourceType() { return Type::getTFResource(context); } -OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); } +OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); } -OtherType *Builder::getTFComplex64Type() { +OtherType Builder::getTFComplex64Type() { return Type::getTFComplex64(context); } -OtherType *Builder::getTFComplex128Type() { +OtherType Builder::getTFComplex128Type() { return Type::getTFComplex128(context); } -OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); } +OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); } -OtherType *Builder::getTFStringType() { return Type::getTFString(context); } +OtherType Builder::getTFStringType() { return Type::getTFString(context); } -IntegerType *Builder::getIntegerType(unsigned width) { +IntegerType Builder::getIntegerType(unsigned width) { return Type::getInteger(width, context); } -FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs, - ArrayRef<Type *> results) { +FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, + ArrayRef<Type> results) { return FunctionType::get(inputs, results, context); } -MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { +MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } -VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) { +VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) { return VectorType::get(shape, elementType); } -RankedTensorType *Builder::getTensorType(ArrayRef<int> shape, - Type *elementType) { +RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) { return RankedTensorType::get(shape, elementType); } -UnrankedTensorType *Builder::getTensorType(Type *elementType) { +UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } @@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { return IntegerSetAttr::get(set); } -TypeAttr Builder::getTypeAttr(Type *type) { +TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type, context); } @@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) { return FunctionAttr::get(value, context); } -ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, ArrayRef<char> data) { return DenseElementsAttr::get(type, data); } -ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes) { return OpaqueElementsAttr::get(type, bytes); } @@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { OperationStmt *MLFuncBuilder::createOperation(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> types, + ArrayRef<Type> types, ArrayRef<NamedAttribute> attrs) { auto *op = OperationStmt::create(location, name, operands, types, attrs, getContext()); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 542e67eaefd..e4bca037c4e 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, numDims = opInfos.size(); // Parse the optional symbol operands. - auto *affineIntTy = parser->getBuilder().getIndexType(); + auto affineIntTy = parser->getBuilder().getIndexType(); if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::OptionalSquare) || parser->resolveOperands(opInfos, affineIntTy, operands)) @@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result, bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); - auto *affineIntTy = builder.getIndexType(); + auto affineIntTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; @@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants, /// Builds a constant op with the specified attribute value and result type. void ConstantOp::build(Builder *builder, OperationState *result, - Attribute value, Type *type) { + Attribute value, Type type) { result->addAttribute("value", value); result->types.push_back(type); } @@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const { p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); if (!getValue().isa<FunctionAttr>()) - *p << " : " << *getType(); + *p << " : " << getType(); } bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { Attribute valueAttr; - Type *type; + Type type; if (parser->parseAttribute(valueAttr, "value", result->attributes) || parser->parseOptionalAttributeDict(result->attributes)) @@ -208,33 +208,33 @@ bool ConstantOp::verify() const { if (!value) return emitOpError("requires a 'value' attribute"); - auto *type = this->getType(); - if (isa<IntegerType>(type) || type->isIndex()) { + auto type = this->getType(); + if (type.isa<IntegerType>() || type.isIndex()) { if (!value.isa<IntegerAttr>()) return emitOpError( "requires 'value' to be an integer for an integer result type"); return false; } - if (isa<FloatType>(type)) { + if (type.isa<FloatType>()) { if (!value.isa<FloatAttr>()) return emitOpError("requires 'value' to be a floating point constant"); return false; } - if (isa<VectorOrTensorType>(type)) { + if (type.isa<VectorOrTensorType>()) { if (!value.isa<ElementsAttr>()) return emitOpError("requires 'value' to be a vector/tensor constant"); return false; } - if (type->isTFString()) { + if (type.isTFString()) { if (!value.isa<StringAttr>()) return emitOpError("requires 'value' to be a string constant"); return false; } - if (isa<FunctionType>(type)) { + if (type.isa<FunctionType>()) { if (!value.isa<FunctionAttr>()) return emitOpError("requires 'value' to be a function reference"); return false; @@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands, } void ConstantFloatOp::build(Builder *builder, OperationState *result, - const APFloat &value, FloatType *type) { + const APFloat &value, FloatType type) { ConstantOp::build(builder, result, builder->getFloatAttr(value), type); } bool ConstantFloatOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<FloatType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<IntegerType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<IntegerType>(); } void ConstantIntOp::build(Builder *builder, OperationState *result, @@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result, /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, Type *type) { - assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type"); + int64_t value, Type type) { + assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); ConstantOp::build(builder, result, builder->getIntegerAttr(value), type); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); + return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState *result, @@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result, bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type *, 2> types; + SmallVector<Type, 2> types; llvm::SMLoc loc; return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || (!opInfo.empty() && parser->parseColonTypeList(types)) || @@ -330,7 +330,7 @@ bool ReturnOp::verify() const { // The operand number and types must match the function signature. MLFunction *function = cast<MLFunction>(block); - const auto &results = function->getType()->getResults(); + const auto &results = function->getType().getResults(); if (stmt->getNumOperands() != results.size()) return emitOpError("has " + Twine(stmt->getNumOperands()) + " operands, but enclosing function returns " + diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index efeb16b61db..70c0e1259b3 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -28,8 +28,8 @@ using namespace mlir; Function::Function(Kind kind, Location *location, StringRef name, - FunctionType *type, ArrayRef<NamedAttribute> attrs) - : nameAndKind(Identifier::get(name, type->getContext()), kind), + FunctionType type, ArrayRef<NamedAttribute> attrs) + : nameAndKind(Identifier::get(name, type.getContext()), kind), location(location), type(type) { this->attrs = AttributeListStorage::get(attrs, getContext()); } @@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const { return {}; } -MLIRContext *Function::getContext() const { return getType()->getContext(); } +MLIRContext *Function::getContext() const { return getType().getContext(); } /// Delete this object. void Function::destroy() { @@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const { // ExtFunction implementation. //===----------------------------------------------------------------------===// -ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, +ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::ExtFunc, location, name, type, attrs) {} @@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, // CFGFunction implementation. //===----------------------------------------------------------------------===// -CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type, +CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::CFGFunc, location, name, type, attrs) {} @@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() { /// Create a new MLFunction with the specific fields. MLFunction *MLFunction::create(Location *location, StringRef name, - FunctionType *type, + FunctionType type, ArrayRef<NamedAttribute> attrs) { - const auto &argTypes = type->getInputs(); + const auto &argTypes = type.getInputs(); auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size()); void *rawMem = malloc(byteSize); @@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name, return function; } -MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type, +MLFunction::MLFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::MLFunc, location, name, type, attrs), StmtBlock(StmtBlockKind::MLFunc) {} diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index 422636bf2e3..d2f49ddfc6e 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const { /// Create a new OperationInst with the specified fields. OperationInst *OperationInst::create(Location *location, OperationName name, ArrayRef<CFGValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(), @@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name, OperationInst *OperationInst::clone() const { SmallVector<CFGValue *, 8> operands; - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; // Put together the operands and results. for (auto *operand : getOperands()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0a2e9416842..8811f7b9f78 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -21,6 +21,7 @@ #include "AttributeDetail.h" #include "AttributeListStorage.h" #include "IntegerSetDetail.h" +#include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -44,11 +45,11 @@ using namespace mlir::detail; using namespace llvm; namespace { -struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { +struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> { // Functions are uniqued based on their inputs and results. - using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>; - using DenseMapInfo<FunctionType *>::getHashValue; - using DenseMapInfo<FunctionType *>::isEqual; + using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>; + using DenseMapInfo<FunctionTypeStorage *>::getHashValue; + using DenseMapInfo<FunctionTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( @@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) { + static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; return lhs == KeyTy(rhs->getInputs(), rhs->getResults()); @@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { } }; -struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> { +struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> { // Vectors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<VectorType *>::getHashValue; - using DenseMapInfo<VectorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<VectorTypeStorage *>::getHashValue; + using DenseMapInfo<VectorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const VectorType *rhs) { + static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> { +struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> { // Ranked tensors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<RankedTensorType *>::getHashValue; - using DenseMapInfo<RankedTensorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue; + using DenseMapInfo<RankedTensorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) { + static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> { +struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> { // MemRefs are uniqued based on their element type, shape, affine map // composition, and memory space. - using KeyTy = - std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; - using DenseMapInfo<MemRefType *>::getHashValue; - using DenseMapInfo<MemRefType *>::isEqual; + using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; + using DenseMapInfo<MemRefTypeStorage *>::getHashValue; + using DenseMapInfo<MemRefTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(std::get<0>(key)), + DenseMapInfo<Type>::getHashValue(std::get<0>(key)), hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()), hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), std::get<3>(key)); } - static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) { + static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(), - rhs->getAffineMaps(), rhs->getMemorySpace()); + return lhs == std::make_tuple(rhs->elementType, rhs->getShape(), + rhs->getAffineMaps(), rhs->memorySpace); } }; @@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> { }; struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>; + using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>; using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue; using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual; @@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { }; struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, StringRef>; + using KeyTy = std::pair<VectorOrTensorType, StringRef>; using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue; using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual; @@ -295,13 +295,14 @@ public: llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers; // Uniquing table for 'other' types. - OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - - int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr}; + OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - + int(Type::Kind::FIRST_OTHER_TYPE) + 1] = { + nullptr}; // Uniquing table for 'float' types. - FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - - int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = { - nullptr}; + FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - + int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = + {nullptr}; // Affine map uniquing. using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; @@ -324,26 +325,26 @@ public: DenseMap<int64_t, AffineConstantExprStorage *> constExprs; /// Integer type uniquing. - DenseMap<unsigned, IntegerType *> integers; + DenseMap<unsigned, IntegerTypeStorage *> integers; /// Function type uniquing. - using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>; + using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>; FunctionTypeSet functions; /// Vector type uniquing. - using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>; + using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>; VectorTypeSet vectors; /// Ranked tensor type uniquing. using RankedTensorTypeSet = - DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>; + DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>; RankedTensorTypeSet rankedTensors; /// Unranked tensor type uniquing. - DenseMap<Type *, UnrankedTensorType *> unrankedTensors; + DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors; /// MemRef type uniquing. - using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>; + using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>; MemRefTypeSet memrefs; // Attribute uniquing. @@ -355,13 +356,12 @@ public: ArrayAttrSet arrayAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs; DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs; - DenseMap<Type *, TypeAttributeStorage *> typeAttrs; + DenseMap<Type, TypeAttributeStorage *> typeAttrs; using AttributeListSet = DenseSet<AttributeListStorage *, AttributeListKeyInfo>; AttributeListSet attributeLists; DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs; - DenseMap<std::pair<VectorOrTensorType *, Attribute>, - SplatElementsAttributeStorage *> + DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *> splatElementsAttrs; using DenseElementsAttrSet = DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>; @@ -369,7 +369,7 @@ public: using OpaqueElementsAttrSet = DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>; OpaqueElementsAttrSet opaqueElementsAttrs; - DenseMap<std::tuple<Type *, Attribute, Attribute>, + DenseMap<std::tuple<Type, Attribute, Attribute>, SparseElementsAttributeStorage *> sparseElementsAttrs; @@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line, // Type uniquing //===----------------------------------------------------------------------===// -IntegerType *IntegerType::get(unsigned width, MLIRContext *context) { +IntegerType IntegerType::get(unsigned width, MLIRContext *context) { + assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); auto &impl = context->getImpl(); auto *&result = impl.integers[width]; if (!result) { - result = impl.allocator.Allocate<IntegerType>(); - new (result) IntegerType(width, context); + result = impl.allocator.Allocate<IntegerTypeStorage>(); + new (result) IntegerTypeStorage{{Kind::Integer, context}, width}; } return result; } -FloatType *FloatType::get(Kind kind, MLIRContext *context) { +FloatType FloatType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE && kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind"); auto &impl = context->getImpl(); @@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<FloatType>(); + auto *ptr = impl.allocator.Allocate<FloatTypeStorage>(); // Initialize the memory using placement new. - new (ptr) FloatType(kind, context); + new (ptr) FloatTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -OtherType *OtherType::get(Kind kind, MLIRContext *context) { +OtherType OtherType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE && "Not an 'other' type kind"); auto &impl = context->getImpl(); @@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<OtherType>(); + auto *ptr = impl.allocator.Allocate<OtherTypeStorage>(); // Initialize the memory using placement new. - new (ptr) OtherType(kind, context); + new (ptr) OtherTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -FunctionType *FunctionType::get(ArrayRef<Type *> inputs, - ArrayRef<Type *> results, - MLIRContext *context) { +FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results, + MLIRContext *context) { auto &impl = context->getImpl(); // Look to see if we already have this function type. @@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<FunctionType>(); + auto *result = impl.allocator.Allocate<FunctionTypeStorage>(); // Copy the inputs and results into the bump pointer. - SmallVector<Type *, 16> types; + SmallVector<Type, 16> types; types.reserve(inputs.size() + results.size()); types.append(inputs.begin(), inputs.end()); types.append(results.begin(), results.end()); - auto typesList = impl.copyInto(ArrayRef<Type *>(types)); + auto typesList = impl.copyInto(ArrayRef<Type>(types)); // Initialize the memory using placement new. - new (result) - FunctionType(typesList.data(), inputs.size(), results.size(), context); + new (result) FunctionTypeStorage{ + {Kind::Function, context, static_cast<unsigned int>(inputs.size())}, + static_cast<unsigned int>(results.size()), + typesList.data()}; // Cache and return it. return *existing.first = result; } -VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { +VectorType VectorType::get(ArrayRef<int> shape, Type elementType) { assert(!shape.empty() && "vector types must have at least one dimension"); - assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) && + assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) && "vectors elements must be primitives"); assert(!std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; }) && "vector types must have static shape"); - auto *context = elementType->getContext(); + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this vector type. @@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<VectorType>(); + auto *result = impl.allocator.Allocate<VectorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) VectorType(shape, elementType, context); + new (result) VectorTypeStorage{ + {{Kind::Vector, context, static_cast<unsigned int>(shape.size())}, + elementType}, + shape.data()}; // Cache and return it. return *existing.first = result; } -RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, - Type *elementType) { - auto *context = elementType->getContext(); +RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this ranked tensor type. @@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<RankedTensorType>(); + auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) RankedTensorType(shape, elementType, context); + new (result) RankedTensorTypeStorage{ + {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())}, + elementType}}, + shape.data()}; // Cache and return it. return *existing.first = result; } -UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { - auto *context = elementType->getContext(); +UnrankedTensorType UnrankedTensorType::get(Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this unranked tensor type. @@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { return result; // On the first use, we allocate them into the bump pointer. - result = impl.allocator.Allocate<UnrankedTensorType>(); + result = impl.allocator.Allocate<UnrankedTensorTypeStorage>(); // Initialize the memory using placement new. - new (result) UnrankedTensorType(elementType, context); + new (result) UnrankedTensorTypeStorage{ + {{{Kind::UnrankedTensor, context}, elementType}}}; return result; } -MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { - auto *context = elementType->getContext(); +MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Drop the unbounded identity maps from the composition. @@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<MemRefType>(); + auto *result = impl.allocator.Allocate<MemRefTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); @@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, impl.copyInto(ArrayRef<AffineMap>(affineMapComposition)); // Initialize the memory using placement new. - new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace, - context); + new (result) MemRefTypeStorage{ + {Kind::MemRef, context, static_cast<unsigned int>(shape.size())}, + elementType, + shape.data(), + static_cast<unsigned int>(affineMapComposition.size()), + affineMapComposition.data(), + memorySpace}; // Cache and return it. return *existing.first = result; } @@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { return result; } -TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { +TypeAttr TypeAttr::get(Type type, MLIRContext *context) { auto *&result = context->getImpl().typeAttrs[type]; if (result) return result; @@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs, return *existing.first = result; } -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, +SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, Attribute elt) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto *&result = impl.splatElementsAttrs[{type, elt}]; @@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, return result; } -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, +DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef<char> data) { - auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); + auto bitsRequired = (long)type.getBitWidth() * type.getNumElements(); (void)bitsRequired; assert((bitsRequired <= data.size() * 8L) && "Input data bit size should be larger than that type requires"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. DenseElementsAttrInfo::KeyTy key({type, data}); @@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first; // Otherwise, allocate a new one, unique it and return it. - auto *eltType = type->getElementType(); - switch (eltType->getKind()) { + auto eltType = type.getElementType(); + switch (eltType.getKind()) { case Type::Kind::BF16: case Type::Kind::F16: case Type::Kind::F32: @@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } case Type::Kind::Integer: { - auto width = ::cast<IntegerType>(eltType)->getWidth(); + auto width = eltType.cast<IntegerType>().getWidth(); auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>(); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); std::uninitialized_copy(data.begin(), data.end(), copy); @@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, } } -OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, +OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, StringRef bytes) { - assert(isValidTensorElementType(type->getElementType()) && + assert(isValidTensorElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. OpaqueElementsAttrInfo::KeyTy key({type, bytes}); @@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type, +SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto key = std::make_tuple(type, indices, values); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2ed09b83b53..0722421c8ba 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, } bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { - auto *type = op->getResult(0)->getType(); + auto type = op->getResult(0)->getType(); for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { if (op->getResult(i)->getType() != type) return op->emitOpError( @@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { /// If this is a vector type, or a tensor type, return the scalar element type /// that it is built around, otherwise return the type unmodified. -static Type *getTensorOrVectorElementType(Type *type) { - if (auto *vec = dyn_cast<VectorType>(type)) - return vec->getElementType(); +static Type getTensorOrVectorElementType(Type type) { + if (auto vec = type.dyn_cast<VectorType>()) + return vec.getElementType(); // Look through tensor<vector<...>> to find the underlying element type. - if (auto *tensor = dyn_cast<TensorType>(type)) - return getTensorOrVectorElementType(tensor->getElementType()); + if (auto tensor = type.dyn_cast<TensorType>()) + return getTensorOrVectorElementType(tensor.getElementType()); return type; } bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<FloatType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) return op->emitOpError("requires a floating point type"); } @@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>()) return op->emitOpError("requires an integer type"); } return false; @@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> ops; - Type *type; + Type type; return parser->parseOperandList(ops, 2) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type) || @@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p->printOptionalAttrDict(op->getAttrs()); - *p << " : " << *op->getResult(0)->getType(); + *p << " : " << op->getResult(0)->getType(); } //===----------------------------------------------------------------------===// @@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { //===----------------------------------------------------------------------===// void impl::buildCastOp(Builder *builder, OperationState *result, - SSAValue *source, Type *destType) { + SSAValue *source, Type destType) { result->addOperands(source); result->addTypes(destType); } bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType srcInfo; - Type *srcType, *dstType; + Type srcType, dstType; return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || parser->resolveOperand(srcInfo, srcType, result->operands) || parser->parseKeywordType("to", dstType) || @@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " - << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType(); + << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index e9c46d6ec5e..698089a1c67 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block, /// Create a new OperationStmt with the specific fields. OperationStmt *OperationStmt::create(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(), @@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) - return getResult(0)->getType()->getContext(); + return getResult(0)->getType().getContext(); if (getNumOperands()) - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); // In the very odd case where we have no operands or results, fall back to // doing a find. @@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const { if (operands.empty()) return findFunction()->getContext(); - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); } //===----------------------------------------------------------------------===// @@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, operands.push_back(remapOperand(opValue)); if (auto *opStmt = dyn_cast<OperationStmt>(this)) { - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; resultTypes.reserve(opStmt->getNumResults()); for (auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h new file mode 100644 index 00000000000..c22e87a283e --- /dev/null +++ b/mlir/lib/IR/TypeDetail.h @@ -0,0 +1,126 @@ +//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of Type. +// +//===----------------------------------------------------------------------===// +#ifndef TYPEDETAIL_H_ +#define TYPEDETAIL_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +class AffineMap; +class MLIRContext; + +namespace detail { + +/// Base storage class appearing in a Type. +struct alignas(8) TypeStorage { + TypeStorage(Type::Kind kind, MLIRContext *context) + : context(context), kind(kind), subclassData(0) {} + TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData) + : context(context), kind(kind), subclassData(subclassData) {} + + unsigned getSubclassData() const { return subclassData; } + + void setSubclassData(unsigned val) { + subclassData = val; + // Ensure we don't have any accidental truncation. + assert(getSubclassData() == val && "Subclass data too large for field"); + } + + /// This refers to the MLIRContext in which this type was uniqued. + MLIRContext *const context; + + /// Classification of the subclass, used for type checking. + Type::Kind kind : 8; + + /// Space for subclasses to store data. + unsigned subclassData : 24; +}; + +struct IntegerTypeStorage : public TypeStorage { + unsigned width; +}; + +struct FloatTypeStorage : public TypeStorage {}; + +struct OtherTypeStorage : public TypeStorage {}; + +struct FunctionTypeStorage : public TypeStorage { + ArrayRef<Type> getInputs() const { + return ArrayRef<Type>(inputsAndResults, subclassData); + } + ArrayRef<Type> getResults() const { + return ArrayRef<Type>(inputsAndResults + subclassData, numResults); + } + + unsigned numResults; + Type const *inputsAndResults; +}; + +struct VectorOrTensorTypeStorage : public TypeStorage { + Type elementType; +}; + +struct VectorTypeStorage : public VectorOrTensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct TensorTypeStorage : public VectorOrTensorTypeStorage {}; + +struct RankedTensorTypeStorage : public TensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct UnrankedTensorTypeStorage : public TensorTypeStorage {}; + +struct MemRefTypeStorage : public TypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + ArrayRef<AffineMap> getAffineMaps() const { + return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + } + + /// The type of each scalar element of the memref. + Type elementType; + /// An array of integers which stores the shape dimension sizes. + const int *shapeElements; + /// The number of affine maps in the 'affineMapList' array. + const unsigned numAffineMaps; + /// List of affine maps in the memref's layout/index map composition. + AffineMap const *affineMapList; + /// Memory space in which data referenced by memref resides. + const unsigned memorySpace; +}; + +} // namespace detail +} // namespace mlir +#endif // TYPEDETAIL_H_ diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 0ad3f4728fe..1a716956608 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -16,10 +16,17 @@ // ============================================================================= #include "mlir/IR/Types.h" +#include "TypeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/raw_ostream.h" + using namespace mlir; +using namespace mlir::detail; + +Type::Kind Type::getKind() const { return type->kind; } + +MLIRContext *Type::getContext() const { return type->context; } unsigned Type::getBitWidth() const { switch (getKind()) { @@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const { case Type::Kind::F64: return 64; case Type::Kind::Integer: - return cast<IntegerType>(this)->getWidth(); + return cast<IntegerType>().getWidth(); case Type::Kind::Vector: case Type::Kind::RankedTensor: case Type::Kind::UnrankedTensor: - return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth(); + return cast<VectorOrTensorType>().getElementType().getBitWidth(); // TODO: Handle more types. default: llvm_unreachable("unexpected type"); } } -IntegerType::IntegerType(unsigned width, MLIRContext *context) - : Type(Kind::Integer, context), width(width) { - assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); +unsigned Type::getSubclassData() const { return type->getSubclassData(); } +void Type::setSubclassData(unsigned val) { type->setSubclassData(val); } + +IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {} + +unsigned IntegerType::getWidth() const { + return static_cast<ImplType *>(type)->width; } -FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {} + +OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {} -OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {} -FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, - unsigned numResults, MLIRContext *context) - : Type(Kind::Function, context, numInputs), numResults(numResults), - inputsAndResults(inputsAndResults) {} +ArrayRef<Type> FunctionType::getInputs() const { + return static_cast<ImplType *>(type)->getInputs(); +} + +unsigned FunctionType::getNumResults() const { + return static_cast<ImplType *>(type)->numResults; +} + +ArrayRef<Type> FunctionType::getResults() const { + return static_cast<ImplType *>(type)->getResults(); +} -VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, - Type *elementType, unsigned subClassData) - : Type(kind, context, subClassData), elementType(elementType) {} +VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {} + +Type VectorOrTensorType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} unsigned VectorOrTensorType::getNumElements() const { switch (getKind()) { @@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const { ArrayRef<int> VectorOrTensorType::getShape() const { switch (getKind()) { case Kind::Vector: - return cast<VectorType>(this)->getShape(); + return cast<VectorType>().getShape(); case Kind::RankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); case Kind::UnrankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); default: llvm_unreachable("not a VectorOrTensorType"); } @@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const { return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); } -VectorType::VectorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()), - shapeElements(shape.data()) {} +VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} -TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) - : VectorOrTensorType(kind, context, elementType) { - assert(isValidTensorElementType(elementType)); +ArrayRef<int> VectorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : TensorType(Kind::RankedTensor, elementType, context), - shapeElements(shape.data()) { - setSubclassData(shape.size()); +TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} + +RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +ArrayRef<int> RankedTensorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) - : TensorType(Kind::UnrankedTensor, elementType, context) {} +UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {} -MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapList, unsigned memorySpace, - MLIRContext *context) - : Type(Kind::MemRef, context, shape.size()), elementType(elementType), - shapeElements(shape.data()), numAffineMaps(affineMapList.size()), - affineMapList(affineMapList.data()), memorySpace(memorySpace) {} +ArrayRef<int> MemRefType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); +} + +Type MemRefType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} ArrayRef<AffineMap> MemRefType::getAffineMaps() const { - return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + return static_cast<ImplType *>(type)->getAffineMaps(); +} + +unsigned MemRefType::getMemorySpace() const { + return static_cast<ImplType *>(type)->memorySpace; } unsigned MemRefType::getNumDynamicDims() const { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7974c7c71a4..ceb893165f0 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -182,19 +182,19 @@ public: // as the results of their action. // Type parsing. - VectorType *parseVectorType(); + VectorType parseVectorType(); ParseResult parseXInDimensionList(); ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions); - Type *parseTensorType(); - Type *parseMemRefType(); - Type *parseFunctionType(); - Type *parseType(); - ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements); - ParseResult parseTypeList(SmallVectorImpl<Type *> &elements); + Type parseTensorType(); + Type parseMemRefType(); + Type parseFunctionType(); + Type parseType(); + ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); + ParseResult parseTypeList(SmallVectorImpl<Type> &elements); // Attribute parsing. Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, - FunctionType *type); + FunctionType type); Attribute parseAttribute(); ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); @@ -206,9 +206,9 @@ public: AffineMap parseAffineMapReference(); IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetReference(); - DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type); - DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector); - VectorOrTensorType *parseVectorOrTensorType(); + DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type); + DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector); + VectorOrTensorType parseVectorOrTensorType(); private: // The Parser is subclassed and reinstantiated. Do not add additional @@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil( /// float-type ::= `f16` | `bf16` | `f32` | `f64` /// other-type ::= `index` | `tf_control` /// -Type *Parser::parseType() { +Type Parser::parseType() { switch (getToken().getKind()) { default: return (emitError("expected type"), nullptr); @@ -368,7 +368,7 @@ Type *Parser::parseType() { /// vector-type ::= `vector` `<` const-dimension-list primitive-type `>` /// const-dimension-list ::= (integer-literal `x`)+ /// -VectorType *Parser::parseVectorType() { +VectorType Parser::parseVectorType() { consumeToken(Token::kw_vector); if (parseToken(Token::less, "expected '<' in vector type")) @@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; - if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) + if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) return (emitError(typeLoc, "invalid vector element type"), nullptr); return VectorType::get(dimensions, elementType); @@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) { /// tensor-type ::= `tensor` `<` dimension-list element-type `>` /// dimension-list ::= dimension-list-ranked | `*x` /// -Type *Parser::parseTensorType() { +Type Parser::parseTensorType() { consumeToken(Token::kw_tensor); if (parseToken(Token::less, "expected '<' in tensor type")) @@ -485,7 +485,7 @@ Type *Parser::parseTensorType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) return nullptr; @@ -505,7 +505,7 @@ Type *Parser::parseTensorType() { /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map /// memory-space ::= integer-literal /* | TODO: address-space-id */ /// -Type *Parser::parseMemRefType() { +Type Parser::parseMemRefType() { consumeToken(Token::kw_memref); if (parseToken(Token::less, "expected '<' in memref type")) @@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType) return nullptr; - if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) && - !isa<VectorType>(elementType)) + if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() && + !elementType.isa<VectorType>()) return (emitError(typeLoc, "invalid memref element type"), nullptr); // Parse semi-affine-map-composition. @@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() { /// /// function-type ::= type-list-parens `->` type-list /// -Type *Parser::parseFunctionType() { +Type Parser::parseFunctionType() { assert(getToken().is(Token::l_paren)); - SmallVector<Type *, 4> arguments, results; + SmallVector<Type, 4> arguments, results; if (parseTypeList(arguments) || parseToken(Token::arrow, "expected '->' in function type") || parseTypeList(results)) @@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() { /// /// type-list-no-parens ::= type (`,` type)* /// -ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) { +ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { auto parseElt = [&]() -> ParseResult { auto elt = parseType(); elements.push_back(elt); @@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) { /// type-list-parens ::= `(` `)` /// | `(` type-list-no-parens `)` /// -ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) { +ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) { auto parseElt = [&]() -> ParseResult { auto elt = parseType(); elements.push_back(elt); @@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) { namespace { class TensorLiteralParser { public: - TensorLiteralParser(Parser &p, Type *eltTy) - : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {} + TensorLiteralParser(Parser &p, Type eltTy) + : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {} ParseResult parse() { return parseList(shape); } @@ -676,7 +676,7 @@ private: } Parser &p; - Type *eltTy; + Type eltTy; size_t currBitPos; size_t bitsWidth; SmallVector<int, 4> shape; @@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) { if (!result) return p.emitError("expected tensor element"); // check result matches the element type. - switch (eltTy->getKind()) { + switch (eltTy.getKind()) { case Type::Kind::BF16: case Type::Kind::F16: case Type::Kind::F32: @@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) { /// synthesizing a forward reference) or emit an error and return null on /// failure. Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, - FunctionType *type) { + FunctionType type) { Identifier name = builder.getIdentifier(nameStr.drop_front()); // See if the function has already been defined in the module. @@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::colon, "expected ':' and function type")) return nullptr; auto typeLoc = getToken().getLoc(); - Type *type = parseType(); + Type type = parseType(); if (!type) return nullptr; - auto *fnType = dyn_cast<FunctionType>(type); + auto fnType = type.dyn_cast<FunctionType>(); if (!fnType) return (emitError(typeLoc, "expected function type"), nullptr); @@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() { consumeToken(Token::kw_opaque); if (parseToken(Token::less, "expected '<' after 'opaque'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; auto val = getToken().getStringValue(); @@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "expected '<' after 'splat'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; switch (getToken().getKind()) { @@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; @@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; switch (getToken().getKind()) { case Token::l_square: { /// Parse indices - auto *indicesEltType = builder.getIntegerType(32); + auto indicesEltType = builder.getIntegerType(32); auto indices = - parseDenseElementsAttr(indicesEltType, isa<VectorType>(type)); + parseDenseElementsAttr(indicesEltType, type.isa<VectorType>()); if (parseToken(Token::comma, "expected ','")) return nullptr; /// Parse values. - auto *valuesEltType = type->getElementType(); + auto valuesEltType = type.getElementType(); auto values = - parseDenseElementsAttr(valuesEltType, isa<VectorType>(type)); + parseDenseElementsAttr(valuesEltType, type.isa<VectorType>()); /// Sanity check. - auto *indicesType = indices.getType(); - auto *valuesType = values.getType(); - auto sameShape = (indicesType->getRank() == 1) || - (type->getRank() == indicesType->getDimSize(1)); + auto indicesType = indices.getType(); + auto valuesType = values.getType(); + auto sameShape = (indicesType.getRank() == 1) || + (type.getRank() == indicesType.getDimSize(1)); auto sameElementNum = - indicesType->getDimSize(0) == valuesType->getDimSize(0); + indicesType.getDimSize(0) == valuesType.getDimSize(0); if (!sameShape || !sameElementNum) { std::string str; llvm::raw_string_ostream s(str); s << "expected shape (["; - interleaveComma(type->getShape(), s); + interleaveComma(type.getShape(), s); s << "]); inferred shape of indices literal (["; - interleaveComma(indicesType->getShape(), s); + interleaveComma(indicesType.getShape(), s); s << "]); inferred shape of values literal (["; - interleaveComma(valuesType->getShape(), s); + interleaveComma(valuesType.getShape(), s); s << "])"; return (emitError(s.str()), nullptr); } @@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() { nullptr); } default: { - if (Type *type = parseType()) + if (Type type = parseType()) return builder.getTypeAttr(type); return nullptr; } @@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() { /// /// This method returns a constructed dense elements attribute with the shape /// from the parsing result. -DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) { +DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) { TensorLiteralParser literalParser(*this, eltType); if (literalParser.parse()) return nullptr; - VectorOrTensorType *type; + VectorOrTensorType type; if (isVector) { type = builder.getVectorType(literalParser.getShape(), eltType); } else { @@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) { /// This method compares the shapes from the parsing result and that from the /// input argument. It returns a constructed dense elements attribute if both /// match. -DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) { - auto *eltTy = type->getElementType(); +DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { + auto eltTy = type.getElementType(); TensorLiteralParser literalParser(*this, eltTy); if (literalParser.parse()) return nullptr; - if (literalParser.getShape() != type->getShape()) { + if (literalParser.getShape() != type.getShape()) { std::string str; llvm::raw_string_ostream s(str); s << "inferred shape of elements literal (["; interleaveComma(literalParser.getShape(), s); s << "]) does not match type (["; - interleaveComma(type->getShape(), s); + interleaveComma(type.getShape(), s); s << "])"; return (emitError(s.str()), nullptr); } @@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) { /// vector-or-tensor-type ::= vector-type | tensor-type /// /// This method also checks the type has static shape and ranked. -VectorOrTensorType *Parser::parseVectorOrTensorType() { - auto *type = dyn_cast<VectorOrTensorType>(parseType()); +VectorOrTensorType Parser::parseVectorOrTensorType() { + auto type = parseType().dyn_cast<VectorOrTensorType>(); if (!type) { return (emitError("expected elements literal has a tensor or vector type"), nullptr); @@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() { if (parseToken(Token::comma, "expected ','")) return nullptr; - if (!type->hasStaticShape() || type->getRank() == -1) { + if (!type.hasStaticShape() || type.getRank() == -1) { return (emitError("tensor literals must be ranked and have static shape"), nullptr); } @@ -1834,7 +1834,7 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type); + SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type); /// Register a definition of a value with the symbol table. ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); @@ -1845,11 +1845,11 @@ public: template <typename ResultType> ResultType parseSSADefOrUseAndType( - const std::function<ResultType(SSAUseInfo, Type *)> &action); + const std::function<ResultType(SSAUseInfo, Type)> &action); SSAValue *parseSSAUseAndType() { return parseSSADefOrUseAndType<SSAValue *>( - [&](SSAUseInfo useInfo, Type *type) -> SSAValue * { + [&](SSAUseInfo useInfo, Type type) -> SSAValue * { return resolveSSAUse(useInfo, type); }); } @@ -1880,7 +1880,7 @@ private: /// their first reference, to allow checking for use of undefined values. DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders; - SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type); + SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. bool isForwardReferencePlaceholder(SSAValue *value) { @@ -1891,7 +1891,7 @@ private: /// Create and remember a new placeholder for a forward reference. SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, - Type *type) { + Type type) { // Forward references are always created as instructions, even in ML // functions, because we just need something with a def/use chain. // @@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) { +SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = values[useInfo.name]; // If we have already seen a value of this name, return it. @@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { /// ssa-use-and-type ::= ssa-use `:` type template <typename ResultType> ResultType FunctionParser::parseSSADefOrUseAndType( - const std::function<ResultType(SSAUseInfo, Type *)> &action) { + const std::function<ResultType(SSAUseInfo, Type)> &action) { SSAUseInfo useInfo; if (parseSSAUse(useInfo) || parseToken(Token::colon, "expected ':' and type for SSA operand")) return nullptr; - auto *type = parseType(); + auto type = parseType(); if (!type) return nullptr; @@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList( if (valueIDs.empty()) return ParseSuccess; - SmallVector<Type *, 4> types; + SmallVector<Type, 4> types; if (parseToken(Token::colon, "expected ':' in operand list") || parseTypeListNoParens(types)) return ParseFailure; @@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation( auto type = parseType(); if (!type) return nullptr; - auto fnType = dyn_cast<FunctionType>(type); + auto fnType = type.dyn_cast<FunctionType>(); if (!fnType) return (emitError(typeLoc, "expected function type"), nullptr); - result.addTypes(fnType->getResults()); + result.addTypes(fnType.getResults()); // Check that we have the right number of types for the operands. - auto operandTypes = fnType->getInputs(); + auto operandTypes = fnType.getInputs(); if (operandTypes.size() != operandInfos.size()) { auto plural = "s"[operandInfos.size() == 1]; return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) + @@ -2253,17 +2253,17 @@ public: return parser.parseToken(Token::comma, "expected ','"); } - bool parseColonType(Type *&result) override { + bool parseColonType(Type &result) override { return parser.parseToken(Token::colon, "expected ':'") || !(result = parser.parseType()); } - bool parseColonTypeList(SmallVectorImpl<Type *> &result) override { + bool parseColonTypeList(SmallVectorImpl<Type> &result) override { if (parser.parseToken(Token::colon, "expected ':'")) return true; do { - if (auto *type = parser.parseType()) + if (auto type = parser.parseType()) result.push_back(type); else return true; @@ -2273,7 +2273,7 @@ public: } /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type *&result) override { + bool parseKeywordType(const char *keyword, Type &result) override { if (parser.getTokenSpelling() != keyword) return parser.emitError("expected '" + Twine(keyword) + "'"); parser.consumeToken(); @@ -2396,7 +2396,7 @@ public: } /// Resolve a parse function name and a type into a function reference. - virtual bool resolveFunctionName(StringRef name, FunctionType *type, + virtual bool resolveFunctionName(StringRef name, FunctionType type, llvm::SMLoc loc, Function *&result) { result = parser.resolveFunctionReference(name, loc, type); return result == nullptr; @@ -2410,7 +2410,7 @@ public: llvm::SMLoc getNameLoc() const override { return nameLoc; } - bool resolveOperand(const OperandType &operand, Type *type, + bool resolveOperand(const OperandType &operand, Type type, SmallVectorImpl<SSAValue *> &result) override { FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; @@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList( return ParseSuccess; return parseCommaSeparatedList([&]() -> ParseResult { - auto type = parseSSADefOrUseAndType<Type *>( - [&](SSAUseInfo useInfo, Type *type) -> Type * { + auto type = parseSSADefOrUseAndType<Type>( + [&](SSAUseInfo useInfo, Type type) -> Type { BBArgument *arg = owner->addArgument(type); if (addDefinition(useInfo, arg)) - return nullptr; + return {}; return type; }); return type ? ParseSuccess : ParseFailure; @@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, " symbol count must match"); // Resolve SSA uses. - Type *indexType = builder.getIndexType(); + Type indexType = builder.getIndexType(); for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { SSAValue *sval = resolveSSAUse(opInfo[i], indexType); if (!sval) @@ -3187,9 +3187,9 @@ private: ParseResult parseAffineStructureDef(); // Functions. - ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, + ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames); - ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type, + ParseResult parseFunctionSignature(StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> *argNames); ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs); ParseResult parseExtFunc(); @@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() { /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ /// ParseResult -ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, +ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames) { consumeToken(Token::l_paren); @@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, /// type-list)? /// ParseResult -ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, +ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> *argNames) { if (getToken().isNot(Token::at_identifier)) return emitError("expected a function identifier like '@foo'"); @@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, if (getToken().isNot(Token::l_paren)) return emitError("expected '(' in function signature"); - SmallVector<Type *, 4> argTypes; + SmallVector<Type, 4> argTypes; ParseResult parseResult; if (argNames) @@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, return ParseFailure; // Parse the return type if present. - SmallVector<Type *, 4> results; + SmallVector<Type, 4> results; if (consumeIf(Token::arrow)) { if (parseTypeList(results)) return ParseFailure; @@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() { auto loc = getToken().getLoc(); StringRef name; - FunctionType *type = nullptr; + FunctionType type; if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; @@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() { auto loc = getToken().getLoc(); StringRef name; - FunctionType *type = nullptr; + FunctionType type; if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; @@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() { consumeToken(Token::kw_mlfunc); StringRef name; - FunctionType *type = nullptr; + FunctionType type; SmallVector<StringRef, 4> argNames; auto loc = getToken().getLoc(); diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index b60d209e1f5..e2bdfd7a18b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// void AllocOp::build(Builder *builder, OperationState *result, - MemRefType *memrefType, ArrayRef<SSAValue *> operands) { + MemRefType memrefType, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->types.push_back(memrefType); } void AllocOp::print(OpAsmPrinter *p) const { - MemRefType *type = getType(); + MemRefType type = getType(); *p << "alloc"; // Print dynamic dimension operands. printDimAndSymbolList(operand_begin(), operand_end(), - type->getNumDynamicDims(), p); + type.getNumDynamicDims(), p); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); - *p << " : " << *type; + *p << " : " << type; } bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { - MemRefType *type; + MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a // memref type. @@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. - if (numDimOperands != type->getNumDynamicDims()) { + if (numDimOperands != type.getNumDynamicDims()) { return parser->emitError(parser->getNameLoc(), "dimension operand count does not equal memref " "dynamic dimension count"); @@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { } bool AllocOp::verify() const { - auto *memRefType = dyn_cast<MemRefType>(getResult()->getType()); + auto memRefType = getResult()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("result must be a memref"); unsigned numSymbols = 0; - if (!memRefType->getAffineMaps().empty()) { - AffineMap affineMap = memRefType->getAffineMaps()[0]; + if (!memRefType.getAffineMaps().empty()) { + AffineMap affineMap = memRefType.getAffineMaps()[0]; // Store number of symbols used in affine map (used in subsequent check). numSymbols = affineMap.getNumSymbols(); // TODO(zinenko): this check does not belong to AllocOp, or any other op but @@ -195,10 +195,10 @@ bool AllocOp::verify() const { // Remove when we can emit errors directly from *Type::get(...) functions. // // Verify that the layout affine map matches the rank of the memref. - if (affineMap.getNumDims() != memRefType->getRank()) + if (affineMap.getNumDims() != memRefType.getRank()) return emitOpError("affine map dimension count must equal memref rank"); } - unsigned numDynamicDims = memRefType->getNumDynamicDims(); + unsigned numDynamicDims = memRefType.getNumDynamicDims(); // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. @@ -208,7 +208,7 @@ bool AllocOp::verify() const { } // Verify that all operands are of type Index. for (auto *operand : getOperands()) { - if (!operand->getType()->isIndex()) + if (!operand->getType().isIndex()) return emitOpError("requires operands to be of type Index"); } return false; @@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern { // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector<int, 4> newShapeConstants; - newShapeConstants.reserve(memrefType->getRank()); + newShapeConstants.reserve(memrefType.getRank()); SmallVector<SSAValue *, 4> newOperands; SmallVector<SSAValue *, 4> droppedOperands; unsigned dynamicDimPos = 0; - for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) { - int dimSize = memrefType->getDimSize(dim); + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); @@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern { } // Create new memref type (which will have fewer dynamic dimensions). - auto *newMemRefType = MemRefType::get( - newShapeConstants, memrefType->getElementType(), - memrefType->getAffineMaps(), memrefType->getMemorySpace()); - assert(newOperands.size() == newMemRefType->getNumDynamicDims()); + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(newOperands.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = @@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType()->getResults()); + result->addTypes(callee->getType().getResults()); } bool CallOp::parse(OpAsmParser *parser, OperationState *result) { StringRef calleeName; llvm::SMLoc calleeLoc; - FunctionType *calleeType = nullptr; + FunctionType calleeType; SmallVector<OpAsmParser::OperandType, 4> operands; Function *callee = nullptr; if (parser->parseFunctionName(calleeName, calleeLoc) || @@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || - parser->addTypesToList(calleeType->getResults(), result->types) || - parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, + parser->addTypesToList(calleeType.getResults(), result->types) || + parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, result->operands)) return true; @@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const { p->printOperands(getOperands()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallOp::verify() const { @@ -338,20 +338,20 @@ bool CallOp::verify() const { return emitOpError("requires a 'callee' function attribute"); // Verify that the operand and result types match the callee. - auto *fnType = fnAttr.getValue()->getType(); - if (fnType->getNumInputs() != getNumOperands()) + auto fnType = fnAttr.getValue()->getType(); + if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -364,14 +364,14 @@ bool CallOp::verify() const { void CallIndirectOp::build(Builder *builder, OperationState *result, SSAValue *callee, ArrayRef<SSAValue *> operands) { - auto *fnType = cast<FunctionType>(callee->getType()); + auto fnType = callee->getType().cast<FunctionType>(); result->operands.push_back(callee); result->addOperands(operands); - result->addTypes(fnType->getResults()); + result->addTypes(fnType.getResults()); } bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { - FunctionType *calleeType = nullptr; + FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector<OpAsmParser::OperandType, 4> operands; @@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, result->operands) || - parser->addTypesToList(calleeType->getResults(), result->types); + parser->addTypesToList(calleeType.getResults(), result->types); } void CallIndirectOp::print(OpAsmPrinter *p) const { @@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const { p->printOperands(++operandRange.begin(), operandRange.end()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallIndirectOp::verify() const { // The callee must be a function. - auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); + auto fnType = getCallee()->getType().dyn_cast<FunctionType>(); if (!fnType) return emitOpError("callee must have function type"); // Verify that the operand and result types match the callee. - if (fnType->getNumInputs() != getNumOperands() - 1) + if (fnType.getNumInputs() != getNumOperands() - 1) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i + 1)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i + 1)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result, } void DeallocOp::print(OpAsmPrinter *p) const { - *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); + *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); } bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; - MemRefType *type; + MemRefType type; return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands); } bool DeallocOp::verify() const { - if (!isa<MemRefType>(getMemRef()->getType())) + if (!getMemRef()->getType().isa<MemRefType>()) return emitOpError("operand must be a memref"); return false; } @@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result, void DimOp::print(OpAsmPrinter *p) const { *p << "dim " << *getOperand() << ", " << getIndex(); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); - *p << " : " << *getOperand()->getType(); + *p << " : " << getOperand()->getType(); } bool DimOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; - Type *type; + Type type; return parser->parseOperand(operandInfo) || parser->parseComma() || parser->parseAttribute(indexAttr, "index", result->attributes) || @@ -496,15 +496,15 @@ bool DimOp::verify() const { return emitOpError("requires an integer attribute named 'index'"); uint64_t index = (uint64_t)indexAttr.getValue(); - auto *type = getOperand()->getType(); - if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { - if (index >= tensorType->getRank()) + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + if (index >= tensorType.getRank()) return emitOpError("index is out of range"); - } else if (auto *memrefType = dyn_cast<MemRefType>(type)) { - if (index >= memrefType->getRank()) + } else if (auto memrefType = type.dyn_cast<MemRefType>()) { + if (index >= memrefType.getRank()) return emitOpError("index is out of range"); - } else if (isa<UnrankedTensorType>(type)) { + } else if (type.isa<UnrankedTensorType>()) { // ok, assumed to be in-range. } else { return emitOpError("requires an operand with tensor or memref type"); @@ -516,12 +516,12 @@ bool DimOp::verify() const { Attribute DimOp::constantFold(ArrayRef<Attribute> operands, MLIRContext *context) const { // Constant fold dim when the size along the index referred to is a constant. - auto *opType = getOperand()->getType(); + auto opType = getOperand()->getType(); int indexSize = -1; - if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) { - indexSize = tensorType->getShape()[getIndex()]; - } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) { - indexSize = memrefType->getShape()[getIndex()]; + if (auto tensorType = opType.dyn_cast<RankedTensorType>()) { + indexSize = tensorType.getShape()[getIndex()]; + } else if (auto memrefType = opType.dyn_cast<MemRefType>()) { + indexSize = memrefType.getShape()[getIndex()]; } if (indexSize >= 0) @@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getSrcMemRef()->getType(); - *p << ", " << *getDstMemRef()->getType(); - *p << ", " << *getTagMemRef()->getType(); + *p << " : " << getSrcMemRef()->getType(); + *p << ", " << getDstMemRef()->getType(); + *p << ", " << getTagMemRef()->getType(); } // Parse DmaStartOp. @@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; - SmallVector<Type *, 3> types; - auto *indexType = parser->getBuilder().getIndexType(); + SmallVector<Type, 3> types; + auto indexType = parser->getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). @@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { return true; // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || - dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) + if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() || + dstIndexInfos.size() != types[1].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "memref rank not equal to indices count"); - if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) + if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << "], "; p->printOperand(getNumElements()); - *p << " : " << *getTagMemRef()->getType(); + *p << " : " << getTagMemRef()->getType(); } // Parse DmaWaitOp. @@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; - Type *type; - auto *indexType = parser->getBuilder().getIndexType(); + Type type; + auto indexType = parser->getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. @@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperand(numElementsInfo, indexType, result->operands)) return true; - if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) + if (tagIndexInfos.size() != type.cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results, void ExtractElementOp::build(Builder *builder, OperationState *result, SSAValue *aggregate, ArrayRef<SSAValue *> indices) { - auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); + auto aggregateType = aggregate->getType().cast<VectorOrTensorType>(); result->addOperands(aggregate); result->addOperands(indices); - result->types.push_back(aggregateType->getElementType()); + result->types.push_back(aggregateType.getElementType()); } void ExtractElementOp::print(OpAsmPrinter *p) const { @@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getAggregate()->getType(); + *p << " : " << getAggregate()->getType(); } bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - VectorOrTensorType *type; + VectorOrTensorType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(aggregateInfo) || @@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(aggregateInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool ExtractElementOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected an aggregate to index into"); - auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); + auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>(); if (!aggregateType) return emitOpError("first operand must be a vector or tensor"); - if (getType() != aggregateType->getElementType()) + if (getType() != aggregateType.getElementType()) return emitOpError("result type must match element type of aggregate"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to extract_element must have 'index' type"); // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType->getRank(); + auto aggregateRank = aggregateType.getRank(); if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) return emitOpError("incorrect number of indices for extract_element"); @@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const { void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, ArrayRef<SSAValue *> indices) { - auto *memrefType = cast<MemRefType>(memref->getType()); + auto memrefType = memref->getType().cast<MemRefType>(); result->addOperands(memref); result->addOperands(indices); - result->types.push_back(memrefType->getElementType()); + result->types.push_back(memrefType.getElementType()); } void LoadOp::print(OpAsmPrinter *p) const { @@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *type; + MemRefType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(memrefInfo) || @@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool LoadOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected a memref to load from"); - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("first operand must be a memref"); - if (getType() != memRefType->getElementType()) + if (getType() != memRefType.getElementType()) return emitOpError("result type must match element type of memref"); - if (memRefType->getRank() != getNumOperands() - 1) + if (memRefType.getRank() != getNumOperands() - 1) return emitOpError("incorrect number of indices for load"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool MemRefCastOp::verify() const { - auto *opType = dyn_cast<MemRefType>(getOperand()->getType()); - auto *resType = dyn_cast<MemRefType>(getType()); + auto opType = getOperand()->getType().dyn_cast<MemRefType>(); + auto resType = getType().dyn_cast<MemRefType>(); if (!opType || !resType) return emitOpError("requires input and result types to be memrefs"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); - if (opType->getAffineMaps() != resType->getAffineMaps()) + if (opType.getAffineMaps() != resType.getAffineMaps()) return emitOpError("requires input and result mappings to be the same"); - if (opType->getMemorySpace() != resType->getMemorySpace()) + if (opType.getMemorySpace() != resType.getMemorySpace()) return emitOpError( "requires input and result memory spaces to be the same"); // They must have the same rank, and any specified dimensions must match. - if (opType->getRank() != resType->getRank()) + if (opType.getRank() != resType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opType->getRank(); i != e; ++i) { - int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i); + for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { + int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } @@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *memrefType; + MemRefType memrefType; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(storeValueInfo) || parser->parseComma() || @@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType->getElementType(), + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands); @@ -950,19 +950,19 @@ bool StoreOp::verify() const { return emitOpError("expected a value to store and a memref"); // Second operand is a memref type. - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("second operand must be a memref"); // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType->getElementType()) + if (getValueToStore()->getType() != memRefType.getElementType()) return emitOpError("first operand must have same type memref element type"); - if (getNumOperands() != 2 + memRefType->getRank()) + if (getNumOperands() != 2 + memRefType.getRank()) return emitOpError("store index operand count not equal to memref rank"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool TensorCastOp::verify() const { - auto *opType = dyn_cast<TensorType>(getOperand()->getType()); - auto *resType = dyn_cast<TensorType>(getType()); + auto opType = getOperand()->getType().dyn_cast<TensorType>(); + auto resType = getType().dyn_cast<TensorType>(); if (!opType || !resType) return emitOpError("requires input and result types to be tensors"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); // If the source or destination are unranked, then the cast is valid. - auto *opRType = dyn_cast<RankedTensorType>(opType); - auto *resRType = dyn_cast<RankedTensorType>(resType); + auto opRType = opType.dyn_cast<RankedTensorType>(); + auto resRType = resType.dyn_cast<RankedTensorType>(); if (!opRType || !resRType) return false; // If they are both ranked, they have to have the same rank, and any specified // dimensions must match. - if (opRType->getRank() != resRType->getRank()) + if (opRType.getRank() != resRType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { - int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); + for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { + int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 81994ddfab4..15dd89bb758 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> { SmallVector<SSAValue *, 8> existingConstants; // Operation statements that were folded and that need to be erased. std::vector<OperationStmt *> opStmtsToErase; - using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>; + using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>; bool foldOperation(Operation *op, SmallVectorImpl<SSAValue *> &existingConstants, @@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) { auto &inst = *instIt++; - auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { builder.setInsertionPoint(&inst); return builder.create<ConstantOp>(inst.getLoc(), value, type); }; @@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { // Override the walker's operation statement visit for constant folding. void ConstantFold::visitOperationStmt(OperationStmt *stmt) { - auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { MLFuncBuilder builder(stmt); return builder.create<ConstantOp>(stmt->getLoc(), value, type); }; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d96d65b5fb7..90421819d82 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { bInner.setInsertionPoint(forStmt, forStmt->begin()); // Doubles the shape with a leading dimension extent of 2. - auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * { + auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { // Add the leading dimension in the shape for the double buffer. - ArrayRef<int> shape = oldMemRefType->getShape(); + ArrayRef<int> shape = oldMemRefType.getShape(); SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); shapeSizes.insert(shapeSizes.begin(), 2); - auto *newMemRefType = - bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {}, - oldMemRefType->getMemorySpace()); + auto newMemRefType = + bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {}, + oldMemRefType.getMemorySpace()); return newMemRefType; }; - auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType())); + auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>()); // Create and place the alloc at the top level. MLFuncBuilder topBuilder(forStmt->getFunction()); - auto *newMemRef = cast<MLValue>( + auto newMemRef = cast<MLValue>( topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType) ->getResult()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index cdf5b7166a0..4ec89425189 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -78,7 +78,7 @@ private: /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants; + DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants; }; }; // end anonymous namespace diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index edd8ce85317..ad9d6dcb769 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, ArrayRef<MLValue *> extraIndices, AffineMap indexRemap) { - unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank(); + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; if (indexRemap) { assert(indexRemap.getNumInputs() == oldMemRefRank); @@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, } // Assert same elemental type. - assert(cast<MemRefType>(oldMemRef->getType())->getElementType() == - cast<MemRefType>(newMemRef->getType())->getElementType()); + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); // Check if memref was used in a non-deferencing context. for (const StmtOperand &use : oldMemRef->getUses()) { @@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, opStmt->operand_end()); // Result types don't change. Both memref's are of the same elemental type. - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; resultTypes.reserve(opStmt->getNumResults()); for (const auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d7a1f531cef..511afa95993 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches, /// sizes specified by vectorSize. The MemRef lives in the same memory space as /// tmpl. The MemRef should be promoted to a closer memory address space in a /// later pass. -static MemRefType *getVectorizedMemRefType(MemRefType *tmpl, - ArrayRef<int> vectorSizes) { - auto *elementType = tmpl->getElementType(); - assert(!dyn_cast<VectorType>(elementType) && +static MemRefType getVectorizedMemRefType(MemRefType tmpl, + ArrayRef<int> vectorSizes) { + auto elementType = tmpl.getElementType(); + assert(!elementType.dyn_cast<VectorType>() && "Can't vectorize an already vector type"); - assert(tmpl->getAffineMaps().empty() && + assert(tmpl.getAffineMaps().empty() && "Unsupported non-implicit identity map"); return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {}, - tmpl->getMemorySpace()); + tmpl.getMemorySpace()); } /// Creates an unaligned load with the following semantics: @@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() { template <typename LoadOrStoreOpPointer> static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp, ArrayRef<int> vectorSize) { - auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType()); - auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); + auto memRefType = + memoryOp->getMemRef()->getType().template cast<MemRefType>(); + auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); // Materialize a MemRef with 1 vector. auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); |

