diff options
| author | Diego Caballero <diego.caballero@intel.com> | 2019-11-19 10:15:36 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-19 11:09:39 -0800 |
| commit | dd5a7cb48833d4abf93a063f40b7cf5baae940ef (patch) | |
| tree | 7e30e57c57499894dd1f75b12aa28fdaf298915f /mlir/test/lib/TestDialect | |
| parent | 06fb797b4090526b906d7af44715a826faed5d3a (diff) | |
| download | bcm5719-llvm-dd5a7cb48833d4abf93a063f40b7cf5baae940ef.tar.gz bcm5719-llvm-dd5a7cb48833d4abf93a063f40b7cf5baae940ef.zip | |
Add getRemappedValue to ConversionPatternRewriter
This method is needed for N->1 conversion patterns to retrieve remapped
Values used in the original N operations.
Closes tensorflow/mlir#237
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/237 from dcaballe:dcaballe/getRemappedValue 1f64fadcf2b203f7b336ff0c5838b116ae3625db
PiperOrigin-RevId: 281321881
Diffstat (limited to 'mlir/test/lib/TestDialect')
| -rw-r--r-- | mlir/test/lib/TestDialect/TestPatterns.cpp | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 936d7632967..5ef03606dbe 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -435,3 +435,66 @@ static mlir::PassRegistration<TestLegalizePatternDriver> return std::make_unique<TestLegalizePatternDriver>( legalizerConversionMode); }); + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter::getRemappedValue testing. This method is used +// to get the remapped value of a original value that was replaced using +// ConversionPatternRewriter. +namespace { +/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with +/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original +/// operand twice. +/// +/// Example: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0) +/// is replaced with: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) +struct OneVResOneVOperandOp1Converter + : public OpConversionPattern<OneVResOneVOperandOp1> { + using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; + + PatternMatchResult + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto origOps = op.getOperands(); + assert(std::distance(origOps.begin(), origOps.end()) == 1 && + "One operand expected"); + Value *origOp = *origOps.begin(); + SmallVector<Value *, 2> remappedOperands; + // Replicate the remapped original operand twice. Note that we don't used + // the remapped 'operand' since the goal is testing 'getRemappedValue'. + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + + SmallVector<Type, 1> resultTypes(op.getResultTypes()); + rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes, + remappedOperands); + return matchSuccess(); + } +}; + +struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + patterns.insert<OneVResOneVOperandOp1Converter>(&getContext()); + + mlir::ConversionTarget target(getContext()); + target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>(); + // We make OneVResOneVOperandOp1 legal only when it has more that one + // operand. This will trigger the conversion that will replace one-operand + // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. + target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( + [](Operation *op) -> bool { + return std::distance(op->operand_begin(), op->operand_end()) > 1; + }); + + if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +static PassRegistration<TestRemappedValue> remapped_value_pass( + "test-remapped-value", + "Test public remapped value mechanism in ConversionPatternRewriter"); |

