summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Builders.cpp4
-rw-r--r--mlir/lib/IR/MLIRContext.cpp3
-rw-r--r--mlir/lib/IR/StandardTypes.cpp18
-rw-r--r--mlir/lib/IR/TypeDetail.h34
-rw-r--r--mlir/lib/Parser/Parser.cpp28
-rw-r--r--mlir/lib/Parser/TokenKinds.def1
7 files changed, 94 insertions, 1 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4454f69b7fc..b9ca89dfb73 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -805,6 +805,13 @@ void ModulePrinter::printType(Type type) {
os << '>';
return;
}
+ case StandardTypes::Tuple: {
+ auto tuple = type.cast<TupleType>();
+ os << "tuple<";
+ interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
+ os << '>';
+ return;
+ }
}
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 56d0ad059fa..6f1936b951f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -96,6 +96,10 @@ UnrankedTensorType Builder::getTensorType(Type elementType) {
return UnrankedTensorType::get(elementType);
}
+TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
+ return TupleType::get(elementTypes, context);
+}
+
//===----------------------------------------------------------------------===//
// Attributes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 140dfa6b3eb..7ea757229bb 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -105,7 +105,8 @@ namespace {
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) {
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
- VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
+ VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
+ TupleType>();
}
};
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 0d46aa59e05..b9da3b92285 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -366,3 +366,21 @@ unsigned MemRefType::getMemorySpace() const {
unsigned MemRefType::getNumDynamicDims() const {
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
}
+
+/// TupleType
+
+/// Get or create a new TupleType with the provided element types. Assumes the
+/// arguments define a well-formed type.
+TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
+ return Base::get(context, StandardTypes::Tuple, elementTypes);
+}
+
+/// Return the elements types for this tuple.
+ArrayRef<Type> TupleType::getTypes() const {
+ return static_cast<ImplType *>(type)->getTypes();
+}
+
+/// Return the number of element types.
+unsigned TupleType::size() const {
+ return static_cast<ImplType *>(type)->size();
+}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 91762df53d6..c55f7956334 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -26,6 +26,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
+#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@@ -255,6 +256,39 @@ struct MemRefTypeStorage : public TypeStorage {
const unsigned memorySpace;
};
+/// A type representing a collection of other types.
+struct TupleTypeStorage final
+ : public TypeStorage,
+ public llvm::TrailingObjects<TupleTypeStorage, Type> {
+ using KeyTy = ArrayRef<Type>;
+
+ TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
+
+ /// Construction.
+ static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
+ const ArrayRef<Type> &key) {
+ // Allocate a new storage instance.
+ auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
+ auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));
+ auto result = ::new (rawMem) TupleTypeStorage(key.size());
+
+ // Copy in the element types into the trailing storage.
+ std::uninitialized_copy(key.begin(), key.end(),
+ result->getTrailingObjects<Type>());
+ return result;
+ }
+
+ bool operator==(const KeyTy &key) const { return key == getTypes(); }
+
+ /// Return the number of held types.
+ unsigned size() const { return getSubclassData(); }
+
+ /// Return the held types.
+ ArrayRef<Type> getTypes() const {
+ return {getTrailingObjects<Type>(), size()};
+ }
+};
+
} // namespace detail
} // namespace mlir
#endif // TYPEDETAIL_H_
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d01727395a9..4905beebf90 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -185,6 +185,7 @@ public:
bool allowDynamic);
Type parseExtendedType();
Type parseTensorType();
+ Type parseTupleType();
Type parseMemRefType();
Type parseFunctionType();
Type parseNonFunctionType();
@@ -319,6 +320,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
/// | vector-type
/// | tensor-type
/// | memref-type
+/// | tuple-type
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
@@ -331,6 +333,8 @@ Type Parser::parseNonFunctionType() {
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
+ case Token::kw_tuple:
+ return parseTupleType();
case Token::kw_vector:
return parseVectorType();
// integer-type
@@ -567,6 +571,30 @@ Type Parser::parseTensorType() {
return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
}
+/// Parse a tuple type.
+///
+/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
+///
+Type Parser::parseTupleType() {
+ consumeToken(Token::kw_tuple);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in tuple type"))
+ return nullptr;
+
+ // Check for an empty tuple by directly parsing '>'.
+ if (consumeIf(Token::greater))
+ return TupleType::get(getContext());
+
+ // Parse the element types and the '>'.
+ SmallVector<Type, 4> types;
+ if (parseTypeListNoParens(types) ||
+ parseToken(Token::greater, "expected '>' in tuple type"))
+ return nullptr;
+
+ return TupleType::get(types, getContext());
+}
+
/// Parse a memref type.
///
/// memref-type ::= `memref` `<` dimension-list-ranked element-type
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index ec00f98b3f5..f58fa9cef41 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -111,6 +111,7 @@ TOK_KEYWORD(step)
TOK_KEYWORD(tensor)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
+TOK_KEYWORD(tuple)
TOK_KEYWORD(type)
TOK_KEYWORD(sparse)
TOK_KEYWORD(vector)
OpenPOWER on IntegriCloud