summaryrefslogtreecommitdiffstats
path: root/mlir/examples
diff options
context:
space:
mode:
authorSana Damani <sdamani@gatech.edu>2019-10-16 12:08:55 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-16 12:19:39 -0700
commit3940b90d84d7239f2bc849068df97f1d248554fe (patch)
treefdd5fb940df322d6cd358b552ebf7c3d84d72339 /mlir/examples
parente88dbc8c955a9c17d5ef444c716633752ced338c (diff)
downloadbcm5719-llvm-3940b90d84d7239f2bc849068df97f1d248554fe.tar.gz
bcm5719-llvm-3940b90d84d7239f2bc849068df97f1d248554fe.zip
Update Chapter 4 of the Toy tutorial
This Chapter now introduces and makes use of the Interface concept in MLIR to demonstrate ShapeInference. END_PUBLIC Closes tensorflow/mlir#191 PiperOrigin-RevId: 275085151
Diffstat (limited to 'mlir/examples')
-rw-r--r--mlir/examples/toy/Ch4/CMakeLists.txt15
-rw-r--r--mlir/examples/toy/Ch4/include/CMakeLists.txt1
-rw-r--r--mlir/examples/toy/Ch4/include/toy/AST.h15
-rw-r--r--mlir/examples/toy/Ch4/include/toy/CMakeLists.txt9
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Dialect.h321
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Lexer.h2
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Ops.td285
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Passes.h7
-rw-r--r--mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td38
-rw-r--r--mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp61
-rw-r--r--mlir/examples/toy/Ch4/mlir/Dialect.cpp190
-rw-r--r--mlir/examples/toy/Ch4/mlir/MLIRGen.cpp413
-rw-r--r--mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp330
-rw-r--r--mlir/examples/toy/Ch4/mlir/ToyCombine.cpp127
-rw-r--r--mlir/examples/toy/Ch4/mlir/ToyCombine.td73
-rw-r--r--mlir/examples/toy/Ch4/mlir/ToyDialect.cpp387
-rw-r--r--mlir/examples/toy/Ch4/toyc.cpp89
17 files changed, 1005 insertions, 1358 deletions
diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt
index 11972e567f1..dde70db25b6 100644
--- a/mlir/examples/toy/Ch4/CMakeLists.txt
+++ b/mlir/examples/toy/Ch4/CMakeLists.txt
@@ -1,16 +1,29 @@
+add_subdirectory(include)
+
set(LLVM_LINK_COMPONENTS
Support
)
+set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
+mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
+add_public_tablegen_target(ToyCh4CombineIncGen)
+
add_toy_chapter(toyc-ch4
toyc.cpp
parser/AST.cpp
mlir/MLIRGen.cpp
- mlir/ToyDialect.cpp
+ mlir/Dialect.cpp
+ mlir/DeadFunctionEliminationPass.cpp
mlir/ShapeInferencePass.cpp
mlir/ToyCombine.cpp
)
+
+add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
+add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
+add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
include_directories(include/)
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch4
PRIVATE
MLIRAnalysis
diff --git a/mlir/examples/toy/Ch4/include/CMakeLists.txt b/mlir/examples/toy/Ch4/include/CMakeLists.txt
new file mode 100644
index 00000000000..37c89d0bae9
--- /dev/null
+++ b/mlir/examples/toy/Ch4/include/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(toy)
diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h
index 456a32309c4..2ad3392c11a 100644
--- a/mlir/examples/toy/Ch4/include/toy/AST.h
+++ b/mlir/examples/toy/Ch4/include/toy/AST.h
@@ -33,10 +33,9 @@
namespace toy {
-/// A variable
+/// A variable type with shape information.
struct VarType {
- enum { TY_FLOAT, TY_INT } elt_ty;
- std::vector<int> shape;
+ std::vector<int64_t> shape;
};
/// Base class for all expression nodes.
@@ -50,9 +49,7 @@ public:
Expr_Var,
Expr_BinOp,
Expr_Call,
- Expr_Print, // builtin
- Expr_If,
- Expr_For,
+ Expr_Print,
};
ExprAST(ExprASTKind kind, Location location)
@@ -85,7 +82,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};
-///
+/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
@@ -116,7 +113,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};
-///
+/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
@@ -136,7 +133,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};
-///
+/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;
diff --git a/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt
new file mode 100644
index 00000000000..798d0df1d8d
--- /dev/null
+++ b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt
@@ -0,0 +1,9 @@
+set(LLVM_TARGET_DEFINITIONS Ops.td)
+mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
+add_public_tablegen_target(ToyCh4OpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
+mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen)
diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h
index b0838870b5a..da61191c6c0 100644
--- a/mlir/examples/toy/Ch4/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h
@@ -16,7 +16,7 @@
// =============================================================================
//
// This file implements the IR Dialect for the Toy language.
-// See g3doc/Tutorials/Toy/Ch-3.md for more information.
+// See g3doc/Tutorials/Toy/Ch-2.md for more information.
//
//===----------------------------------------------------------------------===//
@@ -25,325 +25,30 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/TypeSupport.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/StandardTypes.h"
namespace mlir {
-class Builder;
-}
-
namespace toy {
/// This is the definition of the Toy dialect. A dialect inherits from
-/// mlir::Dialect and register custom operations and types (in its constructor).
-/// It can also overriding general behavior of dialects exposed as virtual
-/// method, for example regarding verification and parsing/printing.
+/// mlir::Dialect and registers custom attributes, operations, and types (in its
+/// constructor). It can also override some general behavior exposed via virtual
+/// methods.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);
- /// Parse a type registered to this dialect. Overriding this method is
- /// required for dialects that have custom types.
- /// Technically this is only needed to be able to round-trip to textual IR.
- mlir::Type parseType(llvm::StringRef tyData,
- mlir::Location loc) const override;
-
- /// Print a type registered to this dialect. Overriding this method is
- /// only required for dialects that have custom types.
- /// Technically this is only needed to be able to round-trip to textual IR.
- void printType(mlir::Type type, llvm::raw_ostream &os) const override;
-};
-
-////////////////////////////////////////////////////////////////////////////////
-/////////////////////// Custom Types for the Dialect ///////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-namespace detail {
-struct ToyArrayTypeStorage;
-}
-
-/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
-enum ToyTypeKind {
- // The enum starts at the range reserved for this dialect.
- TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
- TOY_ARRAY,
+ /// Provide a utility accessor to the dialect namespace. This is used by
+ /// several utilities for casting between dialects.
+ static llvm::StringRef getDialectNamespace() { return "toy"; }
};
-/// Type for Toy arrays.
-/// In MLIR Types are reference to immutable and uniqued objects owned by the
-/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
-/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
-/// provides the public facade API to interact with the type.
-class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
- detail::ToyArrayTypeStorage> {
-public:
- using Base::Base;
-
- /// Returns the dimensions for this array, or and empty range for a generic
- /// array.
- llvm::ArrayRef<int64_t> getShape();
-
- /// Predicate to test if this array is generic (shape haven't been inferred
- /// yet).
- bool isGeneric() { return getShape().empty(); }
-
- /// Return the rank of this array (0 if it is generic).
- int getRank() { return getShape().size(); }
-
- /// Return the type of individual elements in the array.
- mlir::Type getElementType();
-
- /// Get the unique instance of this Type from the context.
- /// A ToyArrayType is only defined by the shape of the array.
- static ToyArrayType get(mlir::MLIRContext *context,
- llvm::ArrayRef<int64_t> shape = {});
-
- /// Support method to enable LLVM-style RTTI type casting.
- static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
-};
-
-////////////////////////////////////////////////////////////////////////////////
-//////////////////// Custom Operations for the Dialect /////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-/// Constant operation turns a literal into an SSA value. The data is attached
-/// to the operation as an attribute. For example:
-///
-/// %0 = "toy.constant"()
-/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
-/// : () -> !toy.array<2, 3>
-///
-/// An operation inherits from `class Op` and specifies optional traits. Here we
-/// indicate that `toy.constant` does not have any operands and returns a single
-/// result. The traits provide some utilities methods for the operation, for
-/// instance we will be able to use `getResult()`, but `getOperand()` won't be
-/// available.
-class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
-public:
- /// This is the name used by MLIR to match an operation to this class during
- /// parsing.
- static llvm::StringRef getOperationName() { return "toy.constant"; }
-
- /// The operation can have extra verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<PrintOp>(...)
- /// This method populates the `state` that MLIR uses to create operations.
- /// The `toy.constant` operation does not have arguments but attaches a
- /// constant array as an attribute and returns it as an SSA value.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- llvm::ArrayRef<int64_t> shape,
- mlir::DenseElementsAttr value);
-
- /// Similar to the one above, but takes a single float and returns a
- /// !toy.array<1>.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::FloatAttr value);
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// Generic calls represent calls to a user defined function that needs to
-/// be specialized for the shape of its arguments. The callee name is attached
-/// as a literal string as an attribute. The arguments list must match the
-/// arguments expected by the callee. For example:
-///
-/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
-/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
-///
-/// This is only valid if a function named "my_func" exists and takes two
-/// arguments.
-class GenericCallOp
- : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
- mlir::OpTrait::OneResult> {
-public:
- /// MLIR will use this to register the operation with the parser/printer.
- static llvm::StringRef getOperationName() { return "toy.generic_call"; }
-
- /// Operations can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to the builder to allow:
- /// mlir::Builder::create<GenericCallOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.generic_call` operation accepts a callee name and a list of
- /// arguments for the call.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- llvm::StringRef callee,
- llvm::ArrayRef<mlir::Value *> arguments);
-
- /// Return the name of the callee.
- llvm::StringRef getCalleeName();
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// Return operations terminate blocks (and functions as well). They take a
-/// single argument and the type must match the function return type.
-class ReturnOp
- : public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
- mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
-public:
- static llvm::StringRef getOperationName() { return "toy.return"; }
-
- /// Operations can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<PrintOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.return` operation accepts an optional single array as an argument
- /// and does not have any returned value.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value = nullptr);
-
- /// Return true if there is a returned value.
- bool hasOperand() { return 0 != getNumOperands(); }
-
- /// Helper to return the optional operand. Caller must check if the operand
- /// is present before calling this.
- mlir::Value *getOperand() { return getOperation()->getOperand(0); }
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// The print builtin takes a single array argument and does not return any.
-class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
- mlir::OpTrait::ZeroResult> {
-public:
- static llvm::StringRef getOperationName() { return "toy.print"; }
-
- /// Operations can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<PrintOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.print` operation accepts a single array as argument and does
- /// not have any returned value.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value);
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
-public:
- static llvm::StringRef getOperationName() { return "toy.transpose"; }
-
- /// Operation can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<TransposeOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.transpose` operation accepts a single array as argument and
- /// returns the transposed array as its only result.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value);
-
- // Register our patterns for rewrite by the Canonicalization framework.
- static void
- getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
- mlir::MLIRContext *context);
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// Reshape operation is transforming its input array into a new array with the
-/// same number of elements but different shapes. For example:
-///
-/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2>
-///
-class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
-public:
- static llvm::StringRef getOperationName() { return "toy.reshape"; }
-
- /// Operation can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<ReshapeOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.reshape` operation accepts a single array as argument and
- /// returns the array with the specified reshapedType as its only result.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value, ToyArrayType reshapedType);
-
- // Register our patterns for rewrite by the Canonicalization framework.
- static void
- getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
- mlir::MLIRContext *context);
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// Binary operation implementing a multiplication. For two-dimensional array
-/// a matrix multiplication is implemented, while for one dimensional array a
-/// dot product is performed.
-class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
-public:
- static llvm::StringRef getOperationName() { return "toy.mul"; }
-
- /// Operation can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<PrintOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.mul` operation accepts two operands as argument and returns
- /// a single value.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *lhs, mlir::Value *rhs);
-
- /// Convenience accessor for LHS of the expression.
- mlir::Value *getLHS() { return getOperand(0); }
-
- /// Convenience accessor for RHS of the expression.
- mlir::Value *getRHS() { return getOperand(1); }
-
- /// Inherit constructor.
- using Op::Op;
-};
-
-/// Element wise addition of two arrays. The shape must match.
-class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
- mlir::OpTrait::OneResult,
- mlir::OpTrait::HasNoSideEffect> {
-public:
- static llvm::StringRef getOperationName() { return "toy.add"; }
-
- /// Operation can add custom verification beyond the traits they define.
- mlir::LogicalResult verify();
-
- /// Interface to mlir::Builder::create<PrintOp>(...)
- /// This method populate the `state` that MLIR use to create operations.
- /// The `toy.mul` operation accepts two operands as argument and returns
- /// a single value.
- static void build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *lhs, mlir::Value *rhs);
-
- /// Convenience accessor for LHS of the expression.
- mlir::Value *getLHS() { return getOperand(0); }
-
- /// Convenience accessor for RHS of the expression.
- mlir::Value *getRHS() { return getOperand(1); }
-
- /// Inherit constructor.
- using Op::Op;
-};
+/// Include the auto-generated header file containing the declarations of the
+/// toy operations.
+#define GET_OP_CLASSES
+#include "toy/Ops.h.inc"
} // end namespace toy
+} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h
index d73adb9706b..21f92614912 100644
--- a/mlir/examples/toy/Ch4/include/toy/Lexer.h
+++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h
@@ -31,7 +31,7 @@ namespace toy {
/// Structure definition a location in a file.
struct Location {
- std::shared_ptr<std::string> file; ///< filename
+ std::shared_ptr<std::string> file; ///< filename.
int line; ///< line number.
int col; ///< column number.
};
diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
new file mode 100644
index 00000000000..f0140d70f9b
--- /dev/null
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -0,0 +1,285 @@
+//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Toy dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef TOY_OPS
+#else
+#define TOY_OPS
+
+#ifdef SHAPE_INFERENCE_INTERFACE
+#else
+include "toy/ShapeInferenceInterface.td"
+#endif // SHAPE_INFERENCE_INTERFACE
+
+// Provide a definition of the 'toy' dialect in the ODS framework so that we
+// can define our operations.
+def Toy_Dialect : Dialect {
+ let name = "toy";
+ let cppNamespace = "toy";
+}
+
+// Base class for toy dialect operations. This operation inherits from the base
+// `Op` class in OpBase.td, and provides:
+// * The parent dialect of the operation.
+// * The mnemonic for the operation, or the name without the dialect prefix.
+// * A list of traits for the operation.
+class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Toy_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+// We define a toy operation by inherting from our base 'Toy_Op' class above.
+// Here we provide the mnemonic and a list of traits for the operation. The
+// constant operation is marked as 'NoSideEffect' as it is a pure operation
+// and may be removed if dead.
+def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
+ // Provide a summary and description for this operation. This can be used to
+ // auto-generate documenatation of the operations within our dialect.
+ let summary = "constant";
+ let description = [{
+ Constant operation turns a literal into an SSA value. The data is attached
+ to the operation as an attribute. For example:
+
+ ```mlir
+ %0 = "toy.constant"()
+ { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
+ : () -> tensor<2x3xf64>
+ ```
+ }];
+
+ // The constant operation takes an attribute as the only input.
+ let arguments = (ins F64ElementsAttr:$value);
+
+ // The constant operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the constant operation. These method populates
+ // the `state` that MLIR uses to create operations, i.e. these are used when
+ // using `builder.create<ConstantOp>(...)`.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "DenseElementsAttr value", [{
+ build(builder, result, value.getType(), value);
+ }]>,
+
+ // Build a constant with a given constant floating-point value.
+ OpBuilder<"Builder *builder, OperationState &result, double value", [{
+ buildConstantOp(builder, result, value);
+ }]>
+ ];
+
+ // Invoke a static verify method to verify this constant operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def AddOp : Toy_Op<"add", [NoSideEffect]> {
+ let summary = "element-wise addition operation";
+ let description = [{
+ The "add" operation performs element-wise addition between two tensors.
+ The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building an AddOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildAddOp(b, result, lhs, rhs);
+ }]
+ >];
+ let extraClassDeclaration = [{
+ void inferShapes() {
+ getResult()->setType(getOperand(0)->getType());
+ return;
+ }
+ }];
+}
+
+def GenericCallOp : Toy_Op<"generic_call"> {
+ let summary = "generic call operation";
+ let description = [{
+ Generic calls represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments. The callee name is attached
+ as a symbol reference via an attribute. The arguments list must match the
+ arguments expected by the callee. For example:
+
+ ```mlir
+ %4 = "toy.generic_call"(%1, %3) {callee = @my_func}
+ : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
+ ```
+
+ This is only valid if a function named "my_func" exists and takes two
+ arguments.
+ }];
+
+ // The generic call operation takes a symbol reference attribute as the
+ // callee, and inputs for the call.
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+
+ // The generic call operation returns a single value of TensorType.
+ let results = (outs F64Tensor);
+
+ // Add custom build methods for the generic call operation.
+ let builders = [
+ // Build a constant with a given constant tensor value.
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "StringRef callee, ArrayRef<Value *> arguments", [{
+ buildGenericCallOp(builder, result, callee, arguments);
+ }]>
+ ];
+}
+
+def MulOp : Toy_Op<"mul", [NoSideEffect]> {
+ let summary = "element-wise multiplication operation";
+ let description = [{
+ The "mul" operation performs element-wise multiplication between two
+ tensors. The shapes of the tensor operands are expected to match.
+ }];
+
+ let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
+ let results = (outs F64Tensor);
+
+ // Allow building a MulOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
+ buildMulOp(b, result, lhs, rhs);
+ }]
+ >];
+ let extraClassDeclaration = [{
+ void inferShapes() {
+ auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
+ auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
+ auto lhsRank = lhs.getShape().size();
+ auto rhsRank = rhs.getShape().size();
+ if (lhsRank != rhsRank) {
+ return;
+ }
+ SmallVector<int64_t, 2> dims;
+ if (lhsRank == 1) {
+ // dot product, result shape is <1>
+ dims.push_back(1);
+ } else {
+ if (lhsRank != 2) {
+ return;
+ }
+ dims.push_back(lhs.getShape()[0]);
+ dims.push_back(rhs.getShape()[1]);
+ }
+ getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
+ return;
+ }
+ }];
+}
+
+def PrintOp : Toy_Op<"print"> {
+ let summary = "print operation";
+ let description = [{
+ The "print" builtin operation prints a given input tensor, and produces
+ no results.
+ }];
+
+ // The print operation takes an input tensor to print.
+ let arguments = (ins F64Tensor:$input);
+}
+
+def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
+ let summary = "tensor reshape operation";
+ let description = [{
+ Reshape operation is transforming its input tensor into a new tensor with
+ the same number of elements but different shapes. For example:
+
+ ```mlir
+ %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
+ ```
+ }];
+
+ let arguments = (ins F64Tensor:$input);
+ let hasCanonicalizer = 1;
+
+ // We expect that the reshape operation returns a statically shaped tensor.
+ let results = (outs StaticShapeTensorOf<[F64]>);
+}
+
+def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
+ let summary = "return operation";
+ let description = [{
+ The "return" operation represents a return operation within a function.
+ The operation takes an optional tensor operand and produces no results.
+ The operand type must match the signature of the function that contains
+ the operation. For example:
+
+ ```mlir
+ func @foo() -> tensor<2xf64> {
+ ...
+ toy.return %0 : tensor<2xf64>
+ }
+ ```
+ }];
+
+ // The return operation takes an optional input operand to return. This
+ // value must match the return type of the enclosing function.
+ let arguments = (ins Variadic<F64Tensor>:$input);
+
+ // Allow building a ReturnOp with no return operand.
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ // Provide extra utility definitions on the c++ operation class definition.
+ let extraClassDeclaration = [{
+ bool hasOperand() { return getNumOperands() != 0; }
+ }];
+
+ // Invoke a static verify method to verify this return operation.
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
+ let summary = "transpose operation";
+
+ let arguments = (ins F64Tensor:$input);
+ let results = (outs F64Tensor);
+ let hasCanonicalizer = 1;
+
+ // Allow building a TransposeOp with from the two input operands.
+ let builders = [
+ OpBuilder<"Builder *b, OperationState &result, Value *input", [{
+ buildTransposeOp(b, result, input);
+ }]
+ >];
+ let extraClassDeclaration = [{
+ void inferShapes() {
+ SmallVector<int64_t, 2> dims;
+ auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
+ dims.insert(dims.end(), arrayTy.getShape().begin(),
+ arrayTy.getShape().end());
+ if (dims.size() == 2)
+ std::swap(dims[0], dims[1]);
+ getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
+ return;
+ }
+ }];
+}
+
+#endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch4/include/toy/Passes.h b/mlir/examples/toy/Ch4/include/toy/Passes.h
index 93cf0d5ba15..8c8365d6882 100644
--- a/mlir/examples/toy/Ch4/include/toy/Passes.h
+++ b/mlir/examples/toy/Ch4/include/toy/Passes.h
@@ -26,10 +26,11 @@
namespace mlir {
class Pass;
-} // namespace mlir
namespace toy {
-std::unique_ptr<mlir::Pass> createShapeInferencePass();
-} // namespace toy
+std::unique_ptr<Pass> createShapeInferencePass();
+std::unique_ptr<Pass> createDeadFunctionEliminationPass();
+} // end namespace toy
+} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_PASSES_H
diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td
new file mode 100644
index 00000000000..2040cc44fdf
--- /dev/null
+++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td
@@ -0,0 +1,38 @@
+//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines the operations of the Shape Inference Op Interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef SHAPE_INFERENCE_INTERFACE
+#else
+#define SHAPE_INFERENCE_INTERFACE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
+ let methods = [
+ InterfaceMethod<"Infer output shape for the current operation.",
+ "void", "inferShapes", (ins), [{}]>
+ ];
+}
+
+#endif // SHAPE_INFERENCE_INTERFACE
diff --git a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp
new file mode 100644
index 00000000000..e7e64ce5b3d
--- /dev/null
+++ b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp
@@ -0,0 +1,61 @@
+//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Module level pass performing dead function
+// elimination. This is required as a post-processing step after function
+// inlining.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "toy/Passes.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace {
+class DeadFunctionEliminationPass
+ : public mlir::ModulePass<DeadFunctionEliminationPass> {
+public:
+ void runOnModule() override {
+ std::string str = "main";
+ auto module = getModule();
+ for (auto &f : module) {
+ // eliminate dead functions that are not main
+ if (str.find(f.getName().getStringRef()) == std::string::npos)
+ f.erase();
+ }
+ }
+};
+} // namespace
+
+/// Create a pass that eliminates inlined functions in toy.
+std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
+ return std::make_unique<DeadFunctionEliminationPass>();
+}
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
new file mode 100644
index 00000000000..63eee4eefb8
--- /dev/null
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -0,0 +1,190 @@
+//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+using namespace mlir;
+using namespace mlir::toy;
+
+//===----------------------------------------------------------------------===//
+// ToyInlinerInterface
+//===----------------------------------------------------------------------===//
+
+/// This class defines the interface for handling inlining with Toy
+/// operations.
+struct ToyInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// All operations within toy can be inlined.
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator(toy.return) by replacing it with a new
+ /// operation as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value *> valuesToRepl) const final {
+ // Only "toy.return" needs to be handled here.
+ auto returnOp = cast<ReturnOp>(op);
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+ >();
+ addInterfaces<ToyInlinerInterface>();
+}
+
+//===----------------------------------------------------------------------===//
+// Toy Operations
+//===----------------------------------------------------------------------===//
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
+ double value) {
+ auto dataType = builder->getTensorType({}, builder->getF64Type());
+ auto dataAttribute = DenseElementsAttr::get(dataType, value);
+ ConstantOp::build(builder, state, dataType, dataAttribute);
+}
+
+/// Verifier for constant operation.
+static mlir::LogicalResult verify(ConstantOp op) {
+ // If the return type of the constant is not an unranked tensor, the shape
+ // must match the shape of the attribute holding the data.
+ auto resultType = op.getResult()->getType().cast<RankedTensorType>();
+ if (!resultType)
+ return success();
+
+ auto attrType = op.value().getType().cast<mlir::TensorType>();
+ if (attrType.getRank() != resultType.getRank()) {
+ return op.emitOpError(
+ "return type must match the one of the attached value "
+ "attribute: ")
+ << attrType.getRank() << " != " << resultType.getRank();
+ }
+ for (int dim = 0; dim < attrType.getRank(); ++dim) {
+ if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+ return op.emitOpError(
+ "return type shape mismatches its attribute at dimension ")
+ << dim << ": " << attrType.getShape()[dim]
+ << " != " << resultType.getShape()[dim];
+ }
+ }
+ return mlir::success();
+}
+
+static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+static void buildGenericCallOp(mlir::Builder *builder,
+ mlir::OperationState &state, StringRef callee,
+ ArrayRef<mlir::Value *> arguments) {
+ // Generic call always returns an unranked Tensor initially.
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands(arguments);
+ state.addAttribute("callee", builder->getSymbolRefAttr(callee));
+}
+
+static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
+ mlir::Value *lhs, mlir::Value *rhs) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands({lhs, rhs});
+}
+
+static mlir::LogicalResult verify(ReturnOp op) {
+ // We know that the parent operation is a function, because of the 'HasParent'
+ // trait attached to the operation definition.
+ auto function = cast<FuncOp>(op.getParentOp());
+
+ /// ReturnOps can only have a single optional operand.
+ if (op.getNumOperands() > 1)
+ return op.emitOpError() << "expects at most 1 return operand";
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getType().getResults();
+ if (op.getNumOperands() != results.size())
+ return op.emitOpError()
+ << "does not return the same number of values ("
+ << op.getNumOperands() << ") as the enclosing function ("
+ << results.size() << ")";
+
+ // If the operation does not have an input, we are done.
+ if (!op.hasOperand())
+ return mlir::success();
+
+ auto inputType = *op.operand_type_begin();
+ auto resultType = results.front();
+
+ // Check that the result type of the function matches the operand type.
+ if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
+ resultType.isa<mlir::UnrankedTensorType>())
+ return mlir::success();
+
+ return op.emitError() << "type of return operand ("
+ << *op.operand_type_begin()
+ << ") doesn't match function result type ("
+ << results.front() << ")";
+}
+
+static void buildTransposeOp(mlir::Builder *builder,
+ mlir::OperationState &state, mlir::Value *value) {
+ state.addTypes(builder->getTensorType(builder->getF64Type()));
+ state.addOperands(value);
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "toy/Ops.cpp.inc"
diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
index c66335a70d5..ace52aff2bf 100644
--- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
@@ -25,30 +25,30 @@
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
+using namespace mlir::toy;
using namespace toy;
+
+using llvm::ArrayRef;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
+using llvm::makeArrayRef;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-using std::make_unique;
namespace {
@@ -57,56 +57,43 @@ namespace {
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
-///
-/// At this point we take advantage of the "raw" MLIR APIs to create operations
-/// that haven't been registered in any way with MLIR. These operations are
-/// unknown to MLIR, custom passes could operate by string-matching the name of
-/// these operations, but no other type checking or semantic is associated with
-/// them natively by MLIR.
class MLIRGenImpl {
public:
- MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
+ MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
- /// Module.
- mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
+ /// Module operation.
+ mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
+ theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->push_back(func);
+ theModule.push_back(func);
}
- // FIXME: (in the next chapter...) without registering a dialect in MLIR,
- // this won't do much, but it should at least check some structural
- // properties.
- if (failed(mlir::verify(*theModule))) {
- emitError(mlir::UnknownLoc::get(&context), "module verification error");
+ // Verify the module after we have finished constructing it, this will check
+ // the structural properties of the IR and invoke any specific verifiers we
+ // have on the Toy operations.
+ if (failed(mlir::verify(theModule))) {
+ theModule.emitError("Module verification error");
return nullptr;
}
- return std::move(theModule);
+ return theModule;
}
private:
- /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
- /// the ownership of many internal structure of the IR and provide a level
- /// of "uniquing" across multiple modules (types for instance).
- mlir::MLIRContext &context;
-
- /// A "module" matches a source file: it contains a list of functions.
- mlir::OwningModuleRef theModule;
+ /// A "module" matches a Toy source file: containing a list of functions.
+ mlir::ModuleOp theModule;
- /// The builder is a helper class to create IR inside a function. It is
- /// re-initialized every time we enter a function and kept around as a
- /// convenience for emitting individual operations.
- /// The builder is stateful, in particular it keeps an "insertion point":
- /// this is where the next operations will be introduced.
- std::unique_ptr<mlir::OpBuilder> builder;
+ /// The builder is a helper class to create IR inside a function. The builder
+ /// is stateful, in particular it keeeps an "insertion point": this is where
+ /// the next operations will be introduced.
+ mlir::OpBuilder builder;
/// The symbol table maps a variable name to a value in the current scope.
/// Entering a function creates a new scope, and the function arguments are
@@ -116,37 +103,35 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
- return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
- loc.line, loc.col, &context);
+ return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
+ loc.col);
}
- /// Declare a variable in the current scope, return true if the variable
+ /// Declare a variable in the current scope, return success if the variable
/// wasn't declared yet.
- bool declare(llvm::StringRef var, mlir::Value *value) {
- if (symbolTable.count(var)) {
- return false;
- }
+ mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
+ if (symbolTable.count(var))
+ return mlir::failure();
symbolTable.insert(var, value);
- return true;
+ return mlir::success();
}
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
+ auto location = loc(proto.loc());
+
// This is a generic function, the return type will be inferred later.
- llvm::SmallVector<mlir::Type, 4> ret_types;
- // Arguments type is uniformly a generic array.
+ // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
- auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
- auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
- func_type, /* attrs = */ {});
+ auto func_type = builder.getFunctionType(arg_types, llvm::None);
+ auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function.getNumArguments())
- function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
-
+ function.setAttr("toy.generic", builder.getUnitAttr());
return function;
}
@@ -165,29 +150,39 @@ private:
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
auto &protoArgs = funcAST.getProto()->getArgs();
+
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
llvm::zip(protoArgs, entryBlock.getArguments())) {
- declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
+ if (failed(declare(std::get<0>(name_value)->getName(),
+ std::get<1>(name_value))))
+ return nullptr;
}
- // Create a builder for the function, it will be used throughout the codegen
- // to create operations in this function.
- builder = std::make_unique<mlir::OpBuilder>(function.getBody());
+ // Set the insertion point in the builder to the beginning of the function
+ // body, it will be used throughout the codegen to create operations in this
+ // function.
+ builder.setInsertionPointToStart(&entryBlock);
// Emit the body of the function.
- if (!mlirGen(*funcAST.getBody())) {
+ if (mlir::failed(mlirGen(*funcAST.getBody()))) {
function.erase();
return nullptr;
}
- // Implicitly return void if no return statement was emited.
+ // Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
- if (function.getBlocks().back().back().getName().getStringRef() !=
- "toy.return") {
- ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
- mlirGen(fakeRet);
+ ReturnOp returnOp;
+ if (!entryBlock.empty())
+ returnOp = dyn_cast<ReturnOp>(entryBlock.back());
+ if (!returnOp) {
+ builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
+ } else if (returnOp.hasOperand()) {
+ // Otherwise, if this return operation has an operand then add a result to
+ // the function.
+ function.setType(builder.getFunctionType(function.getType().getInputs(),
+ getType(VarType{})));
}
return function;
@@ -206,11 +201,11 @@ private:
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
- mlir::Value *L = mlirGen(*binop.getLHS());
- if (!L)
+ mlir::Value *lhs = mlirGen(*binop.getLHS());
+ if (!lhs)
return nullptr;
- mlir::Value *R = mlirGen(*binop.getRHS());
- if (!R)
+ mlir::Value *rhs = mlirGen(*binop.getRHS());
+ if (!rhs)
return nullptr;
auto location = loc(binop.loc());
@@ -218,123 +213,112 @@ private:
// support '+' and '*'.
switch (binop.getOp()) {
case '+':
- return builder->create<AddOp>(location, L, R).getResult();
- break;
+ return builder.create<AddOp>(location, lhs, rhs);
case '*':
- return builder->create<MulOp>(location, L, R).getResult();
- default:
- emitError(location, "error: invalid binary operator '")
- << binop.getOp() << "'";
- return nullptr;
+ return builder.create<MulOp>(location, lhs, rhs);
}
+
+ emitError(location, "invalid binary operator '") << binop.getOp() << "'";
+ return nullptr;
}
- // This is a reference to a variable in an expression. The variable is
- // expected to have been declared and so should have a value in the symbol
- // table, otherwise emit an error and return nullptr.
+ /// This is a reference to a variable in an expression. The variable is
+ /// expected to have been declared and so should have a value in the symbol
+ /// table, otherwise emit an error and return nullptr.
mlir::Value *mlirGen(VariableExprAST &expr) {
- if (symbolTable.count(expr.getName()))
- return symbolTable.lookup(expr.getName());
- emitError(loc(expr.loc()), "error: unknown variable '")
+ if (auto *variable = symbolTable.lookup(expr.getName()))
+ return variable;
+
+ emitError(loc(expr.loc()), "Error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
- // Emit a return operation, return true on success.
- bool mlirGen(ReturnExprAST &ret) {
+ /// Emit a return operation. This will return failure if any generation fails.
+ mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
auto location = loc(ret.loc());
- // `return` takes an optional expression, we need to account for it here.
- if (!ret.getExpr().hasValue()) {
- builder->create<ReturnOp>(location);
- return true;
+
+ // 'return' takes an optional expression, handle that case here.
+ mlir::Value *expr = nullptr;
+ if (ret.getExpr().hasValue()) {
+ if (!(expr = mlirGen(*ret.getExpr().getValue())))
+ return mlir::failure();
}
- auto *expr = mlirGen(*ret.getExpr().getValue());
- if (!expr)
- return false;
- builder->create<ReturnOp>(location, expr);
- return true;
+
+ // Otherwise, this return operation has zero operands.
+ builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
+ : ArrayRef<mlir::Value *>());
+ return mlir::success();
}
- // Emit a literal/constant array. It will be emitted as a flattened array of
- // data in an Attribute attached to a `toy.constant` operation.
- // See documentation on [Attributes](LangRef.md#attributes) for more details.
- // Here is an excerpt:
- //
- // Attributes are the mechanism for specifying constant data in MLIR in
- // places where a variable is never allowed [...]. They consist of a name
- // and a [concrete attribute value](#attribute-values). It is possible to
- // attach attributes to operations, functions, and function arguments. The
- // set of expected attributes, their structure, and their interpretation
- // are all contextually dependent on what they are attached to.
- //
- // Example, the source level statement:
- // var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
- // will be converted to:
- // %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
- // [[1.000000e+00, 2.000000e+00, 3.000000e+00],
- // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
- //
+ /// Emit a literal/constant array. It will be emitted as a flattened array of
+ /// data in an Attribute attached to a `toy.constant` operation.
+ /// See documentation on [Attributes](LangRef.md#attributes) for more details.
+ /// Here is an excerpt:
+ ///
+ /// Attributes are the mechanism for specifying constant data in MLIR in
+ /// places where a variable is never allowed [...]. They consist of a name
+ /// and a concrete attribute value. The set of expected attributes, their
+ /// structure, and their interpretation are all contextually dependent on
+ /// what they are attached to.
+ ///
+ /// Example, the source level statement:
+ /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+ /// will be converted to:
+ /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+ /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+ /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
+ ///
mlir::Value *mlirGen(LiteralExprAST &lit) {
- auto location = loc(lit.loc());
- // The attribute is a vector with an attribute per element (number) in the
- // array, see `collectData()` below for more details.
- std::vector<mlir::Attribute> data;
+ auto type = getType(lit.getDims());
+
+ // The attribute is a vector with a floating point value per element
+ // (number) in the array, see `collectData()` below for more details.
+ std::vector<double> data;
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
std::multiplies<int>()));
collectData(lit, data);
- // FIXME: using a tensor type is a HACK here.
- // Can we do differently without registering a dialect? Using a string blob?
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto dataType = builder->getTensorType(lit.getDims(), elementType);
+ // The type of this attribute is tensor of 64-bit floating-point with the
+ // shape of the literal.
+ mlir::Type elementType = builder.getF64Type();
+ auto dataType = builder.getTensorType(lit.getDims(), elementType);
- // This is the actual attribute that actually hold the list of values for
- // this array literal.
- auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
- .cast<mlir::DenseElementsAttr>();
+ // This is the actual attribute that holds the list of values for this
+ // tensor literal.
+ auto dataAttribute =
+ mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
- // Build the MLIR op `toy.constant`, only boilerplate below.
- return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
- .getResult();
+ // Build the MLIR op `toy.constant`.
+ return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
- // Recursive helper function to accumulate the data that compose an array
- // literal. It flattens the nested structure in the supplied vector. For
- // example with this array:
- // [[1, 2], [3, 4]]
- // we will generate:
- // [ 1, 2, 3, 4 ]
- // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
- // Attributes are the way MLIR attaches constant to operations and functions.
- void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
+ /// Recursive helper function to accumulate the data that compose an array
+ /// literal. It flattens the nested structure in the supplied vector. For
+ /// example with this array:
+ /// [[1, 2], [3, 4]]
+ /// we will generate:
+ /// [ 1, 2, 3, 4 ]
+ /// Individual numbers are represented as doubles.
+ /// Attributes are the way MLIR attaches constant to operations.
+ void collectData(ExprAST &expr, std::vector<double> &data) {
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
for (auto &value : lit->getValues())
collectData(*value, data);
return;
}
+
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto attr = mlir::FloatAttr::getChecked(
- elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
- data.push_back(attr);
+ data.push_back(cast<NumberExprAST>(expr).getValue());
}
- // Emit a call expression. It emits specific operations for the `transpose`
- // builtin. Other identifiers are assumed to be user-defined functions.
+ /// Emit a call expression. It emits specific operations for the `transpose`
+ /// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
+ llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());
- std::string callee = call.getCallee();
- if (callee == "transpose") {
- if (call.getArgs().size() != 1) {
- emitError(location, "MLIR codegen encountered an error: toy.transpose "
- "does not accept multiple arguments");
- return nullptr;
- }
- mlir::Value *arg = mlirGen(*call.getArgs()[0]);
- return builder->create<TransposeOp>(location, arg).getResult();
- }
- // Codegen the operands first
+ // Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
@@ -342,34 +326,41 @@ private:
return nullptr;
operands.push_back(arg);
}
- // Calls to user-defined function are mapped to a custom call that takes
- // the callee name as an attribute.
- return builder->create<GenericCallOp>(location, call.getCallee(), operands)
- .getResult();
+
+ // Builting calls have their custom operation, meaning this is a
+ // straightforward emission.
+ if (callee == "transpose") {
+ if (call.getArgs().size() != 1) {
+ emitError(location, "MLIR codegen encountered an error: toy.transpose "
+ "does not accept multiple arguments");
+ return nullptr;
+ }
+ return builder.create<TransposeOp>(location, operands[0]);
+ }
+
+ // Otherwise this is a call to a user-defined function. Calls to ser-defined
+ // functions are mapped to a custom call that takes the callee name as an
+ // attribute.
+ return builder.create<GenericCallOp>(location, callee, operands);
}
- // Emit a call expression. It emits specific operations for two builtins:
- // transpose(x) and print(x). Other identifiers are assumed to be user-defined
- // functions. Return false on failure.
- bool mlirGen(PrintExprAST &call) {
+ /// Emit a print expression. It emits specific operations for two builtins:
+ /// transpose(x) and print(x).
+ mlir::LogicalResult mlirGen(PrintExprAST &call) {
auto *arg = mlirGen(*call.getArg());
if (!arg)
- return false;
- auto location = loc(call.loc());
- builder->create<PrintOp>(location, arg);
- return true;
+ return mlir::failure();
+
+ builder.create<PrintOp>(loc(call.loc()), arg);
+ return mlir::success();
}
- // Emit a constant for a single number (FIXME: semantic? broadcast?)
+ /// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
- auto location = loc(num.loc());
- mlir::Type elementType = mlir::FloatType::getF64(&context);
- auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
- loc(num.loc()));
- return builder->create<ConstantOp>(location, attr).getResult();
+ return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}
- // Dispatch codegen for the right expression subclass using RTTI.
+ /// Dispatch codegen for the right expression subclass using RTTI.
mlir::Value *mlirGen(ExprAST &expr) {
switch (expr.getKind()) {
case toy::ExprAST::Expr_BinOp:
@@ -390,77 +381,75 @@ private:
}
}
- // Handle a variable declaration, we'll codegen the expression that forms the
- // initializer and record the value in the symbol table before returning it.
- // Future expressions will be able to reference this variable through symbol
- // table lookup.
+ /// Handle a variable declaration, we'll codegen the expression that forms the
+ /// initializer and record the value in the symbol table before returning it.
+ /// Future expressions will be able to reference this variable through symbol
+ /// table lookup.
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
- mlir::Value *value = nullptr;
- auto location = loc(vardecl.loc());
- if (auto init = vardecl.getInitVal()) {
- value = mlirGen(*init);
- if (!value)
- return nullptr;
- // We have the initializer value, but in case the variable was declared
- // with specific shape, we emit a "reshape" operation. It will get
- // optimized out later as needed.
- if (!vardecl.getType().shape.empty()) {
- value = builder
- ->create<ReshapeOp>(
- location, value,
- getType(vardecl.getType()).cast<ToyArrayType>())
- .getResult();
- }
- } else {
+ auto init = vardecl.getInitVal();
+ if (!init) {
emitError(loc(vardecl.loc()),
- "missing initializer in variable declaration");
+ "Missing initializer in variable declaration");
+ return nullptr;
+ }
+
+ mlir::Value *value = mlirGen(*init);
+ if (!value)
return nullptr;
+
+ // We have the initializer value, but in case the variable was declared
+ // with specific shape, we emit a "reshape" operation. It will get
+ // optimized out later as needed.
+ if (!vardecl.getType().shape.empty()) {
+ value = builder.create<ReshapeOp>(loc(vardecl.loc()),
+ getType(vardecl.getType()), value);
}
- // Register the value in the symbol table
- declare(vardecl.getName(), value);
+
+ // Register the value in the symbol table.
+ if (failed(declare(vardecl.getName(), value)))
+ return nullptr;
return value;
}
- /// Codegen a list of expression, return false if one of them hit an error.
- bool mlirGen(ExprASTList &blockAST) {
- ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+ /// Codegen a list of expression, return failure if one of them hit an error.
+ mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
+ ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
// expressions.
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
if (!mlirGen(*vardecl))
- return false;
+ return mlir::failure();
continue;
}
- if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
- if (!mlirGen(*ret))
- return false;
- return true;
- }
+ if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
+ return mlirGen(*ret);
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
- if (!mlirGen(*print))
- return false;
+ if (mlir::failed(mlirGen(*print)))
+ return mlir::success();
continue;
}
+
// Generic expression dispatch codegen.
if (!mlirGen(*expr))
- return false;
+ return mlir::failure();
}
- return true;
+ return mlir::success();
}
- /// Build a type from a list of shape dimensions. Types are `array` followed
- /// by an optional dimension list, example: array<2, 2>
- /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
- /// !toy.array<2, 2>
- template <typename T> mlir::Type getType(T shape) {
- SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
- return ToyArrayType::get(&context, shape64);
+ /// Build a tensor type from a list of shape dimensions.
+ mlir::Type getType(ArrayRef<int64_t> shape) {
+ // If the shape is empty, then this type is unranked.
+ if (shape.empty())
+ return builder.getTensorType(builder.getF64Type());
+
+ // Otherwise, we use the given shape.
+ return builder.getTensorType(shape, builder.getF64Type());
}
- /// Build an MLIR type from a Toy AST variable type
- /// (forward to the generic getType(T) above).
+ /// Build an MLIR type from a Toy AST variable type (forward to the generic
+ /// getType above).
mlir::Type getType(const VarType &type) { return getType(type.shape); }
};
diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index 1600b99ec01..b8b091a62c5 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -1,4 +1,4 @@
-//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
+//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,22 +15,14 @@
// limitations under the License.
// =============================================================================
//
-// This file implements a Module level pass performing interprocedural
+// This file implements a Function level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
-#include "toy/Dialect.h"
-
-#include "mlir/Analysis/Verifier.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/DenseSet.h"
+#include "toy/Dialect.h"
+#include "toy/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
@@ -39,48 +31,26 @@
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
-#define DEBUG_TYPE "toy-shape-inference"
+#define DEBUG_TYPE "shape-inference"
-using namespace toy;
using llvm::MutableArrayRef;
+using llvm::raw_ostream;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
using llvm::StringRef;
using llvm::Twine;
-
-/// Create a mangled name for function specialization. We will simply append the
-/// shape of the arguments to the function name. For example, calling
-///
-/// "toy.generic_call"(%1, %3) {callee: "foo"}
-/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
-///
-/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
-/// have provided a function with a similar name, but we will claim this as a
-/// feature: this allows the user to provide custom specializations!
-static std::string mangle(StringRef funcName,
- MutableArrayRef<mlir::OpOperand> operands) {
- std::string mangledName;
- mangledName.reserve(funcName.size() + operands.size() * 6);
- mangledName = funcName;
- for (auto &operand : operands) {
- auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
- mangledName += "_";
- mlir::interleave(
- arrayTy.getShape(),
- [&](int64_t dim) { mangledName += Twine(dim).str(); },
- [&]() { mangledName += "x"; });
- }
- return mangledName;
-}
+using namespace mlir;
namespace {
-/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
-/// whole. MLIR also supports FunctionPass which are restricted to modify a
-/// single function at a time. This pass couldn't be a function pass due the
-/// nature of its interprocedural transformations.
+// clang-format off
+#include "toy/ShapeInferenceOpInterfaces.h.inc"
+#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
+
+/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
+/// shape inference.
///
-/// The algorithm has two levels, first intra-procedurally:
+/// Algorithm:
///
/// 1) Build a worklist containing all the operations that are returning
/// a generic Toy array: these are the operations that need shape
@@ -94,132 +64,25 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
/// return type for the function from the return operation.
///
-/// There is a twist though: when a call to a generic function is encountered,
-/// shape inference requires the return type of the callee to be inferred first.
-/// At this point we need to run specialize the callee by cloning it. Here is
-/// the inter-procedural flow:
-///
-/// 1) Keep a worklist of function to process. Start with function "main".
-/// 2) While the worklist isn't empty:
-/// a) Take the last inserted function in the worklist.
-/// b) Run the intra-procedural shape inference on this function.
-/// c) If the intra-procedural shape inference can't complete, it returns
-/// a Function that needs to be inferred first. In this case, queue this
-/// new function and continue. Otherwise the inference succeeded and we
-/// can pop from the queue.
-///
-class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
+class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public:
- // One entry in the inter-procedural worklist. It keeps track of the
- // function to process, the mangled name for this specialization, and the
- // types of the arguments on which to specialize.
- struct FunctionToSpecialize {
- mlir::FuncOp function;
- std::string mangledName;
- SmallVector<mlir::Type, 4> argumentsType;
- };
-
- void runOnModule() override {
- auto module = getModule();
- auto main = module.lookupSymbol<mlir::FuncOp>("main");
- if (!main) {
- emitError(mlir::UnknownLoc::get(module.getContext()),
- "shape inference failed: can't find a main function\n");
- signalPassFailure();
- return;
- }
-
- /// Inter-procedural loop, initialize with `main` and iterate until we
- /// successfully infer the full reachable call-graph from main.
- SmallVector<FunctionToSpecialize, 8> worklist;
- worklist.push_back({main, "", {}});
- while (!worklist.empty()) {
- if (failed(specialize(worklist)))
- return;
- }
-
- // Delete any generic function left
- // FIXME: we may want this as a separate pass.
- for (mlir::FuncOp function :
- llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
- if (auto genericAttr =
- function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
- if (genericAttr.getValue())
- function.erase();
- }
+ bool returnsGenericArray(Operation *op) {
+ if (op->getNumResults() == 1) {
+ if (!op->getResult(0)->getType().isa<ShapedType>())
+ return true;
}
+ return false;
}
- /// Run inference on a function. If a mangledName is provided, we need to
- /// specialize the function: to this end clone it first.
- mlir::LogicalResult
- specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
- FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
- mlir::FuncOp f = functionToSpecialize.function;
-
- // Check if cloning for specialization is needed (usually anything but main)
- // We will create a new function with the concrete types for the parameters
- // and clone the body into it.
- if (!functionToSpecialize.mangledName.empty()) {
- if (getModule().lookupSymbol<mlir::FuncOp>(
- functionToSpecialize.mangledName)) {
- funcWorklist.pop_back();
- // Function already specialized, move on.
- return mlir::success();
- }
- // Create a new function with a generic array return type, it will be
- // updated when the inference for the function body completes.
- auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
- {ToyArrayType::get(&getContext())},
- &getContext());
- auto newFunction =
- mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
- type, f.getDialectAttrs());
- getModule().push_back(newFunction);
-
- // Clone the function body
- mlir::BlockAndValueMapping mapper;
- f.cloneInto(newFunction, mapper);
- LLVM_DEBUG({
- llvm::dbgs() << "====== Cloned : \n";
- f.dump();
- llvm::dbgs() << "====== Into : \n";
- newFunction.dump();
- });
- f = newFunction;
- f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
- // Remap the entry-block arguments
- // FIXME: this seems like a bug in `cloneInto()` above?
- auto &entryBlock = f.getBlocks().front();
- int blockArgSize = entryBlock.getArguments().size();
- assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
- entryBlock.addArguments(f.getType().getInputs());
- auto argList = entryBlock.getArguments();
- for (int argNum = 0; argNum < blockArgSize; ++argNum) {
- argList[0]->replaceAllUsesWith(argList[blockArgSize]);
- entryBlock.eraseArgument(0);
- }
- assert(succeeded(mlir::verify(f)));
- }
- LLVM_DEBUG(llvm::dbgs()
- << "Run shape inference on : '" << f.getName() << "'\n");
-
- auto *toyDialect = getContext().getRegisteredDialect("toy");
- if (!toyDialect) {
- emitError(mlir::UnknownLoc::get(&getContext()),
- "Toy dialect is not registered");
- signalPassFailure();
- return mlir::failure();
- }
+ void runOnFunction() override {
+ auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
- // these are the Toy operations that return a generic array.
+ // these are operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
- if (op->getDialect() == toyDialect) {
- if (op->getNumResults() == 1 &&
- op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
- opWorklist.insert(op);
+ if (returnsGenericArray(op)) {
+ opWorklist.insert(op);
}
});
@@ -228,154 +91,31 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
- return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
- return !ty.cast<ToyArrayType>().isGeneric();
- });
+ auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
+ return this->returnsGenericArray(op);
});
+
if (nextop == opWorklist.end())
break; // failure: no operations can be inferred.
- mlir::Operation *op = *nextop;
+ Operation *op = *nextop;
opWorklist.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
-
- // The add operation is trivial: propagate the input type as is.
- if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
- op->getResult(0)->setType(op->getOperand(0)->getType());
- continue;
- }
-
- // Transpose is easy: just invert the dimensions.
- if (auto transpose = llvm::dyn_cast<TransposeOp>(op)) {
- SmallVector<int64_t, 2> dims;
- auto arrayTy = transpose.getOperand()->getType().cast<ToyArrayType>();
- dims.insert(dims.end(), arrayTy.getShape().begin(),
- arrayTy.getShape().end());
- transpose.getResult()->setType(ToyArrayType::get(&getContext(), dims));
- continue;
- }
-
- // Multiplication is a bit trickier, handle rank 1 as dot product and rank
- // 2 as matrix multiplications.
- // We need to be careful about rank mismatch here: the verifier could
- // catch it but shape inference earlier in the pass could generate an
- // invalid IR (from an invalid Toy input of course) and we wouldn't want
- // to crash here.
- if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
- auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
- auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
- auto lhsRank = lhs.getShape().size();
- auto rhsRank = rhs.getShape().size();
- if (lhsRank != rhsRank) {
- return mulOp.emitOpError(
- "shape mismatch: LHS and RHS must have the same "
- "rank for multiplication, got " +
- Twine(lhsRank) + " vs " + Twine(lhsRank));
- }
- SmallVector<int64_t, 2> dims;
- if (lhsRank == 1) {
- // dot product, result shape is <1>
- dims.push_back(1);
- } else if (lhsRank != 2) {
- return op->emitOpError(
- "shape mismatch: expect rank 1 or 2 for mul operands, got " +
- Twine(lhsRank));
- } else {
- dims.push_back(lhs.getShape()[0]);
- dims.push_back(rhs.getShape()[1]);
- }
- op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
- continue;
- }
-
- // Process calls: lookup the callee after mangling the name with the
- // argument shapes. If the callee does not exist, we stop the inference
- // for this function, queue the callee in the inter-procedural work list,
- // and return. The current function stays in the work list and will
- // restart after the callee is processed.
- if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
- auto calleeName = callOp.getCalleeName();
- auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
- if (!callee) {
- f.emitError("shape inference failed, call to unknown '")
- << calleeName << "'";
- signalPassFailure();
- return mlir::failure();
- }
- auto mangledName = mangle(calleeName, op->getOpOperands());
- LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
- << "', mangled: '" << mangledName << "'\n");
- auto mangledCallee =
- getModule().lookupSymbol<mlir::FuncOp>(mangledName);
- if (!mangledCallee) {
- // Can't find the target, this is where we queue the request for the
- // callee and stop the inference for the current function now.
- funcWorklist.push_back({callee, std::move(mangledName),
- llvm::to_vector<4>(op->getOperandTypes())});
- return mlir::success();
- }
- // Found a specialized callee! Let's turn this into a normal call
- // operation.
- SmallVector<mlir::Value *, 8> operands(op->getOperands());
- mlir::OpBuilder builder(op);
- auto newCall =
- builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
- if (newCall.getNumResults()) {
- op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
- op->erase();
- continue;
- }
- }
+ auto shapeOp = dyn_cast<ShapeInference>(op);
+ shapeOp.inferShapes();
}
- // Done with inference on this function, removing it from the worklist.
- funcWorklist.pop_back();
- // Mark the function as non-generic now that inference has succeeded
- f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
-
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
- std::string str;
- llvm::raw_string_ostream errorMsg(str);
- errorMsg << "shape inference failed, " << opWorklist.size()
- << " operations couldn't be inferred\n";
- for (auto *ope : opWorklist)
- errorMsg << " - " << *ope << "\n";
- f.emitError(errorMsg.str());
signalPassFailure();
- return mlir::failure();
- }
-
- // Finally, update the return type of the function based on the argument to
- // the return operation.
- for (auto &block : f.getBlocks()) {
- auto ret = llvm::cast<ReturnOp>(block.getTerminator());
- if (!ret)
- continue;
- if (ret.getNumOperands() &&
- f.getType().getResult(0) == ret.getOperand()->getType())
- // type match, we're done
- break;
- SmallVector<mlir::Type, 1> retTy;
- if (ret.getNumOperands())
- retTy.push_back(ret.getOperand()->getType());
- std::vector<mlir::Type> argumentsType;
- for (auto arg : f.getArguments())
- argumentsType.push_back(arg->getType());
- auto newType =
- mlir::FunctionType::get(argumentsType, retTy, &getContext());
- f.setType(newType);
- assert(succeeded(mlir::verify(f)));
- break;
+ auto diag = f.emitError("Shape inference failed, ")
+ << opWorklist.size() << " operations couldn't be inferred\n";
}
- return mlir::success();
}
};
} // end anonymous namespace
-namespace toy {
-std::unique_ptr<mlir::Pass> createShapeInferencePass() {
+/// Create a Shape Inference pass.
+std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}
-} // namespace toy
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index b89cb85ff06..1b9dcd20291 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -15,24 +15,25 @@
// limitations under the License.
// =============================================================================
//
-// This file implements a simple combiner for optimizing pattern in the Toy
-// dialect.
+// This file implements a set of simple combiners for optimizing operations in
+// the Toy dialect.
//
//===----------------------------------------------------------------------===//
-#include "toy/Dialect.h"
-
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-
+#include "toy/Dialect.h"
#include <numeric>
-
-namespace toy {
+using namespace mlir;
+using namespace toy;
namespace {
+/// Include the patterns defined in the Declarative Rewrite framework.
+#include "ToyCombine.inc"
+} // end anonymous namespace
-/// Fold transpose(transpose(x) -> transpose(x)
+/// This is an example of a c++ rewrite pattern for the TransposeOp. It
+/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
@@ -40,9 +41,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
- /// This method is attempting to match a pattern and rewrite it. The rewriter
- /// argument is the orchestrator of the sequence of rewrites. It is expected
- /// to interact with it to perform any changes to the IR from here.
+ /// This method attempts to match a pattern and rewrite it. The rewriter
+ /// argument is the orchestrator of the sequence of rewrites. The pattern is
+ /// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
@@ -50,106 +51,28 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
+
// If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
- // Use the rewriter to perform the replacement
+ // Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
-/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
- using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- mlir::PatternMatchResult
- matchAndRewrite(ReshapeOp reshape,
- mlir::PatternRewriter &rewriter) const override {
- // Look through the input of the current reshape.
- ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
- reshape.getOperand()->getDefiningOp());
- // If the input is defined by another constant, bingo!
- if (!constantOp)
- return matchFailure();
-
- auto reshapeType = reshape.getType().cast<ToyArrayType>();
- if (auto valueAttr =
- constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
- // FIXME Check matching of element count!
- // auto oldType = constantOp.getType();
- auto newType = rewriter.getTensorType(
- reshapeType.getShape(), valueAttr.getType().getElementType());
- auto newAttr = valueAttr.reshape(newType);
- rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
- newAttr);
- } else if (auto valueAttr =
- constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
- // Broadcast
- auto dataSize = std::accumulate(reshapeType.getShape().begin(),
- reshapeType.getShape().end(), 1,
- std::multiplies<int>());
- std::vector<mlir::Attribute> data(dataSize, valueAttr);
- auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
- reshapeType.getElementType());
- auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
- rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
- newAttr);
- } else {
- llvm_unreachable("Unsupported Constant format");
- }
- return matchSuccess();
- }
-};
-
-/// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
- using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- mlir::PatternMatchResult
- matchAndRewrite(ReshapeOp op,
- mlir::PatternRewriter &rewriter) const override {
- // Look through the input of the current reshape.
- mlir::Value *reshapeInput = op.getOperand();
-
- // If the input is defined by another reshape, bingo!
- if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
- return matchFailure();
-
- // Use the rewriter to perform the replacement
- rewriter.replaceOp(op, {reshapeInput});
- return matchSuccess();
- }
-};
-
-/// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
- using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- mlir::PatternMatchResult
- matchAndRewrite(ReshapeOp op,
- mlir::PatternRewriter &rewriter) const override {
- if (op.getOperand()->getType() != op.getType())
- return matchFailure();
- rewriter.replaceOp(op, {op.getOperand()});
- return matchSuccess();
- }
-};
-
-} // end anonymous namespace.
-
-// Register our patterns for rewrite by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
+/// Register our patterns as "canonicalization" patterns on the TransposeOp so
+/// that they can be picked up by the Canonicalization framework.
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
-// Register our patterns for rewrite by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
- SimplifyNullReshape>(context);
+/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
+/// that they can be picked up by the Canonicalization framework.
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
-
-} // namespace toy
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.td b/mlir/examples/toy/Ch4/mlir/ToyCombine.td
new file mode 100644
index 00000000000..97b9be4c353
--- /dev/null
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.td
@@ -0,0 +1,73 @@
+//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines language-specific pattern match optimizations for Toy using
+// Declarative Rewrite Rules (DRR) specified using TableGen records.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOY_COMBINE
+#define TOY_COMBINE
+
+#ifndef OP_BASE
+include "toy/Ops.td"
+#endif // OP_BASE
+
+/// Note: The DRR definition used for defining patterns is shown below:
+///
+/// class Pattern<
+/// dag sourcePattern, list<dag> resultPatterns,
+/// list<dag> additionalConstraints = [],
+/// dag benefitsAdded = (addBenefit 0)
+/// >;
+
+//===----------------------------------------------------------------------===//
+// Basic Pattern-Match and Rewrite
+//===----------------------------------------------------------------------===//
+
+// Reshape(Reshape(x)) = x
+def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
+ (ReshapeOp $arg)>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite using Native Code Call
+//===----------------------------------------------------------------------===//
+
+// Native Code Calls may be used for more complex transformations using inline
+// C++ and C++ helper functions.
+
+// Reshape(Constant(x)) = x'
+def ReshapeConstant :
+ NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
+def FoldConstantReshapeOptPattern : Pat<
+ (ReshapeOp:$res (ConstantOp $arg)),
+ (ConstantOp (ReshapeConstant $arg, $res))>;
+
+//===----------------------------------------------------------------------===//
+// Pattern-Match and Rewrite with Constraints
+//===----------------------------------------------------------------------===//
+
+// DRR allows for constraint checking when the transformation is conditional
+// on operand properties.
+
+// Reshape(x) = x, where input and output shapes are identical
+def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
+def RedundantReshapeOptPattern : Pat<
+ (ReshapeOp:$res $arg), (replaceWithValue $arg),
+ [(TypesAreIdentical $res, $arg)]>;
+
+#endif // TOY_COMBINE
diff --git a/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp
deleted file mode 100644
index f77754e368f..00000000000
--- a/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp
+++ /dev/null
@@ -1,387 +0,0 @@
-//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements the dialect for the Toy IR: custom type parsing and
-// operation verification.
-//
-//===----------------------------------------------------------------------===//
-
-#include "toy/Dialect.h"
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/iterator_range.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/Regex.h"
-#include "llvm/Support/raw_ostream.h"
-
-using llvm::ArrayRef;
-using llvm::raw_ostream;
-using llvm::raw_string_ostream;
-using llvm::SmallVector;
-using llvm::StringRef;
-using llvm::Twine;
-
-namespace toy {
-namespace detail {
-
-/// This class holds the implementation of the ToyArrayType.
-/// It is intended to be uniqued based on its content and owned by the context.
-struct ToyArrayTypeStorage : public mlir::TypeStorage {
- /// This defines how we unique this type in the context: our key contains
- /// only the shape, a more complex type would have multiple entries in the
- /// tuple here.
- /// The element of the tuples usually matches 1-1 the arguments from the
- /// public `get()` method arguments from the facade.
- using KeyTy = std::tuple<ArrayRef<int64_t>>;
- static unsigned hashKey(const KeyTy &key) {
- return llvm::hash_combine(std::get<0>(key));
- }
- /// When the key hash hits an existing type, we compare the shape themselves
- /// to confirm we have the right type.
- bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
-
- /// This is a factory method to create our type storage. It is only
- /// invoked after looking up the type in the context using the key and not
- /// finding it.
- static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
- const KeyTy &key) {
- // Copy the shape array into the bumpptr allocator owned by the context.
- ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
-
- // Allocate the instance for the ToyArrayTypeStorage itself
- auto *storage = allocator.allocate<ToyArrayTypeStorage>();
- // Initialize the instance using placement new.
- return new (storage) ToyArrayTypeStorage(shape);
- }
-
- ArrayRef<int64_t> getShape() const { return shape; }
-
-private:
- ArrayRef<int64_t> shape;
-
- /// Constructor is only invoked from the `construct()` method above.
- ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
-};
-
-} // namespace detail
-
-mlir::Type ToyArrayType::getElementType() {
- return mlir::FloatType::getF64(getContext());
-}
-
-ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
- ArrayRef<int64_t> shape) {
- return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
-}
-
-ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
-
-/// Dialect creation, the instance will be owned by the context. This is the
-/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
- addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
- MulOp, AddOp, ReturnOp>();
- addTypes<ToyArrayType>();
-}
-
-/// Parse a type registered to this dialect, we expect only Toy arrays.
-mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
- // Sanity check: we only support array or array<...>
- if (!tyData.startswith("array")) {
- emitError(loc, "invalid Toy type '" + tyData + "', array expected");
- return nullptr;
- }
- // Drop the "array" prefix from the type name, we expect either an empty
- // string or just the shape.
- tyData = tyData.drop_front(StringRef("array").size());
- // This is the generic array case without shape, early return it.
- if (tyData.empty())
- return ToyArrayType::get(getContext());
-
- // Use a regex to parse the shape (for efficient we should store this regex in
- // the dialect itself).
- SmallVector<StringRef, 4> matches;
- auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
- if (!shapeRegex.match(tyData, &matches)) {
- emitError(loc, "invalid toy array shape '" + tyData + "'");
- return nullptr;
- }
- SmallVector<int64_t, 4> shape;
- // Iterate through the captures, skip the first one which is the full string.
- for (auto dimStr :
- llvm::make_range(std::next(matches.begin()), matches.end())) {
- if (dimStr.startswith(","))
- continue; // POSIX misses non-capturing groups.
- if (dimStr.empty())
- continue; // '*' makes it an optional group capture
- // Convert the capture to an integer
- unsigned long long dim;
- if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
- emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
- return mlir::Type();
- }
- shape.push_back(dim);
- }
- // Finally we collected all the dimensions in the shape,
- // create the array type.
- return ToyArrayType::get(getContext(), shape);
-}
-
-/// Print a Toy array type, for example `array<2, 3, 4>`
-void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
- auto arrayTy = type.dyn_cast<ToyArrayType>();
- if (!arrayTy) {
- os << "unknown toy type";
- return;
- }
- os << "array";
- if (!arrayTy.getShape().empty()) {
- os << "<";
- mlir::interleaveComma(arrayTy.getShape(), os);
- os << ">";
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-//////////////////// Custom Operations for the Dialect /////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-/// Helper to verify that the result of an operation is a Toy array type.
-template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
- if (!op->getResult()->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its argument, got "
- << op->getResult()->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-/// Helper to verify that the two operands of a binary operation are Toy
-/// arrays..
-template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
- if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its LHS, got "
- << op->getOperand(0)->getType();
- return op->emitOpError(os.str());
- }
- if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its LHS, got "
- << op->getOperand(0)->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-/// Build a constant operation.
-/// The builder is passed as an argument, so is the state that this method is
-/// expected to fill in order to build the operation.
-void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
- ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
- state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
- auto dataAttribute = builder->getNamedAttr("value", value);
- state.attributes.push_back(dataAttribute);
-}
-
-/// Build a constant operation.
-/// The builder is passed as an argument, so is the state that this method is
-/// expected to fill in order to build the operation.
-void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::FloatAttr value) {
- // Broadcast and forward to the other build factory
- mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
- auto dataType = builder->getTensorType({1}, elementType);
- auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
- .cast<mlir::DenseElementsAttr>();
-
- ConstantOp::build(builder, state, {1}, dataAttribute);
-}
-
-/// Verifier for constant operation.
-mlir::LogicalResult ConstantOp::verify() {
- // Ensure that the return type is a Toy array
- if (failed(verifyToyReturnArray(this)))
- return mlir::failure();
-
- // We expect the constant itself to be stored as an attribute.
- auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
- if (!dataAttr) {
- return emitOpError(
- "missing valid `value` DenseElementsAttribute on toy.constant()");
- }
- auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
- if (!attrType) {
- return emitOpError(
- "missing valid `value` DenseElementsAttribute on toy.constant()");
- }
-
- // If the return type of the constant is not a generic array, the shape must
- // match the shape of the attribute holding the data.
- auto resultType = getResult()->getType().cast<ToyArrayType>();
- if (!resultType.isGeneric()) {
- if (attrType.getRank() != resultType.getRank()) {
- return emitOpError("The rank of the toy.constant return type must match "
- "the one of the attached value attribute: " +
- Twine(attrType.getRank()) +
- " != " + Twine(resultType.getRank()));
- }
- for (int dim = 0; dim < attrType.getRank(); ++dim) {
- if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
- std::string msg;
- raw_string_ostream os(msg);
- return emitOpError(
- "Shape mismatch between toy.constant return type and its "
- "attribute at dimension " +
- Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
- " != " + Twine(resultType.getShape()[dim]));
- }
- }
- }
- return mlir::success();
-}
-
-void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
- StringRef callee, ArrayRef<mlir::Value *> arguments) {
- // Generic call always returns a generic ToyArray initially
- state.types.push_back(ToyArrayType::get(builder->getContext()));
- state.operands.assign(arguments.begin(), arguments.end());
- auto calleeAttr = builder->getStringAttr(callee);
- state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
-}
-
-mlir::LogicalResult GenericCallOp::verify() {
- // Verify that every operand is a Toy Array
- for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
- if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its " << opId << " operand, got "
- << getOperand(opId)->getType();
- return emitOpError(os.str());
- }
- }
- return mlir::success();
-}
-
-/// Return the name of the callee.
-StringRef GenericCallOp::getCalleeName() {
- return getAttr("callee").cast<mlir::StringAttr>().getValue();
-}
-
-template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
- if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
- std::string msg;
- raw_string_ostream os(msg);
- os << "expects a Toy Array for its argument, got "
- << op->getOperand()->getType();
- return op->emitOpError(os.str());
- }
- return mlir::success();
-}
-
-void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value) {
- // Return does not return any value and has an optional single argument
- if (value)
- state.operands.push_back(value);
-}
-
-mlir::LogicalResult ReturnOp::verify() {
- if (getNumOperands() > 1)
- return emitOpError("expects zero or one operand, got " +
- Twine(getNumOperands()));
- if (hasOperand() && failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value) {
- // Print does not return any value and has a single argument
- state.operands.push_back(value);
-}
-
-mlir::LogicalResult PrintOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value) {
- state.types.push_back(ToyArrayType::get(builder->getContext()));
- state.operands.push_back(value);
-}
-
-mlir::LogicalResult TransposeOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *value, ToyArrayType reshapedType) {
- state.types.push_back(reshapedType);
- state.operands.push_back(value);
-}
-
-mlir::LogicalResult ReshapeOp::verify() {
- if (failed(verifyToySingleOperand(this)))
- return mlir::failure();
- auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
- if (!retTy)
- return emitOpError("toy.reshape is expected to produce a Toy array");
- if (retTy.isGeneric())
- return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
- "got a generic one.");
- return mlir::success();
-}
-
-void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *lhs, mlir::Value *rhs) {
- state.types.push_back(ToyArrayType::get(builder->getContext()));
- state.operands.push_back(lhs);
- state.operands.push_back(rhs);
-}
-
-mlir::LogicalResult AddOp::verify() {
- if (failed(verifyToyBinOperands(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
- mlir::Value *lhs, mlir::Value *rhs) {
- state.types.push_back(ToyArrayType::get(builder->getContext()));
- state.operands.push_back(lhs);
- state.operands.push_back(rhs);
-}
-
-mlir::LogicalResult MulOp::verify() {
- if (failed(verifyToyBinOperands(this)))
- return mlir::failure();
- return mlir::success();
-}
-
-} // namespace toy
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index 1c084e0918b..6f75269b9be 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -80,54 +80,63 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule();
}
-mlir::LogicalResult optimize(mlir::ModuleOp module) {
- mlir::PassManager pm(module.getContext());
- pm.addPass(mlir::createCanonicalizerPass());
- pm.addPass(createShapeInferencePass());
- pm.addPass(mlir::createCanonicalizerPass());
- // Apply any generic pass manager command line options.
- applyPassManagerCLOptions(pm);
-
- return pm.run(module);
+int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
+ // Handle '.toy' input to the compiler.
+ if (inputType != InputType::MLIR &&
+ !llvm::StringRef(inputFilename).endswith(".mlir")) {
+ auto moduleAST = parseInputFile(inputFilename);
+ module = mlirGen(context, *moduleAST);
+ return !module ? 1 : 0;
+ }
+
+ // Otherwise, the input is '.mlir'.
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+ llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+ if (std::error_code EC = fileOrErr.getError()) {
+ llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+ return -1;
+ }
+
+ // Parse the input mlir.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+ module = mlir::parseSourceFile(sourceMgr, &context);
+ if (!module) {
+ llvm::errs() << "Error can't load file " << inputFilename << "\n";
+ return 3;
+ }
+ return 0;
}
int dumpMLIR() {
- // Register our Dialect with MLIR
- mlir::registerDialect<ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
- if (inputType == InputType::MLIR ||
- llvm::StringRef(inputFilename).endswith(".mlir")) {
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
- llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
- if (std::error_code EC = fileOrErr.getError()) {
- llvm::errs() << "Could not open input file: " << EC.message() << "\n";
- return -1;
- }
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
- module = mlir::parseSourceFile(sourceMgr, &context);
- if (!module) {
- llvm::errs() << "Error can't load file " << inputFilename << "\n";
- return 3;
- }
- if (failed(mlir::verify(*module))) {
- llvm::errs() << "Error verifying MLIR module\n";
- return 4;
- }
- } else {
- auto moduleAST = parseInputFile(inputFilename);
- module = mlirGen(context, *moduleAST);
- }
- if (!module)
- return 1;
+ if (int error = loadMLIR(context, module))
+ return error;
+
if (EnableOpt) {
- if (failed(optimize(*module))) {
- llvm::errs() << "Module optimization failed\n";
- return 7;
- }
+ mlir::PassManager pm(&context);
+ // Apply any generic pass manager command line options and run the pipeline.
+ applyPassManagerCLOptions(pm);
+
+ // Add a run of the canonicalizer to optimize the mlir module.
+ pm.addPass(mlir::createCanonicalizerPass());
+
+ // Inline all functions into main and then delete them.
+ pm.addPass(mlir::createInlinerPass());
+ pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
+
+ // Now that there is only one function, we can infer the shapes of each of
+ // the operations.
+ pm.addPass(mlir::toy::createShapeInferencePass());
+
+ if (mlir::failed(pm.run(*module)))
+ return 4;
}
+
module->dump();
return 0;
}
OpenPOWER on IntegriCloud