summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToLLVM
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2019-10-11 05:13:18 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-11 05:13:55 -0700
commit00d2a37e32067a6b41d16d605dfeb8637cc4cfbb (patch)
tree7d9480f9f6b8bea3779f7c5f109ddb972ef9a063 /mlir/lib/Conversion/StandardToLLVM
parent304e44a6b0eab92114761a50d36bbe6cc371ec10 (diff)
downloadbcm5719-llvm-00d2a37e32067a6b41d16d605dfeb8637cc4cfbb.tar.gz
bcm5719-llvm-00d2a37e32067a6b41d16d605dfeb8637cc4cfbb.zip
Add unary ops and ExpOp to Standard Dialect.
PiperOrigin-RevId: 274152154
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp104
1 files changed, 79 insertions, 25 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 76de499592e..206dde773e3 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
return res;
}
+template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
+ static_assert(
+ std::is_base_of<
+ typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
+ SourceOp>::value,
+ "wrong operand count");
+};
+
+template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
+ static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
+ "expected a single operand");
+};
+
+template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
+ OpCountValidator<SourceOp, OpCount>();
+}
+
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
-// Ops for binary ops with one result. This supports higher-dimensional vector
+// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
-template <typename SourceOp, typename TargetOp>
-struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+template <typename SourceOp, typename TargetOp, unsigned OpCount>
+struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
- using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
+ using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- static_assert(
- std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
- "expected binary op");
+ ValidateOpCount<SourceOp, OpCount>();
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
- "expected single result op");
+ "expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type.
for (Value *operand : operands) {
@@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
arraySizes.push_back(llvmTy.getArrayNumElements());
llvmTy = llvmTy.getArrayElementType();
}
- assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
+ assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
auto llvmVectorTy = llvmTy;
// Iteratively extract a position coordinates with basis `arraySize` from a
@@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
- Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[0], position);
- Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[1], position);
+ SmallVector<Value *, OpCount> extractedOperands;
+ for (unsigned i = 0; i < OpCount; ++i) {
+ extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[i], position));
+ }
Value *newVal = rewriter.create<TargetOp>(
- loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
- op->getAttrs());
+ loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
}
@@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
}
};
+template <typename SourceOp, typename TargetOp>
+using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
+template <typename SourceOp, typename TargetOp>
+using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+
// Specific lowerings.
// FIXME: this should be tablegen'ed.
+struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
+ using Super::Super;
+};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
@@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
+ // clang-format off
patterns.insert<
- AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
- BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
- CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
- DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
- DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
- MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
- RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
- SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering,
- SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering,
- SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
+ AddFOpLowering,
+ AddIOpLowering,
+ AllocOpLowering,
+ AndOpLowering,
+ BranchOpLowering,
+ CallIndirectOpLowering,
+ CallOpLowering,
+ CmpFOpLowering,
+ CmpIOpLowering,
+ CondBranchOpLowering,
+ ConstLLVMOpLowering,
+ DeallocOpLowering,
+ DimOpLowering,
+ DivFOpLowering,
+ DivISOpLowering,
+ DivIUOpLowering,
+ ExpOpLowering,
+ FPExtLowering,
+ FPTruncLowering,
+ FuncOpConversion,
+ IndexCastOpLowering,
+ LoadOpLowering,
+ MemRefCastOpLowering,
+ MulFOpLowering,
+ MulIOpLowering,
+ OrOpLowering,
+ RemFOpLowering,
+ RemISOpLowering,
+ RemIUOpLowering,
+ ReturnOpLowering,
+ SIToFPLowering,
+ SelectOpLowering,
+ SignExtendIOpLowering,
+ SplatOpLowering,
+ StoreOpLowering,
+ SubFOpLowering,
+ SubIOpLowering,
+ TruncateIOpLowering,
+ XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+ // clang-format on
}
// Convert types using the stored LLVM IR module.
OpenPOWER on IntegriCloud