summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp43
1 files changed, 31 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index f5dac8aced1..2f97b62280a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -90,6 +90,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+ assert(copyOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
@@ -98,7 +100,7 @@ public:
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
- IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+ IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
@@ -112,11 +114,13 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+ assert(fillOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
- IndexedValueType O(fillOp.getOutput(0));
+ IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
@@ -128,10 +132,12 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
+ assert(dotOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
- C(dotOp.getOutput(0));
+ C(dotOp.getOutputBuffer(0));
// Emit scalar form.
C() = C() + A(r_i) * B(r_i);
}
@@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatvecOp matvecOp) {
+ assert(matvecOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
- C(matvecOp.getOutput(0));
+ C(matvecOp.getOutputBuffer(0));
// Emit scalar form.
C(i) = C(i) + A(i, r_j) * B(r_j);
}
@@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatmulOp matmulOp) {
+ assert(matmulOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
- C(matmulOp.getOutput(0));
+ C(matmulOp.getOutputBuffer(0));
// Emit scalar form.
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
}
@@ -169,6 +179,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
+ assert(convOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
@@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
GenericOp genericOp) {
+ assert(genericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -237,7 +251,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
+ indexedValues[nInputs + i] =
+ std_load(genericOp.getOutputBuffer(i), indexing);
}
auto funcOp = genericOp.getFunction();
@@ -250,7 +265,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+ std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
}
return;
}
@@ -273,8 +288,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
- indexing);
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ genericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
IndexedGenericOp indexedGenericOp) {
+ assert(indexedGenericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -339,7 +356,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
- std_load(indexedGenericOp.getOutput(i), indexing);
+ std_load(indexedGenericOp.getOutputBuffer(i), indexing);
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -351,7 +368,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+ std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
indexing);
}
return;
@@ -376,7 +393,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
- indexedGenericOp.getOutput(i), indexing);
+ indexedGenericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOpTy>(op);
+ assert(linalgOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
OpenPOWER on IntegriCloud