diff options
Diffstat (limited to 'mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 68 |
1 files changed, 33 insertions, 35 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 1b70df6f8bd..2a034fd15c5 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -111,23 +111,21 @@ public: BaseViewConversionHelper(Type type) : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} - BaseViewConversionHelper(ValuePtr v) : d(v) {} + BaseViewConversionHelper(Value v) : d(v) {} /// Wrappers around MemRefDescriptor that use EDSC builder and location. - ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } - void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); } - ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); } - void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); } - ValuePtr offset() { return d.offset(rewriter(), loc()); } - void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); } - ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); } - void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); } - ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); } - void setStride(unsigned i, ValuePtr v) { - d.setStride(rewriter(), loc(), i, v); - } - - operator ValuePtr() { return d; } + Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } + Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } + Value offset() { return d.offset(rewriter(), loc()); } + void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } + Value size(unsigned i) { return d.size(rewriter(), loc(), i); } + void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + + operator Value() { return d; } private: OpBuilder &rewriter() { return ScopedContext::getBuilder(); } @@ -144,7 +142,7 @@ public: : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast<RangeOp>(op); auto rangeDescriptorTy = @@ -154,7 +152,7 @@ public: // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); - ValuePtr desc = llvm_undef(rangeDescriptorTy); + Value desc = llvm_undef(rangeDescriptorTy); desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); @@ -177,7 +175,7 @@ public: : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); SliceOpOperandAdaptor adaptor(operands); @@ -191,7 +189,7 @@ public: BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); // TODO(ntv): extract sizes and emit asserts. - SmallVector<ValuePtr, 4> strides(memRefType.getRank()); + SmallVector<Value, 4> strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) strides[i] = baseDesc.stride(i); @@ -200,10 +198,10 @@ public: }; // Compute base offset. - ValuePtr baseOffset = baseDesc.offset(); + Value baseOffset = baseDesc.offset(); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { - ValuePtr indexing = adaptor.indexings()[i]; - ValuePtr min = indexing; + Value indexing = adaptor.indexings()[i]; + Value min = indexing; if (sliceOp.indexing(i)->getType().isa<RangeType>()) min = extractvalue(int64Ty, indexing, pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); @@ -220,29 +218,29 @@ public: if (sliceOp.getViewType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); - ValuePtr zero = + Value zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; for (auto en : llvm::enumerate(sliceOp.indexings())) { - ValuePtr indexing = en.value(); + Value indexing = en.value(); if (indexing->getType().isa<RangeType>()) { int rank = en.index(); - ValuePtr rangeDescriptor = adaptor.indexings()[rank]; - ValuePtr min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - ValuePtr max = extractvalue(int64Ty, rangeDescriptor, pos(1)); - ValuePtr step = extractvalue(int64Ty, rangeDescriptor, pos(2)); - ValuePtr baseSize = baseDesc.size(rank); + Value rangeDescriptor = adaptor.indexings()[rank]; + Value min = extractvalue(int64Ty, rangeDescriptor, pos(0)); + Value max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value baseSize = baseDesc.size(rank); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); - ValuePtr size = sub(max, min); + Value size = sub(max, min); // Bound lower by zero. size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); - ValuePtr stride = mul(strides[rank], step); + Value stride = mul(strides[rank], step); desc.setSize(numNewDims, size); desc.setStride(numNewDims, stride); ++numNewDims; @@ -268,7 +266,7 @@ public: : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. edsc::ScopedContext context(rewriter, op->getLoc()); @@ -311,7 +309,7 @@ public: : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, + matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); return matchSuccess(); @@ -446,7 +444,7 @@ public: op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); auto indexedGenericOp = cast<IndexedGenericOp>(op); auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector<ValuePtr, 4> operands; + SmallVector<Value, 4> operands; operands.reserve(numLoops + op.getNumOperands()); for (unsigned i = 0; i < numLoops; ++i) { operands.push_back(zero); @@ -470,7 +468,7 @@ public: PatternMatchResult matchAndRewrite(CopyOp op, PatternRewriter &rewriter) const override { - ValuePtr in = op.input(), out = op.output(); + Value in = op.input(), out = op.output(); // If either inputPerm or outputPerm are non-identities, insert transposes. auto inputPerm = op.inputPermutation(); |