summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Traits.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Traits.cpp')
-rw-r--r--mlir/lib/Dialect/Traits.cpp211
1 files changed, 211 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
new file mode 100644
index 00000000000..3aea206c07e
--- /dev/null
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -0,0 +1,211 @@
+//===- Traits.cpp - Common op traits shared by dialects -------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
+ ArrayRef<int64_t> shape2,
+ SmallVectorImpl<int64_t> &resultShape) {
+ // To compute the result broadcasted shape, we compare operand shapes
+ // element-wise: starting with the trailing dimensions, and working the
+ // way backward. Two dimensions are compatible when
+ // 1. they are equal, or
+ // 2. one of them is 1
+ // The result shape has the maximum among the two inputs at every
+ // dimension index.
+
+ resultShape.clear();
+ if (shape1.size() > shape2.size()) {
+ std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
+ } else {
+ std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
+ }
+
+ auto i1 = shape1.rbegin(), e1 = shape1.rend();
+ auto i2 = shape2.rbegin(), e2 = shape2.rend();
+ auto iR = resultShape.rbegin();
+
+ // Check each dimension is consistent.
+ for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
+ if (*i1 == -1 || *i2 == -1) {
+ // One or both dimensions is unknown. Follow TensorFlow behavior:
+ // - If either dimension is greater than 1, we assume that the program is
+ // correct, and the other dimension will be broadcast to match it.
+ // - If either dimension is 1, the other dimension is the output.
+ if (*i1 > 1) {
+ *iR = *i1;
+ } else if (*i2 > 1) {
+ *iR = *i2;
+ } else if (*i1 == 1) {
+ *iR = *i2;
+ } else if (*i2 == 1) {
+ *iR = *i1;
+ } else {
+ *iR = -1;
+ }
+ } else {
+ if (*i1 == *i2 || *i2 == 1) {
+ *iR = *i1;
+ } else if (*i1 == 1) {
+ *iR = *i2;
+ } else {
+ // This dimension of the two operand types is incompatible.
+ resultShape.clear();
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+/// Returns the shape of the given type. Scalars will be considered as having a
+/// shape with zero dimensions.
+static ArrayRef<int64_t> getShape(Type type) {
+ if (auto sType = type.dyn_cast<ShapedType>())
+ return sType.getShape();
+ return {};
+}
+
+/// Returns the result broadcast composition type from the two given types by
+/// following NumPy broadcast semantics. Returned type may have dynamic shape if
+/// either of the input types has dynamic shape. Returns null type if the two
+/// given types are not broadcast-compatible.
+Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
+ // Returns the scalar type out of the given type.
+ auto getScalarType = [](Type type) -> Type {
+ if (auto shapedType = type.dyn_cast<ShapedType>())
+ return shapedType.getElementType();
+ return type;
+ };
+
+ // Make sure underlying scalar type is the same.
+ auto scalarType = getScalarType(type1);
+ if (scalarType != getScalarType(type2))
+ return {};
+
+ // If one of the types is unranked tensor, then the other type shouldn't be
+ // vector and the result should have unranked tensor type.
+ if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
+ if (type1.isa<VectorType>() || type2.isa<VectorType>())
+ return {};
+ return UnrankedTensorType::get(scalarType);
+ }
+
+ // Returns the type kind if the given type is a vector or ranked tensor type.
+ // Returns llvm::None otherwise.
+ auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
+ if (type.isa<VectorType>() || type.isa<RankedTensorType>())
+ return static_cast<StandardTypes::Kind>(type.getKind());
+ return llvm::None;
+ };
+
+ // Make sure the composite type, if has, is consistent.
+ auto compositeKind1 = getCompositeTypeKind(type1);
+ auto compositeKind2 = getCompositeTypeKind(type2);
+ Optional<StandardTypes::Kind> resultCompositeKind;
+
+ if (compositeKind1 && compositeKind2) {
+ // Disallow mixing vector and tensor.
+ if (compositeKind1 != compositeKind2)
+ return {};
+ resultCompositeKind = compositeKind1;
+ } else if (compositeKind1) {
+ resultCompositeKind = compositeKind1;
+ } else if (compositeKind2) {
+ resultCompositeKind = compositeKind2;
+ }
+
+ // Get the shape of each type.
+ SmallVector<int64_t, 4> resultShape;
+ if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+ return {};
+
+ // Compose the final broadcasted type
+ if (resultCompositeKind == StandardTypes::Vector)
+ return VectorType::get(resultShape, scalarType);
+ if (resultCompositeKind == StandardTypes::RankedTensor)
+ return RankedTensorType::get(resultShape, scalarType);
+ return scalarType;
+}
+
+/// Returns true if the given types has both vector types and tensor types.
+static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
+ return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
+ llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
+}
+
+static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
+ ArrayRef<int64_t> shape2) {
+ auto isCompatible = [](int64_t dim1, int64_t dim2) {
+ return dim1 == dim2 || dim1 == -1 || dim2 == -1;
+ };
+ if (shape1.size() != shape2.size())
+ return false;
+ for (const auto &p : llvm::zip(shape1, shape2))
+ if (!isCompatible(std::get<0>(p), std::get<1>(p)))
+ return false;
+ return true;
+}
+
+LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
+ assert(op->getNumOperands() == 2 &&
+ "only support broadcast check on two operands");
+ assert(op->getNumResults() == 1 &&
+ "only support broadcast check on one result");
+
+ auto type1 = op->getOperand(0)->getType();
+ auto type2 = op->getOperand(1)->getType();
+ auto retType = op->getResult(0)->getType();
+
+ // We forbid broadcasting vector and tensor.
+ if (hasBothVectorAndTensorType({type1, type2, retType}))
+ return op->emitError("cannot broadcast vector with tensor");
+
+ if (retType.isa<UnrankedTensorType>())
+ return success();
+
+ bool isUnranked1 = type1.isa<UnrankedTensorType>();
+ bool isUnranked2 = type2.isa<UnrankedTensorType>();
+
+ // If both operands are unranked, then all result shapes are possible.
+ if (isUnranked1 && isUnranked2)
+ return success();
+
+ // If one of the operands is unranked, then the known dimensions in the result
+ // should be compatible with the other shaped operand.
+ if (isUnranked1 || isUnranked2) {
+ // Result should have higher rank than the shaped operand's rank and then
+ // the result's trailing dimensions should be compatible with the operand
+ // shape.
+ ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
+ ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
+ if (!areCompatibleShapes(actualSuffix, shape))
+ return op->emitOpError()
+ << "result type " << retType
+ << " has shape incompatible with a ranked operand type";
+ return success();
+ }
+
+ // If both operands are shaped, then the computed broadcasted shape should be
+ // compatible with the result shape.
+ SmallVector<int64_t, 4> resultShape;
+ if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+ return op->emitOpError("operands don't have broadcast-compatible shapes");
+
+ if (!areCompatibleShapes(resultShape, getShape(retType)))
+ return op->emitOpError() << "result type " << retType
+ << " does not have shape compatible with the one "
+ "computed from the operand types";
+
+ return success();
+}
OpenPOWER on IntegriCloud