summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorBen Vanik <benvanik@google.com>2019-11-24 18:50:54 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-24 19:23:38 -0800
commitd2284f1f0ba937ed0da8996957eb3e4557243f64 (patch)
tree82237fa830559e953b1a94fe46f05e832e42c14e /mlir/lib
parentae821fe626666ce21efb4a3c5de0adf8ac553493 (diff)
downloadbcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.tar.gz
bcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.zip
Support folding of StandardOps with DenseElementsAttr.
PiperOrigin-RevId: 282270243
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp44
1 files changed, 30 insertions, 14 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index c399cce1110..5426709ca98 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -244,25 +244,41 @@ template <class AttrElementT,
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
+ if (!operands[0] || !operands[1])
+ return {};
+ if (operands[0].getType() != operands[1].getType())
+ return {};
- if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
- auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
- if (!rhs || lhs.getType() != rhs.getType())
- return {};
+ if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
+ auto lhs = operands[0].cast<AttrElementT>();
+ auto rhs = operands[1].cast<AttrElementT>();
return AttrElementT::get(lhs.getType(),
calculate(lhs.getValue(), rhs.getValue()));
- } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
- auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
- if (!rhs || lhs.getType() != rhs.getType())
- return {};
-
- auto elementResult = constFoldBinaryOp<AttrElementT>(
- {lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
- if (!elementResult)
- return {};
-
+ } else if (operands[0].isa<SplatElementsAttr>() &&
+ operands[1].isa<SplatElementsAttr>()) {
+ // Both operands are splats so we can avoid expanding the values out and
+ // just fold based on the splat value.
+ auto lhs = operands[0].cast<SplatElementsAttr>();
+ auto rhs = operands[1].cast<SplatElementsAttr>();
+
+ auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
+ rhs.getSplatValue<ElementValueT>());
return DenseElementsAttr::get(lhs.getType(), elementResult);
+ } else if (operands[0].isa<ElementsAttr>() &&
+ operands[1].isa<ElementsAttr>()) {
+ // Operands are ElementsAttr-derived; perform an element-wise fold by
+ // expanding the values.
+ auto lhs = operands[0].cast<ElementsAttr>();
+ auto rhs = operands[1].cast<ElementsAttr>();
+
+ auto lhsIt = lhs.getValues<ElementValueT>().begin();
+ auto rhsIt = rhs.getValues<ElementValueT>().begin();
+ SmallVector<ElementValueT, 4> elementResults;
+ elementResults.reserve(lhs.getNumElements());
+ for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
+ elementResults.push_back(calculate(*lhsIt, *rhsIt));
+ return DenseElementsAttr::get(lhs.getType(), elementResults);
}
return {};
}
OpenPOWER on IntegriCloud