summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp39
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp34
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp116
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp70
-rw-r--r--mlir/lib/Parser/Parser.cpp10
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp3
6 files changed, 263 insertions, 9 deletions
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 4935d2da3fb..9208ce8ab6d 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -405,13 +405,37 @@ public:
PatternRewriter &rewriter) const override {
// Expand affine map from 'affineLoadOp'.
SmallVector<Value *, 8> indices(op.getMapOperands());
- auto maybeExpandedMap =
+ auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
- if (!maybeExpandedMap)
+ if (!resultOperands)
return matchFailure();
// Build std.load memref[expandedMap.results].
- rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *maybeExpandedMap);
+ rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.prefetch' operation to its operands, and
+// feed the results to a newly created 'std.prefetch' operation (which replaces
+// the original 'affine.prefetch').
+class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
+public:
+ using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affinePrefetchOp'.
+ SmallVector<Value *, 8> indices(op.getMapOperands());
+ auto resultOperands =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!resultOperands)
+ return matchFailure();
+
+ // Build std.prefetch memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<PrefetchOp>(
+ op, op.memref(), *resultOperands, op.isWrite(),
+ op.localityHint().getZExtValue(), op.isDataCache());
return matchSuccess();
}
};
@@ -506,11 +530,10 @@ public:
void mlir::populateAffineToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns
- .insert<AffineApplyLowering, AffineDmaStartLowering,
- AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
- AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
- ctx);
+ patterns.insert<
+ AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering,
+ AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering,
+ AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx);
}
namespace {
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 5bb18458725..897135a8bbb 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1462,6 +1462,39 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
}
};
+// The prefetch operation is lowered in a way similar to the load operation
+// except that the llvm.prefetch operation is used for replacement.
+struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
+ using Base::Base;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto prefetchOp = cast<PrefetchOp>(op);
+ OperandAdaptor<PrefetchOp> transformed(operands);
+ auto type = prefetchOp.getMemRefType();
+
+ Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
+
+ // Replace with llvm.prefetch.
+ auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32));
+ auto isWrite = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
+ auto localityHint = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue()));
+ auto isData = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
+
+ rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
+ localityHint, isData);
+ return matchSuccess();
+ }
+};
+
// The lowering of index_cast becomes an integer conversion since index becomes
// an integer. If the bit width of the source and target integer types is the
// same, just erase the cast. If the target type is wider, sign-extend the
@@ -2041,6 +2074,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
MulFOpLowering,
MulIOpLowering,
OrOpLowering,
+ PrefetchOpLowering,
RemFOpLowering,
RemISOpLowering,
RemIUOpLowering,
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 8c8c67d1595..6768aa59de0 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -763,6 +763,7 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
+ std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
std::is_same<AffineOpTy, AffineApplyOp>::value,
"affine load/store/apply op expected");
@@ -790,6 +791,15 @@ void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
mapOperands);
}
template <>
+void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
+ PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
+ ArrayRef<Value *> mapOperands) const {
+ rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
+ prefetch, prefetch.memref(), map, mapOperands,
+ prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
+ prefetch.isDataCache());
+}
+template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
ArrayRef<Value *> mapOperands) const {
@@ -2003,6 +2013,112 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
+// AffinePrefetchOp
+//===----------------------------------------------------------------------===//
+
+//
+// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
+//
+static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ auto indexTy = builder.getIndexType();
+
+ MemRefType type;
+ OpAsmParser::OperandType memrefInfo;
+ IntegerAttr hintInfo;
+ auto i32Type = parser.getBuilder().getIntegerType(32);
+ StringRef readOrWrite, cacheType;
+
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ if (parser.parseOperand(memrefInfo) ||
+ parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
+ AffinePrefetchOp::getMapAttrName(),
+ result.attributes) ||
+ parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
+ parser.parseComma() || parser.parseKeyword("locality") ||
+ parser.parseLess() ||
+ parser.parseAttribute(hintInfo, i32Type,
+ AffinePrefetchOp::getLocalityHintAttrName(),
+ result.attributes) ||
+ parser.parseGreater() || parser.parseComma() ||
+ parser.parseKeyword(&cacheType) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(memrefInfo, type, result.operands) ||
+ parser.resolveOperands(mapOperands, indexTy, result.operands))
+ return failure();
+
+ if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+ return parser.emitError(parser.getNameLoc(),
+ "rw specifier has to be 'read' or 'write'");
+ result.addAttribute(
+ AffinePrefetchOp::getIsWriteAttrName(),
+ parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+
+ if (!cacheType.equals("data") && !cacheType.equals("instr"))
+ return parser.emitError(parser.getNameLoc(),
+ "cache type has to be 'data' or 'instr'");
+
+ result.addAttribute(
+ AffinePrefetchOp::getIsDataCacheAttrName(),
+ parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+
+ return success();
+}
+
+void print(OpAsmPrinter &p, AffinePrefetchOp op) {
+ p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '[';
+ AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
+ if (mapAttr) {
+ SmallVector<Value *, 2> operands(op.getMapOperands());
+ p.printAffineMapOfSSAIds(mapAttr, operands);
+ }
+ p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
+ << "locality<" << op.localityHint() << ">, "
+ << (op.isDataCache() ? "data" : "instr");
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
+ op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
+ p << " : " << op.getMemRefType();
+}
+
+LogicalResult verify(AffinePrefetchOp op) {
+ auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
+ if (mapAttr) {
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != op.getMemRefType().getRank())
+ return op.emitOpError("affine.prefetch affine map num results must equal"
+ " memref rank");
+ if (map.getNumInputs() + 1 != op.getNumOperands())
+ return op.emitOpError("too few operands");
+ } else {
+ if (op.getNumOperands() != 1)
+ return op.emitOpError("too few operands");
+ }
+
+ for (auto *idx : op.getMapOperands()) {
+ if (!isValidAffineIndexOperand(idx))
+ return op.emitOpError("index must be a dimension or symbol identifier");
+ }
+ return success();
+}
+
+void AffinePrefetchOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // prefetch(memrefcast) -> prefetch
+ results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
+}
+
+LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// prefetch(memrefcast) -> prefetch
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index b2b3ba5f509..d0fd1855f96 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1789,6 +1789,76 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
+// PrefetchOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, PrefetchOp op) {
+ p << PrefetchOp::getOperationName() << " " << *op.memref() << '[';
+ p.printOperands(op.indices());
+ p << ']' << ", " << (op.isWrite() ? "write" : "read");
+ p << ", locality<" << op.localityHint();
+ p << ">, " << (op.isDataCache() ? "data" : "instr");
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
+ p << " : " << op.getMemRefType();
+}
+
+static ParseResult parsePrefetchOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType memrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ IntegerAttr localityHint;
+ MemRefType type;
+ StringRef readOrWrite, cacheType;
+
+ auto indexTy = parser.getBuilder().getIndexType();
+ auto i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseOperand(memrefInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
+ parser.parseComma() || parser.parseKeyword("locality") ||
+ parser.parseLess() ||
+ parser.parseAttribute(localityHint, i32Type, "localityHint",
+ result.attributes) ||
+ parser.parseGreater() || parser.parseComma() ||
+ parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
+ parser.resolveOperand(memrefInfo, type, result.operands) ||
+ parser.resolveOperands(indexInfo, indexTy, result.operands))
+ return failure();
+
+ if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+ return parser.emitError(parser.getNameLoc(),
+ "rw specifier has to be 'read' or 'write'");
+ result.addAttribute(
+ PrefetchOp::getIsWriteAttrName(),
+ parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+
+ if (!cacheType.equals("data") && !cacheType.equals("instr"))
+ return parser.emitError(parser.getNameLoc(),
+ "cache type has to be 'data' or 'instr'");
+
+ result.addAttribute(
+ PrefetchOp::getIsDataCacheAttrName(),
+ parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+
+ return success();
+}
+
+static LogicalResult verify(PrefetchOp op) {
+ if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
+ return op.emitOpError("too few indices");
+
+ return success();
+}
+
+LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // prefetch(memrefcast) -> prefetch
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 1a02745e90c..498a64d70c2 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3883,6 +3883,16 @@ public:
return parser.parseToken(Token::equal, "expected '='");
}
+ /// Parse a '<' token.
+ ParseResult parseLess() override {
+ return parser.parseToken(Token::less, "expected '<'");
+ }
+
+ /// Parse a '>' token.
+ ParseResult parseGreater() override {
+ return parser.parseToken(Token::greater, "expected '>'");
+ }
+
/// Parse a `(` token.
ParseResult parseLParen() override {
return parser.parseToken(Token::l_paren, "expected '('");
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 79a6d7a6902..57a92531163 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -49,7 +49,8 @@ static bool isMemRefDereferencingOp(Operation &op) {
/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
return TypeSwitch<Operation *, NamedAttribute>(op)
- .Case<AffineDmaStartOp, AffineLoadOp, AffineStoreOp, AffineDmaWaitOp>(
+ .Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp,
+ AffineDmaWaitOp>(
[=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
}
OpenPOWER on IntegriCloud