summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp187
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp31
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp43
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp9
7 files changed, 182 insertions, 123 deletions
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 144afa4c5e1..109a35b7611 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -139,7 +139,11 @@ LinalgDependenceGraph::getDependencesInto(
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
- for (auto srcView : src.getOutputs()) { // W
+ assert(src.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(dst.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ for (auto srcView : src.getOutputBuffers()) { // W
// RAW graph
for (auto dstView : dst.getInputs()) { // R
if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
@@ -149,7 +153,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
}
}
// WAW graph
- for (auto dstView : dst.getOutputs()) { // W
+ for (auto dstView : dst.getOutputBuffers()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
addDependenceElem(DependenceType::WAW,
LinalgOpView{src.getOperation(), srcView},
@@ -167,7 +171,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
}
}
// WAR graph
- for (auto dstView : dst.getOutputs()) { // W
+ for (auto dstView : dst.getOutputBuffers()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
addDependenceElem(DependenceType::WAR,
LinalgOpView{src.getOperation(), srcView},
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 384d24957f4..7850dd60be1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -112,19 +112,20 @@ template <typename GenericOpType>
static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
- auto nViews = op.getNumInputsAndOutputs();
- auto nInputViews = op.getNumInputs();
- if (block.getNumArguments() != nViews)
- return op.emitOpError(
- "expected number of block arguments to match number of views");
+ auto nOperands = op.getNumOperands();
+ if (block.getNumArguments() != nOperands)
+ return op.emitOpError("expected number of block arguments to match number "
+ "of operands");
- for (unsigned i = 0; i < nViews; ++i) {
+ // Note: the number and type of yield values are checked in the YieldOp.
+ auto nInputViews = op.getNumInputs();
+ for (unsigned i = 0; i < nOperands; ++i) {
auto viewType = op.getShapedType(i);
if (viewType.getElementType() != block.getArgument(i).getType())
return op.emitOpError("expected block argument ")
- << i << " of the same type as elemental type of "
+ << (i + 1) << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
- << "view: " << viewType;
+ << "operand: " << viewType;
}
return success();
}
@@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nViews = op.getNumInputsAndOutputs();
- if (block.getNumArguments() != nViews + nLoops)
+ auto nOperands = op.getNumOperands();
+ if (block.getNumArguments() != nOperands + nLoops)
return op.emitOpError(
- "expected number of block arguments to match number of views + "
+ "expected number of block arguments to match number of operands + "
"number of loops");
- for (unsigned i = 0; i < nLoops; ++i) {
+ // Note: the number and type of yield values are checked in the YieldOp.
+ for (unsigned i = 0; i < nLoops; ++i)
if (!block.getArgument(i).getType().isIndex())
return op.emitOpError("expected block argument ")
- << i << " to be of IndexType";
- }
+ << (i + 1) << " to be an index";
- for (unsigned i = 0; i < nViews; ++i) {
+ for (unsigned i = 0; i < nOperands; ++i) {
unsigned memrefArgIndex = i + nLoops;
auto viewType = op.getShapedType(i);
if (viewType.getElementType() !=
block.getArgument(memrefArgIndex).getType())
return op.emitOpError("expected block argument ")
- << memrefArgIndex << " of the same type as elemental type of "
+ << (memrefArgIndex + 1)
+ << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
- << "view: " << viewType;
+ << "operand: " << viewType;
}
return success();
}
@@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
template <typename GenericOpType>
static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
+template <typename GenericOpType>
+LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) {
+ auto res = verifyFuncArgs(op, funType);
+ if (failed(res))
+ return res;
+
+ auto nInputs = op.getNumInputs();
+ auto nOutputs = op.getNumOutputs();
+ // linalg.generic output element types are exactly the function results.
+ for (unsigned idx = 0; idx < nOutputs; ++idx) {
+ ShapedType shapedType = op.getShapedType(nInputs + idx);
+ if (funType.getResult(idx) != shapedType.getElementType())
+ return op.emitOpError("expected function result ")
+ << (idx + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of output " << (idx + 1);
+ }
+ return success();
+}
+
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
- auto nViews = op.getNumInputsAndOutputs();
- auto nInputViews = op.getNumInputs();
- if (funType.getNumInputs() != nViews)
- return op.emitOpError("expected fun arguments to match number of views");
- if (funType.getNumResults() != op.getNumOutputs())
+ auto nOperands = op.getNumOperands();
+ if (funType.getNumInputs() != nOperands)
return op.emitOpError(
- "expected fun results to match number of output views");
-
- for (auto en : llvm::enumerate(op.indexing_maps())) {
- auto idx = en.index();
- auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
- : op.getOutputShapedType(idx - nInputViews);
- if (funType.getInput(idx) != view.getElementType())
- return op.emitOpError("expected fun argument ")
- << idx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
-
- if (idx >= nInputViews) {
- auto resultIdx = idx - nInputViews;
- if (funType.getResult(resultIdx) != view.getElementType())
- return op.emitOpError("expected fun result ")
- << resultIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
- }
+ "expected function arguments to match number of operands");
+ if (funType.getNumResults() != op.getNumOutputs())
+ return op.emitOpError("expected function results(")
+ << funType.getNumResults() << ") to match number of outputs("
+ << op.getNumOutputs() << ")";
+
+ // linalg.generic operands element types are exactly the first function
+ // arguments.
+ for (unsigned idx = 0; idx < nOperands; ++idx) {
+ ShapedType shapedType = op.getShapedType(idx);
+ if (funType.getInput(idx) != shapedType.getElementType())
+ return op.emitOpError("expected function argument ")
+ << (idx + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of operand " << (idx + 1);
}
+
return success();
}
template <>
LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
auto nLoops = op.getNumLoops();
- auto nInputViews = op.getNumInputs();
auto nOutputs = op.getNumOutputs();
- auto nViews = op.getNumInputsAndOutputs();
- if (funType.getNumInputs() != nViews + nLoops)
- return op.emitOpError(
- "expected fun arguments to match number of views + number of loops");
+ auto nOperands = op.getNumOperands();
+ if (funType.getNumInputs() != nOperands + nLoops)
+ return op.emitOpError("expected function arguments to match number of "
+ "loops + number of operands");
if (funType.getNumResults() != nOutputs)
return op.emitOpError(
- "expected fun results to match number of output views");
- for (unsigned i = 0; i < nLoops; ++i) {
+ "expected function results to match number of outputs");
+ for (unsigned i = 0; i < nLoops; ++i)
if (!funType.getInput(i).isIndex())
- return op.emitOpError("expected fun argument ")
- << i << " to be of IndexType";
- }
- for (auto en : llvm::enumerate(op.indexing_maps())) {
- auto idx = en.index();
- auto funIdx = nLoops + idx;
- auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
- : op.getOutputShapedType(idx - nInputViews);
- if (funType.getInput(funIdx) != view.getElementType())
- return op.emitOpError("expected fun argument ")
- << funIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
-
- if (idx >= nInputViews) {
- auto resultIdx = idx - nInputViews;
- if (funType.getResult(resultIdx) != view.getElementType())
- return op.emitOpError("expected fun result ")
- << resultIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
- }
+ return op.emitOpError("expected function argument ")
+ << (i + 1) << " to be an index";
+
+ // linalg.generic operands element types are exactly the first function
+ // arguments.
+ for (unsigned idx = 0; idx < nOperands; ++idx) {
+ ShapedType shapedType = op.getShapedType(idx);
+ if (funType.getInput(idx + nLoops) != shapedType.getElementType())
+ return op.emitOpError("expected function argument ")
+ << (idx + nLoops + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of input " << (idx + 1);
}
+
return success();
}
@@ -231,9 +237,11 @@ template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nViews = op.getNumInputsAndOutputs();
- if (nViews != llvm::size(op.views()))
- return op.emitOpError("expected exactly ") << nViews << " view operands";
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
auto &region = op.region();
auto funOp = op.getFunction();
@@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
} else {
if (!funOp || !funOp.getType())
return op.emitOpError(
- "expected fun attribute to refer to a defined symbol");
- if (failed(verifyFuncArgs(op, funType)))
+ "expected function attribute to refer to a defined symbol");
+ if (failed(verifyFuncArgsGeneric(op, funType)))
return failure();
}
@@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
return op.emitOpError("expected the concatenation of maps in indexing_map "
"to be invertible");
- auto outputTensorTypes = op.getOutputTensorTypes();
- if (outputTensorTypes.size() != op.getNumResults())
- return op.emitOpError("expected #output tensor operands (")
- << outputTensorTypes.size() << ") to match #results ("
- << op.getNumResults() << ")";
-
- unsigned index = 0;
- for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) {
- auto resTy = std::get<0>(it);
- auto outOpTy = std::get<1>(it);
- if (resTy != outOpTy)
- return op.emitOpError("result #")
- << index << " must be " << outOpTy << ", but got " << resTy;
- ++index;
- }
-
return success();
}
@@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
template <typename GenericOpType>
static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
// The operand number and types must match the view element types.
- auto nOutputViews = genericOp.getNumOutputs();
- if (op.getNumOperands() != nOutputViews)
- return op.emitOpError("expected ")
- << nOutputViews << " operand to match enclosing linalg.generic op";
+ auto nOutputs = genericOp.getNumOutputs();
+ if (op.getNumOperands() != nOutputs)
+ return op.emitOpError("expected number of yield values (")
+ << nOutputs << ") to match the number of operands of the enclosing "
+ << "linalg.generic op (" << op.getNumOperands() << ")";
- for (unsigned i = 0; i != nOutputViews; ++i) {
+ for (unsigned i = 0; i != nOutputs; ++i) {
auto elementType = genericOp.getOutputShapedType(i).getElementType();
if (op.getOperand(i).getType() != elementType)
- return op.emitOpError("type of return operand ")
- << i << " (" << op.getOperand(i).getType()
- << ") doesn't match view element type (" << elementType << ")";
+ return op.emitOpError("type of yield operand ")
+ << (i + 1) << " (" << op.getOperand(i).getType()
+ << ") doesn't match "
+ << "the element type of the enclosing linalg.generic op ("
+ << elementType << ")";
}
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 043d9c0e7cd..6ad73ee759e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -67,12 +67,13 @@ static llvm::cl::list<unsigned> clTileSizes(
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<SubViewOp::Range> loopRanges) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = loopToOperandRangesMaps(op);
SmallVector<Value, 8> clonedViews;
clonedViews.reserve(op.getNumInputsAndOutputs());
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
@@ -118,10 +119,11 @@ struct ViewDimension {
// they must agree by construction (i.e. have the same size) and we just return
// the first one.
static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = loopToOperandRangesMaps(op);
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
@@ -144,6 +146,10 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, unsigned producerIdx,
OperationFolder *folder) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto subView = dyn_cast_or_null<SubViewOp>(
consumer.getInput(consumerIdx).getDefiningOp());
auto slice =
@@ -197,6 +203,10 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
// Some of these will be lifted in the future with better analysis.
static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
LinalgOp consumer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (producer.getNumOutputs() != 1) {
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
return false;
@@ -217,6 +227,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
LinalgOp consumer,
Value consumedView,
LinalgOp producer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
// Make some simple structural checks that alleviate the need for more
// complex analyses.
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
@@ -236,6 +250,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
LinalgOp consumer, Value consumedView,
LinalgOp producer) {
+ assert(producer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
return false;
// Check for any fusion-preventing dependence to any view read/written that
@@ -252,6 +270,8 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
Optional<FusionInfo> mlir::linalg::fuseProducerOf(
OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
const LinalgDependenceGraph &graph, OperationFolder *folder) {
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
<< *consumer.getOperation());
for (auto dependence : graph.getDependencesInto(
@@ -268,7 +288,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
// Consumer consumes this view, `isStructurallyFusableProducer` also checks
// whether it is a strict subview of the producer view.
auto producedView = dependence.dependentOpView.view;
- auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
+ auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
<< " view: " << producedView
@@ -309,7 +329,10 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
// Save original Linalg ops, we only want to make a pass over those.
SmallVector<Operation *, 8> linalgOps;
- f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
+ f.walk([&](LinalgOp op) {
+ if (op.hasBufferSemantics())
+ linalgOps.push_back(op);
+ });
Aliases aliases;
LinalgDependenceGraph G(aliases, linalgOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index f5dac8aced1..2f97b62280a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -90,6 +90,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+ assert(copyOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
@@ -98,7 +100,7 @@ public:
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
- IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+ IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
@@ -112,11 +114,13 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+ assert(fillOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
- IndexedValueType O(fillOp.getOutput(0));
+ IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
@@ -128,10 +132,12 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
+ assert(dotOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
- C(dotOp.getOutput(0));
+ C(dotOp.getOutputBuffer(0));
// Emit scalar form.
C() = C() + A(r_i) * B(r_i);
}
@@ -142,10 +148,12 @@ class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatvecOp matvecOp) {
+ assert(matvecOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
- C(matvecOp.getOutput(0));
+ C(matvecOp.getOutputBuffer(0));
// Emit scalar form.
C(i) = C(i) + A(i, r_j) * B(r_j);
}
@@ -156,10 +164,12 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatmulOp matmulOp) {
+ assert(matmulOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
- C(matmulOp.getOutput(0));
+ C(matmulOp.getOutputBuffer(0));
// Emit scalar form.
C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
}
@@ -169,6 +179,8 @@ template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
+ assert(convOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
@@ -219,6 +231,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
GenericOp genericOp) {
+ assert(genericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -237,7 +251,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
+ indexedValues[nInputs + i] =
+ std_load(genericOp.getOutputBuffer(i), indexing);
}
auto funcOp = genericOp.getFunction();
@@ -250,7 +265,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+ std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
}
return;
}
@@ -273,8 +288,8 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
- indexing);
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ genericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -314,6 +329,8 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
IndexedGenericOp indexedGenericOp) {
+ assert(indexedGenericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
@@ -339,7 +356,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
- std_load(indexedGenericOp.getOutput(i), indexing);
+ std_load(indexedGenericOp.getOutputBuffer(i), indexing);
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -351,7 +368,7 @@ public:
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+ std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
indexing);
}
return;
@@ -376,7 +393,7 @@ public:
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
- indexedGenericOp.getOutput(i), indexing);
+ indexedGenericOp.getOutputBuffer(i), indexing);
}
}
};
@@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
auto linalgOp = cast<ConcreteOpTy>(op);
+ assert(linalgOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
auto invertedMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
if (!invertedMap) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 9657daf9137..10c537ebd29 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -93,6 +93,8 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType) {
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
+ assert(consumer.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
if (!consumer)
return false;
@@ -171,7 +173,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
return false;
return true;
};
- if (!llvm::all_of(genericOp.getInputsAndOutputs(),
+ if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(),
isStaticMemRefWithIdentityLayout))
return failure();
return success();
@@ -188,6 +190,8 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
"DRR failure case must be a precondition");
auto genericOp = cast<linalg::GenericOp>(op);
+ assert(genericOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
@@ -195,7 +199,7 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
- auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
+ auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
genericOp.iterator_types());
@@ -262,7 +266,7 @@ LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) {
// Transformation applies to buffers only.
if (!linOp || !linOp.hasBufferSemantics())
return failure();
- if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) {
+ if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) {
return isa_and_nonnull<SubViewOp>(v.getDefiningOp());
}))
return failure();
@@ -279,8 +283,10 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
"DRR failure case must be a precondition");
LinalgOp linOp = cast<LinalgOp>(op);
+ assert(linOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
SetVector<Value> subViews;
- for (auto it : linOp.getInputsAndOutputs())
+ for (auto it : linOp.getInputsAndOutputBuffers())
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index eb605699890..3caa8c8d1f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -155,6 +155,8 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
SetVector<Value> subViews,
bool dynamicBuffers,
OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+
// 1. Promote the specified views and use them in the new op.
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
@@ -164,7 +166,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
SmallVector<std::pair<Value, Value>, 8> writebackViews;
writebackViews.reserve(subViews.size());
unsigned promotedIdx = 0;
- for (auto view : op.getInputsAndOutputs()) {
+ for (auto view : op.getInputsAndOutputBuffers()) {
if (subViews.count(view) != 0) {
opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
writebackViews.emplace_back(std::make_pair(
@@ -187,7 +189,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
// WARNING: MUST use the old op to determine whether the operand view is an
// output.
bool isOutput =
- op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
+ op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
if (isOutput)
copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first);
}
@@ -203,11 +205,14 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
SmallVector<LinalgOp, 8> toErase;
OperationFolder folder(f.getContext());
f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
+ if (!op.hasBufferSemantics())
+ return;
+
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
SetVector<Value> subViews;
OpBuilder b(op);
- for (auto it : op.getInputsAndOutputs())
+ for (auto it : op.getInputsAndOutputBuffers())
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index bcf6576abbb..ed9553d710d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -173,6 +173,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
static void transformIndexedGenericOpIndices(
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
if (!indexedGenericOp)
return;
@@ -232,6 +233,8 @@ static SmallVector<Value, 4>
makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
ArrayRef<Value> ivs, ArrayRef<Value> tileSizes,
ArrayRef<Value> viewSizes, OperationFolder *folder) {
+ assert(linalgOp.hasBufferSemantics() &&
+ "expected linalg op with buffer semantics");
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](Value v) { return !isZero(v); })) &&
@@ -254,7 +257,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
SmallVector<Value, 4> res;
res.reserve(op->getNumOperands());
- auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
+ auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin();
for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
++viewIndex) {
Value view = *(viewIteratorBegin + viewIndex);
@@ -309,6 +312,7 @@ Optional<TiledLinalgOp>
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation,
OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
@@ -383,6 +387,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> permutation, OperationFolder *folder) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
if (tileSizes.empty())
return llvm::None;
@@ -419,6 +424,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
OpBuilder b(f);
OperationFolder folder(f.getContext());
f.walk([tileSizes, &b, &folder](LinalgOp op) {
+ if (!op.hasBufferSemantics())
+ return;
auto opLoopsPair =
tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder);
// If tiling occurred successfully, erase old op.
OpenPOWER on IntegriCloud