summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h19
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp104
2 files changed, 63 insertions, 60 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 8faa90cb134..306f2b9f309 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -30,21 +30,23 @@
namespace mlir {
-/// Converts a function type according to the requirements of a SPIR-V entry
-/// function. The arguments need to be converted to spv.GlobalVariables of
-/// spv.ptr types so that they could be bound by the runtime.
+/// Type conversion from stdandard types to SPIR-V types for shader interface.
+///
+/// For composite types, this converter additionally performs type wrapping to
+/// satisfy shader interface requirements: shader interface types must be
+/// pointers to structs.
class SPIRVTypeConverter final : public TypeConverter {
public:
using TypeConverter::TypeConverter;
- /// Converts types to SPIR-V types using the basic type converter.
- Type convertType(Type t) override;
+ /// Converts the given standard `type` to SPIR-V correspondance.
+ Type convertType(Type type) override;
- /// Gets the index type equivalent in SPIR-V.
- Type getIndexType(MLIRContext *context);
+ /// Gets the SPIR-V correspondance for the standard index type.
+ static Type getIndexType(MLIRContext *context);
};
-/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
template <typename SourceOp>
class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
public:
@@ -54,7 +56,6 @@ public:
typeConverter(typeConverter) {}
protected:
- /// Type lowering class.
SPIRVTypeConverter &typeConverter;
};
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 3c571add56a..baa9ed305aa 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -68,8 +68,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
// Type Conversion
//===----------------------------------------------------------------------===//
-namespace {
-Type convertIndexType(MLIRContext *context) {
+Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// Convert to 32-bit integers for now. Might need a way to control this in
// future.
// TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
@@ -82,7 +81,7 @@ Type convertIndexType(MLIRContext *context) {
// TODO(ravishankarm): This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
-Optional<int64_t> getTypeNumBytes(Type t) {
+static Optional<int64_t> getTypeNumBytes(Type t) {
if (auto integerType = t.dyn_cast<IntegerType>()) {
return integerType.getWidth() / 8;
} else if (auto floatType = t.dyn_cast<FloatType>()) {
@@ -92,17 +91,17 @@ Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}
-Type typeConversionImpl(Type t) {
- // Check if the type is SPIR-V supported. If so return the type.
- if (spirv::SPIRVDialect::isValidType(t)) {
- return t;
+static Type convertStdType(Type type) {
+ // If the type is already valid in SPIR-V, directly return.
+ if (spirv::SPIRVDialect::isValidType(type)) {
+ return type;
}
- if (auto indexType = t.dyn_cast<IndexType>()) {
- return convertIndexType(t.getContext());
+ if (auto indexType = type.dyn_cast<IndexType>()) {
+ return SPIRVTypeConverter::getIndexType(type.getContext());
}
- if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ if (auto memRefType = type.dyn_cast<MemRefType>()) {
// TODO(ravishankarm): For now only support default memory space. The memory
// space description is not set is stone within MLIR, i.e. it depends on the
// context it is being used. To map this to SPIR-V storage classes, we
@@ -111,60 +110,65 @@ Type typeConversionImpl(Type t) {
if (memRefType.getMemorySpace()) {
return Type();
}
- auto elementType = typeConversionImpl(memRefType.getElementType());
+
+ auto elementType = convertStdType(memRefType.getElementType());
if (!elementType) {
return Type();
}
+
auto elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
return Type();
}
- // TODO(ravishankarm) : Handle dynamic shapes.
- if (memRefType.hasStaticShape()) {
- // Get the strides and offset
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
- offset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
- // TODO(ravishankarm) : Handle dynamic strides and offsets.
- return Type();
- }
- // Convert to a multi-dimensional spv.array if size is known.
- auto shape = memRefType.getShape();
- assert(shape.size() == strides.size());
- for (int i = shape.size(); i > 0; --i) {
- elementType = spirv::ArrayType::get(
- elementType, shape[i - 1], strides[i - 1] * elementSize.getValue());
- }
- // For the offset, need to wrap the array in a struct.
- auto structType =
- spirv::StructType::get(elementType, offset * elementSize.getValue());
- // For now initialize the storage class to StorageBuffer. This will be
- // updated later based on whats passed in w.r.t to the ABI attributes.
- return spirv::PointerType::get(structType,
- spirv::StorageClass::StorageBuffer);
+
+ if (!memRefType.hasStaticShape()) {
+ // TODO(ravishankarm) : Handle dynamic shapes.
+ return Type();
}
+
+ // Get the strides and offset.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
+ offset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+ // TODO(ravishankarm) : Handle dynamic strides and offsets.
+ return Type();
+ }
+
+ // Convert to a multi-dimensional spv.array if size is known.
+ auto shape = memRefType.getShape();
+ assert(shape.size() == strides.size());
+ Type arrayType = elementType;
+ // TODO(antiagainst): Introduce layout as part of the shader ABI to have
+ // better separate of concerns.
+ for (int i = shape.size(); i > 0; --i) {
+ arrayType = spirv::ArrayType::get(
+ arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
+ }
+
+ // For the offset, need to wrap the array in a struct.
+ auto structType =
+ spirv::StructType::get(arrayType, offset * elementSize.getValue());
+ // For now initialize the storage class to StorageBuffer. This will be
+ // updated later based on whats passed in w.r.t to the ABI attributes.
+ return spirv::PointerType::get(structType,
+ spirv::StorageClass::StorageBuffer);
}
+
return Type();
}
-} // namespace
-
-Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }
-Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
- return convertType(IndexType::get(context));
-}
+Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
-namespace {
/// Look through all global variables in `moduleOp` and check if there is a
/// spv.globalVariable that has the same `builtin` attribute.
-spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
- spirv::BuiltIn builtin) {
+static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+ spirv::BuiltIn builtin) {
for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
stringifyDecoration(spirv::Decoration::BuiltIn)))) {
@@ -178,15 +182,14 @@ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
}
/// Gets name of global variable for a buitlin.
-std::string getBuiltinVarName(spirv::BuiltIn builtin) {
+static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
}
/// Gets or inserts a global variable for a builtin within a module.
-spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
- Location loc,
- spirv::BuiltIn builtin,
- OpBuilder &builder) {
+static spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
+ spirv::BuiltIn builtin, OpBuilder &builder) {
if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
return varOp;
}
@@ -217,7 +220,6 @@ spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
builder.restoreInsertionPoint(ip);
return newVarOp;
}
-} // namespace
/// Gets the global variable associated with a builtin and add
/// it if it doesnt exist.
OpenPOWER on IntegriCloud