diff options
Diffstat (limited to 'mlir/lib/Dialect/StandardOps/Ops.cpp')
-rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 3000 |
1 files changed, 3000 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp new file mode 100644 index 00000000000..831c78a4521 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -0,0 +1,3000 @@ +//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Ops.h" + +#include "mlir/Dialect/CommonFolders.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" + +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// StandardOpsDialect Interfaces +//===----------------------------------------------------------------------===// +namespace { +/// This class defines the interface for handling inlining with standard +/// operations. +struct StdInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within standard ops can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only "std.return" needs to be handled here. + auto returnOp = dyn_cast<ReturnOp>(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef<Value> valuesToRepl) const final { + // Only "std.return" needs to be handled here. + auto returnOp = cast<ReturnOp>(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// StandardOpsDialect +//===----------------------------------------------------------------------===// + +/// A custom unary operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumOperands() == 1 && "unary op should have one operand"); + assert(op->getNumResults() == 1 && "unary op should have one result"); + + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' + << *op->getOperand(0); + p.printOptionalAttrDict(op->getAttrs()); + p << " : " << op->getOperand(0)->getType(); +} + +/// A custom binary operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumOperands() == 2 && "binary op should have two operands"); + assert(op->getNumResults() == 1 && "binary op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0)->getType(); + if (op->getOperand(0)->getType() != resultType || + op->getOperand(1)->getType() != resultType) { + p.printGenericOp(op); + return; + } + + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' + << *op->getOperand(0) << ", " << *op->getOperand(1); + p.printOptionalAttrDict(op->getAttrs()); + + // Now we can output only one type for all operands and the result. + p << " : " << op->getResult(0)->getType(); +} + +/// A custom cast operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' + << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " + << op->getResult(0)->getType(); +} + +/// A custom cast operation verifier. +template <typename T> static LogicalResult verifyCastOp(T op) { + auto opType = op.getOperand()->getType(); + auto resType = op.getType(); + if (!T::areCastCompatible(opType, resType)) + return op.emitError("operand type ") << opType << " and result type " + << resType << " are cast incompatible"; + + return success(); +} + +StandardOpsDialect::StandardOpsDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations<DmaStartOp, DmaWaitOp, +#define GET_OP_LIST +#include "mlir/Dialect/StandardOps/Ops.cpp.inc" + >(); + addInterfaces<StdInlinerInterface>(); +} + +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create<ConstantOp>(loc, type, value); +} + +void mlir::printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter &p) { + Operation::operand_range operands(begin, end); + p << '(' << operands.take_front(numDims) << ')'; + if (operands.size() != numDims) + p << '[' << operands.drop_front(numDims) << ']'; +} + +// Parses dimension and symbol list, and sets 'numDims' to the number of +// dimension operands parsed. +// Returns 'false' on success and 'true' on error. +ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, + SmallVectorImpl<Value> &operands, + unsigned &numDims) { + SmallVector<OpAsmParser::OperandType, 8> opInfos; + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + return failure(); + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto indexTy = parser.getBuilder().getIndexType(); + if (parser.parseOperandList(opInfos, + OpAsmParser::Delimiter::OptionalSquare) || + parser.resolveOperands(opInfos, indexTy, operands)) + return failure(); + return success(); +} + +/// Matches a ConstantIndexOp. +/// TODO: This should probably just be a general matcher that uses m_Constant +/// and checks the operation for an index type. +static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { + return detail::op_matcher<ConstantIndexOp>(); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp()); + if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) { + operand.set(cast.getOperand()); + folded = true; + } + } + return success(folded); +} + +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) { + return constFoldBinaryOp<FloatAttr>( + operands, [](APFloat a, APFloat b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) { + /// addi(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, AllocOp op) { + p << "alloc"; + + // Print dynamic dimension operands. + MemRefType type = op.getType(); + printDimAndSymbolList(op.operand_begin(), op.operand_end(), + type.getNumDynamicDims(), p); + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); + p << " : " << type; +} + +static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) { + MemRefType type; + + // Parse the dimension operands and optional symbol operands, followed by a + // memref type. + unsigned numDimOperands; + if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return failure(); + + // Check numDynamicDims against number of question marks in memref type. + // Note: this check remains here (instead of in verify()), because the + // partition between dim operands and symbol operands is lost after parsing. + // Verification still checks that the total number of operands matches + // the number of symbols in the affine map, plus the number of dynamic + // dimensions in the memref. + if (numDimOperands != type.getNumDynamicDims()) + return parser.emitError(parser.getNameLoc()) + << "dimension operand count does not equal memref dynamic dimension " + "count"; + result.types.push_back(type); + return success(); +} + +static LogicalResult verify(AllocOp op) { + auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>(); + if (!memRefType) + return op.emitOpError("result must be a memref"); + + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) { + // Store number of symbols used in affine map (used in subsequent check). + AffineMap affineMap = memRefType.getAffineMaps()[0]; + numSymbols = affineMap.getNumSymbols(); + } + + // Check that the total number of operands matches the number of symbols in + // the affine map, plus the number of dynamic dimensions specified in the + // memref type. + unsigned numDynamicDims = memRefType.getNumDynamicDims(); + if (op.getNumOperands() != numDynamicDims + numSymbols) + return op.emitOpError( + "operand count does not equal dimension plus symbol operand count"); + + // Verify that all operands are of type Index. + for (auto operandType : op.getOperandTypes()) + if (!operandType.isIndex()) + return op.emitOpError("requires operands to be of type Index"); + return success(); +} + +namespace { +/// Fold constant dimensions into an alloc operation. +struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { + using OpRewritePattern<AllocOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { + // Check to see if any dimensions operands are constants. If so, we can + // substitute and drop them. + if (llvm::none_of(alloc.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); + + auto memrefType = alloc.getType(); + + // Ok, we have one or more constant operands. Collect the non-constant ones + // and keep track of the resultant memref type to build. + SmallVector<int64_t, 4> newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + SmallVector<Value, 4> newOperands; + SmallVector<Value, 4> droppedOperands; + + unsigned dynamicDimPos = 0; + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (dimSize != -1) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + // Record to check for zero uses later below. + droppedOperands.push_back(constantIndexOp); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(-1); + newOperands.push_back(alloc.getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Create new memref type (which will have fewer dynamic dimensions). + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(static_cast<int64_t>(newOperands.size()) == + newMemRefType.getNumDynamicDims()); + + // Create and insert the alloc op for the new memref. + auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, + newOperands, IntegerAttr()); + // Insert a cast so we have the same type as the old alloc. + auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, + alloc.getType()); + + rewriter.replaceOp(alloc, {resultCast}, droppedOperands); + return matchSuccess(); + } +}; + +/// Fold alloc operations with no uses. Alloc has side effects on the heap, +/// but can still be deleted if it has zero uses. +struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { + using OpRewritePattern<AllocOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { + if (alloc.use_empty()) { + rewriter.eraseOp(alloc); + return matchSuccess(); + } + return matchFailure(); + } +}; +} // end anonymous namespace. + +void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context); +} + +//===----------------------------------------------------------------------===// +// BranchOp +//===----------------------------------------------------------------------===// + +namespace { +/// Simplify a branch to a block that has a single predecessor. This effectively +/// merges the two blocks. +struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> { + using OpRewritePattern<BranchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(BranchOp op, + PatternRewriter &rewriter) const override { + // Check that the successor block has a single predecessor. + Block *succ = op.getDest(); + Block *opParent = op.getOperation()->getBlock(); + if (succ == opParent || !has_single_element(succ->getPredecessors())) + return matchFailure(); + + // Merge the successor into the current block and erase the branch. + rewriter.mergeBlocks(succ, opParent, op.getOperands()); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { + Block *dest; + SmallVector<Value, 4> destOperands; + if (parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + return success(); +} + +static void print(OpAsmPrinter &p, BranchOp op) { + p << "br "; + p.printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getSuccessor(0); } + +void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SimplifyBrToBlockWithSinglePred>(context); +} + +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { + FlatSymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", result.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), result.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, CallOp op) { + p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCalleeType(); +} + +static LogicalResult verify(CallOp op) { + // Check that the callee attribute was specified. + auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee"); + if (!fnAttr) + return op.emitOpError("requires a 'callee' symbol reference attribute"); + auto fn = + op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); + if (!fn) + return op.emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getType(); + if (fnType.getNumInputs() != op.getNumOperands()) + return op.emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (op.getOperand(i)->getType() != fnType.getInput(i)) + return op.emitOpError("operand type mismatch"); + + if (fnType.getNumResults() != op.getNumResults()) + return op.emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (op.getResult(i)->getType() != fnType.getResult(i)) + return op.emitOpError("result type mismatch"); + + return success(); +} + +FunctionType CallOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// CallIndirectOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold indirect calls that have a constant function as the callee operand. +struct SimplifyIndirectCallWithKnownCallee + : public OpRewritePattern<CallIndirectOp> { + using OpRewritePattern<CallIndirectOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, + PatternRewriter &rewriter) const override { + // Check that the callee is a constant callee. + SymbolRefAttr calledFn; + if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) + return matchFailure(); + + // Replace with a direct call. + SmallVector<Type, 8> callResults(indirectCall.getResultTypes()); + rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults, + indirectCall.getArgOperands()); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static ParseResult parseCallIndirectOp(OpAsmParser &parser, + OperationState &result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector<OpAsmParser::OperandType, 4> operands; + return failure( + parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.resolveOperand(callee, calleeType, result.operands) || + parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result.operands) || + parser.addTypesToList(calleeType.getResults(), result.types)); +} + +static void print(OpAsmPrinter &p, CallIndirectOp op) { + p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCallee()->getType(); +} + +static LogicalResult verify(CallIndirectOp op) { + // The callee must be a function. + auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>(); + if (!fnType) + return op.emitOpError("callee must have function type"); + + // Verify that the operand and result types match the callee. + if (fnType.getNumInputs() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) + return op.emitOpError("operand type mismatch"); + + if (fnType.getNumResults() != op.getNumResults()) + return op.emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (op.getResult(i)->getType() != fnType.getResult(i)) + return op.emitOpError("result type mismatch"); + + return success(); +} + +void CallIndirectOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<SimplifyIndirectCallWithKnownCallee>(context); +} + +//===----------------------------------------------------------------------===// +// General helpers for comparison ops +//===----------------------------------------------------------------------===// + +// Return the type of the same shape (scalar, vector or tensor) containing i1. +static Type getCheckedI1SameShape(Builder *build, Type type) { + auto i1Type = build->getI1Type(); + if (type.isIntOrIndexOrFloat()) + return i1Type; + if (auto tensorType = type.dyn_cast<RankedTensorType>()) + return RankedTensorType::get(tensorType.getShape(), i1Type); + if (type.isa<UnrankedTensorType>()) + return UnrankedTensorType::get(i1Type); + if (auto vectorType = type.dyn_cast<VectorType>()) + return VectorType::get(vectorType.getShape(), i1Type); + return Type(); +} + +static Type getI1SameShape(Builder *build, Type type) { + Type res = getCheckedI1SameShape(build, type); + assert(res && "expected type with valid i1 shape"); + return res; +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +static void buildCmpIOp(Builder *build, OperationState &result, + CmpIPredicate predicate, Value lhs, Value rhs) { + result.addOperands({lhs, rhs}); + result.types.push_back(getI1SameShape(build, lhs->getType())); + result.addAttribute( + CmpIOp::getPredicateAttrName(), + build->getI64IntegerAttr(static_cast<int64_t>(predicate))); +} + +static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 2> ops; + SmallVector<NamedAttribute, 4> attrs; + Attribute predicateNameAttr; + Type type; + if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), + attrs) || + parser.parseComma() || parser.parseOperandList(ops, 2) || + parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands)) + return failure(); + + if (!predicateNameAttr.isa<StringAttr>()) + return parser.emitError(parser.getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); + Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName); + if (!predicate.hasValue()) + return parser.emitError(parser.getNameLoc()) + << "unknown comparison predicate \"" << predicateName << "\""; + + auto builder = parser.getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser.emitError(parser.getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate)); + result.attributes = attrs; + + result.addTypes({i1Type}); + return success(); +} + +static void print(OpAsmPrinter &p, CmpIOp op) { + p << "cmpi "; + + Builder b(op.getContext()); + auto predicateValue = + op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); + p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue)) + << '"' << ", " << op.lhs() << ", " << op.rhs(); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); + p << " : " << op.lhs()->getType(); +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +// comparison predicates. +static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, + const APInt &rhs) { + switch (predicate) { + case CmpIPredicate::eq: + return lhs.eq(rhs); + case CmpIPredicate::ne: + return lhs.ne(rhs); + case CmpIPredicate::slt: + return lhs.slt(rhs); + case CmpIPredicate::sle: + return lhs.sle(rhs); + case CmpIPredicate::sgt: + return lhs.sgt(rhs); + case CmpIPredicate::sge: + return lhs.sge(rhs); + case CmpIPredicate::ult: + return lhs.ult(rhs); + case CmpIPredicate::ule: + return lhs.ule(rhs); + case CmpIPredicate::ugt: + return lhs.ugt(rhs); + case CmpIPredicate::uge: + return lhs.uge(rhs); + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "cmpi takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); + auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +// Returns an array of mnemonics for CmpFPredicates indexed by values thereof. +static inline const char *const *getCmpFPredicateNames() { + static const char *predicateNames[] = { + /*AlwaysFalse*/ "false", + /*OEQ*/ "oeq", + /*OGT*/ "ogt", + /*OGE*/ "oge", + /*OLT*/ "olt", + /*OLE*/ "ole", + /*ONE*/ "one", + /*ORD*/ "ord", + /*UEQ*/ "ueq", + /*UGT*/ "ugt", + /*UGE*/ "uge", + /*ULT*/ "ult", + /*ULE*/ "ule", + /*UNE*/ "une", + /*UNO*/ "uno", + /*AlwaysTrue*/ "true", + }; + static_assert(std::extent<decltype(predicateNames)>::value == + (size_t)CmpFPredicate::NumPredicates, + "wrong number of predicate names"); + return predicateNames; +} + +// Returns a value of the predicate corresponding to the given mnemonic. +// Returns NumPredicates (one-past-end) if there is no such mnemonic. +CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { + return llvm::StringSwitch<CmpFPredicate>(name) + .Case("false", CmpFPredicate::AlwaysFalse) + .Case("oeq", CmpFPredicate::OEQ) + .Case("ogt", CmpFPredicate::OGT) + .Case("oge", CmpFPredicate::OGE) + .Case("olt", CmpFPredicate::OLT) + .Case("ole", CmpFPredicate::OLE) + .Case("one", CmpFPredicate::ONE) + .Case("ord", CmpFPredicate::ORD) + .Case("ueq", CmpFPredicate::UEQ) + .Case("ugt", CmpFPredicate::UGT) + .Case("uge", CmpFPredicate::UGE) + .Case("ult", CmpFPredicate::ULT) + .Case("ule", CmpFPredicate::ULE) + .Case("une", CmpFPredicate::UNE) + .Case("uno", CmpFPredicate::UNO) + .Case("true", CmpFPredicate::AlwaysTrue) + .Default(CmpFPredicate::NumPredicates); +} + +static void buildCmpFOp(Builder *build, OperationState &result, + CmpFPredicate predicate, Value lhs, Value rhs) { + result.addOperands({lhs, rhs}); + result.types.push_back(getI1SameShape(build, lhs->getType())); + result.addAttribute( + CmpFOp::getPredicateAttrName(), + build->getI64IntegerAttr(static_cast<int64_t>(predicate))); +} + +static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 2> ops; + SmallVector<NamedAttribute, 4> attrs; + Attribute predicateNameAttr; + Type type; + if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), + attrs) || + parser.parseComma() || parser.parseOperandList(ops, 2) || + parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands)) + return failure(); + + if (!predicateNameAttr.isa<StringAttr>()) + return parser.emitError(parser.getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); + auto predicate = CmpFOp::getPredicateByName(predicateName); + if (predicate == CmpFPredicate::NumPredicates) + return parser.emitError(parser.getNameLoc(), + "unknown comparison predicate \"" + predicateName + + "\""); + + auto builder = parser.getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser.emitError(parser.getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); + result.attributes = attrs; + + result.addTypes({i1Type}); + return success(); +} + +static void print(OpAsmPrinter &p, CmpFOp op) { + p << "cmpf "; + + auto predicateValue = + op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt(); + assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && + predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && + "unknown predicate index"); + p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs() + << ", " << op.rhs(); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); + p << " : " << op.lhs()->getType(); +} + +static LogicalResult verify(CmpFOp op) { + auto predicateAttr = + op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); + if (!predicateAttr) + return op.emitOpError("requires an integer attribute named 'predicate'"); + auto predicate = predicateAttr.getInt(); + if (predicate < (int64_t)CmpFPredicate::FirstValidValue || + predicate >= (int64_t)CmpFPredicate::NumPredicates) + return op.emitOpError("'predicate' attribute value out of range"); + + return success(); +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +// comparison predicates. +static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, + const APFloat &rhs) { + auto cmpResult = lhs.compare(rhs); + switch (predicate) { + case CmpFPredicate::AlwaysFalse: + return false; + case CmpFPredicate::OEQ: + return cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OGT: + return cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::OGE: + return cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OLT: + return cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::OLE: + return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ONE: + return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; + case CmpFPredicate::ORD: + return cmpResult != APFloat::cmpUnordered; + case CmpFPredicate::UEQ: + return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UGT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::UGE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ULT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::ULE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UNE: + return cmpResult != APFloat::cmpEqual; + case CmpFPredicate::UNO: + return cmpResult == APFloat::cmpUnordered; + case CmpFPredicate::AlwaysTrue: + return true; + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "cmpf takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); + auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); + + // TODO(gcmn) We could actually do some intelligent things if we know only one + // of the operands, but it's inf or nan. + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// +// CondBranchOp +//===----------------------------------------------------------------------===// + +namespace { +/// cond_br true, ^bb1, ^bb2 -> br ^bb1 +/// cond_br false, ^bb1, ^bb2 -> br ^bb2 +/// +struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + if (matchPattern(condbr.getCondition(), m_NonZero())) { + // True branch taken. + rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), + condbr.getTrueOperands()); + return matchSuccess(); + } else if (matchPattern(condbr.getCondition(), m_Zero())) { + // False branch taken. + rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), + condbr.getFalseOperands()); + return matchSuccess(); + } + return matchFailure(); + } +}; +} // end anonymous namespace. + +static ParseResult parseCondBranchOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<Value, 4> destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser.getBuilder().getI1Type(); + if (parser.parseOperand(condInfo) || parser.parseComma() || + parser.resolveOperand(condInfo, int1Ty, result.operands)) { + return parser.emitError(parser.getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + return success(); +} + +static void print(OpAsmPrinter &p, CondBranchOp op) { + p << "cond_br " << op.getCondition() << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +void CondBranchOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<SimplifyConstCondBranchPred>(context); +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, ConstantOp &op) { + p << "constant "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); + + if (op.getAttrs().size() > 1) + p << ' '; + p << op.getValue(); + + // If the value is a symbol reference, print a trailing type. + if (op.getValue().isa<SymbolRefAttr>()) + p << " : " << op.getType(); +} + +static ParseResult parseConstantOp(OpAsmParser &parser, + OperationState &result) { + Attribute valueAttr; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(valueAttr, "value", result.attributes)) + return failure(); + + // If the attribute is a symbol reference, then we expect a trailing type. + Type type; + if (!valueAttr.isa<SymbolRefAttr>()) + type = valueAttr.getType(); + else if (parser.parseColonType(type)) + return failure(); + + // Add the attribute type to the list. + return parser.addTypeToList(type, result.types); +} + +/// The constant op requires an attribute, and furthermore requires that it +/// matches the return type. +static LogicalResult verify(ConstantOp &op) { + auto value = op.getValue(); + if (!value) + return op.emitOpError("requires a 'value' attribute"); + + auto type = op.getType(); + if (!value.getType().isa<NoneType>() && type != value.getType()) + return op.emitOpError() << "requires attribute's type (" << value.getType() + << ") to match op's return type (" << type << ")"; + + if (type.isa<IndexType>() || value.isa<BoolAttr>()) + return success(); + + if (auto intAttr = value.dyn_cast<IntegerAttr>()) { + // If the type has a known bitwidth we verify that the value can be + // represented with the given bitwidth. + auto bitwidth = type.cast<IntegerType>().getWidth(); + auto intVal = intAttr.getValue(); + if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) + return op.emitOpError("requires 'value' to be an integer within the " + "range of the integer result type"); + return success(); + } + + if (type.isa<FloatType>()) { + if (!value.isa<FloatAttr>()) + return op.emitOpError("requires 'value' to be a floating point constant"); + return success(); + } + + if (type.isa<ShapedType>()) { + if (!value.isa<ElementsAttr>()) + return op.emitOpError("requires 'value' to be a shaped constant"); + return success(); + } + + if (type.isa<FunctionType>()) { + auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>(); + if (!fnAttr) + return op.emitOpError("requires 'value' to be a function reference"); + + // Try to find the referenced function. + auto fn = + op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); + if (!fn) + return op.emitOpError("reference to undefined function 'bar'"); + + // Check that the referenced function has the correct type. + if (fn.getType() != type) + return op.emitOpError("reference to function with mismatched type"); + + return success(); + } + + if (type.isa<NoneType>() && value.isa<UnitAttr>()) + return success(); + + return op.emitOpError("unsupported 'value' attribute: ") << value; +} + +OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +void ConstantOp::getAsmResultNames( + function_ref<void(Value, StringRef)> setNameFn) { + Type type = getType(); + if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { + IntegerType intTy = type.dyn_cast<IntegerType>(); + + // Sugar i1 constants with 'true' and 'false'. + if (intTy && intTy.getWidth() == 1) + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + + // Otherwise, build a complex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst.getInt(); + if (intTy) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + + } else if (type.isa<FunctionType>()) { + setNameFn(getResult(), "f"); + } else { + setNameFn(getResult(), "cst"); + } +} + +/// Returns true if a constant operation can be built with the given value and +/// result type. +bool ConstantOp::isBuildableWith(Attribute value, Type type) { + // SymbolRefAttr can only be used with a function type. + if (value.isa<SymbolRefAttr>()) + return type.isa<FunctionType>(); + // Otherwise, the attribute must have the same type as 'type'. + if (value.getType() != type) + return false; + // Finally, check that the attribute kind is handled. + return value.isa<BoolAttr>() || value.isa<IntegerAttr>() || + value.isa<FloatAttr>() || value.isa<ElementsAttr>() || + value.isa<UnitAttr>(); +} + +void ConstantFloatOp::build(Builder *builder, OperationState &result, + const APFloat &value, FloatType type) { + ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); +} + +bool ConstantFloatOp::classof(Operation *op) { + return ConstantOp::classof(op) && + op->getResult(0)->getType().isa<FloatType>(); +} + +/// ConstantIntOp only matches values whose result type is an IntegerType. +bool ConstantIntOp::classof(Operation *op) { + return ConstantOp::classof(op) && + op->getResult(0)->getType().isa<IntegerType>(); +} + +void ConstantIntOp::build(Builder *builder, OperationState &result, + int64_t value, unsigned width) { + Type type = builder->getIntegerType(width); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// Build a constant int op producing an integer with the specified type, +/// which must be an integer type. +void ConstantIntOp::build(Builder *builder, OperationState &result, + int64_t value, Type type) { + assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +/// ConstantIndexOp only matches values whose result type is Index. +bool ConstantIndexOp::classof(Operation *op) { + return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); +} + +void ConstantIndexOp::build(Builder *builder, OperationState &result, + int64_t value) { + Type type = builder->getIndexType(); + ConstantOp::build(builder, result, type, + builder->getIntegerAttr(type, value)); +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold Dealloc operations that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { + using OpRewritePattern<DeallocOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(DeallocOp dealloc, + PatternRewriter &rewriter) const override { + // Check that the memref operand's defining operation is an AllocOp. + Value memref = dealloc.memref(); + if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp())) + return matchFailure(); + + // Check that all of the uses of the AllocOp are other DeallocOps. + for (auto *user : memref->getUsers()) + if (!isa<DeallocOp>(user)) + return matchFailure(); + + // Erase the dealloc operation. + rewriter.eraseOp(dealloc); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +static void print(OpAsmPrinter &p, DeallocOp op) { + p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); +} + +static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType memrefInfo; + MemRefType type; + + return failure(parser.parseOperand(memrefInfo) || + parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands)); +} + +static LogicalResult verify(DeallocOp op) { + if (!op.memref()->getType().isa<MemRefType>()) + return op.emitOpError("operand must be a memref"); + return success(); +} + +void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SimplifyDeadDealloc>(context); +} + +LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// dealloc(memrefcast) -> dealloc + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, DimOp op) { + p << "dim " << *op.getOperand() << ", " << op.getIndex(); + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); + p << " : " << op.getOperand()->getType(); +} + +static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType operandInfo; + IntegerAttr indexAttr; + Type type; + Type indexType = parser.getBuilder().getIndexType(); + + return failure( + parser.parseOperand(operandInfo) || parser.parseComma() || + parser.parseAttribute(indexAttr, indexType, "index", result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(operandInfo, type, result.operands) || + parser.addTypeToList(indexType, result.types)); +} + +static LogicalResult verify(DimOp op) { + // Check that we have an integer index operand. + auto indexAttr = op.getAttrOfType<IntegerAttr>("index"); + if (!indexAttr) + return op.emitOpError("requires an integer attribute named 'index'"); + int64_t index = indexAttr.getValue().getSExtValue(); + + auto type = op.getOperand()->getType(); + if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + if (index >= tensorType.getRank()) + return op.emitOpError("index is out of range"); + } else if (auto memrefType = type.dyn_cast<MemRefType>()) { + if (index >= memrefType.getRank()) + return op.emitOpError("index is out of range"); + + } else if (type.isa<UnrankedTensorType>()) { + // ok, assumed to be in-range. + } else { + return op.emitOpError("requires an operand with tensor or memref type"); + } + + return success(); +} + +OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { + // Constant fold dim when the size along the index referred to is a constant. + auto opType = memrefOrTensor()->getType(); + int64_t indexSize = -1; + if (auto tensorType = opType.dyn_cast<RankedTensorType>()) + indexSize = tensorType.getShape()[getIndex()]; + else if (auto memrefType = opType.dyn_cast<MemRefType>()) + indexSize = memrefType.getShape()[getIndex()]; + + if (!ShapedType::isDynamic(indexSize)) + return IntegerAttr::get(IndexType::get(getContext()), indexSize); + + // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp. + auto memrefType = opType.dyn_cast<MemRefType>(); + if (!memrefType) + return {}; + + // The size at getIndex() is now a dynamic size of a memref. + auto memref = memrefOrTensor()->getDefiningOp(); + if (auto alloc = dyn_cast_or_null<AllocOp>(memref)) + return *(alloc.getDynamicSizes().begin() + + memrefType.getDynamicDimIndex(getIndex())); + + if (auto view = dyn_cast_or_null<ViewOp>(memref)) + return *(view.getDynamicSizes().begin() + + memrefType.getDynamicDimIndex(getIndex())); + + // The subview op here is expected to have rank dynamic sizes now. + if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) { + auto sizes = subview.sizes(); + if (!sizes.empty()) + return *(sizes.begin() + getIndex()); + } + + /// dim(memrefcast) -> dim + if (succeeded(foldMemRefCast(*this))) + return getResult(); + + return {}; +} + +//===----------------------------------------------------------------------===// +// SignedDivIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + return a.sdiv_ov(b, overflowOrDiv0); + }); + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// UnsignedDivIOp +//===----------------------------------------------------------------------===// + +OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { + if (div0 || !b) { + div0 = true; + return a; + } + return a.udiv(b); + }); + return div0 ? Attribute() : result; +} + +// --------------------------------------------------------------------------- +// DmaStartOp +// --------------------------------------------------------------------------- + +void DmaStartOp::build(Builder *builder, OperationState &result, + Value srcMemRef, ValueRange srcIndices, Value destMemRef, + ValueRange destIndices, Value numElements, + Value tagMemRef, ValueRange tagIndices, Value stride, + Value elementsPerStride) { + result.addOperands(srcMemRef); + result.addOperands(srcIndices); + result.addOperands(destMemRef); + result.addOperands(destIndices); + result.addOperands({numElements, tagMemRef}); + result.addOperands(tagIndices); + if (stride) + result.addOperands({stride, elementsPerStride}); +} + +void DmaStartOp::print(OpAsmPrinter &p) { + p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], " + << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements() + << ", " << *getTagMemRef() << '[' << getTagIndices() << ']'; + if (isStrided()) + p << ", " << *getStride() << ", " << *getNumElementsPerStride(); + + p.printOptionalAttrDict(getAttrs()); + p << " : " << getSrcMemRef()->getType(); + p << ", " << getDstMemRef()->getType(); + p << ", " << getTagMemRef()->getType(); +} + +// Parse DmaStartOp. +// Ex: +// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, +// %tag[%index], %stride, %num_elt_per_stride : +// : memref<3076 x f32, 0>, +// memref<1024 x f32, 2>, +// memref<1 x i32> +// +ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcMemRefInfo; + SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; + OpAsmParser::OperandType dstMemRefInfo; + SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; + OpAsmParser::OperandType numElementsInfo; + OpAsmParser::OperandType tagMemrefInfo; + SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; + SmallVector<OpAsmParser::OperandType, 2> strideInfo; + + SmallVector<Type, 3> types; + auto indexType = parser.getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) source memref followed by its indices (in square brackets). + // *) destination memref followed by its indices (in square brackets). + // *) dma size in KiB. + if (parser.parseOperand(srcMemRefInfo) || + parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseOperand(dstMemRefInfo) || + parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseOperand(numElementsInfo) || + parser.parseComma() || parser.parseOperand(tagMemrefInfo) || + parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) + return failure(); + + // Parse optional stride and elements per stride. + if (parser.parseTrailingOperandList(strideInfo)) + return failure(); + + bool isStrided = strideInfo.size() == 2; + if (!strideInfo.empty() && !isStrided) { + return parser.emitError(parser.getNameLoc(), + "expected two stride related operands"); + } + + if (parser.parseColonTypeList(types)) + return failure(); + if (types.size() != 3) + return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); + + if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || + parser.resolveOperands(srcIndexInfos, indexType, result.operands) || + parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || + parser.resolveOperands(dstIndexInfos, indexType, result.operands) || + // size should be an index. + parser.resolveOperand(numElementsInfo, indexType, result.operands) || + parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || + // tag indices should be index. + parser.resolveOperands(tagIndexInfos, indexType, result.operands)) + return failure(); + + auto memrefType0 = types[0].dyn_cast<MemRefType>(); + if (!memrefType0) + return parser.emitError(parser.getNameLoc(), + "expected source to be of memref type"); + + auto memrefType1 = types[1].dyn_cast<MemRefType>(); + if (!memrefType1) + return parser.emitError(parser.getNameLoc(), + "expected destination to be of memref type"); + + auto memrefType2 = types[2].dyn_cast<MemRefType>(); + if (!memrefType2) + return parser.emitError(parser.getNameLoc(), + "expected tag to be of memref type"); + + if (isStrided) { + if (parser.resolveOperands(strideInfo, indexType, result.operands)) + return failure(); + } + + // Check that source/destination index list size matches associated rank. + if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() || + static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank()) + return parser.emitError(parser.getNameLoc(), + "memref rank not equal to indices count"); + if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank()) + return parser.emitError(parser.getNameLoc(), + "tag memref rank not equal to indices count"); + + return success(); +} + +LogicalResult DmaStartOp::verify() { + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) + return emitOpError("DMA should be between different memory spaces"); + + if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 && + getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + + getDstMemRefRank() + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + return success(); +} + +LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// dma_start(memrefcast) -> dma_start + return foldMemRefCast(*this); +} + +// --------------------------------------------------------------------------- +// DmaWaitOp +// --------------------------------------------------------------------------- + +void DmaWaitOp::build(Builder *builder, OperationState &result, Value tagMemRef, + ValueRange tagIndices, Value numElements) { + result.addOperands(tagMemRef); + result.addOperands(tagIndices); + result.addOperands(numElements); +} + +void DmaWaitOp::print(OpAsmPrinter &p) { + p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " + << getNumElements(); + p.printOptionalAttrDict(getAttrs()); + p << " : " << getTagMemRef()->getType(); +} + +// Parse DmaWaitOp. +// Eg: +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> +// +ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType tagMemrefInfo; + SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; + Type type; + auto indexType = parser.getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its indices, and dma size. + if (parser.parseOperand(tagMemrefInfo) || + parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseOperand(numElementsInfo) || + parser.parseColonType(type) || + parser.resolveOperand(tagMemrefInfo, type, result.operands) || + parser.resolveOperands(tagIndexInfos, indexType, result.operands) || + parser.resolveOperand(numElementsInfo, indexType, result.operands)) + return failure(); + + auto memrefType = type.dyn_cast<MemRefType>(); + if (!memrefType) + return parser.emitError(parser.getNameLoc(), + "expected tag to be of memref type"); + + if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank()) + return parser.emitError(parser.getNameLoc(), + "tag memref rank not equal to indices count"); + + return success(); +} + +LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// dma_wait(memrefcast) -> dma_wait + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, ExtractElementOp op) { + p << "extract_element " << *op.getAggregate() << '[' << op.getIndices(); + p << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getAggregate()->getType(); +} + +static ParseResult parseExtractElementOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType aggregateInfo; + SmallVector<OpAsmParser::OperandType, 4> indexInfo; + ShapedType type; + + auto indexTy = parser.getBuilder().getIndexType(); + return failure( + parser.parseOperand(aggregateInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(aggregateInfo, type, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands) || + parser.addTypeToList(type.getElementType(), result.types)); +} + +static LogicalResult verify(ExtractElementOp op) { + auto aggregateType = op.getAggregate()->getType().cast<ShapedType>(); + + // This should be possible with tablegen type constraints + if (op.getType() != aggregateType.getElementType()) + return op.emitOpError("result type must match element type of aggregate"); + + // Verify the # indices match if we have a ranked type. + if (aggregateType.hasRank() && + aggregateType.getRank() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of indices for extract_element"); + + return success(); +} + +OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { + assert(!operands.empty() && "extract_element takes at least one operand"); + + // The aggregate operand must be a known constant. + Attribute aggregate = operands.front(); + if (!aggregate) + return {}; + + // If this is a splat elements attribute, simply return the value. All of the + // elements of a splat attribute are the same. + if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) + return splatAggregate.getSplatValue(); + + // Otherwise, collect the constant indices into the aggregate. + SmallVector<uint64_t, 8> indices; + for (Attribute indice : llvm::drop_begin(operands, 1)) { + if (!indice || !indice.isa<IntegerAttr>()) + return {}; + indices.push_back(indice.cast<IntegerAttr>().getInt()); + } + + // If this is an elements attribute, query the value at the given indices. + auto elementsAttr = aggregate.dyn_cast<ElementsAttr>(); + if (elementsAttr && elementsAttr.isValidIndex(indices)) + return elementsAttr.getValue(indices); + return {}; +} + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +// Index cast is applicable from index to integer and backwards. +bool IndexCastOp::areCastCompatible(Type a, Type b) { + return (a.isIndex() && b.isa<IntegerType>()) || + (a.isa<IntegerType>() && b.isIndex()); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, LoadOp op) { + p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getMemRefType(); +} + +static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType memrefInfo; + SmallVector<OpAsmParser::OperandType, 4> indexInfo; + MemRefType type; + + auto indexTy = parser.getBuilder().getIndexType(); + return failure( + parser.parseOperand(memrefInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands) || + parser.addTypeToList(type.getElementType(), result.types)); +} + +static LogicalResult verify(LoadOp op) { + if (op.getType() != op.getMemRefType().getElementType()) + return op.emitOpError("result type must match element type of memref"); + + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + return op.emitOpError("incorrect number of indices for load"); + + return success(); +} + +OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { + /// load(memrefcast) -> load + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + +//===----------------------------------------------------------------------===// +// MemRefCastOp +//===----------------------------------------------------------------------===// + +bool MemRefCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast<MemRefType>(); + auto bT = b.dyn_cast<MemRefType>(); + + auto uaT = a.dyn_cast<UnrankedMemRefType>(); + auto ubT = b.dyn_cast<UnrankedMemRefType>(); + + if (aT && bT) { + if (aT.getElementType() != bT.getElementType()) + return false; + if (aT.getAffineMaps() != bT.getAffineMaps()) { + int64_t aOffset, bOffset; + SmallVector<int64_t, 4> aStrides, bStrides; + if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || + failed(getStridesAndOffset(bT, bStrides, bOffset)) || + aStrides.size() != bStrides.size()) + return false; + + // Strides along a dimension/offset are compatible if the value in the + // source memref is static and the value in the target memref is the + // same. They are also compatible if either one is dynamic (see + // description of MemRefCastOp for details). + auto checkCompatible = [](int64_t a, int64_t b) { + return (a == MemRefType::getDynamicStrideOrOffset() || + b == MemRefType::getDynamicStrideOrOffset() || a == b); + }; + if (!checkCompatible(aOffset, bOffset)) + return false; + for (auto aStride : enumerate(aStrides)) + if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) + return false; + } + if (aT.getMemorySpace() != bT.getMemorySpace()) + return false; + + // They must have the same rank, and any specified dimensions must match. + if (aT.getRank() != bT.getRank()) + return false; + + for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { + int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); + if (aDim != -1 && bDim != -1 && aDim != bDim) + return false; + } + return true; + } else { + if (!aT && !uaT) + return false; + if (!bT && !ubT) + return false; + // Unranked to unranked casting is unsupported + if (uaT && ubT) + return false; + + auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); + auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); + if (aEltType != bEltType) + return false; + + auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); + auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); + if (aMemSpace != bMemSpace) + return false; + + return true; + } + + return false; +} + +OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { + return impl::foldCastOp(*this); +} + +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) { + return constFoldBinaryOp<FloatAttr>( + operands, [](APFloat a, APFloat b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { + /// muli(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// muli(x, 1) -> x + if (matchPattern(rhs(), m_One())) + return getOperand(0); + + // TODO: Handle the overflow case. + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, PrefetchOp op) { + p << PrefetchOp::getOperationName() << " " << *op.memref() << '['; + p.printOperands(op.indices()); + p << ']' << ", " << (op.isWrite() ? "write" : "read"); + p << ", locality<" << op.localityHint(); + p << ">, " << (op.isDataCache() ? "data" : "instr"); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); + p << " : " << op.getMemRefType(); +} + +static ParseResult parsePrefetchOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memrefInfo; + SmallVector<OpAsmParser::OperandType, 4> indexInfo; + IntegerAttr localityHint; + MemRefType type; + StringRef readOrWrite, cacheType; + + auto indexTy = parser.getBuilder().getIndexType(); + auto i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseOperand(memrefInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseKeyword(&readOrWrite) || + parser.parseComma() || parser.parseKeyword("locality") || + parser.parseLess() || + parser.parseAttribute(localityHint, i32Type, "localityHint", + result.attributes) || + parser.parseGreater() || parser.parseComma() || + parser.parseKeyword(&cacheType) || parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands)) + return failure(); + + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + return parser.emitError(parser.getNameLoc(), + "rw specifier has to be 'read' or 'write'"); + result.addAttribute( + PrefetchOp::getIsWriteAttrName(), + parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + + if (!cacheType.equals("data") && !cacheType.equals("instr")) + return parser.emitError(parser.getNameLoc(), + "cache type has to be 'data' or 'instr'"); + + result.addAttribute( + PrefetchOp::getIsDataCacheAttrName(), + parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + + return success(); +} + +static LogicalResult verify(PrefetchOp op) { + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + return op.emitOpError("too few indices"); + + return success(); +} + +LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + // prefetch(memrefcast) -> prefetch + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, RankOp op) { + p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); +} + +static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType operandInfo; + Type type; + Type indexType = parser.getBuilder().getIndexType(); + return failure(parser.parseOperand(operandInfo) || + parser.parseColonType(type) || + parser.resolveOperand(operandInfo, type, result.operands) || + parser.addTypeToList(indexType, result.types)); +} + +OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { + // Constant fold rank when the rank of the tensor is known. + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast<RankedTensorType>()) + return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); + return IntegerAttr(); +} + +//===----------------------------------------------------------------------===// +// SignedRemIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "remi_signed takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// UnsignedRemIOp +//===----------------------------------------------------------------------===// + +OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "remi_unsigned takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, result.operands)); +} + +static void print(OpAsmPrinter &p, ReturnOp op) { + p << "return"; + if (op.getNumOperands() != 0) + p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); +} + +static LogicalResult verify(ReturnOp op) { + auto function = cast<FuncOp>(op.getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError("has ") + << op.getNumOperands() + << " operands, but enclosing function returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (op.getOperand(i)->getType() != results[i]) + return op.emitError() + << "type of return operand " << i << " (" + << op.getOperand(i)->getType() + << ") doesn't match function result type (" << results[i] << ")"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SIToFPOp +//===----------------------------------------------------------------------===// + +// sitofp is applicable from integer types to float types. +bool SIToFPOp::areCastCompatible(Type a, Type b) { + return a.isa<IntegerType>() && b.isa<FloatType>(); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 3> ops; + SmallVector<NamedAttribute, 4> attrs; + Type type; + if (parser.parseOperandList(ops, 3) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return failure(); + + auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type); + if (!i1Type) + return parser.emitError(parser.getNameLoc(), + "expected type with valid i1 shape"); + + SmallVector<Type, 3> types = {i1Type, type, type}; + return failure(parser.resolveOperands(ops, types, parser.getNameLoc(), + result.operands) || + parser.addTypeToList(type, result.types)); +} + +static void print(OpAsmPrinter &p, SelectOp op) { + p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType(); + p.printOptionalAttrDict(op.getAttrs()); +} + +static LogicalResult verify(SelectOp op) { + auto trueType = op.getTrueValue()->getType(); + auto falseType = op.getFalseValue()->getType(); + + if (trueType != falseType) + return op.emitOpError( + "requires 'true' and 'false' arguments to be of the same type"); + + return success(); +} + +OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { + auto condition = getCondition(); + + // select true, %0, %1 => %0 + if (matchPattern(condition, m_One())) + return getTrueValue(); + + // select false, %0, %1 => %1 + if (matchPattern(condition, m_Zero())) + return getFalseValue(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// SignExtendIOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SignExtendIOp op) { + // Get the scalar type (which is either directly the type of the operand + // or the vector's/tensor's element type. + auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto dstType = getElementTypeOrSelf(op.getType()); + + // For now, index is forbidden for the source and the destination type. + if (srcType.isa<IndexType>()) + return op.emitError() << srcType << " is not a valid operand type"; + if (dstType.isa<IndexType>()) + return op.emitError() << dstType << " is not a valid result type"; + + if (srcType.cast<IntegerType>().getWidth() >= + dstType.cast<IntegerType>().getWidth()) + return op.emitError("result type ") + << dstType << " must be wider than operand type " << srcType; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, SplatOp op) { + p << "splat " << *op.getOperand(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getType(); +} + +static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType splatValueInfo; + ShapedType shapedType; + + return failure(parser.parseOperand(splatValueInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(shapedType) || + parser.resolveOperand(splatValueInfo, + shapedType.getElementType(), + result.operands) || + parser.addTypeToList(shapedType, result.types)); +} + +static LogicalResult verify(SplatOp op) { + // TODO: we could replace this by a trait. + if (op.getOperand()->getType() != + op.getType().cast<ShapedType>().getElementType()) + return op.emitError("operand should be of elemental type of result type"); + + return success(); +} + +// Constant folding hook for SplatOp. +OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 1 && "splat takes one operand"); + + auto constOperand = operands.front(); + if (!constOperand || + (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>())) + return {}; + + auto shapedType = getType().cast<ShapedType>(); + assert(shapedType.getElementType() == constOperand.getType() && + "incorrect input attribute type for folding"); + + // SplatElementsAttr::get treats single value for second arg as being a splat. + return SplatElementsAttr::get(shapedType, {constOperand}); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, StoreOp op) { + p << "store " << *op.getValueToStore(); + p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getMemRefType(); +} + +static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + SmallVector<OpAsmParser::OperandType, 4> indexInfo; + MemRefType memrefType; + + auto indexTy = parser.getBuilder().getIndexType(); + return failure( + parser.parseOperand(storeValueInfo) || parser.parseComma() || + parser.parseOperand(memrefInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(memrefType) || + parser.resolveOperand(storeValueInfo, memrefType.getElementType(), + result.operands) || + parser.resolveOperand(memrefInfo, memrefType, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands)); +} + +static LogicalResult verify(StoreOp op) { + // First operand must have same type as memref element type. + if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) + return op.emitOpError( + "first operand must have same type memref element type"); + + if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) + return op.emitOpError("store index operand count not equal to memref rank"); + + return success(); +} + +LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, + SmallVectorImpl<OpFoldResult> &results) { + /// store(memrefcast) -> store + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + +OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) { + return constFoldBinaryOp<FloatAttr>( + operands, [](APFloat a, APFloat b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) { + // subi(x,x) -> 0 + if (getOperand(0) == getOperand(1)) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + +OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { + /// and(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// and(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a & b; }); +} + +//===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + +OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { + /// or(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// or(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a | b; }); +} + +//===----------------------------------------------------------------------===// +// XOrOp +//===----------------------------------------------------------------------===// + +OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) { + /// xor(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// xor(x,x) -> 0 + if (lhs() == rhs()) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp<IntegerAttr>(operands, + [](APInt a, APInt b) { return a ^ b; }); +} + +//===----------------------------------------------------------------------===// +// TensorCastOp +//===----------------------------------------------------------------------===// + +bool TensorCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast<TensorType>(); + auto bT = b.dyn_cast<TensorType>(); + if (!aT || !bT) + return false; + + if (aT.getElementType() != bT.getElementType()) + return false; + + return succeeded(verifyCompatibleShape(aT, bT)); +} + +OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { + return impl::foldCastOp(*this); +} + +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Builder &b, Type type) { + if (auto memref = type.dyn_cast<MemRefType>()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + return b.getNoneType(); +} + +//===----------------------------------------------------------------------===// +// TensorLoadOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, TensorLoadOp op) { + p << "tensor_load " << *op.getOperand(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand()->getType(); +} + +static ParseResult parseTensorLoadOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType op; + Type type; + return failure(parser.parseOperand(op) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(op, type, result.operands) || + parser.addTypeToList( + getTensorTypeFromMemRefType(parser.getBuilder(), type), + result.types)); +} + +//===----------------------------------------------------------------------===// +// TensorStoreOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, TensorStoreOp op) { + p << "tensor_store " << *op.tensor() << ", " << *op.memref(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.memref()->getType(); +} + +static ParseResult parseTensorStoreOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<OpAsmParser::OperandType, 2> ops; + Type type; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure( + parser.parseOperandList(ops, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperands( + ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type}, + loc, result.operands)); +} + +//===----------------------------------------------------------------------===// +// TruncateIOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(TruncateIOp op) { + auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto dstType = getElementTypeOrSelf(op.getType()); + + if (srcType.isa<IndexType>()) + return op.emitError() << srcType << " is not a valid operand type"; + if (dstType.isa<IndexType>()) + return op.emitError() << dstType << " is not a valid result type"; + + if (srcType.cast<IntegerType>().getWidth() <= + dstType.cast<IntegerType>().getWidth()) + return op.emitError("operand type ") + << srcType << " must be wider than result type " << dstType; + + return success(); +} + +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// + +static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector<OpAsmParser::OperandType, 1> offsetInfo; + SmallVector<OpAsmParser::OperandType, 4> sizesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + llvm::SMLoc offsetLoc; + if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + return failure(); + + if (offsetInfo.size() > 1) + return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; + + return failure( + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, ViewOp op) { + p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; + auto dynamicOffset = op.getDynamicOffset(); + if (dynamicOffset != nullptr) + p.printOperand(dynamicOffset); + p << "][" << op.getDynamicSizes() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); +} + +Value ViewOp::getDynamicOffset() { + int64_t offset; + SmallVector<int64_t, 4> strides; + auto result = + succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); + assert(result); + if (result && offset == MemRefType::getDynamicStrideOrOffset()) + return getOperand(1); + return nullptr; +} + +static LogicalResult verifyDynamicStrides(MemRefType memrefType, + ArrayRef<int64_t> strides) { + ArrayRef<int64_t> shape = memrefType.getShape(); + unsigned rank = memrefType.getRank(); + assert(rank == strides.size()); + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(shape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) + return failure(); + } + return success(); +} + +static LogicalResult verify(ViewOp op) { + auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); + auto viewType = op.getResult()->getType().cast<MemRefType>(); + + // The base memref should have identity layout map (or none). + if (baseType.getAffineMaps().size() > 1 || + (baseType.getAffineMaps().size() == 1 && + !baseType.getAffineMaps()[0].isIdentity())) + return op.emitError("unsupported map for base memref type ") << baseType; + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != viewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and view memref type " << viewType; + + // Verify that the result memref type has a strided layout map. + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(viewType, strides, offset))) + return op.emitError("result type ") << viewType << " is not strided"; + + // Verify that we have the correct number of operands for the result type. + unsigned memrefOperandCount = 1; + unsigned numDynamicDims = viewType.getNumDynamicDims(); + unsigned dynamicOffsetCount = + offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; + if (op.getNumOperands() != + memrefOperandCount + numDynamicDims + dynamicOffsetCount) + return op.emitError("incorrect number of operands for type ") << viewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + if (failed(verifyDynamicStrides(viewType, strides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; + return success(); +} + +namespace { + +struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { + using OpRewritePattern<ViewOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { + // Return if none of the operands are constants. + if (llvm::none_of(viewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); + + // Get result memref type. + auto memrefType = viewOp.getType(); + if (memrefType.getAffineMaps().size() != 1) + return matchFailure(); + auto map = memrefType.getAffineMaps()[0]; + + // Get offset from old memref view type 'memRefType'. + int64_t oldOffset; + SmallVector<int64_t, 4> oldStrides; + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + return matchFailure(); + + SmallVector<Value, 4> newOperands; + SmallVector<Value, 4> droppedOperands; + + // Fold dynamic offset operand if it is produced by a constant. + auto dynamicOffset = viewOp.getDynamicOffset(); + int64_t newOffset = oldOffset; + unsigned dynamicOffsetOperandCount = 0; + if (dynamicOffset != nullptr) { + auto *defOp = dynamicOffset->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { + // Dynamic offset will be folded into the map. + newOffset = constantIndexOp.getValue(); + droppedOperands.push_back(dynamicOffset); + } else { + // Unable to fold dynamic offset. Add it to 'newOperands' list. + newOperands.push_back(dynamicOffset); + dynamicOffsetOperandCount = 1; + } + } + + // Fold any dynamic dim operands which are produced by a constant. + SmallVector<int64_t, 4> newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + + unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart(); + unsigned rank = memrefType.getRank(); + for (unsigned dim = 0, e = rank; dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (!ShapedType::isDynamic(dimSize)) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + // Record to check for zero uses later below. + droppedOperands.push_back(constantIndexOp); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(dimSize); + newOperands.push_back(viewOp.getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Compute new strides based on 'newShapeConstants'. + SmallVector<int64_t, 4> newStrides(rank); + newStrides[rank - 1] = 1; + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + if (ShapedType::isDynamic(newShapeConstants[i + 1])) + dynamicStrides = true; + if (dynamicStrides) + newStrides[i] = MemRefType::getDynamicStrideOrOffset(); + else + newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1]; + } + + // Regenerate strided layout map with 'newStrides' and 'newOffset'. + map = makeStridedLinearLayoutMap(newStrides, newOffset, + rewriter.getContext()); + + // Create new memref type with constant folded dims and/or offset/strides. + auto newMemRefType = + MemRefType::get(newShapeConstants, memrefType.getElementType(), {map}, + memrefType.getMemorySpace()); + assert(static_cast<int64_t>(newOperands.size()) == + dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); + + // Create new ViewOp. + auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), newOperands); + // Insert a cast so we have the same type as the old memref type. + rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp, + newViewOp, viewOp.getType()); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<ViewOpShapeFolder>(context); +} + +//===----------------------------------------------------------------------===// +// SubViewOp +//===----------------------------------------------------------------------===// + +// Returns a MemRefType with dynamic sizes and offset and the same stride as the +// `memRefType` passed as argument. +// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep +// sizes and offset static. +static Type inferSubViewResultType(MemRefType memRefType) { + auto rank = memRefType.getRank(); + int64_t offset; + SmallVector<int64_t, 4> strides; + Type elementType = memRefType.getElementType(); + auto res = getStridesAndOffset(memRefType, strides, offset); + assert(succeeded(res) && "SubViewOp expected strided memref type"); + (void)res; + + // Assume sizes and offset are fully dynamic for now until canonicalization + // occurs on the ranges. Typed strides don't change though. + offset = MemRefType::getDynamicStrideOrOffset(); + // Overwrite strides because verifier will not pass. + // TODO(b/144419106): don't force degrade the strides to fully dynamic. + for (auto &stride : strides) + stride = MemRefType::getDynamicStrideOrOffset(); + auto stridedLayout = + makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); + SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize); + return MemRefType::get(sizes, elementType, stridedLayout, + memRefType.getMemorySpace()); +} + +void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, Type resultType, + ArrayRef<NamedAttribute> attrs) { + if (!resultType) + resultType = inferSubViewResultType(source->getType().cast<MemRefType>()); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()), + static_cast<int32_t>(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); + result.addAttributes(attrs); +} + +void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, + Value source) { + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); +} + +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector<OpAsmParser::OperandType, 4> offsetsInfo; + SmallVector<OpAsmParser::OperandType, 4> sizesInfo; + SmallVector<OpAsmParser::OperandType, 4> stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + if (parser.parseOperand(srcInfo) || + parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { + return failure(); + } + + auto builder = parser.getBuilder(); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()), + static_cast<int32_t>(sizesInfo.size()), + static_cast<int32_t>(stridesInfo.size())})); + + return failure( + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, SubViewOp op) { + p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets() + << "][" << op.sizes() << "][" << op.strides() << ']'; + + SmallVector<StringRef, 1> elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; + p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); + p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); +} + +static LogicalResult verify(SubViewOp op) { + auto baseType = op.getBaseMemRefType().cast<MemRefType>(); + auto subViewType = op.getType(); + + // The rank of the base and result subview must match. + if (baseType.getRank() != subViewType.getRank()) { + return op.emitError( + "expected rank of result type to match rank of base type "); + } + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and subview memref type " << subViewType; + + // Verify that the base memref type has a strided layout map. + int64_t baseOffset; + SmallVector<int64_t, 4> baseStrides; + if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) + return op.emitError("base type ") << subViewType << " is not strided"; + + // Verify that the result memref type has a strided layout map. + int64_t subViewOffset; + SmallVector<int64_t, 4> subViewStrides; + if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) + return op.emitError("result type ") << subViewType << " is not strided"; + + // Num offsets should either be zero or rank of memref. + if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { + return op.emitError("expected number of dynamic offsets specified to match " + "the rank of the result type ") + << subViewType; + } + + // Num sizes should either be zero or rank of memref. + if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { + return op.emitError("expected number of dynamic sizes specified to match " + "the rank of the result type ") + << subViewType; + } + + // Num strides should either be zero or rank of memref. + if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { + return op.emitError("expected number of dynamic strides specified to match " + "the rank of the result type ") + << subViewType; + } + + // Verify that if the shape of the subview type is static, then sizes are not + // dynamic values, and vice versa. + if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || + (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { + return op.emitError("invalid to specify dynamic sizes when subview result " + "type is statically shaped and viceversa"); + } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. + if (op.getNumSizes() > 0) { + if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { + return dim != ShapedType::kDynamicSize; + })) { + // TODO: This is based on the assumption that number of size arguments are + // either 0, or the rank of the result type. It is possible to have more + // fine-grained verification where only particular dimensions are + // dynamic. That probably needs further changes to the shape op + // specification. + return op.emitError("expected shape of result type to be fully dynamic " + "when sizes are specified"); + } + } + + // Verify that if dynamic offsets are specified or base memref has dynamic + // offset or base memref has dynamic strides, then the subview offset is + // dynamic. + if ((op.getNumOffsets() > 0 || + baseOffset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset())) && + subViewOffset != MemRefType::getDynamicStrideOrOffset()) { + return op.emitError( + "expected result memref layout map to have dynamic offset"); + } + + // For now, verify that if dynamic strides are specified, then all the result + // memref type have dynamic strides. + if (op.getNumStrides() > 0) { + if (llvm::any_of(subViewStrides, [](int64_t stride) { + return stride != MemRefType::getDynamicStrideOrOffset(); + })) { + return op.emitError("expected result type to have dynamic strides"); + } + } + + // If any of the base memref has dynamic stride, then the corresponding + // stride of the subview must also have dynamic stride. + assert(baseStrides.size() == subViewStrides.size()); + for (auto stride : enumerate(baseStrides)) { + if (stride.value() == MemRefType::getDynamicStrideOrOffset() && + subViewStrides[stride.index()] != + MemRefType::getDynamicStrideOrOffset()) { + return op.emitError( + "expected result type to have dynamic stride along a dimension if " + "the base memref type has dynamic stride along that dimension"); + } + } + return success(); +} + +raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { + return os << "range " << *range.offset << ":" << *range.size << ":" + << *range.stride; +} + +SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() { + SmallVector<Range, 8> res; + unsigned rank = getType().getRank(); + res.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); + return res; +} + +LogicalResult +SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) { + // If the strides are dynamic return failure. + if (getNumStrides()) + return failure(); + + // When static, the stride operands can be retrieved by taking the strides of + // the result of the subview op, and dividing the strides of the base memref. + int64_t resultOffset, baseOffset; + SmallVector<int64_t, 2> resultStrides, baseStrides; + if (failed( + getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || + llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || + failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) + return failure(); + + assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() && + baseStrides.size() == resultStrides.size() && + "base and result memrefs must have the same rank"); + assert(!llvm::is_contained(resultStrides, + MemRefType::getDynamicStrideOrOffset()) && + "strides of subview op must be static, when there are no dynamic " + "strides specified"); + staticStrides.resize(getType().getRank()); + for (auto resultStride : enumerate(resultStrides)) { + auto baseStride = baseStrides[resultStride.index()]; + // The result stride is expected to be a multiple of the base stride. Abort + // if that is not the case. + if (resultStride.value() < baseStride || + resultStride.value() % baseStride != 0) + return failure(); + staticStrides[resultStride.index()] = resultStride.value() / baseStride; + } + return success(); +} + +static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) { + if (memrefType.getNumDynamicDims() > 0) + return false; + // Get offset and strides. + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return false; + // Return 'false' if any of offset or strides is dynamic. + if (offset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) + return false; + return true; +} + +namespace { + +/// Pattern to rewrite a subview op with constant size arguments. +class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> { +public: + using OpRewritePattern<SubViewOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + MemRefType subViewType = subViewOp.getType(); + // Follow all or nothing approach for shapes for now. If all the operands + // for sizes are constants then fold it into the type of the result memref. + if (subViewType.hasStaticShape() || + llvm::any_of(subViewOp.sizes(), [](Value operand) { + return !matchPattern(operand, m_ConstantIndex()); + })) { + return matchFailure(); + } + SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes()); + for (auto size : llvm::enumerate(subViewOp.sizes())) { + auto defOp = size.value()->getDefiningOp(); + assert(defOp); + staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue(); + } + MemRefType newMemRefType = MemRefType::get( + staticShape, subViewType.getElementType(), subViewType.getAffineMaps(), + subViewType.getMemorySpace()); + auto newSubViewOp = rewriter.create<SubViewOp>( + subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), + ArrayRef<Value>(), subViewOp.strides(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp<MemRefCastOp>( + subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType()); + return matchSuccess(); + } +}; + +// Pattern to rewrite a subview op with constant stride arguments. +class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> { +public: + using OpRewritePattern<SubViewOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + if (subViewOp.getNumStrides() == 0) { + return matchFailure(); + } + // Follow all or nothing approach for strides for now. If all the operands + // for strides are constants then fold it into the strides of the result + // memref. + int64_t baseOffset, resultOffset; + SmallVector<int64_t, 4> baseStrides, resultStrides; + MemRefType subViewType = subViewOp.getType(); + if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, + baseOffset)) || + failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset()) || + llvm::any_of(subViewOp.strides(), [](Value stride) { + return !matchPattern(stride, m_ConstantIndex()); + })) { + return matchFailure(); + } + + SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides()); + for (auto stride : llvm::enumerate(subViewOp.strides())) { + auto defOp = stride.value()->getDefiningOp(); + assert(defOp); + assert(baseStrides[stride.index()] > 0); + staticStrides[stride.index()] = + cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()]; + } + AffineMap layoutMap = makeStridedLinearLayoutMap( + staticStrides, resultOffset, rewriter.getContext()); + MemRefType newMemRefType = + MemRefType::get(subViewType.getShape(), subViewType.getElementType(), + layoutMap, subViewType.getMemorySpace()); + auto newSubViewOp = rewriter.create<SubViewOp>( + subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), + subViewOp.sizes(), ArrayRef<Value>(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp<MemRefCastOp>( + subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType()); + return matchSuccess(); + } +}; + +// Pattern to rewrite a subview op with constant offset arguments. +class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> { +public: + using OpRewritePattern<SubViewOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + if (subViewOp.getNumOffsets() == 0) { + return matchFailure(); + } + // Follow all or nothing approach for offsets for now. If all the operands + // for offsets are constants then fold it into the offset of the result + // memref. + int64_t baseOffset, resultOffset; + SmallVector<int64_t, 4> baseStrides, resultStrides; + MemRefType subViewType = subViewOp.getType(); + if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, + baseOffset)) || + failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset()) || + baseOffset == MemRefType::getDynamicStrideOrOffset() || + llvm::any_of(subViewOp.offsets(), [](Value stride) { + return !matchPattern(stride, m_ConstantIndex()); + })) { + return matchFailure(); + } + + auto staticOffset = baseOffset; + for (auto offset : llvm::enumerate(subViewOp.offsets())) { + auto defOp = offset.value()->getDefiningOp(); + assert(defOp); + assert(baseStrides[offset.index()] > 0); + staticOffset += + cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()]; + } + + AffineMap layoutMap = makeStridedLinearLayoutMap( + resultStrides, staticOffset, rewriter.getContext()); + MemRefType newMemRefType = + MemRefType::get(subViewType.getShape(), subViewType.getElementType(), + layoutMap, subViewType.getMemorySpace()); + auto newSubViewOp = rewriter.create<SubViewOp>( + subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(), + subViewOp.sizes(), subViewOp.strides(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp<MemRefCastOp>( + subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType()); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder, + SubViewOpOffsetFolder>(context); +} + +//===----------------------------------------------------------------------===// +// ZeroExtendIOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ZeroExtendIOp op) { + auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto dstType = getElementTypeOrSelf(op.getType()); + + if (srcType.isa<IndexType>()) + return op.emitError() << srcType << " is not a valid operand type"; + if (dstType.isa<IndexType>()) + return op.emitError() << dstType << " is not a valid result type"; + + if (srcType.cast<IntegerType>().getWidth() >= + dstType.cast<IntegerType>().getWidth()) + return op.emitError("result type ") + << dstType << " must be wider than operand type " << srcType; + + return success(); +} + +//===----------------------------------------------------------------------===// +// FPExtOp +//===----------------------------------------------------------------------===// + +bool FPExtOp::areCastCompatible(Type a, Type b) { + if (auto fa = a.dyn_cast<FloatType>()) + if (auto fb = b.dyn_cast<FloatType>()) + return fa.getWidth() < fb.getWidth(); + return false; +} + +//===----------------------------------------------------------------------===// +// FPTruncOp +//===----------------------------------------------------------------------===// + +bool FPTruncOp::areCastCompatible(Type a, Type b) { + if (auto fa = a.dyn_cast<FloatType>()) + if (auto fb = b.dyn_cast<FloatType>()) + return fa.getWidth() > fb.getWidth(); + return false; +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/StandardOps/Ops.cpp.inc" |