summaryrefslogtreecommitdiffstats
path: root/mlir/examples
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-10-14 21:12:50 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-14 21:13:45 -0700
commit300112e135f82b1f6faf8aa2ef266a27f07234de (patch)
tree01bce114b72da69b8fd88d727a8fcae2db1dd5dd /mlir/examples
parentf29731d17f469722c73e33b6d503be0ab39cf907 (diff)
downloadbcm5719-llvm-300112e135f82b1f6faf8aa2ef266a27f07234de.tar.gz
bcm5719-llvm-300112e135f82b1f6faf8aa2ef266a27f07234de.zip
Merge Ch3 of the Toy tutorial into chapter 2.
This effectively rewrites Ch.2 to introduce dialects, operations, and registration instead of deferring to Ch.3. This allows for introducing the best practices up front(using ODS, registering operations, etc.), and limits the opaque API to the chapter document instead of the code. PiperOrigin-RevId: 274724289
Diffstat (limited to 'mlir/examples')
-rw-r--r--mlir/examples/toy/Ch2/CMakeLists.txt9
-rw-r--r--mlir/examples/toy/Ch2/include/toy/Dialect.h53
-rw-r--r--mlir/examples/toy/Ch2/include/toy/Ops.td241
-rw-r--r--mlir/examples/toy/Ch2/mlir/Dialect.cpp156
-rw-r--r--mlir/examples/toy/Ch2/mlir/MLIRGen.cpp151
-rw-r--r--mlir/examples/toy/Ch2/toyc.cpp4
6 files changed, 526 insertions, 88 deletions
diff --git a/mlir/examples/toy/Ch2/CMakeLists.txt b/mlir/examples/toy/Ch2/CMakeLists.txt
index 12099634122..21d74dab530 100644
--- a/mlir/examples/toy/Ch2/CMakeLists.txt
+++ b/mlir/examples/toy/Ch2/CMakeLists.txt
@@ -1,3 +1,9 @@
+
+set(LLVM_TARGET_DEFINITIONS include/toy/Ops.td)
+mlir_tablegen(include/toy/Ops.h.inc -gen-op-decls)
+mlir_tablegen(include/toy/Ops.cpp.inc -gen-op-defs)
+add_public_tablegen_target(ToyCh2OpsIncGen)
+
set(LLVM_LINK_COMPONENTS
Support
)
@@ -6,8 +12,11 @@ add_toy_chapter(toyc-ch2
toyc.cpp
parser/AST.cpp
mlir/MLIRGen.cpp
+ mlir/Dialect.cpp
)
include_directories(include/)
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
+add_dependencies(toyc-ch2 ToyCh2OpsIncGen)
target_link_libraries(toyc-ch2
PRIVATE
MLIRAnalysis
diff --git a/mlir/examples/toy/Ch2/include/toy/Dialect.h b/mlir/examples/toy/Ch2/include/toy/Dialect.h
new file mode 100644
index 00000000000..91dd631d2ff
--- /dev/null
+++ b/mlir/examples/toy/Ch2/include/toy/Dialect.h
@@ -0,0 +1,53 @@
+//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the IR Dialect for the Toy language.
+// See g3doc/Tutorials/Toy/Ch-2.md for more information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+
+namespace mlir {
+namespace toy {
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and registers custom attributes, operations, and types (in its
+/// constructor). It can also override some general behavior exposed via virtual
+/// methods.
+class ToyDialect : public mlir::Dialect {
+public:
+ explicit ToyDialect(mlir::MLIRContext *ctx);
+
+ /// Provide a utility accessor to the dialect namespace. This is used by
+ /// several utilities for casting between dialects.
+ static llvm::StringRef getDialectNamespace() { return "toy"; }
+};
+
+/// Include the auto-generated header file containing the declarations of the
+/// toy operations.
+#define GET_OP_CLASSES
+#include "toy/Ops.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td
new file mode 100644
index 00000000000..59a4ff52281
--- /dev/null
+++ b/mlir/examples/toy/Ch2/include/toy/Ops.td
@@ -0,0 +1,241 @@
+//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef TOY_OPS
+#else
+#define TOY_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+// Provide a definition of the 'toy' dialect in the ODS framework so that we
+// can define our operations.
+def Toy_Dialect : Dialect {
+ let name = "toy";
+ let cppNamespace = "toy";
+}
+
+// Base class for toy dialect operations. This operation inherits from the base
+// `Op` class in OpBase.td, and provides:
+// * The parent dialect of the operation.
+// * The mnemonic for the operation, or the name without the dialect prefix.
+// * A list of traits for the operation.
+class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Toy_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+// We define a toy operation by inherting from our base 'Toy_Op' class above.
+// Here we provide the mnemonic and a list of traits for the operation. The
+// constant operation is marked as 'NoSideEffect' as it is a pure operation
+// and may be removed if dead.
+def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
+ // Provide a summary and description for this operation. This can be used to
+ // auto-generate documenatation of the operations within our dialect.
+ let summary = "constant";
+ let description = [{
+ Constant operation turns a literal into an SSA value. The data is attached
+ to the operation as an attribute. For example:
+
+ ```mlir
+ %0 = "toy.constant"()
+ { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
+ : () -> tensor<2x3xf64>
+ ```
+ }];
+
+ // The constant operation takes an attribute as the only input.
+ let arguments = (ins F64ElementsAttr:$value);
+
+ // The constant operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the constant operation. These method populates
+ // the `state` that MLIR uses to create operations, i.e. these are used when
+ // using `builder.create<ConstantOp>(...)`.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "DenseElementsAttr value", [{
+ build(builder, result, value.getType(), value);
+ }]>,
+
+ // Build a constant with a given constant floating-point value.
+ OpBuilder<"Builder *builder, OperationState &result, double value", [{
+ buildConstantOp(builder, result, value);
+ }]>
+ ];
+
+ // Invoke a static verify method to verify this constant operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def AddOp : Toy_Op<"add", [NoSideEffect]> {
+ let summary = "element-wise addition operation";
+ let description = [{
+ The "add" operation performs element-wise addition between two tensors.
+ The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building an AddOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildAddOp(b, result, lhs, rhs);
+ }]
+ >];
+}
+
+def GenericCallOp : Toy_Op<"generic_call"> {
+ let summary = "generic call operation";
+ let description = [{
+ Generic calls represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments. The callee name is attached
+ as a symbol reference via an attribute. The arguments list must match the
+ arguments expected by the callee. For example:
+
+ ```mlir
+ %4 = "toy.generic_call"(%1, %3) {callee = @my_func}
+ : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ ```
+
+ This is only valid if a function named "my_func" exists and takes two
+ arguments.
+ }];
+
+ // The generic call operation takes a symbol reference attribute as the
+ // callee, and inputs for the call.
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+
+ // The generic call operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the generic call operation.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "StringRef callee, ArrayRef<Value *> arguments", [{
+ buildGenericCallOp(builder, result, callee, arguments);
+ }]>
+ ];
+}
+
+def MulOp : Toy_Op<"mul", [NoSideEffect]> {
+ let summary = "element-wise multiplication operation";
+ let description = [{
+ The "mul" operation performs element-wise multiplication between two
+ tensors. The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building a MulOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildMulOp(b, result, lhs, rhs);
+ }]
+ >];
+}
+
+def PrintOp : Toy_Op<"print"> {
+ let summary = "print operation";
+ let description = [{
+ The "print" builtin operation prints a given input tensor, and produces
+ no results.
+ }];
+
+ // The print operation takes an input tensor to print.
+ let arguments = (ins F64Tensor:$input);
+}
+
+def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ Reshape operation is transforming its input tensor into a new tensor with
+ the same number of elements but different shapes. For example:
+
+ ```mlir
+ %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
+ ```
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+
+ // We expect that the reshape operation returns a statically shaped tensor.
+ let results = (outs StaticShapeTensorOf<[F64]>);
+}
+
+def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+ let summary = "return operation";
+ let description = [{
+ The "return" operation represents a return operation within a function.
+ The operation takes an optional tensor operand and produces no results.
+ The operand type must match the signature of the function that contains
+ the operation. For example:
+
+ ```mlir
+ func @foo() -> tensor<2xf64> {
+ ...
+ toy.return %0 : tensor<2xf64>
+ }
+ ```
+ }];
+
+ // The return operation takes an optional input operand to return. This
+ // value must match the return type of the enclosing function.
+ let arguments = (ins Variadic<F64Tensor>:$input);
+
+ // Allow building a ReturnOp with no return operand.
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ // Provide extra utility definitions on the c++ operation class definition.
+ let extraClassDeclaration = [{
+ bool hasOperand() { return getNumOperands() != 0; }
+ }];
+
+ // Invoke a static verify method to verify this return operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
+ let summary = "transpose operation";
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor);
+
+ // Allow building a TransposeOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *input", [{
+ buildTransposeOp(b, result, input);
+ }]
+ >];
+}
+
+#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
new file mode 100644
index 00000000000..dd2e7846144
--- /dev/null
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -0,0 +1,156 @@
+//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::toy;
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+static void buildConstantOp(mlir::Builder *builder,
+ mlir::OperationState &result, double value) {
+ auto dataType = builder->getTensorType({}, builder->getF64Type());
+ auto dataAttribute = DenseElementsAttr::get(dataType, value);
+ ConstantOp::build(builder, result, dataType, dataAttribute);
+}
+
+/// Verifier for the constant operation. This corresponds to the `::verify(...)`
+/// in the op definition.
+static mlir::LogicalResult verify(ConstantOp op) {
+ // If the return type of the constant is not an unranked tensor, the shape
+ // must match the shape of the attribute holding the data.
+ auto resultType = op.getResult()->getType().cast<mlir::RankedTensorType>();
+ if (!resultType)
+ return success();
+
+ // Check that the rank of the attribute type matches the rank of the constant
+ // result type.
+ auto attrType = op.value().getType().cast<mlir::TensorType>();
+ if (attrType.getRank() != resultType.getRank()) {
+ return op.emitOpError(
+ "return type must match the one of the attached value "
+ "attribute: ")
+ << attrType.getRank() << " != " << resultType.getRank();
+ }
+
+ // Check that each of the dimensions match between the two types.
+ for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ return op.emitOpError(
+ "return type shape mismatches its attribute at dimension ")
+ << dim << ": " << attrType.getShape()[dim]
+ << " != " << resultType.getShape()[dim];
+ }
+ }
+ return mlir::success();
+}
+
+static void buildAddOp(mlir::Builder *builder, mlir::OperationState &result,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ result.addTypes(builder->getTensorType(builder->getF64Type()));
+ result.addOperands({lhs, rhs});
+}
+
+static void buildGenericCallOp(mlir::Builder *builder,
+ mlir::OperationState &result, StringRef callee,
+ ArrayRef<mlir::Value *> arguments) {
+ // Generic call always returns an unranked Tensor initially.
+ result.addTypes(builder->getTensorType(builder->getF64Type()));
+ result.addOperands(arguments);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+}
+
+static void buildMulOp(mlir::Builder *builder, mlir::OperationState &result,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ result.addTypes(builder->getTensorType(builder->getF64Type()));
+ result.addOperands({lhs, rhs});
+}
+
+static mlir::LogicalResult verify(ReturnOp op) {
+ // We know that the parent operation is a function, because of the 'HasParent'
+ // trait attached to the operation definition.
+ auto function = cast<FuncOp>(op.getParentOp());
+
+ /// ReturnOps can only have a single optional operand.
+ if (op.getNumOperands() > 1)
+ return op.emitOpError() << "expects at most 1 return operand";
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getType().getResults();
+ if (op.getNumOperands() != results.size())
+ return op.emitOpError()
+ << "does not return the same number of values ("
+ << op.getNumOperands() << ") as the enclosing function ("
+ << results.size() << ")";
+
+ // If the operation does not have an input, we are done.
+ if (!op.hasOperand())
+ return mlir::success();
+
+ auto inputType = *op.operand_type_begin();
+ auto resultType = results.front();
+
+ // Check that the result type of the function matches the operand type.
+ if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
+ resultType.isa<mlir::UnrankedTensorType>())
+ return mlir::success();
+
+ return op.emitError() << "type of return operand ("
+ << *op.operand_type_begin()
+ << ") doesn't match function result type ("
+ << results.front() << ")";
+}
+
+static void buildTransposeOp(mlir::Builder *builder,
+ mlir::OperationState &result, mlir::Value *value) {
+ result.addTypes(builder->getTensorType(builder->getF64Type()));
+ result.addOperands(value);
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "toy/Ops.cpp.inc"
diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
index ab5a5848403..5f12d0a8798 100644
--- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
@@ -22,27 +22,29 @@
#include "toy/MLIRGen.h"
#include "toy/AST.h"
+#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
+using namespace mlir::toy;
using namespace toy;
+
+using llvm::ArrayRef;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
+using llvm::makeArrayRef;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
@@ -55,23 +57,16 @@ namespace {
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
-///
-/// At this point we take advantage of the "raw" MLIR APIs to create operations
-/// that haven't been registered in any way with MLIR. These operations are
-/// unknown to MLIR, custom passes could operate by string-matching the name of
-/// these operations, but no other type checking or semantics are associated
-/// with them natively by MLIR.
class MLIRGenImpl {
public:
- MLIRGenImpl(mlir::MLIRContext &context)
- : context(context), builder(&context) {}
+ MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module operation.
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
+ theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
@@ -80,9 +75,9 @@ public:
theModule.push_back(func);
}
- // FIXME: (in the next chapter...) without registering a dialect in MLIR,
- // this won't do much, but it should at least check some structural
- // properties of the generated MLIR module.
+ // Verify the module after we have finished constructing it, this will check
+ // the structural properties of the IR and invoke any specific verifiers we
+ // have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("module verification error");
return nullptr;
@@ -92,11 +87,6 @@ public:
}
private:
- /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
- /// ownership of many internal structures of the IR and provides a level of
- /// "uniquing" across multiple modules (types for instance).
- mlir::MLIRContext &context;
-
/// A "module" matches a Toy source file: containing a list of functions.
mlir::ModuleOp theModule;
@@ -129,14 +119,14 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
+ auto location = loc(proto.loc());
+
// This is a generic function, the return type will be inferred later.
- llvm::SmallVector<mlir::Type, 4> ret_types;
- // Arguments type is uniformly a generic array.
+ // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
- auto func_type = builder.getFunctionType(arg_types, ret_types);
- auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto func_type = builder.getFunctionType(arg_types, llvm::None);
+ auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
// Mark the function as generic: it'll require type specialization for every
// call site.
@@ -183,10 +173,16 @@ private:
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function.getBody().back().back().getName().getStringRef() !=
- "toy.return") {
- ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
- mlirGen(fakeRet);
+ ReturnOp returnOp;
+ if (!entryBlock.empty())
+ returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+ if (!returnOp) {
+ builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+ } else if (returnOp.hasOperand()) {
+ // Otherwise, if this return operation has an operand then add a result to
+ // the function.
+ function.setType(builder.getFunctionType(function.getType().getInputs(),
+ getType(VarType{})));
}
return function;
@@ -205,36 +201,25 @@ private:
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
- mlir::Value *L = mlirGen(*binop.getLHS());
- if (!L)
+ mlir::Value *lhs = mlirGen(*binop.getLHS());
+ if (!lhs)
return nullptr;
- mlir::Value *R = mlirGen(*binop.getRHS());
- if (!R)
+ mlir::Value *rhs = mlirGen(*binop.getRHS());
+ if (!rhs)
return nullptr;
auto location = loc(binop.loc());
// Derive the operation name from the binary operator. At the moment we only
// support '+' and '*'.
- const char *op_name = nullptr;
switch (binop.getOp()) {
case '+':
- op_name = "toy.add";
- break;
+ return builder.create<AddOp>(location, lhs, rhs);
case '*':
- op_name = "toy.mul";
- break;
- default:
- emitError(location, "error: invalid binary operator '")
- << binop.getOp() << "'";
- return nullptr;
+ return builder.create<MulOp>(location, lhs, rhs);
}
- // Build the MLIR operation from the name and the two operands. The return
- // type is always a generic array for binary operators.
- mlir::OperationState result(location, op_name);
- result.addTypes(getType(VarType{}));
- result.addOperands({L, R});
- return builder.createOperation(result)->getResult(0);
+ emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+ return nullptr;
}
/// This is a reference to a variable in an expression. The variable is
@@ -251,17 +236,18 @@ private:
/// Emit a return operation. This will return failure if any generation fails.
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
- mlir::OperationState result(loc(ret.loc()), "toy.return");
+ auto location = loc(ret.loc());
- // `return` takes an optional expression, we need to account for it here.
+ // 'return' takes an optional expression, handle that case here.
+ mlir::Value *expr = nullptr;
if (ret.getExpr().hasValue()) {
- auto *expr = mlirGen(*ret.getExpr().getValue());
- if (!expr)
+ if (!(expr = mlirGen(*ret.getExpr().getValue())))
return mlir::failure();
- result.addOperands(expr);
}
- builder.createOperation(result);
+ // Otherwise, this return operation has zero operands.
+ builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
+ : ArrayRef<mlir::Value *>());
return mlir::success();
}
@@ -303,11 +289,9 @@ private:
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
- // Build the MLIR op `toy.constant`, only boilerplate below.
- mlir::OperationState result(loc(lit.loc()), "toy.constant");
- result.addTypes(type);
- result.addAttribute("value", dataAttribute);
- return builder.createOperation(result)->getResult(0);
+ // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
+ // method.
+ return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
/// Recursive helper function to accumulate the data that compose an array
@@ -333,6 +317,7 @@ private:
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
llvm::StringRef callee = call.getCallee();
+ auto location = loc(call.loc());
// Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
@@ -346,20 +331,18 @@ private:
// Builting calls have their custom operation, meaning this is a
// straightforward emission.
if (callee == "transpose") {
- mlir::OperationState result(loc(call.loc()), "toy.transpose");
- result.addTypes(getType(VarType{}));
- result.operands = std::move(operands);
- return builder.createOperation(result)->getResult(0);
+ if (call.getArgs().size() != 1) {
+ emitError(location, "MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments");
+ return nullptr;
+ }
+ return builder.create<TransposeOp>(location, operands[0]);
}
- // Otherwise this is a call to a user-defined function. Calls to
- // user-defined functions are mapped to a custom call that takes the callee
- // name as an attribute.
- mlir::OperationState result(loc(call.loc()), "toy.generic_call");
- result.addTypes(getType(VarType{}));
- result.operands = std::move(operands);
- result.addAttribute("callee", builder.getSymbolRefAttr(callee));
- return builder.createOperation(result)->getResult(0);
+ // Otherwise this is a call to a user-defined function. Calls to ser-defined
+ // functions are mapped to a custom call that takes the callee name as an
+ // attribute.
+ return builder.create<GenericCallOp>(location, callee, operands);
}
/// Emit a print expression. It emits specific operations for two builtins:
@@ -369,19 +352,13 @@ private:
if (!arg)
return mlir::failure();
- mlir::OperationState result(loc(call.loc()), "toy.print");
- result.addOperands(arg);
- builder.createOperation(result);
+ builder.create<PrintOp>(loc(call.loc()), arg);
return mlir::success();
}
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
- mlir::OperationState result(loc(num.loc()), "toy.constant");
- mlir::Type elementType = builder.getF64Type();
- result.addTypes(builder.getTensorType({}, elementType));
- result.addAttribute("value", builder.getF64FloatAttr(num.getValue()));
- return builder.createOperation(result)->getResult(0);
+ return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}
/// Dispatch codegen for the right expression subclass using RTTI.
@@ -425,13 +402,11 @@ private:
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
- mlir::OperationState result(loc(vardecl.loc()), "toy.reshape");
- result.addTypes(getType(vardecl.getType()));
- result.addOperands(value);
- value = builder.createOperation(result)->getResult(0);
+ value = builder.create<ReshapeOp>(loc(vardecl.loc()),
+ getType(vardecl.getType()), value);
}
- // Register the value in the symbol table
+ // Register the value in the symbol table.
if (failed(declare(vardecl.getName(), value)))
return nullptr;
return value;
@@ -439,7 +414,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
- ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+ ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
@@ -465,7 +440,7 @@ private:
}
/// Build a tensor type from a list of shape dimensions.
- mlir::Type getType(llvm::ArrayRef<int64_t> shape) {
+ mlir::Type getType(ArrayRef<int64_t> shape) {
// If the shape is empty, then this type is unranked.
if (shape.empty())
return builder.getTensorType(builder.getF64Type());
@@ -474,8 +449,8 @@ private:
return builder.getTensorType(shape, builder.getF64Type());
}
- /// Build an MLIR type from a Toy AST variable type
- /// (forward to the generic getType(T) above).
+ /// Build an MLIR type from a Toy AST variable type (forward to the generic
+ /// getType above).
mlir::Type getType(const VarType &type) { return getType(type.shape); }
};
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index 7280ccb0959..547ac9e65b9 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "toy/Dialect.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
#include <memory>
@@ -75,6 +76,9 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
}
int dumpMLIR() {
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+
mlir::MLIRContext context;
// Handle '.toy' input to the compiler.
OpenPOWER on IntegriCloud