diff options
6 files changed, 449 insertions, 73 deletions
diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h index bf3002f092e..8e5a7ceb15f 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h @@ -15,15 +15,52 @@ // limitations under the License. // ============================================================================= -#ifndef LINALG_CONVERTTOLLVMDIALECT_H_ -#define LINALG_CONVERTTOLLVMDIALECT_H_ +#ifndef LINALG1_CONVERTTOLLVMDIALECT_H_ +#define LINALG1_CONVERTTOLLVMDIALECT_H_ + +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Allocator.h" + +#include <memory> namespace mlir { +class DialectConversion; +class DialectOpConversion; +class MLIRContext; class Module; +class Type; +namespace LLVM { +class LLVMType; +} // end namespace LLVM } // end namespace mlir namespace linalg { +/// Convert the given Linalg dialect type `t` into an LLVM IR dialect type. +/// Keep all other types unmodified. +mlir::Type convertLinalgType(mlir::Type t); + +/// Allocate the conversion patterns for RangeOp, ViewOp and SliceOp from the +/// Linalg dialect to the LLVM IR dialect. The converters are allocated in the +/// `allocator` using the provided `context`. The latter must have the LLVM IR +/// dialect registered. +/// This function can be used to apply multiple conversion patterns in the same +/// pass. It does not have to be called explicitly before the conversion. +llvm::DenseSet<mlir::DialectOpConversion *> +allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator, + mlir::MLIRContext *context); + +/// Create a DialectConversion from the Linalg dialect to the LLVM IR dialect. +/// The conversion is set up to convert types and function signatures using +/// `convertLinalgType` and obtains operation converters by calling `initer`. +std::unique_ptr<mlir::DialectConversion> makeLinalgToLLVMLowering( + std::function<llvm::DenseSet<mlir::DialectOpConversion *>( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + initer); + +/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations +/// to the LLVM IR dialect types and operations in the given `module`. This is +/// the main entry point to the conversion. void convertToLLVM(mlir::Module &module); } // end namespace linalg -#endif // LINALG_CONVERTTOLLVMDIALECT_H_ +#endif // LINALG1_CONVERTTOLLVMDIALECT_H_ diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h b/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h new file mode 100644 index 00000000000..577981b85ed --- /dev/null +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/LLVMIntrinsics.h @@ -0,0 +1,41 @@ +//===- LLVMIntrinsics.h - declarative builders for LLVM dialect -*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG1_LLVMINTRINSICS_H_ +#define LINALG1_LLVMINTRINSICS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/LLVMIR/LLVMDialect.h" + +// Expose some LLVM IR instructions to declarative builders. +namespace intrinsics { +using undef = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::UndefOp>; +using insertvalue = + mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::InsertValueOp>; +using extractvalue = + mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ExtractValueOp>; +using constant = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ConstantOp>; +using add = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::AddOp>; +using sub = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::SubOp>; +using mul = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::MulOp>; +using load = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::LoadOp>; +using store = mlir::edsc::intrinsics::OperationBuilder<mlir::LLVM::StoreOp>; +using gep = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::GEPOp>; +} // end namespace intrinsics + +#endif // LINALG1_LLVMINTRINSICS_H_ diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index d7cb9189fa4..39c6a596f8e 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -40,12 +40,8 @@ #include "linalg1/Common.h" #include "linalg1/ConvertToLLVMDialect.h" -#include "linalg1/RangeOp.h" -#include "linalg1/RangeType.h" -#include "linalg1/SliceOp.h" -#include "linalg1/Types.h" -#include "linalg1/ViewOp.h" -#include "linalg1/ViewType.h" +#include "linalg1/LLVMIntrinsics.h" +#include "linalg1/Ops.h" using namespace mlir; @@ -55,9 +51,9 @@ using namespace mlir; // bitwidth (analogous to intptr_t in C); // - an Integer type is converted into an LLVM integer type of the same width; // - an F32 type is converted into an LLVM float type -// - a Memref, Range, or View is converted into an LLVM structure type -// containing the respective dynamic values. -LLVM::LLVMType convertType(Type t) { +// - a Range or View is converted into an LLVM structure type containing the +// respective dynamic values. +Type linalg::convertLinalgType(Type t) { auto *context = t.getContext(); auto *dialect = static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm")); @@ -78,30 +74,6 @@ LLVM::LLVMType convertType(Type t) { return LLVM::LLVMType::get(context, floatTy); } - // Memref descriptor contains the pointer to the data buffer, followed by - // as many 64-bit integers as the memref has dynamic sizes. These integers - // store the actual value of the dynamic size. - // - // template <typename Elem, size_t NumDynamicRanks> - // struct { - // Elem *ptr; - // int64_t dynRank_0, dynRank_1, ... dynRank_#NumDynamicRanks - // }; - if (auto memrefTy = t.dyn_cast<MemRefType>()) { - auto *elementTy = - convertType(memrefTy.getElementType()).getUnderlyingType(); - if (memrefTy.hasStaticShape()) - return LLVM::LLVMType::get(context, elementTy->getPointerTo()); - - int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); - auto *sizeTy = llvm::IntegerType::get(dialect->getLLVMContext(), width); - SmallVector<llvm::Type *, 4> types(1 + memrefTy.getNumDynamicDims(), - sizeTy); - types[0] = elementTy->getPointerTo(); - return LLVM::LLVMType::get( - context, llvm::StructType::get(dialect->getLLVMContext(), types)); - } - // Range descriptor contains the range bounds and the step as 64-bit integers. // // struct { @@ -139,7 +111,8 @@ LLVM::LLVMType convertType(Type t) { // int64_t strides[Rank]; // }; if (auto viewTy = t.dyn_cast<linalg::ViewType>()) { - auto *elemTy = convertType(viewTy.getElementType()) + auto *elemTy = linalg::convertLinalgType(viewTy.getElementType()) + .cast<LLVM::LLVMType>() .getUnderlyingType() ->getPointerTo(); auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext()); @@ -148,7 +121,8 @@ LLVM::LLVMType convertType(Type t) { return LLVM::LLVMType::get(context, structTy); } - llvm_unreachable("unsupported type"); + // All other types are kept as is. + return t; } // Create an array attribute containing integer attributes with values provided @@ -162,17 +136,6 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder, return builder.getArrayAttr(attrs); } -// Expose some LLVM IR instructions to declarative builders. -namespace intrinsics { -using undef = edsc::intrinsics::ValueBuilder<LLVM::UndefOp>; -using insertvalue = edsc::intrinsics::ValueBuilder<LLVM::InsertValueOp>; -using extractvalue = edsc::intrinsics::ValueBuilder<LLVM::ExtractValueOp>; -using constant = edsc::intrinsics::ValueBuilder<LLVM::ConstantOp>; -using add = edsc::intrinsics::ValueBuilder<LLVM::AddOp>; -using sub = edsc::intrinsics::ValueBuilder<LLVM::SubOp>; -using mul = edsc::intrinsics::ValueBuilder<LLVM::MulOp>; -} // end namespace intrinsics - // RangeOp creates a new range descriptor. class RangeOpConversion : public DialectOpConversion { public: @@ -188,7 +151,8 @@ public: SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override { auto rangeOp = op->cast<linalg::RangeOp>(); - auto rangeDescriptorType = convertType(rangeOp.getResult()->getType()); + auto rangeDescriptorType = + linalg::convertLinalgType(rangeOp.getResult()->getType()); using namespace intrinsics; auto context = edsc::ScopedContext(rewriter, op->getLoc()); @@ -219,10 +183,10 @@ public: SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override { auto viewOp = op->cast<linalg::ViewOp>(); - auto viewDescriptorType = convertType(viewOp.getViewType()); + auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType()); auto memrefType = viewOp.getSupportingMemRef()->getType().cast<MemRefType>(); - auto int64Ty = convertType(rewriter.getIntegerType(64)); + auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64)); // Helper function to create an integer array attribute out of a list of // values. @@ -258,10 +222,11 @@ public: if (type.hasStaticShape()) return memref; - auto elementTy = LLVM::LLVMType::get(type.getContext(), - convertType(type.getElementType()) - .getUnderlyingType() - ->getPointerTo()); + auto elementTy = LLVM::LLVMType::get( + type.getContext(), linalg::convertLinalgType(type.getElementType()) + .cast<LLVM::LLVMType>() + .getUnderlyingType() + ->getPointerTo()); return intrinsics::extractvalue(elementTy, memref, pos(0)); }; @@ -350,12 +315,14 @@ public: SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override { auto sliceOp = op->cast<linalg::SliceOp>(); - auto newViewDescriptorType = convertType(sliceOp.getViewType()); - auto elementType = - rewriter.getType<LLVM::LLVMType>(convertType(sliceOp.getElementType()) - .getUnderlyingType() - ->getPointerTo()); - auto int64Ty = convertType(rewriter.getIntegerType(64)); + auto newViewDescriptorType = + linalg::convertLinalgType(sliceOp.getViewType()); + auto elementType = rewriter.getType<LLVM::LLVMType>( + linalg::convertLinalgType(sliceOp.getElementType()) + .cast<LLVM::LLVMType>() + .getUnderlyingType() + ->getPointerTo()); + auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64)); auto pos = [&rewriter](ArrayRef<int> values) { return makePositionAttr(rewriter, values); @@ -450,24 +417,33 @@ public: } }; +llvm::DenseSet<mlir::DialectOpConversion *> +linalg::allocateDescriptorConverters(llvm::BumpPtrAllocator *allocator, + mlir::MLIRContext *context) { + return ConversionListBuilder<DropConsumer, RangeOpConversion, + SliceOpConversion, + ViewOpConversion>::build(allocator, context); +} + +namespace { // The conversion class from Linalg to LLVMIR. class Lowering : public DialectConversion { public: - Lowering() {} + explicit Lowering(std::function<llvm::DenseSet<mlir::DialectOpConversion *>( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + conversions) + : setup(conversions) {} protected: // Initialize the list of converters. llvm::DenseSet<DialectOpConversion *> initConverters(MLIRContext *context) override { - converterSotrage.Reset(); - return ConversionListBuilder<DropConsumer, RangeOpConversion, - SliceOpConversion, - ViewOpConversion>::build(&converterSotrage, - context); + converterStorage.Reset(); + return setup(&converterStorage, context); } // This gets called for block and region arguments, and attributes. - Type convertType(Type t) override { return ::convertType(t); } + Type convertType(Type t) override { return linalg::convertLinalgType(t); } // This gets called for function signatures. Convert function arguments and // results to the LLVM types, but keep the outer function type as built-in @@ -483,12 +459,12 @@ protected: SmallVector<Type, 4> argTypes; argTypes.reserve(t.getNumInputs()); for (auto ty : t.getInputs()) - argTypes.push_back(convertType(ty)); + argTypes.push_back(linalg::convertLinalgType(ty)); SmallVector<Type, 1> resultTypes; resultTypes.reserve(t.getNumResults()); for (auto ty : t.getResults()) - resultTypes.push_back(convertType(ty)); + resultTypes.push_back(linalg::convertLinalgType(ty)); assert(t.getNumResults() <= 1 && "NYI: multi-result functions"); return FunctionType::get(argTypes, resultTypes, t.getContext()); @@ -496,8 +472,21 @@ protected: private: // Storage for individual converters. - llvm::BumpPtrAllocator converterSotrage; + llvm::BumpPtrAllocator converterStorage; + + // Conversion setup. + std::function<llvm::DenseSet<mlir::DialectOpConversion *>( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + setup; }; +} // end anonymous namespace + +std::unique_ptr<mlir::DialectConversion> linalg::makeLinalgToLLVMLowering( + std::function<llvm::DenseSet<mlir::DialectOpConversion *>( + llvm::BumpPtrAllocator *, mlir::MLIRContext *context)> + initer) { + return llvm::make_unique<Lowering>(initer); +} void linalg::convertToLLVM(mlir::Module &module) { // Remove affine constructs if any by using an existing pass. @@ -509,7 +498,7 @@ void linalg::convertToLLVM(mlir::Module &module) { // Convert Linalg ops to the LLVM IR dialect using the converter defined // above. - auto r = Lowering().convert(&module); + auto r = Lowering(allocateDescriptorConverters).convert(&module); (void)r; assert(succeeded(r) && "conversion failed"); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp new file mode 100644 index 00000000000..bdb9d69f0c0 --- /dev/null +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -0,0 +1,111 @@ +//===- Conversion.cpp - Linalg to LLVM conversion driver ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +// RUN: %p/conversion | FileCheck %s + +#include "TestHarness.h" + +#include "linalg3/ConvertToLLVMDialect.h" + +#include "linalg1/Common.h" +#include "linalg2/Intrinsics.h" +#include "linalg3/Ops.h" +#include "linalg3/Transforms.h" +#include "mlir/IR/OpImplementation.h" + +using llvm::StringRef; + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace linalg; +using namespace linalg::common; +using namespace linalg::intrinsics; + +Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { + MLIRContext *context = module.getContext(); + auto dynamic2DMemRefType = floatMemRefType<2>(context); + mlir::Function *f = linalg::common::makeFunction( + module, name, + {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); + + ScopedContext scope(f); + // clang-format off + ValueHandle + M = dim(f->getArgument(0), 0), + N = dim(f->getArgument(2), 1), + K = dim(f->getArgument(0), 1), + rM = range(constant_index(0), M, constant_index(1)), + rN = range(constant_index(0), N, constant_index(1)), + rK = range(constant_index(0), K, constant_index(1)), + vA = view(f->getArgument(0), {rM, rK}), + vB = view(f->getArgument(1), {rK, rN}), + vC = view(f->getArgument(2), {rM, rN}); + matmul(vA, vB, vC); + ret(); + // clang-format on + + return f; +} + +TEST_FUNC(foo) { + MLIRContext context; + Module module(&context); + mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + lowerToLoops(f); + + convertLinalg3ToLLVM(module); + + // clang-format off + // CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*"> + // CHECK-NEXT: {{.*}} = llvm.load {{.*}} : !llvm<"float*"> + // CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*"> + // CHECK-NEXT: {{.*}} = llvm.load {{.*}} : !llvm<"float*"> + // CHECK: %159 = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm<"i64"> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm<"i64">) -> !llvm<"float*"> + // CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*"> + // clang-format on + module.print(llvm::outs()); +} + +int main() { + RUN_TESTS(); + return 0; +} diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h new file mode 100644 index 00000000000..8f122e05d2f --- /dev/null +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h @@ -0,0 +1,29 @@ +//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_CONVERTTOLLVMDIALECT_H_ +#define LINALG3_CONVERTTOLLVMDIALECT_H_ + +namespace mlir { +class Module; +} // end namespace mlir + +namespace linalg { +void convertLinalg3ToLLVM(mlir::Module &module); +} // end namespace linalg + +#endif // LINALG3_CONVERTTOLLVMDIALECT_H_ diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp new file mode 100644 index 00000000000..2417ecf75f3 --- /dev/null +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -0,0 +1,169 @@ +//===- ConvertToLLVMDialect.cpp - conversion from Linalg to LLVM dialect --===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/LLVMIR/Transforms.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "linalg1/ConvertToLLVMDialect.h" +#include "linalg1/LLVMIntrinsics.h" + +#include "linalg3/ConvertToLLVMDialect.h" +#include "linalg3/Ops.h" + +using namespace mlir; + +// Create an array attribute containing integer attributes with values provided +// in `position`. +static ArrayAttr makePositionAttr(FuncBuilder &builder, + ArrayRef<int> position) { + SmallVector<Attribute, 4> attrs; + attrs.reserve(position.size()); + for (auto p : position) + attrs.push_back(builder.getIntegerAttr(builder.getIntegerType(64), p)); + return builder.getArrayAttr(attrs); +} + +namespace { +// Common functionality for Linalg LoadOp and StoreOp conversion to the +// LLVM IR Dialect. +template <typename Op> +class LoadStoreOpConversion : public DialectOpConversion { +public: + explicit LoadStoreOpConversion(MLIRContext *context) + : DialectOpConversion(Op::getOperationName(), 1, context) {} + using Base = LoadStoreOpConversion<Op>; + + // Match the Op specified as template argument. + PatternMatchResult match(Operation *op) const override { + if (op->isa<Op>()) + return matchSuccess(); + return matchFailure(); + } + + // Compute the pointer to an element of the buffer underlying the view given + // current view indices. Use the base offset and strides stored in the view + // descriptor to emit IR iteratively computing the actual offset, followed by + // a getelementptr. + Value *obtainDataPtr(Operation *op, Value *viewDescriptor, + ArrayRef<Value *> indices, FuncBuilder &rewriter) const { + auto loadOp = op->cast<Op>(); + auto elementType = + loadOp.getViewType().template cast<linalg::ViewType>().getElementType(); + auto *llvmPtrType = linalg::convertLinalgType(elementType) + .template cast<LLVM::LLVMType>() + .getUnderlyingType() + ->getPointerTo(); + elementType = rewriter.getType<LLVM::LLVMType>(llvmPtrType); + auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64)); + + auto pos = [&rewriter](ArrayRef<int> values) { + return makePositionAttr(rewriter, values); + }; + + using namespace intrinsics; + + // Linearize subscripts as: + // base_offset + SUM_i index_i * stride_i. + Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1)); + for (int i = 0, e = loadOp.getRank(); i < e; ++i) { + Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i})); + Value *additionalOffset = mul(indices[i], stride); + offset = add(offset, additionalOffset); + } + Value *base = extractvalue(elementType, viewDescriptor, pos(0)); + return gep(elementType, base, ArrayRef<Value *>{offset}); + } +}; + +// A load is converted into the actual address computation, getelementptr and +// an LLVM IR load. +class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> { + using Base::Base; + SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, + FuncBuilder &rewriter) const override { + auto edscContext = edsc::ScopedContext(rewriter, op->getLoc()); + auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin()); + Value *viewDescriptor = operands[0]; + ArrayRef<Value *> indices = operands.drop_front(); + Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); + Value *element = intrinsics::load(elementType, ptr); + return {element}; + } +}; + +// A store is converted into the actual address computation, getelementptr and +// an LLVM IR store. +class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> { + using Base::Base; + SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, + FuncBuilder &rewriter) const override { + auto edscContext = edsc::ScopedContext(rewriter, op->getLoc()); + Value *viewDescriptor = operands[1]; + Value *data = operands[0]; + ArrayRef<Value *> indices = operands.drop_front(2); + Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); + intrinsics::store(data, ptr); + return {}; + } +}; + +} // end anonymous namespace + +// Helper function that allocates the descriptor converters and adds load/store +// coverters to the list. +static llvm::DenseSet<mlir::DialectOpConversion *> +allocateConversions(llvm::BumpPtrAllocator *allocator, + mlir::MLIRContext *context) { + auto conversions = linalg::allocateDescriptorConverters(allocator, context); + auto additional = + ConversionListBuilder<LoadOpConversion, StoreOpConversion>::build( + allocator, context); + conversions.insert(additional.begin(), additional.end()); + return conversions; +} + +void linalg::convertLinalg3ToLLVM(Module &module) { + // Remove affine constructs if any by using an existing pass. + PassManager pm; + pm.addPass(createLowerAffinePass()); + auto rr = pm.run(&module); + (void)rr; + assert(succeeded(rr) && "affine loop lowering failed"); + + auto lowering = makeLinalgToLLVMLowering(allocateConversions); + auto r = lowering->convert(&module); + (void)r; + assert(succeeded(r) && "conversion failed"); + + // Convert the remaining standard MLIR operations to the LLVM IR dialect using + // the default converter. + auto converter = createStdToLLVMConverter(); + r = converter->convert(&module); + (void)r; + assert(succeeded(r) && "second conversion failed"); +} |

