//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Linalg/IR/LinalgTypes.h" #include "mlir/Linalg/Passes.h" #include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::LLVM; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using add = ValueBuilder; using addi = ValueBuilder; using bitcast = ValueBuilder; using cmpi = ValueBuilder; using constant = ValueBuilder; using extractvalue = ValueBuilder; using gep = ValueBuilder; using insertvalue = ValueBuilder; using llvm_call = OperationBuilder; using llvm_icmp = ValueBuilder; using llvm_load = ValueBuilder; using llvm_store = OperationBuilder; using llvm_select = ValueBuilder; using mul = ValueBuilder; using sub = ValueBuilder; using undef = ValueBuilder; using llvm_alloca = ValueBuilder; using llvm_return = OperationBuilder; template static LLVMType getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { return lowering.convertType(containerType.getElementType()) .template cast() .getPointerTo(); } // Convert the given type to the LLVM IR Dialect type. The following // conversions are supported: // - an Index type is converted into an LLVM integer type with pointer // bitwidth (analogous to intptr_t in C); // - an Integer type is converted into an LLVM integer type of the same width; // - an F32 type is converted into an LLVM float type // - a Buffer, Range or View is converted into an LLVM structure type // containing the respective dynamic values. static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { auto *context = t.getContext(); auto int64Ty = lowering.convertType(IntegerType::get(64, context)) .cast(); // A buffer descriptor contains the pointer to a flat region of storage and // the size of the region. // // template // struct { // Elem *ptr; // int64_t size; // }; if (auto bufferType = t.dyn_cast()) { auto ptrTy = getPtrToElementType(bufferType, lowering); return LLVMType::getStructTy(ptrTy, int64Ty); } // Range descriptor contains the range bounds and the step as 64-bit integers. // // struct { // int64_t min; // int64_t max; // int64_t step; // }; if (t.isa()) return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); // View descriptor contains the pointer to the data buffer, followed by a // 64-bit integer containing the distance between the beginning of the buffer // and the first element to be accessed through the view, followed by two // arrays, each containing as many 64-bit integers as the rank of the View. // The first array represents the size, in number of original elements, of the // view along the given dimension. When taking the view, the size is the // difference between the upper and the lower bound of the range. The second // array represents the "stride" (in tensor abstraction sense), i.e. the // number of consecutive elements of the underlying buffer that separate two // consecutive elements addressable through the view along the given // dimension. When taking the view, the strides are constructed as products // of the original sizes along the trailing dimensions, multiplied by the view // step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1}, // i.e. the view of a complete memref, will have strides N and 1. A view with // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3. // // template // struct { // Elem *ptr; // int64_t offset; // int64_t sizes[Rank]; // int64_t strides[Rank]; // }; if (auto viewType = t.dyn_cast()) { auto ptrTy = getPtrToElementType(viewType, lowering); auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy); } return Type(); } // Create an array attribute containing integer attributes with values provided // in `position`. static ArrayAttr positionAttr(Builder &builder, ArrayRef position) { SmallVector attrs; attrs.reserve(position.size()); for (auto p : position) attrs.push_back(builder.getI64IntegerAttr(p)); return builder.getArrayAttr(attrs); } // BufferAllocOp creates a new `!linalg.buffer` value. class BufferAllocOpConversion : public LLVMOpLowering { public: explicit BufferAllocOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto indexType = IndexType::get(op->getContext()); auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. auto module = op->getParentOfType(); FuncOp mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); mallocFunc = FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); module.push_back(mallocFunc); } // Get MLIR types for injecting element pointer. auto allocOp = cast(op); auto elementType = allocOp.getElementType(); uint64_t elementSize = 0; if (auto vectorType = elementType.dyn_cast()) elementSize = vectorType.getNumElements() * llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); else elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); auto bufferType = allocOp.getResult()->getType().cast(); auto elementPtrType = getPtrToElementType(bufferType, lowering); auto bufferDescriptorType = convertLinalgType(allocOp.getResult()->getType(), lowering); // Emit IR for creating a new buffer descriptor with an underlying malloc. edsc::ScopedContext context(rewriter, op->getLoc()); auto constantSize = bufferType.getBufferSize(); Value *size = constantSize ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize)) .getValue() : operands[0]; Value *allocSize = mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); Value *allocated = llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize) .getOperation() ->getResult(0); allocated = bitcast(elementPtrType, allocated); Value *desc = undef(bufferDescriptorType); desc = insertvalue(bufferDescriptorType, desc, allocated, positionAttr(rewriter, 0)); desc = insertvalue(bufferDescriptorType, desc, size, positionAttr(rewriter, 1)); rewriter.replaceOp(op, desc); return matchSuccess(); } }; // BufferDeallocOp creates no value. class BufferDeallocOpConversion : public LLVMOpLowering { public: explicit BufferDeallocOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. auto module = op->getParentOfType(); FuncOp freeFunc = module.lookupSymbol("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); module.push_back(freeFunc); } // Get MLIR types for extracting element pointer. auto deallocOp = cast(op); auto elementPtrTy = getPtrToElementType( deallocOp.getOperand()->getType().cast(), lowering); // Emit MLIR for buffer_dealloc. edsc::ScopedContext context(rewriter, op->getLoc()); Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], positionAttr(rewriter, 0))); llvm_call(ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); rewriter.replaceOp(op, llvm::None); return matchSuccess(); } }; // BufferSizeOp creates a new `index` value. class BufferSizeOpConversion : public LLVMOpLowering { public: BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); edsc::ScopedContext context(rewriter, op->getLoc()); rewriter.replaceOp( op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))}); return matchSuccess(); } }; // DimOp creates a new `index` value. class DimOpConversion : public LLVMOpLowering { public: explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); auto indexTy = lowering.convertType(rewriter.getIndexType()); edsc::ScopedContext context(rewriter, op->getLoc()); rewriter.replaceOp( op, {extractvalue( indexTy, operands[0], positionAttr(rewriter, {2, static_cast(dimOp.getIndex())}))}); return matchSuccess(); } }; namespace { // Common functionality for Linalg LoadOp and StoreOp conversion to the // LLVM IR Dialect. template class LoadStoreOpConversion : public LLVMOpLowering { public: explicit LoadStoreOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(Op::getOperationName(), context, lowering_) {} using Base = LoadStoreOpConversion; // Compute the pointer to an element of the buffer underlying the view given // current view indices. Use the base offset and strides stored in the view // descriptor to emit IR iteratively computing the actual offset, followed by // a getelementptr. This must be called under an edsc::ScopedContext. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, ConversionPatternRewriter &rewriter) const { auto loadOp = cast(op); auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { return positionAttr(rewriter, values); }; // Linearize subscripts as: // base_offset + SUM_i index_i * stride_i. Value *base = extractvalue(elementTy, viewDescriptor, pos(0)); Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1)); for (int i = 0, e = loadOp.getRank(); i < e; ++i) { Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i})); Value *additionalOffset = mul(indices[i], stride); offset = add(offset, additionalOffset); } return gep(elementTy, base, offset); } }; } // namespace // A load is converted into the actual address computation, getelementptr and // an LLVM IR load. class LoadOpConversion : public LoadStoreOpConversion { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementTy = lowering.convertType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; ArrayRef indices = operands.drop_front(); auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); rewriter.replaceOp(op, {llvm_load(elementTy, ptr)}); return matchSuccess(); } }; // RangeOp creates a new range descriptor. class RangeOpConversion : public LLVMOpLowering { public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); edsc::ScopedContext context(rewriter, op->getLoc()); // Fill in an aggregate value of the descriptor. Value *desc = undef(rangeDescriptorTy); desc = insertvalue(rangeDescriptorTy, desc, operands[0], positionAttr(rewriter, 0)); desc = insertvalue(rangeDescriptorTy, desc, operands[1], positionAttr(rewriter, 1)); desc = insertvalue(rangeDescriptorTy, desc, operands[2], positionAttr(rewriter, 2)); rewriter.replaceOp(op, desc); return matchSuccess(); } }; class SliceOpConversion : public LLVMOpLowering { public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SliceOpOperandAdaptor adaptor(operands); auto sliceOp = cast(op); auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Helper function to create an integer array attribute out of a list of // values. auto pos = [&rewriter](ArrayRef values) { return positionAttr(rewriter, values); }; // Helper function to obtain the ptr of the given `view`. auto getViewPtr = [pos, this](ViewType type, Value *view) -> Value * { auto elementPtrTy = getPtrToElementType(type, lowering); return extractvalue(elementPtrTy, view, pos(0)); }; edsc::ScopedContext context(rewriter, op->getLoc()); // Declare the view descriptor and insert data ptr. Value *desc = undef(viewDescriptorTy); desc = insertvalue(viewDescriptorTy, desc, getViewPtr(viewType, adaptor.view()), pos(0)); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(viewType.getRank()); for (int i = 0, e = viewType.getRank(); i < e; ++i) { strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i})); } // Compute and insert base offset. Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1)); for (int i = 0, e = viewType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = sliceOp.indexing(i)->getType().isa() ? static_cast(extractvalue(int64Ty, indexing, pos(0))) : indexing; Value *product = mul(min, strides[i]); baseOffset = add(baseOffset, product); } desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; for (auto en : llvm::enumerate(sliceOp.indexings())) { Value *indexing = en.value(); if (indexing->getType().isa()) { int i = en.index(); Value *rangeDescriptor = adaptor.indexings()[i]; Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *size = sub(max, min); Value *stride = mul(strides[i], step); desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims})); desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims})); ++numNewDims; } } rewriter.replaceOp(op, desc); return matchSuccess(); } }; // A store is converted into the actual address computation, getelementptr and // an LLVM IR store. class StoreOpConversion : public LoadStoreOpConversion { using Base::Base; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); Value *data = operands[0]; Value *viewDescriptor = operands[1]; ArrayRef indices = operands.drop_front(2); Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); llvm_store(data, ptr); rewriter.replaceOp(op, llvm::None); return matchSuccess(); } }; class ViewOpConversion : public LLVMOpLowering { public: explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef values) { return positionAttr(rewriter, values); }; // First operand to `view` is the buffer descriptor. Value *bufferDescriptor = operands[0]; // Declare the descriptor of the view. edsc::ScopedContext context(rewriter, op->getLoc()); Value *desc = undef(viewDescriptorTy); // Copy the buffer pointer from the old descriptor to the new one. Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0)); desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0)); // Zero base offset. auto indexTy = rewriter.getIndexType(); Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); // Compute and insert view sizes (max - min along the range). int numRanges = llvm::size(viewOp.ranges()); Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1)); for (int i = numRanges - 1; i >= 0; --i) { // Update stride. Value *rangeDescriptor = operands[1 + i]; Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *stride = mul(runningStride, step); desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i})); // Update size. Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *size = sub(max, min); desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i})); // Update stride for the next dimension. if (i > 0) runningStride = mul(runningStride, max); } rewriter.replaceOp(op, desc); return matchSuccess(); } }; // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { auto implFnName = (libFn.getName().str() + "_impl"); auto module = libFn.getParentOfType(); if (auto f = module.lookupSymbol(implFnName)) { return f; } SmallVector fnArgTypes; for (auto t : libFn.getType().getInputs()) { assert(t && t.isa() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast().getPointerTo()); } auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. auto implFnDefn = FuncOp::create(libFn.getLoc(), implFnName, implFnType); module.push_back(implFnDefn); return implFnDefn; } // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template static FuncOp getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, ConversionPatternRewriter &rewriter) { auto linalgOp = cast(op); auto fnName = linalgOp.getLibraryCallName(); if (fnName.empty()) { op->emitWarning("No library call defined for: ") << *op; return FuncOp(); } auto module = op->getParentOfType(); if (auto f = module.lookupSymbol(fnName)) { return f; } // Get the Function type consistent with LLVM Lowering. SmallVector inputTypes; for (auto operand : op->getOperands()) inputTypes.push_back(lowering.convertType(operand->getType())); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType); module.push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } static void getLLVMLibraryCallDefinition(FuncOp fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); // Generate the function body. OpBuilder builder(fn.addEntryBlock()); edsc::ScopedContext scope(builder, fn.getLoc()); SmallVector implFnArgs; // Create a constant 1. auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), IntegerAttr::get(IndexType::get(fn.getContext()), 1)); for (auto arg : fn.getArguments()) { // Allocate a stack for storing the argument value. The stack is passed to // the implementation function. auto alloca = llvm_alloca(arg->getType().cast().getPointerTo(), one) .getValue(); implFnArgs.push_back(alloca); llvm_store(arg, alloca); } llvm_call(ArrayRef(), builder.getSymbolRefAttr(implFn), implFnArgs); llvm_return{ArrayRef()}; } namespace { // The conversion class from Linalg to LLVMIR. class LinalgTypeConverter : public LLVMTypeConverter { using LLVMTypeConverter::LLVMTypeConverter; public: Type convertType(Type t) override { if (auto result = LLVMTypeConverter::convertType(t)) return result; return convertLinalgType(t, *this); } void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); } ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations.getArrayRef(); } private: /// List of library functions declarations needed during dialect conversion llvm::SetVector libraryFnDeclarations; }; } // end anonymous namespace // LinalgOpConversion creates a new call to the // `LinalgOp::getLibraryCallName()` function. // The implementation of the function can be either in the same module or in an // externally linked library. template class LinalgOpConversion : public LLVMOpLowering { public: explicit LinalgOpConversion(MLIRContext *context, LinalgTypeConverter &lowering_) : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); if (!f) return matchFailure(); static_cast(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getSymbolRefAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); rewriter.replaceOpWithNewOp(op, operands, ArrayRef{named}); return matchSuccess(); } }; /// Populate the given list with patterns that convert from Linalg to LLVM. static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns .insert, LinalgOpConversion, LinalgOpConversion, LoadOpConversion, RangeOpConversion, SliceOpConversion, StoreOpConversion, ViewOpConversion>( ctx, converter); } namespace { struct LowerLinalgToLLVMPass : public ModulePass { void runOnModule(); }; } // namespace // This is currently written as a standalone function because the lowering to // affine will look different than lowering to LLVM and it is still unclear how // everything will be eventually structured. static void lowerLinalgSubViewOps(FuncOp &f) { f.walk([&](SubViewOp op) { OpBuilder b(op); ScopedContext scope(b, op.getLoc()); auto *view = op.getView(); SmallVector ranges; for (auto en : llvm::enumerate(op.getRanges())) { using edsc::op::operator<; using linalg::intrinsics::dim; unsigned rank = en.index(); auto sliceRange = en.value(); auto size = dim(view, rank); ValueHandle ub(sliceRange.max); auto max = edsc::intrinsics::select(size < ub, size, ub); ranges.push_back(range(sliceRange.min, max, sliceRange.step)); } op.replaceAllUsesWith(slice(view, ranges)); op.erase(); }); } void LowerLinalgToLLVMPass::runOnModule() { auto module = getModule(); for (auto f : module.getOps()) lowerLinalgSubViewOps(f); // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LinalgTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, patterns, &converter))) { signalPassFailure(); } // Emit the function body of any Library function that was declared. for (auto fn : converter.getLibraryFnDeclarations()) { getLLVMLibraryCallDefinition(fn, converter); } } std::unique_ptr mlir::linalg::createLowerLinalgToLLVMPass() { return llvm::make_unique(); } static PassRegistration pass("linalg-lower-to-llvm-dialect", "Lower the operations from the linalg dialect into the LLVM dialect");