diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-12-16 13:32:02 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-16 13:42:38 -0800 |
| commit | 3c179b657583c4098d189a475d85f39ff230d924 (patch) | |
| tree | bd039a67c8c352618a3b87b720e7904ea3ebf2bf /mlir/include | |
| parent | 11e92875f07261c64205c8b72038abf0d65729a0 (diff) | |
| download | bcm5719-llvm-3c179b657583c4098d189a475d85f39ff230d924.tar.gz bcm5719-llvm-3c179b657583c4098d189a475d85f39ff230d924.zip | |
Add edsc::ops for pointwise, conv and dilated_conv
This CL adds more Linalg EDSC ops and tests to support building pointwise operations along with conv and dilated_conv.
This also fixes a bug in the existing linalg_matmul EDSC and beefs up the test.
The current set of ops is already enough to build an interesting, albeit simple, model used internally.
PiperOrigin-RevId: 285838012
Diffstat (limited to 'mlir/include')
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 145 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h | 35 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 4 | ||||
| -rw-r--r-- | mlir/include/mlir/EDSC/Intrinsics.h | 14 |
4 files changed, 186 insertions, 12 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index 00da1d68cf2..421342038c9 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -22,15 +22,17 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" namespace mlir { class BlockArgument; -namespace edsc { +namespace edsc { enum class IterType { Parallel, Reduction }; inline StringRef toString(IterType t) { @@ -38,7 +40,7 @@ inline StringRef toString(IterType t) { case IterType::Parallel: return getParallelIteratorTypeName(); case IterType::Reduction: - return getParallelIteratorTypeName(); + return getReductionIteratorTypeName(); default: llvm_unreachable("Unsupport IterType"); } @@ -78,20 +80,83 @@ inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {} Operation *makeLinalgGenericOp( ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs, ArrayRef<StructuredIndexed> outputs, - decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder, + llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder = + defaultRegionBuilder, ArrayRef<Value *> otherValues = {}, ArrayRef<Attribute> otherAttributes = {}); +namespace ops { +using edsc::StructuredIndexed; +using edsc::ValueHandle; +using edsc::intrinsics::linalg_yield; + //===----------------------------------------------------------------------===// // EDSC builders for linalg generic operations. //===----------------------------------------------------------------------===// +/// Build the body of a region to compute a multiply-accumulate, under the +/// current ScopedContext, at the current insert point. +void macRegionBuilder(ArrayRef<BlockArgument *> args); + /// TODO(ntv): In the future we should tie these implementations to something in /// Tablegen that generates the proper interfaces and the proper sugared named /// ops. -/// Build a linalg.generic that represents C = A * B in the current -/// ScopedContext. +/// Build a linalg.pointwise, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (i0, ..., in) = (par, ..., par) +/// | +/// | O...(some_subset...(i0, ..., in)) = +/// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in))) +/// ``` +/// +/// This is a very generic entry point that can be configured in many ways to +/// build a perfect loop nest of parallel loops with arbitrarily complex +/// innermost loop code and whatever (explicit) broadcast semantics. +/// +/// This can be used with both out-of-place and in-place semantics. +/// The client is responsible for ensuring the region operations are compatible +/// with in-place semantics and parallelism. + +/// Unary pointwise operation (with broadcast) entry point. +using UnaryPointwiseOpBuilder = llvm::function_ref<Value *(ValueHandle)>; +Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, + StructuredIndexed I, StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = tanh(I)`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); + +/// Binary pointwise operation (with broadcast) entry point. +using BinaryPointwiseOpBuilder = + llvm::function_ref<Value *(ValueHandle, ValueHandle)>; +Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, + StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = I1 + I2`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = max(I!, I2)`. The client is responsible for specifying the +/// proper indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +// TODO(ntv): Implement more useful pointwise operations on a per-need basis. + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (m, n, k) = (par, par, seq) +/// | +/// | C(m, n) += A(m, k) * B(k, n) +/// ``` Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); template <typename Container> Operation *linalg_matmul(Container values) { @@ -99,6 +164,76 @@ template <typename Container> Operation *linalg_matmul(Container values) { return linalg_matmul(values[0], values[1], values[2]); } +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, f, [h, w, ...], [kh, kw, ...], c) = +/// | (par, par, [par, par, ...], [red, red, ...], red) +/// | +/// | O(batch, [h, w, ...], f) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, f) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef<int> strides = {}, + ArrayRef<int> dilations = {}); + +template <typename Container> +Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {}, + ArrayRef<int> dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); +} + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = +/// | (par, par, par, [par, par, ...], [red, red, ...]) +/// | +/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, depth_multiplier) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, int depth_multiplier = 1, + ArrayRef<int> strides = {}, + ArrayRef<int> dilations = {}); + +template <typename Container> +Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, + ArrayRef<int> strides = {}, + ArrayRef<int> dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_dilated_conv_nhwc(values[0], values[1], values[2], + depth_multiplier, strides, dilations); +} + +} // namespace ops } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h new file mode 100644 index 00000000000..f1acab69a4d --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -0,0 +1,35 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using linalg_fill = OperationBuilder<linalg::FillOp>; +using linalg_yield = OperationBuilder<linalg::YieldOp>; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 4f9621c9912..1f24a903e41 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -247,13 +247,13 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> { } def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> { - let arguments = (ins AnyStridedMemRef:$input, + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ ArrayAttr indexing_maps(); ArrayAttr iterator_types() { - unsigned nPar = input()->getType().cast<ShapedType>().getRank(); + unsigned nPar = output()->getType().cast<ShapedType>().getRank(); MLIRContext *ctx = getContext(); SmallVector<Attribute, 8> iters( nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 68bd210fce5..6dbb3432bf6 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -154,22 +154,22 @@ template <typename Op> struct ValueBuilder : public ValueHandle { /// Folder-based template <typename... Args> - ValueBuilder(OperationFolder &folder, Args... args) + ValueBuilder(OperationFolder *folder, Args... args) : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {} - ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs) + ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs) : ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {} template <typename... Args> - ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args) + ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args) : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs), detail::unpack(args)...)) {} template <typename T, typename... Args> - ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs, + ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs, Args... args) : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} template <typename T1, typename T2, typename... Args> - ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs, + ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args) : ValueHandle(ValueHandle::create<Op>( folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), @@ -200,6 +200,7 @@ template <typename Op> struct OperationBuilder : public OperationHandle { OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {} }; +using addf = ValueBuilder<AddFOp>; using affine_apply = ValueBuilder<AffineApplyOp>; using affine_if = OperationBuilder<AffineIfOp>; using affine_load = ValueBuilder<AffineLoadOp>; @@ -212,11 +213,14 @@ using constant_int = ValueBuilder<ConstantIntOp>; using dealloc = OperationBuilder<DeallocOp>; using dim = ValueBuilder<DimOp>; using muli = ValueBuilder<MulIOp>; +using mulf = ValueBuilder<MulFOp>; +using memref_cast = ValueBuilder<MemRefCastOp>; using ret = OperationBuilder<ReturnOp>; using select = ValueBuilder<SelectOp>; using std_load = ValueBuilder<LoadOp>; using std_store = OperationBuilder<StoreOp>; using subi = ValueBuilder<SubIOp>; +using tanh = ValueBuilder<TanhOp>; using view = ValueBuilder<ViewOp>; /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. |

