diff options
| author | Alex Zinenko <zinenko@google.com> | 2019-12-16 05:16:35 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-16 05:17:14 -0800 |
| commit | 0684aa9a8bcb9823ccf3f55d4e180d8a4df13201 (patch) | |
| tree | b990b4286d45f25352ba30fba500e56596a3bdc6 /mlir/lib/Conversion/StandardToLLVM | |
| parent | 44fc7d72b3cb44147394e22f1f21ad36cca7bca8 (diff) | |
| download | bcm5719-llvm-0684aa9a8bcb9823ccf3f55d4e180d8a4df13201.tar.gz bcm5719-llvm-0684aa9a8bcb9823ccf3f55d4e180d8a4df13201.zip | |
Make memref promotion during std->LLVM lowering the default calling convention
During the conversion from the standard dialect to the LLVM dialect,
memref-typed arguments are promoted from registers to memory and passed into
functions by pointer. This had been introduced into the lowering to work around
the abesnce of calling convention modeling in MLIR to enable better
interoperability with LLVM IR generated from C, and has been exerciced for
several months. Make this promotion the default calling covention when
converting to the LLVM dialect. This adds the documentation, simplifies the
code and makes the conversion consistent across function operations and
function types used in other places, e.g. in high-order functions or
attributes, which would not follow the same rule previously.
PiperOrigin-RevId: 285751280
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 46 |
1 files changed, 17 insertions, 29 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 508868de54d..fa6512010c8 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -123,9 +123,15 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType type, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Convert argument types one by one and check for errors. - for (auto &en : llvm::enumerate(type.getInputs())) - if (failed(convertSignatureArg(en.index(), en.value(), result))) + for (auto &en : llvm::enumerate(type.getInputs())) { + Type type = en.value(); + auto converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>(); + if (!converted) return {}; + if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) + converted = converted.getPointerTo(); + result.addInputs(en.index(), converted); + } SmallVector<LLVM::LLVMType, 8> argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); @@ -522,41 +528,23 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { ConversionPatternRewriter &rewriter) const override { auto funcOp = cast<FuncOp>(op); FunctionType type = funcOp.getType(); - // Pack the result types into a struct. - Type packedResult; - if (type.getNumResults() != 0) - if (!(packedResult = lowering.packFunctionResults(type.getResults()))) - return matchFailure(); - LLVM::LLVMType resultType = packedResult - ? packedResult.cast<LLVM::LLVMType>() - : LLVM::LLVMType::getVoidTy(&dialect); - - SmallVector<LLVM::LLVMType, 4> argTypes; - argTypes.reserve(type.getNumInputs()); + + // Store the positions of memref-typed arguments so that we can emit loads + // from them to follow the calling convention. SmallVector<unsigned, 4> promotedArgIndices; promotedArgIndices.reserve(type.getNumInputs()); + for (auto en : llvm::enumerate(type.getInputs())) { + if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>()) + promotedArgIndices.push_back(en.index()); + } // Convert the original function arguments. Struct arguments are promoted to // pointer to struct arguments to allow calling external functions with // various ABIs (e.g. compiled from C/C++ on platform X). auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); - for (auto en : llvm::enumerate(type.getInputs())) { - auto t = en.value(); - auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>(); - if (!converted) - return matchFailure(); - if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) { - converted = converted.getPointerTo(); - promotedArgIndices.push_back(en.index()); - } - argTypes.push_back(converted); - } - for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx) - result.addInputs(idx, argTypes[idx]); - - auto llvmType = LLVM::LLVMType::getFunctionTy( - resultType, argTypes, varargsAttr && varargsAttr.getValue()); + auto llvmType = lowering.convertFunctionSignature( + funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); // Only retain those attributes that are not constructed by build. SmallVector<NamedAttribute, 4> attributes; |

