//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" using namespace mlir; LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); } // Get the LLVM context. llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { return module->getContext(); } // Extract an LLVM IR type from the LLVM IR dialect type. LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy( llvmDialect, module->getDataLayout().getPointerSizeInBits()); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: return LLVM::LLVMType::getFloatTy(llvmDialect); case mlir::StandardTypes::F64: return LLVM::LLVMType::getDoubleTy(llvmDialect); case mlir::StandardTypes::F16: return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { auto *mlirContext = llvmDialect->getContext(); return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), Type(); } default: llvm_unreachable("non-float type in convertFloatType"); } } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { // Convert argument types one by one and check for errors. SmallVector argTypes; for (auto t : type.getInputs()) { auto converted = convertType(t); if (!converted) return {}; argTypes.push_back(unwrap(converted)); } // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(llvmDialect) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false) .getPointerTo(); } // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then // we return a pointer to the converted element type. Otherwise we return an // LLVM stucture type, where the first element of the structure type is a // pointer to the elemental type of the MemRef and the following N elements are // values of the Index type, one for each of N dynamic dimensions of the MemRef. Type LLVMTypeConverter::convertMemRefType(MemRefType type) { LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrType = elementType.getPointerTo(); // Extra value for the memory space. unsigned numDynamicSizes = type.getNumDynamicDims(); // If memref is statically-shaped we return the underlying pointer type. if (numDynamicSizes == 0) return ptrType; SmallVector types(numDynamicSizes + 1, getIndexType()); types.front() = ptrType; return LLVM::LLVMType::getStructTy(llvmDialect, types); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } // Dispatch based on the actual type. Return null type on error. Type LLVMTypeConverter::convertStandardType(Type type) { if (auto funcType = type.dyn_cast()) return convertFunctionType(funcType); if (auto intType = type.dyn_cast()) return convertIntegerType(intType); if (auto floatType = type.dyn_cast()) return convertFloatType(floatType); if (auto indexType = type.dyn_cast()) return convertIndexType(indexType); if (auto memRefType = type.dyn_cast()) return convertMemRefType(memRefType); if (auto vectorType = type.dyn_cast()) return convertVectorType(vectorType); if (auto llvmType = type.dyn_cast()) return llvmType; return {}; } // Convert the element type of the memref `t` to to an LLVM type using // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it // into the MLIR LLVM dialect type and return. static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) { auto elementType = t.getElementType(); auto converted = lowering.convertType(elementType); if (!converted) return {}; return converted.cast().getPointerTo(); } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_) : ConversionPattern(rootOpName, /*benefit=*/1, context), lowering(lowering_) {} namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. template class LLVMLegalizationPattern : public LLVMOpLowering { public: // Construct a conversion pattern. explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &lowering_) : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), lowering_), dialect(dialect_) {} // Get the LLVM IR dialect. LLVM::LLVMDialect &getDialect() const { return dialect; } // Get the LLVM context. llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); } // Get the LLVM module in which the types are constructed. llvm::Module &getModule() const { return dialect.getLLVMModule(); } // Get the MLIR type wrapping the LLVM integer type whose bit width is defined // by the pointer size used in the LLVM module. LLVM::LLVMType getIndexType() const { return LLVM::LLVMType::getIntNTy( &dialect, getModule().getDataLayout().getPointerSizeInBits()); } // Get the MLIR type wrapping the LLVM i8* type. LLVM::LLVMType getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&dialect); } // Create an LLVM IR pseudo-operation defining the given index constant. Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const { auto attr = builder.getIntegerAttr(builder.getIndexType(), value); return builder.create(loc, getIndexType(), attr); } // Get the array attribute named "position" containing the given list of // integers as integer attribute elements. static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder, ArrayRef values) { SmallVector attrs; attrs.reserve(values.size()); for (int64_t pos : values) attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos)); return builder.getArrayAttr(attrs); } // Extract raw data pointer value from a value representing a memref. static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, Location loc, Value *convertedMemRefValue, Type elementTypePtr, bool hasStaticShape) { Value *buffer; if (hasStaticShape) return convertedMemRefValue; else return builder.create( loc, elementTypePtr, convertedMemRefValue, getIntegerArrayAttr(builder, 0)); return buffer; } protected: LLVM::LLVMDialect &dialect; }; struct FuncOpConversion : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); FunctionType type = funcOp.getType(); // Convert the original function arguments. TypeConverter::SignatureConversion result(type.getNumInputs()); for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) if (failed(lowering.convertSignatureArg(i, type.getInput(i), result))) return matchFailure(); // Pack the result types into a struct. Type packedResult; if (type.getNumResults() != 0) { if (!(packedResult = lowering.packFunctionResults(type.getResults()))) return matchFailure(); } // Create a new function with an updated signature. auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); newFuncOp.setType(FunctionType::get( result.getConvertedTypes(), packedResult ? ArrayRef(packedResult) : llvm::None, funcOp.getContext())); // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); rewriter.replaceOp(op, llvm::None); return matchSuccess(); } }; // Basic lowering implementation for one-to-one rewriting from Standard Ops to // LLVM Dialect Ops. template struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMOpLowering; // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = this->lowering.packFunctionResults( llvm::to_vector<4>(op->getResultTypes())); assert(packedType && "type conversion failed, such operation should not " "have been matched"); } auto newOp = rewriter.create(op->getLoc(), packedType, operands, op->getAttrs()); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), this->matchSuccess(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), this->getIntegerArrayAttr(rewriter, i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); } }; // Specific lowerings. // FIXME: this should be tablegen'ed. struct AddIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct SubIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct MulIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct DivISOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct DivIUOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct RemISOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct RemIUOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct AndOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct OrOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct XOrOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct AddFOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct SubFOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct MulFOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct DivFOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct RemFOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct SelectOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct CallOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct CallIndirectOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct ConstLLVMOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; // Check if the MemRefType `type` is supported by the lowering. We currently do // not support memrefs with affine maps and non-default memory spaces. static bool isSupportedMemRefType(MemRefType type) { if (!type.getAffineMaps().empty()) return false; if (type.getMemorySpace() != 0) return false; return true; } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where the first element is a pointer // to the (typed) data buffer, and the remaining elements serve to store // dynamic sizes of the memref using LLVM-converted `index` type. struct AllocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult match(Operation *op) const override { MemRefType type = cast(op).getType(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto allocOp = cast(op); MemRefType type = allocOp.getType(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; auto numOperands = allocOp.getNumOperands(); sizes.reserve(numOperands); unsigned i = 0; for (int64_t s : type.getShape()) sizes.push_back(s == -1 ? operands[i++] : createIndexConstant(rewriter, op->getLoc(), s)); if (sizes.empty()) sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1)); // Compute the total number of memref elements. Value *cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create( op->getLoc(), getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); // Compute the total amount of bytes to allocate. auto elementType = type.getElementType(); assert((elementType.isIntOrFloat() || elementType.isa()) && "invalid memref element type"); uint64_t elementSize = 0; if (auto vectorType = elementType.dyn_cast()) elementSize = vectorType.getNumElements() * llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); else elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); cumulativeSize = rewriter.create( op->getLoc(), getIndexType(), ArrayRef{ cumulativeSize, createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. auto module = op->getParentOfType(); FuncOp mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); module.push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Value *allocated = rewriter .create(op->getLoc(), getVoidPtrType(), rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) .getResult(0); auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo(); allocated = rewriter.create(op->getLoc(), elementPtrType, ArrayRef(allocated)); // Deal with static memrefs if (numOperands == 0) return rewriter.replaceOp(op, allocated); // Create the MemRef descriptor. auto structType = lowering.convertType(type); Value *memRefDescriptor = rewriter.create( op->getLoc(), structType, ArrayRef{}); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, allocated, getIntegerArrayAttr(rewriter, 0)); // Store dynamically allocated sizes in the descriptor. Dynamic sizes are // passed in as operands. for (auto indexedSize : llvm::enumerate(operands)) { memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), getIntegerArrayAttr(rewriter, 1 + indexedSize.index())); } // Return the final value of the descriptor. rewriter.replaceOp(op, memRefDescriptor); } }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. FuncOp freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); op->getParentOfType().push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); auto hasStaticShape = type.getUnderlyingType()->isPointerTy(); Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); Value *bufferPtr = extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(), elementPtrType, hasStaticShape); Value *casted = rewriter.create( op->getLoc(), getVoidPtrType(), bufferPtr); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); } }; struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult match(Operation *op) const override { auto memRefCastOp = cast(op); MemRefType sourceType = memRefCastOp.getOperand()->getType().cast(); MemRefType targetType = memRefCastOp.getType(); return (isSupportedMemRefType(targetType) && isSupportedMemRefType(sourceType)) ? matchSuccess() : matchFailure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); auto targetType = memRefCastOp.getType(); auto sourceType = memRefCastOp.getOperand()->getType().cast(); // Copy the data buffer pointer. auto elementTypePtr = getMemRefElementPtrType(targetType, lowering); Value *buffer = extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(), elementTypePtr, sourceType.hasStaticShape()); // Account for static memrefs as target types if (targetType.hasStaticShape()) return rewriter.replaceOp(op, buffer); // Create the new MemRef descriptor. auto structType = lowering.convertType(targetType); Value *newDescriptor = rewriter.create( op->getLoc(), structType, ArrayRef{}); // Otherwise target type is dynamic memref, so create a proper descriptor. newDescriptor = rewriter.create( op->getLoc(), structType, newDescriptor, buffer, getIntegerArrayAttr(rewriter, 0)); // Fill in the dynamic sizes of the new descriptor. If the size was // dynamic, copy it from the old descriptor. If the size was static, insert // the constant. Note that the positions of dynamic sizes in the // descriptors start from 1 (the buffer pointer is at position zero). int64_t sourceDynamicDimIdx = 1; int64_t targetDynamicDimIdx = 1; for (int i = 0, e = sourceType.getRank(); i < e; ++i) { // Ignore new static sizes (they will be known from the type). If the // size was dynamic, update the index of dynamic types. if (targetType.getShape()[i] != -1) { if (sourceType.getShape()[i] == -1) ++sourceDynamicDimIdx; continue; } auto sourceSize = sourceType.getShape()[i]; Value *size = sourceSize == -1 ? rewriter.create( op->getLoc(), getIndexType(), transformed.source(), // NB: dynamic memref getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++)) : createIndexConstant(rewriter, op->getLoc(), sourceSize); newDescriptor = rewriter.create( op->getLoc(), structType, newDescriptor, size, getIntegerArrayAttr(rewriter, targetDynamicDimIdx++)); } assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() && "source dynamic dimensions were not processed"); assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() && "target dynamic dimensions were not set up"); rewriter.replaceOp(op, newDescriptor); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult match(Operation *op) const override { auto dimOp = cast(op); MemRefType type = dimOp.getOperand()->getType().cast(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand()->getType().cast(); auto shape = type.getShape(); uint64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor and define static size // as a constant. if (shape[index] == -1) { // Find the position of the dynamic dimension in the list of dynamic sizes // by counting the number of preceding dynamic dimensions. Start from 1 // because the buffer pointer is at position zero. int64_t position = 1; for (uint64_t i = 0; i < index; ++i) { if (shape[i] == -1) ++position; } rewriter.replaceOpWithNewOp( op, getIndexType(), transformed.memrefOrTensor(), getIntegerArrayAttr(rewriter, position)); } else { rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); } } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Base = LoadStoreOpLowering; PatternMatchResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); return isSupportedMemRefType(type) ? this->matchSuccess() : this->matchFailure(); } // Given subscript indices and array sizes in row-major order, // i_n, i_{n-1}, ..., i_1 // s_n, s_{n-1}, ..., s_1 // obtain a value that corresponds to the linearized subscript // \sum_k i_k * \prod_{j=1}^{k-1} s_j // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); Value *linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); } return linearized; } // Given the MemRef type, a descriptor and a list of indices, extract the data // buffer pointer from the descriptor, convert multi-dimensional subscripts // into a linearized index (using dynamic size data from the descriptor if // necessary) and get the pointer to the buffer element identified by the // indices. Value *getElementPtr(Location loc, Type elementTypePtr, ArrayRef shape, Value *memRefDescriptor, ArrayRef indices, ConversionPatternRewriter &rewriter) const { // Get the list of MemRef sizes. Static sizes are defined as constants. // Dynamic sizes are extracted from the MemRef descriptor, where they start // from the position 1 (the buffer is at position 0). SmallVector sizes; unsigned dynamicSizeIdx = 1; for (int64_t s : shape) { if (s == -1) { Value *size = rewriter.create( loc, this->getIndexType(), memRefDescriptor, this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++)); sizes.push_back(size); } else { sizes.push_back(this->createIndexConstant(rewriter, loc, s)); } } // The second and subsequent operands are access subscripts. Obtain the // linearized address in the buffer. Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); Value *dataPtr = rewriter.create( loc, elementTypePtr, memRefDescriptor, this->getIntegerArrayAttr(rewriter, 0)); return rewriter.create(loc, elementTypePtr, ArrayRef{dataPtr, subscript}, ArrayRef{}); } // This is a getElementPtr variant, where the value is a direct raw pointer. // If a shape is empty, we are dealing with a zero-dimensional memref. Return // the pointer unmodified in this case. Otherwise, linearize subscripts to // obtain the offset with respect to the base pointer. Use this offset to // compute and return the element pointer. Value *getRawElementPtr(Location loc, Type elementTypePtr, ArrayRef shape, Value *rawDataPtr, ArrayRef indices, ConversionPatternRewriter &rewriter) const { if (shape.empty()) return rawDataPtr; SmallVector sizes; for (int64_t s : shape) { sizes.push_back(this->createIndexConstant(rewriter, loc, s)); } Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); return rewriter.create( loc, elementTypePtr, ArrayRef{rawDataPtr, subscript}, ArrayRef{}); } Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr, ArrayRef indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { auto ptrType = getMemRefElementPtrType(type, this->lowering); auto shape = type.getShape(); if (type.hasStaticShape()) { // NB: If memref was statically-shaped, dataPtr is pointer to raw data. return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter); } return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); auto elementType = lowering.convertType(type.getElementType()); rewriter.replaceOpWithNewOp(op, elementType, ArrayRef{dataPtr}); return matchSuccess(); } }; // Store opreation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return matchSuccess(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = this->lowering.convertType(indexCastOp.getResult()->getType()) .cast(); auto sourceType = transformed.in()->getType().cast(); unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); return matchSuccess(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, lowering.convertType(cmpiOp.getResult()->getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return matchSuccess(); } }; struct CmpFOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, lowering.convertType(cmpfOp.getResult()->getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return matchSuccess(); } }; struct SIToFPLowering : public OneToOneLLVMOpLowering { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMTerminatorLowering; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, ArrayRef> operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, properOperands, destinations, operands, op->getAttrs()); return this->matchSuccess(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp( op, llvm::ArrayRef(), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, llvm::ArrayRef(operands.front()), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); return matchSuccess(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); Value *packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], getIntegerArrayAttr(rewriter, i)); } rewriter.replaceOpWithNewOp( op, llvm::makeArrayRef(packed), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); return matchSuccess(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; } // namespace static void ensureDistinctSuccessors(Block &bb) { auto *terminator = bb.getTerminator(); // Find repeated successors with arguments. llvm::SmallDenseMap> successorPositions; for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { Block *successor = terminator->getSuccessor(i); // Blocks with no arguments are safe even if they appear multiple times // because they don't need PHI nodes. if (successor->getNumArguments() == 0) continue; successorPositions[successor].push_back(i); } // If a successor appears for the second or more time in the terminator, // create a new dummy block that unconditionally branches to the original // destination, and retarget the terminator to branch to this new block. // There is no need to pass arguments to the dummy block because it will be // dominated by the original block and can therefore use any values defined in // the original block. for (const auto &successor : successorPositions) { const auto &positions = successor.second; // Start from the second occurrence of a block in the successor list. for (auto position = std::next(positions.begin()), end = positions.end(); position != end; ++position) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); auto builder = OpBuilder(dummyBlock); SmallVector operands( terminator->getSuccessorOperands(*position)); builder.create(terminator->getLoc(), successor.first, operands); terminator->setSuccessor(dummyBlock, *position); for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; ++i) terminator->eraseSuccessorOperand(*position, i); } } } void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { for (auto f : m.getOps()) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } } } /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed patterns.insert< AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter); } // Convert types using the stored LLVM IR module. Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } // Create an LLVM IR structure type if there is more than one result. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertType(t).dyn_cast(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } /// Create an instance of LLVMTypeConverter in the given context. static std::unique_ptr makeStandardToLLVMTypeConverter(MLIRContext *context) { return llvm::make_unique(context); } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { // By default, the patterns are those converting Standard operations to the // LLVMIR dialect. explicit LLVMLoweringPass( LLVMPatternListFiller patternListFiller = populateStdToLLVMConversionPatterns, LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) : patternListFiller(patternListFiller), typeConverterMaker(converterBuilder) {} // Run the dialect converter on the module. void runOnModule() override { if (!typeConverterMaker || !patternListFiller) return signalPassFailure(); ModuleOp m = getModule(); LLVM::ensureDistinctSuccessors(m); std::unique_ptr typeConverter = typeConverterMaker(&getContext()); if (!typeConverter) return signalPassFailure(); OwningRewritePatternList patterns; populateLoopToStdConversionPatterns(patterns, m.getContext()); patternListFiller(*typeConverter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp([&](FuncOp op) { return typeConverter->isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) signalPassFailure(); } // Callback for creating a list of patterns. It is called every time in // runOnModule since applyPartialConversion consumes the list. LLVMPatternListFiller patternListFiller; // Callback for creating an instance of type converter. The converter // constructor needs an MLIRContext, which is not available until runOnModule. LLVMTypeConverterMaker typeConverterMaker; }; } // end namespace std::unique_ptr mlir::createConvertToLLVMIRPass() { return llvm::make_unique(); } std::unique_ptr mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker) { return llvm::make_unique(patternListFiller, typeConverterMaker); } static PassRegistration pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");