diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 95 |
1 files changed, 75 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 606160b9b14..3daeafe00ca 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -15,50 +15,84 @@ // limitations under the License. // ============================================================================= -#include "mlir/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/Functional.h" using namespace mlir; using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; + +static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices, + unsigned &pos) { + for (auto sidx : structuredIndices) { + for (auto expr : sidx.getExprs()) { + expr.walk([&pos](AffineExpr e) { + if (auto d = e.dyn_cast<AffineDimExpr>()) + pos = std::max(pos, d.getPosition()); + }); + } + } +} Operation *mlir::edsc::makeLinalgGenericOp( - ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions, - ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews, - ArrayRef<StringRef> iteratorTypes, - decltype(defaultRegionBuilder) regionBuilder) { + ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs, + ArrayRef<StructuredIndexed> outputs, + decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues, + ArrayRef<Attribute> otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); + unsigned nInputs = inputs.size(); + unsigned nOutputs = outputs.size(); + unsigned rank = 0; + getMaxDimIndex(inputs, rank); + getMaxDimIndex(outputs, rank); SmallVector<AffineMap, 4> maps; - maps.reserve(mapExpressions.size()); - for (auto exprs : mapExpressions) - maps.push_back(AffineMap::get(indices.size(), 0, exprs)); + maps.reserve(nInputs + nOutputs); + for (auto in : inputs) + maps.push_back( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs())); + for (auto out : outputs) + maps.push_back( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs())); - SmallVector<Value *, 4> views; - views.reserve(inputViews.size() + outputViews.size()); - views.append(inputViews.begin(), inputViews.end()); - views.append(outputViews.begin(), outputViews.end()); + unsigned nViews = nInputs + nOutputs; + SmallVector<Value *, 4> values; + values.reserve(nViews); + values.append(inputs.begin(), inputs.end()); + values.append(outputs.begin(), outputs.end()); + auto iteratorStrTypes = functional::map(toString, iteratorTypes); + // clang-format off auto *op = edsc::ScopedContext::getBuilder() .create<linalg::GenericOp>( - edsc::ScopedContext::getLocation(), views, - IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()), - IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()), + edsc::ScopedContext::getLocation(), + values, + IntegerAttr::get(IntegerType::get(64, ctx), nInputs), + IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), builder.getAffineMapArrayAttr(maps), - builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/, - FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/ + builder.getStrArrayAttr(iteratorStrTypes), + StringAttr() /*doc*/, + FlatSymbolRefAttr() /*fun*/, + StringAttr() /*library_call*/ + /* TODO: other attributes in op */ ) .getOperation(); + // clang-format on using namespace edsc; SmallVector<Type, 4> blockTypes; - blockTypes.reserve(views.size()); - for (auto *v : views) - blockTypes.push_back(getElementTypeOrSelf(v)); + blockTypes.reserve(values.size()); + for (auto it : llvm::enumerate(values)) + blockTypes.push_back((it.index() < nViews) + ? getElementTypeOrSelf(it.value()) + : it.value()->getType()); assert(op->getRegions().front().empty()); op->getRegions().front().push_front(new Block); @@ -70,3 +104,24 @@ Operation *mlir::edsc::makeLinalgGenericOp( [&] { regionBuilder(b.getBlock()->getArguments()); }); return op; } + +using linalg_yield = OperationBuilder<linalg::YieldOp>; + +Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC) { + // clang-format off + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + StructuredIndexed A(vA), B(vB), C(vC); + return makeLinalgGenericOp( + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {A({m, n}), B({k, n})}, + {C({m, n})}, + [](ArrayRef<BlockArgument *> args) { + using edsc::op::operator*; + using edsc::op::operator+; + ValueHandle a(args[0]), b(args[1]), c(args[2]); + linalg_yield((c + a * b).getValue()); + }); + // clang-format on +} |

