summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps/Ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp3000
1 files changed, 3000 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
new file mode 100644
index 00000000000..831c78a4521
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -0,0 +1,3000 @@
+//===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+
+#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// StandardOpsDialect Interfaces
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class defines the interface for handling inlining with standard
+/// operations.
+struct StdInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// All operations within standard ops can be inlined.
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator by replacing it with a new operation
+ /// as necessary.
+ void handleTerminator(Operation *op, Block *newDest) const final {
+ // Only "std.return" needs to be handled here.
+ auto returnOp = dyn_cast<ReturnOp>(op);
+ if (!returnOp)
+ return;
+
+ // Replace the return with a branch to the dest.
+ OpBuilder builder(op);
+ builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
+ op->erase();
+ }
+
+ /// Handle the given inlined terminator by replacing it with a new operation
+ /// as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const final {
+ // Only "std.return" needs to be handled here.
+ auto returnOp = cast<ReturnOp>(op);
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// StandardOpsDialect
+//===----------------------------------------------------------------------===//
+
+/// A custom unary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
+ assert(op->getNumOperands() == 1 && "unary op should have one operand");
+ assert(op->getNumResults() == 1 && "unary op should have one result");
+
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+ << *op->getOperand(0);
+ p.printOptionalAttrDict(op->getAttrs());
+ p << " : " << op->getOperand(0)->getType();
+}
+
+/// A custom binary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
+ assert(op->getNumOperands() == 2 && "binary op should have two operands");
+ assert(op->getNumResults() == 1 && "binary op should have one result");
+
+ // If not all the operand and result types are the same, just use the
+ // generic assembly form to avoid omitting information in printing.
+ auto resultType = op->getResult(0)->getType();
+ if (op->getOperand(0)->getType() != resultType ||
+ op->getOperand(1)->getType() != resultType) {
+ p.printGenericOp(op);
+ return;
+ }
+
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+ << *op->getOperand(0) << ", " << *op->getOperand(1);
+ p.printOptionalAttrDict(op->getAttrs());
+
+ // Now we can output only one type for all operands and the result.
+ p << " : " << op->getResult(0)->getType();
+}
+
+/// A custom cast operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+ << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
+ << op->getResult(0)->getType();
+}
+
+/// A custom cast operation verifier.
+template <typename T> static LogicalResult verifyCastOp(T op) {
+ auto opType = op.getOperand()->getType();
+ auto resType = op.getType();
+ if (!T::areCastCompatible(opType, resType))
+ return op.emitError("operand type ") << opType << " and result type "
+ << resType << " are cast incompatible";
+
+ return success();
+}
+
+StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<DmaStartOp, DmaWaitOp,
+#define GET_OP_LIST
+#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
+ >();
+ addInterfaces<StdInlinerInterface>();
+}
+
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
+void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
+ Operation::operand_iterator end,
+ unsigned numDims, OpAsmPrinter &p) {
+ Operation::operand_range operands(begin, end);
+ p << '(' << operands.take_front(numDims) << ')';
+ if (operands.size() != numDims)
+ p << '[' << operands.drop_front(numDims) << ']';
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
+ SmallVectorImpl<Value> &operands,
+ unsigned &numDims) {
+ SmallVector<OpAsmParser::OperandType, 8> opInfos;
+ if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
+ return failure();
+ // Store number of dimensions for validation by caller.
+ numDims = opInfos.size();
+
+ // Parse the optional symbol operands.
+ auto indexTy = parser.getBuilder().getIndexType();
+ if (parser.parseOperandList(opInfos,
+ OpAsmParser::Delimiter::OptionalSquare) ||
+ parser.resolveOperands(opInfos, indexTy, operands))
+ return failure();
+ return success();
+}
+
+/// Matches a ConstantIndexOp.
+/// TODO: This should probably just be a general matcher that uses m_Constant
+/// and checks the operation for an index type.
+static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
+ return detail::op_matcher<ConstantIndexOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// Common canonicalization pattern support logic
+//===----------------------------------------------------------------------===//
+
+/// This is a common class used for patterns of the form
+/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
+/// into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
+//===----------------------------------------------------------------------===//
+// AddFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [](APFloat a, APFloat b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AddIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
+ /// addi(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, AllocOp op) {
+ p << "alloc";
+
+ // Print dynamic dimension operands.
+ MemRefType type = op.getType();
+ printDimAndSymbolList(op.operand_begin(), op.operand_end(),
+ type.getNumDynamicDims(), p);
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
+ p << " : " << type;
+}
+
+static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) {
+ MemRefType type;
+
+ // Parse the dimension operands and optional symbol operands, followed by a
+ // memref type.
+ unsigned numDimOperands;
+ if (parseDimAndSymbolList(parser, result.operands, numDimOperands) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type))
+ return failure();
+
+ // Check numDynamicDims against number of question marks in memref type.
+ // Note: this check remains here (instead of in verify()), because the
+ // partition between dim operands and symbol operands is lost after parsing.
+ // Verification still checks that the total number of operands matches
+ // the number of symbols in the affine map, plus the number of dynamic
+ // dimensions in the memref.
+ if (numDimOperands != type.getNumDynamicDims())
+ return parser.emitError(parser.getNameLoc())
+ << "dimension operand count does not equal memref dynamic dimension "
+ "count";
+ result.types.push_back(type);
+ return success();
+}
+
+static LogicalResult verify(AllocOp op) {
+ auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return op.emitOpError("result must be a memref");
+
+ unsigned numSymbols = 0;
+ if (!memRefType.getAffineMaps().empty()) {
+ // Store number of symbols used in affine map (used in subsequent check).
+ AffineMap affineMap = memRefType.getAffineMaps()[0];
+ numSymbols = affineMap.getNumSymbols();
+ }
+
+ // Check that the total number of operands matches the number of symbols in
+ // the affine map, plus the number of dynamic dimensions specified in the
+ // memref type.
+ unsigned numDynamicDims = memRefType.getNumDynamicDims();
+ if (op.getNumOperands() != numDynamicDims + numSymbols)
+ return op.emitOpError(
+ "operand count does not equal dimension plus symbol operand count");
+
+ // Verify that all operands are of type Index.
+ for (auto operandType : op.getOperandTypes())
+ if (!operandType.isIndex())
+ return op.emitOpError("requires operands to be of type Index");
+ return success();
+}
+
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
+ // Check to see if any dimensions operands are constants. If so, we can
+ // substitute and drop them.
+ if (llvm::none_of(alloc.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return matchFailure();
+
+ auto memrefType = alloc.getType();
+
+ // Ok, we have one or more constant operands. Collect the non-constant ones
+ // and keep track of the resultant memref type to build.
+ SmallVector<int64_t, 4> newShapeConstants;
+ newShapeConstants.reserve(memrefType.getRank());
+ SmallVector<Value, 4> newOperands;
+ SmallVector<Value, 4> droppedOperands;
+
+ unsigned dynamicDimPos = 0;
+ for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+ int64_t dimSize = memrefType.getDimSize(dim);
+ // If this is already static dimension, keep it.
+ if (dimSize != -1) {
+ newShapeConstants.push_back(dimSize);
+ continue;
+ }
+ auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic shape dimension will be folded.
+ newShapeConstants.push_back(constantIndexOp.getValue());
+ // Record to check for zero uses later below.
+ droppedOperands.push_back(constantIndexOp);
+ } else {
+ // Dynamic shape dimension not folded; copy operand from old memref.
+ newShapeConstants.push_back(-1);
+ newOperands.push_back(alloc.getOperand(dynamicDimPos));
+ }
+ dynamicDimPos++;
+ }
+
+ // Create new memref type (which will have fewer dynamic dimensions).
+ auto newMemRefType = MemRefType::get(
+ newShapeConstants, memrefType.getElementType(),
+ memrefType.getAffineMaps(), memrefType.getMemorySpace());
+ assert(static_cast<int64_t>(newOperands.size()) ==
+ newMemRefType.getNumDynamicDims());
+
+ // Create and insert the alloc op for the new memref.
+ auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType,
+ newOperands, IntegerAttr());
+ // Insert a cast so we have the same type as the old alloc.
+ auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
+ alloc.getType());
+
+ rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
+ return matchSuccess();
+ }
+};
+
+/// Fold alloc operations with no uses. Alloc has side effects on the heap,
+/// but can still be deleted if it has zero uses.
+struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
+ if (alloc.use_empty()) {
+ rewriter.eraseOp(alloc);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+} // end anonymous namespace.
+
+void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// BranchOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Simplify a branch to a block that has a single predecessor. This effectively
+/// merges the two blocks.
+struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
+ using OpRewritePattern<BranchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(BranchOp op,
+ PatternRewriter &rewriter) const override {
+ // Check that the successor block has a single predecessor.
+ Block *succ = op.getDest();
+ Block *opParent = op.getOperation()->getBlock();
+ if (succ == opParent || !has_single_element(succ->getPredecessors()))
+ return matchFailure();
+
+ // Merge the successor into the current block and erase the branch.
+ rewriter.mergeBlocks(succ, opParent, op.getOperands());
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace.
+
+static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
+ Block *dest;
+ SmallVector<Value, 4> destOperands;
+ if (parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void print(OpAsmPrinter &p, BranchOp op) {
+ p << "br ";
+ p.printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
+
+void BranchOp::eraseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(0, index);
+}
+
+void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SimplifyBrToBlockWithSinglePred>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
+ FlatSymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", result.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), result.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ result.operands))
+ return failure();
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, CallOp op) {
+ p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCalleeType();
+}
+
+static LogicalResult verify(CallOp op) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return op.emitOpError("requires a 'callee' symbol reference attribute");
+ auto fn =
+ op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
+ if (!fn)
+ return op.emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+
+ // Verify that the operand and result types match the callee.
+ auto fnType = fn.getType();
+ if (fnType.getNumInputs() != op.getNumOperands())
+ return op.emitOpError("incorrect number of operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+ if (op.getOperand(i)->getType() != fnType.getInput(i))
+ return op.emitOpError("operand type mismatch");
+
+ if (fnType.getNumResults() != op.getNumResults())
+ return op.emitOpError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+ if (op.getResult(i)->getType() != fnType.getResult(i))
+ return op.emitOpError("result type mismatch");
+
+ return success();
+}
+
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// CallIndirectOp
+//===----------------------------------------------------------------------===//
+namespace {
+/// Fold indirect calls that have a constant function as the callee operand.
+struct SimplifyIndirectCallWithKnownCallee
+ : public OpRewritePattern<CallIndirectOp> {
+ using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
+ PatternRewriter &rewriter) const override {
+ // Check that the callee is a constant callee.
+ SymbolRefAttr calledFn;
+ if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
+ return matchFailure();
+
+ // Replace with a direct call.
+ SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+ indirectCall.getArgOperands());
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace.
+
+static ParseResult parseCallIndirectOp(OpAsmParser &parser,
+ OperationState &result) {
+ FunctionType calleeType;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return failure(
+ parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.resolveOperand(callee, calleeType, result.operands) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+ result.operands) ||
+ parser.addTypesToList(calleeType.getResults(), result.types));
+}
+
+static void print(OpAsmPrinter &p, CallIndirectOp op) {
+ p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCallee()->getType();
+}
+
+static LogicalResult verify(CallIndirectOp op) {
+ // The callee must be a function.
+ auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
+ if (!fnType)
+ return op.emitOpError("callee must have function type");
+
+ // Verify that the operand and result types match the callee.
+ if (fnType.getNumInputs() != op.getNumOperands() - 1)
+ return op.emitOpError("incorrect number of operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
+ if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
+ return op.emitOpError("operand type mismatch");
+
+ if (fnType.getNumResults() != op.getNumResults())
+ return op.emitOpError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
+ if (op.getResult(i)->getType() != fnType.getResult(i))
+ return op.emitOpError("result type mismatch");
+
+ return success();
+}
+
+void CallIndirectOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SimplifyIndirectCallWithKnownCallee>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// General helpers for comparison ops
+//===----------------------------------------------------------------------===//
+
+// Return the type of the same shape (scalar, vector or tensor) containing i1.
+static Type getCheckedI1SameShape(Builder *build, Type type) {
+ auto i1Type = build->getI1Type();
+ if (type.isIntOrIndexOrFloat())
+ return i1Type;
+ if (auto tensorType = type.dyn_cast<RankedTensorType>())
+ return RankedTensorType::get(tensorType.getShape(), i1Type);
+ if (type.isa<UnrankedTensorType>())
+ return UnrankedTensorType::get(i1Type);
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return VectorType::get(vectorType.getShape(), i1Type);
+ return Type();
+}
+
+static Type getI1SameShape(Builder *build, Type type) {
+ Type res = getCheckedI1SameShape(build, type);
+ assert(res && "expected type with valid i1 shape");
+ return res;
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+static void buildCmpIOp(Builder *build, OperationState &result,
+ CmpIPredicate predicate, Value lhs, Value rhs) {
+ result.addOperands({lhs, rhs});
+ result.types.push_back(getI1SameShape(build, lhs->getType()));
+ result.addAttribute(
+ CmpIOp::getPredicateAttrName(),
+ build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ SmallVector<NamedAttribute, 4> attrs;
+ Attribute predicateNameAttr;
+ Type type;
+ if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
+ attrs) ||
+ parser.parseComma() || parser.parseOperandList(ops, 2) ||
+ parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
+ parser.resolveOperands(ops, type, result.operands))
+ return failure();
+
+ if (!predicateNameAttr.isa<StringAttr>())
+ return parser.emitError(parser.getNameLoc(),
+ "expected string comparison predicate attribute");
+
+ // Rewrite string attribute to an enum value.
+ StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
+ Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
+ if (!predicate.hasValue())
+ return parser.emitError(parser.getNameLoc())
+ << "unknown comparison predicate \"" << predicateName << "\"";
+
+ auto builder = parser.getBuilder();
+ Type i1Type = getCheckedI1SameShape(&builder, type);
+ if (!i1Type)
+ return parser.emitError(parser.getNameLoc(),
+ "expected type with valid i1 shape");
+
+ attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
+ result.attributes = attrs;
+
+ result.addTypes({i1Type});
+ return success();
+}
+
+static void print(OpAsmPrinter &p, CmpIOp op) {
+ p << "cmpi ";
+
+ Builder b(op.getContext());
+ auto predicateValue =
+ op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
+ p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
+ << '"' << ", " << op.lhs() << ", " << op.rhs();
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
+ p << " : " << op.lhs()->getType();
+}
+
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
+// comparison predicates.
+static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
+ const APInt &rhs) {
+ switch (predicate) {
+ case CmpIPredicate::eq:
+ return lhs.eq(rhs);
+ case CmpIPredicate::ne:
+ return lhs.ne(rhs);
+ case CmpIPredicate::slt:
+ return lhs.slt(rhs);
+ case CmpIPredicate::sle:
+ return lhs.sle(rhs);
+ case CmpIPredicate::sgt:
+ return lhs.sgt(rhs);
+ case CmpIPredicate::sge:
+ return lhs.sge(rhs);
+ case CmpIPredicate::ult:
+ return lhs.ult(rhs);
+ case CmpIPredicate::ule:
+ return lhs.ule(rhs);
+ case CmpIPredicate::ugt:
+ return lhs.ugt(rhs);
+ case CmpIPredicate::uge:
+ return lhs.uge(rhs);
+ default:
+ llvm_unreachable("unknown comparison predicate");
+ }
+}
+
+// Constant folding hook for comparisons.
+OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "cmpi takes two arguments");
+
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs || !rhs)
+ return {};
+
+ auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+ return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
+static inline const char *const *getCmpFPredicateNames() {
+ static const char *predicateNames[] = {
+ /*AlwaysFalse*/ "false",
+ /*OEQ*/ "oeq",
+ /*OGT*/ "ogt",
+ /*OGE*/ "oge",
+ /*OLT*/ "olt",
+ /*OLE*/ "ole",
+ /*ONE*/ "one",
+ /*ORD*/ "ord",
+ /*UEQ*/ "ueq",
+ /*UGT*/ "ugt",
+ /*UGE*/ "uge",
+ /*ULT*/ "ult",
+ /*ULE*/ "ule",
+ /*UNE*/ "une",
+ /*UNO*/ "uno",
+ /*AlwaysTrue*/ "true",
+ };
+ static_assert(std::extent<decltype(predicateNames)>::value ==
+ (size_t)CmpFPredicate::NumPredicates,
+ "wrong number of predicate names");
+ return predicateNames;
+}
+
+// Returns a value of the predicate corresponding to the given mnemonic.
+// Returns NumPredicates (one-past-end) if there is no such mnemonic.
+CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
+ return llvm::StringSwitch<CmpFPredicate>(name)
+ .Case("false", CmpFPredicate::AlwaysFalse)
+ .Case("oeq", CmpFPredicate::OEQ)
+ .Case("ogt", CmpFPredicate::OGT)
+ .Case("oge", CmpFPredicate::OGE)
+ .Case("olt", CmpFPredicate::OLT)
+ .Case("ole", CmpFPredicate::OLE)
+ .Case("one", CmpFPredicate::ONE)
+ .Case("ord", CmpFPredicate::ORD)
+ .Case("ueq", CmpFPredicate::UEQ)
+ .Case("ugt", CmpFPredicate::UGT)
+ .Case("uge", CmpFPredicate::UGE)
+ .Case("ult", CmpFPredicate::ULT)
+ .Case("ule", CmpFPredicate::ULE)
+ .Case("une", CmpFPredicate::UNE)
+ .Case("uno", CmpFPredicate::UNO)
+ .Case("true", CmpFPredicate::AlwaysTrue)
+ .Default(CmpFPredicate::NumPredicates);
+}
+
+static void buildCmpFOp(Builder *build, OperationState &result,
+ CmpFPredicate predicate, Value lhs, Value rhs) {
+ result.addOperands({lhs, rhs});
+ result.types.push_back(getI1SameShape(build, lhs->getType()));
+ result.addAttribute(
+ CmpFOp::getPredicateAttrName(),
+ build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ SmallVector<NamedAttribute, 4> attrs;
+ Attribute predicateNameAttr;
+ Type type;
+ if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
+ attrs) ||
+ parser.parseComma() || parser.parseOperandList(ops, 2) ||
+ parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
+ parser.resolveOperands(ops, type, result.operands))
+ return failure();
+
+ if (!predicateNameAttr.isa<StringAttr>())
+ return parser.emitError(parser.getNameLoc(),
+ "expected string comparison predicate attribute");
+
+ // Rewrite string attribute to an enum value.
+ StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
+ auto predicate = CmpFOp::getPredicateByName(predicateName);
+ if (predicate == CmpFPredicate::NumPredicates)
+ return parser.emitError(parser.getNameLoc(),
+ "unknown comparison predicate \"" + predicateName +
+ "\"");
+
+ auto builder = parser.getBuilder();
+ Type i1Type = getCheckedI1SameShape(&builder, type);
+ if (!i1Type)
+ return parser.emitError(parser.getNameLoc(),
+ "expected type with valid i1 shape");
+
+ attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+ result.attributes = attrs;
+
+ result.addTypes({i1Type});
+ return success();
+}
+
+static void print(OpAsmPrinter &p, CmpFOp op) {
+ p << "cmpf ";
+
+ auto predicateValue =
+ op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
+ assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
+ predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
+ "unknown predicate index");
+ p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
+ << ", " << op.rhs();
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
+ p << " : " << op.lhs()->getType();
+}
+
+static LogicalResult verify(CmpFOp op) {
+ auto predicateAttr =
+ op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
+ if (!predicateAttr)
+ return op.emitOpError("requires an integer attribute named 'predicate'");
+ auto predicate = predicateAttr.getInt();
+ if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
+ predicate >= (int64_t)CmpFPredicate::NumPredicates)
+ return op.emitOpError("'predicate' attribute value out of range");
+
+ return success();
+}
+
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
+// comparison predicates.
+static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
+ const APFloat &rhs) {
+ auto cmpResult = lhs.compare(rhs);
+ switch (predicate) {
+ case CmpFPredicate::AlwaysFalse:
+ return false;
+ case CmpFPredicate::OEQ:
+ return cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::OGT:
+ return cmpResult == APFloat::cmpGreaterThan;
+ case CmpFPredicate::OGE:
+ return cmpResult == APFloat::cmpGreaterThan ||
+ cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::OLT:
+ return cmpResult == APFloat::cmpLessThan;
+ case CmpFPredicate::OLE:
+ return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::ONE:
+ return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
+ case CmpFPredicate::ORD:
+ return cmpResult != APFloat::cmpUnordered;
+ case CmpFPredicate::UEQ:
+ return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::UGT:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpGreaterThan;
+ case CmpFPredicate::UGE:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpGreaterThan ||
+ cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::ULT:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpLessThan;
+ case CmpFPredicate::ULE:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::UNE:
+ return cmpResult != APFloat::cmpEqual;
+ case CmpFPredicate::UNO:
+ return cmpResult == APFloat::cmpUnordered;
+ case CmpFPredicate::AlwaysTrue:
+ return true;
+ default:
+ llvm_unreachable("unknown comparison predicate");
+ }
+}
+
+// Constant folding hook for comparisons.
+OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "cmpf takes two arguments");
+
+ auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
+ auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
+
+ // TODO(gcmn) We could actually do some intelligent things if we know only one
+ // of the operands, but it's inf or nan.
+ if (!lhs || !rhs)
+ return {};
+
+ auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+ return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+}
+
+//===----------------------------------------------------------------------===//
+// CondBranchOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// cond_br true, ^bb1, ^bb2 -> br ^bb1
+/// cond_br false, ^bb1, ^bb2 -> br ^bb2
+///
+struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ if (matchPattern(condbr.getCondition(), m_NonZero())) {
+ // True branch taken.
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
+ condbr.getTrueOperands());
+ return matchSuccess();
+ } else if (matchPattern(condbr.getCondition(), m_Zero())) {
+ // False branch taken.
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
+ condbr.getFalseOperands());
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+} // end anonymous namespace.
+
+static ParseResult parseCondBranchOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Value, 4> destOperands;
+ Block *dest;
+ OpAsmParser::OperandType condInfo;
+
+ // Parse the condition.
+ Type int1Ty = parser.getBuilder().getI1Type();
+ if (parser.parseOperand(condInfo) || parser.parseComma() ||
+ parser.resolveOperand(condInfo, int1Ty, result.operands)) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected condition type was boolean (i1)");
+ }
+
+ // Parse the true successor.
+ if (parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ // Parse the false successor.
+ destOperands.clear();
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, CondBranchOp op) {
+ p << "cond_br " << op.getCondition() << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+void CondBranchOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SimplifyConstCondBranchPred>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ConstantOp &op) {
+ p << "constant ";
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
+
+ if (op.getAttrs().size() > 1)
+ p << ' ';
+ p << op.getValue();
+
+ // If the value is a symbol reference, print a trailing type.
+ if (op.getValue().isa<SymbolRefAttr>())
+ p << " : " << op.getType();
+}
+
+static ParseResult parseConstantOp(OpAsmParser &parser,
+ OperationState &result) {
+ Attribute valueAttr;
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseAttribute(valueAttr, "value", result.attributes))
+ return failure();
+
+ // If the attribute is a symbol reference, then we expect a trailing type.
+ Type type;
+ if (!valueAttr.isa<SymbolRefAttr>())
+ type = valueAttr.getType();
+ else if (parser.parseColonType(type))
+ return failure();
+
+ // Add the attribute type to the list.
+ return parser.addTypeToList(type, result.types);
+}
+
+/// The constant op requires an attribute, and furthermore requires that it
+/// matches the return type.
+static LogicalResult verify(ConstantOp &op) {
+ auto value = op.getValue();
+ if (!value)
+ return op.emitOpError("requires a 'value' attribute");
+
+ auto type = op.getType();
+ if (!value.getType().isa<NoneType>() && type != value.getType())
+ return op.emitOpError() << "requires attribute's type (" << value.getType()
+ << ") to match op's return type (" << type << ")";
+
+ if (type.isa<IndexType>() || value.isa<BoolAttr>())
+ return success();
+
+ if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
+ // If the type has a known bitwidth we verify that the value can be
+ // represented with the given bitwidth.
+ auto bitwidth = type.cast<IntegerType>().getWidth();
+ auto intVal = intAttr.getValue();
+ if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
+ return op.emitOpError("requires 'value' to be an integer within the "
+ "range of the integer result type");
+ return success();
+ }
+
+ if (type.isa<FloatType>()) {
+ if (!value.isa<FloatAttr>())
+ return op.emitOpError("requires 'value' to be a floating point constant");
+ return success();
+ }
+
+ if (type.isa<ShapedType>()) {
+ if (!value.isa<ElementsAttr>())
+ return op.emitOpError("requires 'value' to be a shaped constant");
+ return success();
+ }
+
+ if (type.isa<FunctionType>()) {
+ auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
+ if (!fnAttr)
+ return op.emitOpError("requires 'value' to be a function reference");
+
+ // Try to find the referenced function.
+ auto fn =
+ op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
+ if (!fn)
+ return op.emitOpError("reference to undefined function 'bar'");
+
+ // Check that the referenced function has the correct type.
+ if (fn.getType() != type)
+ return op.emitOpError("reference to function with mismatched type");
+
+ return success();
+ }
+
+ if (type.isa<NoneType>() && value.isa<UnitAttr>())
+ return success();
+
+ return op.emitOpError("unsupported 'value' attribute: ") << value;
+}
+
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return getValue();
+}
+
+void ConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ Type type = getType();
+ if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
+ IntegerType intTy = type.dyn_cast<IntegerType>();
+
+ // Sugar i1 constants with 'true' and 'false'.
+ if (intTy && intTy.getWidth() == 1)
+ return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
+
+ // Otherwise, build a complex name with the value and type.
+ SmallString<32> specialNameBuffer;
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ specialName << 'c' << intCst.getInt();
+ if (intTy)
+ specialName << '_' << type;
+ setNameFn(getResult(), specialName.str());
+
+ } else if (type.isa<FunctionType>()) {
+ setNameFn(getResult(), "f");
+ } else {
+ setNameFn(getResult(), "cst");
+ }
+}
+
+/// Returns true if a constant operation can be built with the given value and
+/// result type.
+bool ConstantOp::isBuildableWith(Attribute value, Type type) {
+ // SymbolRefAttr can only be used with a function type.
+ if (value.isa<SymbolRefAttr>())
+ return type.isa<FunctionType>();
+ // Otherwise, the attribute must have the same type as 'type'.
+ if (value.getType() != type)
+ return false;
+ // Finally, check that the attribute kind is handled.
+ return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
+ value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
+ value.isa<UnitAttr>();
+}
+
+void ConstantFloatOp::build(Builder *builder, OperationState &result,
+ const APFloat &value, FloatType type) {
+ ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
+}
+
+bool ConstantFloatOp::classof(Operation *op) {
+ return ConstantOp::classof(op) &&
+ op->getResult(0)->getType().isa<FloatType>();
+}
+
+/// ConstantIntOp only matches values whose result type is an IntegerType.
+bool ConstantIntOp::classof(Operation *op) {
+ return ConstantOp::classof(op) &&
+ op->getResult(0)->getType().isa<IntegerType>();
+}
+
+void ConstantIntOp::build(Builder *builder, OperationState &result,
+ int64_t value, unsigned width) {
+ Type type = builder->getIntegerType(width);
+ ConstantOp::build(builder, result, type,
+ builder->getIntegerAttr(type, value));
+}
+
+/// Build a constant int op producing an integer with the specified type,
+/// which must be an integer type.
+void ConstantIntOp::build(Builder *builder, OperationState &result,
+ int64_t value, Type type) {
+ assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
+ ConstantOp::build(builder, result, type,
+ builder->getIntegerAttr(type, value));
+}
+
+/// ConstantIndexOp only matches values whose result type is Index.
+bool ConstantIndexOp::classof(Operation *op) {
+ return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex();
+}
+
+void ConstantIndexOp::build(Builder *builder, OperationState &result,
+ int64_t value) {
+ Type type = builder->getIndexType();
+ ConstantOp::build(builder, result, type,
+ builder->getIntegerAttr(type, value));
+}
+
+//===----------------------------------------------------------------------===//
+// DeallocOp
+//===----------------------------------------------------------------------===//
+namespace {
+/// Fold Dealloc operations that are deallocating an AllocOp that is only used
+/// by other Dealloc operations.
+struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(DeallocOp dealloc,
+ PatternRewriter &rewriter) const override {
+ // Check that the memref operand's defining operation is an AllocOp.
+ Value memref = dealloc.memref();
+ if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
+ return matchFailure();
+
+ // Check that all of the uses of the AllocOp are other DeallocOps.
+ for (auto *user : memref->getUsers())
+ if (!isa<DeallocOp>(user))
+ return matchFailure();
+
+ // Erase the dealloc operation.
+ rewriter.eraseOp(dealloc);
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace.
+
+static void print(OpAsmPrinter &p, DeallocOp op) {
+ p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
+}
+
+static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType memrefInfo;
+ MemRefType type;
+
+ return failure(parser.parseOperand(memrefInfo) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(memrefInfo, type, result.operands));
+}
+
+static LogicalResult verify(DeallocOp op) {
+ if (!op.memref()->getType().isa<MemRefType>())
+ return op.emitOpError("operand must be a memref");
+ return success();
+}
+
+void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SimplifyDeadDealloc>(context);
+}
+
+LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dealloc(memrefcast) -> dealloc
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, DimOp op) {
+ p << "dim " << *op.getOperand() << ", " << op.getIndex();
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+ p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType operandInfo;
+ IntegerAttr indexAttr;
+ Type type;
+ Type indexType = parser.getBuilder().getIndexType();
+
+ return failure(
+ parser.parseOperand(operandInfo) || parser.parseComma() ||
+ parser.parseAttribute(indexAttr, indexType, "index", result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(operandInfo, type, result.operands) ||
+ parser.addTypeToList(indexType, result.types));
+}
+
+static LogicalResult verify(DimOp op) {
+ // Check that we have an integer index operand.
+ auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
+ if (!indexAttr)
+ return op.emitOpError("requires an integer attribute named 'index'");
+ int64_t index = indexAttr.getValue().getSExtValue();
+
+ auto type = op.getOperand()->getType();
+ if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ if (index >= tensorType.getRank())
+ return op.emitOpError("index is out of range");
+ } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+ if (index >= memrefType.getRank())
+ return op.emitOpError("index is out of range");
+
+ } else if (type.isa<UnrankedTensorType>()) {
+ // ok, assumed to be in-range.
+ } else {
+ return op.emitOpError("requires an operand with tensor or memref type");
+ }
+
+ return success();
+}
+
+OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+ // Constant fold dim when the size along the index referred to is a constant.
+ auto opType = memrefOrTensor()->getType();
+ int64_t indexSize = -1;
+ if (auto tensorType = opType.dyn_cast<RankedTensorType>())
+ indexSize = tensorType.getShape()[getIndex()];
+ else if (auto memrefType = opType.dyn_cast<MemRefType>())
+ indexSize = memrefType.getShape()[getIndex()];
+
+ if (!ShapedType::isDynamic(indexSize))
+ return IntegerAttr::get(IndexType::get(getContext()), indexSize);
+
+ // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
+ auto memrefType = opType.dyn_cast<MemRefType>();
+ if (!memrefType)
+ return {};
+
+ // The size at getIndex() is now a dynamic size of a memref.
+ auto memref = memrefOrTensor()->getDefiningOp();
+ if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
+ return *(alloc.getDynamicSizes().begin() +
+ memrefType.getDynamicDimIndex(getIndex()));
+
+ if (auto view = dyn_cast_or_null<ViewOp>(memref))
+ return *(view.getDynamicSizes().begin() +
+ memrefType.getDynamicDimIndex(getIndex()));
+
+ // The subview op here is expected to have rank dynamic sizes now.
+ if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
+ auto sizes = subview.sizes();
+ if (!sizes.empty())
+ return *(sizes.begin() + getIndex());
+ }
+
+ /// dim(memrefcast) -> dim
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// SignedDivIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // Don't fold if it would overflow or if it requires a division by zero.
+ bool overflowOrDiv0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (overflowOrDiv0 || !b) {
+ overflowOrDiv0 = true;
+ return a;
+ }
+ return a.sdiv_ov(b, overflowOrDiv0);
+ });
+ return overflowOrDiv0 ? Attribute() : result;
+}
+
+//===----------------------------------------------------------------------===//
+// UnsignedDivIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // Don't fold if it would require a division by zero.
+ bool div0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (div0 || !b) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : result;
+}
+
+// ---------------------------------------------------------------------------
+// DmaStartOp
+// ---------------------------------------------------------------------------
+
+void DmaStartOp::build(Builder *builder, OperationState &result,
+ Value srcMemRef, ValueRange srcIndices, Value destMemRef,
+ ValueRange destIndices, Value numElements,
+ Value tagMemRef, ValueRange tagIndices, Value stride,
+ Value elementsPerStride) {
+ result.addOperands(srcMemRef);
+ result.addOperands(srcIndices);
+ result.addOperands(destMemRef);
+ result.addOperands(destIndices);
+ result.addOperands({numElements, tagMemRef});
+ result.addOperands(tagIndices);
+ if (stride)
+ result.addOperands({stride, elementsPerStride});
+}
+
+void DmaStartOp::print(OpAsmPrinter &p) {
+ p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
+ << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
+ << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
+ if (isStrided())
+ p << ", " << *getStride() << ", " << *getNumElementsPerStride();
+
+ p.printOptionalAttrDict(getAttrs());
+ p << " : " << getSrcMemRef()->getType();
+ p << ", " << getDstMemRef()->getType();
+ p << ", " << getTagMemRef()->getType();
+}
+
+// Parse DmaStartOp.
+// Ex:
+// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
+// %tag[%index], %stride, %num_elt_per_stride :
+// : memref<3076 x f32, 0>,
+// memref<1024 x f32, 2>,
+// memref<1 x i32>
+//
+ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType srcMemRefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
+ OpAsmParser::OperandType dstMemRefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
+ OpAsmParser::OperandType numElementsInfo;
+ OpAsmParser::OperandType tagMemrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
+ SmallVector<OpAsmParser::OperandType, 2> strideInfo;
+
+ SmallVector<Type, 3> types;
+ auto indexType = parser.getBuilder().getIndexType();
+
+ // Parse and resolve the following list of operands:
+ // *) source memref followed by its indices (in square brackets).
+ // *) destination memref followed by its indices (in square brackets).
+ // *) dma size in KiB.
+ if (parser.parseOperand(srcMemRefInfo) ||
+ parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
+ parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() || parser.parseOperand(numElementsInfo) ||
+ parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
+ parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
+ return failure();
+
+ // Parse optional stride and elements per stride.
+ if (parser.parseTrailingOperandList(strideInfo))
+ return failure();
+
+ bool isStrided = strideInfo.size() == 2;
+ if (!strideInfo.empty() && !isStrided) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected two stride related operands");
+ }
+
+ if (parser.parseColonTypeList(types))
+ return failure();
+ if (types.size() != 3)
+ return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
+
+ if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
+ parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
+ parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
+ parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
+ // size should be an index.
+ parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
+ parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
+ // tag indices should be index.
+ parser.resolveOperands(tagIndexInfos, indexType, result.operands))
+ return failure();
+
+ auto memrefType0 = types[0].dyn_cast<MemRefType>();
+ if (!memrefType0)
+ return parser.emitError(parser.getNameLoc(),
+ "expected source to be of memref type");
+
+ auto memrefType1 = types[1].dyn_cast<MemRefType>();
+ if (!memrefType1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected destination to be of memref type");
+
+ auto memrefType2 = types[2].dyn_cast<MemRefType>();
+ if (!memrefType2)
+ return parser.emitError(parser.getNameLoc(),
+ "expected tag to be of memref type");
+
+ if (isStrided) {
+ if (parser.resolveOperands(strideInfo, indexType, result.operands))
+ return failure();
+ }
+
+ // Check that source/destination index list size matches associated rank.
+ if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
+ static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
+ return parser.emitError(parser.getNameLoc(),
+ "memref rank not equal to indices count");
+ if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
+ return parser.emitError(parser.getNameLoc(),
+ "tag memref rank not equal to indices count");
+
+ return success();
+}
+
+LogicalResult DmaStartOp::verify() {
+ // DMAs from different memory spaces supported.
+ if (getSrcMemorySpace() == getDstMemorySpace())
+ return emitOpError("DMA should be between different memory spaces");
+
+ if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+ getDstMemRefRank() + 3 + 1 &&
+ getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
+ getDstMemRefRank() + 3 + 1 + 2) {
+ return emitOpError("incorrect number of operands");
+ }
+ return success();
+}
+
+LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dma_start(memrefcast) -> dma_start
+ return foldMemRefCast(*this);
+}
+
+// ---------------------------------------------------------------------------
+// DmaWaitOp
+// ---------------------------------------------------------------------------
+
+void DmaWaitOp::build(Builder *builder, OperationState &result, Value tagMemRef,
+ ValueRange tagIndices, Value numElements) {
+ result.addOperands(tagMemRef);
+ result.addOperands(tagIndices);
+ result.addOperands(numElements);
+}
+
+void DmaWaitOp::print(OpAsmPrinter &p) {
+ p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
+ << getNumElements();
+ p.printOptionalAttrDict(getAttrs());
+ p << " : " << getTagMemRef()->getType();
+}
+
+// Parse DmaWaitOp.
+// Eg:
+// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
+//
+ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType tagMemrefInfo;
+ SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
+ Type type;
+ auto indexType = parser.getBuilder().getIndexType();
+ OpAsmParser::OperandType numElementsInfo;
+
+ // Parse tag memref, its indices, and dma size.
+ if (parser.parseOperand(tagMemrefInfo) ||
+ parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() || parser.parseOperand(numElementsInfo) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
+ parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
+ parser.resolveOperand(numElementsInfo, indexType, result.operands))
+ return failure();
+
+ auto memrefType = type.dyn_cast<MemRefType>();
+ if (!memrefType)
+ return parser.emitError(parser.getNameLoc(),
+ "expected tag to be of memref type");
+
+ if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
+ return parser.emitError(parser.getNameLoc(),
+ "tag memref rank not equal to indices count");
+
+ return success();
+}
+
+LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dma_wait(memrefcast) -> dma_wait
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ExtractElementOp op) {
+ p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
+ p << ']';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getAggregate()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType aggregateInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ ShapedType type;
+
+ auto indexTy = parser.getBuilder().getIndexType();
+ return failure(
+ parser.parseOperand(aggregateInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(aggregateInfo, type, result.operands) ||
+ parser.resolveOperands(indexInfo, indexTy, result.operands) ||
+ parser.addTypeToList(type.getElementType(), result.types));
+}
+
+static LogicalResult verify(ExtractElementOp op) {
+ auto aggregateType = op.getAggregate()->getType().cast<ShapedType>();
+
+ // This should be possible with tablegen type constraints
+ if (op.getType() != aggregateType.getElementType())
+ return op.emitOpError("result type must match element type of aggregate");
+
+ // Verify the # indices match if we have a ranked type.
+ if (aggregateType.hasRank() &&
+ aggregateType.getRank() != op.getNumOperands() - 1)
+ return op.emitOpError("incorrect number of indices for extract_element");
+
+ return success();
+}
+
+OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
+ assert(!operands.empty() && "extract_element takes at least one operand");
+
+ // The aggregate operand must be a known constant.
+ Attribute aggregate = operands.front();
+ if (!aggregate)
+ return {};
+
+ // If this is a splat elements attribute, simply return the value. All of the
+ // elements of a splat attribute are the same.
+ if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
+ return splatAggregate.getSplatValue();
+
+ // Otherwise, collect the constant indices into the aggregate.
+ SmallVector<uint64_t, 8> indices;
+ for (Attribute indice : llvm::drop_begin(operands, 1)) {
+ if (!indice || !indice.isa<IntegerAttr>())
+ return {};
+ indices.push_back(indice.cast<IntegerAttr>().getInt());
+ }
+
+ // If this is an elements attribute, query the value at the given indices.
+ auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
+ if (elementsAttr && elementsAttr.isValidIndex(indices))
+ return elementsAttr.getValue(indices);
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+// Index cast is applicable from index to integer and backwards.
+bool IndexCastOp::areCastCompatible(Type a, Type b) {
+ return (a.isIndex() && b.isa<IntegerType>()) ||
+ (a.isa<IntegerType>() && b.isIndex());
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, LoadOp op) {
+ p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getMemRefType();
+}
+
+static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType memrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ MemRefType type;
+
+ auto indexTy = parser.getBuilder().getIndexType();
+ return failure(
+ parser.parseOperand(memrefInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(memrefInfo, type, result.operands) ||
+ parser.resolveOperands(indexInfo, indexTy, result.operands) ||
+ parser.addTypeToList(type.getElementType(), result.types));
+}
+
+static LogicalResult verify(LoadOp op) {
+ if (op.getType() != op.getMemRefType().getElementType())
+ return op.emitOpError("result type must match element type of memref");
+
+ if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
+ return op.emitOpError("incorrect number of indices for load");
+
+ return success();
+}
+
+OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
+ /// load(memrefcast) -> load
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
+}
+
+//===----------------------------------------------------------------------===//
+// MemRefCastOp
+//===----------------------------------------------------------------------===//
+
+bool MemRefCastOp::areCastCompatible(Type a, Type b) {
+ auto aT = a.dyn_cast<MemRefType>();
+ auto bT = b.dyn_cast<MemRefType>();
+
+ auto uaT = a.dyn_cast<UnrankedMemRefType>();
+ auto ubT = b.dyn_cast<UnrankedMemRefType>();
+
+ if (aT && bT) {
+ if (aT.getElementType() != bT.getElementType())
+ return false;
+ if (aT.getAffineMaps() != bT.getAffineMaps()) {
+ int64_t aOffset, bOffset;
+ SmallVector<int64_t, 4> aStrides, bStrides;
+ if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
+ failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+ aStrides.size() != bStrides.size())
+ return false;
+
+ // Strides along a dimension/offset are compatible if the value in the
+ // source memref is static and the value in the target memref is the
+ // same. They are also compatible if either one is dynamic (see
+ // description of MemRefCastOp for details).
+ auto checkCompatible = [](int64_t a, int64_t b) {
+ return (a == MemRefType::getDynamicStrideOrOffset() ||
+ b == MemRefType::getDynamicStrideOrOffset() || a == b);
+ };
+ if (!checkCompatible(aOffset, bOffset))
+ return false;
+ for (auto aStride : enumerate(aStrides))
+ if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ return false;
+ }
+ if (aT.getMemorySpace() != bT.getMemorySpace())
+ return false;
+
+ // They must have the same rank, and any specified dimensions must match.
+ if (aT.getRank() != bT.getRank())
+ return false;
+
+ for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
+ int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
+ if (aDim != -1 && bDim != -1 && aDim != bDim)
+ return false;
+ }
+ return true;
+ } else {
+ if (!aT && !uaT)
+ return false;
+ if (!bT && !ubT)
+ return false;
+ // Unranked to unranked casting is unsupported
+ if (uaT && ubT)
+ return false;
+
+ auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
+ auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
+ if (aEltType != bEltType)
+ return false;
+
+ auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
+ auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
+ if (aMemSpace != bMemSpace)
+ return false;
+
+ return true;
+ }
+
+ return false;
+}
+
+OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// MulFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [](APFloat a, APFloat b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// MulIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
+ /// muli(x, 0) -> 0
+ if (matchPattern(rhs(), m_Zero()))
+ return rhs();
+ /// muli(x, 1) -> x
+ if (matchPattern(rhs(), m_One()))
+ return getOperand(0);
+
+ // TODO: Handle the overflow case.
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// PrefetchOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, PrefetchOp op) {
+ p << PrefetchOp::getOperationName() << " " << *op.memref() << '[';
+ p.printOperands(op.indices());
+ p << ']' << ", " << (op.isWrite() ? "write" : "read");
+ p << ", locality<" << op.localityHint();
+ p << ">, " << (op.isDataCache() ? "data" : "instr");
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
+ p << " : " << op.getMemRefType();
+}
+
+static ParseResult parsePrefetchOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType memrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ IntegerAttr localityHint;
+ MemRefType type;
+ StringRef readOrWrite, cacheType;
+
+ auto indexTy = parser.getBuilder().getIndexType();
+ auto i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseOperand(memrefInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
+ parser.parseComma() || parser.parseKeyword("locality") ||
+ parser.parseLess() ||
+ parser.parseAttribute(localityHint, i32Type, "localityHint",
+ result.attributes) ||
+ parser.parseGreater() || parser.parseComma() ||
+ parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
+ parser.resolveOperand(memrefInfo, type, result.operands) ||
+ parser.resolveOperands(indexInfo, indexTy, result.operands))
+ return failure();
+
+ if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+ return parser.emitError(parser.getNameLoc(),
+ "rw specifier has to be 'read' or 'write'");
+ result.addAttribute(
+ PrefetchOp::getIsWriteAttrName(),
+ parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+
+ if (!cacheType.equals("data") && !cacheType.equals("instr"))
+ return parser.emitError(parser.getNameLoc(),
+ "cache type has to be 'data' or 'instr'");
+
+ result.addAttribute(
+ PrefetchOp::getIsDataCacheAttrName(),
+ parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+
+ return success();
+}
+
+static LogicalResult verify(PrefetchOp op) {
+ if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
+ return op.emitOpError("too few indices");
+
+ return success();
+}
+
+LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // prefetch(memrefcast) -> prefetch
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, RankOp op) {
+ p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType operandInfo;
+ Type type;
+ Type indexType = parser.getBuilder().getIndexType();
+ return failure(parser.parseOperand(operandInfo) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(operandInfo, type, result.operands) ||
+ parser.addTypeToList(indexType, result.types));
+}
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+ // Constant fold rank when the rank of the tensor is known.
+ auto type = getOperand()->getType();
+ if (auto tensorType = type.dyn_cast<RankedTensorType>())
+ return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
+ return IntegerAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// SignedRemIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_signed takes two operands");
+
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
+ return {};
+ auto rhsValue = rhs.getValue();
+
+ // x % 1 = 0
+ if (rhsValue.isOneValue())
+ return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+
+ // Don't fold if it requires division by zero.
+ if (rhsValue.isNullValue())
+ return {};
+
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return {};
+ return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
+}
+
+//===----------------------------------------------------------------------===//
+// UnsignedRemIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_unsigned takes two operands");
+
+ auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
+ return {};
+ auto rhsValue = rhs.getValue();
+
+ // x % 1 = 0
+ if (rhsValue.isOneValue())
+ return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+
+ // Don't fold if it requires division by zero.
+ if (rhsValue.isNullValue())
+ return {};
+
+ auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return {};
+ return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, result.operands));
+}
+
+static void print(OpAsmPrinter &p, ReturnOp op) {
+ p << "return";
+ if (op.getNumOperands() != 0)
+ p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
+}
+
+static LogicalResult verify(ReturnOp op) {
+ auto function = cast<FuncOp>(op.getParentOp());
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getType().getResults();
+ if (op.getNumOperands() != results.size())
+ return op.emitOpError("has ")
+ << op.getNumOperands()
+ << " operands, but enclosing function returns " << results.size();
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (op.getOperand(i)->getType() != results[i])
+ return op.emitError()
+ << "type of return operand " << i << " ("
+ << op.getOperand(i)->getType()
+ << ") doesn't match function result type (" << results[i] << ")";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SIToFPOp
+//===----------------------------------------------------------------------===//
+
+// sitofp is applicable from integer types to float types.
+bool SIToFPOp::areCastCompatible(Type a, Type b) {
+ return a.isa<IntegerType>() && b.isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 3> ops;
+ SmallVector<NamedAttribute, 4> attrs;
+ Type type;
+ if (parser.parseOperandList(ops, 3) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type))
+ return failure();
+
+ auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type);
+ if (!i1Type)
+ return parser.emitError(parser.getNameLoc(),
+ "expected type with valid i1 shape");
+
+ SmallVector<Type, 3> types = {i1Type, type, type};
+ return failure(parser.resolveOperands(ops, types, parser.getNameLoc(),
+ result.operands) ||
+ parser.addTypeToList(type, result.types));
+}
+
+static void print(OpAsmPrinter &p, SelectOp op) {
+ p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
+ p.printOptionalAttrDict(op.getAttrs());
+}
+
+static LogicalResult verify(SelectOp op) {
+ auto trueType = op.getTrueValue()->getType();
+ auto falseType = op.getFalseValue()->getType();
+
+ if (trueType != falseType)
+ return op.emitOpError(
+ "requires 'true' and 'false' arguments to be of the same type");
+
+ return success();
+}
+
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+ auto condition = getCondition();
+
+ // select true, %0, %1 => %0
+ if (matchPattern(condition, m_One()))
+ return getTrueValue();
+
+ // select false, %0, %1 => %1
+ if (matchPattern(condition, m_Zero()))
+ return getFalseValue();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// SignExtendIOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SignExtendIOp op) {
+ // Get the scalar type (which is either directly the type of the operand
+ // or the vector's/tensor's element type.
+ auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto dstType = getElementTypeOrSelf(op.getType());
+
+ // For now, index is forbidden for the source and the destination type.
+ if (srcType.isa<IndexType>())
+ return op.emitError() << srcType << " is not a valid operand type";
+ if (dstType.isa<IndexType>())
+ return op.emitError() << dstType << " is not a valid result type";
+
+ if (srcType.cast<IntegerType>().getWidth() >=
+ dstType.cast<IntegerType>().getWidth())
+ return op.emitError("result type ")
+ << dstType << " must be wider than operand type " << srcType;
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, SplatOp op) {
+ p << "splat " << *op.getOperand();
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getType();
+}
+
+static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType splatValueInfo;
+ ShapedType shapedType;
+
+ return failure(parser.parseOperand(splatValueInfo) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(shapedType) ||
+ parser.resolveOperand(splatValueInfo,
+ shapedType.getElementType(),
+ result.operands) ||
+ parser.addTypeToList(shapedType, result.types));
+}
+
+static LogicalResult verify(SplatOp op) {
+ // TODO: we could replace this by a trait.
+ if (op.getOperand()->getType() !=
+ op.getType().cast<ShapedType>().getElementType())
+ return op.emitError("operand should be of elemental type of result type");
+
+ return success();
+}
+
+// Constant folding hook for SplatOp.
+OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1 && "splat takes one operand");
+
+ auto constOperand = operands.front();
+ if (!constOperand ||
+ (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+ return {};
+
+ auto shapedType = getType().cast<ShapedType>();
+ assert(shapedType.getElementType() == constOperand.getType() &&
+ "incorrect input attribute type for folding");
+
+ // SplatElementsAttr::get treats single value for second arg as being a splat.
+ return SplatElementsAttr::get(shapedType, {constOperand});
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, StoreOp op) {
+ p << "store " << *op.getValueToStore();
+ p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getMemRefType();
+}
+
+static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType storeValueInfo;
+ OpAsmParser::OperandType memrefInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ MemRefType memrefType;
+
+ auto indexTy = parser.getBuilder().getIndexType();
+ return failure(
+ parser.parseOperand(storeValueInfo) || parser.parseComma() ||
+ parser.parseOperand(memrefInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(memrefType) ||
+ parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
+ result.operands) ||
+ parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+ parser.resolveOperands(indexInfo, indexTy, result.operands));
+}
+
+static LogicalResult verify(StoreOp op) {
+ // First operand must have same type as memref element type.
+ if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
+ return op.emitOpError(
+ "first operand must have same type memref element type");
+
+ if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
+ return op.emitOpError("store index operand count not equal to memref rank");
+
+ return success();
+}
+
+LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// store(memrefcast) -> store
+ return foldMemRefCast(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// SubFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [](APFloat a, APFloat b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// SubIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
+ // subi(x,x) -> 0
+ if (getOperand(0) == getOperand(1))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// AndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
+ /// and(x, 0) -> 0
+ if (matchPattern(rhs(), m_Zero()))
+ return rhs();
+ /// and(x,x) -> x
+ if (lhs() == rhs())
+ return rhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
+ /// or(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+ /// or(x,x) -> x
+ if (lhs() == rhs())
+ return rhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// XOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
+ /// xor(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+ /// xor(x,x) -> 0
+ if (lhs() == rhs())
+ return Builder(getContext()).getZeroAttr(getType());
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a ^ b; });
+}
+
+//===----------------------------------------------------------------------===//
+// TensorCastOp
+//===----------------------------------------------------------------------===//
+
+bool TensorCastOp::areCastCompatible(Type a, Type b) {
+ auto aT = a.dyn_cast<TensorType>();
+ auto bT = b.dyn_cast<TensorType>();
+ if (!aT || !bT)
+ return false;
+
+ if (aT.getElementType() != bT.getElementType())
+ return false;
+
+ return succeeded(verifyCompatibleShape(aT, bT));
+}
+
+OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
+ return impl::foldCastOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op
+//===----------------------------------------------------------------------===//
+
+static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
+ if (auto memref = type.dyn_cast<MemRefType>())
+ return RankedTensorType::get(memref.getShape(), memref.getElementType());
+ return b.getNoneType();
+}
+
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, TensorLoadOp op) {
+ p << "tensor_load " << *op.getOperand();
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseTensorLoadOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType op;
+ Type type;
+ return failure(parser.parseOperand(op) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(op, type, result.operands) ||
+ parser.addTypeToList(
+ getTensorTypeFromMemRefType(parser.getBuilder(), type),
+ result.types));
+}
+
+//===----------------------------------------------------------------------===//
+// TensorStoreOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, TensorStoreOp op) {
+ p << "tensor_store " << *op.tensor() << ", " << *op.memref();
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.memref()->getType();
+}
+
+static ParseResult parseTensorStoreOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ Type type;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(
+ parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperands(
+ ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type},
+ loc, result.operands));
+}
+
+//===----------------------------------------------------------------------===//
+// TruncateIOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(TruncateIOp op) {
+ auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto dstType = getElementTypeOrSelf(op.getType());
+
+ if (srcType.isa<IndexType>())
+ return op.emitError() << srcType << " is not a valid operand type";
+ if (dstType.isa<IndexType>())
+ return op.emitError() << dstType << " is not a valid result type";
+
+ if (srcType.cast<IntegerType>().getWidth() <=
+ dstType.cast<IntegerType>().getWidth())
+ return op.emitError("operand type ")
+ << srcType << " must be wider than result type " << dstType;
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType srcInfo;
+ SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
+ SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
+ auto indexType = parser.getBuilder().getIndexType();
+ Type srcType, dstType;
+ llvm::SMLoc offsetLoc;
+ if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
+ parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
+ return failure();
+
+ if (offsetInfo.size() > 1)
+ return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand";
+
+ return failure(
+ parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(srcType) ||
+ parser.resolveOperand(srcInfo, srcType, result.operands) ||
+ parser.resolveOperands(offsetInfo, indexType, result.operands) ||
+ parser.resolveOperands(sizesInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", dstType) ||
+ parser.addTypeToList(dstType, result.types));
+}
+
+static void print(OpAsmPrinter &p, ViewOp op) {
+ p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
+ auto dynamicOffset = op.getDynamicOffset();
+ if (dynamicOffset != nullptr)
+ p.printOperand(dynamicOffset);
+ p << "][" << op.getDynamicSizes() << ']';
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
+}
+
+Value ViewOp::getDynamicOffset() {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto result =
+ succeeded(mlir::getStridesAndOffset(getType(), strides, offset));
+ assert(result);
+ if (result && offset == MemRefType::getDynamicStrideOrOffset())
+ return getOperand(1);
+ return nullptr;
+}
+
+static LogicalResult verifyDynamicStrides(MemRefType memrefType,
+ ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ unsigned rank = memrefType.getRank();
+ assert(rank == strides.size());
+ bool dynamicStrides = false;
+ for (int i = rank - 2; i >= 0; --i) {
+ // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
+ if (ShapedType::isDynamic(shape[i + 1]))
+ dynamicStrides = true;
+ // If stride at dim 'i' is not dynamic, return error.
+ if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
+ return failure();
+ }
+ return success();
+}
+
+static LogicalResult verify(ViewOp op) {
+ auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
+ auto viewType = op.getResult()->getType().cast<MemRefType>();
+
+ // The base memref should have identity layout map (or none).
+ if (baseType.getAffineMaps().size() > 1 ||
+ (baseType.getAffineMaps().size() == 1 &&
+ !baseType.getAffineMaps()[0].isIdentity()))
+ return op.emitError("unsupported map for base memref type ") << baseType;
+
+ // The base memref and the view memref should be in the same memory space.
+ if (baseType.getMemorySpace() != viewType.getMemorySpace())
+ return op.emitError("different memory spaces specified for base memref "
+ "type ")
+ << baseType << " and view memref type " << viewType;
+
+ // Verify that the result memref type has a strided layout map.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(viewType, strides, offset)))
+ return op.emitError("result type ") << viewType << " is not strided";
+
+ // Verify that we have the correct number of operands for the result type.
+ unsigned memrefOperandCount = 1;
+ unsigned numDynamicDims = viewType.getNumDynamicDims();
+ unsigned dynamicOffsetCount =
+ offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0;
+ if (op.getNumOperands() !=
+ memrefOperandCount + numDynamicDims + dynamicOffsetCount)
+ return op.emitError("incorrect number of operands for type ") << viewType;
+
+ // Verify dynamic strides symbols were added to correct dimensions based
+ // on dynamic sizes.
+ if (failed(verifyDynamicStrides(viewType, strides)))
+ return op.emitError("incorrect dynamic strides in view memref type ")
+ << viewType;
+ return success();
+}
+
+namespace {
+
+struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
+ using OpRewritePattern<ViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ViewOp viewOp,
+ PatternRewriter &rewriter) const override {
+ // Return if none of the operands are constants.
+ if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return matchFailure();
+
+ // Get result memref type.
+ auto memrefType = viewOp.getType();
+ if (memrefType.getAffineMaps().size() != 1)
+ return matchFailure();
+ auto map = memrefType.getAffineMaps()[0];
+
+ // Get offset from old memref view type 'memRefType'.
+ int64_t oldOffset;
+ SmallVector<int64_t, 4> oldStrides;
+ if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
+ return matchFailure();
+
+ SmallVector<Value, 4> newOperands;
+ SmallVector<Value, 4> droppedOperands;
+
+ // Fold dynamic offset operand if it is produced by a constant.
+ auto dynamicOffset = viewOp.getDynamicOffset();
+ int64_t newOffset = oldOffset;
+ unsigned dynamicOffsetOperandCount = 0;
+ if (dynamicOffset != nullptr) {
+ auto *defOp = dynamicOffset->getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic offset will be folded into the map.
+ newOffset = constantIndexOp.getValue();
+ droppedOperands.push_back(dynamicOffset);
+ } else {
+ // Unable to fold dynamic offset. Add it to 'newOperands' list.
+ newOperands.push_back(dynamicOffset);
+ dynamicOffsetOperandCount = 1;
+ }
+ }
+
+ // Fold any dynamic dim operands which are produced by a constant.
+ SmallVector<int64_t, 4> newShapeConstants;
+ newShapeConstants.reserve(memrefType.getRank());
+
+ unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart();
+ unsigned rank = memrefType.getRank();
+ for (unsigned dim = 0, e = rank; dim < e; ++dim) {
+ int64_t dimSize = memrefType.getDimSize(dim);
+ // If this is already static dimension, keep it.
+ if (!ShapedType::isDynamic(dimSize)) {
+ newShapeConstants.push_back(dimSize);
+ continue;
+ }
+ auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic shape dimension will be folded.
+ newShapeConstants.push_back(constantIndexOp.getValue());
+ // Record to check for zero uses later below.
+ droppedOperands.push_back(constantIndexOp);
+ } else {
+ // Dynamic shape dimension not folded; copy operand from old memref.
+ newShapeConstants.push_back(dimSize);
+ newOperands.push_back(viewOp.getOperand(dynamicDimPos));
+ }
+ dynamicDimPos++;
+ }
+
+ // Compute new strides based on 'newShapeConstants'.
+ SmallVector<int64_t, 4> newStrides(rank);
+ newStrides[rank - 1] = 1;
+ bool dynamicStrides = false;
+ for (int i = rank - 2; i >= 0; --i) {
+ if (ShapedType::isDynamic(newShapeConstants[i + 1]))
+ dynamicStrides = true;
+ if (dynamicStrides)
+ newStrides[i] = MemRefType::getDynamicStrideOrOffset();
+ else
+ newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
+ }
+
+ // Regenerate strided layout map with 'newStrides' and 'newOffset'.
+ map = makeStridedLinearLayoutMap(newStrides, newOffset,
+ rewriter.getContext());
+
+ // Create new memref type with constant folded dims and/or offset/strides.
+ auto newMemRefType =
+ MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
+ memrefType.getMemorySpace());
+ assert(static_cast<int64_t>(newOperands.size()) ==
+ dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
+
+ // Create new ViewOp.
+ auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
+ viewOp.getOperand(0), newOperands);
+ // Insert a cast so we have the same type as the old memref type.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
+ newViewOp, viewOp.getType());
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ViewOpShapeFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// SubViewOp
+//===----------------------------------------------------------------------===//
+
+// Returns a MemRefType with dynamic sizes and offset and the same stride as the
+// `memRefType` passed as argument.
+// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep
+// sizes and offset static.
+static Type inferSubViewResultType(MemRefType memRefType) {
+ auto rank = memRefType.getRank();
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ Type elementType = memRefType.getElementType();
+ auto res = getStridesAndOffset(memRefType, strides, offset);
+ assert(succeeded(res) && "SubViewOp expected strided memref type");
+ (void)res;
+
+ // Assume sizes and offset are fully dynamic for now until canonicalization
+ // occurs on the ranges. Typed strides don't change though.
+ offset = MemRefType::getDynamicStrideOrOffset();
+ // Overwrite strides because verifier will not pass.
+ // TODO(b/144419106): don't force degrade the strides to fully dynamic.
+ for (auto &stride : strides)
+ stride = MemRefType::getDynamicStrideOrOffset();
+ auto stridedLayout =
+ makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
+ SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
+ return MemRefType::get(sizes, elementType, stridedLayout,
+ memRefType.getMemorySpace());
+}
+
+void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
+ ValueRange offsets, ValueRange sizes,
+ ValueRange strides, Type resultType,
+ ArrayRef<NamedAttribute> attrs) {
+ if (!resultType)
+ resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
+ auto segmentAttr = b->getI32VectorAttr(
+ {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
+ static_cast<int32_t>(strides.size())});
+ build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
+ result.addAttributes(attrs);
+}
+
+void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
+ Value source) {
+ build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
+ resultType);
+}
+
+static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType srcInfo;
+ SmallVector<OpAsmParser::OperandType, 4> offsetsInfo;
+ SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
+ SmallVector<OpAsmParser::OperandType, 4> stridesInfo;
+ auto indexType = parser.getBuilder().getIndexType();
+ Type srcType, dstType;
+ if (parser.parseOperand(srcInfo) ||
+ parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
+ return failure();
+ }
+
+ auto builder = parser.getBuilder();
+ result.addAttribute(
+ SubViewOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
+ static_cast<int32_t>(sizesInfo.size()),
+ static_cast<int32_t>(stridesInfo.size())}));
+
+ return failure(
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(srcType) ||
+ parser.resolveOperand(srcInfo, srcType, result.operands) ||
+ parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
+ parser.resolveOperands(sizesInfo, indexType, result.operands) ||
+ parser.resolveOperands(stridesInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", dstType) ||
+ parser.addTypeToList(dstType, result.types));
+}
+
+static void print(OpAsmPrinter &p, SubViewOp op) {
+ p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
+ << "][" << op.sizes() << "][" << op.strides() << ']';
+
+ SmallVector<StringRef, 1> elidedAttrs = {
+ SubViewOp::getOperandSegmentSizeAttr()};
+ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+ p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
+}
+
+static LogicalResult verify(SubViewOp op) {
+ auto baseType = op.getBaseMemRefType().cast<MemRefType>();
+ auto subViewType = op.getType();
+
+ // The rank of the base and result subview must match.
+ if (baseType.getRank() != subViewType.getRank()) {
+ return op.emitError(
+ "expected rank of result type to match rank of base type ");
+ }
+
+ // The base memref and the view memref should be in the same memory space.
+ if (baseType.getMemorySpace() != subViewType.getMemorySpace())
+ return op.emitError("different memory spaces specified for base memref "
+ "type ")
+ << baseType << " and subview memref type " << subViewType;
+
+ // Verify that the base memref type has a strided layout map.
+ int64_t baseOffset;
+ SmallVector<int64_t, 4> baseStrides;
+ if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
+ return op.emitError("base type ") << subViewType << " is not strided";
+
+ // Verify that the result memref type has a strided layout map.
+ int64_t subViewOffset;
+ SmallVector<int64_t, 4> subViewStrides;
+ if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
+ return op.emitError("result type ") << subViewType << " is not strided";
+
+ // Num offsets should either be zero or rank of memref.
+ if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic offsets specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
+
+ // Num sizes should either be zero or rank of memref.
+ if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic sizes specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
+
+ // Num strides should either be zero or rank of memref.
+ if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic strides specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
+
+ // Verify that if the shape of the subview type is static, then sizes are not
+ // dynamic values, and vice versa.
+ if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
+ (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
+ return op.emitError("invalid to specify dynamic sizes when subview result "
+ "type is statically shaped and viceversa");
+ }
+
+ // Verify that if dynamic sizes are specified, then the result memref type
+ // have full dynamic dimensions.
+ if (op.getNumSizes() > 0) {
+ if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
+ return dim != ShapedType::kDynamicSize;
+ })) {
+ // TODO: This is based on the assumption that number of size arguments are
+ // either 0, or the rank of the result type. It is possible to have more
+ // fine-grained verification where only particular dimensions are
+ // dynamic. That probably needs further changes to the shape op
+ // specification.
+ return op.emitError("expected shape of result type to be fully dynamic "
+ "when sizes are specified");
+ }
+ }
+
+ // Verify that if dynamic offsets are specified or base memref has dynamic
+ // offset or base memref has dynamic strides, then the subview offset is
+ // dynamic.
+ if ((op.getNumOffsets() > 0 ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset())) &&
+ subViewOffset != MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result memref layout map to have dynamic offset");
+ }
+
+ // For now, verify that if dynamic strides are specified, then all the result
+ // memref type have dynamic strides.
+ if (op.getNumStrides() > 0) {
+ if (llvm::any_of(subViewStrides, [](int64_t stride) {
+ return stride != MemRefType::getDynamicStrideOrOffset();
+ })) {
+ return op.emitError("expected result type to have dynamic strides");
+ }
+ }
+
+ // If any of the base memref has dynamic stride, then the corresponding
+ // stride of the subview must also have dynamic stride.
+ assert(baseStrides.size() == subViewStrides.size());
+ for (auto stride : enumerate(baseStrides)) {
+ if (stride.value() == MemRefType::getDynamicStrideOrOffset() &&
+ subViewStrides[stride.index()] !=
+ MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result type to have dynamic stride along a dimension if "
+ "the base memref type has dynamic stride along that dimension");
+ }
+ }
+ return success();
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
+ return os << "range " << *range.offset << ":" << *range.size << ":"
+ << *range.stride;
+}
+
+SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
+ SmallVector<Range, 8> res;
+ unsigned rank = getType().getRank();
+ res.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
+ *(strides().begin() + i)});
+ return res;
+}
+
+LogicalResult
+SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
+ // If the strides are dynamic return failure.
+ if (getNumStrides())
+ return failure();
+
+ // When static, the stride operands can be retrieved by taking the strides of
+ // the result of the subview op, and dividing the strides of the base memref.
+ int64_t resultOffset, baseOffset;
+ SmallVector<int64_t, 2> resultStrides, baseStrides;
+ if (failed(
+ getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
+ llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
+ failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+ return failure();
+
+ assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
+ baseStrides.size() == resultStrides.size() &&
+ "base and result memrefs must have the same rank");
+ assert(!llvm::is_contained(resultStrides,
+ MemRefType::getDynamicStrideOrOffset()) &&
+ "strides of subview op must be static, when there are no dynamic "
+ "strides specified");
+ staticStrides.resize(getType().getRank());
+ for (auto resultStride : enumerate(resultStrides)) {
+ auto baseStride = baseStrides[resultStride.index()];
+ // The result stride is expected to be a multiple of the base stride. Abort
+ // if that is not the case.
+ if (resultStride.value() < baseStride ||
+ resultStride.value() % baseStride != 0)
+ return failure();
+ staticStrides[resultStride.index()] = resultStride.value() / baseStride;
+ }
+ return success();
+}
+
+static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) {
+ if (memrefType.getNumDynamicDims() > 0)
+ return false;
+ // Get offset and strides.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(memrefType, strides, offset)))
+ return false;
+ // Return 'false' if any of offset or strides is dynamic.
+ if (offset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
+ return false;
+ return true;
+}
+
+namespace {
+
+/// Pattern to rewrite a subview op with constant size arguments.
+class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ MemRefType subViewType = subViewOp.getType();
+ // Follow all or nothing approach for shapes for now. If all the operands
+ // for sizes are constants then fold it into the type of the result memref.
+ if (subViewType.hasStaticShape() ||
+ llvm::any_of(subViewOp.sizes(), [](Value operand) {
+ return !matchPattern(operand, m_ConstantIndex());
+ })) {
+ return matchFailure();
+ }
+ SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
+ for (auto size : llvm::enumerate(subViewOp.sizes())) {
+ auto defOp = size.value()->getDefiningOp();
+ assert(defOp);
+ staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
+ }
+ MemRefType newMemRefType = MemRefType::get(
+ staticShape, subViewType.getElementType(), subViewType.getAffineMaps(),
+ subViewType.getMemorySpace());
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(
+ subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType());
+ return matchSuccess();
+ }
+};
+
+// Pattern to rewrite a subview op with constant stride arguments.
+class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ if (subViewOp.getNumStrides() == 0) {
+ return matchFailure();
+ }
+ // Follow all or nothing approach for strides for now. If all the operands
+ // for strides are constants then fold it into the strides of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ llvm::any_of(subViewOp.strides(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
+ return matchFailure();
+ }
+
+ SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
+ for (auto stride : llvm::enumerate(subViewOp.strides())) {
+ auto defOp = stride.value()->getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[stride.index()] > 0);
+ staticStrides[stride.index()] =
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()];
+ }
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ staticStrides, resultOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
+ layoutMap, subViewType.getMemorySpace());
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(
+ subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType());
+ return matchSuccess();
+ }
+};
+
+// Pattern to rewrite a subview op with constant offset arguments.
+class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ if (subViewOp.getNumOffsets() == 0) {
+ return matchFailure();
+ }
+ // Follow all or nothing approach for offsets for now. If all the operands
+ // for offsets are constants then fold it into the offset of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::any_of(subViewOp.offsets(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
+ return matchFailure();
+ }
+
+ auto staticOffset = baseOffset;
+ for (auto offset : llvm::enumerate(subViewOp.offsets())) {
+ auto defOp = offset.value()->getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[offset.index()] > 0);
+ staticOffset +=
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()];
+ }
+
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ resultStrides, staticOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
+ layoutMap, subViewType.getMemorySpace());
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
+ subViewOp.sizes(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(
+ subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType());
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
+ SubViewOpOffsetFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// ZeroExtendIOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ZeroExtendIOp op) {
+ auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto dstType = getElementTypeOrSelf(op.getType());
+
+ if (srcType.isa<IndexType>())
+ return op.emitError() << srcType << " is not a valid operand type";
+ if (dstType.isa<IndexType>())
+ return op.emitError() << dstType << " is not a valid result type";
+
+ if (srcType.cast<IntegerType>().getWidth() >=
+ dstType.cast<IntegerType>().getWidth())
+ return op.emitError("result type ")
+ << dstType << " must be wider than operand type " << srcType;
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FPExtOp
+//===----------------------------------------------------------------------===//
+
+bool FPExtOp::areCastCompatible(Type a, Type b) {
+ if (auto fa = a.dyn_cast<FloatType>())
+ if (auto fb = b.dyn_cast<FloatType>())
+ return fa.getWidth() < fb.getWidth();
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// FPTruncOp
+//===----------------------------------------------------------------------===//
+
+bool FPTruncOp::areCastCompatible(Type a, Type b) {
+ if (auto fa = a.dyn_cast<FloatType>())
+ if (auto fb = b.dyn_cast<FloatType>())
+ return fa.getWidth() > fb.getWidth();
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/StandardOps/Ops.cpp.inc"
OpenPOWER on IntegriCloud