diff options
author | River Riddle <riverriddle@google.com> | 2019-06-18 18:26:26 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2019-06-19 23:07:34 -0700 |
commit | 30bbd910565a1319bf121b0ef87031b8217cf1c2 (patch) | |
tree | 71c44332f8984385c2d6d573e69fffe60d11828c | |
parent | 18743a33ac0caf08b69f097383c176a8a653e4f5 (diff) | |
download | bcm5719-llvm-30bbd910565a1319bf121b0ef87031b8217cf1c2.tar.gz bcm5719-llvm-30bbd910565a1319bf121b0ef87031b8217cf1c2.zip |
Simplify usages of SplatElementsAttr now that it inherits from DenseElementsAttr.
PiperOrigin-RevId: 253910543
-rw-r--r-- | mlir/include/mlir/IR/Attributes.h | 19 | ||||
-rw-r--r-- | mlir/include/mlir/IR/Builders.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/IR/Matchers.h | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Quantizer/Support/Statistics.h | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 38 | ||||
-rw-r--r-- | mlir/lib/IR/Attributes.cpp | 22 | ||||
-rw-r--r-- | mlir/lib/IR/Builders.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/StandardOps/Ops.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 9 | ||||
-rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 2 | ||||
-rw-r--r-- | mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp | 2 |
12 files changed, 11 insertions, 101 deletions
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 76e57e527e4..04ba52b8c65 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -820,25 +820,6 @@ public: class SplatElementsAttr : public DenseElementsAttr { public: using DenseElementsAttr::DenseElementsAttr; - using ValueType = Attribute; - - /// 'type' must be a vector or tensor with static shape. - static SplatElementsAttr get(ShapedType type, Attribute elt); - Attribute getValue() const { return getSplatValue(); } - - /// Generates a new SplatElementsAttr by mapping each int value to a new - /// underlying APInt. The new values can represent either a integer or float. - /// This ElementsAttr should contain integers. - SplatElementsAttr - mapValues(Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const; - - /// Generates a new SplatElementsAttr by mapping each float value to a new - /// underlying APInt. The new values can represent either a integer or float. - /// This ElementsAttr should contain floats. - SplatElementsAttr - mapValues(Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const; /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 80bdbc35cf4..c04ca7ad214 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,7 +113,6 @@ public: TypeAttr getTypeAttr(Type type); FunctionAttr getFunctionAttr(Function *value); FunctionAttr getFunctionAttr(StringRef value); - ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt); ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<Attribute> values); ElementsAttr getDenseIntElementsAttr(ShapedType type, diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 61796ff09ab..4ea1ce2c621 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -104,7 +104,7 @@ struct constant_int_op_binder { if (type.isa<VectorType>() || type.isa<RankedTensorType>()) { if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { return attr_value_binder<IntegerAttr>(bind_value) - .match(splatAttr.getValue()); + .match(splatAttr.getSplatValue()); } } return false; diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h index d4641d66cf2..c6f059efd79 100644 --- a/mlir/include/mlir/Quantizer/Support/Statistics.h +++ b/mlir/include/mlir/Quantizer/Support/Statistics.h @@ -73,7 +73,6 @@ public: /// DenseFPElementsAttr /// OpaqueElementsAttr (with Float based type) /// SparseElementAttr (with Float based type) -/// SplatElementsAttr class AttributeTensorStatistics : public AbstractTensorStatistics { public: AttributeTensorStatistics(Attribute attr) : attr(attr) {} diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 0c8ba3171aa..9dcc6df6bea 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -80,8 +80,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, } // Is the constant value a type expressed in a way that we support? - if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() && - !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) { + if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() && + !value.isa<SparseElementsAttr>()) { return matchFailure(); } diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 850f1224cbb..7cfedf9412d 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -69,36 +69,6 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, /// Converts a real expressed SplatElementsAttr to a corresponding /// SplatElementsAttr containing quantized storage values assuming the given /// quantizedElementType and converter. -static SplatElementsAttr -convertSplatElementsAttr(SplatElementsAttr realSplatAttr, - QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { - // Since the splat just references a single primitive value, use the - // function for converting primitives. - // NOTE: When implementing per-channel, we will need to promote the - // splat to a dense and handle channels individually. - Type unusedPrimitiveType; - auto elementAttr = - convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType, - converter, unusedPrimitiveType); - if (!elementAttr) { - return nullptr; - } - - // Cast from an expressed-type-based type to storage-type-based type, - // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newSplatType = - quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) - .dyn_cast_or_null<ShapedType>(); - if (!newSplatType) { - return nullptr; - } - return SplatElementsAttr::get(newSplatType, elementAttr); -} - -/// Converts a real expressed SplatElementsAttr to a corresponding -/// SplatElementsAttr containing quantized storage values assuming the given -/// quantizedElementType and converter. static SparseElementsAttr convertSparseElementsAttr(SparseElementsAttr realSparseAttr, QuantizedType quantizedElementType, @@ -134,13 +104,7 @@ Attribute quantizeAttrUniform(Attribute realValue, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. - if (realValue.isa<SplatElementsAttr>()) { - // Splatted tensor or vector constant. - auto converted = convertSplatElementsAttr( - realValue.cast<SplatElementsAttr>(), quantizedElementType, converter); - outConvertedType = converted.getType(); - return converted; - } else if (realValue.isa<DenseFPElementsAttr>()) { + if (realValue.isa<DenseFPElementsAttr>()) { // Dense tensor or vector constant. auto converted = convertDenseFPElementsAttr( realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index ce33508830c..f4a6cf11bca 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -941,28 +941,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { } //===----------------------------------------------------------------------===// -// SplatElementsAttr -//===----------------------------------------------------------------------===// - -SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { - return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast<SplatElementsAttr>(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast<SplatElementsAttr>(); -} - -//===----------------------------------------------------------------------===// // NamedAttributeList //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 72eaa91211e..e2c3a55421b 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -180,10 +180,6 @@ FunctionAttr Builder::getFunctionAttr(StringRef value) { return FunctionAttr::get(value, getContext()); } -ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) { - return SplatElementsAttr::get(type, elt); -} - ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef<Attribute> values) { return DenseElementsAttr::get(type, values); @@ -255,7 +251,7 @@ Attribute Builder::getZeroAttr(Type type) { auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; - return getSplatElementsAttr(vtType, element); + return getDenseElementsAttr(vtType, element); } default: break; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 7e4ce29d146..9a4a3f26d65 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -182,11 +182,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, return {}; auto elementResult = constFoldBinaryOp<AttrElementT>( - {lhs.getValue(), rhs.getValue()}, calculate); + {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); if (!elementResult) return {}; - return SplatElementsAttr::get(lhs.getType(), elementResult); + return DenseElementsAttr::get(lhs.getType(), elementResult); } return {}; } @@ -1614,7 +1614,7 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) - return splatAggregate.getValue(); + return splatAggregate.getSplatValue(); // Otherwise, collect the constant indices into the aggregate. SmallVector<uint64_t, 8> indices; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ef9cbe82eb2..36d04a9ae6c 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -92,18 +92,11 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { auto *vectorType = cast<llvm::VectorType>(llvmType); auto *child = getLLVMConstant(vectorType->getElementType(), - splatAttr.getValue(), loc); + splatAttr.getSplatValue(), loc); return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child); } if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) { auto *vectorType = cast<llvm::VectorType>(llvmType); - if (denseAttr.isSplat()) { - auto *child = getLLVMConstant(vectorType->getElementType(), - denseAttr.getSplatValue(), loc); - return llvm::ConstantVector::getSplat(vectorType->getNumElements(), - child); - } - SmallVector<llvm::Constant *, 8> constants; uint64_t numElements = vectorType->getNumElements(); constants.reserve(numElements); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6819a4ef62a..2204c42dec1 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -388,7 +388,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) { SmallVector<NamedAttribute, 1> res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) { - auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); + auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue()); res.push_back(NamedAttribute(a.first, attr)); } else { res.push_back(a); diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp index d2b551f0296..d10623e3d1d 100644 --- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp @@ -128,7 +128,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); - auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>( + auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>( &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); Type returnedType; |