summaryrefslogtreecommitdiffstats
path: root/mlir/test/lib/TestDialect
diff options
context:
space:
mode:
authorMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
committerMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
commit0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch)
tree31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/test/lib/TestDialect
parent6f635f90929da9545dd696071a829a1a42f84b30 (diff)
parent5b4a01d4a63cb66ab981e52548f940813393bf42 (diff)
downloadbcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz
bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/test/lib/TestDialect')
-rw-r--r--mlir/test/lib/TestDialect/CMakeLists.txt28
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp316
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.h53
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td1047
-rw-r--r--mlir/test/lib/TestDialect/TestPatterns.cpp504
-rw-r--r--mlir/test/lib/TestDialect/lit.local.cfg1
6 files changed, 1949 insertions, 0 deletions
diff --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt
new file mode 100644
index 00000000000..e6a22833de4
--- /dev/null
+++ b/mlir/test/lib/TestDialect/CMakeLists.txt
@@ -0,0 +1,28 @@
+set(LLVM_OPTIONAL_SOURCES
+ TestDialect.cpp
+ TestPatterns.cpp
+)
+
+set(LLVM_TARGET_DEFINITIONS TestOps.td)
+mlir_tablegen(TestOps.h.inc -gen-op-decls)
+mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
+mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(TestPatterns.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestOpsIncGen)
+
+add_llvm_library(MLIRTestDialect
+ TestDialect.cpp
+ TestPatterns.cpp
+)
+add_dependencies(MLIRTestDialect
+ MLIRTestOpsIncGen
+ MLIRIR
+ LLVMSupport
+ MLIRTypeInferOpInterfaceIncGen
+)
+target_link_libraries(MLIRTestDialect
+ MLIRDialect
+ MLIRIR
+ LLVMSupport
+)
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
new file mode 100644
index 00000000000..21cf69ec1fa
--- /dev/null
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -0,0 +1,316 @@
+//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TestDialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Test support for interacting with the AsmPrinter.
+struct TestOpAsmInterface : public OpAsmDialectInterface {
+ using OpAsmDialectInterface::OpAsmDialectInterface;
+
+ void getAsmResultNames(Operation *op,
+ OpAsmSetValueNameFn setNameFn) const final {
+ if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
+ setNameFn(asmOp, "result");
+ }
+
+ void getAsmBlockArgumentNames(Block *block,
+ OpAsmSetValueNameFn setNameFn) const final {
+ auto op = block->getParentOp();
+ auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
+ if (!arrayAttr)
+ return;
+ auto args = block->getArguments();
+ auto e = std::min(arrayAttr.size(), args.size());
+ for (unsigned i = 0; i < e; ++i) {
+ if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
+ setNameFn(args[i], strAttr.getValue());
+ }
+ }
+};
+
+struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
+ using OpFolderDialectInterface::OpFolderDialectInterface;
+
+ /// Registered hook to check if the given region, which is attached to an
+ /// operation that is *not* isolated from above, should be used when
+ /// materializing constants.
+ bool shouldMaterializeInto(Region *region) const final {
+ // If this is a one region operation, then insert into it.
+ return isa<OneRegionOp>(region->getParentOp());
+ }
+};
+
+/// This class defines the interface for handling inlining with standard
+/// operations.
+struct TestInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
+ // Inlining into test dialect regions is legal.
+ return true;
+ }
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ bool shouldAnalyzeRecursively(Operation *op) const final {
+ // Analyze recursively if this is not a functional region operation, it
+ // froms a separate functional scope.
+ return !isa<FunctionalRegionOp>(op);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator by replacing it with a new operation
+ /// as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const final {
+ // Only handle "test.return" here.
+ auto returnOp = dyn_cast<TestReturnOp>(op);
+ if (!returnOp)
+ return;
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ }
+
+ /// Attempt to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned.
+ Operation *materializeCallConversion(OpBuilder &builder, Value input,
+ Type resultType,
+ Location conversionLoc) const final {
+ // Only allow conversion for i16/i32 types.
+ if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
+ !(input->getType().isInteger(16) || input->getType().isInteger(32)))
+ return nullptr;
+ return builder.create<TestCastOp>(conversionLoc, resultType, input);
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+TestDialect::TestDialect(MLIRContext *context)
+ : Dialect(getDialectName(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "TestOps.cpp.inc"
+ >();
+ addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
+ TestInlinerInterface>();
+ allowUnknownOperations();
+}
+
+LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
+ NamedAttribute namedAttr) {
+ if (namedAttr.first == "test.invalid_attr")
+ return op->emitError() << "invalid to use 'test.invalid_attr'";
+ return success();
+}
+
+LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute namedAttr) {
+ if (namedAttr.first == "test.invalid_attr")
+ return op->emitError() << "invalid to use 'test.invalid_attr'";
+ return success();
+}
+
+LogicalResult
+TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+ unsigned resultIndex,
+ NamedAttribute namedAttr) {
+ if (namedAttr.first == "test.invalid_attr")
+ return op->emitError() << "invalid to use 'test.invalid_attr'";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test IsolatedRegionOp - parse passthrough region arguments.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType argInfo;
+ Type argType = parser.getBuilder().getIndexType();
+
+ // Parse the input operand.
+ if (parser.parseOperand(argInfo) ||
+ parser.resolveOperand(argInfo, argType, result.operands))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, argType,
+ /*enableNameShadowing=*/true);
+}
+
+static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
+ p << "test.isolated_region ";
+ p.printOperand(op.getOperand());
+ p.shadowRegionArgs(op.region(), op.getOperand());
+ p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// Test parser.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
+ OperationState &result) {
+ StringRef keyword;
+ if (parser.parseKeyword(&keyword))
+ return failure();
+ result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
+ return success();
+}
+
+static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
+ p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
+}
+
+//===----------------------------------------------------------------------===//
+// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
+
+static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseKeyword("wraps"))
+ return failure();
+
+ // Parse the wrapped op in a region
+ Region &body = *result.addRegion();
+ body.push_back(new Block);
+ Block &block = body.back();
+ Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
+ if (!wrapped_op)
+ return failure();
+
+ // Create a return terminator in the inner region, pass as operand to the
+ // terminator the returned values from the wrapped operation.
+ SmallVector<Value, 8> return_operands(wrapped_op->getResults());
+ OpBuilder builder(parser.getBuilder().getContext());
+ builder.setInsertionPointToEnd(&block);
+ builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
+
+ // Get the results type for the wrapping op from the terminator operands.
+ Operation &return_op = body.back().back();
+ result.types.append(return_op.operand_type_begin(),
+ return_op.operand_type_end());
+
+ // Use the location of the wrapped op for the "test.wrapping_region" op.
+ result.location = wrapped_op->getLoc();
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, WrappingRegionOp op) {
+ p << op.getOperationName() << " wraps ";
+ p.printGenericOp(&op.region().front().front());
+}
+
+//===----------------------------------------------------------------------===//
+// Test PolyForOp - parse list of region arguments.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
+ // Parse list of region arguments without a delimiter.
+ if (parser.parseRegionArgumentList(ivsInfo))
+ return failure();
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ auto &builder = parser.getBuilder();
+ SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
+ return parser.parseRegion(*body, ivsInfo, argTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Test removing op with inner ops.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestRemoveOpWithInnerOps
+ : public OpRewritePattern<TestOpWithRegionPattern> {
+ using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace
+
+void TestOpWithRegionPattern::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<TestRemoveOpWithInnerOps>(context);
+}
+
+OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
+ return operand();
+}
+
+LogicalResult TestOpWithVariadicResultsAndFolder::fold(
+ ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
+ for (Value input : this->operands()) {
+ results.push_back(input);
+ }
+ return success();
+}
+
+LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
+ llvm::Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferedReturnTypes) {
+ if (operands[0]->getType() != operands[1]->getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0]->getType(), " vs ",
+ operands[1]->getType());
+ }
+ inferedReturnTypes.assign({operands[0]->getType()});
+ return success();
+}
+
+// Static initialization for Test dialect registration.
+static mlir::DialectRegistration<mlir::TestDialect> testDialect;
+
+#include "TestOpEnums.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h
new file mode 100644
index 00000000000..20db0f39b81
--- /dev/null
+++ b/mlir/test/lib/TestDialect/TestDialect.h
@@ -0,0 +1,53 @@
+//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a fake 'test' dialect that can be used for testing things
+// that do not have a respective counterpart in the main source directories.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTDIALECT_H
+#define MLIR_TESTDIALECT_H
+
+#include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/InferTypeOpInterface.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+
+#include "TestOpEnums.h.inc"
+
+namespace mlir {
+
+class TestDialect : public Dialect {
+public:
+ /// Create the dialect in the given `context`.
+ TestDialect(MLIRContext *context);
+
+ /// Get the canonical string name of the dialect.
+ static StringRef getDialectName() { return "test"; }
+
+ LogicalResult verifyOperationAttribute(Operation *op,
+ NamedAttribute namedAttr) override;
+ LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute namedAttr) override;
+ LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+ unsigned resultIndex,
+ NamedAttribute namedAttr) override;
+};
+
+#define GET_OP_CLASSES
+#include "TestOps.h.inc"
+
+} // end namespace mlir
+
+#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
new file mode 100644
index 00000000000..dacb796de18
--- /dev/null
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -0,0 +1,1047 @@
+//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_OPS
+#define TEST_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/InferTypeOpInterface.td"
+
+def TEST_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "";
+}
+
+class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<TEST_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Test Types
+//===----------------------------------------------------------------------===//
+
+def ComplexF64 : Complex<F64>;
+def ComplexOp : TEST_Op<"complex_f64"> {
+ let results = (outs ComplexF64);
+}
+
+def ComplexTensorOp : TEST_Op<"complex_f64_tensor"> {
+ let results = (outs TensorOf<[ComplexF64]>);
+}
+
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+
+def TupleOp : TEST_Op<"tuple_32_bit"> {
+ let results = (outs TupleOf<[I32, F32]>);
+}
+
+def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
+ let results = (outs NestedTupleOf<[I32, F32]>);
+}
+
+def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> {
+ let arguments = (ins AnyStaticShapeMemRef:$x);
+}
+
+def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> {
+ let results = (outs MemRefRankOf<[I8, F32], [0, 1]>);
+}
+
+def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
+ let arguments = (ins
+ 0DTensorOf<[F32]>:$arg0,
+ 1DTensorOf<[F32]>:$arg1,
+ 2DTensorOf<[I16]>:$arg2,
+ 3DTensorOf<[I16]>:$arg3,
+ 4DTensorOf<[I16]>:$arg4
+ );
+}
+
+def RankedTensorOp : TEST_Op<"ranked_tensor_op"> {
+ let arguments = (ins AnyRankedTensor:$input);
+}
+
+def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
+ let arguments = (ins
+ TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Operands
+//===----------------------------------------------------------------------===//
+
+def SymbolScopeOp : TEST_Op<"symbol_scope",
+ [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> {
+ let summary = "operation which defines a new symbol table";
+ let regions = (region SizedRegion<1>:$region);
+}
+
+def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
+ let summary = "operation which defines a new symbol table without a "
+ "restriction on a terminator";
+ let regions = (region SizedRegion<1>:$region);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Operands
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicOperandOp : TEST_Op<
+ "mixed_normal_variadic_operand", [SameVariadicOperandSize]> {
+ let arguments = (ins
+ Variadic<AnyTensor>:$input1,
+ AnyTensor:$input2,
+ Variadic<AnyTensor>:$input3
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Results
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicResults : TEST_Op<
+ "mixed_normal_variadic_result", [SameVariadicResultSize]> {
+ let results = (outs
+ Variadic<AnyTensor>:$output1,
+ AnyTensor:$output2,
+ Variadic<AnyTensor>:$output3
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Attributes
+//===----------------------------------------------------------------------===//
+
+def NonNegIntAttrOp : TEST_Op<"non_negative_int_attr"> {
+ let arguments = (ins
+ NonNegativeI32Attr:$i32attr,
+ NonNegativeI64Attr:$i64attr
+ );
+}
+
+def PositiveIntAttrOp : TEST_Op<"positive_int_attr"> {
+ let arguments = (ins
+ PositiveI32Attr:$i32attr,
+ PositiveI64Attr:$i64attr
+ );
+}
+
+def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
+ let arguments = (ins TypeArrayAttr:$attr);
+}
+def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
+ let arguments = (ins StrAttr:$attr);
+ let printer = [{ p << getAttr("attr"); }];
+ let parser = [{
+ Attribute attr;
+ Type stringType = OpaqueType::get(Identifier::get("foo",
+ result.getContext()), "string",
+ result.getContext());
+ return parser.parseAttribute(attr, stringType, "attr", result.attributes);
+ }];
+}
+
+def StrCaseA: StrEnumAttrCase<"A">;
+def StrCaseB: StrEnumAttrCase<"B">;
+
+def SomeStrEnum: StrEnumAttr<
+ "SomeStrEnum", "", [StrCaseA, StrCaseB]>;
+
+def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
+ let arguments = (ins SomeStrEnum:$attr);
+ let results = (outs I32:$val);
+}
+
+def I32Case5: I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+ "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
+def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
+ let arguments = (ins SomeI32Enum:$attr);
+ let results = (outs I32:$val);
+}
+
+def I64Case5: I64EnumAttrCase<"case5", 5>;
+def I64Case10: I64EnumAttrCase<"case10", 10>;
+
+def SomeI64Enum: I64EnumAttr<
+ "SomeI64Enum", "", [I64Case5, I64Case10]>;
+
+def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
+ let arguments = (ins SomeI64Enum:$attr);
+ let results = (outs I32:$val);
+}
+
+def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
+ let arguments = (ins
+ RankedF32ElementsAttr<[2]>:$scalar_f32_attr,
+ RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr
+ );
+}
+
+// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
+// This tests both matching and generating float elements attributes.
+def UpdateFloatElementsAttr : Pat<
+ (FloatElementsAttrOp
+ ConstantAttr<RankedF32ElementsAttr<[2]>, "{3.0f, 4.0f}">:$f32attr,
+ $f64attr),
+ (FloatElementsAttrOp
+ ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr,
+ $f64attr)>;
+
+//===----------------------------------------------------------------------===//
+// Test Attribute Constraints
+//===----------------------------------------------------------------------===//
+
+def SymbolRefOp : TEST_Op<"symbol_ref_attr"> {
+ let arguments = (ins
+ Confined<FlatSymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Regions
+//===----------------------------------------------------------------------===//
+
+def OneRegionOp : TEST_Op<"one_region_op", []> {
+ let regions = (region AnyRegion);
+}
+
+def TwoRegionOp : TEST_Op<"two_region_op", []> {
+ let regions = (region AnyRegion, AnyRegion);
+}
+
+def SizedRegionOp : TEST_Op<"sized_region_op", []> {
+ let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
+}
+
+//===----------------------------------------------------------------------===//
+// Test Call Interfaces
+//===----------------------------------------------------------------------===//
+
+def ConversionCallOp : TEST_Op<"conversion_call_op",
+ [CallOpInterface]> {
+ let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee);
+ let results = (outs Variadic<AnyType>);
+
+ let extraClassDeclaration = [{
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() { return inputs(); }
+
+ /// Return the callee of this operation.
+ CallInterfaceCallable getCallableForCallee() {
+ return getAttrOfType<FlatSymbolRefAttr>("callee");
+ }
+ }];
+}
+
+def FunctionalRegionOp : TEST_Op<"functional_region_op",
+ [CallableOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let results = (outs FunctionType);
+
+ let extraClassDeclaration = [{
+ Region *getCallableRegion(CallInterfaceCallable) { return &body(); }
+ void getCallableRegions(SmallVectorImpl<Region *> &callables) {
+ callables.push_back(&body());
+ }
+ ArrayRef<Type> getCallableResults(Region *) {
+ return getType().cast<FunctionType>().getResults();
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Test Traits
+//===----------------------------------------------------------------------===//
+
+def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
+ [SameOperandsElementType]> {
+ let arguments = (ins AnyType, AnyType);
+ let results = (outs AnyType);
+}
+
+def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
+ [SameOperandsAndResultElementType]> {
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs Variadic<AnyType>);
+}
+
+def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {
+ let arguments = (ins Variadic<AnyShaped>);
+}
+
+def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
+ [SameOperandsAndResultShape]> {
+ let arguments = (ins Variadic<AnyShaped>);
+ let results = (outs Variadic<AnyShaped>);
+}
+
+def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
+ [SameOperandsAndResultType]> {
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs Variadic<AnyType>);
+}
+
+def ArgAndResHaveFixedElementTypesOp :
+ TEST_Op<"arg_and_res_have_fixed_element_types",
+ [PredOpTrait<"fixed type combination",
+ And<[ElementTypeIsPred<"x", I32>,
+ ElementTypeIsPred<"y", F32>]>>,
+ ElementTypeIs<"res", I16>]> {
+ let arguments = (ins
+ AnyShaped:$x, AnyShaped:$y);
+ let results = (outs AnyShaped:$res);
+}
+
+def OperandsHaveSameElementType : TEST_Op<"operands_have_same_element_type", [
+ AllElementTypesMatch<["x", "y"]>]> {
+ let arguments = (ins AnyType:$x, AnyType:$y);
+}
+
+def OperandZeroAndResultHaveSameElementType : TEST_Op<
+ "operand0_and_result_have_same_element_type",
+ [AllElementTypesMatch<["x", "res"]>]> {
+ let arguments = (ins AnyType:$x, AnyType:$y);
+ let results = (outs AnyType:$res);
+}
+
+def OperandsHaveSameType :
+ TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> {
+ let arguments = (ins AnyType:$x, AnyType:$y);
+}
+
+def OperandZeroAndResultHaveSameType :
+ TEST_Op<"operand0_and_result_have_same_type",
+ [AllTypesMatch<["x", "res"]>]> {
+ let arguments = (ins AnyType:$x, AnyType:$y);
+ let results = (outs AnyType:$res);
+}
+
+def OperandsHaveSameRank :
+ TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> {
+ let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+}
+
+def OperandZeroAndResultHaveSameRank :
+ TEST_Op<"operand0_and_result_have_same_rank",
+ [AllRanksMatch<["x", "res"]>]> {
+ let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+ let results = (outs AnyShaped:$res);
+}
+
+def OperandZeroAndResultHaveSameShape :
+ TEST_Op<"operand0_and_result_have_same_shape",
+ [AllShapesMatch<["x", "res"]>]> {
+ let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+ let results = (outs AnyShaped:$res);
+}
+
+def OperandZeroAndResultHaveSameElementCount :
+ TEST_Op<"operand0_and_result_have_same_element_count",
+ [AllElementCountsMatch<["x", "res"]>]> {
+ let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+ let results = (outs AnyShaped:$res);
+}
+
+def FourEqualsFive :
+ TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>;
+
+def OperandRankEqualsResultSize :
+ TEST_Op<"operand_rank_equals_result_size",
+ [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result],
+ "operand rank equals result size">]> {
+ let arguments = (ins AnyShaped:$operand);
+ let results = (outs AnyShaped:$result);
+}
+
+def IfFirstOperandIsNoneThenSoIsSecond :
+ TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait<
+ "has either both none type operands or first is not none",
+ Or<[
+ And<[TypeIsPred<"x", NoneType>, TypeIsPred<"y", NoneType>]>,
+ Neg<TypeIsPred<"x", NoneType>>]>>]> {
+ let arguments = (ins AnyType:$x, AnyType:$y);
+}
+
+def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
+ let arguments = (ins AnyTensor, AnyTensor);
+ let results = (outs AnyTensor);
+}
+
+// There the "HasParent" trait.
+def ParentOp : TEST_Op<"parent">;
+def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
+
+
+def TerminatorOp : TEST_Op<"finish", [Terminator]>;
+def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
+ [SingleBlockImplicitTerminator<"TerminatorOp">]> {
+ let regions = (region SizedRegion<1>:$region);
+}
+
+def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
+ let arguments = (ins I32ElementsAttr:$attr);
+}
+
+def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let arguments = (ins AnyTensor, AnyTensor);
+ let results = (outs AnyTensor);
+}
+
+def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
+
+def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
+ (I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">),
+ [(IsNotScalar $attr)]>;
+
+def TestBranchOp : TEST_Op<"br", [Terminator]> {
+ let arguments = (ins Variadic<AnyType>:$operands);
+}
+
+def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ Variadic<I32>:$a,
+ Variadic<I32>:$b,
+ I32:$c,
+ Variadic<I32>:$d,
+ I32ElementsAttr:$operand_segment_sizes
+ );
+}
+
+def AttrSizedResultOp : TEST_Op<"attr_sized_results",
+ [AttrSizedResultSegments]> {
+ let arguments = (ins
+ I32ElementsAttr:$result_segment_sizes
+ );
+ let results = (outs
+ Variadic<I32>:$a,
+ Variadic<I32>:$b,
+ I32:$c,
+ Variadic<I32>:$d
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns
+//===----------------------------------------------------------------------===//
+
+def OpA : TEST_Op<"op_a"> {
+ let arguments = (ins I32, I32Attr:$attr);
+ let results = (outs I32);
+}
+
+def OpB : TEST_Op<"op_b"> {
+ let arguments = (ins I32, I32Attr:$attr);
+ let results = (outs I32);
+}
+
+// Test named pattern.
+def TestNamedPatternRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+
+// Test with fused location.
+def : Pat<(OpA (OpA $input, $attr), $bttr), (OpB $input, $bttr)>;
+
+// Test added benefit.
+def OpD : TEST_Op<"op_d">, Arguments<(ins I32)>, Results<(outs I32)>;
+def OpE : TEST_Op<"op_e">, Arguments<(ins I32)>, Results<(outs I32)>;
+def OpF : TEST_Op<"op_f">, Arguments<(ins I32)>, Results<(outs I32)>;
+def OpG : TEST_Op<"op_g">, Arguments<(ins I32)>, Results<(outs I32)>;
+// Verify that bumping benefit results in selecting different op.
+def : Pat<(OpD $input), (OpE $input)>;
+def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
+// Verify that patterns with more source nodes are selected before those with fewer.
+def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
+def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
+
+// Test patterns for zero-result op.
+def OpH : TEST_Op<"op_h">, Arguments<(ins I32)>, Results<(outs)>;
+def OpI : TEST_Op<"op_i">, Arguments<(ins I32)>, Results<(outs)>;
+def : Pat<(OpH $input), (OpI $input)>;
+
+// Test patterns for zero-input op.
+def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>;
+def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>;
+def : Pat<(OpJ), (OpK)>;
+
+// Test `$_` for ignoring op argument match.
+def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> {
+ let arguments = (ins
+ AnyType:$a, AnyType:$b, AnyType:$c,
+ AnyAttr:$d, AnyAttr:$e, AnyAttr:$f);
+}
+def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> {
+ let arguments = (ins AnyType:$b, AnyAttr:$f);
+}
+def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f),
+ (TestIgnoreArgMatchDstOp $b, $f)>;
+
+def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> {
+ let arguments = (ins
+ I32:$input1,
+ I64Attr:$attr1,
+ I32:$input2,
+ I64Attr:$attr2
+ );
+}
+
+def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> {
+ let arguments = (ins
+ I32:$input1,
+ I64Attr:$attr1,
+ I32:$input2,
+ I64Attr:$attr2
+ );
+}
+
+def ManyArgsOp : TEST_Op<"many_arguments"> {
+ let arguments = (ins
+ I32:$input1, I32:$input2, I32:$input3, I32:$input4, I32:$input5,
+ I32:$input6, I32:$input7, I32:$input8, I32:$input9,
+ I64Attr:$attr1, I64Attr:$attr2, I64Attr:$attr3, I64Attr:$attr4,
+ I64Attr:$attr5, I64Attr:$attr6, I64Attr:$attr7, I64Attr:$attr8,
+ I64Attr:$attr9
+ );
+}
+
+// Test that DRR does not blow up when seeing lots of arguments.
+def : Pat<(ManyArgsOp
+ $input1, $input2, $input3, $input4, $input5,
+ $input6, $input7, $input8, $input9,
+ ConstantAttr<I64Attr, "42">,
+ $attr2, $attr3, $attr4, $attr5, $attr6,
+ $attr7, $attr8, $attr9),
+ (ManyArgsOp
+ $input1, $input2, $input3, $input4, $input5,
+ $input6, $input7, $input8, $input9,
+ ConstantAttr<I64Attr, "24">,
+ $attr2, $attr3, $attr4, $attr5, $attr6,
+ $attr7, $attr8, $attr9)>;
+
+// Test that we can capture and reference interleaved operands and attributes.
+def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2),
+ (OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>;
+
+// Test NativeCodeCall.
+def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> {
+ let arguments = (ins
+ I32:$input1, I32:$input2,
+ BoolAttr:$choice,
+ I64Attr:$attr1, I64Attr:$attr2
+ );
+ let results = (outs I32);
+}
+def OpNativeCodeCall2 : TEST_Op<"native_code_call2"> {
+ let arguments = (ins I32:$input, I64ArrayAttr:$attr);
+ let results = (outs I32);
+}
+// Native code call to invoke a C++ function
+def CreateOperand: NativeCodeCall<"chooseOperand($0, $1, $2)">;
+// Native code call to invoke a C++ expression
+def CreateArrayAttr: NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
+// Test that we can use NativeCodeCall to create operand and attribute.
+// This pattern chooses between $input1 and $input2 according to $choice and
+// it combines $attr1 and $attr2 into an array attribute.
+def : Pat<(OpNativeCodeCall1 $input1, $input2,
+ ConstBoolAttrTrue:$choice, $attr1, $attr2),
+ (OpNativeCodeCall2 (CreateOperand $input1, $input2, $choice),
+ (CreateArrayAttr $attr1, $attr2))>;
+// Note: the following is just for testing purpose.
+// Should use the replaceWithValue directive instead.
+def UseOpResult: NativeCodeCall<"$0">;
+// Test that we can use NativeCodeCall to create result.
+def : Pat<(OpNativeCodeCall1 $input1, $input2,
+ ConstBoolAttrFalse, $attr1, $attr2),
+ (UseOpResult $input2)>;
+
+def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
+ let arguments = (ins I32:$input);
+ let results = (outs I32);
+}
+// Test that NativeCodeCall is not ignored if it is not used to directly
+// replace the matched root op.
+def : Pattern<(OpNativeCodeCall3 $input),
+ [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>;
+
+// Test AllAttrConstraintsOf.
+def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
+ let arguments = (ins I64ArrayAttr:$attr);
+ let results = (outs I32);
+}
+def OpAllAttrConstraint2 : TEST_Op<"all_attr_constraint_of2"> {
+ let arguments = (ins I64ArrayAttr:$attr);
+ let results = (outs I32);
+}
+def Constraint0 : AttrConstraint<
+ CPred<"$_self.cast<ArrayAttr>().getValue()[0]."
+ "cast<IntegerAttr>().getInt() == 0">,
+ "[0] == 0">;
+def Constraint1 : AttrConstraint<
+ CPred<"$_self.cast<ArrayAttr>().getValue()[1]."
+ "cast<IntegerAttr>().getInt() == 1">,
+ "[1] == 1">;
+def : Pat<(OpAllAttrConstraint1
+ AllAttrConstraintsOf<[Constraint0, Constraint1]>:$attr),
+ (OpAllAttrConstraint2 $attr)>;
+
+// Op for testing RewritePattern removing op with inner ops.
+def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
+ let regions = (region SizedRegion<1>:$region);
+ let hasCanonicalizer = 1;
+}
+
+// Op for testing trivial removal via folding of op with inner ops and no uses.
+def TestOpWithRegionFoldNoSideEffect : TEST_Op<
+ "op_with_region_fold_no_side_effect", [NoSideEffect]> {
+ let regions = (region SizedRegion<1>:$region);
+}
+
+// Op for testing folding of outer op with inner ops.
+def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> {
+ let arguments = (ins I32:$operand);
+ let results = (outs I32);
+ let regions = (region SizedRegion<1>:$region);
+ let hasFolder = 1;
+}
+
+def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_folder"> {
+ let arguments = (ins Variadic<I32>:$operands);
+ let results = (outs Variadic<I32>);
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Symbol Binding)
+
+// Test symbol binding.
+def OpSymbolBindingA : TEST_Op<"symbol_binding_a", []> {
+ let arguments = (ins I32:$operand, I64Attr:$attr);
+ let results = (outs I32);
+}
+def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> {
+ let arguments = (ins I32:$operand);
+ let results = (outs I32);
+
+ let builders = [
+ OpBuilder<
+ "Builder *builder, OperationState &state, Value operand",
+ [{
+ state.types.assign({builder->getIntegerType(32)});
+ state.addOperands({operand});
+ }]>
+ ];
+}
+def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> {
+ let arguments = (ins I32:$operand);
+ let results = (outs I32);
+ let builders = OpSymbolBindingB.builders;
+}
+def OpSymbolBindingD : TEST_Op<"symbol_binding_d", []> {
+ let arguments = (ins I32:$input1, I32:$input2, I64Attr:$attr);
+ let results = (outs I32);
+}
+def HasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
+def : Pattern<
+ // Bind to source pattern op operand/attribute/result
+ (OpSymbolBindingA:$res_a $operand, $attr), [
+ // Bind to auxiliary op result
+ (OpSymbolBindingC:$res_c (OpSymbolBindingB:$res_b $operand)),
+
+ // Use bound symbols in resultant ops
+ (OpSymbolBindingD $res_b, $res_c, $attr)],
+ // Use bound symbols in additional constraints
+ [(HasOneUse $res_a)]>;
+
+def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> {
+ let arguments = (ins I32:$operand);
+}
+
+// Test that we can bind to an op without results and reference it later.
+def : Pat<(OpSymbolBindingNoResult:$op $operand),
+ (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Attributes)
+
+// Test matching against op attributes.
+def OpAttrMatch1 : TEST_Op<"match_op_attribute1"> {
+ let arguments = (ins
+ I32Attr:$required_attr,
+ OptionalAttr<I32Attr>:$optional_attr,
+ DefaultValuedAttr<I32Attr, "42">:$default_valued_attr,
+ I32Attr:$more_attr
+ );
+ let results = (outs I32);
+}
+def OpAttrMatch2 : TEST_Op<"match_op_attribute2"> {
+ let arguments = OpAttrMatch1.arguments;
+ let results = (outs I32);
+}
+def MoreConstraint : AttrConstraint<
+ CPred<"$_self.cast<IntegerAttr>().getInt() == 4">, "more constraint">;
+def : Pat<(OpAttrMatch1 $required, $optional, $default_valued,
+ MoreConstraint:$more),
+ (OpAttrMatch2 $required, $optional, $default_valued, $more)>;
+
+// Test unit attrs.
+def OpAttrMatch3 : TEST_Op<"match_op_attribute3"> {
+ let arguments = (ins UnitAttr:$attr);
+ let results = (outs I32);
+}
+def OpAttrMatch4 : TEST_Op<"match_op_attribute4"> {
+ let arguments = (ins UnitAttr:$attr1, UnitAttr:$attr2);
+ let results = (outs I32);
+}
+def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>;
+
+// Test with constant attr.
+def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>;
+def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
+
+// Test string enum attribute in rewrites.
+def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
+// Test integer enum attribute in rewrites.
+def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
+def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
+def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
+
+def MultiResultOpEnum: I64EnumAttr<
+ "MultiResultOpEnum", "Multi-result op kinds", [
+ MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+ MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
+ ]>;
+
+def ThreeResultOp : TEST_Op<"three_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def AnotherThreeResultOp : TEST_Op<"another_three_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def TwoResultOp : TEST_Op<"two_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2);
+
+ let builders = [
+ OpBuilder<
+ "Builder *builder, OperationState &state, IntegerAttr kind",
+ [{
+ auto i32 = builder->getIntegerType(32);
+ auto f32 = builder->getF32Type();
+ state.types.assign({i32, f32});
+ state.addAttribute("kind", kind);
+ }]>
+ ];
+}
+
+def AnotherTwoResultOp : TEST_Op<"another_two_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs F32:$result1, F32:$result2);
+}
+
+def OneResultOp1 : TEST_Op<"one_result1"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs F32:$result1);
+}
+
+def OneResultOp2 : TEST_Op<"one_result2"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1);
+}
+
+def OneResultOp3 : TEST_Op<"one_result3"> {
+ let arguments = (ins F32);
+ let results = (outs I32:$result1);
+}
+
+// Test using multi-result op as a whole
+def : Pat<(ThreeResultOp MultiResultOpKind1),
+ (AnotherThreeResultOp MultiResultOpKind1)>;
+
+// Test using multi-result op as a whole for partial replacement
+def : Pattern<(ThreeResultOp MultiResultOpKind2),
+ [(TwoResultOp MultiResultOpKind2),
+ (OneResultOp1 MultiResultOpKind2)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3),
+ [(OneResultOp2 MultiResultOpKind3),
+ (AnotherTwoResultOp MultiResultOpKind3)]>;
+
+// Test using results separately in a multi-result op
+def : Pattern<(ThreeResultOp MultiResultOpKind4),
+ [(TwoResultOp:$res1__0 MultiResultOpKind4),
+ (OneResultOp1 MultiResultOpKind4),
+ (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+
+// Test referencing a single value in the value pack
+// This rule only matches TwoResultOp if its second result has no use.
+def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
+ [(OneResultOp2 MultiResultOpKind5),
+ (OneResultOp1 MultiResultOpKind5)],
+ [(HasNoUseOf:$res__1)]>;
+
+// Test using auxiliary ops for replacing multi-result op
+def : Pattern<
+ (ThreeResultOp MultiResultOpKind6), [
+ // Auxiliary op generated to help building the final result but not
+ // directly used to replace the source op's results.
+ (TwoResultOp:$interm MultiResultOpKind6),
+
+ (OneResultOp3 $interm__1),
+ (AnotherTwoResultOp MultiResultOpKind6)
+ ]>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Variadic Ops)
+
+def OneVResOneVOperandOp1 : TEST_Op<"one_variadic_out_one_variadic_in1"> {
+ let arguments = (ins Variadic<I32>);
+ let results = (outs Variadic<I32>);
+}
+def OneVResOneVOperandOp2 : TEST_Op<"one_variadic_out_one_variadic_in2"> {
+ let arguments = (ins Variadic<I32>);
+ let results = (outs Variadic<I32>);
+}
+
+// Rewrite an op with one variadic operand and one variadic result to
+// another similar op.
+def : Pat<(OneVResOneVOperandOp1 $inputs), (OneVResOneVOperandOp2 $inputs)>;
+
+def MixedVOperandOp1 : TEST_Op<"mixed_variadic_in1",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins
+ Variadic<I32>:$input1,
+ F32:$input2,
+ Variadic<I32>:$input3
+ );
+}
+
+def MixedVOperandOp2 : TEST_Op<"mixed_variadic_in2",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins
+ Variadic<I32>:$input1,
+ F32:$input2,
+ Variadic<I32>:$input3
+ );
+}
+
+// Rewrite an op with both variadic operands and normal operands.
+def : Pat<(MixedVOperandOp1 $input1, $input2, $input3),
+ (MixedVOperandOp2 $input1, $input2, $input3)>;
+
+def MixedVResultOp1 : TEST_Op<"mixed_variadic_out1", [SameVariadicResultSize]> {
+ let results = (outs
+ Variadic<I32>:$output1,
+ F32:$output2,
+ Variadic<I32>:$output3
+ );
+}
+
+def MixedVResultOp2 : TEST_Op<"mixed_variadic_out2", [SameVariadicResultSize]> {
+ let results = (outs
+ Variadic<I32>:$output1,
+ F32:$output2,
+ Variadic<I32>:$output3
+ );
+}
+
+// Rewrite an op with both variadic results and normal results.
+// Note that because we are generating the op with a top-level result pattern,
+// we are able to deduce the correct result types for the generated op using
+// the information from the matched root op.
+def : Pat<(MixedVResultOp1), (MixedVResultOp2)>;
+
+def OneI32ResultOp : TEST_Op<"one_i32_out"> {
+ let results = (outs I32);
+}
+
+def MixedVOperandOp3 : TEST_Op<"mixed_variadic_in3",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins
+ I32:$input1,
+ Variadic<I32>:$input2,
+ Variadic<I32>:$input3,
+ I32Attr:$count
+ );
+
+ let results = (outs I32);
+}
+
+def MixedVResultOp3 : TEST_Op<"mixed_variadic_out3",
+ [SameVariadicResultSize]> {
+ let arguments = (ins I32Attr:$count);
+
+ let results = (outs
+ I32:$output1,
+ Variadic<I32>:$output2,
+ Variadic<I32>:$output3
+ );
+
+ // We will use this op in a nested result pattern, where we cannot deduce the
+ // result type. So need to provide a builder not requiring result types.
+ let builders = [
+ OpBuilder<
+ "Builder *builder, OperationState &state, IntegerAttr count",
+ [{
+ auto i32Type = builder->getIntegerType(32);
+ state.addTypes(i32Type); // $output1
+ SmallVector<Type, 4> types(count.getInt(), i32Type);
+ state.addTypes(types); // $output2
+ state.addTypes(types); // $output3
+ state.addAttribute("count", count);
+ }]>
+ ];
+}
+
+// Generates an op with variadic results using nested pattern.
+def : Pat<(OneI32ResultOp),
+ (MixedVOperandOp3
+ (MixedVResultOp3:$results__0 ConstantAttr<I32Attr, "2">),
+ (replaceWithValue $results__1),
+ (replaceWithValue $results__2),
+ ConstantAttr<I32Attr, "2">)>;
+
+//===----------------------------------------------------------------------===//
+// Test Legalization
+//===----------------------------------------------------------------------===//
+
+def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">;
+def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">;
+
+def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure",
+ [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>;
+
+def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>;
+def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>;
+def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>;
+def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>;
+def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
+def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
+def LegalOpA : TEST_Op<"legal_op_a">,
+ Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
+def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
+
+// Check that smaller pattern depths are chosen, i.e. prioritize more direct
+// mappings.
+def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>;
+
+def : Pat<(ILLegalOpA), (ILLegalOpB)>;
+def : Pat<(ILLegalOpB), (LegalOpA Test_LegalizerEnum_Failure)>;
+
+// Check that the higher benefit pattern is taken for multiple legalizations
+// with the same depth.
+def : Pat<(ILLegalOpC), (ILLegalOpD)>;
+def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>;
+
+def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>;
+def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>;
+
+// Check that patterns use the most up-to-date value when being replaced.
+def TestRewriteOp : TEST_Op<"rewrite">,
+ Arguments<(ins AnyType)>, Results<(outs AnyType)>;
+def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
+
+//===----------------------------------------------------------------------===//
+// Test Type Legalization
+//===----------------------------------------------------------------------===//
+
+def TestRegionBuilderOp : TEST_Op<"region_builder">;
+def TestReturnOp : TEST_Op<"return", [Terminator]>,
+ Arguments<(ins Variadic<AnyType>)>;
+def TestCastOp : TEST_Op<"cast">,
+ Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
+def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
+ Arguments<(ins Variadic<AnyType>)>;
+def TestTypeProducerOp : TEST_Op<"type_producer">,
+ Results<(outs AnyType)>;
+def TestTypeConsumerOp : TEST_Op<"type_consumer">,
+ Arguments<(ins AnyType)>;
+def TestValidOp : TEST_Op<"valid", [Terminator]>,
+ Arguments<(ins Variadic<AnyType>)>;
+
+//===----------------------------------------------------------------------===//
+// Test parser.
+//===----------------------------------------------------------------------===//
+
+def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> {
+ let arguments = (ins StrAttr:$keyword);
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+ let printer = [{ return ::print(p, *this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Test region argument list parsing.
+
+def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> {
+ let summary = "isolated region operation";
+ let description = [{
+ Test op with an isolated region, to test passthrough region arguments. Each
+ argument is of index type.
+ }];
+
+ let arguments = (ins Index);
+ let regions = (region SizedRegion<1>:$region);
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+ let printer = [{ return ::print(p, *this); }];
+}
+
+def WrappingRegionOp : TEST_Op<"wrapping_region",
+ [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+ let summary = "wrapping region operation";
+ let description = [{
+ Test op wrapping another op in a region, to test calling
+ parseGenericOperation from the custom parser.
+ }];
+
+ let results = (outs Variadic<AnyType>);
+ let regions = (region SizedRegion<1>:$region);
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+ let printer = [{ return ::print(p, *this); }];
+}
+
+def PolyForOp : TEST_Op<"polyfor">
+{
+ let summary = "polyfor operation";
+ let description = [{
+ Test op with multiple region arguments, each argument of index type.
+ }];
+
+ let regions = (region SizedRegion<1>:$region);
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmInterface.
+
+def AsmInterfaceOp : TEST_Op<"asm_interface_op"> {
+ let results = (outs AnyType:$first, Variadic<AnyType>:$middle_results,
+ AnyType);
+}
+
+def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> {
+ let results = (outs AnyType);
+}
+
+#endif // TEST_OPS
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
new file mode 100644
index 00000000000..929c4a941a2
--- /dev/null
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -0,0 +1,504 @@
+//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+using namespace mlir;
+
+// Native function for testing NativeCodeCall
+static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
+ return choice.getValue() ? input1 : input2;
+}
+
+static void createOpI(PatternRewriter &rewriter, Value input) {
+ rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
+}
+
+void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) {
+ // Turn the no result op to a one-result op.
+ rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand()->getType(),
+ op.operand());
+}
+
+namespace {
+#include "TestPatterns.inc"
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Canonicalizer Driver.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
+ void runOnFunction() override {
+ mlir::OwningRewritePatternList patterns;
+ populateWithGenerated(&getContext(), &patterns);
+
+ // Verify named pattern is generated with expected name.
+ patterns.insert<TestNamedPatternRule>(&getContext());
+
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+} // end anonymous namespace
+
+static mlir::PassRegistration<TestPatternDriver>
+ pass("test-patterns", "Run test dialect patterns");
+
+//===----------------------------------------------------------------------===//
+// ReturnType Driver.
+//===----------------------------------------------------------------------===//
+
+struct ReturnTypeOpMatch : public RewritePattern {
+ ReturnTypeOpMatch(MLIRContext *ctx)
+ : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
+ }
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
+ SmallVector<Value, 4> values(op->getOperands());
+ SmallVector<Type, 2> inferedReturnTypes;
+ if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
+ op->getAttrs(), op->getRegions(),
+ inferedReturnTypes)))
+ return matchFailure();
+ SmallVector<Type, 1> resultTypes(op->getResultTypes());
+ if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
+ return op->emitOpError(
+ "inferred type incompatible with return type of operation"),
+ matchFailure();
+
+ // TODO(jpienaar): Split this out to make the test more focused.
+ // Create new op with unknown location to verify building with
+ // InferTypeOpInterface is triggered.
+ auto fop = op->getParentOfType<FuncOp>();
+ if (values[0] == fop.getArgument(0)) {
+ // Use the 2nd function argument if the first function argument is used
+ // when constructing the new op so that a new return type is inferred.
+ values[0] = fop.getArgument(1);
+ values[1] = fop.getArgument(1);
+ // TODO(jpienaar): Expand to regions.
+ rewriter.create<OpWithInferTypeInterfaceOp>(
+ UnknownLoc::get(op->getContext()), values, op->getAttrs());
+ }
+ }
+ return matchFailure();
+ }
+};
+
+namespace {
+struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
+ void runOnFunction() override {
+ mlir::OwningRewritePatternList patterns;
+ populateWithGenerated(&getContext(), &patterns);
+ patterns.insert<ReturnTypeOpMatch>(&getContext());
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+} // end anonymous namespace
+
+static mlir::PassRegistration<TestReturnTypeDriver>
+ rt_pass("test-return-type", "Run return type functions");
+
+//===----------------------------------------------------------------------===//
+// Legalization Driver.
+//===----------------------------------------------------------------------===//
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Region-Block Rewrite Testing
+
+/// This pattern is a simple pattern that inlines the first region of a given
+/// operation into the parent region.
+struct TestRegionRewriteBlockMovement : public ConversionPattern {
+ TestRegionRewriteBlockMovement(MLIRContext *ctx)
+ : ConversionPattern("test.region", 1, ctx) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Inline this region into the parent region.
+ auto &parentRegion = *op->getParentRegion();
+ if (op->getAttr("legalizer.should_clone"))
+ rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
+ parentRegion.end());
+ else
+ rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
+ parentRegion.end());
+
+ // Drop this operation.
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+/// This pattern is a simple pattern that generates a region containing an
+/// illegal operation.
+struct TestRegionRewriteUndo : public RewritePattern {
+ TestRegionRewriteUndo(MLIRContext *ctx)
+ : RewritePattern("test.region_builder", 1, ctx) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ // Create the region operation with an entry block containing arguments.
+ OperationState newRegion(op->getLoc(), "test.region");
+ newRegion.addRegion();
+ auto *regionOp = rewriter.createOperation(newRegion);
+ auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
+ entryBlock->addArgument(rewriter.getIntegerType(64));
+
+ // Add an explicitly illegal operation to ensure the conversion fails.
+ rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
+ rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
+
+ // Drop this operation.
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Type-Conversion Rewrite Testing
+
+/// This patterns erases a region operation that has had a type conversion.
+struct TestDropOpSignatureConversion : public ConversionPattern {
+ TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+ : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
+ }
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Region &region = op->getRegion(0);
+ Block *entry = &region.front();
+
+ // Convert the original entry arguments.
+ TypeConverter::SignatureConversion result(entry->getNumArguments());
+ for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
+ if (failed(converter.convertSignatureArg(
+ i, entry->getArgument(i)->getType(), result)))
+ return matchFailure();
+
+ // Convert the region signature and just drop the operation.
+ rewriter.applySignatureConversion(&region, result);
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+ /// The type converter to use when rewriting the signature.
+ TypeConverter &converter;
+};
+/// This pattern simply updates the operands of the given operation.
+struct TestPassthroughInvalidOp : public ConversionPattern {
+ TestPassthroughInvalidOp(MLIRContext *ctx)
+ : ConversionPattern("test.invalid", 1, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
+ llvm::None);
+ return matchSuccess();
+ }
+};
+/// This pattern handles the case of a split return value.
+struct TestSplitReturnType : public ConversionPattern {
+ TestSplitReturnType(MLIRContext *ctx)
+ : ConversionPattern("test.return", 1, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Check for a return of F32.
+ if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
+ return matchFailure();
+
+ // Check if the first operation is a cast operation, if it is we use the
+ // results directly.
+ auto *defOp = operands[0]->getDefiningOp();
+ if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
+ rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
+ return matchSuccess();
+ }
+
+ // Otherwise, fail to match.
+ return matchFailure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Multi-Level Type-Conversion Rewrite Testing
+struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
+ TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
+ : ConversionPattern("test.type_producer", 1, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // If the type is I32, change the type to F32.
+ if (!(*op->result_type_begin()).isInteger(32))
+ return matchFailure();
+ rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
+ return matchSuccess();
+ }
+};
+struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
+ TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
+ : ConversionPattern("test.type_producer", 1, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // If the type is F32, change the type to F64.
+ if (!(*op->result_type_begin()).isF32())
+ return matchFailure();
+ rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
+ return matchSuccess();
+ }
+};
+struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
+ TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
+ : ConversionPattern("test.type_producer", 10, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Always convert to B16, even though it is not a legal type. This tests
+ // that values are unmapped correctly.
+ rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
+ return matchSuccess();
+ }
+};
+struct TestUpdateConsumerType : public ConversionPattern {
+ TestUpdateConsumerType(MLIRContext *ctx)
+ : ConversionPattern("test.type_consumer", 1, ctx) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Verify that the incoming operand has been successfully remapped to F64.
+ if (!operands[0]->getType().isF64())
+ return matchFailure();
+ rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
+ return matchSuccess();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Non-Root Replacement Rewrite Testing
+/// This pattern generates an invalid operation, but replaces it before the
+/// pattern is finished. This checks that we don't need to legalize the
+/// temporary op.
+struct TestNonRootReplacement : public RewritePattern {
+ TestNonRootReplacement(MLIRContext *ctx)
+ : RewritePattern("test.replace_non_root", 1, ctx) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ auto resultType = *op->result_type_begin();
+ auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
+ auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
+
+ rewriter.replaceOp(illegalOp, {legalOp});
+ rewriter.replaceOp(op, {illegalOp});
+ return matchSuccess();
+ }
+};
+} // namespace
+
+namespace {
+struct TestTypeConverter : public TypeConverter {
+ using TypeConverter::TypeConverter;
+
+ LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
+ // Drop I16 types.
+ if (t.isInteger(16))
+ return success();
+
+ // Convert I64 to F64.
+ if (t.isInteger(64)) {
+ results.push_back(FloatType::getF64(t.getContext()));
+ return success();
+ }
+
+ // Split F32 into F16,F16.
+ if (t.isF32()) {
+ results.assign(2, FloatType::getF16(t.getContext()));
+ return success();
+ }
+
+ // Otherwise, convert the type directly.
+ results.push_back(t);
+ return success();
+ }
+
+ /// Override the hook to materialize a conversion. This is necessary because
+ /// we generate 1->N type mappings.
+ Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
+ ArrayRef<Value> inputs,
+ Location loc) override {
+ return rewriter.create<TestCastOp>(loc, resultType, inputs);
+ }
+};
+
+struct TestLegalizePatternDriver
+ : public ModulePass<TestLegalizePatternDriver> {
+ /// The mode of conversion to use with the driver.
+ enum class ConversionMode { Analysis, Full, Partial };
+
+ TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+
+ void runOnModule() override {
+ TestTypeConverter converter;
+ mlir::OwningRewritePatternList patterns;
+ populateWithGenerated(&getContext(), &patterns);
+ patterns
+ .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+ TestPassthroughInvalidOp, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement>(&getContext());
+ patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
+ mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
+ converter);
+
+ // Define the conversion target used for the test.
+ ConversionTarget target(getContext());
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+ target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
+ target
+ .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
+ target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
+ // Don't allow F32 operands.
+ return llvm::none_of(op.getOperandTypes(),
+ [](Type type) { return type.isF32(); });
+ });
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+
+ // Expect the type_producer/type_consumer operations to only operate on f64.
+ target.addDynamicallyLegalOp<TestTypeProducerOp>(
+ [](TestTypeProducerOp op) { return op.getType().isF64(); });
+ target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
+ return op.getOperand()->getType().isF64();
+ });
+
+ // Check support for marking certain operations as recursively legal.
+ target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
+ return static_cast<bool>(
+ op->getAttrOfType<UnitAttr>("test.recursively_legal"));
+ });
+
+ // Handle a partial conversion.
+ if (mode == ConversionMode::Partial) {
+ (void)applyPartialConversion(getModule(), target, patterns, &converter);
+ return;
+ }
+
+ // Handle a full conversion.
+ if (mode == ConversionMode::Full) {
+ (void)applyFullConversion(getModule(), target, patterns, &converter);
+ return;
+ }
+
+ // Otherwise, handle an analysis conversion.
+ assert(mode == ConversionMode::Analysis);
+
+ // Analyze the convertible operations.
+ DenseSet<Operation *> legalizedOps;
+ if (failed(applyAnalysisConversion(getModule(), target, patterns,
+ legalizedOps, &converter)))
+ return signalPassFailure();
+
+ // Emit remarks for each legalizable operation.
+ for (auto *op : legalizedOps)
+ op->emitRemark() << "op '" << op->getName() << "' is legalizable";
+ }
+
+ /// The mode of conversion to use.
+ ConversionMode mode;
+};
+} // end anonymous namespace
+
+static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
+ legalizerConversionMode(
+ "test-legalize-mode",
+ llvm::cl::desc("The legalization mode to use with the test driver"),
+ llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+ llvm::cl::values(
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+ "analysis", "Perform an analysis conversion"),
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
+ "Perform a full conversion"),
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+ "partial", "Perform a partial conversion")));
+
+static mlir::PassRegistration<TestLegalizePatternDriver>
+ legalizer_pass("test-legalize-patterns",
+ "Run test dialect legalization patterns", [] {
+ return std::make_unique<TestLegalizePatternDriver>(
+ legalizerConversionMode);
+ });
+
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriter::getRemappedValue testing. This method is used
+// to get the remapped value of a original value that was replaced using
+// ConversionPatternRewriter.
+namespace {
+/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
+/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
+/// operand twice.
+///
+/// Example:
+/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
+/// is replaced with:
+/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
+struct OneVResOneVOperandOp1Converter
+ : public OpConversionPattern<OneVResOneVOperandOp1> {
+ using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
+
+ PatternMatchResult
+ matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto origOps = op.getOperands();
+ assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
+ "One operand expected");
+ Value origOp = *origOps.begin();
+ SmallVector<Value, 2> remappedOperands;
+ // Replicate the remapped original operand twice. Note that we don't used
+ // the remapped 'operand' since the goal is testing 'getRemappedValue'.
+ remappedOperands.push_back(rewriter.getRemappedValue(origOp));
+ remappedOperands.push_back(rewriter.getRemappedValue(origOp));
+
+ SmallVector<Type, 1> resultTypes(op.getResultTypes());
+ rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
+ remappedOperands);
+ return matchSuccess();
+ }
+};
+
+struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
+ void runOnFunction() override {
+ mlir::OwningRewritePatternList patterns;
+ patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
+
+ mlir::ConversionTarget target(getContext());
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
+ // We make OneVResOneVOperandOp1 legal only when it has more that one
+ // operand. This will trigger the conversion that will replace one-operand
+ // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
+ target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
+ [](Operation *op) -> bool {
+ return std::distance(op->operand_begin(), op->operand_end()) > 1;
+ });
+
+ if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
+ signalPassFailure();
+ }
+ }
+};
+} // end anonymous namespace
+
+static PassRegistration<TestRemappedValue> remapped_value_pass(
+ "test-remapped-value",
+ "Test public remapped value mechanism in ConversionPatternRewriter");
diff --git a/mlir/test/lib/TestDialect/lit.local.cfg b/mlir/test/lib/TestDialect/lit.local.cfg
new file mode 100644
index 00000000000..edb5b44b2e2
--- /dev/null
+++ b/mlir/test/lib/TestDialect/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove('.td') \ No newline at end of file
OpenPOWER on IntegriCloud