diff options
| author | River Riddle <riverriddle@google.com> | 2018-10-30 14:59:22 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:45:54 -0700 |
| commit | 4c465a181db49c436f62da303e8fdd3ed317fee7 (patch) | |
| tree | fb190912d0714222d6e336e19d5b8ea16342fb6e /mlir/lib/IR | |
| parent | 75376b8e33c67a42e3dca2c597197e0622b6eaa2 (diff) | |
| download | bcm5719-llvm-4c465a181db49c436f62da303e8fdd3ed317fee7.tar.gz bcm5719-llvm-4c465a181db49c436f62da303e8fdd3ed317fee7.zip | |
Implement value type abstraction for types.
This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.
PiperOrigin-RevId: 219372163
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 91 | ||||
| -rw-r--r-- | mlir/lib/IR/AttributeDetail.h | 5 | ||||
| -rw-r--r-- | mlir/lib/IR/Attributes.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/IR/BasicBlock.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 55 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/IR/Function.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/IR/Instructions.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/MLIRContext.cpp | 207 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 26 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/TypeDetail.h | 126 | ||||
| -rw-r--r-- | mlir/lib/IR/Types.cpp | 101 |
13 files changed, 431 insertions, 270 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 454a28a6558..cb5e96f0086 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -122,7 +122,7 @@ private: void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); void visitOperationStmt(const OperationStmt *opStmt); - void visitType(const Type *type); + void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const Operation *op); @@ -135,16 +135,16 @@ private: } // end anonymous namespace // TODO Support visiting other types/instructions when implemented. -void ModuleState::visitType(const Type *type) { - if (auto *funcType = dyn_cast<FunctionType>(type)) { +void ModuleState::visitType(Type type) { + if (auto funcType = type.dyn_cast<FunctionType>()) { // Visit input and result types for functions. - for (auto *input : funcType->getInputs()) + for (auto input : funcType.getInputs()) visitType(input); - for (auto *result : funcType->getResults()) + for (auto result : funcType.getResults()) visitType(result); - } else if (auto *memref = dyn_cast<MemRefType>(type)) { + } else if (auto memref = type.dyn_cast<MemRefType>()) { // Visit affine maps in memref type. - for (auto map : memref->getAffineMaps()) { + for (auto map : memref.getAffineMaps()) { recordAffineMapReference(map); } } @@ -271,7 +271,7 @@ public: void print(const Module *module); void printFunctionReference(const Function *func); void printAttribute(Attribute attr); - void printType(const Type *type); + void printType(Type type); void print(const Function *fn); void print(const ExtFunction *fn); void print(const CFGFunction *fn); @@ -290,7 +290,7 @@ protected: void printFunctionAttributes(const Function *fn); void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs = {}); - void printFunctionResultType(const FunctionType *type); + void printFunctionResultType(FunctionType type); void printAffineMapId(int affineMapId) const; void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; @@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) { } void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { - auto *type = attr.getType(); - auto shape = type->getShape(); - auto rank = type->getRank(); + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); SmallVector<Attribute, 16> elements; attr.getValues(elements); @@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { os << ']'; } -void ModulePrinter::printType(const Type *type) { - switch (type->getKind()) { +void ModulePrinter::printType(Type type) { + switch (type.getKind()) { case Type::Kind::Index: os << "index"; return; @@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) { return; case Type::Kind::Integer: { - auto *integer = cast<IntegerType>(type); - os << 'i' << integer->getWidth(); + auto integer = type.cast<IntegerType>(); + os << 'i' << integer.getWidth(); return; } case Type::Kind::Function: { - auto *func = cast<FunctionType>(type); + auto func = type.cast<FunctionType>(); os << '('; - interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); + interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; - auto results = func->getResults(); + auto results = func.getResults(); if (results.size() == 1) - os << *results[0]; + os << results[0]; else { os << '('; - interleaveComma(results, [&](Type *type) { printType(type); }); + interleaveComma(results, [&](Type type) { printType(type); }); os << ')'; } return; } case Type::Kind::Vector: { - auto *v = cast<VectorType>(type); + auto v = type.cast<VectorType>(); os << "vector<"; - for (auto dim : v->getShape()) + for (auto dim : v.getShape()) os << dim << 'x'; - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::RankedTensor: { - auto *v = cast<RankedTensorType>(type); + auto v = type.cast<RankedTensorType>(); os << "tensor<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::UnrankedTensor: { - auto *v = cast<UnrankedTensorType>(type); + auto v = type.cast<UnrankedTensorType>(); os << "tensor<*x"; - printType(v->getElementType()); + printType(v.getElementType()); os << '>'; return; } case Type::Kind::MemRef: { - auto *v = cast<MemRefType>(type); + auto v = type.cast<MemRefType>(); os << "memref<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - printType(v->getElementType()); - for (auto map : v->getAffineMaps()) { + printType(v.getElementType()); + for (auto map : v.getAffineMaps()) { os << ", "; printAffineMapReference(map); } // Only print the memory space if it is the non-default one. - if (v->getMemorySpace()) - os << ", " << v->getMemorySpace(); + if (v.getMemorySpace()) + os << ", " << v.getMemorySpace(); os << '>'; return; } @@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { // Function printing //===----------------------------------------------------------------------===// -void ModulePrinter::printFunctionResultType(const FunctionType *type) { - switch (type->getResults().size()) { +void ModulePrinter::printFunctionResultType(FunctionType type) { + switch (type.getResults().size()) { case 0: break; case 1: os << " -> "; - printType(type->getResults()[0]); + printType(type.getResults()[0]); break; default: os << " -> ("; - interleaveComma(type->getResults(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getResults(), + [&](Type eltType) { printType(eltType); }); os << ')'; break; } @@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) { auto type = fn->getType(); os << "@" << fn->getName() << '('; - interleaveComma(type->getInputs(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); }); os << ')'; printFunctionResultType(type); @@ -937,7 +936,7 @@ public: // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } - void printType(const Type *type) { ModulePrinter::printType(type); } + void printType(Type type) { ModulePrinter::printType(type); } void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); } void printAffineMap(AffineMap map) { return ModulePrinter::printAffineMapReference(map); @@ -974,10 +973,10 @@ protected: if (auto *op = value->getDefiningOperation()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) { // i1 constants get special names. - if (intOp->getType()->isInteger(1)) { + if (intOp->getType().isInteger(1)) { specialName << (intOp->getValue() ? "true" : "false"); } else { - specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); + specialName << 'c' << intOp->getValue() << '_' << intOp->getType(); } } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { specialName << 'c' << intOp->getValue(); @@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); } void Type::print(raw_ostream &os) const { ModuleState state(getContext()); - ModulePrinter(os, state).printType(this); + ModulePrinter(os, state).printType(*this); } void Type::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index a0e9afb4fd3..63ad544fa48 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage { /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { - Type *value; + Type value; }; /// An attribute representing a reference to a function. @@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage { /// A base attribute representing a reference to a vector or tensor constant. struct ElementsAttributeStorage : public AttributeStorage { - VectorOrTensorType *type; + VectorOrTensorType type; }; /// An attribute representing a reference to a vector or tensor constant, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 34312b84a0b..58b5b90d43d 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const { TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -Type *TypeAttr::getValue() const { - return static_cast<ImplType *>(attr)->value; -} +Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} @@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } -FunctionType *FunctionAttr::getType() const { return getValue()->getType(); } +FunctionType FunctionAttr::getType() const { return getValue()->getType(); } ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -VectorOrTensorType *ElementsAttr::getType() const { +VectorOrTensorType ElementsAttr::getType() const { return static_cast<ImplType *>(attr)->type; } @@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos, void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth; - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); values.reserve(elementNum); if (bitsWidth == 64) { ArrayRef<int64_t> vs( @@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr) : DenseElementsAttr(ptr) {} void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()), getRawData().size() / 8}); values.reserve(elementNum); diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp index bb8ac75d91a..29a5ce12e4a 100644 --- a/mlir/lib/IR/BasicBlock.cpp +++ b/mlir/lib/IR/BasicBlock.cpp @@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() { // Argument list management. //===----------------------------------------------------------------------===// -BBArgument *BasicBlock::addArgument(Type *type) { +BBArgument *BasicBlock::addArgument(Type type) { auto *arg = new BBArgument(type, this); arguments.push_back(arg); return arg; } /// Add one argument to the argument list for each type specified in the list. -auto BasicBlock::addArguments(ArrayRef<Type *> types) +auto BasicBlock::addArguments(ArrayRef<Type> types) -> llvm::iterator_range<args_iterator> { arguments.reserve(arguments.size() + types.size()); auto initialSize = arguments.size(); - for (auto *type : types) { + for (auto type : types) { addArgument(type); } return {arguments.data() + initialSize, arguments.data() + arguments.size()}; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 22d749a6c8c..906b580d9af 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename, // Types. //===----------------------------------------------------------------------===// -FloatType *Builder::getBF16Type() { return Type::getBF16(context); } +FloatType Builder::getBF16Type() { return Type::getBF16(context); } -FloatType *Builder::getF16Type() { return Type::getF16(context); } +FloatType Builder::getF16Type() { return Type::getF16(context); } -FloatType *Builder::getF32Type() { return Type::getF32(context); } +FloatType Builder::getF32Type() { return Type::getF32(context); } -FloatType *Builder::getF64Type() { return Type::getF64(context); } +FloatType Builder::getF64Type() { return Type::getF64(context); } -OtherType *Builder::getIndexType() { return Type::getIndex(context); } +OtherType Builder::getIndexType() { return Type::getIndex(context); } -OtherType *Builder::getTFControlType() { return Type::getTFControl(context); } +OtherType Builder::getTFControlType() { return Type::getTFControl(context); } -OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } +OtherType Builder::getTFResourceType() { return Type::getTFResource(context); } -OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); } +OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); } -OtherType *Builder::getTFComplex64Type() { +OtherType Builder::getTFComplex64Type() { return Type::getTFComplex64(context); } -OtherType *Builder::getTFComplex128Type() { +OtherType Builder::getTFComplex128Type() { return Type::getTFComplex128(context); } -OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); } +OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); } -OtherType *Builder::getTFStringType() { return Type::getTFString(context); } +OtherType Builder::getTFStringType() { return Type::getTFString(context); } -IntegerType *Builder::getIntegerType(unsigned width) { +IntegerType Builder::getIntegerType(unsigned width) { return Type::getInteger(width, context); } -FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs, - ArrayRef<Type *> results) { +FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, + ArrayRef<Type> results) { return FunctionType::get(inputs, results, context); } -MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { +MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } -VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) { +VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) { return VectorType::get(shape, elementType); } -RankedTensorType *Builder::getTensorType(ArrayRef<int> shape, - Type *elementType) { +RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) { return RankedTensorType::get(shape, elementType); } -UnrankedTensorType *Builder::getTensorType(Type *elementType) { +UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } @@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { return IntegerSetAttr::get(set); } -TypeAttr Builder::getTypeAttr(Type *type) { +TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type, context); } @@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) { return FunctionAttr::get(value, context); } -ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, ArrayRef<char> data) { return DenseElementsAttr::get(type, data); } -ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes) { return OpaqueElementsAttr::get(type, bytes); } @@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { OperationStmt *MLFuncBuilder::createOperation(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> types, + ArrayRef<Type> types, ArrayRef<NamedAttribute> attrs) { auto *op = OperationStmt::create(location, name, operands, types, attrs, getContext()); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 542e67eaefd..e4bca037c4e 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, numDims = opInfos.size(); // Parse the optional symbol operands. - auto *affineIntTy = parser->getBuilder().getIndexType(); + auto affineIntTy = parser->getBuilder().getIndexType(); if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::OptionalSquare) || parser->resolveOperands(opInfos, affineIntTy, operands)) @@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result, bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); - auto *affineIntTy = builder.getIndexType(); + auto affineIntTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; @@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants, /// Builds a constant op with the specified attribute value and result type. void ConstantOp::build(Builder *builder, OperationState *result, - Attribute value, Type *type) { + Attribute value, Type type) { result->addAttribute("value", value); result->types.push_back(type); } @@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const { p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); if (!getValue().isa<FunctionAttr>()) - *p << " : " << *getType(); + *p << " : " << getType(); } bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { Attribute valueAttr; - Type *type; + Type type; if (parser->parseAttribute(valueAttr, "value", result->attributes) || parser->parseOptionalAttributeDict(result->attributes)) @@ -208,33 +208,33 @@ bool ConstantOp::verify() const { if (!value) return emitOpError("requires a 'value' attribute"); - auto *type = this->getType(); - if (isa<IntegerType>(type) || type->isIndex()) { + auto type = this->getType(); + if (type.isa<IntegerType>() || type.isIndex()) { if (!value.isa<IntegerAttr>()) return emitOpError( "requires 'value' to be an integer for an integer result type"); return false; } - if (isa<FloatType>(type)) { + if (type.isa<FloatType>()) { if (!value.isa<FloatAttr>()) return emitOpError("requires 'value' to be a floating point constant"); return false; } - if (isa<VectorOrTensorType>(type)) { + if (type.isa<VectorOrTensorType>()) { if (!value.isa<ElementsAttr>()) return emitOpError("requires 'value' to be a vector/tensor constant"); return false; } - if (type->isTFString()) { + if (type.isTFString()) { if (!value.isa<StringAttr>()) return emitOpError("requires 'value' to be a string constant"); return false; } - if (isa<FunctionType>(type)) { + if (type.isa<FunctionType>()) { if (!value.isa<FunctionAttr>()) return emitOpError("requires 'value' to be a function reference"); return false; @@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands, } void ConstantFloatOp::build(Builder *builder, OperationState *result, - const APFloat &value, FloatType *type) { + const APFloat &value, FloatType type) { ConstantOp::build(builder, result, builder->getFloatAttr(value), type); } bool ConstantFloatOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<FloatType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<IntegerType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<IntegerType>(); } void ConstantIntOp::build(Builder *builder, OperationState *result, @@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result, /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, Type *type) { - assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type"); + int64_t value, Type type) { + assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); ConstantOp::build(builder, result, builder->getIntegerAttr(value), type); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); + return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState *result, @@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result, bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type *, 2> types; + SmallVector<Type, 2> types; llvm::SMLoc loc; return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || (!opInfo.empty() && parser->parseColonTypeList(types)) || @@ -330,7 +330,7 @@ bool ReturnOp::verify() const { // The operand number and types must match the function signature. MLFunction *function = cast<MLFunction>(block); - const auto &results = function->getType()->getResults(); + const auto &results = function->getType().getResults(); if (stmt->getNumOperands() != results.size()) return emitOpError("has " + Twine(stmt->getNumOperands()) + " operands, but enclosing function returns " + diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index efeb16b61db..70c0e1259b3 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -28,8 +28,8 @@ using namespace mlir; Function::Function(Kind kind, Location *location, StringRef name, - FunctionType *type, ArrayRef<NamedAttribute> attrs) - : nameAndKind(Identifier::get(name, type->getContext()), kind), + FunctionType type, ArrayRef<NamedAttribute> attrs) + : nameAndKind(Identifier::get(name, type.getContext()), kind), location(location), type(type) { this->attrs = AttributeListStorage::get(attrs, getContext()); } @@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const { return {}; } -MLIRContext *Function::getContext() const { return getType()->getContext(); } +MLIRContext *Function::getContext() const { return getType().getContext(); } /// Delete this object. void Function::destroy() { @@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const { // ExtFunction implementation. //===----------------------------------------------------------------------===// -ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, +ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::ExtFunc, location, name, type, attrs) {} @@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, // CFGFunction implementation. //===----------------------------------------------------------------------===// -CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type, +CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::CFGFunc, location, name, type, attrs) {} @@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() { /// Create a new MLFunction with the specific fields. MLFunction *MLFunction::create(Location *location, StringRef name, - FunctionType *type, + FunctionType type, ArrayRef<NamedAttribute> attrs) { - const auto &argTypes = type->getInputs(); + const auto &argTypes = type.getInputs(); auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size()); void *rawMem = malloc(byteSize); @@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name, return function; } -MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type, +MLFunction::MLFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::MLFunc, location, name, type, attrs), StmtBlock(StmtBlockKind::MLFunc) {} diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index 422636bf2e3..d2f49ddfc6e 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const { /// Create a new OperationInst with the specified fields. OperationInst *OperationInst::create(Location *location, OperationName name, ArrayRef<CFGValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(), @@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name, OperationInst *OperationInst::clone() const { SmallVector<CFGValue *, 8> operands; - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; // Put together the operands and results. for (auto *operand : getOperands()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0a2e9416842..8811f7b9f78 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -21,6 +21,7 @@ #include "AttributeDetail.h" #include "AttributeListStorage.h" #include "IntegerSetDetail.h" +#include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -44,11 +45,11 @@ using namespace mlir::detail; using namespace llvm; namespace { -struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { +struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> { // Functions are uniqued based on their inputs and results. - using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>; - using DenseMapInfo<FunctionType *>::getHashValue; - using DenseMapInfo<FunctionType *>::isEqual; + using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>; + using DenseMapInfo<FunctionTypeStorage *>::getHashValue; + using DenseMapInfo<FunctionTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( @@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) { + static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; return lhs == KeyTy(rhs->getInputs(), rhs->getResults()); @@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { } }; -struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> { +struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> { // Vectors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<VectorType *>::getHashValue; - using DenseMapInfo<VectorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<VectorTypeStorage *>::getHashValue; + using DenseMapInfo<VectorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const VectorType *rhs) { + static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> { +struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> { // Ranked tensors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<RankedTensorType *>::getHashValue; - using DenseMapInfo<RankedTensorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue; + using DenseMapInfo<RankedTensorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) { + static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> { +struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> { // MemRefs are uniqued based on their element type, shape, affine map // composition, and memory space. - using KeyTy = - std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; - using DenseMapInfo<MemRefType *>::getHashValue; - using DenseMapInfo<MemRefType *>::isEqual; + using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; + using DenseMapInfo<MemRefTypeStorage *>::getHashValue; + using DenseMapInfo<MemRefTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(std::get<0>(key)), + DenseMapInfo<Type>::getHashValue(std::get<0>(key)), hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()), hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), std::get<3>(key)); } - static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) { + static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(), - rhs->getAffineMaps(), rhs->getMemorySpace()); + return lhs == std::make_tuple(rhs->elementType, rhs->getShape(), + rhs->getAffineMaps(), rhs->memorySpace); } }; @@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> { }; struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>; + using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>; using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue; using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual; @@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { }; struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, StringRef>; + using KeyTy = std::pair<VectorOrTensorType, StringRef>; using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue; using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual; @@ -295,13 +295,14 @@ public: llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers; // Uniquing table for 'other' types. - OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - - int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr}; + OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - + int(Type::Kind::FIRST_OTHER_TYPE) + 1] = { + nullptr}; // Uniquing table for 'float' types. - FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - - int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = { - nullptr}; + FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - + int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = + {nullptr}; // Affine map uniquing. using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; @@ -324,26 +325,26 @@ public: DenseMap<int64_t, AffineConstantExprStorage *> constExprs; /// Integer type uniquing. - DenseMap<unsigned, IntegerType *> integers; + DenseMap<unsigned, IntegerTypeStorage *> integers; /// Function type uniquing. - using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>; + using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>; FunctionTypeSet functions; /// Vector type uniquing. - using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>; + using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>; VectorTypeSet vectors; /// Ranked tensor type uniquing. using RankedTensorTypeSet = - DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>; + DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>; RankedTensorTypeSet rankedTensors; /// Unranked tensor type uniquing. - DenseMap<Type *, UnrankedTensorType *> unrankedTensors; + DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors; /// MemRef type uniquing. - using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>; + using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>; MemRefTypeSet memrefs; // Attribute uniquing. @@ -355,13 +356,12 @@ public: ArrayAttrSet arrayAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs; DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs; - DenseMap<Type *, TypeAttributeStorage *> typeAttrs; + DenseMap<Type, TypeAttributeStorage *> typeAttrs; using AttributeListSet = DenseSet<AttributeListStorage *, AttributeListKeyInfo>; AttributeListSet attributeLists; DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs; - DenseMap<std::pair<VectorOrTensorType *, Attribute>, - SplatElementsAttributeStorage *> + DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *> splatElementsAttrs; using DenseElementsAttrSet = DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>; @@ -369,7 +369,7 @@ public: using OpaqueElementsAttrSet = DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>; OpaqueElementsAttrSet opaqueElementsAttrs; - DenseMap<std::tuple<Type *, Attribute, Attribute>, + DenseMap<std::tuple<Type, Attribute, Attribute>, SparseElementsAttributeStorage *> sparseElementsAttrs; @@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line, // Type uniquing //===----------------------------------------------------------------------===// -IntegerType *IntegerType::get(unsigned width, MLIRContext *context) { +IntegerType IntegerType::get(unsigned width, MLIRContext *context) { + assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); auto &impl = context->getImpl(); auto *&result = impl.integers[width]; if (!result) { - result = impl.allocator.Allocate<IntegerType>(); - new (result) IntegerType(width, context); + result = impl.allocator.Allocate<IntegerTypeStorage>(); + new (result) IntegerTypeStorage{{Kind::Integer, context}, width}; } return result; } -FloatType *FloatType::get(Kind kind, MLIRContext *context) { +FloatType FloatType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE && kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind"); auto &impl = context->getImpl(); @@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<FloatType>(); + auto *ptr = impl.allocator.Allocate<FloatTypeStorage>(); // Initialize the memory using placement new. - new (ptr) FloatType(kind, context); + new (ptr) FloatTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -OtherType *OtherType::get(Kind kind, MLIRContext *context) { +OtherType OtherType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE && "Not an 'other' type kind"); auto &impl = context->getImpl(); @@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<OtherType>(); + auto *ptr = impl.allocator.Allocate<OtherTypeStorage>(); // Initialize the memory using placement new. - new (ptr) OtherType(kind, context); + new (ptr) OtherTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -FunctionType *FunctionType::get(ArrayRef<Type *> inputs, - ArrayRef<Type *> results, - MLIRContext *context) { +FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results, + MLIRContext *context) { auto &impl = context->getImpl(); // Look to see if we already have this function type. @@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<FunctionType>(); + auto *result = impl.allocator.Allocate<FunctionTypeStorage>(); // Copy the inputs and results into the bump pointer. - SmallVector<Type *, 16> types; + SmallVector<Type, 16> types; types.reserve(inputs.size() + results.size()); types.append(inputs.begin(), inputs.end()); types.append(results.begin(), results.end()); - auto typesList = impl.copyInto(ArrayRef<Type *>(types)); + auto typesList = impl.copyInto(ArrayRef<Type>(types)); // Initialize the memory using placement new. - new (result) - FunctionType(typesList.data(), inputs.size(), results.size(), context); + new (result) FunctionTypeStorage{ + {Kind::Function, context, static_cast<unsigned int>(inputs.size())}, + static_cast<unsigned int>(results.size()), + typesList.data()}; // Cache and return it. return *existing.first = result; } -VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { +VectorType VectorType::get(ArrayRef<int> shape, Type elementType) { assert(!shape.empty() && "vector types must have at least one dimension"); - assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) && + assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) && "vectors elements must be primitives"); assert(!std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; }) && "vector types must have static shape"); - auto *context = elementType->getContext(); + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this vector type. @@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<VectorType>(); + auto *result = impl.allocator.Allocate<VectorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) VectorType(shape, elementType, context); + new (result) VectorTypeStorage{ + {{Kind::Vector, context, static_cast<unsigned int>(shape.size())}, + elementType}, + shape.data()}; // Cache and return it. return *existing.first = result; } -RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, - Type *elementType) { - auto *context = elementType->getContext(); +RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this ranked tensor type. @@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<RankedTensorType>(); + auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) RankedTensorType(shape, elementType, context); + new (result) RankedTensorTypeStorage{ + {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())}, + elementType}}, + shape.data()}; // Cache and return it. return *existing.first = result; } -UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { - auto *context = elementType->getContext(); +UnrankedTensorType UnrankedTensorType::get(Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this unranked tensor type. @@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { return result; // On the first use, we allocate them into the bump pointer. - result = impl.allocator.Allocate<UnrankedTensorType>(); + result = impl.allocator.Allocate<UnrankedTensorTypeStorage>(); // Initialize the memory using placement new. - new (result) UnrankedTensorType(elementType, context); + new (result) UnrankedTensorTypeStorage{ + {{{Kind::UnrankedTensor, context}, elementType}}}; return result; } -MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { - auto *context = elementType->getContext(); +MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Drop the unbounded identity maps from the composition. @@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<MemRefType>(); + auto *result = impl.allocator.Allocate<MemRefTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); @@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, impl.copyInto(ArrayRef<AffineMap>(affineMapComposition)); // Initialize the memory using placement new. - new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace, - context); + new (result) MemRefTypeStorage{ + {Kind::MemRef, context, static_cast<unsigned int>(shape.size())}, + elementType, + shape.data(), + static_cast<unsigned int>(affineMapComposition.size()), + affineMapComposition.data(), + memorySpace}; // Cache and return it. return *existing.first = result; } @@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { return result; } -TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { +TypeAttr TypeAttr::get(Type type, MLIRContext *context) { auto *&result = context->getImpl().typeAttrs[type]; if (result) return result; @@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs, return *existing.first = result; } -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, +SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, Attribute elt) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto *&result = impl.splatElementsAttrs[{type, elt}]; @@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, return result; } -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, +DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef<char> data) { - auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); + auto bitsRequired = (long)type.getBitWidth() * type.getNumElements(); (void)bitsRequired; assert((bitsRequired <= data.size() * 8L) && "Input data bit size should be larger than that type requires"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. DenseElementsAttrInfo::KeyTy key({type, data}); @@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first; // Otherwise, allocate a new one, unique it and return it. - auto *eltType = type->getElementType(); - switch (eltType->getKind()) { + auto eltType = type.getElementType(); + switch (eltType.getKind()) { case Type::Kind::BF16: case Type::Kind::F16: case Type::Kind::F32: @@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } case Type::Kind::Integer: { - auto width = ::cast<IntegerType>(eltType)->getWidth(); + auto width = eltType.cast<IntegerType>().getWidth(); auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>(); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); std::uninitialized_copy(data.begin(), data.end(), copy); @@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, } } -OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, +OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, StringRef bytes) { - assert(isValidTensorElementType(type->getElementType()) && + assert(isValidTensorElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. OpaqueElementsAttrInfo::KeyTy key({type, bytes}); @@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type, +SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto key = std::make_tuple(type, indices, values); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2ed09b83b53..0722421c8ba 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, } bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { - auto *type = op->getResult(0)->getType(); + auto type = op->getResult(0)->getType(); for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { if (op->getResult(i)->getType() != type) return op->emitOpError( @@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { /// If this is a vector type, or a tensor type, return the scalar element type /// that it is built around, otherwise return the type unmodified. -static Type *getTensorOrVectorElementType(Type *type) { - if (auto *vec = dyn_cast<VectorType>(type)) - return vec->getElementType(); +static Type getTensorOrVectorElementType(Type type) { + if (auto vec = type.dyn_cast<VectorType>()) + return vec.getElementType(); // Look through tensor<vector<...>> to find the underlying element type. - if (auto *tensor = dyn_cast<TensorType>(type)) - return getTensorOrVectorElementType(tensor->getElementType()); + if (auto tensor = type.dyn_cast<TensorType>()) + return getTensorOrVectorElementType(tensor.getElementType()); return type; } bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<FloatType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) return op->emitOpError("requires a floating point type"); } @@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>()) return op->emitOpError("requires an integer type"); } return false; @@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> ops; - Type *type; + Type type; return parser->parseOperandList(ops, 2) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type) || @@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p->printOptionalAttrDict(op->getAttrs()); - *p << " : " << *op->getResult(0)->getType(); + *p << " : " << op->getResult(0)->getType(); } //===----------------------------------------------------------------------===// @@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { //===----------------------------------------------------------------------===// void impl::buildCastOp(Builder *builder, OperationState *result, - SSAValue *source, Type *destType) { + SSAValue *source, Type destType) { result->addOperands(source); result->addTypes(destType); } bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType srcInfo; - Type *srcType, *dstType; + Type srcType, dstType; return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || parser->resolveOperand(srcInfo, srcType, result->operands) || parser->parseKeywordType("to", dstType) || @@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " - << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType(); + << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index e9c46d6ec5e..698089a1c67 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block, /// Create a new OperationStmt with the specific fields. OperationStmt *OperationStmt::create(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(), @@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) - return getResult(0)->getType()->getContext(); + return getResult(0)->getType().getContext(); if (getNumOperands()) - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); // In the very odd case where we have no operands or results, fall back to // doing a find. @@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const { if (operands.empty()) return findFunction()->getContext(); - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); } //===----------------------------------------------------------------------===// @@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, operands.push_back(remapOperand(opValue)); if (auto *opStmt = dyn_cast<OperationStmt>(this)) { - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; resultTypes.reserve(opStmt->getNumResults()); for (auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h new file mode 100644 index 00000000000..c22e87a283e --- /dev/null +++ b/mlir/lib/IR/TypeDetail.h @@ -0,0 +1,126 @@ +//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of Type. +// +//===----------------------------------------------------------------------===// +#ifndef TYPEDETAIL_H_ +#define TYPEDETAIL_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +class AffineMap; +class MLIRContext; + +namespace detail { + +/// Base storage class appearing in a Type. +struct alignas(8) TypeStorage { + TypeStorage(Type::Kind kind, MLIRContext *context) + : context(context), kind(kind), subclassData(0) {} + TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData) + : context(context), kind(kind), subclassData(subclassData) {} + + unsigned getSubclassData() const { return subclassData; } + + void setSubclassData(unsigned val) { + subclassData = val; + // Ensure we don't have any accidental truncation. + assert(getSubclassData() == val && "Subclass data too large for field"); + } + + /// This refers to the MLIRContext in which this type was uniqued. + MLIRContext *const context; + + /// Classification of the subclass, used for type checking. + Type::Kind kind : 8; + + /// Space for subclasses to store data. + unsigned subclassData : 24; +}; + +struct IntegerTypeStorage : public TypeStorage { + unsigned width; +}; + +struct FloatTypeStorage : public TypeStorage {}; + +struct OtherTypeStorage : public TypeStorage {}; + +struct FunctionTypeStorage : public TypeStorage { + ArrayRef<Type> getInputs() const { + return ArrayRef<Type>(inputsAndResults, subclassData); + } + ArrayRef<Type> getResults() const { + return ArrayRef<Type>(inputsAndResults + subclassData, numResults); + } + + unsigned numResults; + Type const *inputsAndResults; +}; + +struct VectorOrTensorTypeStorage : public TypeStorage { + Type elementType; +}; + +struct VectorTypeStorage : public VectorOrTensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct TensorTypeStorage : public VectorOrTensorTypeStorage {}; + +struct RankedTensorTypeStorage : public TensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct UnrankedTensorTypeStorage : public TensorTypeStorage {}; + +struct MemRefTypeStorage : public TypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + ArrayRef<AffineMap> getAffineMaps() const { + return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + } + + /// The type of each scalar element of the memref. + Type elementType; + /// An array of integers which stores the shape dimension sizes. + const int *shapeElements; + /// The number of affine maps in the 'affineMapList' array. + const unsigned numAffineMaps; + /// List of affine maps in the memref's layout/index map composition. + AffineMap const *affineMapList; + /// Memory space in which data referenced by memref resides. + const unsigned memorySpace; +}; + +} // namespace detail +} // namespace mlir +#endif // TYPEDETAIL_H_ diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 0ad3f4728fe..1a716956608 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -16,10 +16,17 @@ // ============================================================================= #include "mlir/IR/Types.h" +#include "TypeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/raw_ostream.h" + using namespace mlir; +using namespace mlir::detail; + +Type::Kind Type::getKind() const { return type->kind; } + +MLIRContext *Type::getContext() const { return type->context; } unsigned Type::getBitWidth() const { switch (getKind()) { @@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const { case Type::Kind::F64: return 64; case Type::Kind::Integer: - return cast<IntegerType>(this)->getWidth(); + return cast<IntegerType>().getWidth(); case Type::Kind::Vector: case Type::Kind::RankedTensor: case Type::Kind::UnrankedTensor: - return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth(); + return cast<VectorOrTensorType>().getElementType().getBitWidth(); // TODO: Handle more types. default: llvm_unreachable("unexpected type"); } } -IntegerType::IntegerType(unsigned width, MLIRContext *context) - : Type(Kind::Integer, context), width(width) { - assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); +unsigned Type::getSubclassData() const { return type->getSubclassData(); } +void Type::setSubclassData(unsigned val) { type->setSubclassData(val); } + +IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {} + +unsigned IntegerType::getWidth() const { + return static_cast<ImplType *>(type)->width; } -FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {} + +OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {} -OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {} -FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, - unsigned numResults, MLIRContext *context) - : Type(Kind::Function, context, numInputs), numResults(numResults), - inputsAndResults(inputsAndResults) {} +ArrayRef<Type> FunctionType::getInputs() const { + return static_cast<ImplType *>(type)->getInputs(); +} + +unsigned FunctionType::getNumResults() const { + return static_cast<ImplType *>(type)->numResults; +} + +ArrayRef<Type> FunctionType::getResults() const { + return static_cast<ImplType *>(type)->getResults(); +} -VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, - Type *elementType, unsigned subClassData) - : Type(kind, context, subClassData), elementType(elementType) {} +VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {} + +Type VectorOrTensorType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} unsigned VectorOrTensorType::getNumElements() const { switch (getKind()) { @@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const { ArrayRef<int> VectorOrTensorType::getShape() const { switch (getKind()) { case Kind::Vector: - return cast<VectorType>(this)->getShape(); + return cast<VectorType>().getShape(); case Kind::RankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); case Kind::UnrankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); default: llvm_unreachable("not a VectorOrTensorType"); } @@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const { return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); } -VectorType::VectorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()), - shapeElements(shape.data()) {} +VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} -TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) - : VectorOrTensorType(kind, context, elementType) { - assert(isValidTensorElementType(elementType)); +ArrayRef<int> VectorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : TensorType(Kind::RankedTensor, elementType, context), - shapeElements(shape.data()) { - setSubclassData(shape.size()); +TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} + +RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +ArrayRef<int> RankedTensorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) - : TensorType(Kind::UnrankedTensor, elementType, context) {} +UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {} -MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapList, unsigned memorySpace, - MLIRContext *context) - : Type(Kind::MemRef, context, shape.size()), elementType(elementType), - shapeElements(shape.data()), numAffineMaps(affineMapList.size()), - affineMapList(affineMapList.data()), memorySpace(memorySpace) {} +ArrayRef<int> MemRefType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); +} + +Type MemRefType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} ArrayRef<AffineMap> MemRefType::getAffineMaps() const { - return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + return static_cast<ImplType *>(type)->getAffineMaps(); +} + +unsigned MemRefType::getMemorySpace() const { + return static_cast<ImplType *>(type)->memorySpace; } unsigned MemRefType::getNumDynamicDims() const { |

