diff options
| author | Jose Ignacio Gomez <jigomez@ucm.es> | 2019-12-05 15:14:22 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-05 15:14:59 -0800 |
| commit | f60bbb6c3b407b25367ce5bc5637b6edaf8c9e16 (patch) | |
| tree | f105f9a3bb05cbb390ae01113b7e8a8cd1759120 | |
| parent | da53000fb4191a3c1cef31d0b2faf4757a5dcfec (diff) | |
| download | bcm5719-llvm-f60bbb6c3b407b25367ce5bc5637b6edaf8c9e16.tar.gz bcm5719-llvm-f60bbb6c3b407b25367ce5bc5637b6edaf8c9e16.zip | |
[Linalg] Add permutation information to tiling
This patch closes issue tensorflow/mlir#271.
It adds an optional permutation map to declarative tiling transformations.
The map is expressed as a list of integers.
Closes tensorflow/mlir#288
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/288 from tetuante:issue271 2df2938d6a1f01b3bc404ded08dea2dd1e10b588
PiperOrigin-RevId: 284064151
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td | 8 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h | 15 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 28 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/AffineMap.h | 9 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 47 | ||||
| -rw-r--r-- | mlir/lib/IR/AffineMap.cpp | 14 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/tile_permute_patterns.mlir | 70 | ||||
| -rw-r--r-- | mlir/test/lib/DeclarativeTransforms/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td | 57 | ||||
| -rw-r--r-- | mlir/test/lib/Transforms/CMakeLists.txt | 2 | ||||
| -rw-r--r-- | mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp | 64 |
12 files changed, 304 insertions, 24 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 8bc0eaf2097..f558fa5da48 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -57,9 +57,13 @@ class TileAndFuseLinalgOp< // In the future, tile sizes should be derived from op properties + machine // description but we do not need to wait on this to start having useful // patterns. -class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall< +// `permutation` is an optional parameter to specify the ordering of the +// tiled loops. If provided, it must be a list of integers with the same number +// of elements as `sizes`. +class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall< "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # - StrJoinInt<sizes>.result # "}, \"" # value # "\")))" # + StrJoinInt<sizes>.result # "}, \"" # value # "\", {" # + StrJoinInt<permutation>.result # "})))" # " return matchFailure();">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 966b8f93135..89615e113c7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -58,11 +58,20 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { // success. //////////////////////////////////////////////////////////////////////////////// -// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to -// `linalgMarker`. +/// Tiles `op` by `sizes` permuting the looops according to `permutation` +/// and sets the attribute `kLinalgTransformMarker` to `linalgMarker`. +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, - StringRef linalgMarker); + StringRef linalgMarker, + ArrayRef<unsigned> permutation); // Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets // the attribute `kLinalgTransformMarker` to `linalgMarker`. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 91c7082b264..8dc78458c87 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -134,23 +134,43 @@ struct TiledLinalgOp { }; /// Performs standalone tiling of a single LinalgOp by `tileSizes`. -/// Returns a struct containing the tiled loops and the cloned op if successful, -/// llvm::None otherwise. +/// and permute the loop nest according to `permutation` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes, + ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr); /// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. -/// Returns a struct containing the tiled loops and the cloned op if successful, -/// llvm::None otherwise. +/// and permute the loop nest according to `permutation` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `permutation` +/// must be equal to the length of `tileSizes`. +/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with +/// `permutation = [1,2,0]`. All values in `permutation` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). An empty list +/// states for the identity permutation. +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, + ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr); template <typename... Args> diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 9b30f15628a..e42173d5a2b 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -65,6 +65,15 @@ public: static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context); + /// Returns an AffineMap representing a permutation. + /// The permutation is expressed as a non-empty vector of integers. + /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with + /// `permutation = [1,2,0]`. All values in `permutation` must be + /// integers, in the range 0..`permutation.size()-1` without duplications + /// (i.e. `[1,1,2]` is an invalid permutation). + static AffineMap getPermutationMap(ArrayRef<unsigned> permutation, + MLIRContext *context); + MLIRContext *getContext() const; explicit operator bool() { return map != nullptr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 0e4aaa7ac83..1b4509ffc11 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -33,11 +33,11 @@ using namespace mlir::linalg; const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(PatternRewriter &rewriter, - Operation *op, - ArrayRef<int64_t> sizes, - StringRef linalgMarker) { - auto tileRes = tileLinalgOperation(rewriter, op, sizes); +LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, + StringRef linalgMarker, ArrayRef<unsigned> permutation) { + assert(permutation.empty() || permutation.size() == sizes.size()); + auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 09a1ba6b332..2c84eeecbba 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -215,10 +215,17 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, return res; } -llvm::Optional<TiledLinalgOp> -mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef<Value *> tileSizes, - OperationFolder *folder) { +void applyPermutationToLoopRanges(SmallVector<SubViewOp::Range, 4> &loopRanges, + ArrayRef<unsigned> permutation) { + SmallVector<SubViewOp::Range, 4> auxVec(loopRanges.size()); + for (unsigned i = 0; i < permutation.size(); ++i) + auxVec[i] = loopRanges[permutation[i]]; + loopRanges = auxVec; +} + +llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( + OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes, + ArrayRef<unsigned> permutation, OperationFolder *folder) { // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -226,6 +233,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, op.getNumWindowLoops() == tileSizes.size() && "expected matching number of tile sizes and loops"); + + // If permutation is empty, use the identity. Build the permutation map + // otherwise. + auto invPermutationMap = AffineMap::getMultiDimIdentityMap( + tileSizes.size(), ScopedContext::getContext()); + if (!permutation.empty()) + invPermutationMap = inversePermutation( + AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); ScopedContext scope(b, op.getLoc()); @@ -239,6 +255,8 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, auto loopRanges = makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, viewSizes, tileSizes, folder); + if (!permutation.empty()) + applyPermutationToLoopRanges(loopRanges, permutation); // 3. Create the tiled loops. LinalgOp res = op; @@ -248,6 +266,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end()); + + // If we have to apply a permutation to the tiled loop nest, we have to + // reorder the induction variables This permutation is the right one + // assuming that loopRanges have previously been permuted by + // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of + // that one: (d0,d1,d2)->(d2,d0,d1) + if (!permutation.empty()) + ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); + auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder); auto operands = getAssumedNonViewOperands(op); @@ -264,10 +291,9 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, return TiledLinalgOp{res, loops}; } -llvm::Optional<TiledLinalgOp> -mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef<int64_t> tileSizes, - OperationFolder *folder) { +llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( + OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, + ArrayRef<unsigned> permutation, OperationFolder *folder) { if (tileSizes.empty()) return llvm::None; @@ -297,14 +323,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, tileSizeValues.push_back(constant_index(folder, 0)); } - return tileLinalgOp(b, op, tileSizeValues, folder); + return tileLinalgOp(b, op, tileSizeValues, permutation, folder); } static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { OpBuilder b(f); OperationFolder folder(f.getContext()); f.walk([tileSizes, &b, &folder](LinalgOp op) { - auto opLoopsPair = tileLinalgOp(b, op, tileSizes, &folder); + auto opLoopsPair = + tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); // If tiling occurred successfully, erase old op. if (opLoopsPair) op.erase(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index e56d0e83f65..98357b1348b 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -106,6 +106,20 @@ AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { {getAffineConstantExpr(val, context)}); } +/// Returns an AffineMap representing a permutation. +AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, + MLIRContext *context) { + assert(!permutation.empty() && + "Cannot create permutation map from empty permutation vector"); + SmallVector<AffineExpr, 4> affExprs; + for (auto index : permutation) + affExprs.push_back(getAffineDimExpr(index, context)); + auto m = std::max_element(permutation.begin(), permutation.end()); + auto permutationMap = AffineMap::get(*m + 1, 0, affExprs); + assert(permutationMap.isPermutation() && "Invalid permutation vector"); + return permutationMap; +} + AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, MLIRContext *context) { SmallVector<AffineExpr, 4> dimExprs; diff --git a/mlir/test/Dialect/Linalg/tile_permute_patterns.mlir b/mlir/test/Dialect/Linalg/tile_permute_patterns.mlir new file mode 100644 index 00000000000..4844f20afa2 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_permute_patterns.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s -test-linalg-tile-and-permute-patterns | FileCheck %s + +// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + +func @dot(%x: memref<?xf32, offset: ?, strides: [1]>, + %y: memref<?xf32, offset: ?, strides: [1]>, + %v: memref<f32>) { + linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>, + memref<?xf32, offset: ?, strides: [1]>, + memref<f32> + return +} +// CHECK-LABEL: func @dot +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c8:.*]] = constant 8 : index +// CHECK-DAG : %[[c8000:.*]] = constant 8000 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { +// CHECK : linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32> + +func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, + %x: memref<?xf32, offset: ?, strides: [1]>, + %y: memref<?xf32, offset: ?, strides: [1]>) { + linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>, + memref<?xf32, offset: ?, strides: [1]>, + memref<?xf32, offset: ?, strides: [1]> + return +} +// CHECK-LABEL: func @matvec +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c5:.*]] = constant 5 : index +// CHECK-DAG : %[[c6:.*]] = constant 6 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] +// CHECK : linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]> + +func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, + %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, + %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) { + linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, + memref<?x?xf32, offset: ?, strides: [?, 1]>, + memref<?x?xf32, offset: ?, strides: [?, 1]> + return +} +// CHECK-LABEL: func @matmul +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c2:.*]] = constant 2 : index +// CHECK-DAG : %[[c3:.*]] = constant 3 : index +// CHECK-DAG : %[[c4:.*]] = constant 4 : index +// CHECK-DAG : %[[c20:.*]] = constant 20 : index +// CHECK-DAG : %[[c30:.*]] = constant 30 : index +// CHECK-DAG : %[[c40:.*]] = constant 40 : index +// CHECK-DAG : %[[c200:.*]] = constant 200 : index +// CHECK-DAG : %[[c300:.*]] = constant 300 : index +// CHECK-DAG : %[[c400:.*]] = constant 400 : index +// CHECK-DAG : %[[c2000:.*]] = constant 2000 : index +// CHECK-DAG : %[[c3000:.*]] = constant 3000 : index +// CHECK-DAG : %[[c4000:.*]] = constant 4000 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { +// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> + diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt index 06e81a098f4..1ee62d82129 100644 --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -1,3 +1,7 @@ set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td) mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) + +set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td) +mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td new file mode 100644 index 00000000000..6d7bfffdf71 --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td @@ -0,0 +1,57 @@ +//===- TestLinalgTilePermutePatterns.td - Test patterns --*- tablegen ----*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the pattern definition file for declarative Linalg transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LINALG_TILEPERMUTE_PATTERNS +#define TEST_LINALG_TILEPERMUTE_PATTERNS + +include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" + +//===----------------------------------------------------------------------===// +// Linalg tiling and permutation patterns. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[2000, 3000, 4000], "L2", [1,2,0]> $op), + [(Constraint<Or<[HasNoLinalgTransformMarker, + HasLinalgTransformMarker<"MEM">]>> $op)]>; +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[200, 300, 400], "L1", [1,0,2]> $op), + [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>; +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[20, 30, 40], "REG"> $op), + [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>; + + +def : Pattern<(MatvecOp:$op $A, $b, $c), + [(TileLinalgOp<[5, 6], "L1", [1,0]> $op)], + [(Constraint<HasNoLinalgTransformMarker> $op)]>; + +def : Pattern<(DotOp:$op $a, $b, $c), + [(TileLinalgOp<[8000], "L1"> $op)], + [(Constraint<Or<[HasNoLinalgTransformMarker, + HasLinalgTransformMarker<"MEM">, + HasLinalgTransformMarker<"L3">, + HasLinalgTransformMarker<"L2">]>> $op)]>; +def : Pattern<(DotOp:$op $a, $b, $c), + [(TileLinalgOp<[8], "REG"> $op)], + [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>; + +#endif // TEST_LINALG_TILEPERMUTE_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 8bc9c736187..8a7933451b8 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_llvm_library(MLIRTestTransforms TestLoopFusion.cpp TestInlining.cpp TestLinalgTransforms.cpp + TestLinalgTilePermuteTransforms.cpp TestLoopMapping.cpp TestLoopParametricTiling.cpp TestOpaqueLoc.cpp @@ -21,6 +22,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../DeclarativeTransforms) include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms) add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen) add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen) +add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen) target_link_libraries(MLIRTestTransforms MLIRAffineOps MLIRAnalysis diff --git a/mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp new file mode 100644 index 00000000000..ec7fa4e71b4 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp @@ -0,0 +1,64 @@ +//===- TestLinalgTilePermuteTransforms.cpp - Test Linalg tile + permute ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements logic for testing Linalg transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace mlir { +namespace linalg { +namespace { +#include "TestLinalgTilePermutePatterns.h.inc" +} // end namespace +} // end namespace linalg +} // end namespace mlir + +namespace { +struct TestLinalgTilePermuteTransforms + : public FunctionPass<TestLinalgTilePermuteTransforms> { + void runOnFunction() override; +}; +} // end anonymous namespace + +/// Apply transformations specified as patterns. +void TestLinalgTilePermuteTransforms::runOnFunction() { + OwningRewritePatternList patterns; + auto funcOp = getFunction(); + + // Add the generated patterns to the list. + linalg::populateWithGenerated(&getContext(), &patterns); + applyPatternsGreedily(funcOp, patterns); + + // Drop the marker. + funcOp.walk([](LinalgOp op) { + op.removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); +} + +static PassRegistration<TestLinalgTilePermuteTransforms> + pass("test-linalg-tile-and-permute-patterns", + "Test Linalg transformation with permutation patterns by applying " + "them greedily."); |

