summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp68
1 files changed, 33 insertions, 35 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 1b70df6f8bd..2a034fd15c5 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -111,23 +111,21 @@ public:
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
- BaseViewConversionHelper(ValuePtr v) : d(v) {}
+ BaseViewConversionHelper(Value v) : d(v) {}
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
- ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
- void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); }
- ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
- void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); }
- ValuePtr offset() { return d.offset(rewriter(), loc()); }
- void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); }
- ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); }
- void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); }
- ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
- void setStride(unsigned i, ValuePtr v) {
- d.setStride(rewriter(), loc(), i, v);
- }
-
- operator ValuePtr() { return d; }
+ Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
+ void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
+ Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
+ void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
+ Value offset() { return d.offset(rewriter(), loc()); }
+ void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
+ Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
+ void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+ Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
+ void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
+
+ operator Value() { return d; }
private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
@@ -144,7 +142,7 @@ public:
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
@@ -154,7 +152,7 @@ public:
// Fill in an aggregate value of the descriptor.
RangeOpOperandAdaptor adaptor(operands);
- ValuePtr desc = llvm_undef(rangeDescriptorTy);
+ Value desc = llvm_undef(rangeDescriptorTy);
desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
@@ -177,7 +175,7 @@ public:
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
@@ -191,7 +189,7 @@ public:
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
// TODO(ntv): extract sizes and emit asserts.
- SmallVector<ValuePtr, 4> strides(memRefType.getRank());
+ SmallVector<Value, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
strides[i] = baseDesc.stride(i);
@@ -200,10 +198,10 @@ public:
};
// Compute base offset.
- ValuePtr baseOffset = baseDesc.offset();
+ Value baseOffset = baseDesc.offset();
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
- ValuePtr indexing = adaptor.indexings()[i];
- ValuePtr min = indexing;
+ Value indexing = adaptor.indexings()[i];
+ Value min = indexing;
if (sliceOp.indexing(i)->getType().isa<RangeType>())
min = extractvalue(int64Ty, indexing, pos(0));
baseOffset = add(baseOffset, mul(min, strides[i]));
@@ -220,29 +218,29 @@ public:
if (sliceOp.getViewType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), matchSuccess();
- ValuePtr zero =
+ Value zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
// Compute and insert view sizes (max - min along the range) and strides.
// Skip the non-range operands as they will be projected away from the view.
int numNewDims = 0;
for (auto en : llvm::enumerate(sliceOp.indexings())) {
- ValuePtr indexing = en.value();
+ Value indexing = en.value();
if (indexing->getType().isa<RangeType>()) {
int rank = en.index();
- ValuePtr rangeDescriptor = adaptor.indexings()[rank];
- ValuePtr min = extractvalue(int64Ty, rangeDescriptor, pos(0));
- ValuePtr max = extractvalue(int64Ty, rangeDescriptor, pos(1));
- ValuePtr step = extractvalue(int64Ty, rangeDescriptor, pos(2));
- ValuePtr baseSize = baseDesc.size(rank);
+ Value rangeDescriptor = adaptor.indexings()[rank];
+ Value min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+ Value max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value baseSize = baseDesc.size(rank);
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
- ValuePtr size = sub(max, min);
+ Value size = sub(max, min);
// Bound lower by zero.
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
- ValuePtr stride = mul(strides[rank], step);
+ Value stride = mul(strides[rank], step);
desc.setSize(numNewDims, size);
desc.setStride(numNewDims, stride);
++numNewDims;
@@ -268,7 +266,7 @@ public:
: LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -311,7 +309,7 @@ public:
: LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return matchSuccess();
@@ -446,7 +444,7 @@ public:
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
auto indexedGenericOp = cast<IndexedGenericOp>(op);
auto numLoops = indexedGenericOp.getNumLoops();
- SmallVector<ValuePtr, 4> operands;
+ SmallVector<Value, 4> operands;
operands.reserve(numLoops + op.getNumOperands());
for (unsigned i = 0; i < numLoops; ++i) {
operands.push_back(zero);
@@ -470,7 +468,7 @@ public:
PatternMatchResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
- ValuePtr in = op.input(), out = op.output();
+ Value in = op.input(), out = op.output();
// If either inputPerm or outputPerm are non-identities, insert transposes.
auto inputPerm = op.inputPermutation();
OpenPOWER on IntegriCloud