summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2018-10-30 14:59:22 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:45:54 -0700
commit4c465a181db49c436f62da303e8fdd3ed317fee7 (patch)
treefb190912d0714222d6e336e19d5b8ea16342fb6e /mlir/lib/IR
parent75376b8e33c67a42e3dca2c597197e0622b6eaa2 (diff)
downloadbcm5719-llvm-4c465a181db49c436f62da303e8fdd3ed317fee7.tar.gz
bcm5719-llvm-4c465a181db49c436f62da303e8fdd3ed317fee7.zip
Implement value type abstraction for types.
This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast. PiperOrigin-RevId: 219372163
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp91
-rw-r--r--mlir/lib/IR/AttributeDetail.h5
-rw-r--r--mlir/lib/IR/Attributes.cpp16
-rw-r--r--mlir/lib/IR/BasicBlock.cpp6
-rw-r--r--mlir/lib/IR/Builders.cpp55
-rw-r--r--mlir/lib/IR/BuiltinOps.cpp38
-rw-r--r--mlir/lib/IR/Function.cpp16
-rw-r--r--mlir/lib/IR/Instructions.cpp4
-rw-r--r--mlir/lib/IR/MLIRContext.cpp207
-rw-r--r--mlir/lib/IR/Operation.cpp26
-rw-r--r--mlir/lib/IR/Statement.cpp10
-rw-r--r--mlir/lib/IR/TypeDetail.h126
-rw-r--r--mlir/lib/IR/Types.cpp101
13 files changed, 431 insertions, 270 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 454a28a6558..cb5e96f0086 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -122,7 +122,7 @@ private:
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
- void visitType(const Type *type);
+ void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const Operation *op);
@@ -135,16 +135,16 @@ private:
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
-void ModuleState::visitType(const Type *type) {
- if (auto *funcType = dyn_cast<FunctionType>(type)) {
+void ModuleState::visitType(Type type) {
+ if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
- for (auto *input : funcType->getInputs())
+ for (auto input : funcType.getInputs())
visitType(input);
- for (auto *result : funcType->getResults())
+ for (auto result : funcType.getResults())
visitType(result);
- } else if (auto *memref = dyn_cast<MemRefType>(type)) {
+ } else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
- for (auto map : memref->getAffineMaps()) {
+ for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map);
}
}
@@ -271,7 +271,7 @@ public:
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(Attribute attr);
- void printType(const Type *type);
+ void printType(Type type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
@@ -290,7 +290,7 @@ protected:
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
- void printFunctionResultType(const FunctionType *type);
+ void printFunctionResultType(FunctionType type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
@@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
- auto *type = attr.getType();
- auto shape = type->getShape();
- auto rank = type->getRank();
+ auto type = attr.getType();
+ auto shape = type.getShape();
+ auto rank = type.getRank();
SmallVector<Attribute, 16> elements;
attr.getValues(elements);
@@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
os << ']';
}
-void ModulePrinter::printType(const Type *type) {
- switch (type->getKind()) {
+void ModulePrinter::printType(Type type) {
+ switch (type.getKind()) {
case Type::Kind::Index:
os << "index";
return;
@@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
return;
case Type::Kind::Integer: {
- auto *integer = cast<IntegerType>(type);
- os << 'i' << integer->getWidth();
+ auto integer = type.cast<IntegerType>();
+ os << 'i' << integer.getWidth();
return;
}
case Type::Kind::Function: {
- auto *func = cast<FunctionType>(type);
+ auto func = type.cast<FunctionType>();
os << '(';
- interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
+ interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> ";
- auto results = func->getResults();
+ auto results = func.getResults();
if (results.size() == 1)
- os << *results[0];
+ os << results[0];
else {
os << '(';
- interleaveComma(results, [&](Type *type) { printType(type); });
+ interleaveComma(results, [&](Type type) { printType(type); });
os << ')';
}
return;
}
case Type::Kind::Vector: {
- auto *v = cast<VectorType>(type);
+ auto v = type.cast<VectorType>();
os << "vector<";
- for (auto dim : v->getShape())
+ for (auto dim : v.getShape())
os << dim << 'x';
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
- auto *v = cast<RankedTensorType>(type);
+ auto v = type.cast<RankedTensorType>();
os << "tensor<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
- auto *v = cast<UnrankedTensorType>(type);
+ auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x";
- printType(v->getElementType());
+ printType(v.getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
- auto *v = cast<MemRefType>(type);
+ auto v = type.cast<MemRefType>();
os << "memref<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- printType(v->getElementType());
- for (auto map : v->getAffineMaps()) {
+ printType(v.getElementType());
+ for (auto map : v.getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
// Only print the memory space if it is the non-default one.
- if (v->getMemorySpace())
- os << ", " << v->getMemorySpace();
+ if (v.getMemorySpace())
+ os << ", " << v.getMemorySpace();
os << '>';
return;
}
@@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
// Function printing
//===----------------------------------------------------------------------===//
-void ModulePrinter::printFunctionResultType(const FunctionType *type) {
- switch (type->getResults().size()) {
+void ModulePrinter::printFunctionResultType(FunctionType type) {
+ switch (type.getResults().size()) {
case 0:
break;
case 1:
os << " -> ";
- printType(type->getResults()[0]);
+ printType(type.getResults()[0]);
break;
default:
os << " -> (";
- interleaveComma(type->getResults(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getResults(),
+ [&](Type eltType) { printType(eltType); });
os << ')';
break;
}
@@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleaveComma(type->getInputs(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
@@ -937,7 +936,7 @@ public:
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
- void printType(const Type *type) { ModulePrinter::printType(type); }
+ void printType(Type type) { ModulePrinter::printType(type); }
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
@@ -974,10 +973,10 @@ protected:
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
- if (intOp->getType()->isInteger(1)) {
+ if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
- specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
+ specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
@@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
- ModulePrinter(os, state).printType(this);
+ ModulePrinter(os, state).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index a0e9afb4fd3..63ad544fa48 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
- Type *value;
+ Type value;
};
/// An attribute representing a reference to a function.
@@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage {
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
- VectorOrTensorType *type;
+ VectorOrTensorType type;
};
/// An attribute representing a reference to a vector or tensor constant,
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 34312b84a0b..58b5b90d43d 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const {
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
-Type *TypeAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
@@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
-FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
-VectorOrTensorType *ElementsAttr::getType() const {
+VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
@@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
- auto elementNum = getType()->getNumElements();
- auto context = getType()->getContext();
+ auto elementNum = getType().getNumElements();
+ auto context = getType().getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
@@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
- auto elementNum = getType()->getNumElements();
- auto context = getType()->getContext();
+ auto elementNum = getType().getNumElements();
+ auto context = getType().getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp
index bb8ac75d91a..29a5ce12e4a 100644
--- a/mlir/lib/IR/BasicBlock.cpp
+++ b/mlir/lib/IR/BasicBlock.cpp
@@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() {
// Argument list management.
//===----------------------------------------------------------------------===//
-BBArgument *BasicBlock::addArgument(Type *type) {
+BBArgument *BasicBlock::addArgument(Type type) {
auto *arg = new BBArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type *> types)
+auto BasicBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
- for (auto *type : types) {
+ for (auto type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 22d749a6c8c..906b580d9af 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename,
// Types.
//===----------------------------------------------------------------------===//
-FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
+FloatType Builder::getBF16Type() { return Type::getBF16(context); }
-FloatType *Builder::getF16Type() { return Type::getF16(context); }
+FloatType Builder::getF16Type() { return Type::getF16(context); }
-FloatType *Builder::getF32Type() { return Type::getF32(context); }
+FloatType Builder::getF32Type() { return Type::getF32(context); }
-FloatType *Builder::getF64Type() { return Type::getF64(context); }
+FloatType Builder::getF64Type() { return Type::getF64(context); }
-OtherType *Builder::getIndexType() { return Type::getIndex(context); }
+OtherType Builder::getIndexType() { return Type::getIndex(context); }
-OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
-OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
+OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
-OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
-OtherType *Builder::getTFComplex64Type() {
+OtherType Builder::getTFComplex64Type() {
return Type::getTFComplex64(context);
}
-OtherType *Builder::getTFComplex128Type() {
+OtherType Builder::getTFComplex128Type() {
return Type::getTFComplex128(context);
}
-OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
+OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
-OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+OtherType Builder::getTFStringType() { return Type::getTFString(context); }
-IntegerType *Builder::getIntegerType(unsigned width) {
+IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}
-FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
- ArrayRef<Type *> results) {
+FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
+ ArrayRef<Type> results) {
return FunctionType::get(inputs, results, context);
}
-MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace) {
+MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
+ ArrayRef<AffineMap> affineMapComposition,
+ unsigned memorySpace) {
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
-VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
+VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
return VectorType::get(shape, elementType);
}
-RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
- Type *elementType) {
+RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
return RankedTensorType::get(shape, elementType);
}
-UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+UnrankedTensorType Builder::getTensorType(Type elementType) {
return UnrankedTensorType::get(elementType);
}
@@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
return IntegerSetAttr::get(set);
}
-TypeAttr Builder::getTypeAttr(Type *type) {
+TypeAttr Builder::getTypeAttr(Type type) {
return TypeAttr::get(type, context);
}
@@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}
-ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
-ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
-ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values);
}
-ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
}
@@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
OperationStmt *MLFuncBuilder::createOperation(Location *location,
OperationName name,
ArrayRef<MLValue *> operands,
- ArrayRef<Type *> types,
+ ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs) {
auto *op = OperationStmt::create(location, name, operands, types, attrs,
getContext());
diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp
index 542e67eaefd..e4bca037c4e 100644
--- a/mlir/lib/IR/BuiltinOps.cpp
+++ b/mlir/lib/IR/BuiltinOps.cpp
@@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
numDims = opInfos.size();
// Parse the optional symbol operands.
- auto *affineIntTy = parser->getBuilder().getIndexType();
+ auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
@@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
- auto *affineIntTy = builder.getIndexType();
+ auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
@@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result,
- Attribute value, Type *type) {
+ Attribute value, Type type) {
result->addAttribute("value", value);
result->types.push_back(type);
}
@@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const {
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!getValue().isa<FunctionAttr>())
- *p << " : " << *getType();
+ *p << " : " << getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
- Type *type;
+ Type type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes))
@@ -208,33 +208,33 @@ bool ConstantOp::verify() const {
if (!value)
return emitOpError("requires a 'value' attribute");
- auto *type = this->getType();
- if (isa<IntegerType>(type) || type->isIndex()) {
+ auto type = this->getType();
+ if (type.isa<IntegerType>() || type.isIndex()) {
if (!value.isa<IntegerAttr>())
return emitOpError(
"requires 'value' to be an integer for an integer result type");
return false;
}
- if (isa<FloatType>(type)) {
+ if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
- if (isa<VectorOrTensorType>(type)) {
+ if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
- if (type->isTFString()) {
+ if (type.isTFString()) {
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");
return false;
}
- if (isa<FunctionType>(type)) {
+ if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
@@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
- const APFloat &value, FloatType *type) {
+ const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- isa<FloatType>(op->getResult(0)->getType());
+ op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- isa<IntegerType>(op->getResult(0)->getType());
+ op->getResult(0)->getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState *result,
@@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result,
- int64_t value, Type *type) {
- assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
+ int64_t value, Type type) {
+ assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
}
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Operation *op) {
- return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
+ return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState *result,
@@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result,
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type *, 2> types;
+ SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
@@ -330,7 +330,7 @@ bool ReturnOp::verify() const {
// The operand number and types must match the function signature.
MLFunction *function = cast<MLFunction>(block);
- const auto &results = function->getType()->getResults();
+ const auto &results = function->getType().getResults();
if (stmt->getNumOperands() != results.size())
return emitOpError("has " + Twine(stmt->getNumOperands()) +
" operands, but enclosing function returns " +
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index efeb16b61db..70c0e1259b3 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -28,8 +28,8 @@
using namespace mlir;
Function::Function(Kind kind, Location *location, StringRef name,
- FunctionType *type, ArrayRef<NamedAttribute> attrs)
- : nameAndKind(Identifier::get(name, type->getContext()), kind),
+ FunctionType type, ArrayRef<NamedAttribute> attrs)
+ : nameAndKind(Identifier::get(name, type.getContext()), kind),
location(location), type(type) {
this->attrs = AttributeListStorage::get(attrs, getContext());
}
@@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
return {};
}
-MLIRContext *Function::getContext() const { return getType()->getContext(); }
+MLIRContext *Function::getContext() const { return getType().getContext(); }
/// Delete this object.
void Function::destroy() {
@@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const {
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
-ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::ExtFunc, location, name, type, attrs) {}
@@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
-CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::CFGFunc, location, name, type, attrs) {}
@@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() {
/// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location *location, StringRef name,
- FunctionType *type,
+ FunctionType type,
ArrayRef<NamedAttribute> attrs) {
- const auto &argTypes = type->getInputs();
+ const auto &argTypes = type.getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
@@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
return function;
}
-MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp
index 422636bf2e3..d2f49ddfc6e 100644
--- a/mlir/lib/IR/Instructions.cpp
+++ b/mlir/lib/IR/Instructions.cpp
@@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const {
/// Create a new OperationInst with the specified fields.
OperationInst *OperationInst::create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands,
- ArrayRef<Type *> resultTypes,
+ ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name,
OperationInst *OperationInst::clone() const {
SmallVector<CFGValue *, 8> operands;
- SmallVector<Type *, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
// Put together the operands and results.
for (auto *operand : getOperands())
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 0a2e9416842..8811f7b9f78 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -21,6 +21,7 @@
#include "AttributeDetail.h"
#include "AttributeListStorage.h"
#include "IntegerSetDetail.h"
+#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@@ -44,11 +45,11 @@ using namespace mlir::detail;
using namespace llvm;
namespace {
-struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
+struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
// Functions are uniqued based on their inputs and results.
- using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
- using DenseMapInfo<FunctionType *>::getHashValue;
- using DenseMapInfo<FunctionType *>::isEqual;
+ using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+ using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
+ using DenseMapInfo<FunctionTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
@@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
}
};
-struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
+struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
// Vectors are uniqued based on their element type and shape.
- using KeyTy = std::pair<Type *, ArrayRef<int>>;
- using DenseMapInfo<VectorType *>::getHashValue;
- using DenseMapInfo<VectorType *>::isEqual;
+ using KeyTy = std::pair<Type, ArrayRef<int>>;
+ using DenseMapInfo<VectorTypeStorage *>::getHashValue;
+ using DenseMapInfo<VectorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(key.first),
+ DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+ return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
-struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
// Ranked tensors are uniqued based on their element type and shape.
- using KeyTy = std::pair<Type *, ArrayRef<int>>;
- using DenseMapInfo<RankedTensorType *>::getHashValue;
- using DenseMapInfo<RankedTensorType *>::isEqual;
+ using KeyTy = std::pair<Type, ArrayRef<int>>;
+ using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
+ using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(key.first),
+ DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+ return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
-struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
+struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
// MemRefs are uniqued based on their element type, shape, affine map
// composition, and memory space.
- using KeyTy =
- std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
- using DenseMapInfo<MemRefType *>::getHashValue;
- using DenseMapInfo<MemRefType *>::isEqual;
+ using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
+ using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
+ using DenseMapInfo<MemRefTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
+ DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
std::get<3>(key));
}
- static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
- rhs->getAffineMaps(), rhs->getMemorySpace());
+ return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
+ rhs->getAffineMaps(), rhs->memorySpace);
}
};
@@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
};
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
- using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
+ using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
@@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
- using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
+ using KeyTy = std::pair<VectorOrTensorType, StringRef>;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
@@ -295,13 +295,14 @@ public:
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
// Uniquing table for 'other' types.
- OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
- int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
+ OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
+ int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
+ nullptr};
// Uniquing table for 'float' types.
- FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
- int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
- nullptr};
+ FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
+ int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
+ {nullptr};
// Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@@ -324,26 +325,26 @@ public:
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
/// Integer type uniquing.
- DenseMap<unsigned, IntegerType *> integers;
+ DenseMap<unsigned, IntegerTypeStorage *> integers;
/// Function type uniquing.
- using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
+ using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
FunctionTypeSet functions;
/// Vector type uniquing.
- using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
+ using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
VectorTypeSet vectors;
/// Ranked tensor type uniquing.
using RankedTensorTypeSet =
- DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
+ DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
RankedTensorTypeSet rankedTensors;
/// Unranked tensor type uniquing.
- DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
+ DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
/// MemRef type uniquing.
- using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
+ using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
MemRefTypeSet memrefs;
// Attribute uniquing.
@@ -355,13 +356,12 @@ public:
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
- DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
+ DenseMap<Type, TypeAttributeStorage *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
- DenseMap<std::pair<VectorOrTensorType *, Attribute>,
- SplatElementsAttributeStorage *>
+ DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
splatElementsAttrs;
using DenseElementsAttrSet =
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@@ -369,7 +369,7 @@ public:
using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs;
- DenseMap<std::tuple<Type *, Attribute, Attribute>,
+ DenseMap<std::tuple<Type, Attribute, Attribute>,
SparseElementsAttributeStorage *>
sparseElementsAttrs;
@@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line,
// Type uniquing
//===----------------------------------------------------------------------===//
-IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
+IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
+ assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
auto &impl = context->getImpl();
auto *&result = impl.integers[width];
if (!result) {
- result = impl.allocator.Allocate<IntegerType>();
- new (result) IntegerType(width, context);
+ result = impl.allocator.Allocate<IntegerTypeStorage>();
+ new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
}
return result;
}
-FloatType *FloatType::get(Kind kind, MLIRContext *context) {
+FloatType FloatType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
auto &impl = context->getImpl();
@@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) {
return entry;
// On the first use, we allocate them into the bump pointer.
- auto *ptr = impl.allocator.Allocate<FloatType>();
+ auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
// Initialize the memory using placement new.
- new (ptr) FloatType(kind, context);
+ new (ptr) FloatTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
-OtherType *OtherType::get(Kind kind, MLIRContext *context) {
+OtherType OtherType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
"Not an 'other' type kind");
auto &impl = context->getImpl();
@@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) {
return entry;
// On the first use, we allocate them into the bump pointer.
- auto *ptr = impl.allocator.Allocate<OtherType>();
+ auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
// Initialize the memory using placement new.
- new (ptr) OtherType(kind, context);
+ new (ptr) OtherTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
-FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
- ArrayRef<Type *> results,
- MLIRContext *context) {
+FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+ MLIRContext *context) {
auto &impl = context->getImpl();
// Look to see if we already have this function type.
@@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<FunctionType>();
+ auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
// Copy the inputs and results into the bump pointer.
- SmallVector<Type *, 16> types;
+ SmallVector<Type, 16> types;
types.reserve(inputs.size() + results.size());
types.append(inputs.begin(), inputs.end());
types.append(results.begin(), results.end());
- auto typesList = impl.copyInto(ArrayRef<Type *>(types));
+ auto typesList = impl.copyInto(ArrayRef<Type>(types));
// Initialize the memory using placement new.
- new (result)
- FunctionType(typesList.data(), inputs.size(), results.size(), context);
+ new (result) FunctionTypeStorage{
+ {Kind::Function, context, static_cast<unsigned int>(inputs.size())},
+ static_cast<unsigned int>(results.size()),
+ typesList.data()};
// Cache and return it.
return *existing.first = result;
}
-VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
+VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
- assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
+ assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
"vectors elements must be primitives");
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
return i < 0;
}) && "vector types must have static shape");
- auto *context = elementType->getContext();
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this vector type.
@@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<VectorType>();
+ auto *result = impl.allocator.Allocate<VectorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
- new (result) VectorType(shape, elementType, context);
+ new (result) VectorTypeStorage{
+ {{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
+ elementType},
+ shape.data()};
// Cache and return it.
return *existing.first = result;
}
-RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
- Type *elementType) {
- auto *context = elementType->getContext();
+RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this ranked tensor type.
@@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<RankedTensorType>();
+ auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
- new (result) RankedTensorType(shape, elementType, context);
+ new (result) RankedTensorTypeStorage{
+ {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
+ elementType}},
+ shape.data()};
// Cache and return it.
return *existing.first = result;
}
-UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
- auto *context = elementType->getContext();
+UnrankedTensorType UnrankedTensorType::get(Type elementType) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
@@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
return result;
// On the first use, we allocate them into the bump pointer.
- result = impl.allocator.Allocate<UnrankedTensorType>();
+ result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
// Initialize the memory using placement new.
- new (result) UnrankedTensorType(elementType, context);
+ new (result) UnrankedTensorTypeStorage{
+ {{{Kind::UnrankedTensor, context}, elementType}}};
return result;
}
-MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace) {
- auto *context = elementType->getContext();
+MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
+ ArrayRef<AffineMap> affineMapComposition,
+ unsigned memorySpace) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Drop the unbounded identity maps from the composition.
@@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<MemRefType>();
+ auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
@@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
// Initialize the memory using placement new.
- new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
- context);
+ new (result) MemRefTypeStorage{
+ {Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
+ elementType,
+ shape.data(),
+ static_cast<unsigned int>(affineMapComposition.size()),
+ affineMapComposition.data(),
+ memorySpace};
// Cache and return it.
return *existing.first = result;
}
@@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
return result;
}
-TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
+TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type];
if (result)
return result;
@@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
@@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
return result;
}
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) {
- auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
+ auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
(void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires");
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
DenseElementsAttrInfo::KeyTy key({type, data});
@@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
- auto *eltType = type->getElementType();
- switch (eltType->getKind()) {
+ auto eltType = type.getElementType();
+ switch (eltType.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
@@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result;
}
case Type::Kind::Integer: {
- auto width = ::cast<IntegerType>(eltType)->getWidth();
+ auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
@@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
}
}
-OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
+OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
StringRef bytes) {
- assert(isValidTensorElementType(type->getElementType()) &&
+ assert(isValidTensorElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result;
}
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto key = std::make_tuple(type, indices, values);
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 2ed09b83b53..0722421c8ba 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
}
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
- auto *type = op->getResult(0)->getType();
+ auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type)
return op->emitOpError(
@@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
/// If this is a vector type, or a tensor type, return the scalar element type
/// that it is built around, otherwise return the type unmodified.
-static Type *getTensorOrVectorElementType(Type *type) {
- if (auto *vec = dyn_cast<VectorType>(type))
- return vec->getElementType();
+static Type getTensorOrVectorElementType(Type type) {
+ if (auto vec = type.dyn_cast<VectorType>())
+ return vec.getElementType();
// Look through tensor<vector<...>> to find the underlying element type.
- if (auto *tensor = dyn_cast<TensorType>(type))
- return getTensorOrVectorElementType(tensor->getElementType());
+ if (auto tensor = type.dyn_cast<TensorType>())
+ return getTensorOrVectorElementType(tensor.getElementType());
return type;
}
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
for (auto *result : op->getResults()) {
- if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
+ if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
return op->emitOpError("requires a floating point type");
}
@@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
for (auto *result : op->getResults()) {
- if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
+ if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
return op->emitOpError("requires an integer type");
}
return false;
@@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result,
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
- Type *type;
+ Type type;
return parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
@@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
- *p << " : " << *op->getResult(0)->getType();
+ *p << " : " << op->getResult(0)->getType();
}
//===----------------------------------------------------------------------===//
@@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
//===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result,
- SSAValue *source, Type *destType) {
+ SSAValue *source, Type destType) {
result->addOperands(source);
result->addTypes(destType);
}
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo;
- Type *srcType, *dstType;
+ Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
@@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
- << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
+ << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
}
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp
index e9c46d6ec5e..698089a1c67 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Statement.cpp
@@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block,
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location *location, OperationName name,
ArrayRef<MLValue *> operands,
- ArrayRef<Type *> resultTypes,
+ ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const {
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
- return getResult(0)->getType()->getContext();
+ return getResult(0)->getType().getContext();
if (getNumOperands())
- return getOperand(0)->getType()->getContext();
+ return getOperand(0)->getType().getContext();
// In the very odd case where we have no operands or results, fall back to
// doing a find.
@@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const {
if (operands.empty())
return findFunction()->getContext();
- return getOperand(0)->getType()->getContext();
+ return getOperand(0)->getType().getContext();
}
//===----------------------------------------------------------------------===//
@@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
operands.push_back(remapOperand(opValue));
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
- SmallVector<Type *, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
new file mode 100644
index 00000000000..c22e87a283e
--- /dev/null
+++ b/mlir/lib/IR/TypeDetail.h
@@ -0,0 +1,126 @@
+//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Type.
+//
+//===----------------------------------------------------------------------===//
+#ifndef TYPEDETAIL_H_
+#define TYPEDETAIL_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+
+class AffineMap;
+class MLIRContext;
+
+namespace detail {
+
+/// Base storage class appearing in a Type.
+struct alignas(8) TypeStorage {
+ TypeStorage(Type::Kind kind, MLIRContext *context)
+ : context(context), kind(kind), subclassData(0) {}
+ TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
+ : context(context), kind(kind), subclassData(subclassData) {}
+
+ unsigned getSubclassData() const { return subclassData; }
+
+ void setSubclassData(unsigned val) {
+ subclassData = val;
+ // Ensure we don't have any accidental truncation.
+ assert(getSubclassData() == val && "Subclass data too large for field");
+ }
+
+ /// This refers to the MLIRContext in which this type was uniqued.
+ MLIRContext *const context;
+
+ /// Classification of the subclass, used for type checking.
+ Type::Kind kind : 8;
+
+ /// Space for subclasses to store data.
+ unsigned subclassData : 24;
+};
+
+struct IntegerTypeStorage : public TypeStorage {
+ unsigned width;
+};
+
+struct FloatTypeStorage : public TypeStorage {};
+
+struct OtherTypeStorage : public TypeStorage {};
+
+struct FunctionTypeStorage : public TypeStorage {
+ ArrayRef<Type> getInputs() const {
+ return ArrayRef<Type>(inputsAndResults, subclassData);
+ }
+ ArrayRef<Type> getResults() const {
+ return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
+ }
+
+ unsigned numResults;
+ Type const *inputsAndResults;
+};
+
+struct VectorOrTensorTypeStorage : public TypeStorage {
+ Type elementType;
+};
+
+struct VectorTypeStorage : public VectorOrTensorTypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ const int *shapeElements;
+};
+
+struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
+
+struct RankedTensorTypeStorage : public TensorTypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ const int *shapeElements;
+};
+
+struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
+
+struct MemRefTypeStorage : public TypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ ArrayRef<AffineMap> getAffineMaps() const {
+ return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+ }
+
+ /// The type of each scalar element of the memref.
+ Type elementType;
+ /// An array of integers which stores the shape dimension sizes.
+ const int *shapeElements;
+ /// The number of affine maps in the 'affineMapList' array.
+ const unsigned numAffineMaps;
+ /// List of affine maps in the memref's layout/index map composition.
+ AffineMap const *affineMapList;
+ /// Memory space in which data referenced by memref resides.
+ const unsigned memorySpace;
+};
+
+} // namespace detail
+} // namespace mlir
+#endif // TYPEDETAIL_H_
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 0ad3f4728fe..1a716956608 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -16,10 +16,17 @@
// =============================================================================
#include "mlir/IR/Types.h"
+#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
+
using namespace mlir;
+using namespace mlir::detail;
+
+Type::Kind Type::getKind() const { return type->kind; }
+
+MLIRContext *Type::getContext() const { return type->context; }
unsigned Type::getBitWidth() const {
switch (getKind()) {
@@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const {
case Type::Kind::F64:
return 64;
case Type::Kind::Integer:
- return cast<IntegerType>(this)->getWidth();
+ return cast<IntegerType>().getWidth();
case Type::Kind::Vector:
case Type::Kind::RankedTensor:
case Type::Kind::UnrankedTensor:
- return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
+ return cast<VectorOrTensorType>().getElementType().getBitWidth();
// TODO: Handle more types.
default:
llvm_unreachable("unexpected type");
}
}
-IntegerType::IntegerType(unsigned width, MLIRContext *context)
- : Type(Kind::Integer, context), width(width) {
- assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+ return static_cast<ImplType *>(type)->width;
}
-FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
+
+OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
-OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
-FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
- unsigned numResults, MLIRContext *context)
- : Type(Kind::Function, context, numInputs), numResults(numResults),
- inputsAndResults(inputsAndResults) {}
+ArrayRef<Type> FunctionType::getInputs() const {
+ return static_cast<ImplType *>(type)->getInputs();
+}
+
+unsigned FunctionType::getNumResults() const {
+ return static_cast<ImplType *>(type)->numResults;
+}
+
+ArrayRef<Type> FunctionType::getResults() const {
+ return static_cast<ImplType *>(type)->getResults();
+}
-VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
- Type *elementType, unsigned subClassData)
- : Type(kind, context, subClassData), elementType(elementType) {}
+VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
+
+Type VectorOrTensorType::getElementType() const {
+ return static_cast<ImplType *>(type)->elementType;
+}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
@@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
- return cast<VectorType>(this)->getShape();
+ return cast<VectorType>().getShape();
case Kind::RankedTensor:
- return cast<RankedTensorType>(this)->getShape();
+ return cast<RankedTensorType>().getShape();
case Kind::UnrankedTensor:
- return cast<RankedTensorType>(this)->getShape();
+ return cast<RankedTensorType>().getShape();
default:
llvm_unreachable("not a VectorOrTensorType");
}
@@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const {
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
-VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
- MLIRContext *context)
- : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
- shapeElements(shape.data()) {}
+VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
- : VectorOrTensorType(kind, context, elementType) {
- assert(isValidTensorElementType(elementType));
+ArrayRef<int> VectorType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
}
-RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
- MLIRContext *context)
- : TensorType(Kind::RankedTensor, elementType, context),
- shapeElements(shape.data()) {
- setSubclassData(shape.size());
+TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
+
+RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
+
+ArrayRef<int> RankedTensorType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
}
-UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
- : TensorType(Kind::UnrankedTensor, elementType, context) {}
+UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
+
+MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
-MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
- MLIRContext *context)
- : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
- shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
- affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
+ArrayRef<int> MemRefType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
+}
+
+Type MemRefType::getElementType() const {
+ return static_cast<ImplType *>(type)->elementType;
+}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
- return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+ return static_cast<ImplType *>(type)->getAffineMaps();
+}
+
+unsigned MemRefType::getMemorySpace() const {
+ return static_cast<ImplType *>(type)->memorySpace;
}
unsigned MemRefType::getNumDynamicDims() const {
OpenPOWER on IntegriCloud