diff options
| author | Alexander Belyaev <pifon@google.com> | 2019-10-11 05:13:18 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-10-11 05:13:55 -0700 |
| commit | 00d2a37e32067a6b41d16d605dfeb8637cc4cfbb (patch) | |
| tree | 7d9480f9f6b8bea3779f7c5f109ddb972ef9a063 /mlir/lib/Conversion/StandardToLLVM | |
| parent | 304e44a6b0eab92114761a50d36bbe6cc371ec10 (diff) | |
| download | bcm5719-llvm-00d2a37e32067a6b41d16d605dfeb8637cc4cfbb.tar.gz bcm5719-llvm-00d2a37e32067a6b41d16d605dfeb8637cc4cfbb.zip | |
Add unary ops and ExpOp to Standard Dialect.
PiperOrigin-RevId: 274152154
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 104 |
1 files changed, 79 insertions, 25 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 76de499592e..206dde773e3 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, return res; } +template <typename SourceOp, unsigned OpCount> struct OpCountValidator { + static_assert( + std::is_base_of< + typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>, + SourceOp>::value, + "wrong operand count"); +}; + +template <typename SourceOp> struct OpCountValidator<SourceOp, 1> { + static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value, + "expected a single operand"); +}; + +template <typename SourceOp, unsigned OpCount> void ValidateOpCount() { + OpCountValidator<SourceOp, OpCount>(); +} + // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect -// Ops for binary ops with one result. This supports higher-dimensional vector +// Ops for N-ary ops with one result. This supports higher-dimensional vector // types. -template <typename SourceOp, typename TargetOp> -struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { +template <typename SourceOp, typename TargetOp, unsigned OpCount> +struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; - using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>; + using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>; // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { - static_assert( - std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value, - "expected binary op"); + ValidateOpCount<SourceOp, OpCount>(); static_assert( std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, "expected single result op"); static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, SourceOp>::value, - "expected single result op"); + "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. for (Value *operand : operands) { @@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } - assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); + assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type"); auto llvmVectorTy = llvmTy; // Iteratively extract a position coordinates with basis `arraySize` from a @@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors - Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmVectorTy, operands[0], position); - Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmVectorTy, operands[1], position); + SmallVector<Value *, OpCount> extractedOperands; + for (unsigned i = 0; i < OpCount; ++i) { + extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( + loc, llvmVectorTy, operands[i], position)); + } Value *newVal = rewriter.create<TargetOp>( - loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS}, - op->getAttrs()); + loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal, position); } @@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { } }; +template <typename SourceOp, typename TargetOp> +using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>; +template <typename SourceOp, typename TargetOp> +using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>; + // Specific lowerings. // FIXME: this should be tablegen'ed. +struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> { + using Super::Super; +}; struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { using Super::Super; }; @@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed + // clang-format off patterns.insert< - AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, - BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, - CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, - DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, - DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, - MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, - RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, - SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering, - SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, TruncateIOpLowering, XOrOpLowering, + AddFOpLowering, + AddIOpLowering, + AllocOpLowering, + AndOpLowering, + BranchOpLowering, + CallIndirectOpLowering, + CallOpLowering, + CmpFOpLowering, + CmpIOpLowering, + CondBranchOpLowering, + ConstLLVMOpLowering, + DeallocOpLowering, + DimOpLowering, + DivFOpLowering, + DivISOpLowering, + DivIUOpLowering, + ExpOpLowering, + FPExtLowering, + FPTruncLowering, + FuncOpConversion, + IndexCastOpLowering, + LoadOpLowering, + MemRefCastOpLowering, + MulFOpLowering, + MulIOpLowering, + OrOpLowering, + RemFOpLowering, + RemISOpLowering, + RemIUOpLowering, + ReturnOpLowering, + SIToFPLowering, + SelectOpLowering, + SignExtendIOpLowering, + SplatOpLowering, + StoreOpLowering, + SubFOpLowering, + SubIOpLowering, + TruncateIOpLowering, + XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); + // clang-format on } // Convert types using the stored LLVM IR module. |

