summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Parser
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-10-03 12:33:47 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-03 12:34:36 -0700
commit218f0e611a624516da2043e9495bc6c0e2bcd8a5 (patch)
tree734921e9a9dcbf32cc729bb020e7ca2291125587 /mlir/lib/Parser
parent0b93c092b620f0d70987061ad67f621b9c69925b (diff)
downloadbcm5719-llvm-218f0e611a624516da2043e9495bc6c0e2bcd8a5.tar.gz
bcm5719-llvm-218f0e611a624516da2043e9495bc6c0e2bcd8a5.zip
Add syntactic sugar for strided memref parsing.
This CL implements the last remaining bit of the [strided memref proposal](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio). The syntax is a bit more explicit than what was originally proposed and resembles: `memref<?x?xf32, offset: 0 strides: [?, 1]>` Nonnegative strides and offsets are currently supported. Future extensions will include negative strides. This also gives a concrete example of syntactic sugar for the ([RFC] Proposed Changes to MemRef and Tensor MLIR Types)[https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/-wKHANzDNTg]. The underlying implementation still uses AffineMap layout. PiperOrigin-RevId: 272717437
Diffstat (limited to 'mlir/lib/Parser')
-rw-r--r--mlir/lib/Parser/Parser.cpp107
-rw-r--r--mlir/lib/Parser/TokenKinds.def2
2 files changed, 97 insertions, 12 deletions
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 33e1c6c2851..4cb3569137f 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -38,7 +38,6 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/bit.h"
-#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
@@ -211,6 +210,14 @@ public:
bool allowDynamic = true);
ParseResult parseXInDimensionList();
+ /// Parse strided layout specification.
+ ParseResult parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides);
+
+ // Parse a brace-delimiter list of comma-separated integers with `?` as an
+ // unknown marker.
+ ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@@ -634,6 +641,40 @@ Type Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
+/// Parse the offset and strides from a strided layout specification.
+///
+/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
+///
+ParseResult Parser::parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides) {
+ // Parse offset.
+ consumeToken(Token::kw_offset);
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `offset` keyword");
+ auto maybeOffset = getToken().getUnsignedIntegerValue();
+ bool question = getToken().is(Token::question);
+ if (!maybeOffset && !question)
+ return emitError("invalid offset");
+ offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
+ : MemRefType::kDynamicStrideOrOffset;
+ consumeToken();
+
+ if (!consumeIf(Token::comma))
+ return emitError("expected comma after offset value");
+
+ // Parse stride list.
+ if (!consumeIf(Token::kw_strides))
+ return emitError("expected `strides` keyword after offset specification");
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `strides` keyword");
+ if (failed(parseStrideList(strides)))
+ return emitError("invalid braces-enclosed stride list");
+ if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
+ return emitError("invalid memref stride");
+
+ return success();
+}
+
/// Parse a memref type.
///
/// memref-type ::= `memref` `<` dimension-list-ranked type
@@ -675,18 +716,28 @@ Type Parser::parseMemRefType() {
consumeToken(Token::integer);
parsedMemorySpace = true;
} else {
- // Parse affine map.
if (parsedMemorySpace)
- return emitError("affine map after memory space in memref type");
- auto affineMap = parseAttribute();
- if (!affineMap)
- return failure();
-
- // Verify that the parsed attribute is an affine map.
- if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
- affineMapComposition.push_back(affineMapAttr.getValue());
- else
- return emitError("expected affine map in memref type");
+ return emitError("expected memory space to be last in memref type");
+ if (getToken().is(Token::kw_offset)) {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(parseStridedLayout(offset, strides)))
+ return failure();
+ // Construct strided affine map.
+ auto map = makeStridedLinearLayoutMap(strides, offset,
+ elementType.getContext());
+ affineMapComposition.push_back(map);
+ } else {
+ // Parse affine map.
+ auto affineMap = parseAttribute();
+ if (!affineMap)
+ return failure();
+ // Verify that the parsed attribute is an affine map.
+ if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
+ affineMapComposition.push_back(affineMapAttr.getValue());
+ else
+ return emitError("expected affine map in memref type");
+ }
}
return success();
};
@@ -935,6 +986,38 @@ ParseResult Parser::parseXInDimensionList() {
return success();
}
+// Parse a comma-separated list of dimensions, possibly empty:
+// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
+ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
+ if (!consumeIf(Token::l_square))
+ return failure();
+ // Empty list early exit.
+ if (consumeIf(Token::r_square))
+ return success();
+ while (true) {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(MemRefType::kDynamicStrideOrOffset);
+ } else {
+ // This must be an integer value.
+ int64_t val;
+ if (getToken().getSpelling().getAsInteger(10, val))
+ return emitError("invalid integer value: ") << getToken().getSpelling();
+ // Make sure it is not the one value for `?`.
+ if (ShapedType::isDynamic(val))
+ return emitError("invalid integer value: ")
+ << getToken().getSpelling()
+ << ", use `?` to specify a dynamic dimension";
+ dimensions.push_back(val);
+ consumeToken(Token::integer);
+ }
+ if (!consumeIf(Token::comma))
+ break;
+ }
+ if (!consumeIf(Token::r_square))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Attribute parsing.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index 32e9b120938..19cd343274d 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -110,10 +110,12 @@ TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(none)
+TOK_KEYWORD(offset)
TOK_KEYWORD(opaque)
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
+TOK_KEYWORD(strides)
TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
TOK_KEYWORD(to)
OpenPOWER on IntegriCloud