//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::LLVM; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using add = ValueBuilder; using addi = ValueBuilder; using bitcast = ValueBuilder; using cmpi = ValueBuilder; using constant = ValueBuilder; using extractvalue = ValueBuilder; using gep = ValueBuilder; using insertvalue = ValueBuilder; using llvm_call = OperationBuilder; using llvm_icmp = ValueBuilder; using llvm_load = ValueBuilder; using llvm_store = OperationBuilder; using llvm_select = ValueBuilder; using mul = ValueBuilder; using ptrtoint = ValueBuilder; using sub = ValueBuilder; using llvm_undef = ValueBuilder; using urem = ValueBuilder; using llvm_alloca = ValueBuilder; using llvm_return = OperationBuilder; template static LLVMType getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { return lowering.convertType(containerType.getElementType()) .template cast() .getPointerTo(); } // Convert the given type to the LLVM IR Dialect type. The following // conversions are supported: // - an Index type is converted into an LLVM integer type with pointer // bitwidth (analogous to intptr_t in C); // - an Integer type is converted into an LLVM integer type of the same width; // - an F32 type is converted into an LLVM float type // - a Buffer, Range or View is converted into an LLVM structure type // containing the respective dynamic values. static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { auto *context = t.getContext(); auto int64Ty = lowering.convertType(IntegerType::get(64, context)) .cast(); // A buffer descriptor contains the pointer to a flat region of storage and // the size of the region. // // template // struct { // void *baseAlloc; // Elem *ptr; // int64_t size; // }; if (auto bufferType = t.dyn_cast()) { auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto ptrTy = getPtrToElementType(bufferType, lowering); return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty); } // Range descriptor contains the range bounds and the step as 64-bit integers. // // struct { // int64_t min; // int64_t max; // int64_t step; // }; if (t.isa()) return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); return Type(); } static constexpr int kBasePtrPosInBuffer = 0; static constexpr int kPtrPosInBuffer = 1; static constexpr int kSizePosInBuffer = 2; static constexpr int kPtrPosInView = 0; static constexpr int kOffsetPosInView = 1; static constexpr int kSizePosInView = 2; static constexpr int kStridePosInView = 3; namespace { /// Factor out the common information for all view conversions: /// 1. common types in (standard and LLVM dialects) /// 2. `pos` method /// 3. view descriptor construction `desc`. class BaseViewConversionHelper { public: BaseViewConversionHelper(Location loc, MemRefType memRefType, ConversionPatternRewriter &rewriter, LLVMTypeConverter &lowering) : zeroDMemRef(memRefType.getRank() == 0), elementTy(getPtrToElementType(memRefType, lowering)), int64Ty( lowering.convertType(rewriter.getIntegerType(64)).cast()), desc(nullptr), rewriter(rewriter) { assert(isStrided(memRefType) && "expected strided memref type"); viewDescriptorTy = lowering.convertType(memRefType).cast(); desc = rewriter.create(loc, viewDescriptorTy); } ArrayAttr pos(ArrayRef values) const { return rewriter.getI64ArrayAttr(values); }; bool zeroDMemRef; LLVMType elementTy, int64Ty, viewDescriptorTy; Value *desc; ConversionPatternRewriter &rewriter; }; } // namespace // RangeOp creates a new range descriptor. class RangeOpConversion : public LLVMOpLowering { public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); edsc::ScopedContext context(rewriter, op->getLoc()); // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); Value *desc = llvm_undef(rangeDescriptorTy); desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); rewriter.replaceOp(op, desc); return matchSuccess(); } }; /// Conversion pattern that transforms a linalg.slice op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride corresponding to the region of memory within the bounds of /// the parent view. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.slice op is replaced by the alloca'ed pointer. class SliceOpConversion : public LLVMOpLowering { public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SliceOpOperandAdaptor adaptor(operands); Value *baseDesc = adaptor.view(); auto sliceOp = cast(op); auto memRefType = sliceOp.getBaseViewType(); BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(), rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; edsc::ScopedContext context(rewriter, op->getLoc()); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) strides[i] = extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i})); // Compute base offset. Value *baseOffset = extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView)); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = indexing; if (sliceOp.indexing(i)->getType().isa()) min = extractvalue(int64Ty, indexing, helper.pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); } // Insert base pointer. auto ptrPos = helper.pos(kPtrPosInView); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // Insert base offset. desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); // Corner case, no sizes or strides: early return the descriptor. if (helper.zeroDMemRef) return rewriter.replaceOp(op, desc), matchSuccess(); Value *zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; for (auto en : llvm::enumerate(sliceOp.indexings())) { Value *indexing = en.value(); if (indexing->getType().isa()) { int rank = en.index(); Value *rangeDescriptor = adaptor.indexings()[rank]; Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); Value *baseSize = extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank})); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); Value *size = sub(max, min); // Bound lower by zero. size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); Value *stride = mul(strides[rank], step); desc = insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims})); desc = insertvalue(desc, stride, helper.pos({kStridePosInView, numNewDims})); ++numNewDims; } } rewriter.replaceOp(op, desc); return matchSuccess(); } }; /// Conversion pattern that transforms a linalg.transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.transpose op is replaced by the alloca'ed pointer. class TransposeOpConversion : public LLVMOpLowering { public: explicit TransposeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. TransposeOpOperandAdaptor adaptor(operands); Value *baseDesc = adaptor.view(); auto transposeOp = cast(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) return rewriter.replaceOp(op, baseDesc), matchSuccess(); BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(), rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; edsc::ScopedContext context(rewriter, op->getLoc()); // Copy the base pointer from the old descriptor to the new one. ArrayAttr ptrPos = helper.pos(kPtrPosInView); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // Copy the offset pointer from the old descriptor to the new one. ArrayAttr offPos = helper.pos(kOffsetPosInView); desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); Value *size = extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, sourcePos})); desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos})); Value *stride = extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, sourcePos})); desc = insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos})); } rewriter.replaceOp(op, desc); return matchSuccess(); } }; // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public LLVMOpLowering { public: explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return matchSuccess(); } }; // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template static SymbolRefAttr getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { auto linalgOp = cast(op); auto fnName = linalgOp.getLibraryCallName(); if (fnName.empty()) { op->emitWarning("No library call defined for: ") << *op; return {}; } // fnName is a dynamic std::String, unique it via a SymbolRefAttr. SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); auto module = op->getParentOfType(); if (module.lookupSymbol(fnName)) { return fnNameAttr; } SmallVector inputTypes(op->getOperandTypes()); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); OpBuilder::InsertionGuard guard(rewriter); // Insert before module terminator. rewriter.setInsertionPoint(module.getBody(), std::prev(module.getBody()->end())); rewriter.create(op->getLoc(), fnNameAttr.getValue(), libFnType, ArrayRef{}); return fnNameAttr; } namespace { // The conversion class from Linalg to LLVMIR. class LinalgTypeConverter : public LLVMTypeConverter { using LLVMTypeConverter::LLVMTypeConverter; public: Type convertType(Type t) override { if (auto result = LLVMTypeConverter::convertType(t)) return result; return convertLinalgType(t, *this); } }; } // end anonymous namespace // LinalgOpConversion creates a new call to the // `LinalgOp::getLibraryCallName()` function. // The implementation of the function can be either in the same module or in an // externally linked library. template class LinalgOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) return this->matchFailure(); SmallVector operands(op.getOperands().begin(), op.getOperands().end()); rewriter.replaceOpWithNewOp(op, libraryCallName.getValue(), ArrayRef{}, operands); return this->matchSuccess(); } }; /// Conversion pattern specialization for CopyOp. This kicks in when both input /// and output permutations are left unspecified or are the identity. template <> class LinalgOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override { auto inputPerm = op.inputPermutation(); if (inputPerm.hasValue() && !inputPerm->isIdentity()) return matchFailure(); auto outputPerm = op.outputPermutation(); if (outputPerm.hasValue() && !outputPerm->isIdentity()) return matchFailure(); auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) return matchFailure(); SmallVector operands(op.getOperands().begin(), op.getOperands().end()); rewriter.replaceOpWithNewOp(op, libraryCallName.getValue(), ArrayRef{}, operands); return matchSuccess(); } }; /// A non-conversion rewrite pattern kicks in to convert CopyOp with /// permutations into a sequence of TransposeOp and permutation-free CopyOp. /// This interplays together with TransposeOpConversion and /// LinalgConversion to create a path to the LLVM dialect. class CopyTransposeConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override { Value *in = op.input(), *out = op.output(); // If either inputPerm or outputPerm are non-identities, insert transposes. auto inputPerm = op.inputPermutation(); if (inputPerm.hasValue() && !inputPerm->isIdentity()) in = rewriter.create(op.getLoc(), in, AffineMapAttr::get(*inputPerm)); auto outputPerm = op.outputPermutation(); if (outputPerm.hasValue() && !outputPerm->isIdentity()) out = rewriter.create( op.getLoc(), out, AffineMapAttr::get(*outputPerm)); // If nothing was transposed, fail and let the conversion kick in. if (in == op.input() && out == op.output()) return matchFailure(); rewriter.replaceOpWithNewOp(op, in, out); return matchSuccess(); } }; /// A non-conversion rewrite pattern kicks in to convert SubViewOp into RangeOps /// and SliceOps. class SubViewOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(mlir::linalg::SubViewOp op, PatternRewriter &rewriter) const override { auto *view = op.getView(); SmallVector ranges; for (auto sliceRange : op.getRanges()) ranges.push_back(rewriter.create( op.getLoc(), sliceRange.min, sliceRange.max, sliceRange.step)); rewriter.replaceOpWithNewOp(op, view, ranges); return matchSuccess(); } }; /// Populate the given list with patterns that convert from Linalg to Standard. static void populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. patterns.insert, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, SubViewOpConversion>(ctx); } /// Populate the given list with patterns that convert from Linalg to LLVM. static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx, converter); } namespace { struct LowerLinalgToLLVMPass : public ModulePass { void runOnModule() override; }; } // namespace void LowerLinalgToLLVMPass::runOnModule() { auto module = getModule(); // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LinalgTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addLegalOp(); if (failed(applyFullConversion(module, target, patterns, &converter))) signalPassFailure(); } std::unique_ptr> mlir::linalg::createLowerLinalgToLLVMPass() { return std::make_unique(); } static PassRegistration pass("convert-linalg-to-llvm", "Lower the operations from the linalg dialect into the LLVM dialect");