From f52d71736b10e87b1aa1880b777dc9462a0085ce Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sat, 11 Jan 2020 02:22:00 -0500 Subject: [mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors. Summary: This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic. The changes are self-contained within MLIR and can be captured and tested independently of XLA. The linalg.generic and indexed_generic are updated to: To allow progressive lowering from the value world (a.k.a tensor values) to the buffer world (a.k.a memref values), a linalg.generic op accepts mixing input and output ranked tensor values with input and output memrefs. ``` %1 = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor, memref -> (tensor) ``` In this case, the number of outputs (args_out) must match the sum of (1) the number of output buffer operands and (2) the number of tensor return values. The semantics is that the linalg.indexed_generic op produces (i.e. allocates and fills) its return values. Tensor values must be legalized by a buffer allocation pass before most transformations can be applied. Such legalization moves tensor return values into output buffer operands and updates the region argument accordingly. Transformations that create control-flow around linalg.indexed_generic operations are not expected to mix with tensors because SSA values do not escape naturally. Still, transformations and rewrites that take advantage of tensor SSA values are expected to be useful and will be added in the near future. Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72555 --- mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp') diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 9657daf9137..10c537ebd29 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -93,6 +93,8 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!consumer) return false; @@ -171,7 +173,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { return false; return true; }; - if (!llvm::all_of(genericOp.getInputsAndOutputs(), + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); @@ -188,6 +190,8 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); auto genericOp = cast(op); + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -195,7 +199,7 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, using vector_type_cast = edsc::intrinsics::ValueBuilder; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); @@ -262,7 +266,7 @@ LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); - if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) return failure(); @@ -279,8 +283,10 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); LinalgOp linOp = cast(op); + assert(linOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputs()) + for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { -- cgit v1.2.3