diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h | 19 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 104 |
2 files changed, 63 insertions, 60 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index 8faa90cb134..306f2b9f309 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -30,21 +30,23 @@ namespace mlir { -/// Converts a function type according to the requirements of a SPIR-V entry -/// function. The arguments need to be converted to spv.GlobalVariables of -/// spv.ptr types so that they could be bound by the runtime. +/// Type conversion from stdandard types to SPIR-V types for shader interface. +/// +/// For composite types, this converter additionally performs type wrapping to +/// satisfy shader interface requirements: shader interface types must be +/// pointers to structs. class SPIRVTypeConverter final : public TypeConverter { public: using TypeConverter::TypeConverter; - /// Converts types to SPIR-V types using the basic type converter. - Type convertType(Type t) override; + /// Converts the given standard `type` to SPIR-V correspondance. + Type convertType(Type type) override; - /// Gets the index type equivalent in SPIR-V. - Type getIndexType(MLIRContext *context); + /// Gets the SPIR-V correspondance for the standard index type. + static Type getIndexType(MLIRContext *context); }; -/// Base class to define a conversion pattern to translate Ops into SPIR-V. +/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. template <typename SourceOp> class SPIRVOpLowering : public OpConversionPattern<SourceOp> { public: @@ -54,7 +56,6 @@ public: typeConverter(typeConverter) {} protected: - /// Type lowering class. SPIRVTypeConverter &typeConverter; }; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 3c571add56a..baa9ed305aa 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -68,8 +68,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, // Type Conversion //===----------------------------------------------------------------------===// -namespace { -Type convertIndexType(MLIRContext *context) { +Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { // Convert to 32-bit integers for now. Might need a way to control this in // future. // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To @@ -82,7 +81,7 @@ Type convertIndexType(MLIRContext *context) { // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. -Optional<int64_t> getTypeNumBytes(Type t) { +static Optional<int64_t> getTypeNumBytes(Type t) { if (auto integerType = t.dyn_cast<IntegerType>()) { return integerType.getWidth() / 8; } else if (auto floatType = t.dyn_cast<FloatType>()) { @@ -92,17 +91,17 @@ Optional<int64_t> getTypeNumBytes(Type t) { return llvm::None; } -Type typeConversionImpl(Type t) { - // Check if the type is SPIR-V supported. If so return the type. - if (spirv::SPIRVDialect::isValidType(t)) { - return t; +static Type convertStdType(Type type) { + // If the type is already valid in SPIR-V, directly return. + if (spirv::SPIRVDialect::isValidType(type)) { + return type; } - if (auto indexType = t.dyn_cast<IndexType>()) { - return convertIndexType(t.getContext()); + if (auto indexType = type.dyn_cast<IndexType>()) { + return SPIRVTypeConverter::getIndexType(type.getContext()); } - if (auto memRefType = t.dyn_cast<MemRefType>()) { + if (auto memRefType = type.dyn_cast<MemRefType>()) { // TODO(ravishankarm): For now only support default memory space. The memory // space description is not set is stone within MLIR, i.e. it depends on the // context it is being used. To map this to SPIR-V storage classes, we @@ -111,60 +110,65 @@ Type typeConversionImpl(Type t) { if (memRefType.getMemorySpace()) { return Type(); } - auto elementType = typeConversionImpl(memRefType.getElementType()); + + auto elementType = convertStdType(memRefType.getElementType()); if (!elementType) { return Type(); } + auto elementSize = getTypeNumBytes(elementType); if (!elementSize) { return Type(); } - // TODO(ravishankarm) : Handle dynamic shapes. - if (memRefType.hasStaticShape()) { - // Get the strides and offset - int64_t offset; - SmallVector<int64_t, 4> strides; - if (failed(getStridesAndOffset(memRefType, strides, offset)) || - offset == MemRefType::getDynamicStrideOrOffset() || - llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { - // TODO(ravishankarm) : Handle dynamic strides and offsets. - return Type(); - } - // Convert to a multi-dimensional spv.array if size is known. - auto shape = memRefType.getShape(); - assert(shape.size() == strides.size()); - for (int i = shape.size(); i > 0; --i) { - elementType = spirv::ArrayType::get( - elementType, shape[i - 1], strides[i - 1] * elementSize.getValue()); - } - // For the offset, need to wrap the array in a struct. - auto structType = - spirv::StructType::get(elementType, offset * elementSize.getValue()); - // For now initialize the storage class to StorageBuffer. This will be - // updated later based on whats passed in w.r.t to the ABI attributes. - return spirv::PointerType::get(structType, - spirv::StorageClass::StorageBuffer); + + if (!memRefType.hasStaticShape()) { + // TODO(ravishankarm) : Handle dynamic shapes. + return Type(); } + + // Get the strides and offset. + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(memRefType, strides, offset)) || + offset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + // TODO(ravishankarm) : Handle dynamic strides and offsets. + return Type(); + } + + // Convert to a multi-dimensional spv.array if size is known. + auto shape = memRefType.getShape(); + assert(shape.size() == strides.size()); + Type arrayType = elementType; + // TODO(antiagainst): Introduce layout as part of the shader ABI to have + // better separate of concerns. + for (int i = shape.size(); i > 0; --i) { + arrayType = spirv::ArrayType::get( + arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue()); + } + + // For the offset, need to wrap the array in a struct. + auto structType = + spirv::StructType::get(arrayType, offset * elementSize.getValue()); + // For now initialize the storage class to StorageBuffer. This will be + // updated later based on whats passed in w.r.t to the ABI attributes. + return spirv::PointerType::get(structType, + spirv::StorageClass::StorageBuffer); } + return Type(); } -} // namespace - -Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); } -Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { - return convertType(IndexType::get(context)); -} +Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// -namespace { /// Look through all global variables in `moduleOp` and check if there is a /// spv.globalVariable that has the same `builtin` attribute. -spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, - spirv::BuiltIn builtin) { +static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, + spirv::BuiltIn builtin) { for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) { if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase( stringifyDecoration(spirv::Decoration::BuiltIn)))) { @@ -178,15 +182,14 @@ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, } /// Gets name of global variable for a buitlin. -std::string getBuiltinVarName(spirv::BuiltIn builtin) { +static std::string getBuiltinVarName(spirv::BuiltIn builtin) { return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } /// Gets or inserts a global variable for a builtin within a module. -spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, - Location loc, - spirv::BuiltIn builtin, - OpBuilder &builder) { +static spirv::GlobalVariableOp +getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, + spirv::BuiltIn builtin, OpBuilder &builder) { if (auto varOp = getBuiltinVariable(moduleOp, builtin)) { return varOp; } @@ -217,7 +220,6 @@ spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, builder.restoreInsertionPoint(ip); return newVarOp; } -} // namespace /// Gets the global variable associated with a builtin and add /// it if it doesnt exist. |

