summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-11 02:22:00 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-14 17:25:28 -0500
commitf52d71736b10e87b1aa1880b777dc9462a0085ce (patch)
tree3eaa824f59037e0b987abd0c39094ec999e04c3c /mlir/lib/Dialect/Linalg/Transforms
parent8d07f8d98c48ee0a9dca450aaf4e1cabc621ff68 (diff)
downloadbcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.tar.gz
bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.zip
[mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors.
Summary: This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic. The changes are self-contained within MLIR and can be captured and tested independently of XLA. The linalg.generic and indexed_generic are updated to: To allow progressive lowering from the value world (a.k.a tensor values) to the buffer world (a.k.a memref values), a linalg.generic op accepts mixing input and output ranked tensor values with input and output memrefs. ``` %1 = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor<?x?xf32>, memref<?x?xf32, stride_specification> -> (tensor<?x?xf32>) ``` In this case, the number of outputs (args_out) must match the sum of (1) the number of output buffer operands and (2) the number of tensor return values. The semantics is that the linalg.indexed_generic op produces (i.e. allocates and fills) its return values. Tensor values must be legalized by a buffer allocation pass before most transformations can be applied. Such legalization moves tensor return values into output buffer operands and updates the region argument accordingly. Transformations that create control-flow around linalg.indexed_generic operations are not expected to mix with tensors because SSA values do not escape naturally. Still, transformations and rewrites that take advantage of tensor SSA values are expected to be useful and will be added in the near future. Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72555
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp31
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp43
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp9
5 files changed, 84 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 043d9c0e7cd..6ad73ee759e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -67,12 +67,13 @@ static llvm::cl::list<unsigned> clTileSizes(
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<SubViewOp::Range> loopRanges) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = loopToOperandRangesMaps(op);
SmallVector<Value, 8> clonedViews;
clonedViews.reserve(op.getNumInputsAndOutputs());
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
@@ -118,10 +119,11 @@ struct ViewDimension {
// they must agree by construction (i.e. have the same size) and we just return
// the first one.
static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = loopToOperandRangesMaps(op);
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
@@ -144,6 +146,10 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, unsigned producerIdx,
OperationFolder *folder) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto subView = dyn_cast_or_null<SubViewOp>(
consumer.getInput(consumerIdx).getDefiningOp());
auto slice =
@@ -197,6 +203,10 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
// Some of these will be lifted in the future with better analysis.
static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
LinalgOp consumer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (producer.getNumOutputs() != 1) {
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
return false;
@@ -217,6 +227,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
LinalgOp consumer,
Value consumedView,
LinalgOp producer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
// Make some simple structural checks that alleviate the need for more
// complex analyses.
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
@@ -236,6 +250,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
LinalgOp consumer, Value consumedView,
LinalgOp producer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
return false;
// Check for any fusion-preventing dependence to any view read/written that
@@ -252,6 +270,8 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
Optional<FusionInfo> mlir::linalg::fuseProducerOf(
OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
const LinalgDependenceGraph &graph, OperationFolder *folder) {
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
<< *consumer.getOperation());
for (auto dependence : graph.getDependencesInto(
@@ -268,7 +288,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
// Consumer consumes this view, `isStructurallyFusableProducer` also checks
// whether it is a strict subview of the producer view.
auto producedView = dependence.dependentOpView.view;
- auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
+ auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
<< " view: " << producedView
@@ -309,7 +329,10 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
// Save original Linalg ops, we only want to make a pass over those.
SmallVector<Operation *, 8> linalgOps;
- f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
+ f.walk([&](LinalgOp op) {
+ if (op.hasBufferSemantics())
+ linalgOps.push_back(op);
+ });
Aliases aliases;
LinalgDependenceGraph G(aliases, linalgOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index f5dac8aced1..2f97b62280a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -90,6 +90,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+ assert(copyOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
@@ -98,7 +100,7 @@ public:
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
- IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+ IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
@@ -112,11 +114,13 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+ assert(fillOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
- IndexedValueType O(fillOp.getOutput(0));
+ IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
@@ -128,10 +132,12 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
+ assert(dotOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
- C(dotOp.getOutput(0));
+ C(dotOp.getOutputBuffer(0));
// Emit scalar form.
C() = C() + A(r_i) * B(r_i);
}
@@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatvecOp matvecOp) {
+ assert(matvecOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
- C(matvecOp.getOutput(0));
+ C(matvecOp.getOutputBuffer(0));
// Emit scalar form.
C(i) = C(i) + A(i, r_j) * B(r_j);
}
@@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatmulOp matmulOp) {
+ assert(matmulOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
- C(matmulOp.getOutput(0));
+ C(matmulOp.getOutputBuffer(0));
// Emit scalar form.
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
}
@@ -169,6 +179,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
+ assert(convOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
@@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
GenericOp genericOp) {
+ assert(genericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -237,7 +251,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
+ indexedValues[nInputs + i] =
+ std_load(genericOp.getOutputBuffer(i), indexing);
}
auto funcOp = genericOp.getFunction();
@@ -250,7 +265,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+ std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
}
return;
}
@@ -273,8 +288,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
- indexing);
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ genericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
IndexedGenericOp indexedGenericOp) {
+ assert(indexedGenericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -339,7 +356,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
- std_load(indexedGenericOp.getOutput(i), indexing);
+ std_load(indexedGenericOp.getOutputBuffer(i), indexing);
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -351,7 +368,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+ std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
indexing);
}
return;
@@ -376,7 +393,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
- indexedGenericOp.getOutput(i), indexing);
+ indexedGenericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOpTy>(op);
+ assert(linalgOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 9657daf9137..10c537ebd29 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -93,6 +93,8 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType) {
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (!consumer)
return false;
@@ -171,7 +173,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
return false;
return true;
};
- if (!llvm::all_of(genericOp.getInputsAndOutputs(),
+ if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(),
isStaticMemRefWithIdentityLayout))
return failure();
return success();
@@ -188,6 +190,8 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
"DRR failure case must be a precondition");
auto genericOp = cast<linalg::GenericOp>(op);
+ assert(genericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
@@ -195,7 +199,7 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
- auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
+ auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
genericOp.iterator_types());
@@ -262,7 +266,7 @@ LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) {
// Transformation applies to buffers only.
if (!linOp || !linOp.hasBufferSemantics())
return failure();
- if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) {
+ if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) {
return isa_and_nonnull<SubViewOp>(v.getDefiningOp());
}))
return failure();
@@ -279,8 +283,10 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
"DRR failure case must be a precondition");
LinalgOp linOp = cast<LinalgOp>(op);
+ assert(linOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
SetVector<Value> subViews;
- for (auto it : linOp.getInputsAndOutputs())
+ for (auto it : linOp.getInputsAndOutputBuffers())
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index eb605699890..3caa8c8d1f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -155,6 +155,8 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
SetVector<Value> subViews,
bool dynamicBuffers,
OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+
// 1. Promote the specified views and use them in the new op.
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
@@ -164,7 +166,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
SmallVector<std::pair<Value, Value>, 8> writebackViews;
writebackViews.reserve(subViews.size());
unsigned promotedIdx = 0;
- for (auto view : op.getInputsAndOutputs()) {
+ for (auto view : op.getInputsAndOutputBuffers()) {
if (subViews.count(view) != 0) {
opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
writebackViews.emplace_back(std::make_pair(
@@ -187,7 +189,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
// WARNING: MUST use the old op to determine whether the operand view is an
// output.
bool isOutput =
- op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
+ op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
if (isOutput)
copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first);
}
@@ -203,11 +205,14 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
SmallVector<LinalgOp, 8> toErase;
OperationFolder folder(f.getContext());
f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
+ if (!op.hasBufferSemantics())
+ return;
+
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
SetVector<Value> subViews;
OpBuilder b(op);
- for (auto it : op.getInputsAndOutputs())
+ for (auto it : op.getInputsAndOutputBuffers())
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index bcf6576abbb..ed9553d710d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -173,6 +173,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
static void transformIndexedGenericOpIndices(
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
if (!indexedGenericOp)
return;
@@ -232,6 +233,8 @@ static SmallVector<Value, 4>
makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
ArrayRef<Value> ivs, ArrayRef<Value> tileSizes,
ArrayRef<Value> viewSizes, OperationFolder *folder) {
+ assert(linalgOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](Value v) { return !isZero(v); })) &&
@@ -254,7 +257,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
SmallVector<Value, 4> res;
res.reserve(op->getNumOperands());
- auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
+ auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin();
for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
++viewIndex) {
Value view = *(viewIteratorBegin + viewIndex);
@@ -309,6 +312,7 @@ Optional<TiledLinalgOp>
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation,
OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
@@ -383,6 +387,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
if (tileSizes.empty())
return llvm::None;
@@ -419,6 +424,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
OpBuilder b(f);
OperationFolder folder(f.getContext());
f.walk([tileSizes, &b, &folder](LinalgOp op) {
+ if (!op.hasBufferSemantics())
+ return;
auto opLoopsPair =
tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder);
// If tiling occurred successfully, erase old op.
OpenPOWER on IntegriCloud