summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-06-18 18:26:26 -0700
committerMehdi Amini <joker.eph@gmail.com>2019-06-19 23:07:34 -0700
commit30bbd910565a1319bf121b0ef87031b8217cf1c2 (patch)
tree71c44332f8984385c2d6d573e69fffe60d11828c
parent18743a33ac0caf08b69f097383c176a8a653e4f5 (diff)
downloadbcm5719-llvm-30bbd910565a1319bf121b0ef87031b8217cf1c2.tar.gz
bcm5719-llvm-30bbd910565a1319bf121b0ef87031b8217cf1c2.zip
Simplify usages of SplatElementsAttr now that it inherits from DenseElementsAttr.
PiperOrigin-RevId: 253910543
-rw-r--r--mlir/include/mlir/IR/Attributes.h19
-rw-r--r--mlir/include/mlir/IR/Builders.h1
-rw-r--r--mlir/include/mlir/IR/Matchers.h2
-rw-r--r--mlir/include/mlir/Quantizer/Support/Statistics.h1
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp4
-rw-r--r--mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp38
-rw-r--r--mlir/lib/IR/Attributes.cpp22
-rw-r--r--mlir/lib/IR/Builders.cpp6
-rw-r--r--mlir/lib/StandardOps/Ops.cpp6
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp9
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp2
-rw-r--r--mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp2
12 files changed, 11 insertions, 101 deletions
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 76e57e527e4..04ba52b8c65 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -820,25 +820,6 @@ public:
class SplatElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
- using ValueType = Attribute;
-
- /// 'type' must be a vector or tensor with static shape.
- static SplatElementsAttr get(ShapedType type, Attribute elt);
- Attribute getValue() const { return getSplatValue(); }
-
- /// Generates a new SplatElementsAttr by mapping each int value to a new
- /// underlying APInt. The new values can represent either a integer or float.
- /// This ElementsAttr should contain integers.
- SplatElementsAttr
- mapValues(Type newElementType,
- llvm::function_ref<APInt(const APInt &)> mapping) const;
-
- /// Generates a new SplatElementsAttr by mapping each float value to a new
- /// underlying APInt. The new values can represent either a integer or float.
- /// This ElementsAttr should contain floats.
- SplatElementsAttr
- mapValues(Type newElementType,
- llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 80bdbc35cf4..c04ca7ad214 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -113,7 +113,6 @@ public:
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(StringRef value);
- ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 61796ff09ab..4ea1ce2c621 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -104,7 +104,7 @@ struct constant_int_op_binder {
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
- .match(splatAttr.getValue());
+ .match(splatAttr.getSplatValue());
}
}
return false;
diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h
index d4641d66cf2..c6f059efd79 100644
--- a/mlir/include/mlir/Quantizer/Support/Statistics.h
+++ b/mlir/include/mlir/Quantizer/Support/Statistics.h
@@ -73,7 +73,6 @@ public:
/// DenseFPElementsAttr
/// OpaqueElementsAttr (with Float based type)
/// SparseElementAttr (with Float based type)
-/// SplatElementsAttr
class AttributeTensorStatistics : public AbstractTensorStatistics {
public:
AttributeTensorStatistics(Attribute attr) : attr(attr) {}
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 0c8ba3171aa..9dcc6df6bea 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -80,8 +80,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
}
// Is the constant value a type expressed in a way that we support?
- if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
- !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
+ !value.isa<SparseElementsAttr>()) {
return matchFailure();
}
diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
index 850f1224cbb..7cfedf9412d 100644
--- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
+++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -69,36 +69,6 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
-static SplatElementsAttr
-convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
- QuantizedType quantizedElementType,
- const UniformQuantizedValueConverter &converter) {
- // Since the splat just references a single primitive value, use the
- // function for converting primitives.
- // NOTE: When implementing per-channel, we will need to promote the
- // splat to a dense and handle channels individually.
- Type unusedPrimitiveType;
- auto elementAttr =
- convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
- converter, unusedPrimitiveType);
- if (!elementAttr) {
- return nullptr;
- }
-
- // Cast from an expressed-type-based type to storage-type-based type,
- // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
- ShapedType newSplatType =
- quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
- .dyn_cast_or_null<ShapedType>();
- if (!newSplatType) {
- return nullptr;
- }
- return SplatElementsAttr::get(newSplatType, elementAttr);
-}
-
-/// Converts a real expressed SplatElementsAttr to a corresponding
-/// SplatElementsAttr containing quantized storage values assuming the given
-/// quantizedElementType and converter.
static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
QuantizedType quantizedElementType,
@@ -134,13 +104,7 @@ Attribute quantizeAttrUniform(Attribute realValue,
const UniformQuantizedValueConverter &converter,
Type &outConvertedType) {
// Fork to handle different variants of constants supported.
- if (realValue.isa<SplatElementsAttr>()) {
- // Splatted tensor or vector constant.
- auto converted = convertSplatElementsAttr(
- realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
- outConvertedType = converted.getType();
- return converted;
- } else if (realValue.isa<DenseFPElementsAttr>()) {
+ if (realValue.isa<DenseFPElementsAttr>()) {
// Dense tensor or vector constant.
auto converted = convertDenseFPElementsAttr(
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index ce33508830c..f4a6cf11bca 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -941,28 +941,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
}
//===----------------------------------------------------------------------===//
-// SplatElementsAttr
-//===----------------------------------------------------------------------===//
-
-SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
- return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>();
-}
-
-SplatElementsAttr SplatElementsAttr::mapValues(
- Type newElementType,
- llvm::function_ref<APInt(const APInt &)> mapping) const {
- return DenseElementsAttr::mapValues(newElementType, mapping)
- .cast<SplatElementsAttr>();
-}
-
-SplatElementsAttr SplatElementsAttr::mapValues(
- Type newElementType,
- llvm::function_ref<APInt(const APFloat &)> mapping) const {
- return DenseElementsAttr::mapValues(newElementType, mapping)
- .cast<SplatElementsAttr>();
-}
-
-//===----------------------------------------------------------------------===//
// NamedAttributeList
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 72eaa91211e..e2c3a55421b 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -180,10 +180,6 @@ FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
}
-ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
- return SplatElementsAttr::get(type, elt);
-}
-
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values) {
return DenseElementsAttr::get(type, values);
@@ -255,7 +251,7 @@ Attribute Builder::getZeroAttr(Type type) {
auto element = getZeroAttr(vtType.getElementType());
if (!element)
return {};
- return getSplatElementsAttr(vtType, element);
+ return getDenseElementsAttr(vtType, element);
}
default:
break;
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index 7e4ce29d146..9a4a3f26d65 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -182,11 +182,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
return {};
auto elementResult = constFoldBinaryOp<AttrElementT>(
- {lhs.getValue(), rhs.getValue()}, calculate);
+ {lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
if (!elementResult)
return {};
- return SplatElementsAttr::get(lhs.getType(), elementResult);
+ return DenseElementsAttr::get(lhs.getType(), elementResult);
}
return {};
}
@@ -1614,7 +1614,7 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
- return splatAggregate.getValue();
+ return splatAggregate.getSplatValue();
// Otherwise, collect the constant indices into the aggregate.
SmallVector<uint64_t, 8> indices;
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ef9cbe82eb2..36d04a9ae6c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -92,18 +92,11 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
auto *child = getLLVMConstant(vectorType->getElementType(),
- splatAttr.getValue(), loc);
+ splatAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
}
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
- if (denseAttr.isSplat()) {
- auto *child = getLLVMConstant(vectorType->getElementType(),
- denseAttr.getSplatValue(), loc);
- return llvm::ConstantVector::getSplat(vectorType->getNumElements(),
- child);
- }
-
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 6819a4ef62a..2204c42dec1 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -388,7 +388,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
- auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
+ auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue());
res.push_back(NamedAttribute(a.first, attr));
} else {
res.push_back(a);
diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp
index d2b551f0296..d10623e3d1d 100644
--- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp
@@ -128,7 +128,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
- auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>(
+ auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
&ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
Type returnedType;
OpenPOWER on IntegriCloud