diff options
Diffstat (limited to 'mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp')
-rw-r--r-- | mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 255 |
1 files changed, 133 insertions, 122 deletions
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index de183f8f76e..05bdf24a975 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -117,21 +117,22 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { if (t.isa<RangeType>()) return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); - // View descriptor contains the pointer to the data buffer, followed by a - // 64-bit integer containing the distance between the beginning of the buffer - // and the first element to be accessed through the view, followed by two - // arrays, each containing as many 64-bit integers as the rank of the View. - // The first array represents the size, in number of original elements, of the - // view along the given dimension. When taking the view, the size is the - // difference between the upper and the lower bound of the range. The second - // array represents the "stride" (in tensor abstraction sense), i.e. the - // number of consecutive elements of the underlying buffer that separate two - // consecutive elements addressable through the view along the given - // dimension. When taking the view, the strides are constructed as products - // of the original sizes along the trailing dimensions, multiplied by the view - // step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1}, - // i.e. the view of a complete memref, will have strides N and 1. A view with - // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3. + // A linalg.view type converts to a *pointer to* a view descriptor. The view + // descriptor contains the pointer to the data buffer, followed by a 64-bit + // integer containing the distance between the beginning of the buffer and the + // first element to be accessed through the view, followed by two arrays, each + // containing as many 64-bit integers as the rank of the View. The first array + // represents the size, in number of original elements, of the view along the + // given dimension. When taking the view, the size is the difference between + // the upper and the lower bound of the range. The second array represents the + // "stride" (in tensor abstraction sense), i.e. the number of consecutive + // elements of the underlying buffer that separate two consecutive elements + // addressable through the view along the given dimension. When taking the + // view, the strides are constructed as products of the original sizes along + // the trailing dimensions, multiplied by the view step. For example, a view + // of a MxN memref with ranges {0:M:1}, {0:N:1}, i.e. the view of a complete + // memref, will have strides N and 1. A view with ranges {0:M:2}, {0:N:3} + // will have strides 2*N and 3. // // template <typename Elem, size_t Rank> // struct { @@ -139,16 +140,24 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { // int64_t offset; // int64_t sizes[Rank]; // int64_t strides[Rank]; - // }; + // } *; if (auto viewType = t.dyn_cast<ViewType>()) { auto ptrTy = getPtrToElementType(viewType, lowering); auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); - return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy); + return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy) + .getPointerTo(); } return Type(); } +static constexpr int kPtrPosInBuffer = 0; +static constexpr int kSizePosInBuffer = 1; +static constexpr int kPtrPosInView = 0; +static constexpr int kOffsetPosInView = 1; +static constexpr int kSizePosInView = 2; +static constexpr int kStridePosInView = 3; + // Create an array attribute containing integer attributes with values provided // in `position`. static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) { @@ -192,10 +201,9 @@ public: llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); else elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); - auto bufferType = allocOp.getResult()->getType().cast<BufferType>(); + auto bufferType = allocOp.getBufferType(); auto elementPtrType = getPtrToElementType(bufferType, lowering); - auto bufferDescriptorType = - convertLinalgType(allocOp.getResult()->getType(), lowering); + auto bufferDescriptorTy = convertLinalgType(bufferType, lowering); // Emit IR for creating a new buffer descriptor with an underlying malloc. edsc::ScopedContext context(rewriter, op->getLoc()); @@ -212,11 +220,11 @@ public: .getOperation() ->getResult(0); allocated = bitcast(elementPtrType, allocated); - Value *desc = undef(bufferDescriptorType); - desc = insertvalue(bufferDescriptorType, desc, allocated, - positionAttr(rewriter, 0)); - desc = insertvalue(bufferDescriptorType, desc, size, - positionAttr(rewriter, 1)); + Value *desc = undef(bufferDescriptorTy); + desc = insertvalue(bufferDescriptorTy, desc, allocated, + positionAttr(rewriter, kPtrPosInBuffer)); + desc = insertvalue(bufferDescriptorTy, desc, size, + positionAttr(rewriter, kSizePosInBuffer)); rewriter.replaceOp(op, desc); return matchSuccess(); } @@ -246,13 +254,15 @@ public: // Get MLIR types for extracting element pointer. auto deallocOp = cast<BufferDeallocOp>(op); - auto elementPtrTy = getPtrToElementType( - deallocOp.getOperand()->getType().cast<BufferType>(), lowering); + auto elementPtrTy = + getPtrToElementType(deallocOp.getBufferType(), lowering); // Emit MLIR for buffer_dealloc. + BufferDeallocOpOperandAdaptor adaptor(operands); edsc::ScopedContext context(rewriter, op->getLoc()); - Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], - positionAttr(rewriter, 0))); + Value *casted = + bitcast(voidPtrTy, extractvalue(elementPtrTy, adaptor.buffer(), + positionAttr(rewriter, 0))); llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); rewriter.replaceOp(op, llvm::None); return matchSuccess(); @@ -270,8 +280,10 @@ public: ConversionPatternRewriter &rewriter) const override { auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); edsc::ScopedContext context(rewriter, op->getLoc()); + BufferSizeOpOperandAdaptor adaptor(operands); rewriter.replaceOp( - op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))}); + op, {extractvalue(int64Ty, adaptor.buffer(), + positionAttr(rewriter, kSizePosInBuffer))}); return matchSuccess(); } }; @@ -288,11 +300,11 @@ public: auto dimOp = cast<linalg::DimOp>(op); auto indexTy = lowering.convertType(rewriter.getIndexType()); edsc::ScopedContext context(rewriter, op->getLoc()); - rewriter.replaceOp( - op, - {extractvalue( - indexTy, operands[0], - positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))}); + auto pos = positionAttr( + rewriter, {kSizePosInView, static_cast<int>(dimOp.getIndex())}); + linalg::DimOpOperandAdaptor adaptor(operands); + Value *viewDescriptor = llvm_load(adaptor.view()); + rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)}); return matchSuccess(); } }; @@ -311,7 +323,7 @@ public: // current view indices. Use the base offset and strides stored in the view // descriptor to emit IR iteratively computing the actual offset, followed by // a getelementptr. This must be called under an edsc::ScopedContext. - Value *obtainDataPtr(Operation *op, Value *viewDescriptor, + Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr, ArrayRef<Value *> indices, ConversionPatternRewriter &rewriter) const { auto loadOp = cast<Op>(op); @@ -323,10 +335,13 @@ public: // Linearize subscripts as: // base_offset + SUM_i index_i * stride_i. - Value *base = extractvalue(elementTy, viewDescriptor, pos(0)); - Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1)); + Value *viewDescriptor = llvm_load(viewDescriptorPtr); + Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView)); + Value *offset = + extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView)); for (int i = 0, e = loadOp.getRank(); i < e; ++i) { - Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i})); + Value *stride = + extractvalue(int64Ty, viewDescriptor, pos({kStridePosInView, i})); Value *additionalOffset = mul(indices[i], stride); offset = add(offset, additionalOffset); } @@ -344,9 +359,8 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> { ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementTy = lowering.convertType(*op->result_type_begin()); - Value *viewDescriptor = operands[0]; - ArrayRef<Value *> indices = operands.drop_front(); - auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); + linalg::LoadOpOperandAdaptor adaptor(operands); + auto ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); rewriter.replaceOp(op, {llvm_load(elementTy, ptr)}); return matchSuccess(); } @@ -368,18 +382,23 @@ public: edsc::ScopedContext context(rewriter, op->getLoc()); // Fill in an aggregate value of the descriptor. + RangeOpOperandAdaptor adaptor(operands); Value *desc = undef(rangeDescriptorTy); - desc = insertvalue(rangeDescriptorTy, desc, operands[0], - positionAttr(rewriter, 0)); - desc = insertvalue(rangeDescriptorTy, desc, operands[1], - positionAttr(rewriter, 1)); - desc = insertvalue(rangeDescriptorTy, desc, operands[2], - positionAttr(rewriter, 2)); + desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0)); + desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1)); + desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2)); rewriter.replaceOp(op, desc); return matchSuccess(); } }; +/// Conversion pattern that transforms a linalg.slice op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride corresponding to the +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The linalg.slice op is replaced by the alloca'ed pointer. class SliceOpConversion : public LLVMOpLowering { public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) @@ -390,7 +409,8 @@ public: ConversionPatternRewriter &rewriter) const override { SliceOpOperandAdaptor adaptor(operands); auto sliceOp = cast<SliceOp>(op); - auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); + auto viewDescriptorPtrTy = + convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -399,26 +419,36 @@ public: auto pos = [&rewriter](ArrayRef<int> values) { return positionAttr(rewriter, values); }; - // Helper function to obtain the ptr of the given `view`. - auto getViewPtr = [pos, this](ViewType type, Value *view) -> Value * { - auto elementPtrTy = getPtrToElementType(type, lowering); - return extractvalue(elementPtrTy, view, pos(0)); - }; edsc::ScopedContext context(rewriter, op->getLoc()); - // Declare the view descriptor and insert data ptr. - Value *desc = undef(viewDescriptorTy); - desc = insertvalue(viewDescriptorTy, desc, - getViewPtr(viewType, adaptor.view()), pos(0)); + // Declare the view descriptor and insert data ptr *at the entry block of + // the function*, which is the preferred location for LLVM's analyses. + auto ip = rewriter.getInsertionPoint(); + auto ib = rewriter.getInsertionBlock(); + rewriter.setInsertionPointToStart( + &op->getParentOfType<FuncOp>().getBlocks().front()); + Value *one = + constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + // Alloca with proper alignment. + Value *allocatedDesc = + llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); + Value *desc = llvm_load(allocatedDesc); + rewriter.setInsertionPoint(ib, ip); + + Value *baseDesc = llvm_load(adaptor.view()); + + auto ptrPos = pos(kPtrPosInView); + auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering); + desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // TODO(ntv): extract sizes and emit asserts. SmallVector<Value *, 4> strides(viewType.getRank()); for (int i = 0, e = viewType.getRank(); i < e; ++i) { - strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i})); + strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i})); } // Compute and insert base offset. - Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1)); + Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView)); for (int i = 0, e = viewType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = @@ -428,7 +458,7 @@ public: Value *product = mul(min, strides[i]); baseOffset = add(baseOffset, product); } - desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); + desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. @@ -443,14 +473,15 @@ public: Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *size = sub(max, min); Value *stride = mul(strides[i], step); - desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims})); - desc = - insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims})); + desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims})); + desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims})); ++numNewDims; } } - rewriter.replaceOp(op, desc); + // Store back in alloca'ed region. + llvm_store(desc, allocatedDesc); + rewriter.replaceOp(op, allocatedDesc); return matchSuccess(); } }; @@ -463,16 +494,21 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> { matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); - Value *data = operands[0]; - Value *viewDescriptor = operands[1]; - ArrayRef<Value *> indices = operands.drop_front(2); - Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); - llvm_store(data, ptr); + linalg::StoreOpOperandAdaptor adaptor(operands); + Value *ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); + llvm_store(adaptor.value(), ptr); rewriter.replaceOp(op, llvm::None); return matchSuccess(); } }; +/// Conversion pattern that transforms a linalg.view op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride. +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The linalg.view op is replaced by the alloca'ed pointer. class ViewOpConversion : public LLVMOpLowering { public: explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) @@ -482,7 +518,9 @@ public: matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { auto viewOp = cast<ViewOp>(op); - auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); + ViewOpOperandAdaptor adaptor(operands); + auto viewDescriptorPtrTy = + convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -490,21 +528,34 @@ public: return positionAttr(rewriter, values); }; - // First operand to `view` is the buffer descriptor. - Value *bufferDescriptor = operands[0]; + Value *bufferDescriptor = adaptor.buffer(); + auto bufferTy = getPtrToElementType( + viewOp.buffer()->getType().cast<BufferType>(), lowering); // Declare the descriptor of the view. edsc::ScopedContext context(rewriter, op->getLoc()); - Value *desc = undef(viewDescriptorTy); + auto ip = rewriter.getInsertionPoint(); + auto ib = rewriter.getInsertionBlock(); + rewriter.setInsertionPointToStart( + &op->getParentOfType<FuncOp>().getBlocks().front()); + Value *one = + constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + // Alloca for proper alignment. + Value *allocatedDesc = + llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); + Value *desc = llvm_load(allocatedDesc); + rewriter.setInsertionPoint(ib, ip); // Copy the buffer pointer from the old descriptor to the new one. - Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0)); - desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0)); + Value *bufferAsViewElementType = + bitcast(elementTy, + extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer))); + desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView)); // Zero base offset. auto indexTy = rewriter.getIndexType(); Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); - desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1)); + desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); // Compute and insert view sizes (max - min along the range). int numRanges = llvm::size(viewOp.ranges()); @@ -514,18 +565,20 @@ public: Value *rangeDescriptor = operands[1 + i]; Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *stride = mul(runningStride, step); - desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i})); + desc = insertvalue(desc, stride, pos({kStridePosInView, i})); // Update size. Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *size = sub(max, min); - desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i})); + desc = insertvalue(desc, size, pos({kSizePosInView, i})); // Update stride for the next dimension. if (i > 0) runningStride = mul(runningStride, max); } - rewriter.replaceOp(op, desc); + // Store back in alloca'ed region. + llvm_store(desc, allocatedDesc); + rewriter.replaceOp(op, allocatedDesc); return matchSuccess(); } }; @@ -585,32 +638,6 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, return libFn; } -static void getLLVMLibraryCallDefinition(FuncOp fn, - LLVMTypeConverter &lowering) { - // Generate the implementation function definition. - auto implFn = getLLVMLibraryCallImplDefinition(fn); - - // Generate the function body. - OpBuilder builder(fn.addEntryBlock()); - edsc::ScopedContext scope(builder, fn.getLoc()); - SmallVector<Value *, 4> implFnArgs; - - // Create a constant 1. - auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), - IntegerAttr::get(IndexType::get(fn.getContext()), 1)); - for (auto arg : fn.getArguments()) { - // Allocate a stack for storing the argument value. The stack is passed to - // the implementation function. - auto alloca = - llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one) - .getValue(); - implFnArgs.push_back(alloca); - llvm_store(arg, alloca); - } - llvm_call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs); - llvm_return{ArrayRef<Value *>()}; -} - namespace { // The conversion class from Linalg to LLVMIR. class LinalgTypeConverter : public LLVMTypeConverter { @@ -622,16 +649,6 @@ public: return result; return convertLinalgType(t, *this); } - - void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); } - - ArrayRef<FuncOp> getLibraryFnDeclarations() { - return libraryFnDeclarations.getArrayRef(); - } - -private: - /// List of library functions declarations needed during dialect conversion - llvm::SetVector<FuncOp> libraryFnDeclarations; }; } // end anonymous namespace @@ -652,7 +669,6 @@ public: auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); if (!f) return matchFailure(); - static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getSymbolRefAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); @@ -727,11 +743,6 @@ void LowerLinalgToLLVMPass::runOnModule() { if (failed(applyPartialConversion(module, target, patterns, &converter))) { signalPassFailure(); } - - // Emit the function body of any Library function that was declared. - for (auto fn : converter.getLibraryFnDeclarations()) { - getLLVMLibraryCallDefinition(fn, converter); - } } std::unique_ptr<ModulePassBase> mlir::linalg::createLowerLinalgToLLVMPass() { |