diff options
| author | Ben Vanik <benvanik@google.com> | 2019-11-24 18:50:54 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-24 19:23:38 -0800 |
| commit | d2284f1f0ba937ed0da8996957eb3e4557243f64 (patch) | |
| tree | 82237fa830559e953b1a94fe46f05e832e42c14e /mlir/lib | |
| parent | ae821fe626666ce21efb4a3c5de0adf8ac553493 (diff) | |
| download | bcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.tar.gz bcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.zip | |
Support folding of StandardOps with DenseElementsAttr.
PiperOrigin-RevId: 282270243
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index c399cce1110..5426709ca98 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -244,25 +244,41 @@ template <class AttrElementT, Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if (!operands[0] || !operands[1]) + return {}; + if (operands[0].getType() != operands[1].getType()) + return {}; - if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) { - auto rhs = operands[1].dyn_cast_or_null<AttrElementT>(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; + if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) { + auto lhs = operands[0].cast<AttrElementT>(); + auto rhs = operands[1].cast<AttrElementT>(); return AttrElementT::get(lhs.getType(), calculate(lhs.getValue(), rhs.getValue())); - } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { - auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>(); - if (!rhs || lhs.getType() != rhs.getType()) - return {}; - - auto elementResult = constFoldBinaryOp<AttrElementT>( - {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); - if (!elementResult) - return {}; - + } else if (operands[0].isa<SplatElementsAttr>() && + operands[1].isa<SplatElementsAttr>()) { + // Both operands are splats so we can avoid expanding the values out and + // just fold based on the splat value. + auto lhs = operands[0].cast<SplatElementsAttr>(); + auto rhs = operands[1].cast<SplatElementsAttr>(); + + auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(), + rhs.getSplatValue<ElementValueT>()); return DenseElementsAttr::get(lhs.getType(), elementResult); + } else if (operands[0].isa<ElementsAttr>() && + operands[1].isa<ElementsAttr>()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto lhs = operands[0].cast<ElementsAttr>(); + auto rhs = operands[1].cast<ElementsAttr>(); + + auto lhsIt = lhs.getValues<ElementValueT>().begin(); + auto rhsIt = rhs.getValues<ElementValueT>().begin(); + SmallVector<ElementValueT, 4> elementResults; + elementResults.reserve(lhs.getNumElements()); + for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) + elementResults.push_back(calculate(*lhsIt, *rhsIt)); + return DenseElementsAttr::get(lhs.getType(), elementResults); } return {}; } |

