summaryrefslogtreecommitdiffstats
path: root/mlir/lib/StandardOps/StandardOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/StandardOps/StandardOps.cpp')
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp226
1 files changed, 113 insertions, 113 deletions
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index b60d209e1f5..e2bdfd7a18b 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result,
- MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
+ MemRefType memrefType, ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->types.push_back(memrefType);
}
void AllocOp::print(OpAsmPrinter *p) const {
- MemRefType *type = getType();
+ MemRefType type = getType();
*p << "alloc";
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
- type->getNumDynamicDims(), p);
+ type.getNumDynamicDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
- *p << " : " << *type;
+ *p << " : " << type;
}
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
- MemRefType *type;
+ MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
@@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
// Verification still checks that the total number of operands matches
// the number of symbols in the affine map, plus the number of dynamic
// dimensions in the memref.
- if (numDimOperands != type->getNumDynamicDims()) {
+ if (numDimOperands != type.getNumDynamicDims()) {
return parser->emitError(parser->getNameLoc(),
"dimension operand count does not equal memref "
"dynamic dimension count");
@@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
}
bool AllocOp::verify() const {
- auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
+ auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("result must be a memref");
unsigned numSymbols = 0;
- if (!memRefType->getAffineMaps().empty()) {
- AffineMap affineMap = memRefType->getAffineMaps()[0];
+ if (!memRefType.getAffineMaps().empty()) {
+ AffineMap affineMap = memRefType.getAffineMaps()[0];
// Store number of symbols used in affine map (used in subsequent check).
numSymbols = affineMap.getNumSymbols();
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
@@ -195,10 +195,10 @@ bool AllocOp::verify() const {
// Remove when we can emit errors directly from *Type::get(...) functions.
//
// Verify that the layout affine map matches the rank of the memref.
- if (affineMap.getNumDims() != memRefType->getRank())
+ if (affineMap.getNumDims() != memRefType.getRank())
return emitOpError("affine map dimension count must equal memref rank");
}
- unsigned numDynamicDims = memRefType->getNumDynamicDims();
+ unsigned numDynamicDims = memRefType.getNumDynamicDims();
// Check that the total number of operands matches the number of symbols in
// the affine map, plus the number of dynamic dimensions specified in the
// memref type.
@@ -208,7 +208,7 @@ bool AllocOp::verify() const {
}
// Verify that all operands are of type Index.
for (auto *operand : getOperands()) {
- if (!operand->getType()->isIndex())
+ if (!operand->getType().isIndex())
return emitOpError("requires operands to be of type Index");
}
return false;
@@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern {
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType->getRank());
+ newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands;
unsigned dynamicDimPos = 0;
- for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
- int dimSize = memrefType->getDimSize(dim);
+ for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+ int dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it.
if (dimSize != -1) {
newShapeConstants.push_back(dimSize);
@@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern {
}
// Create new memref type (which will have fewer dynamic dimensions).
- auto *newMemRefType = MemRefType::get(
- newShapeConstants, memrefType->getElementType(),
- memrefType->getAffineMaps(), memrefType->getMemorySpace());
- assert(newOperands.size() == newMemRefType->getNumDynamicDims());
+ auto newMemRefType = MemRefType::get(
+ newShapeConstants, memrefType.getElementType(),
+ memrefType.getAffineMaps(), memrefType.getMemorySpace());
+ assert(newOperands.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
auto newAlloc =
@@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
- result->addTypes(callee->getType()->getResults());
+ result->addTypes(callee->getType().getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
- FunctionType *calleeType = nullptr;
+ FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
@@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
- parser->addTypesToList(calleeType->getResults(), result->types) ||
- parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+ parser->addTypesToList(calleeType.getResults(), result->types) ||
+ parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return true;
@@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const {
p->printOperands(getOperands());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
+ *p << " : " << getCallee()->getType();
}
bool CallOp::verify() const {
@@ -338,20 +338,20 @@ bool CallOp::verify() const {
return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee.
- auto *fnType = fnAttr.getValue()->getType();
- if (fnType->getNumInputs() != getNumOperands())
+ auto fnType = fnAttr.getValue()->getType();
+ if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i)->getType() != fnType->getInput(i))
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
- if (fnType->getNumResults() != getNumResults())
+ if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@@ -364,14 +364,14 @@ bool CallOp::verify() const {
void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) {
- auto *fnType = cast<FunctionType>(callee->getType());
+ auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee);
result->addOperands(operands);
- result->addTypes(fnType->getResults());
+ result->addTypes(fnType.getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
- FunctionType *calleeType = nullptr;
+ FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
@@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
- parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+ parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) ||
- parser->addTypesToList(calleeType->getResults(), result->types);
+ parser->addTypesToList(calleeType.getResults(), result->types);
}
void CallIndirectOp::print(OpAsmPrinter *p) const {
@@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const {
p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
+ *p << " : " << getCallee()->getType();
}
bool CallIndirectOp::verify() const {
// The callee must be a function.
- auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+ auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
if (!fnType)
return emitOpError("callee must have function type");
// Verify that the operand and result types match the callee.
- if (fnType->getNumInputs() != getNumOperands() - 1)
+ if (fnType.getNumInputs() != getNumOperands() - 1)
return emitOpError("incorrect number of operands for callee");
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i + 1)->getType() != fnType->getInput(i))
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i + 1)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
- if (fnType->getNumResults() != getNumResults())
+ if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result,
}
void DeallocOp::print(OpAsmPrinter *p) const {
- *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
+ *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
}
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
- MemRefType *type;
+ MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands);
}
bool DeallocOp::verify() const {
- if (!isa<MemRefType>(getMemRef()->getType()))
+ if (!getMemRef()->getType().isa<MemRefType>())
return emitOpError("operand must be a memref");
return false;
}
@@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result,
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
- *p << " : " << *getOperand()->getType();
+ *p << " : " << getOperand()->getType();
}
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
- Type *type;
+ Type type;
return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
@@ -496,15 +496,15 @@ bool DimOp::verify() const {
return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr.getValue();
- auto *type = getOperand()->getType();
- if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
- if (index >= tensorType->getRank())
+ auto type = getOperand()->getType();
+ if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ if (index >= tensorType.getRank())
return emitOpError("index is out of range");
- } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
- if (index >= memrefType->getRank())
+ } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+ if (index >= memrefType.getRank())
return emitOpError("index is out of range");
- } else if (isa<UnrankedTensorType>(type)) {
+ } else if (type.isa<UnrankedTensorType>()) {
// ok, assumed to be in-range.
} else {
return emitOpError("requires an operand with tensor or memref type");
@@ -516,12 +516,12 @@ bool DimOp::verify() const {
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant.
- auto *opType = getOperand()->getType();
+ auto opType = getOperand()->getType();
int indexSize = -1;
- if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
- indexSize = tensorType->getShape()[getIndex()];
- } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
- indexSize = memrefType->getShape()[getIndex()];
+ if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
+ indexSize = tensorType.getShape()[getIndex()];
+ } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
+ indexSize = memrefType.getShape()[getIndex()];
}
if (indexSize >= 0)
@@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getSrcMemRef()->getType();
- *p << ", " << *getDstMemRef()->getType();
- *p << ", " << *getTagMemRef()->getType();
+ *p << " : " << getSrcMemRef()->getType();
+ *p << ", " << getDstMemRef()->getType();
+ *p << ", " << getTagMemRef()->getType();
}
// Parse DmaStartOp.
@@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
- SmallVector<Type *, 3> types;
- auto *indexType = parser->getBuilder().getIndexType();
+ SmallVector<Type, 3> types;
+ auto indexType = parser->getBuilder().getIndexType();
// Parse and resolve the following list of operands:
// *) source memref followed by its indices (in square brackets).
@@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return true;
// Check that source/destination index list size matches associated rank.
- if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
- dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
+ if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
+ dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"memref rank not equal to indices count");
- if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
+ if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices());
*p << "], ";
p->printOperand(getNumElements());
- *p << " : " << *getTagMemRef()->getType();
+ *p << " : " << getTagMemRef()->getType();
}
// Parse DmaWaitOp.
@@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
- Type *type;
- auto *indexType = parser->getBuilder().getIndexType();
+ Type type;
+ auto indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its indices, and dma size.
@@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
- if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
+ if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate,
ArrayRef<SSAValue *> indices) {
- auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+ auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate);
result->addOperands(indices);
- result->types.push_back(aggregateType->getElementType());
+ result->types.push_back(aggregateType.getElementType());
}
void ExtractElementOp::print(OpAsmPrinter *p) const {
@@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getAggregate()->getType();
+ *p << " : " << getAggregate()->getType();
}
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- VectorOrTensorType *type;
+ VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) ||
@@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
+ parser->addTypeToList(type.getElementType(), result->types);
}
bool ExtractElementOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected an aggregate to index into");
- auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+ auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
if (!aggregateType)
return emitOpError("first operand must be a vector or tensor");
- if (getType() != aggregateType->getElementType())
+ if (getType() != aggregateType.getElementType())
return emitOpError("result type must match element type of aggregate");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
- auto aggregateRank = aggregateType->getRank();
+ auto aggregateRank = aggregateType.getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element");
@@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const {
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices) {
- auto *memrefType = cast<MemRefType>(memref->getType());
+ auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref);
result->addOperands(indices);
- result->types.push_back(memrefType->getElementType());
+ result->types.push_back(memrefType.getElementType());
}
void LoadOp::print(OpAsmPrinter *p) const {
@@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRefType();
+ *p << " : " << getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *type;
+ MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) ||
@@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
+ parser->addTypeToList(type.getElementType(), result->types);
}
bool LoadOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected a memref to load from");
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+ auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("first operand must be a memref");
- if (getType() != memRefType->getElementType())
+ if (getType() != memRefType.getElementType())
return emitOpError("result type must match element type of memref");
- if (memRefType->getRank() != getNumOperands() - 1)
+ if (memRefType.getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
bool MemRefCastOp::verify() const {
- auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
- auto *resType = dyn_cast<MemRefType>(getType());
+ auto opType = getOperand()->getType().dyn_cast<MemRefType>();
+ auto resType = getType().dyn_cast<MemRefType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be memrefs");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
- if (opType->getElementType() != resType->getElementType())
+ if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
- if (opType->getAffineMaps() != resType->getAffineMaps())
+ if (opType.getAffineMaps() != resType.getAffineMaps())
return emitOpError("requires input and result mappings to be the same");
- if (opType->getMemorySpace() != resType->getMemorySpace())
+ if (opType.getMemorySpace() != resType.getMemorySpace())
return emitOpError(
"requires input and result memory spaces to be the same");
// They must have the same rank, and any specified dimensions must match.
- if (opType->getRank() != resType->getRank())
+ if (opType.getRank() != resType.getRank())
return emitOpError("requires input and result ranks to match");
- for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
- int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
+ for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
+ int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}
@@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRefType();
+ *p << " : " << getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *memrefType;
+ MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
- parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
+ parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@@ -950,19 +950,19 @@ bool StoreOp::verify() const {
return emitOpError("expected a value to store and a memref");
// Second operand is a memref type.
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+ auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("second operand must be a memref");
// First operand must have same type as memref element type.
- if (getValueToStore()->getType() != memRefType->getElementType())
+ if (getValueToStore()->getType() != memRefType.getElementType())
return emitOpError("first operand must have same type memref element type");
- if (getNumOperands() != 2 + memRefType->getRank())
+ if (getNumOperands() != 2 + memRefType.getRank())
return emitOpError("store index operand count not equal to memref rank");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
bool TensorCastOp::verify() const {
- auto *opType = dyn_cast<TensorType>(getOperand()->getType());
- auto *resType = dyn_cast<TensorType>(getType());
+ auto opType = getOperand()->getType().dyn_cast<TensorType>();
+ auto resType = getType().dyn_cast<TensorType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be tensors");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
- if (opType->getElementType() != resType->getElementType())
+ if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
// If the source or destination are unranked, then the cast is valid.
- auto *opRType = dyn_cast<RankedTensorType>(opType);
- auto *resRType = dyn_cast<RankedTensorType>(resType);
+ auto opRType = opType.dyn_cast<RankedTensorType>();
+ auto resRType = resType.dyn_cast<RankedTensorType>();
if (!opRType || !resRType)
return false;
// If they are both ranked, they have to have the same rank, and any specified
// dimensions must match.
- if (opRType->getRank() != resRType->getRank())
+ if (opRType.getRank() != resRType.getRank())
return emitOpError("requires input and result ranks to match");
- for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
- int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+ for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
+ int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}
OpenPOWER on IntegriCloud