diff options
Diffstat (limited to 'mlir/lib/StandardOps/StandardOps.cpp')
| -rw-r--r-- | mlir/lib/StandardOps/StandardOps.cpp | 226 |
1 files changed, 113 insertions, 113 deletions
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index b60d209e1f5..e2bdfd7a18b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// void AllocOp::build(Builder *builder, OperationState *result, - MemRefType *memrefType, ArrayRef<SSAValue *> operands) { + MemRefType memrefType, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->types.push_back(memrefType); } void AllocOp::print(OpAsmPrinter *p) const { - MemRefType *type = getType(); + MemRefType type = getType(); *p << "alloc"; // Print dynamic dimension operands. printDimAndSymbolList(operand_begin(), operand_end(), - type->getNumDynamicDims(), p); + type.getNumDynamicDims(), p); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); - *p << " : " << *type; + *p << " : " << type; } bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { - MemRefType *type; + MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a // memref type. @@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. - if (numDimOperands != type->getNumDynamicDims()) { + if (numDimOperands != type.getNumDynamicDims()) { return parser->emitError(parser->getNameLoc(), "dimension operand count does not equal memref " "dynamic dimension count"); @@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { } bool AllocOp::verify() const { - auto *memRefType = dyn_cast<MemRefType>(getResult()->getType()); + auto memRefType = getResult()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("result must be a memref"); unsigned numSymbols = 0; - if (!memRefType->getAffineMaps().empty()) { - AffineMap affineMap = memRefType->getAffineMaps()[0]; + if (!memRefType.getAffineMaps().empty()) { + AffineMap affineMap = memRefType.getAffineMaps()[0]; // Store number of symbols used in affine map (used in subsequent check). numSymbols = affineMap.getNumSymbols(); // TODO(zinenko): this check does not belong to AllocOp, or any other op but @@ -195,10 +195,10 @@ bool AllocOp::verify() const { // Remove when we can emit errors directly from *Type::get(...) functions. // // Verify that the layout affine map matches the rank of the memref. - if (affineMap.getNumDims() != memRefType->getRank()) + if (affineMap.getNumDims() != memRefType.getRank()) return emitOpError("affine map dimension count must equal memref rank"); } - unsigned numDynamicDims = memRefType->getNumDynamicDims(); + unsigned numDynamicDims = memRefType.getNumDynamicDims(); // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. @@ -208,7 +208,7 @@ bool AllocOp::verify() const { } // Verify that all operands are of type Index. for (auto *operand : getOperands()) { - if (!operand->getType()->isIndex()) + if (!operand->getType().isIndex()) return emitOpError("requires operands to be of type Index"); } return false; @@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern { // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector<int, 4> newShapeConstants; - newShapeConstants.reserve(memrefType->getRank()); + newShapeConstants.reserve(memrefType.getRank()); SmallVector<SSAValue *, 4> newOperands; SmallVector<SSAValue *, 4> droppedOperands; unsigned dynamicDimPos = 0; - for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) { - int dimSize = memrefType->getDimSize(dim); + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); @@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern { } // Create new memref type (which will have fewer dynamic dimensions). - auto *newMemRefType = MemRefType::get( - newShapeConstants, memrefType->getElementType(), - memrefType->getAffineMaps(), memrefType->getMemorySpace()); - assert(newOperands.size() == newMemRefType->getNumDynamicDims()); + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(newOperands.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = @@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType()->getResults()); + result->addTypes(callee->getType().getResults()); } bool CallOp::parse(OpAsmParser *parser, OperationState *result) { StringRef calleeName; llvm::SMLoc calleeLoc; - FunctionType *calleeType = nullptr; + FunctionType calleeType; SmallVector<OpAsmParser::OperandType, 4> operands; Function *callee = nullptr; if (parser->parseFunctionName(calleeName, calleeLoc) || @@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || - parser->addTypesToList(calleeType->getResults(), result->types) || - parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, + parser->addTypesToList(calleeType.getResults(), result->types) || + parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, result->operands)) return true; @@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const { p->printOperands(getOperands()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallOp::verify() const { @@ -338,20 +338,20 @@ bool CallOp::verify() const { return emitOpError("requires a 'callee' function attribute"); // Verify that the operand and result types match the callee. - auto *fnType = fnAttr.getValue()->getType(); - if (fnType->getNumInputs() != getNumOperands()) + auto fnType = fnAttr.getValue()->getType(); + if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -364,14 +364,14 @@ bool CallOp::verify() const { void CallIndirectOp::build(Builder *builder, OperationState *result, SSAValue *callee, ArrayRef<SSAValue *> operands) { - auto *fnType = cast<FunctionType>(callee->getType()); + auto fnType = callee->getType().cast<FunctionType>(); result->operands.push_back(callee); result->addOperands(operands); - result->addTypes(fnType->getResults()); + result->addTypes(fnType.getResults()); } bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { - FunctionType *calleeType = nullptr; + FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector<OpAsmParser::OperandType, 4> operands; @@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, result->operands) || - parser->addTypesToList(calleeType->getResults(), result->types); + parser->addTypesToList(calleeType.getResults(), result->types); } void CallIndirectOp::print(OpAsmPrinter *p) const { @@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const { p->printOperands(++operandRange.begin(), operandRange.end()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallIndirectOp::verify() const { // The callee must be a function. - auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); + auto fnType = getCallee()->getType().dyn_cast<FunctionType>(); if (!fnType) return emitOpError("callee must have function type"); // Verify that the operand and result types match the callee. - if (fnType->getNumInputs() != getNumOperands() - 1) + if (fnType.getNumInputs() != getNumOperands() - 1) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i + 1)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i + 1)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result, } void DeallocOp::print(OpAsmPrinter *p) const { - *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); + *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); } bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; - MemRefType *type; + MemRefType type; return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands); } bool DeallocOp::verify() const { - if (!isa<MemRefType>(getMemRef()->getType())) + if (!getMemRef()->getType().isa<MemRefType>()) return emitOpError("operand must be a memref"); return false; } @@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result, void DimOp::print(OpAsmPrinter *p) const { *p << "dim " << *getOperand() << ", " << getIndex(); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); - *p << " : " << *getOperand()->getType(); + *p << " : " << getOperand()->getType(); } bool DimOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; - Type *type; + Type type; return parser->parseOperand(operandInfo) || parser->parseComma() || parser->parseAttribute(indexAttr, "index", result->attributes) || @@ -496,15 +496,15 @@ bool DimOp::verify() const { return emitOpError("requires an integer attribute named 'index'"); uint64_t index = (uint64_t)indexAttr.getValue(); - auto *type = getOperand()->getType(); - if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { - if (index >= tensorType->getRank()) + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + if (index >= tensorType.getRank()) return emitOpError("index is out of range"); - } else if (auto *memrefType = dyn_cast<MemRefType>(type)) { - if (index >= memrefType->getRank()) + } else if (auto memrefType = type.dyn_cast<MemRefType>()) { + if (index >= memrefType.getRank()) return emitOpError("index is out of range"); - } else if (isa<UnrankedTensorType>(type)) { + } else if (type.isa<UnrankedTensorType>()) { // ok, assumed to be in-range. } else { return emitOpError("requires an operand with tensor or memref type"); @@ -516,12 +516,12 @@ bool DimOp::verify() const { Attribute DimOp::constantFold(ArrayRef<Attribute> operands, MLIRContext *context) const { // Constant fold dim when the size along the index referred to is a constant. - auto *opType = getOperand()->getType(); + auto opType = getOperand()->getType(); int indexSize = -1; - if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) { - indexSize = tensorType->getShape()[getIndex()]; - } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) { - indexSize = memrefType->getShape()[getIndex()]; + if (auto tensorType = opType.dyn_cast<RankedTensorType>()) { + indexSize = tensorType.getShape()[getIndex()]; + } else if (auto memrefType = opType.dyn_cast<MemRefType>()) { + indexSize = memrefType.getShape()[getIndex()]; } if (indexSize >= 0) @@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getSrcMemRef()->getType(); - *p << ", " << *getDstMemRef()->getType(); - *p << ", " << *getTagMemRef()->getType(); + *p << " : " << getSrcMemRef()->getType(); + *p << ", " << getDstMemRef()->getType(); + *p << ", " << getTagMemRef()->getType(); } // Parse DmaStartOp. @@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; - SmallVector<Type *, 3> types; - auto *indexType = parser->getBuilder().getIndexType(); + SmallVector<Type, 3> types; + auto indexType = parser->getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). @@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { return true; // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || - dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) + if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() || + dstIndexInfos.size() != types[1].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "memref rank not equal to indices count"); - if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) + if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << "], "; p->printOperand(getNumElements()); - *p << " : " << *getTagMemRef()->getType(); + *p << " : " << getTagMemRef()->getType(); } // Parse DmaWaitOp. @@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; - Type *type; - auto *indexType = parser->getBuilder().getIndexType(); + Type type; + auto indexType = parser->getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. @@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperand(numElementsInfo, indexType, result->operands)) return true; - if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) + if (tagIndexInfos.size() != type.cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results, void ExtractElementOp::build(Builder *builder, OperationState *result, SSAValue *aggregate, ArrayRef<SSAValue *> indices) { - auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); + auto aggregateType = aggregate->getType().cast<VectorOrTensorType>(); result->addOperands(aggregate); result->addOperands(indices); - result->types.push_back(aggregateType->getElementType()); + result->types.push_back(aggregateType.getElementType()); } void ExtractElementOp::print(OpAsmPrinter *p) const { @@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getAggregate()->getType(); + *p << " : " << getAggregate()->getType(); } bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - VectorOrTensorType *type; + VectorOrTensorType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(aggregateInfo) || @@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(aggregateInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool ExtractElementOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected an aggregate to index into"); - auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); + auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>(); if (!aggregateType) return emitOpError("first operand must be a vector or tensor"); - if (getType() != aggregateType->getElementType()) + if (getType() != aggregateType.getElementType()) return emitOpError("result type must match element type of aggregate"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to extract_element must have 'index' type"); // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType->getRank(); + auto aggregateRank = aggregateType.getRank(); if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) return emitOpError("incorrect number of indices for extract_element"); @@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const { void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, ArrayRef<SSAValue *> indices) { - auto *memrefType = cast<MemRefType>(memref->getType()); + auto memrefType = memref->getType().cast<MemRefType>(); result->addOperands(memref); result->addOperands(indices); - result->types.push_back(memrefType->getElementType()); + result->types.push_back(memrefType.getElementType()); } void LoadOp::print(OpAsmPrinter *p) const { @@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *type; + MemRefType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(memrefInfo) || @@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool LoadOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected a memref to load from"); - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("first operand must be a memref"); - if (getType() != memRefType->getElementType()) + if (getType() != memRefType.getElementType()) return emitOpError("result type must match element type of memref"); - if (memRefType->getRank() != getNumOperands() - 1) + if (memRefType.getRank() != getNumOperands() - 1) return emitOpError("incorrect number of indices for load"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool MemRefCastOp::verify() const { - auto *opType = dyn_cast<MemRefType>(getOperand()->getType()); - auto *resType = dyn_cast<MemRefType>(getType()); + auto opType = getOperand()->getType().dyn_cast<MemRefType>(); + auto resType = getType().dyn_cast<MemRefType>(); if (!opType || !resType) return emitOpError("requires input and result types to be memrefs"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); - if (opType->getAffineMaps() != resType->getAffineMaps()) + if (opType.getAffineMaps() != resType.getAffineMaps()) return emitOpError("requires input and result mappings to be the same"); - if (opType->getMemorySpace() != resType->getMemorySpace()) + if (opType.getMemorySpace() != resType.getMemorySpace()) return emitOpError( "requires input and result memory spaces to be the same"); // They must have the same rank, and any specified dimensions must match. - if (opType->getRank() != resType->getRank()) + if (opType.getRank() != resType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opType->getRank(); i != e; ++i) { - int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i); + for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { + int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } @@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *memrefType; + MemRefType memrefType; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(storeValueInfo) || parser->parseComma() || @@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType->getElementType(), + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands); @@ -950,19 +950,19 @@ bool StoreOp::verify() const { return emitOpError("expected a value to store and a memref"); // Second operand is a memref type. - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("second operand must be a memref"); // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType->getElementType()) + if (getValueToStore()->getType() != memRefType.getElementType()) return emitOpError("first operand must have same type memref element type"); - if (getNumOperands() != 2 + memRefType->getRank()) + if (getNumOperands() != 2 + memRefType.getRank()) return emitOpError("store index operand count not equal to memref rank"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool TensorCastOp::verify() const { - auto *opType = dyn_cast<TensorType>(getOperand()->getType()); - auto *resType = dyn_cast<TensorType>(getType()); + auto opType = getOperand()->getType().dyn_cast<TensorType>(); + auto resType = getType().dyn_cast<TensorType>(); if (!opType || !resType) return emitOpError("requires input and result types to be tensors"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); // If the source or destination are unranked, then the cast is valid. - auto *opRType = dyn_cast<RankedTensorType>(opType); - auto *resRType = dyn_cast<RankedTensorType>(resType); + auto opRType = opType.dyn_cast<RankedTensorType>(); + auto resRType = resType.dyn_cast<RankedTensorType>(); if (!opRType || !resRType) return false; // If they are both ranked, they have to have the same rank, and any specified // dimensions must match. - if (opRType->getRank() != resRType->getRank()) + if (opRType.getRank() != resRType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { - int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); + for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { + int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } |

