diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/MLIRContext.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardTypes.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/IR/TypeDetail.h | 34 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 28 | ||||
| -rw-r--r-- | mlir/lib/Parser/TokenKinds.def | 1 |
7 files changed, 94 insertions, 1 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4454f69b7fc..b9ca89dfb73 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -805,6 +805,13 @@ void ModulePrinter::printType(Type type) { os << '>'; return; } + case StandardTypes::Tuple: { + auto tuple = type.cast<TupleType>(); + os << "tuple<"; + interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); + os << '>'; + return; + } } } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 56d0ad059fa..6f1936b951f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -96,6 +96,10 @@ UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } +TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) { + return TupleType::get(elementTypes, context); +} + //===----------------------------------------------------------------------===// // Attributes. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 140dfa6b3eb..7ea757229bb 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -105,7 +105,8 @@ namespace { struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType, - VectorType, RankedTensorType, UnrankedTensorType, MemRefType>(); + VectorType, RankedTensorType, UnrankedTensorType, MemRefType, + TupleType>(); } }; diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 0d46aa59e05..b9da3b92285 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -366,3 +366,21 @@ unsigned MemRefType::getMemorySpace() const { unsigned MemRefType::getNumDynamicDims() const { return llvm::count_if(getShape(), [](int64_t i) { return i < 0; }); } + +/// TupleType + +/// Get or create a new TupleType with the provided element types. Assumes the +/// arguments define a well-formed type. +TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) { + return Base::get(context, StandardTypes::Tuple, elementTypes); +} + +/// Return the elements types for this tuple. +ArrayRef<Type> TupleType::getTypes() const { + return static_cast<ImplType *>(type)->getTypes(); +} + +/// Return the number of element types. +unsigned TupleType::size() const { + return static_cast<ImplType *>(type)->size(); +} diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 91762df53d6..c55f7956334 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -26,6 +26,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +#include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -255,6 +256,39 @@ struct MemRefTypeStorage : public TypeStorage { const unsigned memorySpace; }; +/// A type representing a collection of other types. +struct TupleTypeStorage final + : public TypeStorage, + public llvm::TrailingObjects<TupleTypeStorage, Type> { + using KeyTy = ArrayRef<Type>; + + TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {} + + /// Construction. + static TupleTypeStorage *construct(TypeStorageAllocator &allocator, + const ArrayRef<Type> &key) { + // Allocate a new storage instance. + auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size()); + auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage)); + auto result = ::new (rawMem) TupleTypeStorage(key.size()); + + // Copy in the element types into the trailing storage. + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects<Type>()); + return result; + } + + bool operator==(const KeyTy &key) const { return key == getTypes(); } + + /// Return the number of held types. + unsigned size() const { return getSubclassData(); } + + /// Return the held types. + ArrayRef<Type> getTypes() const { + return {getTrailingObjects<Type>(), size()}; + } +}; + } // namespace detail } // namespace mlir #endif // TYPEDETAIL_H_ diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d01727395a9..4905beebf90 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -185,6 +185,7 @@ public: bool allowDynamic); Type parseExtendedType(); Type parseTensorType(); + Type parseTupleType(); Type parseMemRefType(); Type parseFunctionType(); Type parseNonFunctionType(); @@ -319,6 +320,7 @@ ParseResult Parser::parseCommaSeparatedListUntil( /// | vector-type /// | tensor-type /// | memref-type +/// | tuple-type /// /// index-type ::= `index` /// float-type ::= `f16` | `bf16` | `f32` | `f64` @@ -331,6 +333,8 @@ Type Parser::parseNonFunctionType() { return parseMemRefType(); case Token::kw_tensor: return parseTensorType(); + case Token::kw_tuple: + return parseTupleType(); case Token::kw_vector: return parseVectorType(); // integer-type @@ -567,6 +571,30 @@ Type Parser::parseTensorType() { return RankedTensorType::getChecked(dimensions, elementType, typeLocation); } +/// Parse a tuple type. +/// +/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` +/// +Type Parser::parseTupleType() { + consumeToken(Token::kw_tuple); + + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in tuple type")) + return nullptr; + + // Check for an empty tuple by directly parsing '>'. + if (consumeIf(Token::greater)) + return TupleType::get(getContext()); + + // Parse the element types and the '>'. + SmallVector<Type, 4> types; + if (parseTypeListNoParens(types) || + parseToken(Token::greater, "expected '>' in tuple type")) + return nullptr; + + return TupleType::get(types, getContext()); +} + /// Parse a memref type. /// /// memref-type ::= `memref` `<` dimension-list-ranked element-type diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index ec00f98b3f5..f58fa9cef41 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -111,6 +111,7 @@ TOK_KEYWORD(step) TOK_KEYWORD(tensor) TOK_KEYWORD(to) TOK_KEYWORD(true) +TOK_KEYWORD(tuple) TOK_KEYWORD(type) TOK_KEYWORD(sparse) TOK_KEYWORD(vector) |

