summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/QuantOps/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-01 15:46:28 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-01 15:47:03 -0700
commite94a8bfca8f9dadda5d8a548c210ae39d97a45c1 (patch)
treef325c0ecfbd69e1bdddd766ac8f3d963b6e383c3 /mlir/lib/Dialect/QuantOps/IR
parent2ba4d802e030b51e78b7d29238ccc552ea19d1c4 (diff)
downloadbcm5719-llvm-e94a8bfca8f9dadda5d8a548c210ae39d97a45c1.tar.gz
bcm5719-llvm-e94a8bfca8f9dadda5d8a548c210ae39d97a45c1.zip
Refactor QuantOps TypeParser to use the DialectAsmParser methods directly.
This greatly simplifies the implementation and removes custom parser functionality. The necessary methods are added to the DialectAsmParser. PiperOrigin-RevId: 278015983
Diffstat (limited to 'mlir/lib/Dialect/QuantOps/IR')
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp642
1 files changed, 143 insertions, 499 deletions
diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
index 26212f69c3c..2bdde1f94f8 100644
--- a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
@@ -28,289 +28,76 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
-namespace mlir {
-namespace quant {
-
-/// Print a floating point value in a way that the parser will be able to
-/// round-trip losslessly.
-static void printStabilizedFloat(const APFloat &apValue, raw_ostream &os) {
- // We would like to output the FP constant value in exponential notation,
- // but we cannot do this if doing so will lose precision. Check here to
- // make sure that we only output it in exponential format if we can parse
- // the value back and get the same value.
- bool isInf = apValue.isInfinity();
- bool isNaN = apValue.isNaN();
- if (!isInf && !isNaN) {
- SmallString<128> strValue;
- apValue.toString(strValue, 6, 0, false);
-
- // Check to make sure that the stringized number is not some string like
- // "Inf" or NaN, that atof will accept, but the lexer will not. Check
- // that the string matches the "[-+]?[0-9]" regex.
- assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
- ((strValue[0] == '-' || strValue[0] == '+') &&
- (strValue[1] >= '0' && strValue[1] <= '9'))) &&
- "[-+]?[0-9] regex does not match!");
- // Reparse stringized version!
- if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
- os << strValue;
- return;
- }
- }
+using namespace mlir;
+using namespace quant;
- SmallVector<char, 16> str;
- apValue.toString(str);
- os << str;
-}
-
-namespace {
-
-enum class TokenKind {
- error,
- eof,
- l_brace,
- r_brace,
- l_angle,
- r_angle,
- colon,
- comma,
- alpha_ident,
- integer_literal,
- float_literal,
-};
-
-struct Token {
- TokenKind kind;
- StringRef spelling;
-};
-
-class Lexer {
-public:
- Lexer(StringRef source) : curBuffer(source), curPtr(curBuffer.begin()) {}
-
- Token lexToken();
-
-private:
- Token formToken(TokenKind kind, const char *tokStart) {
- return Token{kind, StringRef(tokStart, curPtr - tokStart)};
- }
-
- Token emitError(const char *loc, const Twine &message) {
- return formToken(TokenKind::error, loc);
- }
-
- bool isEnd() const { return curPtr == curBuffer.end(); }
-
- // Lexer implementation methods
- Token lexalpha_ident(const char *tokStart);
- Token lexNumber(const char *tokStart);
-
- StringRef curBuffer;
- const char *curPtr;
-};
-
-} // namespace
-
-Token Lexer::lexToken() {
- // Ignore whitespace.
- while (!isEnd()) {
- switch (*curPtr) {
- case ' ':
- case '\t':
- case '\n':
- case '\r':
- ++curPtr;
- continue;
- default:
- break;
- }
- break;
- }
+static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+ auto typeLoc = parser.getCurrentLocation();
+ IntegerType type;
- if (isEnd()) {
- return Token{TokenKind::eof, ""};
- }
+ // Parse storage type (alpha_ident, integer_literal).
+ StringRef identifier;
+ unsigned storageTypeWidth = 0;
+ if (failed(parser.parseOptionalKeyword(&identifier))) {
+ // If we didn't parse a keyword, this must be a signed type.
+ if (parser.parseType(type))
+ return nullptr;
+ isSigned = true;
+ storageTypeWidth = type.getWidth();
- const char *tokStart = curPtr;
- switch (*curPtr++) {
- default:
- if (isalpha(*tokStart)) {
- return lexalpha_ident(tokStart);
+ // Otherwise, this must be an unsigned integer (`u` integer-literal).
+ } else {
+ if (!identifier.consume_front("u")) {
+ parser.emitError(typeLoc, "illegal storage type prefix");
+ return nullptr;
}
- if (isdigit(*tokStart)) {
- return lexNumber(tokStart);
+ if (identifier.getAsInteger(10, storageTypeWidth)) {
+ parser.emitError(typeLoc, "expected storage type width");
+ return nullptr;
}
-
- return emitError(tokStart, "unexpected character");
-
- case '<':
- return formToken(TokenKind::l_angle, tokStart);
- case '>':
- return formToken(TokenKind::r_angle, tokStart);
- case '{':
- return formToken(TokenKind::l_brace, tokStart);
- case '}':
- return formToken(TokenKind::r_brace, tokStart);
- case ':':
- return formToken(TokenKind::colon, tokStart);
- case ',':
- return formToken(TokenKind::comma, tokStart);
- case '-':
- return lexNumber(tokStart);
- case '+':
- return lexNumber(tokStart);
- }
-}
-
-/// Lex a bare alpha identifier. Since this DSL often contains identifiers with
-/// trailing numeric components, this only matches alphas. It is up to the
-/// parser to handle identifiers that can be mixed alphanum.
-///
-/// alpha-ident ::= (letter)(letter)*
-Token Lexer::lexalpha_ident(const char *tokStart) {
- while (!isEnd() && isalpha(*curPtr)) {
- ++curPtr;
- }
- return formToken(TokenKind::alpha_ident, tokStart);
-}
-
-/// Lex a number.
-///
-/// integer-literal ::= [-+]?digit+
-/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
-Token Lexer::lexNumber(const char *tokStart) {
- // Leading '+', '-' or digit has already been consumed.
- while (!isEnd() && isdigit(*curPtr)) {
- ++curPtr;
- }
- // If not a decimal point, treat as integer.
- if (isEnd() || *curPtr != '.') {
- return formToken(TokenKind::integer_literal, tokStart);
- }
- ++curPtr;
-
- // Skip over [0-9]*([eE][-+]?[0-9]+)?
- // Leading digits.
- while (!isEnd() && isdigit(*curPtr)) {
- ++curPtr;
+ isSigned = false;
+ type = parser.getBuilder().getIntegerType(storageTypeWidth);
}
- // [eE][-+]?[0-9]+
- if (!isEnd() && (*curPtr == 'e' || *curPtr == 'E')) {
- auto remaining = curBuffer.end() - curPtr;
- if (remaining > 2 && isdigit(curPtr[1])) {
- // Lookahead 2 for digit.
- curPtr += 2;
- while (!isEnd() && isdigit(*curPtr)) {
- ++curPtr;
- }
- } else if (remaining > 3 && (curPtr[1] == '-' || curPtr[1] == '+') &&
- isdigit(curPtr[2])) {
- // Lookahead 3 for [+-] digit.
- curPtr += 3;
- while (!isEnd() && isdigit(*curPtr)) {
- ++curPtr;
- }
- }
- }
- return formToken(TokenKind::float_literal, tokStart);
-} // end namespace
-
-// --- TypeParser ---
-namespace {
-
-class TypeParser {
-public:
- TypeParser(StringRef source, MLIRContext *context, Location location)
- : context(context), location(location), lexer(source),
- curToken(lexer.lexToken()) {}
-
- /// Attempts to parse the source as a type, returning the unknown
- /// type on error.
- Type parseType();
-
-private:
- /// Unconditionally consumes the current token.
- void consumeToken() {
- assert(curToken.kind != TokenKind::eof &&
- "should not advance past EOF or errors");
- curToken = lexer.lexToken();
+ if (storageTypeWidth == 0 ||
+ storageTypeWidth > QuantizedType::MaxStorageBits) {
+ parser.emitError(typeLoc, "illegal storage type size: ")
+ << storageTypeWidth;
+ return nullptr;
}
- /// Unconditionally consumes the current token, asserting that it is of the
- /// specified kind.
- void consumeToken(TokenKind kind) {
- assert(curToken.kind == kind && "consumed an unexpected token");
- consumeToken();
- }
+ return type;
+}
- /// Conditionally consumes a token if of the specified kind.
- /// Returns true if consumed.
- bool consumeIf(TokenKind kind) {
- if (curToken.kind == kind) {
- consumeToken();
- return true;
- }
- return false;
+static ParseResult parseStorageRange(DialectAsmParser &parser,
+ IntegerType storageType, bool isSigned,
+ int64_t &storageTypeMin,
+ int64_t &storageTypeMax) {
+ int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
+ isSigned, storageType.getWidth());
+ int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
+ isSigned, storageType.getWidth());
+ if (failed(parser.parseOptionalLess())) {
+ storageTypeMin = defaultIntegerMin;
+ storageTypeMax = defaultIntegerMax;
+ return success();
}
- /// Emits an error at the current location with a message.
- void emitError(const Twine &message) {
- // TODO: All errors show up at the beginning of the extended type location.
- // Figure out how to make this location relative to where the error occurred
- // in this instance.
- mlir::emitError(location, message);
+ // Explicit storage min and storage max.
+ llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
+ if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
+ parser.getCurrentLocation(&maxLoc) ||
+ parser.parseInteger(storageTypeMax) || parser.parseGreater())
+ return failure();
+ if (storageTypeMin < defaultIntegerMin) {
+ return parser.emitError(minLoc, "illegal storage type minimum: ")
+ << storageTypeMin;
}
-
- // Parsers.
- Type parseAnyType();
- Type parseUniformType();
- IntegerType parseStorageType(bool &isSigned);
- bool parseStorageRange(IntegerType storageType, bool isSigned,
- int64_t &storageTypeMin, int64_t &storageTypeMax);
- FloatType parseExpressedType();
- bool parseQuantParams(double &scale, int64_t &zeroPoint);
-
- MLIRContext *context;
- Location location;
- Lexer lexer;
-
- // The next token that has not yet been consumed.
- Token curToken;
-};
-
-} // namespace
-
-Type TypeParser::parseType() {
- // All types start with an identifier that we switch on.
- if (curToken.kind == TokenKind::alpha_ident) {
- StringRef typeNameSpelling = curToken.spelling;
- consumeToken();
-
- Type result;
- if (typeNameSpelling == "uniform") {
- result = parseUniformType();
- if (!result) {
- return nullptr;
- }
- } else if (typeNameSpelling == "any") {
- result = parseAnyType();
- if (!result) {
- return nullptr;
- }
- } else {
- return (emitError("unknown quantized type " + typeNameSpelling), nullptr);
- }
-
- // Make sure the entire input was consumed.
- if (curToken.kind != TokenKind::eof) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
- }
-
- return result;
- } else {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (storageTypeMax > defaultIntegerMax) {
+ return parser.emitError(maxLoc, "illegal storage type maximum: ")
+ << storageTypeMax;
}
+ return success();
}
/// Parses a UniformQuantizedType.
@@ -320,7 +107,7 @@ Type TypeParser::parseType() {
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
-Type TypeParser::parseAnyType() {
+static Type parseAnyType(DialectAsmParser &parser, Location loc) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@@ -328,13 +115,12 @@ Type TypeParser::parseAnyType() {
int64_t storageTypeMax;
// Type specification.
- if (!consumeIf(TokenKind::l_angle)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
- }
+ if (parser.parseLess())
+ return nullptr;
// Storage type.
bool isSigned = false;
- storageType = parseStorageType(isSigned);
+ storageType = parseStorageType(parser, isSigned);
if (!storageType) {
return nullptr;
}
@@ -343,25 +129,41 @@ Type TypeParser::parseAnyType() {
}
// Storage type range.
- if (parseStorageRange(storageType, isSigned, storageTypeMin,
+ if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
return nullptr;
}
// Optional expressed type.
- if (consumeIf(TokenKind::colon)) {
- expressedType = parseExpressedType();
- if (!expressedType) {
+ if (succeeded(parser.parseOptionalColon())) {
+ if (parser.parseType(expressedType)) {
return nullptr;
}
}
- if (!consumeIf(TokenKind::r_angle)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseGreater()) {
+ return nullptr;
}
return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
- storageTypeMin, storageTypeMax, location);
+ storageTypeMin, storageTypeMax, loc);
+}
+
+static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
+ int64_t &zeroPoint) {
+ // scale[:zeroPoint]?
+ // scale.
+ if (parser.parseFloat(scale))
+ return failure();
+
+ // zero point.
+ zeroPoint = 0;
+ if (failed(parser.parseOptionalColon())) {
+ // Default zero point.
+ return success();
+ }
+
+ return parser.parseInteger(zeroPoint);
}
/// Parses a UniformQuantizedType.
@@ -379,7 +181,7 @@ Type TypeParser::parseAnyType() {
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= float-literal `:` integer-literal
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
-Type TypeParser::parseUniformType() {
+static Type parseUniformType(DialectAsmParser &parser, Location loc) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@@ -391,13 +193,13 @@ Type TypeParser::parseUniformType() {
SmallVector<int64_t, 1> zeroPoints;
// Type specification.
- if (!consumeIf(TokenKind::l_angle)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseLess()) {
+ return nullptr;
}
// Storage type.
bool isSigned = false;
- storageType = parseStorageType(isSigned);
+ storageType = parseStorageType(parser, isSigned);
if (!storageType) {
return nullptr;
}
@@ -406,68 +208,60 @@ Type TypeParser::parseUniformType() {
}
// Storage type range.
- if (parseStorageRange(storageType, isSigned, storageTypeMin,
+ if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
return nullptr;
}
// Expressed type.
- if (!consumeIf(TokenKind::colon)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
- }
- expressedType = parseExpressedType();
- if (!expressedType) {
+ if (parser.parseColon() || parser.parseType(expressedType)) {
return nullptr;
}
// Optionally parse quantized dimension for per-axis quantization.
- if (consumeIf(TokenKind::colon)) {
- if (curToken.kind != TokenKind::integer_literal) {
- return (emitError("expected quantized dimension"), nullptr);
- }
- if (curToken.spelling.getAsInteger(10, quantizedDimension)) {
- return (emitError("illegal quantized dimension: " + curToken.spelling),
- nullptr);
- }
- consumeToken(TokenKind::integer_literal);
+ if (succeeded(parser.parseOptionalColon())) {
+ if (parser.parseInteger(quantizedDimension))
+ return nullptr;
isPerAxis = true;
}
// Comma leading into range_spec.
- if (!consumeIf(TokenKind::comma)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseComma()) {
+ return nullptr;
}
// Parameter specification.
// For per-axis, ranges are in a {} delimitted list.
if (isPerAxis) {
- if (!consumeIf(TokenKind::l_brace)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseLBrace()) {
+ return nullptr;
}
}
// Parse scales/zeroPoints.
+ llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
do {
scales.resize(scales.size() + 1);
zeroPoints.resize(zeroPoints.size() + 1);
- if (parseQuantParams(scales.back(), zeroPoints.back())) {
+ if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
return nullptr;
}
- } while (isPerAxis && consumeIf(TokenKind::comma));
+ } while (isPerAxis && succeeded(parser.parseOptionalComma()));
if (isPerAxis) {
- if (!consumeIf(TokenKind::r_brace)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseRBrace()) {
+ return nullptr;
}
}
- if (!consumeIf(TokenKind::r_angle)) {
- return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ if (parser.parseGreater()) {
+ return nullptr;
}
if (!isPerAxis && scales.size() > 1) {
- return (emitError("multiple scales/zeroPoints provided, but "
- "quantizedDimension wasn't specified"),
+ return (parser.emitError(scaleZPLoc,
+ "multiple scales/zeroPoints provided, but "
+ "quantizedDimension wasn't specified"),
nullptr);
}
@@ -476,160 +270,34 @@ Type TypeParser::parseUniformType() {
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return UniformQuantizedPerAxisType::getChecked(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
- quantizedDimension, storageTypeMin, storageTypeMax, location);
- }
-
- return UniformQuantizedType::getChecked(
- typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
- storageTypeMin, storageTypeMax, location);
-}
-
-IntegerType TypeParser::parseStorageType(bool &isSigned) {
- // Parse storage type (alpha_ident, integer_literal).
- StringRef storageTypePrefix = curToken.spelling;
- unsigned storageTypeWidth;
- if (curToken.kind != TokenKind::alpha_ident) {
- return (emitError("expected storage type prefix"), nullptr);
+ quantizedDimension, storageTypeMin, storageTypeMax, loc);
}
- consumeToken();
- if (curToken.kind != TokenKind::integer_literal) {
- return (emitError("expected storage type width"), nullptr);
- }
- if (curToken.spelling.getAsInteger(10, storageTypeWidth) ||
- storageTypeWidth == 0 ||
- storageTypeWidth > QuantizedType::MaxStorageBits) {
- return (emitError("illegal storage type size: " + Twine(curToken.spelling)),
- nullptr);
- }
- consumeToken();
- if (storageTypePrefix == "i") {
- isSigned = true;
- return IntegerType::get(storageTypeWidth, context);
- } else if (storageTypePrefix == "u") {
- isSigned = false;
- return IntegerType::get(storageTypeWidth, context);
- } else {
- return (
- emitError("illegal storage type prefix: " + Twine(storageTypePrefix)),
- nullptr);
- }
-}
-
-bool TypeParser::parseStorageRange(IntegerType storageType, bool isSigned,
- int64_t &storageTypeMin,
- int64_t &storageTypeMax) {
-
- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
- isSigned, storageType.getWidth());
- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
- isSigned, storageType.getWidth());
- if (consumeIf(TokenKind::l_angle)) {
- // Explicit storage min and storage max.
- if (curToken.kind != TokenKind::integer_literal) {
- return (emitError("expected storage type minimum"), true);
- }
- if (curToken.spelling.getAsInteger(10, storageTypeMin) ||
- storageTypeMin < defaultIntegerMin) {
- return (emitError("illegal storage type minimum: " + curToken.spelling),
- true);
- }
- consumeToken(TokenKind::integer_literal);
-
- if (!consumeIf(TokenKind::colon)) {
- return (emitError("unrecognized token: " + curToken.spelling), true);
- }
-
- if (curToken.kind != TokenKind::integer_literal) {
- return (emitError("expected storage type maximum"), true);
- }
- if (curToken.spelling.getAsInteger(10, storageTypeMax) ||
- storageTypeMax > defaultIntegerMax) {
- return (emitError("illegal storage type maximum: " + curToken.spelling),
- true);
- }
- consumeToken(TokenKind::integer_literal);
-
- if (!consumeIf(TokenKind::r_angle)) {
- return (emitError("unrecognized token: " + curToken.spelling), true);
- }
- } else {
- storageTypeMin = defaultIntegerMin;
- storageTypeMax = defaultIntegerMax;
- }
-
- return false;
-}
-
-FloatType TypeParser::parseExpressedType() {
- // Expect an alpha_ident followed by integer literal that we concat back
- // together.
- StringRef prefix = curToken.spelling;
- if (!consumeIf(TokenKind::alpha_ident)) {
- return (emitError("expected expressed type"), nullptr);
- }
- StringRef suffix = curToken.spelling;
- if (!consumeIf(TokenKind::integer_literal)) {
- return (emitError("expected expressed type"), nullptr);
- }
-
- SmallVector<char, 4> holder;
- StringRef typeName = (Twine(prefix) + Twine(suffix)).toStringRef(holder);
- if (typeName == "f32")
- return FloatType::getF32(context);
- if (typeName == "f16")
- return FloatType::getF16(context);
- if (typeName == "bf16")
- return FloatType::getBF16(context);
- if (typeName == "f64")
- return FloatType::getF64(context);
-
- return (emitError("unrecognized expressed type: " + typeName), nullptr);
-}
-
-bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
- // scale[:zeroPoint]?
- // scale.
- StringRef scaleSpelling = curToken.spelling;
- if (!consumeIf(TokenKind::float_literal) ||
- scaleSpelling.getAsDouble(scale)) {
- return (
- emitError("expected valid uniform scale. got: " + Twine(scaleSpelling)),
- true);
- }
-
- // zero point.
- zeroPoint = 0;
- if (!consumeIf(TokenKind::colon)) {
- // Default zero point.
- return false;
- }
- StringRef zeroPointSpelling = curToken.spelling;
- if (!consumeIf(TokenKind::integer_literal) ||
- zeroPointSpelling.getAsInteger(10, zeroPoint)) {
- return (emitError("expected integer uniform zero point. got: " +
- Twine(zeroPointSpelling)),
- true);
- }
-
- return false;
+ return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType,
+ scales.front(), zeroPoints.front(),
+ storageTypeMin, storageTypeMax, loc);
}
/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
- TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
- Type parsedType = typeParser.parseType();
- if (parsedType == nullptr) {
- // Error.
- // TODO(laurenzo): Do something?
- return parsedType;
- }
- return parsedType;
+ // All types start with an identifier that we switch on.
+ StringRef typeNameSpelling;
+ if (failed(parser.parseKeyword(&typeNameSpelling)))
+ return nullptr;
+
+ if (typeNameSpelling == "uniform")
+ return parseUniformType(parser, loc);
+ if (typeNameSpelling == "any")
+ return parseAnyType(parser, loc);
+
+ parser.emitError(parser.getNameLoc(),
+ "unknown quantized type " + typeNameSpelling);
+ return nullptr;
}
-static void printStorageType(QuantizedType type, raw_ostream &out) {
+static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
unsigned storageWidth = type.getStorageTypeIntegralWidth();
bool isSigned = type.isSigned();
@@ -651,49 +319,31 @@ static void printStorageType(QuantizedType type, raw_ostream &out) {
}
}
-static void printExpressedType(QuantizedType type, raw_ostream &out) {
- // repr type
- Type expressedType = type.getExpressedType();
- if (expressedType.isF32()) {
- out << "f32";
- } else if (expressedType.isF64()) {
- out << "f64";
- } else if (expressedType.isF16()) {
- out << "f16";
- } else if (expressedType.isBF16()) {
- out << "bf16";
- } else {
- out << "unknown";
- }
-}
-
static void printQuantParams(double scale, int64_t zeroPoint,
- raw_ostream &out) {
- printStabilizedFloat(APFloat(scale), out);
+ DialectAsmPrinter &out) {
+ out << scale;
if (zeroPoint != 0) {
out << ":" << zeroPoint;
}
}
/// Helper that prints a UniformQuantizedType.
-static void printAnyQuantizedType(AnyQuantizedType type, raw_ostream &out) {
+static void printAnyQuantizedType(AnyQuantizedType type,
+ DialectAsmPrinter &out) {
out << "any<";
printStorageType(type, out);
- if (type.getExpressedType()) {
- out << ":";
- printExpressedType(type, out);
+ if (Type expressedType = type.getExpressedType()) {
+ out << ":" << expressedType;
}
out << ">";
}
/// Helper that prints a UniformQuantizedType.
static void printUniformQuantizedType(UniformQuantizedType type,
- raw_ostream &out) {
+ DialectAsmPrinter &out) {
out << "uniform<";
printStorageType(type, out);
- out << ":";
- printExpressedType(type, out);
- out << ", ";
+ out << ":" << type.getExpressedType() << ", ";
// scheme specific parameters
printQuantParams(type.getScale(), type.getZeroPoint(), out);
@@ -702,12 +352,10 @@ static void printUniformQuantizedType(UniformQuantizedType type,
/// Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
- raw_ostream &out) {
+ DialectAsmPrinter &out) {
out << "uniform<";
printStorageType(type, out);
- out << ":";
- printExpressedType(type, out);
- out << ":";
+ out << ":" << type.getExpressedType() << ":";
out << type.getQuantizedDimension();
out << ", ";
@@ -715,12 +363,12 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
ArrayRef<double> scales = type.getScales();
ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
out << "{";
- for (unsigned i = 0; i < scales.size(); ++i) {
- printQuantParams(scales[i], zeroPoints[i], out);
- if (i != scales.size() - 1) {
- out << ",";
- }
- }
+ interleave(
+ llvm::seq<size_t>(0, scales.size()), out,
+ [&](size_t index) {
+ printQuantParams(scales[index], zeroPoints[index], out);
+ },
+ ",");
out << "}>";
}
@@ -730,18 +378,14 @@ void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
default:
llvm_unreachable("Unhandled quantized type");
case QuantizationTypes::Any:
- printAnyQuantizedType(type.cast<AnyQuantizedType>(), os.getStream());
+ printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
break;
case QuantizationTypes::UniformQuantized:
- printUniformQuantizedType(type.cast<UniformQuantizedType>(),
- os.getStream());
+ printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
break;
case QuantizationTypes::UniformQuantizedPerAxis:
printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
- os.getStream());
+ os);
break;
}
}
-
-} // namespace quant
-} // namespace mlir
OpenPOWER on IntegriCloud