summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorBen Vanik <benvanik@google.com>2019-11-24 18:50:54 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-24 19:23:38 -0800
commitd2284f1f0ba937ed0da8996957eb3e4557243f64 (patch)
tree82237fa830559e953b1a94fe46f05e832e42c14e /mlir
parentae821fe626666ce21efb4a3c5de0adf8ac553493 (diff)
downloadbcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.tar.gz
bcm5719-llvm-d2284f1f0ba937ed0da8996957eb3e4557243f64.zip
Support folding of StandardOps with DenseElementsAttr.
PiperOrigin-RevId: 282270243
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp44
-rw-r--r--mlir/test/Transforms/constant-fold.mlir28
2 files changed, 58 insertions, 14 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index c399cce1110..5426709ca98 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -244,25 +244,41 @@ template <class AttrElementT,
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
+ if (!operands[0] || !operands[1])
+ return {};
+ if (operands[0].getType() != operands[1].getType())
+ return {};
- if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
- auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
- if (!rhs || lhs.getType() != rhs.getType())
- return {};
+ if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
+ auto lhs = operands[0].cast<AttrElementT>();
+ auto rhs = operands[1].cast<AttrElementT>();
return AttrElementT::get(lhs.getType(),
calculate(lhs.getValue(), rhs.getValue()));
- } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
- auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
- if (!rhs || lhs.getType() != rhs.getType())
- return {};
-
- auto elementResult = constFoldBinaryOp<AttrElementT>(
- {lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
- if (!elementResult)
- return {};
-
+ } else if (operands[0].isa<SplatElementsAttr>() &&
+ operands[1].isa<SplatElementsAttr>()) {
+ // Both operands are splats so we can avoid expanding the values out and
+ // just fold based on the splat value.
+ auto lhs = operands[0].cast<SplatElementsAttr>();
+ auto rhs = operands[1].cast<SplatElementsAttr>();
+
+ auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
+ rhs.getSplatValue<ElementValueT>());
return DenseElementsAttr::get(lhs.getType(), elementResult);
+ } else if (operands[0].isa<ElementsAttr>() &&
+ operands[1].isa<ElementsAttr>()) {
+ // Operands are ElementsAttr-derived; perform an element-wise fold by
+ // expanding the values.
+ auto lhs = operands[0].cast<ElementsAttr>();
+ auto rhs = operands[1].cast<ElementsAttr>();
+
+ auto lhsIt = lhs.getValues<ElementValueT>().begin();
+ auto rhsIt = rhs.getValues<ElementValueT>().begin();
+ SmallVector<ElementValueT, 4> elementResults;
+ elementResults.reserve(lhs.getNumElements());
+ for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
+ elementResults.push_back(calculate(*lhsIt, *rhsIt));
+ return DenseElementsAttr::get(lhs.getType(), elementResults);
}
return {};
}
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 6d1acffa24c..ac2e20955d2 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -50,6 +50,34 @@ func @addf_splat_tensor() -> tensor<4xf32> {
// -----
+// CHECK-LABEL: func @addf_dense_tensor
+func @addf_dense_tensor() -> tensor<4xf32> {
+ %0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
+ %1 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
+
+ // CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 5.{{0*}}e+00, 7.{{0*}}e+00, 9.{{0*}}e+00]> : tensor<4xf32>
+ %2 = addf %0, %1 : tensor<4xf32>
+
+ // CHECK-NEXT: return [[C]]
+ return %2 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @addf_dense_and_splat_tensors
+func @addf_dense_and_splat_tensors() -> tensor<4xf32> {
+ %0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
+ %1 = constant dense<1.5> : tensor<4xf32>
+
+ // CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 4.{{0*}}e+00, 5.{{0*}}e+00, 6.{{0*}}e+00]> : tensor<4xf32>
+ %2 = addf %0, %1 : tensor<4xf32>
+
+ // CHECK-NEXT: return [[C]]
+ return %2 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @simple_addi
func @simple_addi() -> i32 {
%0 = constant 1 : i32
OpenPOWER on IntegriCloud