diff options
author | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
---|---|---|
committer | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
commit | 0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch) | |
tree | 31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/test/lib/TestDialect | |
parent | 6f635f90929da9545dd696071a829a1a42f84b30 (diff) | |
parent | 5b4a01d4a63cb66ab981e52548f940813393bf42 (diff) | |
download | bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip |
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/test/lib/TestDialect')
-rw-r--r-- | mlir/test/lib/TestDialect/CMakeLists.txt | 28 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 316 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.h | 53 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestOps.td | 1047 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestPatterns.cpp | 504 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/lit.local.cfg | 1 |
6 files changed, 1949 insertions, 0 deletions
diff --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt new file mode 100644 index 00000000000..e6a22833de4 --- /dev/null +++ b/mlir/test/lib/TestDialect/CMakeLists.txt @@ -0,0 +1,28 @@ +set(LLVM_OPTIONAL_SOURCES + TestDialect.cpp + TestPatterns.cpp +) + +set(LLVM_TARGET_DEFINITIONS TestOps.td) +mlir_tablegen(TestOps.h.inc -gen-op-decls) +mlir_tablegen(TestOps.cpp.inc -gen-op-defs) +mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) +mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(TestPatterns.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestOpsIncGen) + +add_llvm_library(MLIRTestDialect + TestDialect.cpp + TestPatterns.cpp +) +add_dependencies(MLIRTestDialect + MLIRTestOpsIncGen + MLIRIR + LLVMSupport + MLIRTypeInferOpInterfaceIncGen +) +target_link_libraries(MLIRTestDialect + MLIRDialect + MLIRIR + LLVMSupport +) diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp new file mode 100644 index 00000000000..21cf69ec1fa --- /dev/null +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -0,0 +1,316 @@ +//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// TestDialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +// Test support for interacting with the AsmPrinter. +struct TestOpAsmInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + void getAsmResultNames(Operation *op, + OpAsmSetValueNameFn setNameFn) const final { + if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) + setNameFn(asmOp, "result"); + } + + void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const final { + auto op = block->getParentOp(); + auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); + if (!arrayAttr) + return; + auto args = block->getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>()) + setNameFn(args[i], strAttr.getValue()); + } + } +}; + +struct TestOpFolderDialectInterface : public OpFolderDialectInterface { + using OpFolderDialectInterface::OpFolderDialectInterface; + + /// Registered hook to check if the given region, which is attached to an + /// operation that is *not* isolated from above, should be used when + /// materializing constants. + bool shouldMaterializeInto(Region *region) const final { + // If this is a one region operation, then insert into it. + return isa<OneRegionOp>(region->getParentOp()); + } +}; + +/// This class defines the interface for handling inlining with standard +/// operations. +struct TestInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { + // Inlining into test dialect regions is legal. + return true; + } + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + bool shouldAnalyzeRecursively(Operation *op) const final { + // Analyze recursively if this is not a functional region operation, it + // froms a separate functional scope. + return !isa<FunctionalRegionOp>(op); + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef<Value> valuesToRepl) const final { + // Only handle "test.return" here. + auto returnOp = dyn_cast<TestReturnOp>(op); + if (!returnOp) + return; + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempt to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + // Only allow conversion for i16/i32 types. + if (!(resultType.isInteger(16) || resultType.isInteger(32)) || + !(input->getType().isInteger(16) || input->getType().isInteger(32))) + return nullptr; + return builder.create<TestCastOp>(conversionLoc, resultType, input); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +TestDialect::TestDialect(MLIRContext *context) + : Dialect(getDialectName(), context) { + addOperations< +#define GET_OP_LIST +#include "TestOps.cpp.inc" + >(); + addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, + TestInlinerInterface>(); + allowUnknownOperations(); +} + +LogicalResult TestDialect::verifyOperationAttribute(Operation *op, + NamedAttribute namedAttr) { + if (namedAttr.first == "test.invalid_attr") + return op->emitError() << "invalid to use 'test.invalid_attr'"; + return success(); +} + +LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, + unsigned regionIndex, + unsigned argIndex, + NamedAttribute namedAttr) { + if (namedAttr.first == "test.invalid_attr") + return op->emitError() << "invalid to use 'test.invalid_attr'"; + return success(); +} + +LogicalResult +TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute namedAttr) { + if (namedAttr.first == "test.invalid_attr") + return op->emitError() << "invalid to use 'test.invalid_attr'"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Test IsolatedRegionOp - parse passthrough region arguments. +//===----------------------------------------------------------------------===// + +static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType argInfo; + Type argType = parser.getBuilder().getIndexType(); + + // Parse the input operand. + if (parser.parseOperand(argInfo) || + parser.resolveOperand(argInfo, argType, result.operands)) + return failure(); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, argType, + /*enableNameShadowing=*/true); +} + +static void print(OpAsmPrinter &p, IsolatedRegionOp op) { + p << "test.isolated_region "; + p.printOperand(op.getOperand()); + p.shadowRegionArgs(op.region(), op.getOperand()); + p.printRegion(op.region(), /*printEntryBlockArgs=*/false); +} + +//===----------------------------------------------------------------------===// +// Test parser. +//===----------------------------------------------------------------------===// + +static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, + OperationState &result) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); + return success(); +} + +static void print(OpAsmPrinter &p, WrappedKeywordOp op) { + p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); +} + +//===----------------------------------------------------------------------===// +// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. + +static ParseResult parseWrappingRegionOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseKeyword("wraps")) + return failure(); + + // Parse the wrapped op in a region + Region &body = *result.addRegion(); + body.push_back(new Block); + Block &block = body.back(); + Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); + if (!wrapped_op) + return failure(); + + // Create a return terminator in the inner region, pass as operand to the + // terminator the returned values from the wrapped operation. + SmallVector<Value, 8> return_operands(wrapped_op->getResults()); + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); + + // Get the results type for the wrapping op from the terminator operands. + Operation &return_op = body.back().back(); + result.types.append(return_op.operand_type_begin(), + return_op.operand_type_end()); + + // Use the location of the wrapped op for the "test.wrapping_region" op. + result.location = wrapped_op->getLoc(); + + return success(); +} + +static void print(OpAsmPrinter &p, WrappingRegionOp op) { + p << op.getOperationName() << " wraps "; + p.printGenericOp(&op.region().front().front()); +} + +//===----------------------------------------------------------------------===// +// Test PolyForOp - parse list of region arguments. +//===----------------------------------------------------------------------===// + +static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 4> ivsInfo; + // Parse list of region arguments without a delimiter. + if (parser.parseRegionArgumentList(ivsInfo)) + return failure(); + + // Parse the body region. + Region *body = result.addRegion(); + auto &builder = parser.getBuilder(); + SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); + return parser.parseRegion(*body, ivsInfo, argTypes); +} + +//===----------------------------------------------------------------------===// +// Test removing op with inner ops. +//===----------------------------------------------------------------------===// + +namespace { +struct TestRemoveOpWithInnerOps + : public OpRewritePattern<TestOpWithRegionPattern> { + using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return matchSuccess(); + } +}; +} // end anonymous namespace + +void TestOpWithRegionPattern::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<TestRemoveOpWithInnerOps>(context); +} + +OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { + return operand(); +} + +LogicalResult TestOpWithVariadicResultsAndFolder::fold( + ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { + for (Value input : this->operands()) { + results.push_back(input); + } + return success(); +} + +LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( + llvm::Optional<Location> location, ValueRange operands, + ArrayRef<NamedAttribute> attributes, RegionRange regions, + SmallVectorImpl<Type> &inferedReturnTypes) { + if (operands[0]->getType() != operands[1]->getType()) { + return emitOptionalError(location, "operand type mismatch ", + operands[0]->getType(), " vs ", + operands[1]->getType()); + } + inferedReturnTypes.assign({operands[0]->getType()}); + return success(); +} + +// Static initialization for Test dialect registration. +static mlir::DialectRegistration<mlir::TestDialect> testDialect; + +#include "TestOpEnums.cpp.inc" + +#define GET_OP_CLASSES +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h new file mode 100644 index 00000000000..20db0f39b81 --- /dev/null +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -0,0 +1,53 @@ +//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test' dialect that can be used for testing things +// that do not have a respective counterpart in the main source directories. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTDIALECT_H +#define MLIR_TESTDIALECT_H + +#include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/InferTypeOpInterface.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" + +#include "TestOpEnums.h.inc" + +namespace mlir { + +class TestDialect : public Dialect { +public: + /// Create the dialect in the given `context`. + TestDialect(MLIRContext *context); + + /// Get the canonical string name of the dialect. + static StringRef getDialectName() { return "test"; } + + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute namedAttr) override; + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex, + unsigned argIndex, + NamedAttribute namedAttr) override; + LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute namedAttr) override; +}; + +#define GET_OP_CLASSES +#include "TestOps.h.inc" + +} // end namespace mlir + +#endif // MLIR_TESTDIALECT_H diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td new file mode 100644 index 00000000000..dacb796de18 --- /dev/null +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -0,0 +1,1047 @@ +//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_OPS +#define TEST_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/InferTypeOpInterface.td" + +def TEST_Dialect : Dialect { + let name = "test"; + let cppNamespace = ""; +} + +class TEST_Op<string mnemonic, list<OpTrait> traits = []> : + Op<TEST_Dialect, mnemonic, traits>; + +//===----------------------------------------------------------------------===// +// Test Types +//===----------------------------------------------------------------------===// + +def ComplexF64 : Complex<F64>; +def ComplexOp : TEST_Op<"complex_f64"> { + let results = (outs ComplexF64); +} + +def ComplexTensorOp : TEST_Op<"complex_f64_tensor"> { + let results = (outs TensorOf<[ComplexF64]>); +} + +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; + +def TupleOp : TEST_Op<"tuple_32_bit"> { + let results = (outs TupleOf<[I32, F32]>); +} + +def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { + let results = (outs NestedTupleOf<[I32, F32]>); +} + +def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> { + let arguments = (ins AnyStaticShapeMemRef:$x); +} + +def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> { + let results = (outs MemRefRankOf<[I8, F32], [0, 1]>); +} + +def NDTensorOfOp : TEST_Op<"nd_tensor_of"> { + let arguments = (ins + 0DTensorOf<[F32]>:$arg0, + 1DTensorOf<[F32]>:$arg1, + 2DTensorOf<[I16]>:$arg2, + 3DTensorOf<[I16]>:$arg3, + 4DTensorOf<[I16]>:$arg4 + ); +} + +def RankedTensorOp : TEST_Op<"ranked_tensor_op"> { + let arguments = (ins AnyRankedTensor:$input); +} + +def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> { + let arguments = (ins + TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0 + ); +} + +//===----------------------------------------------------------------------===// +// Test Operands +//===----------------------------------------------------------------------===// + +def SymbolScopeOp : TEST_Op<"symbol_scope", + [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "operation which defines a new symbol table"; + let regions = (region SizedRegion<1>:$region); +} + +def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { + let summary = "operation which defines a new symbol table without a " + "restriction on a terminator"; + let regions = (region SizedRegion<1>:$region); +} + +//===----------------------------------------------------------------------===// +// Test Operands +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicOperandOp : TEST_Op< + "mixed_normal_variadic_operand", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic<AnyTensor>:$input1, + AnyTensor:$input2, + Variadic<AnyTensor>:$input3 + ); +} + +//===----------------------------------------------------------------------===// +// Test Results +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicResults : TEST_Op< + "mixed_normal_variadic_result", [SameVariadicResultSize]> { + let results = (outs + Variadic<AnyTensor>:$output1, + AnyTensor:$output2, + Variadic<AnyTensor>:$output3 + ); +} + +//===----------------------------------------------------------------------===// +// Test Attributes +//===----------------------------------------------------------------------===// + +def NonNegIntAttrOp : TEST_Op<"non_negative_int_attr"> { + let arguments = (ins + NonNegativeI32Attr:$i32attr, + NonNegativeI64Attr:$i64attr + ); +} + +def PositiveIntAttrOp : TEST_Op<"positive_int_attr"> { + let arguments = (ins + PositiveI32Attr:$i32attr, + PositiveI64Attr:$i64attr + ); +} + +def TypeArrayAttrOp : TEST_Op<"type_array_attr"> { + let arguments = (ins TypeArrayAttr:$attr); +} +def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> { + let arguments = (ins StrAttr:$attr); + let printer = [{ p << getAttr("attr"); }]; + let parser = [{ + Attribute attr; + Type stringType = OpaqueType::get(Identifier::get("foo", + result.getContext()), "string", + result.getContext()); + return parser.parseAttribute(attr, stringType, "attr", result.attributes); + }]; +} + +def StrCaseA: StrEnumAttrCase<"A">; +def StrCaseB: StrEnumAttrCase<"B">; + +def SomeStrEnum: StrEnumAttr< + "SomeStrEnum", "", [StrCaseA, StrCaseB]>; + +def StrEnumAttrOp : TEST_Op<"str_enum_attr"> { + let arguments = (ins SomeStrEnum:$attr); + let results = (outs I32:$val); +} + +def I32Case5: I32EnumAttrCase<"case5", 5>; +def I32Case10: I32EnumAttrCase<"case10", 10>; + +def SomeI32Enum: I32EnumAttr< + "SomeI32Enum", "", [I32Case5, I32Case10]>; + +def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { + let arguments = (ins SomeI32Enum:$attr); + let results = (outs I32:$val); +} + +def I64Case5: I64EnumAttrCase<"case5", 5>; +def I64Case10: I64EnumAttrCase<"case10", 10>; + +def SomeI64Enum: I64EnumAttr< + "SomeI64Enum", "", [I64Case5, I64Case10]>; + +def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { + let arguments = (ins SomeI64Enum:$attr); + let results = (outs I32:$val); +} + +def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> { + let arguments = (ins + RankedF32ElementsAttr<[2]>:$scalar_f32_attr, + RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr + ); +} + +// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>. +// This tests both matching and generating float elements attributes. +def UpdateFloatElementsAttr : Pat< + (FloatElementsAttrOp + ConstantAttr<RankedF32ElementsAttr<[2]>, "{3.0f, 4.0f}">:$f32attr, + $f64attr), + (FloatElementsAttrOp + ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr, + $f64attr)>; + +//===----------------------------------------------------------------------===// +// Test Attribute Constraints +//===----------------------------------------------------------------------===// + +def SymbolRefOp : TEST_Op<"symbol_ref_attr"> { + let arguments = (ins + Confined<FlatSymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol + ); +} + +//===----------------------------------------------------------------------===// +// Test Regions +//===----------------------------------------------------------------------===// + +def OneRegionOp : TEST_Op<"one_region_op", []> { + let regions = (region AnyRegion); +} + +def TwoRegionOp : TEST_Op<"two_region_op", []> { + let regions = (region AnyRegion, AnyRegion); +} + +def SizedRegionOp : TEST_Op<"sized_region_op", []> { + let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>); +} + +//===----------------------------------------------------------------------===// +// Test Call Interfaces +//===----------------------------------------------------------------------===// + +def ConversionCallOp : TEST_Op<"conversion_call_op", + [CallOpInterface]> { + let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee); + let results = (outs Variadic<AnyType>); + + let extraClassDeclaration = [{ + /// Get the argument operands to the called function. + operand_range getArgOperands() { return inputs(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType<FlatSymbolRefAttr>("callee"); + } + }]; +} + +def FunctionalRegionOp : TEST_Op<"functional_region_op", + [CallableOpInterface]> { + let regions = (region AnyRegion:$body); + let results = (outs FunctionType); + + let extraClassDeclaration = [{ + Region *getCallableRegion(CallInterfaceCallable) { return &body(); } + void getCallableRegions(SmallVectorImpl<Region *> &callables) { + callables.push_back(&body()); + } + ArrayRef<Type> getCallableResults(Region *) { + return getType().cast<FunctionType>().getResults(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// Test Traits +//===----------------------------------------------------------------------===// + +def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type", + [SameOperandsElementType]> { + let arguments = (ins AnyType, AnyType); + let results = (outs AnyType); +} + +def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type", + [SameOperandsAndResultElementType]> { + let arguments = (ins Variadic<AnyType>); + let results = (outs Variadic<AnyType>); +} + +def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> { + let arguments = (ins Variadic<AnyShaped>); +} + +def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape", + [SameOperandsAndResultShape]> { + let arguments = (ins Variadic<AnyShaped>); + let results = (outs Variadic<AnyShaped>); +} + +def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic<AnyType>); + let results = (outs Variadic<AnyType>); +} + +def ArgAndResHaveFixedElementTypesOp : + TEST_Op<"arg_and_res_have_fixed_element_types", + [PredOpTrait<"fixed type combination", + And<[ElementTypeIsPred<"x", I32>, + ElementTypeIsPred<"y", F32>]>>, + ElementTypeIs<"res", I16>]> { + let arguments = (ins + AnyShaped:$x, AnyShaped:$y); + let results = (outs AnyShaped:$res); +} + +def OperandsHaveSameElementType : TEST_Op<"operands_have_same_element_type", [ + AllElementTypesMatch<["x", "y"]>]> { + let arguments = (ins AnyType:$x, AnyType:$y); +} + +def OperandZeroAndResultHaveSameElementType : TEST_Op< + "operand0_and_result_have_same_element_type", + [AllElementTypesMatch<["x", "res"]>]> { + let arguments = (ins AnyType:$x, AnyType:$y); + let results = (outs AnyType:$res); +} + +def OperandsHaveSameType : + TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> { + let arguments = (ins AnyType:$x, AnyType:$y); +} + +def OperandZeroAndResultHaveSameType : + TEST_Op<"operand0_and_result_have_same_type", + [AllTypesMatch<["x", "res"]>]> { + let arguments = (ins AnyType:$x, AnyType:$y); + let results = (outs AnyType:$res); +} + +def OperandsHaveSameRank : + TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> { + let arguments = (ins AnyShaped:$x, AnyShaped:$y); +} + +def OperandZeroAndResultHaveSameRank : + TEST_Op<"operand0_and_result_have_same_rank", + [AllRanksMatch<["x", "res"]>]> { + let arguments = (ins AnyShaped:$x, AnyShaped:$y); + let results = (outs AnyShaped:$res); +} + +def OperandZeroAndResultHaveSameShape : + TEST_Op<"operand0_and_result_have_same_shape", + [AllShapesMatch<["x", "res"]>]> { + let arguments = (ins AnyShaped:$x, AnyShaped:$y); + let results = (outs AnyShaped:$res); +} + +def OperandZeroAndResultHaveSameElementCount : + TEST_Op<"operand0_and_result_have_same_element_count", + [AllElementCountsMatch<["x", "res"]>]> { + let arguments = (ins AnyShaped:$x, AnyShaped:$y); + let results = (outs AnyShaped:$res); +} + +def FourEqualsFive : + TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>; + +def OperandRankEqualsResultSize : + TEST_Op<"operand_rank_equals_result_size", + [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result], + "operand rank equals result size">]> { + let arguments = (ins AnyShaped:$operand); + let results = (outs AnyShaped:$result); +} + +def IfFirstOperandIsNoneThenSoIsSecond : + TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait< + "has either both none type operands or first is not none", + Or<[ + And<[TypeIsPred<"x", NoneType>, TypeIsPred<"y", NoneType>]>, + Neg<TypeIsPred<"x", NoneType>>]>>]> { + let arguments = (ins AnyType:$x, AnyType:$y); +} + +def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> { + let arguments = (ins AnyTensor, AnyTensor); + let results = (outs AnyTensor); +} + +// There the "HasParent" trait. +def ParentOp : TEST_Op<"parent">; +def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; + + +def TerminatorOp : TEST_Op<"finish", [Terminator]>; +def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator", + [SingleBlockImplicitTerminator<"TerminatorOp">]> { + let regions = (region SizedRegion<1>:$region); +} + +def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { + let arguments = (ins I32ElementsAttr:$attr); +} + +def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ + DeclareOpInterfaceMethods<InferTypeOpInterface>]> { + let arguments = (ins AnyTensor, AnyTensor); + let results = (outs AnyTensor); +} + +def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>; + +def UpdateAttr : Pat<(I32ElementsAttrOp $attr), + (I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">), + [(IsNotScalar $attr)]>; + +def TestBranchOp : TEST_Op<"br", [Terminator]> { + let arguments = (ins Variadic<AnyType>:$operands); +} + +def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic<I32>:$a, + Variadic<I32>:$b, + I32:$c, + Variadic<I32>:$d, + I32ElementsAttr:$operand_segment_sizes + ); +} + +def AttrSizedResultOp : TEST_Op<"attr_sized_results", + [AttrSizedResultSegments]> { + let arguments = (ins + I32ElementsAttr:$result_segment_sizes + ); + let results = (outs + Variadic<I32>:$a, + Variadic<I32>:$b, + I32:$c, + Variadic<I32>:$d + ); +} + +//===----------------------------------------------------------------------===// +// Test Patterns +//===----------------------------------------------------------------------===// + +def OpA : TEST_Op<"op_a"> { + let arguments = (ins I32, I32Attr:$attr); + let results = (outs I32); +} + +def OpB : TEST_Op<"op_b"> { + let arguments = (ins I32, I32Attr:$attr); + let results = (outs I32); +} + +// Test named pattern. +def TestNamedPatternRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>; + +// Test with fused location. +def : Pat<(OpA (OpA $input, $attr), $bttr), (OpB $input, $bttr)>; + +// Test added benefit. +def OpD : TEST_Op<"op_d">, Arguments<(ins I32)>, Results<(outs I32)>; +def OpE : TEST_Op<"op_e">, Arguments<(ins I32)>, Results<(outs I32)>; +def OpF : TEST_Op<"op_f">, Arguments<(ins I32)>, Results<(outs I32)>; +def OpG : TEST_Op<"op_g">, Arguments<(ins I32)>, Results<(outs I32)>; +// Verify that bumping benefit results in selecting different op. +def : Pat<(OpD $input), (OpE $input)>; +def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>; +// Verify that patterns with more source nodes are selected before those with fewer. +def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>; +def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>; + +// Test patterns for zero-result op. +def OpH : TEST_Op<"op_h">, Arguments<(ins I32)>, Results<(outs)>; +def OpI : TEST_Op<"op_i">, Arguments<(ins I32)>, Results<(outs)>; +def : Pat<(OpH $input), (OpI $input)>; + +// Test patterns for zero-input op. +def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>; +def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; +def : Pat<(OpJ), (OpK)>; + +// Test `$_` for ignoring op argument match. +def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> { + let arguments = (ins + AnyType:$a, AnyType:$b, AnyType:$c, + AnyAttr:$d, AnyAttr:$e, AnyAttr:$f); +} +def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> { + let arguments = (ins AnyType:$b, AnyAttr:$f); +} +def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f), + (TestIgnoreArgMatchDstOp $b, $f)>; + +def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> { + let arguments = (ins + I32:$input1, + I64Attr:$attr1, + I32:$input2, + I64Attr:$attr2 + ); +} + +def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> { + let arguments = (ins + I32:$input1, + I64Attr:$attr1, + I32:$input2, + I64Attr:$attr2 + ); +} + +def ManyArgsOp : TEST_Op<"many_arguments"> { + let arguments = (ins + I32:$input1, I32:$input2, I32:$input3, I32:$input4, I32:$input5, + I32:$input6, I32:$input7, I32:$input8, I32:$input9, + I64Attr:$attr1, I64Attr:$attr2, I64Attr:$attr3, I64Attr:$attr4, + I64Attr:$attr5, I64Attr:$attr6, I64Attr:$attr7, I64Attr:$attr8, + I64Attr:$attr9 + ); +} + +// Test that DRR does not blow up when seeing lots of arguments. +def : Pat<(ManyArgsOp + $input1, $input2, $input3, $input4, $input5, + $input6, $input7, $input8, $input9, + ConstantAttr<I64Attr, "42">, + $attr2, $attr3, $attr4, $attr5, $attr6, + $attr7, $attr8, $attr9), + (ManyArgsOp + $input1, $input2, $input3, $input4, $input5, + $input6, $input7, $input8, $input9, + ConstantAttr<I64Attr, "24">, + $attr2, $attr3, $attr4, $attr5, $attr6, + $attr7, $attr8, $attr9)>; + +// Test that we can capture and reference interleaved operands and attributes. +def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2), + (OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>; + +// Test NativeCodeCall. +def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> { + let arguments = (ins + I32:$input1, I32:$input2, + BoolAttr:$choice, + I64Attr:$attr1, I64Attr:$attr2 + ); + let results = (outs I32); +} +def OpNativeCodeCall2 : TEST_Op<"native_code_call2"> { + let arguments = (ins I32:$input, I64ArrayAttr:$attr); + let results = (outs I32); +} +// Native code call to invoke a C++ function +def CreateOperand: NativeCodeCall<"chooseOperand($0, $1, $2)">; +// Native code call to invoke a C++ expression +def CreateArrayAttr: NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; +// Test that we can use NativeCodeCall to create operand and attribute. +// This pattern chooses between $input1 and $input2 according to $choice and +// it combines $attr1 and $attr2 into an array attribute. +def : Pat<(OpNativeCodeCall1 $input1, $input2, + ConstBoolAttrTrue:$choice, $attr1, $attr2), + (OpNativeCodeCall2 (CreateOperand $input1, $input2, $choice), + (CreateArrayAttr $attr1, $attr2))>; +// Note: the following is just for testing purpose. +// Should use the replaceWithValue directive instead. +def UseOpResult: NativeCodeCall<"$0">; +// Test that we can use NativeCodeCall to create result. +def : Pat<(OpNativeCodeCall1 $input1, $input2, + ConstBoolAttrFalse, $attr1, $attr2), + (UseOpResult $input2)>; + +def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> { + let arguments = (ins I32:$input); + let results = (outs I32); +} +// Test that NativeCodeCall is not ignored if it is not used to directly +// replace the matched root op. +def : Pattern<(OpNativeCodeCall3 $input), + [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>; + +// Test AllAttrConstraintsOf. +def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { + let arguments = (ins I64ArrayAttr:$attr); + let results = (outs I32); +} +def OpAllAttrConstraint2 : TEST_Op<"all_attr_constraint_of2"> { + let arguments = (ins I64ArrayAttr:$attr); + let results = (outs I32); +} +def Constraint0 : AttrConstraint< + CPred<"$_self.cast<ArrayAttr>().getValue()[0]." + "cast<IntegerAttr>().getInt() == 0">, + "[0] == 0">; +def Constraint1 : AttrConstraint< + CPred<"$_self.cast<ArrayAttr>().getValue()[1]." + "cast<IntegerAttr>().getInt() == 1">, + "[1] == 1">; +def : Pat<(OpAllAttrConstraint1 + AllAttrConstraintsOf<[Constraint0, Constraint1]>:$attr), + (OpAllAttrConstraint2 $attr)>; + +// Op for testing RewritePattern removing op with inner ops. +def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> { + let regions = (region SizedRegion<1>:$region); + let hasCanonicalizer = 1; +} + +// Op for testing trivial removal via folding of op with inner ops and no uses. +def TestOpWithRegionFoldNoSideEffect : TEST_Op< + "op_with_region_fold_no_side_effect", [NoSideEffect]> { + let regions = (region SizedRegion<1>:$region); +} + +// Op for testing folding of outer op with inner ops. +def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> { + let arguments = (ins I32:$operand); + let results = (outs I32); + let regions = (region SizedRegion<1>:$region); + let hasFolder = 1; +} + +def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_folder"> { + let arguments = (ins Variadic<I32>:$operands); + let results = (outs Variadic<I32>); + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// Test Patterns (Symbol Binding) + +// Test symbol binding. +def OpSymbolBindingA : TEST_Op<"symbol_binding_a", []> { + let arguments = (ins I32:$operand, I64Attr:$attr); + let results = (outs I32); +} +def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> { + let arguments = (ins I32:$operand); + let results = (outs I32); + + let builders = [ + OpBuilder< + "Builder *builder, OperationState &state, Value operand", + [{ + state.types.assign({builder->getIntegerType(32)}); + state.addOperands({operand}); + }]> + ]; +} +def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> { + let arguments = (ins I32:$operand); + let results = (outs I32); + let builders = OpSymbolBindingB.builders; +} +def OpSymbolBindingD : TEST_Op<"symbol_binding_d", []> { + let arguments = (ins I32:$input1, I32:$input2, I64Attr:$attr); + let results = (outs I32); +} +def HasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">; +def : Pattern< + // Bind to source pattern op operand/attribute/result + (OpSymbolBindingA:$res_a $operand, $attr), [ + // Bind to auxiliary op result + (OpSymbolBindingC:$res_c (OpSymbolBindingB:$res_b $operand)), + + // Use bound symbols in resultant ops + (OpSymbolBindingD $res_b, $res_c, $attr)], + // Use bound symbols in additional constraints + [(HasOneUse $res_a)]>; + +def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> { + let arguments = (ins I32:$operand); +} + +// Test that we can bind to an op without results and reference it later. +def : Pat<(OpSymbolBindingNoResult:$op $operand), + (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>; + +//===----------------------------------------------------------------------===// +// Test Patterns (Attributes) + +// Test matching against op attributes. +def OpAttrMatch1 : TEST_Op<"match_op_attribute1"> { + let arguments = (ins + I32Attr:$required_attr, + OptionalAttr<I32Attr>:$optional_attr, + DefaultValuedAttr<I32Attr, "42">:$default_valued_attr, + I32Attr:$more_attr + ); + let results = (outs I32); +} +def OpAttrMatch2 : TEST_Op<"match_op_attribute2"> { + let arguments = OpAttrMatch1.arguments; + let results = (outs I32); +} +def MoreConstraint : AttrConstraint< + CPred<"$_self.cast<IntegerAttr>().getInt() == 4">, "more constraint">; +def : Pat<(OpAttrMatch1 $required, $optional, $default_valued, + MoreConstraint:$more), + (OpAttrMatch2 $required, $optional, $default_valued, $more)>; + +// Test unit attrs. +def OpAttrMatch3 : TEST_Op<"match_op_attribute3"> { + let arguments = (ins UnitAttr:$attr); + let results = (outs I32); +} +def OpAttrMatch4 : TEST_Op<"match_op_attribute4"> { + let arguments = (ins UnitAttr:$attr1, UnitAttr:$attr2); + let results = (outs I32); +} +def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>; + +// Test with constant attr. +def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>; +def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>; + +// Test string enum attribute in rewrites. +def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>; +// Test integer enum attribute in rewrites. +def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; +def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; + +//===----------------------------------------------------------------------===// +// Test Patterns (Multi-result Ops) + +def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; +def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; +def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; +def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; +def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; +def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; + +def MultiResultOpEnum: I64EnumAttr< + "MultiResultOpEnum", "Multi-result op kinds", [ + MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, + MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 + ]>; + +def ThreeResultOp : TEST_Op<"three_result"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs I32:$result1, F32:$result2, F32:$result3); +} + +def AnotherThreeResultOp : TEST_Op<"another_three_result"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs I32:$result1, F32:$result2, F32:$result3); +} + +def TwoResultOp : TEST_Op<"two_result"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs I32:$result1, F32:$result2); + + let builders = [ + OpBuilder< + "Builder *builder, OperationState &state, IntegerAttr kind", + [{ + auto i32 = builder->getIntegerType(32); + auto f32 = builder->getF32Type(); + state.types.assign({i32, f32}); + state.addAttribute("kind", kind); + }]> + ]; +} + +def AnotherTwoResultOp : TEST_Op<"another_two_result"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs F32:$result1, F32:$result2); +} + +def OneResultOp1 : TEST_Op<"one_result1"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs F32:$result1); +} + +def OneResultOp2 : TEST_Op<"one_result2"> { + let arguments = (ins MultiResultOpEnum:$kind); + let results = (outs I32:$result1); +} + +def OneResultOp3 : TEST_Op<"one_result3"> { + let arguments = (ins F32); + let results = (outs I32:$result1); +} + +// Test using multi-result op as a whole +def : Pat<(ThreeResultOp MultiResultOpKind1), + (AnotherThreeResultOp MultiResultOpKind1)>; + +// Test using multi-result op as a whole for partial replacement +def : Pattern<(ThreeResultOp MultiResultOpKind2), + [(TwoResultOp MultiResultOpKind2), + (OneResultOp1 MultiResultOpKind2)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind3), + [(OneResultOp2 MultiResultOpKind3), + (AnotherTwoResultOp MultiResultOpKind3)]>; + +// Test using results separately in a multi-result op +def : Pattern<(ThreeResultOp MultiResultOpKind4), + [(TwoResultOp:$res1__0 MultiResultOpKind4), + (OneResultOp1 MultiResultOpKind4), + (TwoResultOp:$res2__1 MultiResultOpKind4)]>; + +// Test referencing a single value in the value pack +// This rule only matches TwoResultOp if its second result has no use. +def : Pattern<(TwoResultOp:$res MultiResultOpKind5), + [(OneResultOp2 MultiResultOpKind5), + (OneResultOp1 MultiResultOpKind5)], + [(HasNoUseOf:$res__1)]>; + +// Test using auxiliary ops for replacing multi-result op +def : Pattern< + (ThreeResultOp MultiResultOpKind6), [ + // Auxiliary op generated to help building the final result but not + // directly used to replace the source op's results. + (TwoResultOp:$interm MultiResultOpKind6), + + (OneResultOp3 $interm__1), + (AnotherTwoResultOp MultiResultOpKind6) + ]>; + +//===----------------------------------------------------------------------===// +// Test Patterns (Variadic Ops) + +def OneVResOneVOperandOp1 : TEST_Op<"one_variadic_out_one_variadic_in1"> { + let arguments = (ins Variadic<I32>); + let results = (outs Variadic<I32>); +} +def OneVResOneVOperandOp2 : TEST_Op<"one_variadic_out_one_variadic_in2"> { + let arguments = (ins Variadic<I32>); + let results = (outs Variadic<I32>); +} + +// Rewrite an op with one variadic operand and one variadic result to +// another similar op. +def : Pat<(OneVResOneVOperandOp1 $inputs), (OneVResOneVOperandOp2 $inputs)>; + +def MixedVOperandOp1 : TEST_Op<"mixed_variadic_in1", + [SameVariadicOperandSize]> { + let arguments = (ins + Variadic<I32>:$input1, + F32:$input2, + Variadic<I32>:$input3 + ); +} + +def MixedVOperandOp2 : TEST_Op<"mixed_variadic_in2", + [SameVariadicOperandSize]> { + let arguments = (ins + Variadic<I32>:$input1, + F32:$input2, + Variadic<I32>:$input3 + ); +} + +// Rewrite an op with both variadic operands and normal operands. +def : Pat<(MixedVOperandOp1 $input1, $input2, $input3), + (MixedVOperandOp2 $input1, $input2, $input3)>; + +def MixedVResultOp1 : TEST_Op<"mixed_variadic_out1", [SameVariadicResultSize]> { + let results = (outs + Variadic<I32>:$output1, + F32:$output2, + Variadic<I32>:$output3 + ); +} + +def MixedVResultOp2 : TEST_Op<"mixed_variadic_out2", [SameVariadicResultSize]> { + let results = (outs + Variadic<I32>:$output1, + F32:$output2, + Variadic<I32>:$output3 + ); +} + +// Rewrite an op with both variadic results and normal results. +// Note that because we are generating the op with a top-level result pattern, +// we are able to deduce the correct result types for the generated op using +// the information from the matched root op. +def : Pat<(MixedVResultOp1), (MixedVResultOp2)>; + +def OneI32ResultOp : TEST_Op<"one_i32_out"> { + let results = (outs I32); +} + +def MixedVOperandOp3 : TEST_Op<"mixed_variadic_in3", + [SameVariadicOperandSize]> { + let arguments = (ins + I32:$input1, + Variadic<I32>:$input2, + Variadic<I32>:$input3, + I32Attr:$count + ); + + let results = (outs I32); +} + +def MixedVResultOp3 : TEST_Op<"mixed_variadic_out3", + [SameVariadicResultSize]> { + let arguments = (ins I32Attr:$count); + + let results = (outs + I32:$output1, + Variadic<I32>:$output2, + Variadic<I32>:$output3 + ); + + // We will use this op in a nested result pattern, where we cannot deduce the + // result type. So need to provide a builder not requiring result types. + let builders = [ + OpBuilder< + "Builder *builder, OperationState &state, IntegerAttr count", + [{ + auto i32Type = builder->getIntegerType(32); + state.addTypes(i32Type); // $output1 + SmallVector<Type, 4> types(count.getInt(), i32Type); + state.addTypes(types); // $output2 + state.addTypes(types); // $output3 + state.addAttribute("count", count); + }]> + ]; +} + +// Generates an op with variadic results using nested pattern. +def : Pat<(OneI32ResultOp), + (MixedVOperandOp3 + (MixedVResultOp3:$results__0 ConstantAttr<I32Attr, "2">), + (replaceWithValue $results__1), + (replaceWithValue $results__2), + ConstantAttr<I32Attr, "2">)>; + +//===----------------------------------------------------------------------===// +// Test Legalization +//===----------------------------------------------------------------------===// + +def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">; +def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">; + +def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure", + [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>; + +def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>; +def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>; +def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>; +def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; +def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; +def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; +def LegalOpA : TEST_Op<"legal_op_a">, + Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; +def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; + +// Check that smaller pattern depths are chosen, i.e. prioritize more direct +// mappings. +def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>; + +def : Pat<(ILLegalOpA), (ILLegalOpB)>; +def : Pat<(ILLegalOpB), (LegalOpA Test_LegalizerEnum_Failure)>; + +// Check that the higher benefit pattern is taken for multiple legalizations +// with the same depth. +def : Pat<(ILLegalOpC), (ILLegalOpD)>; +def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>; + +def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>; +def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>; + +// Check that patterns use the most up-to-date value when being replaced. +def TestRewriteOp : TEST_Op<"rewrite">, + Arguments<(ins AnyType)>, Results<(outs AnyType)>; +def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>; + +//===----------------------------------------------------------------------===// +// Test Type Legalization +//===----------------------------------------------------------------------===// + +def TestRegionBuilderOp : TEST_Op<"region_builder">; +def TestReturnOp : TEST_Op<"return", [Terminator]>, + Arguments<(ins Variadic<AnyType>)>; +def TestCastOp : TEST_Op<"cast">, + Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>; +def TestInvalidOp : TEST_Op<"invalid", [Terminator]>, + Arguments<(ins Variadic<AnyType>)>; +def TestTypeProducerOp : TEST_Op<"type_producer">, + Results<(outs AnyType)>; +def TestTypeConsumerOp : TEST_Op<"type_consumer">, + Arguments<(ins AnyType)>; +def TestValidOp : TEST_Op<"valid", [Terminator]>, + Arguments<(ins Variadic<AnyType>)>; + +//===----------------------------------------------------------------------===// +// Test parser. +//===----------------------------------------------------------------------===// + +def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> { + let arguments = (ins StrAttr:$keyword); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + +//===----------------------------------------------------------------------===// +// Test region argument list parsing. + +def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { + let summary = "isolated region operation"; + let description = [{ + Test op with an isolated region, to test passthrough region arguments. Each + argument is of index type. + }]; + + let arguments = (ins Index); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + +def WrappingRegionOp : TEST_Op<"wrapping_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "wrapping region operation"; + let description = [{ + Test op wrapping another op in a region, to test calling + parseGenericOperation from the custom parser. + }]; + + let results = (outs Variadic<AnyType>); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + +def PolyForOp : TEST_Op<"polyfor"> +{ + let summary = "polyfor operation"; + let description = [{ + Test op with multiple region arguments, each argument of index type. + }]; + + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +//===----------------------------------------------------------------------===// +// Test OpAsmInterface. + +def AsmInterfaceOp : TEST_Op<"asm_interface_op"> { + let results = (outs AnyType:$first, Variadic<AnyType>:$middle_results, + AnyType); +} + +def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> { + let results = (outs AnyType); +} + +#endif // TEST_OPS diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp new file mode 100644 index 00000000000..929c4a941a2 --- /dev/null +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -0,0 +1,504 @@ +//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +using namespace mlir; + +// Native function for testing NativeCodeCall +static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { + return choice.getValue() ? input1 : input2; +} + +static void createOpI(PatternRewriter &rewriter, Value input) { + rewriter.create<OpI>(rewriter.getUnknownLoc(), input); +} + +void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) { + // Turn the no result op to a one-result op. + rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand()->getType(), + op.operand()); +} + +namespace { +#include "TestPatterns.inc" +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Canonicalizer Driver. +//===----------------------------------------------------------------------===// + +namespace { +struct TestPatternDriver : public FunctionPass<TestPatternDriver> { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + populateWithGenerated(&getContext(), &patterns); + + // Verify named pattern is generated with expected name. + patterns.insert<TestNamedPatternRule>(&getContext()); + + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +static mlir::PassRegistration<TestPatternDriver> + pass("test-patterns", "Run test dialect patterns"); + +//===----------------------------------------------------------------------===// +// ReturnType Driver. +//===----------------------------------------------------------------------===// + +struct ReturnTypeOpMatch : public RewritePattern { + ReturnTypeOpMatch(MLIRContext *ctx) + : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) { + } + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) { + SmallVector<Value, 4> values(op->getOperands()); + SmallVector<Type, 2> inferedReturnTypes; + if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values, + op->getAttrs(), op->getRegions(), + inferedReturnTypes))) + return matchFailure(); + SmallVector<Type, 1> resultTypes(op->getResultTypes()); + if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) + return op->emitOpError( + "inferred type incompatible with return type of operation"), + matchFailure(); + + // TODO(jpienaar): Split this out to make the test more focused. + // Create new op with unknown location to verify building with + // InferTypeOpInterface is triggered. + auto fop = op->getParentOfType<FuncOp>(); + if (values[0] == fop.getArgument(0)) { + // Use the 2nd function argument if the first function argument is used + // when constructing the new op so that a new return type is inferred. + values[0] = fop.getArgument(1); + values[1] = fop.getArgument(1); + // TODO(jpienaar): Expand to regions. + rewriter.create<OpWithInferTypeInterfaceOp>( + UnknownLoc::get(op->getContext()), values, op->getAttrs()); + } + } + return matchFailure(); + } +}; + +namespace { +struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + populateWithGenerated(&getContext(), &patterns); + patterns.insert<ReturnTypeOpMatch>(&getContext()); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +static mlir::PassRegistration<TestReturnTypeDriver> + rt_pass("test-return-type", "Run return type functions"); + +//===----------------------------------------------------------------------===// +// Legalization Driver. +//===----------------------------------------------------------------------===// + +namespace { +//===----------------------------------------------------------------------===// +// Region-Block Rewrite Testing + +/// This pattern is a simple pattern that inlines the first region of a given +/// operation into the parent region. +struct TestRegionRewriteBlockMovement : public ConversionPattern { + TestRegionRewriteBlockMovement(MLIRContext *ctx) + : ConversionPattern("test.region", 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Inline this region into the parent region. + auto &parentRegion = *op->getParentRegion(); + if (op->getAttr("legalizer.should_clone")) + rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, + parentRegion.end()); + else + rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, + parentRegion.end()); + + // Drop this operation. + rewriter.eraseOp(op); + return matchSuccess(); + } +}; +/// This pattern is a simple pattern that generates a region containing an +/// illegal operation. +struct TestRegionRewriteUndo : public RewritePattern { + TestRegionRewriteUndo(MLIRContext *ctx) + : RewritePattern("test.region_builder", 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + // Create the region operation with an entry block containing arguments. + OperationState newRegion(op->getLoc(), "test.region"); + newRegion.addRegion(); + auto *regionOp = rewriter.createOperation(newRegion); + auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); + entryBlock->addArgument(rewriter.getIntegerType(64)); + + // Add an explicitly illegal operation to ensure the conversion fails. + rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); + rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); + + // Drop this operation. + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Type-Conversion Rewrite Testing + +/// This patterns erases a region operation that has had a type conversion. +struct TestDropOpSignatureConversion : public ConversionPattern { + TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { + } + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + Region ®ion = op->getRegion(0); + Block *entry = ®ion.front(); + + // Convert the original entry arguments. + TypeConverter::SignatureConversion result(entry->getNumArguments()); + for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) + if (failed(converter.convertSignatureArg( + i, entry->getArgument(i)->getType(), result))) + return matchFailure(); + + // Convert the region signature and just drop the operation. + rewriter.applySignatureConversion(®ion, result); + rewriter.eraseOp(op); + return matchSuccess(); + } + + /// The type converter to use when rewriting the signature. + TypeConverter &converter; +}; +/// This pattern simply updates the operands of the given operation. +struct TestPassthroughInvalidOp : public ConversionPattern { + TestPassthroughInvalidOp(MLIRContext *ctx) + : ConversionPattern("test.invalid", 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, + llvm::None); + return matchSuccess(); + } +}; +/// This pattern handles the case of a split return value. +struct TestSplitReturnType : public ConversionPattern { + TestSplitReturnType(MLIRContext *ctx) + : ConversionPattern("test.return", 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Check for a return of F32. + if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) + return matchFailure(); + + // Check if the first operation is a cast operation, if it is we use the + // results directly. + auto *defOp = operands[0]->getDefiningOp(); + if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { + rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); + return matchSuccess(); + } + + // Otherwise, fail to match. + return matchFailure(); + } +}; + +//===----------------------------------------------------------------------===// +// Multi-Level Type-Conversion Rewrite Testing +struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { + TestChangeProducerTypeI32ToF32(MLIRContext *ctx) + : ConversionPattern("test.type_producer", 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // If the type is I32, change the type to F32. + if (!(*op->result_type_begin()).isInteger(32)) + return matchFailure(); + rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); + return matchSuccess(); + } +}; +struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { + TestChangeProducerTypeF32ToF64(MLIRContext *ctx) + : ConversionPattern("test.type_producer", 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // If the type is F32, change the type to F64. + if (!(*op->result_type_begin()).isF32()) + return matchFailure(); + rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); + return matchSuccess(); + } +}; +struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { + TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) + : ConversionPattern("test.type_producer", 10, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Always convert to B16, even though it is not a legal type. This tests + // that values are unmapped correctly. + rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); + return matchSuccess(); + } +}; +struct TestUpdateConsumerType : public ConversionPattern { + TestUpdateConsumerType(MLIRContext *ctx) + : ConversionPattern("test.type_consumer", 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Verify that the incoming operand has been successfully remapped to F64. + if (!operands[0]->getType().isF64()) + return matchFailure(); + rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Non-Root Replacement Rewrite Testing +/// This pattern generates an invalid operation, but replaces it before the +/// pattern is finished. This checks that we don't need to legalize the +/// temporary op. +struct TestNonRootReplacement : public RewritePattern { + TestNonRootReplacement(MLIRContext *ctx) + : RewritePattern("test.replace_non_root", 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + auto resultType = *op->result_type_begin(); + auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); + auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); + + rewriter.replaceOp(illegalOp, {legalOp}); + rewriter.replaceOp(op, {illegalOp}); + return matchSuccess(); + } +}; +} // namespace + +namespace { +struct TestTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override { + // Drop I16 types. + if (t.isInteger(16)) + return success(); + + // Convert I64 to F64. + if (t.isInteger(64)) { + results.push_back(FloatType::getF64(t.getContext())); + return success(); + } + + // Split F32 into F16,F16. + if (t.isF32()) { + results.assign(2, FloatType::getF16(t.getContext())); + return success(); + } + + // Otherwise, convert the type directly. + results.push_back(t); + return success(); + } + + /// Override the hook to materialize a conversion. This is necessary because + /// we generate 1->N type mappings. + Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, + ArrayRef<Value> inputs, + Location loc) override { + return rewriter.create<TestCastOp>(loc, resultType, inputs); + } +}; + +struct TestLegalizePatternDriver + : public ModulePass<TestLegalizePatternDriver> { + /// The mode of conversion to use with the driver. + enum class ConversionMode { Analysis, Full, Partial }; + + TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} + + void runOnModule() override { + TestTypeConverter converter; + mlir::OwningRewritePatternList patterns; + populateWithGenerated(&getContext(), &patterns); + patterns + .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, + TestPassthroughInvalidOp, TestSplitReturnType, + TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, + TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, + TestNonRootReplacement>(&getContext()); + patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter); + mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), + converter); + + // Define the conversion target used for the test. + ConversionTarget target(getContext()); + target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); + target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>(); + target + .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); + target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { + // Don't allow F32 operands. + return llvm::none_of(op.getOperandTypes(), + [](Type type) { return type.isF32(); }); + }); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + + // Expect the type_producer/type_consumer operations to only operate on f64. + target.addDynamicallyLegalOp<TestTypeProducerOp>( + [](TestTypeProducerOp op) { return op.getType().isF64(); }); + target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { + return op.getOperand()->getType().isF64(); + }); + + // Check support for marking certain operations as recursively legal. + target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { + return static_cast<bool>( + op->getAttrOfType<UnitAttr>("test.recursively_legal")); + }); + + // Handle a partial conversion. + if (mode == ConversionMode::Partial) { + (void)applyPartialConversion(getModule(), target, patterns, &converter); + return; + } + + // Handle a full conversion. + if (mode == ConversionMode::Full) { + (void)applyFullConversion(getModule(), target, patterns, &converter); + return; + } + + // Otherwise, handle an analysis conversion. + assert(mode == ConversionMode::Analysis); + + // Analyze the convertible operations. + DenseSet<Operation *> legalizedOps; + if (failed(applyAnalysisConversion(getModule(), target, patterns, + legalizedOps, &converter))) + return signalPassFailure(); + + // Emit remarks for each legalizable operation. + for (auto *op : legalizedOps) + op->emitRemark() << "op '" << op->getName() << "' is legalizable"; + } + + /// The mode of conversion to use. + ConversionMode mode; +}; +} // end anonymous namespace + +static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> + legalizerConversionMode( + "test-legalize-mode", + llvm::cl::desc("The legalization mode to use with the test driver"), + llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), + llvm::cl::values( + clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, + "analysis", "Perform an analysis conversion"), + clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", + "Perform a full conversion"), + clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, + "partial", "Perform a partial conversion"))); + +static mlir::PassRegistration<TestLegalizePatternDriver> + legalizer_pass("test-legalize-patterns", + "Run test dialect legalization patterns", [] { + return std::make_unique<TestLegalizePatternDriver>( + legalizerConversionMode); + }); + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter::getRemappedValue testing. This method is used +// to get the remapped value of a original value that was replaced using +// ConversionPatternRewriter. +namespace { +/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with +/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original +/// operand twice. +/// +/// Example: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0) +/// is replaced with: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) +struct OneVResOneVOperandOp1Converter + : public OpConversionPattern<OneVResOneVOperandOp1> { + using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; + + PatternMatchResult + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto origOps = op.getOperands(); + assert(std::distance(origOps.begin(), origOps.end()) == 1 && + "One operand expected"); + Value origOp = *origOps.begin(); + SmallVector<Value, 2> remappedOperands; + // Replicate the remapped original operand twice. Note that we don't used + // the remapped 'operand' since the goal is testing 'getRemappedValue'. + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + + SmallVector<Type, 1> resultTypes(op.getResultTypes()); + rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes, + remappedOperands); + return matchSuccess(); + } +}; + +struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); + + mlir::ConversionTarget target(getContext()); + target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); + // We make OneVResOneVOperandOp1 legal only when it has more that one + // operand. This will trigger the conversion that will replace one-operand + // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. + target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( + [](Operation *op) -> bool { + return std::distance(op->operand_begin(), op->operand_end()) > 1; + }); + + if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +static PassRegistration<TestRemappedValue> remapped_value_pass( + "test-remapped-value", + "Test public remapped value mechanism in ConversionPatternRewriter"); diff --git a/mlir/test/lib/TestDialect/lit.local.cfg b/mlir/test/lib/TestDialect/lit.local.cfg new file mode 100644 index 00000000000..edb5b44b2e2 --- /dev/null +++ b/mlir/test/lib/TestDialect/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove('.td')
\ No newline at end of file |