//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" using namespace mlir; LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); } // Get the LLVM context. llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { return module->getContext(); } // Extract an LLVM IR type from the LLVM IR dialect type. LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy( llvmDialect, module->getDataLayout().getPointerSizeInBits()); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: return LLVM::LLVMType::getFloatTy(llvmDialect); case mlir::StandardTypes::F64: return LLVM::LLVMType::getDoubleTy(llvmDialect); case mlir::StandardTypes::F16: return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { auto *mlirContext = llvmDialect->getContext(); return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), Type(); } default: llvm_unreachable("non-float type in convertFloatType"); } } // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return converted.getPointerTo(); } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType type, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) if (failed(convertSignatureArg(en.index(), en.value(), result))) return {}; SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(llvmDialect) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by // 2. a lowered `index`-type integer containing the distance between the // beginning of the buffer and the first element to be accessed through the // view, followed by // 3. an array containing as many `index`-type integers as the rank of the // MemRef: the array represents the size, in number of elements, of the memref // along the given dimension. For constant MemRef dimensions, the // corresponding size entry is a constant whose runtime value must match the // static value, followed by // 4. a second array containing as many `index`-type integers as the rank of // the MemRef: the second array represents the "stride" (in tensor abstraction // sense), i.e. the number of consecutive elements of the underlying buffer. // TODO(ntv, zinenko): add assertions for the static cases. // // template // struct { // Elem *ptr; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { int64_t offset; SmallVector strides; bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); assert(strideSuccess && "Non-strided layout maps must have been normalized away"); (void)strideSuccess; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); auto rank = type.getRank(); if (rank > 0) { auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); return LLVM::LLVMType::getStructTy(ptrTy, indexTy, arrayTy, arrayTy); } return LLVM::LLVMType::getStructTy(ptrTy, indexTy); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } // Dispatch based on the actual type. Return null type on error. Type LLVMTypeConverter::convertStandardType(Type type) { if (auto funcType = type.dyn_cast()) return convertFunctionType(funcType); if (auto intType = type.dyn_cast()) return convertIntegerType(intType); if (auto floatType = type.dyn_cast()) return convertFloatType(floatType); if (auto indexType = type.dyn_cast()) return convertIndexType(indexType); if (auto memRefType = type.dyn_cast()) return convertMemRefType(memRefType); if (auto vectorType = type.dyn_cast()) return convertVectorType(vectorType); if (auto llvmType = type.dyn_cast()) return llvmType; return {}; } // Convert the element type of the memref `t` to to an LLVM type using // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it // into the MLIR LLVM dialect type and return. static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) { auto elementType = t.getElementType(); auto converted = lowering.convertType(elementType); if (!converted) return {}; return converted.cast().getPointerTo(t.getMemorySpace()); } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. template class LLVMLegalizationPattern : public LLVMOpLowering { public: // Construct a conversion pattern. explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &lowering_) : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), lowering_), dialect(dialect_) {} // Get the LLVM IR dialect. LLVM::LLVMDialect &getDialect() const { return dialect; } // Get the LLVM context. llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); } // Get the LLVM module in which the types are constructed. llvm::Module &getModule() const { return dialect.getLLVMModule(); } // Get the MLIR type wrapping the LLVM integer type whose bit width is defined // by the pointer size used in the LLVM module. LLVM::LLVMType getIndexType() const { return LLVM::LLVMType::getIntNTy( &dialect, getModule().getDataLayout().getPointerSizeInBits()); } LLVM::LLVMType getVoidType() const { return LLVM::LLVMType::getVoidTy(&dialect); } // Get the MLIR type wrapping the LLVM i8* type. LLVM::LLVMType getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&dialect); } // Create an LLVM IR pseudo-operation defining the given index constant. Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const { auto attr = builder.getIntegerAttr(builder.getIndexType(), value); return builder.create(loc, getIndexType(), attr); } // Extract raw data pointer value from a value representing a memref. static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, Location loc, Value *memref, Type elementTypePtr) { return builder.create( loc, elementTypePtr, memref, builder.getIndexArrayAttr( LLVMTypeConverter::kPtrPosInMemRefDescriptor)); } protected: LLVM::LLVMDialect &dialect; }; struct FuncOpConversion : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); FunctionType type = funcOp.getType(); // Pack the result types into a struct. Type packedResult; if (type.getNumResults() != 0) if (!(packedResult = lowering.packFunctionResults(type.getResults()))) return matchFailure(); LLVM::LLVMType resultType = packedResult ? packedResult.cast() : LLVM::LLVMType::getVoidTy(&dialect); SmallVector argTypes; argTypes.reserve(type.getNumInputs()); SmallVector promotedArgIndices; promotedArgIndices.reserve(type.getNumInputs()); // Convert the original function arguments. Struct arguments are promoted to // pointer to struct arguments to allow calling external functions with // various ABIs (e.g. compiled from C/C++ on platform X). auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); for (auto en : llvm::enumerate(type.getInputs())) { auto t = en.value(); auto converted = lowering.convertType(t).dyn_cast(); if (!converted) return matchFailure(); if (t.isa()) { converted = converted.getPointerTo(); promotedArgIndices.push_back(en.index()); } argTypes.push_back(converted); } for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx) result.addInputs(idx, argTypes[idx]); auto llvmType = LLVM::LLVMType::getFunctionTy( resultType, argTypes, varargsAttr && varargsAttr.getValue()); // Only retain those attributes that are not constructed by build. SmallVector attributes; for (const auto &attr : funcOp.getAttrs()) { if (attr.first.is(SymbolTable::getSymbolAttrName()) || attr.first.is(impl::getTypeAttrName()) || attr.first.is("std.varargs")) continue; attributes.push_back(attr); } // Create an LLVM funcion. auto newFuncOp = rewriter.create( op->getLoc(), funcOp.getName(), llvmType, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); // Insert loads from memref descriptor pointers in function bodies. if (!newFuncOp.getBody().empty()) { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); for (unsigned idx : promotedArgIndices) { BlockArgument *arg = firstBlock->getArgument(idx); Value *loaded = rewriter.create(funcOp.getLoc(), arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } rewriter.eraseOp(op); return matchSuccess(); } }; //////////////// Support for Lowering operations on n-D vectors //////////////// namespace { // Helper struct to "unroll" operations on n-D vectors in terms of operations on // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. LLVM::LLVMType llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. LLVM::LLVMType llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; } // namespace // For >1-D vector types, extracts the necessary information to iterate over all // 1-D subvectors in the underlying llrepresentation of the n-D vecotr // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "extpected >1D vector type"); NDVectorTypeInfo info; info.llvmArrayTy = converter.convertType(vectorType).dyn_cast(); if (!info.llvmArrayTy) return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isArrayTy()) { info.arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } if (!llvmTy.isVectorTy()) return info; info.llvmVectorTy = llvmTy; return info; } // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where // P is the product of all the basis coordinates. // // Prerequisites: // Basis is an array of nonnegative integers (signed type inherited from // vector shape type). static SmallVector getCoordinates(ArrayRef basis, unsigned linearIndex) { SmallVector res; res.reserve(basis.size()); for (unsigned basisElement : llvm::reverse(basis)) { res.push_back(linearIndex % basisElement); linearIndex = linearIndex / basisElement; } if (linearIndex > 0) return {}; std::reverse(res.begin(), res.end()); return res; } // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. template void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, Lambda fun) { unsigned ub = 1; for (auto s : info.arraySizes) ub *= s; for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { auto coords = getCoordinates(info.arraySizes, linearIndex); // Linear index is out of bounds, we are done. if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); auto position = builder.getIndexArrayAttr(coords); fun(position); } } ////////////// End Support for Lowering operations on n-D vectors ////////////// // Basic lowering implementation for one-to-one rewriting from Standard Ops to // LLVM Dialect Ops. template struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMOpLowering; // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = this->lowering.packFunctionResults( llvm::to_vector<4>(op->getResultTypes())); assert(packedType && "type conversion failed, such operation should not " "have been matched"); } auto newOp = rewriter.create(op->getLoc(), packedType, operands, op->getAttrs()); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), this->matchSuccess(); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), this->matchSuccess(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getIndexArrayAttr(i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); } }; template struct OpCountValidator { static_assert( std::is_base_of< typename OpTrait::NOperands::template Impl, SourceOp>::value, "wrong operand count"); }; template struct OpCountValidator { static_assert(std::is_base_of, SourceOp>::value, "expected a single operand"); }; template void ValidateOpCount() { OpCountValidator(); } // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect // Ops for N-ary ops with one result. This supports higher-dimensional vector // types. template struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = NaryOpLLVMOpLowering; // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ValidateOpCount(); static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); static_assert(std::is_base_of, SourceOp>::value, "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. for (Value *operand : operands) { if (!operand || !operand->getType().isa()) return this->matchFailure(); } auto loc = op->getLoc(); auto llvmArrayTy = operands[0]->getType().cast(); if (!llvmArrayTy.isArrayTy()) { auto newOp = rewriter.create( op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); rewriter.replaceOp(op, newOp.getResult()); return this->matchSuccess(); } auto vectorType = op->getResult(0)->getType().dyn_cast(); if (!vectorType) return this->matchFailure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return this->matchFailure(); Value *desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (unsigned i = 0; i < OpCount; ++i) { extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operands[i], position)); } Value *newVal = rewriter.create( loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); }); rewriter.replaceOp(op, desc); return this->matchSuccess(); } }; template using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering; template using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering; // Specific lowerings. // FIXME: this should be tablegen'ed. struct ExpOpLowering : public UnaryOpLLVMOpLowering { using Super::Super; }; struct AddIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct SubIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct MulIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct DivISOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct DivIUOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct RemISOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct RemIUOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct AndOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct OrOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct XOrOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct AddFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct SubFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct MulFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct DivFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct RemFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct SelectOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct ConstLLVMOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), [](AffineMap map) { return map.isIdentity(); }); } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where the first element is a pointer // to the (typed) data buffer, and the remaining elements serve to store // dynamic sizes of the memref using LLVM-converted `index` type. struct AllocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult match(Operation *op) const override { MemRefType type = cast(op).getType(); if (isSupportedMemRefType(type)) return matchSuccess(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); if (failed(successStrides)) return matchFailure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which // is guaranteed when succeeded(successStrides)). Dynamic offset however can // never be alloc'ed. if (offset == MemRefType::getDynamicStrideOrOffset()) return matchFailure(); return matchSuccess(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto allocOp = cast(op); MemRefType type = allocOp.getType(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; sizes.reserve(type.getRank()); unsigned i = 0; for (int64_t s : type.getShape()) sizes.push_back(s == -1 ? operands[i++] : createIndexConstant(rewriter, op->getLoc(), s)); if (sizes.empty()) sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1)); // Compute the total number of memref elements. Value *cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create( op->getLoc(), getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto elementType = type.getElementType(); auto convertedPtrType = lowering.convertType(elementType).cast().getPointerTo(); auto nullPtr = rewriter.create(op->getLoc(), convertedPtrType); auto one = createIndexConstant(rewriter, op->getLoc(), 1); auto gep = rewriter.create(op->getLoc(), convertedPtrType, ArrayRef{nullPtr, one}); auto elementSize = rewriter.create(op->getLoc(), getIndexType(), gep); cumulativeSize = rewriter.create( op->getLoc(), getIndexType(), ArrayRef{cumulativeSize, elementSize}); // Insert the `malloc` declaration if it is not already present. auto module = op->getParentOfType(); auto mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); mallocFunc = moduleBuilder.create( rewriter.getUnknownLoc(), "malloc", LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), /*isVarArg=*/false)); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Value *allocated = rewriter .create(op->getLoc(), getVoidPtrType(), rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) .getResult(0); auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo( type.getMemorySpace()); allocated = rewriter.create(op->getLoc(), elementPtrType, ArrayRef(allocated)); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; assert(offset != MemRefType::getDynamicStrideOrOffset() && "unexpected dynamic offset"); // 0-D memref corner case: they have size 1 ... assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) || (strides.size() == sizes.size())) && "unexpected number of strides"); // Create the MemRef descriptor. auto structType = lowering.convertType(type); Value *memRefDescriptor = rewriter.create( op->getLoc(), structType, ArrayRef{}); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, allocated, rewriter.getIndexArrayAttr( LLVMTypeConverter::kPtrPosInMemRefDescriptor)); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, createIndexConstant(rewriter, op->getLoc(), offset), rewriter.getIndexArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. return rewriter.replaceOp(op, memRefDescriptor); // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. Value *runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); SmallVector strideValues(nStrides, nullptr); for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { int64_t index = nStrides - 1 - indexedStride.index(); if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index]` runningStride = runningStride ? rewriter.create( op->getLoc(), runningStride, sizes[index]) : createIndexConstant(rewriter, op->getLoc(), 1); else runningStride = createIndexConstant(rewriter, op->getLoc(), strides[index]); strideValues[index] = runningStride; } // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), rewriter.getI64ArrayAttr( {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, strideValues[index], rewriter.getI64ArrayAttr( {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); } // Return the final value of the descriptor. rewriter.replaceOp(op, memRefDescriptor); } }; // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template struct CallOpInterfaceLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = CallOpInterfaceLowering; using Base = LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto callOp = cast(op); // Pack the result types into a struct. Type packedResult; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { if (!(packedResult = this->lowering.packFunctionResults(resultTypes))) return this->matchFailure(); } SmallVector opOperands(op->getOperands()); auto promoted = this->lowering.promoteMemRefDescriptors( op->getLoc(), opOperands, operands, rewriter); auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); // If < 2 results, packing did not do anything and we can just return. if (numResults < 2) { SmallVector results(newOp.getResults()); rewriter.replaceOp(op, results); return this->matchSuccess(); } // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around // a particular interaction between MemRefType and CallOp lowering. Find a // way to avoid special casing. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getIndexArrayAttr(i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); } }; struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); freeFunc = moduleBuilder.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), /*isVarArg=*/false)); } auto type = transformed.memref()->getType().cast(); Type elementPtrType = type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor); Value *bufferPtr = extractMemRefElementPtr( rewriter, op->getLoc(), transformed.memref(), elementPtrType); Value *casted = rewriter.create( op->getLoc(), getVoidPtrType(), bufferPtr); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); } }; struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult match(Operation *op) const override { auto memRefCastOp = cast(op); MemRefType sourceType = memRefCastOp.getOperand()->getType().cast(); MemRefType targetType = memRefCastOp.getType(); return (isSupportedMemRefType(targetType) && isSupportedMemRefType(sourceType)) ? matchSuccess() : matchFailure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); // memref_cast is defined for source and destination memref types with the // same element type, same mappings, same address space and same rank. // Therefore a simple bitcast suffices. If not it is undefined behavior. auto targetStructType = lowering.convertType(memRefCastOp.getType()); rewriter.replaceOpWithNewOp(op, targetStructType, transformed.source()); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand()->getType().cast(); auto shape = type.getShape(); int64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor. if (ShapedType::isDynamic(shape[index])) rewriter.replaceOpWithNewOp( op, getIndexType(), transformed.memrefOrTensor(), rewriter.getI64ArrayAttr( {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); else // Use constant for static size. rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); return matchSuccess(); } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Base = LoadStoreOpLowering; PatternMatchResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); return isSupportedMemRefType(type) ? this->matchSuccess() : this->matchFailure(); } // Given subscript indices and array sizes in row-major order, // i_n, i_{n-1}, ..., i_1 // s_n, s_{n-1}, ..., s_1 // obtain a value that corresponds to the linearized subscript // \sum_k i_k * \prod_{j=1}^{k-1} s_j // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); Value *linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); } return linearized; } // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value *getStridedElementPtr(Location loc, Type elementTypePtr, Value *memRefDescriptor, ArrayRef indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { auto indexTy = this->getIndexType(); Value *base = this->extractMemRefElementPtr(rewriter, loc, memRefDescriptor, elementTypePtr); Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? rewriter.create( loc, indexTy, memRefDescriptor, rewriter.getIndexArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)) : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value *stride; if (strides[i] != MemRefType::getDynamicStrideOrOffset()) { // Use static stride. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]); stride = rewriter.create(loc, indexTy, attr); } else { // Use dynamic stride. stride = rewriter.create( loc, indexTy, memRefDescriptor, rewriter.getIndexArrayAttr( {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); } Value *additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = rewriter.create(loc, offsetValue, additionalOffset); } return rewriter.create(loc, elementTypePtr, base, offsetValue); } Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc, ArrayRef indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { auto ptrType = getMemRefElementPtrType(type, this->lowering); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, offset, rewriter); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); return matchSuccess(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return matchSuccess(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = this->lowering.convertType(indexCastOp.getResult()->getType()) .cast(); auto sourceType = transformed.in()->getType().cast(); unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); return matchSuccess(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, lowering.convertType(cmpiOp.getResult()->getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return matchSuccess(); } }; struct CmpFOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, lowering.convertType(cmpfOp.getResult()->getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return matchSuccess(); } }; struct SIToFPLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct FPExtLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct FPTruncLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct SignExtendIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct TruncateIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; struct ZeroExtendIOpLowering : public OneToOneLLVMOpLowering { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMTerminatorLowering; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, ArrayRef> operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, properOperands, destinations, operands, op->getAttrs()); return this->matchSuccess(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp(op, llvm::ArrayRef(), llvm::ArrayRef(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, llvm::ArrayRef(operands.front()), llvm::ArrayRef(), op->getAttrs()); return matchSuccess(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); Value *packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], rewriter.getIndexArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, llvm::makeArrayRef(packed), llvm::ArrayRef(), op->getAttrs()); return matchSuccess(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 1-d vector result types are lowered. struct SplatOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return matchFailure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = lowering.convertType(splatOp.getType()); Value *undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); return matchSuccess(); } }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 2+-d vector result types are lowered by the // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct SplatNdOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); OperandAdaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return matchFailure(); // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) return matchFailure(); // Construct returned value. Value *desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value *vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value *v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); v = rewriter.create(loc, v, v, zeroAttrs); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); rewriter.replaceOp(op, desc); return matchSuccess(); } }; /// Conversion pattern that transforms a op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value *getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ArrayRef dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { return ShapedType::isDynamic(v); }); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value *getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value *nextSize, Value *runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); ViewOpOperandAdaptor adaptor(operands); auto sourceMemRefType = viewOp.source()->getType().cast(); auto sourceElementTy = lowering.convertType(sourceMemRefType.getElementType()) .dyn_cast(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = lowering.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), matchFailure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return op->emitWarning("cannot cast to non-strided shape"), matchFailure(); // Create the descriptor. Value *desc = rewriter.create(loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value *sourceDescriptor = adaptor.source(); Value *bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), rewriter.create( loc, sourceElementTy.getPointerTo(), sourceDescriptor, rewriter.getI64ArrayAttr( LLVMTypeConverter::kPtrPosInMemRefDescriptor))); desc = rewriter.create( loc, desc, bitcastPtr, rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor)); // Offset. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); (void)numDynamicSizes; auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 || offset != MemRefType::getDynamicStrideOrOffset()); Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset()) ? createIndexConstant(rewriter, loc, offset) // TODO(ntv): better adaptor. : sizeAndOffsetOperands.back(); desc = rewriter.create( loc, desc, baseOffset, rewriter.getI64ArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(op, desc), matchSuccess(); // Update sizes and strides. if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); Value *stride = nullptr, *nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeAndOffsetOperands, i); desc = rewriter.create( loc, desc, size, rewriter.getI64ArrayAttr( {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); desc = rewriter.create( loc, desc, stride, rewriter.getI64ArrayAttr( {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); nextSize = size; } rewriter.replaceOp(op, desc); return matchSuccess(); } }; } // namespace static void ensureDistinctSuccessors(Block &bb) { auto *terminator = bb.getTerminator(); // Find repeated successors with arguments. llvm::SmallDenseMap> successorPositions; for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { Block *successor = terminator->getSuccessor(i); // Blocks with no arguments are safe even if they appear multiple times // because they don't need PHI nodes. if (successor->getNumArguments() == 0) continue; successorPositions[successor].push_back(i); } // If a successor appears for the second or more time in the terminator, // create a new dummy block that unconditionally branches to the original // destination, and retarget the terminator to branch to this new block. // There is no need to pass arguments to the dummy block because it will be // dominated by the original block and can therefore use any values defined in // the original block. for (const auto &successor : successorPositions) { const auto &positions = successor.second; // Start from the second occurrence of a block in the successor list. for (auto position = std::next(positions.begin()), end = positions.end(); position != end; ++position) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); auto builder = OpBuilder(dummyBlock); SmallVector operands( terminator->getSuccessorOperands(*position)); builder.create(terminator->getLoc(), successor.first, operands); terminator->setSuccessor(dummyBlock, *position); for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; ++i) terminator->eraseSuccessorOperand(*position, i); } } } void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { for (auto f : m.getOps()) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } } } /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< AddFOpLowering, AddIOpLowering, AllocOpLowering, AndOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpFOpLowering, CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, DimOpLowering, DivFOpLowering, DivISOpLowering, DivIUOpLowering, ExpOpLowering, FPExtLowering, FPTruncLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemFOpLowering, RemISOpLowering, RemIUOpLowering, ReturnOpLowering, SIToFPLowering, SelectOpLowering, SignExtendIOpLowering, SplatOpLowering, SplatNdOpLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); // clang-format on } // Convert types using the stored LLVM IR module. Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } // Create an LLVM IR structure type if there is more than one result. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertType(t).dyn_cast(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value *operand, OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand->getType().cast().getPointerTo(); Value *one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value *allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } SmallVector LLVMTypeConverter::promoteMemRefDescriptors( Location loc, ArrayRef opOperands, ArrayRef operands, OpBuilder &builder) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto *operand = std::get<0>(it); auto *llvmOperand = std::get<1>(it); if (!operand->getType().isa()) { promotedOperands.push_back(operand); continue; } promotedOperands.push_back( promoteOneMemRefDescriptor(loc, llvmOperand, builder)); } return promotedOperands; } /// Create an instance of LLVMTypeConverter in the given context. static std::unique_ptr makeStandardToLLVMTypeConverter(MLIRContext *context) { return std::make_unique(context); } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { // By default, the patterns are those converting Standard operations to the // LLVMIR dialect. explicit LLVMLoweringPass( LLVMPatternListFiller patternListFiller = populateStdToLLVMConversionPatterns, LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) : patternListFiller(patternListFiller), typeConverterMaker(converterBuilder) {} // Run the dialect converter on the module. void runOnModule() override { if (!typeConverterMaker || !patternListFiller) return signalPassFailure(); ModuleOp m = getModule(); LLVM::ensureDistinctSuccessors(m); std::unique_ptr typeConverter = typeConverterMaker(&getContext()); if (!typeConverter) return signalPassFailure(); OwningRewritePatternList patterns; populateLoopToStdConversionPatterns(patterns, m.getContext()); patternListFiller(*typeConverter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) signalPassFailure(); } // Callback for creating a list of patterns. It is called every time in // runOnModule since applyPartialConversion consumes the list. LLVMPatternListFiller patternListFiller; // Callback for creating an instance of type converter. The converter // constructor needs an MLIRContext, which is not available until runOnModule. LLVMTypeConverterMaker typeConverterMaker; }; } // end namespace std::unique_ptr> mlir::createLowerToLLVMPass() { return std::make_unique(); } std::unique_ptr> mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, LLVMTypeConverterMaker typeConverterMaker) { return std::make_unique(patternListFiller, typeConverterMaker); } static PassRegistration pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");