diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp | 39 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 34 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 116 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 70 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 3 |
6 files changed, 263 insertions, 9 deletions
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 4935d2da3fb..9208ce8ab6d 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -405,13 +405,37 @@ public: PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. SmallVector<Value *, 8> indices(op.getMapOperands()); - auto maybeExpandedMap = + auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); - if (!maybeExpandedMap) + if (!resultOperands) return matchFailure(); // Build std.load memref[expandedMap.results]. - rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *maybeExpandedMap); + rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.prefetch' operation to its operands, and +// feed the results to a newly created 'std.prefetch' operation (which replaces +// the original 'affine.prefetch'). +class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> { +public: + using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffinePrefetchOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affinePrefetchOp'. + SmallVector<Value *, 8> indices(op.getMapOperands()); + auto resultOperands = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!resultOperands) + return matchFailure(); + + // Build std.prefetch memref[expandedMap.results]. + rewriter.replaceOpWithNewOp<PrefetchOp>( + op, op.memref(), *resultOperands, op.isWrite(), + op.localityHint().getZExtValue(), op.isDataCache()); return matchSuccess(); } }; @@ -506,11 +530,10 @@ public: void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns - .insert<AffineApplyLowering, AffineDmaStartLowering, - AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering, - AffineForLowering, AffineIfLowering, AffineTerminatorLowering>( - ctx); + patterns.insert< + AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering, + AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering, + AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx); } namespace { diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5bb18458725..897135a8bbb 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1462,6 +1462,39 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { } }; +// The prefetch operation is lowered in a way similar to the load operation +// except that the llvm.prefetch operation is used for replacement. +struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto prefetchOp = cast<PrefetchOp>(op); + OperandAdaptor<PrefetchOp> transformed(operands); + auto type = prefetchOp.getMemRefType(); + + Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); + + // Replace with llvm.prefetch. + auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); + auto isWrite = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.isWrite())); + auto localityHint = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue())); + auto isData = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); + + rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite, + localityHint, isData); + return matchSuccess(); + } +}; + // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the @@ -2041,6 +2074,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( MulFOpLowering, MulIOpLowering, OrOpLowering, + PrefetchOpLowering, RemFOpLowering, RemISOpLowering, RemIUOpLowering, diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 8c8c67d1595..6768aa59de0 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -763,6 +763,7 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> { PatternMatchResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value || + std::is_same<AffineOpTy, AffinePrefetchOp>::value || std::is_same<AffineOpTy, AffineStoreOp>::value || std::is_same<AffineOpTy, AffineApplyOp>::value, "affine load/store/apply op expected"); @@ -790,6 +791,15 @@ void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp( mapOperands); } template <> +void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp( + PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, + ArrayRef<Value *> mapOperands) const { + rewriter.replaceOpWithNewOp<AffinePrefetchOp>( + prefetch, prefetch.memref(), map, mapOperands, + prefetch.localityHint().getZExtValue(), prefetch.isWrite(), + prefetch.isDataCache()); +} +template <> void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp( PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, ArrayRef<Value *> mapOperands) const { @@ -2003,6 +2013,112 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) { } //===----------------------------------------------------------------------===// +// AffinePrefetchOp +//===----------------------------------------------------------------------===// + +// +// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> +// +static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + auto indexTy = builder.getIndexType(); + + MemRefType type; + OpAsmParser::OperandType memrefInfo; + IntegerAttr hintInfo; + auto i32Type = parser.getBuilder().getIntegerType(32); + StringRef readOrWrite, cacheType; + + AffineMapAttr mapAttr; + SmallVector<OpAsmParser::OperandType, 1> mapOperands; + if (parser.parseOperand(memrefInfo) || + parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, + AffinePrefetchOp::getMapAttrName(), + result.attributes) || + parser.parseComma() || parser.parseKeyword(&readOrWrite) || + parser.parseComma() || parser.parseKeyword("locality") || + parser.parseLess() || + parser.parseAttribute(hintInfo, i32Type, + AffinePrefetchOp::getLocalityHintAttrName(), + result.attributes) || + parser.parseGreater() || parser.parseComma() || + parser.parseKeyword(&cacheType) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(mapOperands, indexTy, result.operands)) + return failure(); + + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + return parser.emitError(parser.getNameLoc(), + "rw specifier has to be 'read' or 'write'"); + result.addAttribute( + AffinePrefetchOp::getIsWriteAttrName(), + parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + + if (!cacheType.equals("data") && !cacheType.equals("instr")) + return parser.emitError(parser.getNameLoc(), + "cache type has to be 'data' or 'instr'"); + + result.addAttribute( + AffinePrefetchOp::getIsDataCacheAttrName(), + parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + + return success(); +} + +void print(OpAsmPrinter &p, AffinePrefetchOp op) { + p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '['; + AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()); + if (mapAttr) { + SmallVector<Value *, 2> operands(op.getMapOperands()); + p.printAffineMapOfSSAIds(mapAttr, operands); + } + p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " + << "locality<" << op.localityHint() << ">, " + << (op.isDataCache() ? "data" : "instr"); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), + op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); + p << " : " << op.getMemRefType(); +} + +LogicalResult verify(AffinePrefetchOp op) { + auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()); + if (mapAttr) { + AffineMap map = mapAttr.getValue(); + if (map.getNumResults() != op.getMemRefType().getRank()) + return op.emitOpError("affine.prefetch affine map num results must equal" + " memref rank"); + if (map.getNumInputs() + 1 != op.getNumOperands()) + return op.emitOpError("too few operands"); + } else { + if (op.getNumOperands() != 1) + return op.emitOpError("too few operands"); + } + + for (auto *idx : op.getMapOperands()) { + if (!isValidAffineIndexOperand(idx)) + return op.emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffinePrefetchOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + // prefetch(memrefcast) -> prefetch + results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context); +} + +LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// prefetch(memrefcast) -> prefetch + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index b2b3ba5f509..d0fd1855f96 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1789,6 +1789,76 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { } //===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, PrefetchOp op) { + p << PrefetchOp::getOperationName() << " " << *op.memref() << '['; + p.printOperands(op.indices()); + p << ']' << ", " << (op.isWrite() ? "write" : "read"); + p << ", locality<" << op.localityHint(); + p << ">, " << (op.isDataCache() ? "data" : "instr"); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); + p << " : " << op.getMemRefType(); +} + +static ParseResult parsePrefetchOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memrefInfo; + SmallVector<OpAsmParser::OperandType, 4> indexInfo; + IntegerAttr localityHint; + MemRefType type; + StringRef readOrWrite, cacheType; + + auto indexTy = parser.getBuilder().getIndexType(); + auto i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseOperand(memrefInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseKeyword(&readOrWrite) || + parser.parseComma() || parser.parseKeyword("locality") || + parser.parseLess() || + parser.parseAttribute(localityHint, i32Type, "localityHint", + result.attributes) || + parser.parseGreater() || parser.parseComma() || + parser.parseKeyword(&cacheType) || parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands)) + return failure(); + + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + return parser.emitError(parser.getNameLoc(), + "rw specifier has to be 'read' or 'write'"); + result.addAttribute( + PrefetchOp::getIsWriteAttrName(), + parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + + if (!cacheType.equals("data") && !cacheType.equals("instr")) + return parser.emitError(parser.getNameLoc(), + "cache type has to be 'data' or 'instr'"); + + result.addAttribute( + PrefetchOp::getIsDataCacheAttrName(), + parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + + return success(); +} + +static LogicalResult verify(PrefetchOp op) { + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + return op.emitOpError("too few indices"); + + return success(); +} + +LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + // prefetch(memrefcast) -> prefetch + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 1a02745e90c..498a64d70c2 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3883,6 +3883,16 @@ public: return parser.parseToken(Token::equal, "expected '='"); } + /// Parse a '<' token. + ParseResult parseLess() override { + return parser.parseToken(Token::less, "expected '<'"); + } + + /// Parse a '>' token. + ParseResult parseGreater() override { + return parser.parseToken(Token::greater, "expected '>'"); + } + /// Parse a `(` token. ParseResult parseLParen() override { return parser.parseToken(Token::l_paren, "expected '('"); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 79a6d7a6902..57a92531163 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -49,7 +49,8 @@ static bool isMemRefDereferencingOp(Operation &op) { /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { return TypeSwitch<Operation *, NamedAttribute>(op) - .Case<AffineDmaStartOp, AffineLoadOp, AffineStoreOp, AffineDmaWaitOp>( + .Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp, + AffineDmaWaitOp>( [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); } |

