summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp99
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp10
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h8
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp50
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp8
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp72
-rw-r--r--mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/EDSC/Builders.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp36
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp6
-rw-r--r--mlir/lib/Dialect/LoopOps/LoopOps.cpp19
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp4
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp160
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp8
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp2
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp166
-rw-r--r--mlir/lib/Dialect/Traits.cpp6
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp109
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp30
26 files changed, 425 insertions, 443 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 5f4cc2e1060..0da539ea9f0 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -108,8 +108,8 @@ static bool isFunctionRegion(Region *region) {
/// symbol.
bool mlir::isTopLevelValue(Value value) {
if (auto arg = value.dyn_cast<BlockArgument>())
- return isFunctionRegion(arg->getOwner()->getParent());
- return isFunctionRegion(value->getDefiningOp()->getParentRegion());
+ return isFunctionRegion(arg.getOwner()->getParent());
+ return isFunctionRegion(value.getDefiningOp()->getParentRegion());
}
// Value can be used as a dimension id if it is valid as a symbol, or
@@ -117,10 +117,10 @@ bool mlir::isTopLevelValue(Value value) {
// with dimension id arguments.
bool mlir::isValidDim(Value value) {
// The value must be an index type.
- if (!value->getType().isIndex())
+ if (!value.getType().isIndex())
return false;
- if (auto *op = value->getDefiningOp()) {
+ if (auto *op = value.getDefiningOp()) {
// Top level operation or constant operation is ok.
if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
return true;
@@ -134,7 +134,7 @@ bool mlir::isValidDim(Value value) {
return false;
}
// This value has to be a block argument for a FuncOp or an affine.for.
- auto *parentOp = value.cast<BlockArgument>()->getOwner()->getParentOp();
+ auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
}
@@ -162,11 +162,11 @@ static bool isDimOpValidSymbol(DimOp dimOp) {
// The dim op is also okay if its operand memref/tensor is a view/subview
// whose corresponding size is a valid symbol.
unsigned index = dimOp.getIndex();
- if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp()))
+ if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp()))
return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
- if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp()))
+ if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp()))
return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
- if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp()))
+ if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp()))
return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
return false;
}
@@ -177,10 +177,10 @@ static bool isDimOpValidSymbol(DimOp dimOp) {
// constraints.
bool mlir::isValidSymbol(Value value) {
// The value must be an index type.
- if (!value->getType().isIndex())
+ if (!value.getType().isIndex())
return false;
- if (auto *op = value->getDefiningOp()) {
+ if (auto *op = value.getDefiningOp()) {
// Top level operation or constant operation is ok.
if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
return true;
@@ -283,7 +283,7 @@ LogicalResult AffineApplyOp::verify() {
return emitOpError("operands must be of type 'index'");
}
- if (!getResult()->getType().isIndex())
+ if (!getResult().getType().isIndex())
return emitOpError("result must be of type 'index'");
// Verify that the map only produces one result.
@@ -332,7 +332,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
if (inserted) {
reorderedDims.push_back(v);
}
- return getAffineDimExpr(iterPos->second, v->getContext())
+ return getAffineDimExpr(iterPos->second, v.getContext())
.cast<AffineDimExpr>();
}
@@ -365,7 +365,7 @@ static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands) {
llvm::SetVector<unsigned> res;
for (auto en : llvm::enumerate(operands))
- if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
+ if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
res.insert(en.index());
return res;
}
@@ -487,7 +487,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
// 1. Only dispatch dims or symbols.
for (auto en : llvm::enumerate(operands)) {
auto t = en.value();
- assert(t->getType().isIndex());
+ assert(t.getType().isIndex());
bool isDim = (en.index() < map.getNumDims());
if (isDim) {
// a. The mathematical composition of AffineMap composes dims.
@@ -503,7 +503,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
// 2. Compose AffineApplyOps and dispatch dims or symbols.
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
auto t = operands[i];
- auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
+ auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
if (affineApply) {
// a. Compose affine.apply operations.
LLVM_DEBUG(affineApply.getOperation()->print(
@@ -588,7 +588,7 @@ static void composeAffineMapAndOperands(AffineMap *map,
void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands) {
while (llvm::any_of(*operands, [](Value v) {
- return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
+ return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
})) {
composeAffineMapAndOperands(map, operands);
}
@@ -819,8 +819,8 @@ void AffineApplyOp::getCanonicalizationPatterns(
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
- auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
- if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+ if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
@@ -856,16 +856,16 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result,
}
void AffineDmaStartOp::print(OpAsmPrinter &p) {
- p << "affine.dma_start " << *getSrcMemRef() << '[';
+ p << "affine.dma_start " << getSrcMemRef() << '[';
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
- p << "], " << *getDstMemRef() << '[';
+ p << "], " << getDstMemRef() << '[';
p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
- p << "], " << *getTagMemRef() << '[';
+ p << "], " << getTagMemRef() << '[';
p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
- p << "], " << *getNumElements();
+ p << "], " << getNumElements();
if (isStrided()) {
- p << ", " << *getStride();
- p << ", " << *getNumElementsPerStride();
+ p << ", " << getStride();
+ p << ", " << getNumElementsPerStride();
}
p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
<< getTagMemRefType();
@@ -951,11 +951,11 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
}
LogicalResult AffineDmaStartOp::verify() {
- if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
+ if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA source to be of memref type");
- if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
+ if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA destination to be of memref type");
- if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
+ if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
// DMAs from different memory spaces supported.
@@ -971,19 +971,19 @@ LogicalResult AffineDmaStartOp::verify() {
}
for (auto idx : getSrcIndices()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("src index must be a dimension or symbol identifier");
}
for (auto idx : getDstIndices()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("dst index must be a dimension or symbol identifier");
}
for (auto idx : getTagIndices()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("tag index must be a dimension or symbol identifier");
@@ -1012,12 +1012,12 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
}
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
- p << "affine.dma_wait " << *getTagMemRef() << '[';
+ p << "affine.dma_wait " << getTagMemRef() << '[';
SmallVector<Value, 2> operands(getTagIndices());
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
p << "], ";
p.printOperand(getNumElements());
- p << " : " << getTagMemRef()->getType();
+ p << " : " << getTagMemRef().getType();
}
// Parse AffineDmaWaitOp.
@@ -1056,10 +1056,10 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
}
LogicalResult AffineDmaWaitOp::verify() {
- if (!getOperand(0)->getType().isa<MemRefType>())
+ if (!getOperand(0).getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
for (auto idx : getTagIndices()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("index must be a dimension or symbol identifier");
@@ -1123,8 +1123,7 @@ static LogicalResult verify(AffineForOp op) {
// Check that the body defines as single block argument for the induction
// variable.
auto *body = op.getBody();
- if (body->getNumArguments() != 1 ||
- !body->getArgument(0)->getType().isIndex())
+ if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
return op.emitOpError(
"expected body to have a single index argument for the "
"induction variable");
@@ -1553,7 +1552,7 @@ bool AffineForOp::matchingBoundOperandList() {
Region &AffineForOp::getLoopBody() { return region(); }
bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
- return !region().isAncestor(value->getParentRegion());
+ return !region().isAncestor(value.getParentRegion());
}
LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
@@ -1571,9 +1570,9 @@ bool mlir::isForInductionVar(Value val) {
/// not an induction variable, then return nullptr.
AffineForOp mlir::getForInductionVarOwner(Value val) {
auto ivArg = val.dyn_cast<BlockArgument>();
- if (!ivArg || !ivArg->getOwner())
+ if (!ivArg || !ivArg.getOwner())
return AffineForOp();
- auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
+ auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
return dyn_cast<AffineForOp>(containingInst);
}
@@ -1744,7 +1743,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result,
result.addOperands(operands);
if (map)
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
- auto memrefType = operands[0]->getType().cast<MemRefType>();
+ auto memrefType = operands[0].getType().cast<MemRefType>();
result.types.push_back(memrefType.getElementType());
}
@@ -1753,14 +1752,14 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
- auto memrefType = memref->getType().cast<MemRefType>();
+ auto memrefType = memref.getType().cast<MemRefType>();
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
result.types.push_back(memrefType.getElementType());
}
void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
ValueRange indices) {
- auto memrefType = memref->getType().cast<MemRefType>();
+ auto memrefType = memref.getType().cast<MemRefType>();
auto rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
@@ -1789,7 +1788,7 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AffineLoadOp::print(OpAsmPrinter &p) {
- p << "affine.load " << *getMemRef() << '[';
+ p << "affine.load " << getMemRef() << '[';
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
p << ']';
@@ -1816,7 +1815,7 @@ LogicalResult AffineLoadOp::verify() {
}
for (auto idx : getMapOperands()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("index must be a dimension or symbol identifier");
@@ -1854,7 +1853,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result,
void AffineStoreOp::build(Builder *builder, OperationState &result,
Value valueToStore, Value memref,
ValueRange indices) {
- auto memrefType = memref->getType().cast<MemRefType>();
+ auto memrefType = memref.getType().cast<MemRefType>();
auto rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
@@ -1885,8 +1884,8 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AffineStoreOp::print(OpAsmPrinter &p) {
- p << "affine.store " << *getValueToStore();
- p << ", " << *getMemRef() << '[';
+ p << "affine.store " << getValueToStore();
+ p << ", " << getMemRef() << '[';
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
p << ']';
@@ -1896,7 +1895,7 @@ void AffineStoreOp::print(OpAsmPrinter &p) {
LogicalResult AffineStoreOp::verify() {
// First operand must have same type as memref element type.
- if (getValueToStore()->getType() != getMemRefType().getElementType())
+ if (getValueToStore().getType() != getMemRefType().getElementType())
return emitOpError("first operand must have same type memref element type");
auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
@@ -1914,7 +1913,7 @@ LogicalResult AffineStoreOp::verify() {
}
for (auto idx : getMapOperands()) {
- if (!idx->getType().isIndex())
+ if (!idx.getType().isIndex())
return emitOpError("index to store must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("index must be a dimension or symbol identifier");
@@ -2059,7 +2058,7 @@ static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
}
void print(OpAsmPrinter &p, AffinePrefetchOp op) {
- p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '[';
+ p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
SmallVector<Value, 2> operands(op.getMapOperands());
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index df6015de1b9..1dfa2460b61 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -47,8 +47,8 @@ static Value emitUniformPerLayerDequantize(Location loc, Value input,
return nullptr;
}
- Type storageType = elementType.castToStorageType(input->getType());
- Type realType = elementType.castToExpressedType(input->getType());
+ Type storageType = elementType.castToStorageType(input.getType());
+ Type realType = elementType.castToExpressedType(input.getType());
Type intermediateType =
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
assert(storageType && "cannot cast to storage type");
@@ -90,7 +90,7 @@ emitUniformPerAxisDequantize(Location loc, Value input,
static Value emitDequantize(Location loc, Value input,
PatternRewriter &rewriter) {
- Type inputType = input->getType();
+ Type inputType = input.getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
if (auto uperLayerElementType =
@@ -113,8 +113,8 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const override {
- Type inputType = op.arg()->getType();
- Type outputType = op.getResult()->getType();
+ Type inputType = op.arg().getType();
+ Type outputType = op.getResult().getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
index 8cea97c693c..5eb7492c424 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -53,11 +53,11 @@ struct UniformBinaryOpInfo {
UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
- lhsType(getUniformElementType(lhs->getType())),
- rhsType(getUniformElementType(rhs->getType())),
+ lhsType(getUniformElementType(lhs.getType())),
+ rhsType(getUniformElementType(rhs.getType())),
resultType(getUniformElementType(*op->result_type_begin())),
- lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
- rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
+ lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
+ rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
resultStorageType(
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 32d7fae65d9..e750d0fefff 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -110,7 +110,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// to encode target module" has landed.
// auto functionType = kernelFunc.getType();
// for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
- // if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
+ // if (getKernelOperand(i).getType() != functionType.getInput(i)) {
// return emitOpError("type of function argument ")
// << i << " does not match";
// }
@@ -137,7 +137,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
if (allReduce.body().front().getNumArguments() != 2)
return allReduce.emitError("expected two region arguments");
for (auto argument : allReduce.body().front().getArguments()) {
- if (argument->getType() != allReduce.getType())
+ if (argument.getType() != allReduce.getType())
return allReduce.emitError("incorrect region argument type");
}
unsigned yieldCount = 0;
@@ -145,7 +145,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
if (yield.getNumOperands() != 1)
return allReduce.emitError("expected one gpu.yield operand");
- if (yield.getOperand(0)->getType() != allReduce.getType())
+ if (yield.getOperand(0).getType() != allReduce.getType())
return allReduce.emitError("incorrect gpu.yield type");
++yieldCount;
}
@@ -157,8 +157,8 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
}
static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
- auto type = shuffleOp.value()->getType();
- if (shuffleOp.result()->getType() != type) {
+ auto type = shuffleOp.value().getType();
+ if (shuffleOp.result().getType() != type) {
return shuffleOp.emitOpError()
<< "requires the same type for value operand and result";
}
@@ -170,10 +170,8 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
}
static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
- p << ShuffleOp::getOperationName() << ' ';
- p.printOperands(op.getOperands());
- p << ' ' << op.mode() << " : ";
- p.printType(op.value()->getType());
+ p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' '
+ << op.mode() << " : " << op.value().getType();
}
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
@@ -201,14 +199,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
// LaunchOp
//===----------------------------------------------------------------------===//
-static SmallVector<Type, 4> getValueTypes(ValueRange values) {
- SmallVector<Type, 4> types;
- types.reserve(values.size());
- for (Value v : values)
- types.push_back(v->getType());
- return types;
-}
-
void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
Value gridSizeY, Value gridSizeZ, Value blockSizeX,
Value blockSizeY, Value blockSizeZ, ValueRange operands) {
@@ -224,7 +214,7 @@ void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
Block *body = new Block();
body->addArguments(
std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
- body->addArguments(getValueTypes(operands));
+ body->addArguments(llvm::to_vector<4>(operands.getTypes()));
kernelRegion->push_back(body);
}
@@ -309,10 +299,10 @@ LogicalResult verify(LaunchOp op) {
// where %size-* and %iter-* will correspond to the body region arguments.
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
ValueRange operands, KernelDim3 ids) {
- p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
- p << *size.x << " = " << *operands[0] << ", ";
- p << *size.y << " = " << *operands[1] << ", ";
- p << *size.z << " = " << *operands[2] << ')';
+ p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
+ p << size.x << " = " << operands[0] << ", ";
+ p << size.y << " = " << operands[1] << ", ";
+ p << size.z << " = " << operands[2] << ')';
}
void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
@@ -335,8 +325,8 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
p << ' ' << op.getArgsKeyword() << '(';
Block *entryBlock = &op.body().front();
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
- p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
- << " = " << *operands[i];
+ p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
+ << " = " << operands[i];
});
p << ") ";
}
@@ -486,14 +476,14 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
Value operand = operands[index];
- if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
+ if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp()))
continue;
found = true;
Value internalConstant =
- rewriter.clone(*operand->getDefiningOp())->getResult(0);
+ rewriter.clone(*operand.getDefiningOp())->getResult(0);
Value kernelArg = *std::next(kernelArgs.begin(), index);
- kernelArg->replaceAllUsesWith(internalConstant);
+ kernelArg.replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
@@ -740,7 +730,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
p << ' ' << keyword << '(';
interleaveComma(values, p,
- [&p](BlockArgument v) { p << *v << " : " << v->getType(); });
+ [&p](BlockArgument v) { p << v << " : " << v.getType(); });
p << ')';
}
@@ -790,7 +780,7 @@ static LogicalResult verifyAttributions(Operation *op,
ArrayRef<BlockArgument> attributions,
unsigned memorySpace) {
for (Value v : attributions) {
- auto type = v->getType().dyn_cast<MemRefType>();
+ auto type = v.getType().dyn_cast<MemRefType>();
if (!type)
return op->emitOpError() << "expected memref type in attribution";
@@ -814,7 +804,7 @@ LogicalResult GPUFuncOp::verifyBody() {
ArrayRef<Type> funcArgTypes = getType().getInputs();
for (unsigned i = 0; i < numFuncArguments; ++i) {
- Type blockArgType = front().getArgument(i)->getType();
+ Type blockArgType = front().getArgument(i).getType();
if (funcArgTypes[i] != blockArgType)
return emitOpError() << "expected body region argument #" << i
<< " to be of type " << funcArgTypes[i] << ", got "
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 2d00ac03d33..37f9c2e7b84 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -45,7 +45,7 @@ static void injectGpuIndexOperations(Location loc, Region &body) {
// Replace the leading 12 function args with the respective thread/block index
// operations. Iterate backwards since args are erased and indices change.
for (int i = 11; i >= 0; --i) {
- firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]);
+ firstBlock.getArgument(i).replaceAllUsesWith(indexOps[i]);
firstBlock.eraseArgument(i);
}
}
@@ -66,7 +66,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
}
for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
- auto operandOp = launch.getKernelOperand(i)->getDefiningOp();
+ auto operandOp = launch.getKernelOperand(i).getDefiningOp();
if (!operandOp || !isInliningBeneficiary(operandOp)) {
newLaunchArgs.push_back(launch.getKernelOperand(i));
continue;
@@ -77,7 +77,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
continue;
}
auto clone = kernelBuilder.clone(*operandOp, map);
- firstBlock.getArgument(i)->replaceAllUsesWith(clone->getResult(0));
+ firstBlock.getArgument(i).replaceAllUsesWith(clone->getResult(0));
firstBlock.eraseArgument(i);
}
if (newLaunchArgs.size() == launch.getNumKernelOperands())
@@ -88,7 +88,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
SmallVector<Type, 8> newArgumentTypes;
newArgumentTypes.reserve(firstBlock.getNumArguments());
for (auto value : firstBlock.getArguments()) {
- newArgumentTypes.push_back(value->getType());
+ newArgumentTypes.push_back(value.getType());
}
kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {}));
auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 71b7064ac63..5c96581741b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -35,16 +35,16 @@ using namespace mlir::LLVM;
//===----------------------------------------------------------------------===//
static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
- << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+ << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
- p << " : " << op.lhs()->getType();
+ p << " : " << op.lhs().getType();
}
static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
- << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+ << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
- p << " : " << op.lhs()->getType();
+ p << " : " << op.lhs().getType();
}
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
@@ -120,10 +120,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
- auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
+ auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
op.getContext());
- p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
+ p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
p.printOptionalAttrDict(op.getAttrs());
else
@@ -168,7 +168,7 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
- p << op.getOperationName() << ' ' << *op.base() << '['
+ p << op.getOperationName() << ' ' << op.base() << '['
<< op.getOperands().drop_front() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << funcTy;
@@ -212,9 +212,9 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
- p << op.getOperationName() << ' ' << *op.addr();
+ p << op.getOperationName() << ' ' << op.addr();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.addr()->getType();
+ p << " : " << op.addr().getType();
}
// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
@@ -256,9 +256,9 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
- p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
+ p << op.getOperationName() << ' ' << op.value() << ", " << op.addr();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.addr()->getType();
+ p << " : " << op.addr().getType();
}
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
@@ -300,7 +300,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
if (isDirect)
p.printSymbolName(callee.getValue());
else
- p << *op.getOperand(0);
+ p << op.getOperand(0);
p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
@@ -408,17 +408,17 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
void LLVM::ExtractElementOp::build(Builder *b, OperationState &result,
Value vector, Value position,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>();
+ auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
auto llvmType = wrappedVectorType.getVectorElementType();
build(b, result, llvmType, vector, position);
result.addAttributes(attrs);
}
static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
- p << op.getOperationName() << ' ' << *op.vector() << "[" << *op.position()
- << " : " << op.position()->getType() << "]";
+ p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
+ << " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.vector()->getType();
+ p << " : " << op.vector().getType();
}
// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
@@ -450,9 +450,9 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
- p << op.getOperationName() << ' ' << *op.container() << op.position();
+ p << op.getOperationName() << ' ' << op.container() << op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
- p << " : " << op.container()->getType();
+ p << " : " << op.container().getType();
}
// Extract the type at `position` in the wrapped LLVM IR aggregate type
@@ -542,10 +542,10 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
- p << op.getOperationName() << ' ' << *op.value() << ", " << *op.vector()
- << "[" << *op.position() << " : " << op.position()->getType() << "]";
+ p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
+ << op.position() << " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.vector()->getType();
+ p << " : " << op.vector().getType();
}
// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
@@ -586,10 +586,10 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
- p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
+ p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
- p << " : " << op.container()->getType();
+ p << " : " << op.container().getType();
}
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
@@ -629,10 +629,10 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static void printSelectOp(OpAsmPrinter &p, SelectOp &op) {
- p << op.getOperationName() << ' ' << *op.condition() << ", "
- << *op.trueValue() << ", " << *op.falseValue();
+ p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue()
+ << ", " << op.falseValue();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
+ p << " : " << op.condition().getType() << ", " << op.trueValue().getType();
}
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
@@ -686,7 +686,7 @@ static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
- p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
+ p << op.getOperationName() << ' ' << op.getOperand(0) << ", ";
p.printSuccessorAndUseList(op.getOperation(), 0);
p << ", ";
p.printSuccessorAndUseList(op.getOperation(), 1);
@@ -733,7 +733,7 @@ static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
if (op.getNumOperands() == 0)
return;
- p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
+ p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
}
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
@@ -761,7 +761,7 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
static void printUndefOp(OpAsmPrinter &p, UndefOp &op) {
p << op.getOperationName();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.res()->getType();
+ p << " : " << op.res().getType();
}
// <operation> ::= `llvm.mlir.undef` attribute-dict? : type
@@ -792,7 +792,7 @@ GlobalOp AddressOfOp::getGlobal() {
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
p << op.getOperationName() << " @" << op.global_name();
p.printOptionalAttrDict(op.getAttrs(), {"global_name"});
- p << " : " << op.getResult()->getType();
+ p << " : " << op.getResult().getType();
}
static ParseResult parseAddressOfOp(OpAsmParser &parser,
@@ -816,7 +816,7 @@ static LogicalResult verify(AddressOfOp op) {
"must reference a global defined by 'llvm.mlir.global'");
if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
- op.getResult()->getType())
+ op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referred global");
@@ -830,7 +830,7 @@ static LogicalResult verify(AddressOfOp op) {
static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
p << op.getOperationName() << '(' << op.value() << ')';
p.printOptionalAttrDict(op.getAttrs(), {"value"});
- p << " : " << op.res()->getType();
+ p << " : " << op.res().getType();
}
// <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type
@@ -1060,7 +1060,7 @@ static LogicalResult verify(GlobalOp op) {
void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
Value v2, ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>();
+ auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
auto vType = LLVMType::getVectorTy(
wrappedContainerType1.getVectorElementType(), mask.size());
build(b, result, vType, v1, v2, mask);
@@ -1068,10 +1068,10 @@ void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
}
static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
- p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
+ p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
<< op.mask();
p.printOptionalAttrDict(op.getAttrs(), {"mask"});
- p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
+ p << " : " << op.v1().getType() << ", " << op.v2().getType();
}
// <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
@@ -1329,7 +1329,7 @@ static LogicalResult verify(LLVMFuncOp op) {
unsigned numArguments = funcType->getNumParams();
Block &entryBlock = op.front();
for (unsigned i = 0; i < numArguments; ++i) {
- Type argType = entryBlock.getArgument(i)->getType();
+ Type argType = entryBlock.getArgument(i).getType();
auto argLLVMType = argType.dyn_cast<LLVMType>();
if (!argLLVMType)
return op.emitOpError("entry block argument #")
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 7644cc69218..144afa4c5e1 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -48,30 +48,30 @@ Value Aliases::find(Value v) {
auto it = aliases.find(v);
if (it != aliases.end()) {
- assert(it->getSecond()->getType().isa<MemRefType>() && "Memref expected");
+ assert(it->getSecond().getType().isa<MemRefType>() && "Memref expected");
return it->getSecond();
}
while (true) {
if (v.isa<BlockArgument>())
return v;
- if (auto alloc = dyn_cast_or_null<AllocOp>(v->getDefiningOp())) {
+ if (auto alloc = dyn_cast_or_null<AllocOp>(v.getDefiningOp())) {
if (isStrided(alloc.getType()))
return alloc.getResult();
}
- if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
+ if (auto slice = dyn_cast_or_null<SliceOp>(v.getDefiningOp())) {
auto it = aliases.insert(std::make_pair(v, find(slice.view())));
return it.first->second;
}
- if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
+ if (auto view = dyn_cast_or_null<ViewOp>(v.getDefiningOp())) {
auto it = aliases.insert(std::make_pair(v, view.source()));
return it.first->second;
}
- if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
+ if (auto view = dyn_cast_or_null<SubViewOp>(v.getDefiningOp())) {
v = view.source();
continue;
}
- llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
+ llvm::errs() << "View alias analysis reduces to: " << v << "\n";
llvm_unreachable("unsupported view alias case");
}
}
@@ -224,7 +224,7 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
auto *op = dependence.dependentOpView.op;
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
<< toStringRef(dt) << ": " << *src << " -> " << *op
- << " on " << *dependence.indexingView);
+ << " on " << dependence.indexingView);
res.push_back(op);
}
}
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index b35a8ed0fd8..9b850431113 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -88,7 +88,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
for (auto it : llvm::enumerate(values))
blockTypes.push_back((it.index() < nViews)
? getElementTypeOrSelf(it.value())
- : it.value()->getType());
+ : it.value().getType());
assert(op->getRegions().front().empty());
op->getRegions().front().push_front(new Block);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index efbb95e7319..cf27a817edb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -120,7 +120,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
for (unsigned i = 0; i < nViews; ++i) {
auto viewType = op.getShapedType(i);
- if (viewType.getElementType() != block.getArgument(i)->getType())
+ if (viewType.getElementType() != block.getArgument(i).getType())
return op.emitOpError("expected block argument ")
<< i << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
@@ -139,7 +139,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
"number of loops");
for (unsigned i = 0; i < nLoops; ++i) {
- if (!block.getArgument(i)->getType().isIndex())
+ if (!block.getArgument(i).getType().isIndex())
return op.emitOpError("expected block argument ")
<< i << " to be of IndexType";
}
@@ -148,7 +148,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
unsigned memrefArgIndex = i + nLoops;
auto viewType = op.getShapedType(i);
if (viewType.getElementType() !=
- block.getArgument(memrefArgIndex)->getType())
+ block.getArgument(memrefArgIndex).getType())
return op.emitOpError("expected block argument ")
<< memrefArgIndex << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
@@ -314,10 +314,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RangeOp op) {
- p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":"
- << *op.step();
+ p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":"
+ << op.step();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getResult()->getType();
+ p << " : " << op.getResult().getType();
}
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
@@ -541,7 +541,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
result.addOperands(base);
result.addOperands(indexings);
- auto memRefType = base->getType().cast<MemRefType>();
+ auto memRefType = base.getType().cast<MemRefType>();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(memRefType, strides, offset);
@@ -560,7 +560,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
static void print(OpAsmPrinter &p, SliceOp op) {
auto indexings = op.indexings();
- p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
+ p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings
<< "] ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getBaseViewType();
@@ -599,7 +599,7 @@ static LogicalResult verify(SliceOp op) {
<< rank << " indexings, got " << llvm::size(op.indexings());
unsigned index = 0;
for (auto indexing : op.indexings()) {
- if (indexing->getType().isa<IndexType>())
+ if (indexing.getType().isa<IndexType>())
--rank;
++index;
}
@@ -618,7 +618,7 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
auto permutationMap = permutation.getValue();
assert(permutationMap);
- auto memRefType = view->getType().cast<MemRefType>();
+ auto memRefType = view.getType().cast<MemRefType>();
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
// Compute permuted sizes.
@@ -644,10 +644,10 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
}
static void print(OpAsmPrinter &p, TransposeOp op) {
- p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
+ p << op.getOperationName() << " " << op.view() << " " << op.permutation();
p.printOptionalAttrDict(op.getAttrs(),
{TransposeOp::getPermutationAttrName()});
- p << " : " << op.view()->getType();
+ p << " : " << op.view().getType();
}
static ParseResult parseTransposeOp(OpAsmParser &parser,
@@ -698,9 +698,9 @@ LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
for (unsigned i = 0; i != nOutputViews; ++i) {
auto elementType = genericOp.getOutputShapedType(i).getElementType();
- if (op.getOperand(i)->getType() != elementType)
+ if (op.getOperand(i).getType() != elementType)
return op.emitOpError("type of return operand ")
- << i << " (" << op.getOperand(i)->getType()
+ << i << " (" << op.getOperand(i).getType()
<< ") doesn't match view element type (" << elementType << ")";
}
return success();
@@ -765,7 +765,7 @@ static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
- auto fillType = op.value()->getType();
+ auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
return success();
@@ -816,9 +816,9 @@ verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
}
static LogicalResult verify(ConvOp op) {
- auto oType = op.output()->getType().cast<MemRefType>();
- auto fType = op.filter()->getType().cast<MemRefType>();
- auto iType = op.input()->getType().cast<MemRefType>();
+ auto oType = op.output().getType().cast<MemRefType>();
+ auto fType = op.filter().getType().cast<MemRefType>();
+ auto iType = op.input().getType().cast<MemRefType>();
if (oType.getElementType() != iType.getElementType() ||
oType.getElementType() != fType.getElementType())
return op.emitOpError("expects memref elemental types to match");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9df7bce0879..043d9c0e7cd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -133,8 +133,7 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
<< "\n");
- LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
- << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
return ViewDimension{view, static_cast<unsigned>(en2.index())};
}
}
@@ -146,9 +145,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, unsigned producerIdx,
OperationFolder *folder) {
auto subView = dyn_cast_or_null<SubViewOp>(
- consumer.getInput(consumerIdx)->getDefiningOp());
- auto slice = dyn_cast_or_null<SliceOp>(
- consumer.getInput(consumerIdx)->getDefiningOp());
+ consumer.getInput(consumerIdx).getDefiningOp());
+ auto slice =
+ dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
assert(subView || slice);
(void)subView;
(void)slice;
@@ -272,13 +271,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
- << " view: " << *producedView
+ << " view: " << producedView
<< " output index: " << producerIdx);
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
- auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp());
- auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp());
+ auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
if (!subView && !slice) {
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 4bc452afa36..9657daf9137 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -166,7 +166,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
// TODO(ntv): non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value v) {
- auto m = v->getType().dyn_cast<MemRefType>();
+ auto m = v.getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
@@ -281,7 +281,7 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
LinalgOp linOp = cast<LinalgOp>(op);
SetVector<Value> subViews;
for (auto it : linOp.getInputsAndOutputs())
- if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
promoteSubViewOperands(rewriter, linOp, subViews);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index b8b27958ff5..eb605699890 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -47,10 +47,10 @@ static llvm::cl::opt<bool> clPromoteDynamic(
llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
- auto *ctx = size->getContext();
+ auto *ctx = size.getContext();
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
if (!dynamicBuffers)
- if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
+ if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
return alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
Value mul = muli(constant_index(width), size);
@@ -116,7 +116,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
res.reserve(subViews.size());
DenseMap<Value, PromotionInfo> promotionInfoMap;
for (auto v : subViews) {
- SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
+ SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
auto viewType = subView.getType();
// TODO(ntv): support more cases than just float.
if (!viewType.getElementType().isa<FloatType>())
@@ -128,7 +128,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
}
for (auto v : subViews) {
- SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
+ SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
auto info = promotionInfoMap.find(v);
if (info == promotionInfoMap.end())
continue;
@@ -146,7 +146,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
auto info = promotionInfoMap.find(v);
if (info == promotionInfoMap.end())
continue;
- copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView);
+ copy(cast<SubViewOp>(v.getDefiningOp()), info->second.partialLocalView);
}
return res;
}
@@ -208,7 +208,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
SetVector<Value> subViews;
OpBuilder b(op);
for (auto it : op.getInputsAndOutputs())
- if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3841392dbdb..b77f658aa2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -45,8 +45,8 @@ static llvm::cl::list<unsigned>
llvm::cl::cat(clOptionsCategory));
static bool isZero(Value v) {
- return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
- cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
+ return isa_and_nonnull<ConstantIndexOp>(v.getDefiningOp()) &&
+ cast<ConstantIndexOp>(v.getDefiningOp()).getValue() == 0;
}
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
@@ -201,8 +201,8 @@ void transformIndexedGenericOpIndices(
// variable and replace all uses of the previous value.
Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
pivs[rangeIndex->second]->getValue());
- for (auto &use : oldIndex->getUses()) {
- if (use.getOwner() == newIndex->getDefiningOp())
+ for (auto &use : oldIndex.getUses()) {
+ if (use.getOwner() == newIndex.getDefiningOp())
continue;
use.set(newIndex);
}
@@ -258,7 +258,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
++viewIndex) {
Value view = *(viewIteratorBegin + viewIndex);
- unsigned rank = view->getType().cast<MemRefType>().getRank();
+ unsigned rank = view.getType().cast<MemRefType>().getRank();
auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
// If the view is not tiled, we can use it as is.
if (!isTiled(map, tileSizes)) {
@@ -299,8 +299,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
// defined.
if (folder)
for (auto v : llvm::concat<Value>(lbs, subViewSizes))
- if (v->use_empty())
- v->getDefiningOp()->erase();
+ if (v.use_empty())
+ v.getDefiningOp()->erase();
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 66b82fd14e4..1b8a7be7a22 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -35,9 +35,9 @@ using namespace mlir::loop;
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
ValueHandle range) {
assert(range.getType() && "expected !linalg.range type");
- assert(range.getValue()->getDefiningOp() &&
+ assert(range.getValue().getDefiningOp() &&
"need operations to extract range parts");
- auto rangeOp = cast<RangeOp>(range.getValue()->getDefiningOp());
+ auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp());
auto lb = rangeOp.min();
auto ub = rangeOp.max();
auto step = rangeOp.step();
@@ -168,7 +168,7 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
res.reserve(nOperands);
for (unsigned i = 0; i < nOperands; ++i) {
res.push_back(op->getOperand(numViews + i));
- auto t = res.back()->getType();
+ auto t = res.back().getType();
(void)t;
assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) &&
"expected scalar or vector type");
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index acbab01df79..5452b3d4ab8 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -69,23 +69,22 @@ void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
}
LogicalResult verify(ForOp op) {
- if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp()))
+ if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
if (cst.getValue() <= 0)
return op.emitOpError("constant step operand must be positive");
// Check that the body defines as single block argument for the induction
// variable.
auto *body = op.getBody();
- if (body->getNumArguments() != 1 ||
- !body->getArgument(0)->getType().isIndex())
+ if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
return op.emitOpError("expected body to have a single index argument for "
"the induction variable");
return success();
}
static void print(OpAsmPrinter &p, ForOp op) {
- p << op.getOperationName() << " " << *op.getInductionVar() << " = "
- << *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step();
+ p << op.getOperationName() << " " << op.getInductionVar() << " = "
+ << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
p.printRegion(op.region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
@@ -126,11 +125,11 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
Region &ForOp::getLoopBody() { return region(); }
bool ForOp::isDefinedOutsideOfLoop(Value value) {
- return !region().isAncestor(value->getParentRegion());
+ return !region().isAncestor(value.getParentRegion());
}
LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
- for (auto *op : ops)
+ for (auto op : ops)
op->moveBefore(this->getOperation());
return success();
}
@@ -139,8 +138,8 @@ ForOp mlir::loop::getForInductionVarOwner(Value val) {
auto ivArg = val.dyn_cast<BlockArgument>();
if (!ivArg)
return ForOp();
- assert(ivArg->getOwner() && "unlinked block argument");
- auto *containingInst = ivArg->getOwner()->getParentOp();
+ assert(ivArg.getOwner() && "unlinked block argument");
+ auto *containingInst = ivArg.getOwner()->getParentOp();
return dyn_cast_or_null<ForOp>(containingInst);
}
@@ -205,7 +204,7 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, IfOp op) {
- p << IfOp::getOperationName() << " " << *op.condition();
+ p << IfOp::getOperationName() << " " << op.condition();
p.printRegion(op.thenRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index faeff246bd2..8ff6fbed587 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -36,8 +36,8 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
- auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp());
- if (!srcScastOp || srcScastOp.arg()->getType() != getType())
+ auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
+ if (!srcScastOp || srcScastOp.arg().getType() != getType())
return OpFoldResult();
return srcScastOp.arg();
}
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 08a5ec59e8d..d62dd595985 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -52,7 +52,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
- Type qbarrierResultType = qbarrier.getResult()->getType();
+ Type qbarrierResultType = qbarrier.getResult().getType();
QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
if (!quantizedElementType) {
@@ -66,7 +66,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
if (!quantizedElementType.isCompatibleExpressedType(
- qbarrier.arg()->getType())) {
+ qbarrier.arg().getType())) {
return matchFailure();
}
@@ -86,7 +86,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
auto fusedLoc = FusedLoc::get(
- {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+ {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index d6fd35418b3..98c36fb9b08 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -104,7 +104,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
// Replace the values directly with the return operands.
assert(valuesToRepl.size() == 1 &&
"spv.ReturnValue expected to only handle one result");
- valuesToRepl.front()->replaceAllUsesWith(retValOp.value());
+ valuesToRepl.front().replaceAllUsesWith(retValOp.value());
}
};
} // namespace
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ff20e091f91..a6f5d358d0c 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -167,8 +167,8 @@ printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
static LogicalResult verifyCastOp(Operation *op,
bool requireSameBitWidth = true) {
- Type operandType = op->getOperand(0)->getType();
- Type resultType = op->getResult(0)->getType();
+ Type operandType = op->getOperand(0).getType();
+ Type resultType = op->getResult(0).getType();
// ODS checks that result type and operand type have the same shape.
if (auto vectorType = operandType.dyn_cast<VectorType>()) {
@@ -271,8 +271,8 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
//
// TODO(ravishankarm): Check that the value type satisfies restrictions of
// SPIR-V OpLoad/OpStore operations
- if (val->getType() !=
- ptr->getType().cast<spirv::PointerType>().getPointeeType()) {
+ if (val.getType() !=
+ ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
@@ -497,11 +497,11 @@ static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
}
static LogicalResult verifyBitFieldExtractOp(Operation *op) {
- if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
+ if (op->getOperand(0).getType() != op->getResult(0).getType()) {
return op->emitError("expected the same type for the first operand and "
"result, but provided ")
- << op->getOperand(0)->getType() << " and "
- << op->getResult(0)->getType();
+ << op->getOperand(0).getType() << " and "
+ << op->getResult(0).getType();
}
return success();
}
@@ -547,13 +547,12 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
printer << spirv::stringifyMemorySemantics(
static_cast<spirv::MemorySemantics>(
memorySemanticsAttr.getInt()))
- << "\" " << op->getOperands() << " : "
- << op->getOperand(0)->getType();
+ << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}
// Verifies an atomic update op.
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
- auto ptrType = op->getOperand(0)->getType().cast<spirv::PointerType>();
+ auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
auto elementType = ptrType.getPointeeType();
if (!elementType.isa<IntegerType>())
return op->emitOpError(
@@ -561,7 +560,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< elementType;
if (op->getNumOperands() > 1) {
- auto valueType = op->getOperand(1)->getType();
+ auto valueType = op->getOperand(1).getType();
if (valueType != elementType)
return op->emitOpError("expected value to have the same type as the "
"pointer operand's pointee type ")
@@ -595,8 +594,8 @@ static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
}
static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
- printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
- << unaryOp->getOperand(0)->getType();
+ printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
+ << unaryOp->getOperand(0).getType();
}
/// Result of a logical op must be a scalar or vector of boolean type.
@@ -634,7 +633,7 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
- << logicalOp->getOperand(0)->getType();
+ << logicalOp->getOperand(0).getType();
}
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
@@ -657,16 +656,16 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
Value base = op->getOperand(0);
Value shift = op->getOperand(1);
- printer << op->getName() << ' ' << *base << ", " << *shift << " : "
- << base->getType() << ", " << shift->getType();
+ printer << op->getName() << ' ' << base << ", " << shift << " : "
+ << base.getType() << ", " << shift.getType();
}
static LogicalResult verifyShiftOp(Operation *op) {
- if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
+ if (op->getOperand(0).getType() != op->getResult(0).getType()) {
return op->emitError("expected the same type for the first operand and "
"result, but provided ")
- << op->getOperand(0)->getType() << " and "
- << op->getResult(0)->getType();
+ << op->getOperand(0).getType() << " and "
+ << op->getResult(0).getType();
}
return success();
}
@@ -704,7 +703,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
}
index = 0;
if (resultType.isa<spirv::StructType>()) {
- Operation *op = indexSSA->getDefiningOp();
+ Operation *op = indexSSA.getDefiningOp();
if (!op) {
emitError(baseLoc, "'spv.AccessChain' op index must be an "
"integer spv.constant to access "
@@ -734,7 +733,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
void spirv::AccessChainOp::build(Builder *builder, OperationState &state,
Value basePtr, ValueRange indices) {
- auto type = getElementPtrType(basePtr->getType(), indices, state.location);
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, indices);
}
@@ -768,14 +767,14 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
}
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
- printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
- << '[' << op.indices() << "] : " << op.base_ptr()->getType();
+ printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
+ << '[' << op.indices() << "] : " << op.base_ptr().getType();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
accessChainOp.indices().end());
- auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
+ auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
indices, accessChainOp.getLoc());
if (!resultType) {
return failure();
@@ -808,7 +807,7 @@ struct CombineChainedAccessChain
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
- accessChainOp.base_ptr()->getDefiningOp());
+ accessChainOp.base_ptr().getDefiningOp());
if (!parentAccessChainOp) {
return matchFailure();
@@ -868,7 +867,7 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
printer.printSymbolName(addressOfOp.variable());
// Print the type.
- printer << " : " << addressOfOp.pointer()->getType();
+ printer << " : " << addressOfOp.pointer().getType();
}
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
@@ -878,7 +877,7 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
if (!varOp) {
return addressOfOp.emitOpError("expected spv.globalVariable symbol");
}
- if (addressOfOp.pointer()->getType() != varOp.type()) {
+ if (addressOfOp.pointer().getType() != varOp.type()) {
return addressOfOp.emitOpError(
"result type mismatch with the referenced global variable's type");
}
@@ -926,7 +925,7 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
<< stringifyScope(atomOp.memory_scope()) << "\" \""
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
- << atomOp.getOperands() << " : " << atomOp.pointer()->getType();
+ << atomOp.getOperands() << " : " << atomOp.pointer().getType();
}
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
@@ -934,19 +933,19 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
// "The type of Value must be the same as Result Type. The type of the value
// pointed to by Pointer must be the same as Result Type. This type must also
// match the type of Comparator."
- if (atomOp.getType() != atomOp.value()->getType())
+ if (atomOp.getType() != atomOp.value().getType())
return atomOp.emitOpError("value operand must have the same type as the op "
"result, but found ")
- << atomOp.value()->getType() << " vs " << atomOp.getType();
+ << atomOp.value().getType() << " vs " << atomOp.getType();
- if (atomOp.getType() != atomOp.comparator()->getType())
+ if (atomOp.getType() != atomOp.comparator().getType())
return atomOp.emitOpError(
"comparator operand must have the same type as the op "
"result, but found ")
- << atomOp.comparator()->getType() << " vs " << atomOp.getType();
+ << atomOp.comparator().getType() << " vs " << atomOp.getType();
Type pointeeType =
- atomOp.pointer()->getType().cast<spirv::PointerType>().getPointeeType();
+ atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
if (atomOp.getType() != pointeeType)
return atomOp.emitOpError(
"pointer operand's pointee type must have the same "
@@ -966,8 +965,8 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
static LogicalResult verify(spirv::BitcastOp bitcastOp) {
// TODO: The SPIR-V spec validation rules are different for different
// versions.
- auto operandType = bitcastOp.operand()->getType();
- auto resultType = bitcastOp.result()->getType();
+ auto operandType = bitcastOp.operand().getType();
+ auto resultType = bitcastOp.result().getType();
if (operandType == resultType) {
return bitcastOp.emitError(
"result type must be different from operand type");
@@ -1026,15 +1025,15 @@ static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) {
printer << spirv::BitFieldInsertOp::getOperationName() << ' '
<< bitFieldInsertOp.getOperands() << " : "
- << bitFieldInsertOp.base()->getType() << ", "
- << bitFieldInsertOp.offset()->getType() << ", "
- << bitFieldInsertOp.count()->getType();
+ << bitFieldInsertOp.base().getType() << ", "
+ << bitFieldInsertOp.offset().getType() << ", "
+ << bitFieldInsertOp.count().getType();
}
static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) {
- auto baseType = bitFieldOp.base()->getType();
- auto insertType = bitFieldOp.insert()->getType();
- auto resultType = bitFieldOp.getResult()->getType();
+ auto baseType = bitFieldOp.base().getType();
+ auto insertType = bitFieldOp.insert().getType();
+ auto resultType = bitFieldOp.getResult().getType();
if ((baseType != insertType) || (baseType != resultType)) {
return bitFieldOp.emitError("expected the same type for the base operand, "
@@ -1199,7 +1198,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp,
OpAsmPrinter &printer) {
printer << spirv::CompositeConstructOp::getOperationName() << " "
<< compositeConstructOp.constituents() << " : "
- << compositeConstructOp.getResult()->getType();
+ << compositeConstructOp.getResult().getType();
}
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
@@ -1214,11 +1213,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
}
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
- if (constituents[index]->getType() != cType.getElementType(index)) {
+ if (constituents[index].getType() != cType.getElementType(index)) {
return compositeConstructOp.emitError(
"operand type mismatch: expected operand type ")
<< cType.getElementType(index) << ", but provided "
- << constituents[index]->getType();
+ << constituents[index].getType();
}
}
@@ -1234,7 +1233,7 @@ void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state,
ArrayRef<int32_t> indices) {
auto indexAttr = builder->getI32ArrayAttr(indices);
auto elementType =
- getElementType(composite->getType(), indexAttr, state.location);
+ getElementType(composite.getType(), indexAttr, state.location);
if (!elementType) {
return;
}
@@ -1268,13 +1267,13 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
static void print(spirv::CompositeExtractOp compositeExtractOp,
OpAsmPrinter &printer) {
printer << spirv::CompositeExtractOp::getOperationName() << ' '
- << *compositeExtractOp.composite() << compositeExtractOp.indices()
- << " : " << compositeExtractOp.composite()->getType();
+ << compositeExtractOp.composite() << compositeExtractOp.indices()
+ << " : " << compositeExtractOp.composite().getType();
}
static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
- auto resultType = getElementType(compExOp.composite()->getType(),
+ auto resultType = getElementType(compExOp.composite().getType(),
indicesArrayAttr, compExOp.getLoc());
if (!resultType)
return failure();
@@ -1321,21 +1320,21 @@ static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
auto objectType =
- getElementType(compositeInsertOp.composite()->getType(), indicesArrayAttr,
+ getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
compositeInsertOp.getLoc());
if (!objectType)
return failure();
- if (objectType != compositeInsertOp.object()->getType()) {
+ if (objectType != compositeInsertOp.object().getType()) {
return compositeInsertOp.emitOpError("object operand type should be ")
<< objectType << ", but found "
- << compositeInsertOp.object()->getType();
+ << compositeInsertOp.object().getType();
}
- if (compositeInsertOp.composite()->getType() != compositeInsertOp.getType()) {
+ if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
return compositeInsertOp.emitOpError("result type should be the same as "
"the composite type, but found ")
- << compositeInsertOp.composite()->getType() << " vs "
+ << compositeInsertOp.composite().getType() << " vs "
<< compositeInsertOp.getType();
}
@@ -1345,10 +1344,10 @@ static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
static void print(spirv::CompositeInsertOp compositeInsertOp,
OpAsmPrinter &printer) {
printer << spirv::CompositeInsertOp::getOperationName() << " "
- << *compositeInsertOp.object() << ", "
- << *compositeInsertOp.composite() << compositeInsertOp.indices()
- << " : " << compositeInsertOp.object()->getType() << " into "
- << compositeInsertOp.composite()->getType();
+ << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
+ << compositeInsertOp.indices() << " : "
+ << compositeInsertOp.object().getType() << " into "
+ << compositeInsertOp.composite().getType();
}
//===----------------------------------------------------------------------===//
@@ -1707,12 +1706,12 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
}
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
- if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) {
+ if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
return functionCallOp.emitOpError(
"operand type mismatch: expected operand type ")
<< functionType.getInput(i) << ", but provided "
- << functionCallOp.getOperand(i)->getType()
- << " for operand number " << i;
+ << functionCallOp.getOperand(i).getType() << " for operand number "
+ << i;
}
}
@@ -1724,10 +1723,10 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
}
if (functionCallOp.getNumResults() &&
- (functionCallOp.getResult(0)->getType() != functionType.getResult(0))) {
+ (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
return functionCallOp.emitOpError("result type mismatch: expected ")
<< functionType.getResult(0) << ", but provided "
- << functionCallOp.getResult(0)->getType();
+ << functionCallOp.getResult(0).getType();
}
return success();
@@ -1955,7 +1954,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
void spirv::LoadOp::build(Builder *builder, OperationState &state,
Value basePtr, IntegerAttr memory_access,
IntegerAttr alignment) {
- auto ptrType = basePtr->getType().cast<spirv::PointerType>();
+ auto ptrType = basePtr.getType().cast<spirv::PointerType>();
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
alignment);
}
@@ -1986,7 +1985,7 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
auto *op = loadOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+ loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
<< loadOp.ptr();
@@ -2414,7 +2413,7 @@ static ParseResult parseReferenceOfOp(OpAsmParser &parser,
static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
printer << spirv::ReferenceOfOp::getOperationName() << ' ';
printer.printSymbolName(referenceOfOp.spec_const());
- printer << " : " << referenceOfOp.reference()->getType();
+ printer << " : " << referenceOfOp.reference().getType();
}
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
@@ -2424,7 +2423,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
if (!specConstOp) {
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
}
- if (referenceOfOp.reference()->getType() !=
+ if (referenceOfOp.reference().getType() !=
specConstOp.default_value().getType()) {
return referenceOfOp.emitOpError("result type mismatch with the referenced "
"specialization constant's type");
@@ -2461,7 +2460,7 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
- << " : " << retValOp.value()->getType();
+ << " : " << retValOp.value().getType();
}
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
@@ -2472,7 +2471,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
"returns 1 value but enclosing function requires ")
<< numFnResults << " results";
- auto operandType = retValOp.value()->getType();
+ auto operandType = retValOp.value().getType();
auto fnResultType = funcOp.getType().getResult(0);
if (operandType != fnResultType)
return retValOp.emitOpError(" return value's type (")
@@ -2488,7 +2487,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
Value trueValue, Value falseValue) {
- build(builder, state, trueValue->getType(), cond, trueValue, falseValue);
+ build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
}
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
@@ -2514,19 +2513,18 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
- << " : " << op.condition()->getType() << ", "
- << op.result()->getType();
+ << " : " << op.condition().getType() << ", " << op.result().getType();
}
static LogicalResult verify(spirv::SelectOp op) {
- auto resultTy = op.result()->getType();
- if (op.true_value()->getType() != resultTy) {
+ auto resultTy = op.result().getType();
+ if (op.true_value().getType() != resultTy) {
return op.emitOpError("result type and true value type must be the same");
}
- if (op.false_value()->getType() != resultTy) {
+ if (op.false_value().getType() != resultTy) {
return op.emitOpError("result type and false value type must be the same");
}
- if (auto conditionTy = op.condition()->getType().dyn_cast<VectorType>()) {
+ if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
auto resultVectorTy = resultTy.dyn_cast<VectorType>();
if (!resultVectorTy) {
return op.emitOpError("result expected to be of vector type when "
@@ -2695,7 +2693,7 @@ struct ConvertSelectionOpToSelect
cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
auto selectOp = rewriter.create<spirv::SelectOp>(
- selectionOp.getLoc(), trueValue->getType(), brConditionalOp.condition(),
+ selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
trueValue, falseValue);
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
selectOp.getResult(), storeOpAttributes);
@@ -2773,7 +2771,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
// attributes and a valid type of the value.
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
- !isValidType(trueBrStoreOp.value()->getType())) {
+ !isValidType(trueBrStoreOp.value().getType())) {
return matchFailure();
}
@@ -2879,13 +2877,13 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
auto *op = storeOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+ storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
<< storeOp.ptr() << ", " << storeOp.value();
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
- printer << " : " << storeOp.value()->getType();
+ printer << " : " << storeOp.value().getType();
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
@@ -3025,7 +3023,7 @@ static LogicalResult verify(spirv::VariableOp varOp) {
"spv.globalVariable for module-level variables.");
}
- auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
+ auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
if (varOp.storage_class() != pointerType.getStorageClass())
return varOp.emitOpError(
"storage class must match result pointer's storage class");
@@ -3033,7 +3031,7 @@ static LogicalResult verify(spirv::VariableOp varOp) {
if (varOp.getNumOperands() != 0) {
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
// a global (module scope) OpVariable instruction".
- auto *initOp = varOp.getOperand(0)->getDefiningOp();
+ auto *initOp = varOp.getOperand(0).getDefiningOp();
if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
isa<spirv::AddressOfOp>(initOp)))
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4e030217160..25d5763dc98 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1775,7 +1775,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
<< " from block " << block << "\n");
if (!isFnEntryBlock(block)) {
for (BlockArgument blockArg : block->getArguments()) {
- auto newArg = newBlock->addArgument(blockArg->getType());
+ auto newArg = newBlock->addArgument(blockArg.getType());
mapper.map(blockArg, newArg);
LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
<< " to " << newArg << '\n');
@@ -1816,7 +1816,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
// make sure the old merge block has the same block argument list.
assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
for (BlockArgument blockArg : headerBlock->getArguments()) {
- mergeBlock->addArgument(blockArg->getType());
+ mergeBlock->addArgument(blockArg.getType());
}
// If the loop header block has block arguments, make sure the spv.branch op
@@ -2200,7 +2200,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
"spirv::BitcastOp, only ")
<< wordIndex << " of " << words.size() << " processed";
}
- if (resultTypes[0] == operands[0]->getType() &&
+ if (resultTypes[0] == operands[0].getType() &&
resultTypes[0].isa<IntegerType>()) {
// TODO(b/130356985): This check is added to ignore error in Op verification
// due to both signed and unsigned integers mapping to the same
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index d7971eb5e35..74a959b0ea5 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -507,10 +507,10 @@ void Serializer::printValueIDMap(raw_ostream &os) {
Value val = valueIDPair.first;
os << " " << val << " "
<< "id = " << valueIDPair.second << ' ';
- if (auto *op = val->getDefiningOp()) {
+ if (auto *op = val.getDefiningOp()) {
os << "from op '" << op->getName() << "'";
} else if (auto arg = val.dyn_cast<BlockArgument>()) {
- Block *block = arg->getOwner();
+ Block *block = arg.getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
}
@@ -714,7 +714,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
// Declare the parameters.
for (auto arg : op.getArguments()) {
uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
return failure();
}
auto argValueID = getNextID();
@@ -1397,7 +1397,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// Get the type <id> and result <id> for this OpPhi instruction.
uint32_t phiTypeID = 0;
- if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID)))
+ if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
return failure();
uint32_t phiID = getNextID();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index 07621d6fa80..c94746ac22e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -100,7 +100,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
// Change the type for the direct users.
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
- return VulkanLayoutUtils::isLegalType(op.pointer()->getType());
+ return VulkanLayoutUtils::isLegalType(op.pointer().getType());
});
// TODO: Change the type for the indirect users such as spv.Load, spv.Store,
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index f9296196789..25e420ad9f8 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -79,7 +79,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
- valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
+ valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
};
} // end anonymous namespace
@@ -96,9 +96,9 @@ static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
- << *op->getOperand(0);
+ << op->getOperand(0);
p.printOptionalAttrDict(op->getAttrs());
- p << " : " << op->getOperand(0)->getType();
+ p << " : " << op->getOperand(0).getType();
}
/// A custom binary operation printer that omits the "std." prefix from the
@@ -109,20 +109,20 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
// If not all the operand and result types are the same, just use the
// generic assembly form to avoid omitting information in printing.
- auto resultType = op->getResult(0)->getType();
- if (op->getOperand(0)->getType() != resultType ||
- op->getOperand(1)->getType() != resultType) {
+ auto resultType = op->getResult(0).getType();
+ if (op->getOperand(0).getType() != resultType ||
+ op->getOperand(1).getType() != resultType) {
p.printGenericOp(op);
return;
}
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
- << *op->getOperand(0) << ", " << *op->getOperand(1);
+ << op->getOperand(0) << ", " << op->getOperand(1);
p.printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
- p << " : " << op->getResult(0)->getType();
+ p << " : " << op->getResult(0).getType();
}
/// A custom cast operation printer that omits the "std." prefix from the
@@ -130,13 +130,13 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
- << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
- << op->getResult(0)->getType();
+ << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
+ << op->getResult(0).getType();
}
/// A custom cast operation verifier.
template <typename T> static LogicalResult verifyCastOp(T op) {
- auto opType = op.getOperand()->getType();
+ auto opType = op.getOperand().getType();
auto resType = op.getType();
if (!T::areCastCompatible(opType, resType))
return op.emitError("operand type ") << opType << " and result type "
@@ -209,8 +209,8 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
- auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
- if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+ if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
@@ -281,7 +281,7 @@ static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) {
}
static LogicalResult verify(AllocOp op) {
- auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>();
+ auto memRefType = op.getResult().getType().dyn_cast<MemRefType>();
if (!memRefType)
return op.emitOpError("result must be a memref");
@@ -338,7 +338,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
+ auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
@@ -489,14 +489,14 @@ static LogicalResult verify(CallOp op) {
return op.emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
- if (op.getOperand(i)->getType() != fnType.getInput(i))
+ if (op.getOperand(i).getType() != fnType.getInput(i))
return op.emitOpError("operand type mismatch");
if (fnType.getNumResults() != op.getNumResults())
return op.emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
- if (op.getResult(i)->getType() != fnType.getResult(i))
+ if (op.getResult(i).getType() != fnType.getResult(i))
return op.emitOpError("result type mismatch");
return success();
@@ -553,12 +553,12 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, CallIndirectOp op) {
p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee()->getType();
+ p << " : " << op.getCallee().getType();
}
static LogicalResult verify(CallIndirectOp op) {
// The callee must be a function.
- auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
+ auto fnType = op.getCallee().getType().dyn_cast<FunctionType>();
if (!fnType)
return op.emitOpError("callee must have function type");
@@ -567,14 +567,14 @@ static LogicalResult verify(CallIndirectOp op) {
return op.emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
- if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
+ if (op.getOperand(i + 1).getType() != fnType.getInput(i))
return op.emitOpError("operand type mismatch");
if (fnType.getNumResults() != op.getNumResults())
return op.emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
- if (op.getResult(i)->getType() != fnType.getResult(i))
+ if (op.getResult(i).getType() != fnType.getResult(i))
return op.emitOpError("result type mismatch");
return success();
@@ -616,7 +616,7 @@ static Type getI1SameShape(Builder *build, Type type) {
static void buildCmpIOp(Builder *build, OperationState &result,
CmpIPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
- result.types.push_back(getI1SameShape(build, lhs->getType()));
+ result.types.push_back(getI1SameShape(build, lhs.getType()));
result.addAttribute(
CmpIOp::getPredicateAttrName(),
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -668,7 +668,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
<< '"' << ", " << op.lhs() << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
- p << " : " << op.lhs()->getType();
+ p << " : " << op.lhs().getType();
}
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
@@ -769,7 +769,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
static void buildCmpFOp(Builder *build, OperationState &result,
CmpFPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
- result.types.push_back(getI1SameShape(build, lhs->getType()));
+ result.types.push_back(getI1SameShape(build, lhs.getType()));
result.addAttribute(
CmpFOp::getPredicateAttrName(),
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
@@ -824,7 +824,7 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
<< ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
- p << " : " << op.lhs()->getType();
+ p << " : " << op.lhs().getType();
}
static LogicalResult verify(CmpFOp op) {
@@ -1123,14 +1123,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState &result,
}
bool ConstantFloatOp::classof(Operation *op) {
- return ConstantOp::classof(op) &&
- op->getResult(0)->getType().isa<FloatType>();
+ return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::classof(Operation *op) {
return ConstantOp::classof(op) &&
- op->getResult(0)->getType().isa<IntegerType>();
+ op->getResult(0).getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState &result,
@@ -1151,7 +1150,7 @@ void ConstantIntOp::build(Builder *builder, OperationState &result,
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::classof(Operation *op) {
- return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex();
+ return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState &result,
@@ -1174,11 +1173,11 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
PatternRewriter &rewriter) const override {
// Check that the memref operand's defining operation is an AllocOp.
Value memref = dealloc.memref();
- if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
+ if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
return matchFailure();
// Check that all of the uses of the AllocOp are other DeallocOps.
- for (auto *user : memref->getUsers())
+ for (auto *user : memref.getUsers())
if (!isa<DeallocOp>(user))
return matchFailure();
@@ -1190,7 +1189,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
} // end anonymous namespace.
static void print(OpAsmPrinter &p, DeallocOp op) {
- p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
+ p << "dealloc " << op.memref() << " : " << op.memref().getType();
}
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
@@ -1203,7 +1202,7 @@ static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
}
static LogicalResult verify(DeallocOp op) {
- if (!op.memref()->getType().isa<MemRefType>())
+ if (!op.memref().getType().isa<MemRefType>())
return op.emitOpError("operand must be a memref");
return success();
}
@@ -1224,9 +1223,9 @@ LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, DimOp op) {
- p << "dim " << *op.getOperand() << ", " << op.getIndex();
+ p << "dim " << op.getOperand() << ", " << op.getIndex();
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
- p << " : " << op.getOperand()->getType();
+ p << " : " << op.getOperand().getType();
}
static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
@@ -1251,7 +1250,7 @@ static LogicalResult verify(DimOp op) {
return op.emitOpError("requires an integer attribute named 'index'");
int64_t index = indexAttr.getValue().getSExtValue();
- auto type = op.getOperand()->getType();
+ auto type = op.getOperand().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index >= tensorType.getRank())
return op.emitOpError("index is out of range");
@@ -1270,7 +1269,7 @@ static LogicalResult verify(DimOp op) {
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Constant fold dim when the size along the index referred to is a constant.
- auto opType = memrefOrTensor()->getType();
+ auto opType = memrefOrTensor().getType();
int64_t indexSize = -1;
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
indexSize = tensorType.getShape()[getIndex()];
@@ -1286,7 +1285,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
// The size at getIndex() is now a dynamic size of a memref.
- auto memref = memrefOrTensor()->getDefiningOp();
+ auto memref = memrefOrTensor().getDefiningOp();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(getIndex()));
@@ -1367,16 +1366,15 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
}
void DmaStartOp::print(OpAsmPrinter &p) {
- p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
- << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
- << ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
+ p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
+ << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
+ << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
if (isStrided())
- p << ", " << *getStride() << ", " << *getNumElementsPerStride();
+ p << ", " << getStride() << ", " << getNumElementsPerStride();
p.printOptionalAttrDict(getAttrs());
- p << " : " << getSrcMemRef()->getType();
- p << ", " << getDstMemRef()->getType();
- p << ", " << getTagMemRef()->getType();
+ p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
+ << ", " << getTagMemRef().getType();
}
// Parse DmaStartOp.
@@ -1506,7 +1504,7 @@ void DmaWaitOp::print(OpAsmPrinter &p) {
p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
<< getNumElements();
p.printOptionalAttrDict(getAttrs());
- p << " : " << getTagMemRef()->getType();
+ p << " : " << getTagMemRef().getType();
}
// Parse DmaWaitOp.
@@ -1553,10 +1551,10 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
- p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
+ p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
p << ']';
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getAggregate()->getType();
+ p << " : " << op.getAggregate().getType();
}
static ParseResult parseExtractElementOp(OpAsmParser &parser,
@@ -1577,7 +1575,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
}
static LogicalResult verify(ExtractElementOp op) {
- auto aggregateType = op.getAggregate()->getType().cast<ShapedType>();
+ auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
// This should be possible with tablegen type constraints
if (op.getType() != aggregateType.getElementType())
@@ -1634,7 +1632,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) {
- p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
+ p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -1781,7 +1779,7 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, PrefetchOp op) {
- p << PrefetchOp::getOperationName() << " " << *op.memref() << '[';
+ p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
p.printOperands(op.indices());
p << ']' << ", " << (op.isWrite() ? "write" : "read");
p << ", locality<" << op.localityHint();
@@ -1851,7 +1849,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RankOp op) {
- p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
+ p << "rank " << op.getOperand() << " : " << op.getOperand().getType();
}
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
@@ -1866,7 +1864,7 @@ static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
// Constant fold rank when the rank of the tensor is known.
- auto type = getOperand()->getType();
+ auto type = getOperand().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
return IntegerAttr();
@@ -1954,10 +1952,10 @@ static LogicalResult verify(ReturnOp op) {
<< " operands, but enclosing function returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
- if (op.getOperand(i)->getType() != results[i])
+ if (op.getOperand(i).getType() != results[i])
return op.emitError()
<< "type of return operand " << i << " ("
- << op.getOperand(i)->getType()
+ << op.getOperand(i).getType()
<< ") doesn't match function result type (" << results[i] << ")";
return success();
@@ -1997,13 +1995,13 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SelectOp op) {
- p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
+ p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
p.printOptionalAttrDict(op.getAttrs());
}
static LogicalResult verify(SelectOp op) {
- auto trueType = op.getTrueValue()->getType();
- auto falseType = op.getFalseValue()->getType();
+ auto trueType = op.getTrueValue().getType();
+ auto falseType = op.getFalseValue().getType();
if (trueType != falseType)
return op.emitOpError(
@@ -2032,7 +2030,7 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult verify(SignExtendIOp op) {
// Get the scalar type (which is either directly the type of the operand
// or the vector's/tensor's element type.
- auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto srcType = getElementTypeOrSelf(op.getOperand().getType());
auto dstType = getElementTypeOrSelf(op.getType());
// For now, index is forbidden for the source and the destination type.
@@ -2054,7 +2052,7 @@ static LogicalResult verify(SignExtendIOp op) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, SplatOp op) {
- p << "splat " << *op.getOperand();
+ p << "splat " << op.getOperand();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}
@@ -2074,7 +2072,7 @@ static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
static LogicalResult verify(SplatOp op) {
// TODO: we could replace this by a trait.
- if (op.getOperand()->getType() !=
+ if (op.getOperand().getType() !=
op.getType().cast<ShapedType>().getElementType())
return op.emitError("operand should be of elemental type of result type");
@@ -2103,8 +2101,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, StoreOp op) {
- p << "store " << *op.getValueToStore();
- p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
+ p << "store " << op.getValueToStore();
+ p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
@@ -2130,7 +2128,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
static LogicalResult verify(StoreOp op) {
// First operand must have same type as memref element type.
- if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
+ if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
return op.emitOpError(
"first operand must have same type memref element type");
@@ -2251,9 +2249,9 @@ static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TensorLoadOp op) {
- p << "tensor_load " << *op.getOperand();
+ p << "tensor_load " << op.getOperand();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getOperand()->getType();
+ p << " : " << op.getOperand().getType();
}
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
@@ -2274,9 +2272,9 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TensorStoreOp op) {
- p << "tensor_store " << *op.tensor() << ", " << *op.memref();
+ p << "tensor_store " << op.tensor() << ", " << op.memref();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.memref()->getType();
+ p << " : " << op.memref().getType();
}
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
@@ -2298,7 +2296,7 @@ static ParseResult parseTensorStoreOp(OpAsmParser &parser,
//===----------------------------------------------------------------------===//
static LogicalResult verify(TruncateIOp op) {
- auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto srcType = getElementTypeOrSelf(op.getOperand().getType());
auto dstType = getElementTypeOrSelf(op.getType());
if (srcType.isa<IndexType>())
@@ -2344,13 +2342,13 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, ViewOp op) {
- p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
+ p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
auto dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
p << "][" << op.getDynamicSizes() << ']';
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
+ p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
Value ViewOp::getDynamicOffset() {
@@ -2382,8 +2380,8 @@ static LogicalResult verifyDynamicStrides(MemRefType memrefType,
}
static LogicalResult verify(ViewOp op) {
- auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
- auto viewType = op.getResult()->getType().cast<MemRefType>();
+ auto baseType = op.getOperand(0).getType().cast<MemRefType>();
+ auto viewType = op.getResult().getType().cast<MemRefType>();
// The base memref should have identity layout map (or none).
if (baseType.getAffineMaps().size() > 1 ||
@@ -2453,7 +2451,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
int64_t newOffset = oldOffset;
unsigned dynamicOffsetOperandCount = 0;
if (dynamicOffset != nullptr) {
- auto *defOp = dynamicOffset->getDefiningOp();
+ auto *defOp = dynamicOffset.getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic offset will be folded into the map.
newOffset = constantIndexOp.getValue();
@@ -2478,7 +2476,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
+ auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
@@ -2590,7 +2588,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
ValueRange strides, Type resultType,
ArrayRef<NamedAttribute> attrs) {
if (!resultType)
- resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
+ resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
auto segmentAttr = b->getI32VectorAttr(
{1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
static_cast<int32_t>(strides.size())});
@@ -2637,13 +2635,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, SubViewOp op) {
- p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
+ p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets()
<< "][" << op.sizes() << "][" << op.strides() << ']';
SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
- p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
+ p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
static LogicalResult verify(SubViewOp op) {
@@ -2757,8 +2755,8 @@ static LogicalResult verify(SubViewOp op) {
}
raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
- return os << "range " << *range.offset << ":" << *range.size << ":"
- << *range.stride;
+ return os << "range " << range.offset << ":" << range.size << ":"
+ << range.stride;
}
SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
@@ -2827,7 +2825,7 @@ public:
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
for (auto size : llvm::enumerate(subViewOp.sizes())) {
- auto defOp = size.value()->getDefiningOp();
+ auto defOp = size.value().getDefiningOp();
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
}
@@ -2873,7 +2871,7 @@ public:
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
for (auto stride : llvm::enumerate(subViewOp.strides())) {
- auto defOp = stride.value()->getDefiningOp();
+ auto defOp = stride.value().getDefiningOp();
assert(defOp);
assert(baseStrides[stride.index()] > 0);
staticStrides[stride.index()] =
@@ -2924,7 +2922,7 @@ public:
auto staticOffset = baseOffset;
for (auto offset : llvm::enumerate(subViewOp.offsets())) {
- auto defOp = offset.value()->getDefiningOp();
+ auto defOp = offset.value().getDefiningOp();
assert(defOp);
assert(baseStrides[offset.index()] > 0);
staticOffset +=
@@ -2959,7 +2957,7 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static LogicalResult verify(ZeroExtendIOp op) {
- auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
+ auto srcType = getElementTypeOrSelf(op.getOperand().getType());
auto dstType = getElementTypeOrSelf(op.getType());
if (srcType.isa<IndexType>())
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 744500af663..11aea0936f2 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -163,9 +163,9 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
assert(op->getNumResults() == 1 &&
"only support broadcast check on one result");
- auto type1 = op->getOperand(0)->getType();
- auto type2 = op->getOperand(1)->getType();
- auto retType = op->getResult(0)->getType();
+ auto type1 = op->getOperand(0).getType();
+ auto type2 = op->getOperand(1).getType();
+ auto retType = op->getResult(0).getType();
// We forbid broadcasting vector and tensor.
if (hasBothVectorAndTensorType({type1, type2, retType}))
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 92a230eb5d1..8206d10962b 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -67,7 +67,7 @@ void vector::ContractionOp::build(Builder *builder, OperationState &result,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) {
result.addOperands({lhs, rhs, acc});
- result.addTypes(acc->getType());
+ result.addTypes(acc.getType());
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
}
@@ -125,13 +125,13 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
attrs.push_back(attr);
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
- p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
- p << *op.rhs() << ", " << *op.acc();
+ p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
+ p << op.rhs() << ", " << op.acc();
if (op.masks().size() == 2)
p << ", " << op.masks();
p.printOptionalAttrDict(op.getAttrs(), attrNames);
- p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
+ p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
<< op.getResultType();
}
@@ -211,7 +211,7 @@ static LogicalResult verify(ContractionOp op) {
if (map.getNumDims() != numIterators)
return op.emitOpError("expected indexing map ")
<< index << " to have " << numIterators << " number of inputs";
- auto operandType = op.getOperand(index)->getType().cast<VectorType>();
+ auto operandType = op.getOperand(index).getType().cast<VectorType>();
unsigned rank = operandType.getShape().size();
if (map.getNumResults() != rank)
return op.emitOpError("expected indexing map ")
@@ -351,10 +351,10 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
- p << op.getOperationName() << " " << *op.vector() << "[" << *op.position()
- << " : " << op.position()->getType() << "]";
+ p << op.getOperationName() << " " << op.vector() << "[" << op.position()
+ << " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.vector()->getType();
+ p << " : " << op.vector().getType();
}
static ParseResult parseExtractElementOp(OpAsmParser &parser,
@@ -398,15 +398,15 @@ void vector::ExtractOp::build(Builder *builder, OperationState &result,
Value source, ArrayRef<int64_t> position) {
result.addOperands(source);
auto positionAttr = getVectorSubscriptAttr(*builder, position);
- result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
+ result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
positionAttr));
result.addAttribute(getPositionAttrName(), positionAttr);
}
static void print(OpAsmPrinter &p, vector::ExtractOp op) {
- p << op.getOperationName() << " " << *op.vector() << op.position();
+ p << op.getOperationName() << " " << op.vector() << op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
- p << " : " << op.vector()->getType();
+ p << " : " << op.vector().getType();
}
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
@@ -495,13 +495,13 @@ static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
- p << op.getOperationName() << ' ' << *op.vector() << ", ";
+ p << op.getOperationName() << ' ' << op.vector() << ", ";
p << op.sizes() << ", " << op.strides();
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
ExtractSlicesOp::getStridesAttrName()});
- p << " : " << op.vector()->getType();
+ p << " : " << op.vector().getType();
p << " into " << op.getResultTupleType();
}
@@ -594,7 +594,7 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) {
- p << op.getOperationName() << " " << *op.source() << " : "
+ p << op.getOperationName() << " " << op.source() << " : "
<< op.getSourceType() << " to " << op.getVectorType();
}
@@ -642,15 +642,15 @@ void ShuffleOp::build(Builder *builder, OperationState &result, Value v1,
Value v2, ArrayRef<int64_t> mask) {
result.addOperands({v1, v2});
auto maskAttr = getVectorSubscriptAttr(*builder, mask);
- result.addTypes(v1->getType());
+ result.addTypes(v1.getType());
result.addAttribute(getMaskAttrName(), maskAttr);
}
static void print(OpAsmPrinter &p, ShuffleOp op) {
- p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " "
+ p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
<< op.mask();
p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
- p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
+ p << " : " << op.v1().getType() << ", " << op.v2().getType();
}
static LogicalResult verify(ShuffleOp op) {
@@ -725,10 +725,10 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, InsertElementOp op) {
- p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "["
- << *op.position() << " : " << op.position()->getType() << "]";
+ p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
+ << op.position() << " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.dest()->getType();
+ p << " : " << op.dest().getType();
}
static ParseResult parseInsertElementOp(OpAsmParser &parser,
@@ -766,12 +766,12 @@ void InsertOp::build(Builder *builder, OperationState &result, Value source,
Value dest, ArrayRef<int64_t> position) {
result.addOperands({source, dest});
auto positionAttr = getVectorSubscriptAttr(*builder, position);
- result.addTypes(dest->getType());
+ result.addTypes(dest.getType());
result.addAttribute(getPositionAttrName(), positionAttr);
}
static void print(OpAsmPrinter &p, InsertOp op) {
- p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
+ p << op.getOperationName() << " " << op.source() << ", " << op.dest()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
@@ -851,13 +851,13 @@ static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
- p << op.getOperationName() << ' ' << *op.vectors() << ", ";
+ p << op.getOperationName() << ' ' << op.vectors() << ", ";
p << op.sizes() << ", " << op.strides();
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
InsertSlicesOp::getStridesAttrName()});
- p << " : " << op.vectors()->getType();
+ p << " : " << op.vectors().getType();
p << " into " << op.getResultVectorType();
}
@@ -890,14 +890,13 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
result.addOperands({source, dest});
auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
- result.addTypes(dest->getType());
+ result.addTypes(dest.getType());
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
}
static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
- p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
- << " ";
+ p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
}
@@ -1049,10 +1048,10 @@ static LogicalResult verify(InsertStridedSliceOp op) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, OuterProductOp op) {
- p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+ p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
if (!op.acc().empty())
p << ", " << op.acc();
- p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
+ p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
}
static ParseResult parseOuterProductOp(OpAsmParser &parser,
@@ -1103,7 +1102,7 @@ static LogicalResult verify(OuterProductOp op) {
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ReshapeOp op) {
- p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape()
+ p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape()
<< "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
SmallVector<StringRef, 2> elidedAttrs = {
ReshapeOp::getOperandSegmentSizeAttr(),
@@ -1193,18 +1192,18 @@ static LogicalResult verify(ReshapeOp op) {
// If all shape operands are produced by constant ops, verify that product
// of dimensions for input/output shape match.
auto isDefByConstant = [](Value operand) {
- return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
+ return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
};
if (llvm::all_of(op.input_shape(), isDefByConstant) &&
llvm::all_of(op.output_shape(), isDefByConstant)) {
int64_t numInputElements = 1;
for (auto operand : op.input_shape())
numInputElements *=
- cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
int64_t numOutputElements = 1;
for (auto operand : op.output_shape())
numOutputElements *=
- cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
if (numInputElements != numOutputElements)
return op.emitError("product of input and output shape sizes must match");
}
@@ -1245,7 +1244,7 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(
- inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
+ inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
offsetsAttr, sizesAttr, stridesAttr));
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
result.addAttribute(getSizesAttrName(), sizesAttr);
@@ -1253,9 +1252,9 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
}
static void print(OpAsmPrinter &p, StridedSliceOp op) {
- p << op.getOperationName() << " " << *op.vector();
+ p << op.getOperationName() << " " << op.vector();
p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
+ p << " : " << op.vector().getType() << " to " << op.getResult().getType();
}
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
@@ -1305,7 +1304,7 @@ static LogicalResult verify(StridedSliceOp op) {
auto resultType = inferStridedSliceOpResultType(
op.getVectorType(), op.offsets(), op.sizes(), op.strides());
- if (op.getResult()->getType() != resultType) {
+ if (op.getResult().getType() != resultType) {
op.emitOpError("expected result type to be ") << resultType;
return failure();
}
@@ -1328,7 +1327,7 @@ public:
PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
PatternRewriter &rewriter) const override {
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
- auto defOp = stridedSliceOp.vector()->getDefiningOp();
+ auto defOp = stridedSliceOp.vector().getDefiningOp();
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
if (!constantMaskOp)
return matchFailure();
@@ -1365,7 +1364,7 @@ public:
// Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
- stridedSliceOp, stridedSliceOp.getResult()->getType(),
+ stridedSliceOp, stridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
return matchSuccess();
}
@@ -1503,7 +1502,7 @@ static LogicalResult verify(TransferReadOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- auto paddingType = op.padding()->getType();
+ auto paddingType = op.padding().getType();
auto permutationMap = op.permutation_map();
auto memrefElementType = memrefType.getElementType();
@@ -1540,8 +1539,8 @@ static LogicalResult verify(TransferReadOp op) {
// TransferWriteOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TransferWriteOp op) {
- p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
- << "[" << op.indices() << "]";
+ p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
+ << op.indices() << "]";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
}
@@ -1596,12 +1595,12 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) {
void TypeCastOp::build(Builder *builder, OperationState &result, Value source) {
result.addOperands(source);
result.addTypes(
- inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
+ inferVectorTypeCastResultType(source.getType().cast<MemRefType>()));
}
static void print(OpAsmPrinter &p, TypeCastOp op) {
- auto type = op.getOperand()->getType().cast<MemRefType>();
- p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
+ auto type = op.getOperand().getType().cast<MemRefType>();
+ p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to "
<< inferVectorTypeCastResultType(type);
}
@@ -1665,14 +1664,14 @@ static ParseResult parseTupleGetOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, TupleGetOp op) {
- p << op.getOperationName() << ' ' << *op.getOperand() << ", " << op.index();
+ p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
- p << " : " << op.getOperand()->getType();
+ p << " : " << op.getOperand().getType();
}
static LogicalResult verify(TupleGetOp op) {
- auto tupleType = op.getOperand()->getType().cast<TupleType>();
+ auto tupleType = op.getOperand().getType().cast<TupleType>();
if (op.getIndex() < 0 ||
op.getIndex() >= static_cast<int64_t>(tupleType.size()))
return op.emitOpError("tuple get index out of range");
@@ -1696,12 +1695,12 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
- << op.getResult()->getType();
+ << op.getResult().getType();
}
static LogicalResult verify(ConstantMaskOp &op) {
// Verify that array attr size matches the rank of the vector result.
- auto resultType = op.getResult()->getType().cast<VectorType>();
+ auto resultType = op.getResult().getType().cast<VectorType>();
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
return op.emitOpError(
"must specify array attr of size equal vector result rank");
@@ -1749,7 +1748,7 @@ static void print(OpAsmPrinter &p, CreateMaskOp op) {
static LogicalResult verify(CreateMaskOp op) {
// Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() !=
- op.getResult()->getType().cast<VectorType>().getRank())
+ op.getResult().getType().cast<VectorType>().getRank())
return op.emitOpError(
"must specify an operand for each result vector dimension");
return success();
@@ -1768,7 +1767,7 @@ ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, PrintOp op) {
- p << op.getOperationName() << ' ' << *op.source() << " : "
+ p << op.getOperationName() << ' ' << op.source() << " : "
<< op.getPrintType();
}
@@ -1783,19 +1782,19 @@ public:
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
auto is_not_def_by_constant = [](Value operand) {
- return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
+ return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
return matchFailure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
for (auto operand : createMaskOp.operands()) {
- auto defOp = operand->getDefiningOp();
+ auto defOp = operand.getDefiningOp();
maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
}
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
- createMaskOp, createMaskOp.getResult()->getType(),
+ createMaskOp, createMaskOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
return matchSuccess();
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 9fcbd0cb921..d98f41e3d63 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -195,7 +195,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
auto tupleType =
generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
- initValue->getLoc(), tupleType, initValue, sizes, strides);
+ initValue.getLoc(), tupleType, initValue, sizes, strides);
}
}
@@ -232,7 +232,7 @@ static Value getOrCreateUnrolledVectorSlice(
if (valueSlice == nullptr) {
// Return tuple element at 'sliceLinearIndex'.
auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
- auto initValueType = initValue->getType().cast<VectorType>();
+ auto initValueType = initValue.getType().cast<VectorType>();
auto vectorType =
VectorType::get(state.unrolledShape, initValueType.getElementType());
// Initialize 'cache' with slice from 'initValue'.
@@ -311,7 +311,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
unsigned resultIndex,
ArrayRef<int64_t> targetShape,
PatternRewriter &builder) {
- auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
+ auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
@@ -379,7 +379,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
- vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>();
+ vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
vectorTupleValues[i] = caches[resultIndex][i];
}
TupleType tupleType = builder.getTupleType(vectorTupleTypes);
@@ -387,7 +387,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
vectorTupleValues);
// Create InsertSlicesOp(Tuple(result_vectors)).
- auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
+ auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
@@ -411,7 +411,7 @@ static void getVectorContractionOpUnrollState(
vectors.resize(numIterators);
unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
for (unsigned i = 0; i < numIterators; ++i) {
- vectors[i].type = contractionOp.getOperand(i)->getType().cast<VectorType>();
+ vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
vectors[i].indexMap = iterationIndexMapList[i];
vectors[i].operandIndex = i;
vectors[i].isAcc = i == accOperandIndex ? true : false;
@@ -437,7 +437,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
std::vector<VectorState> &vectors,
unsigned &resultIndex) {
// Verify that operation and operands all have the same vector shape.
- auto resultType = op->getResult(0)->getType().dyn_cast_or_null<VectorType>();
+ auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
assert(resultType && "Expected op with vector result type");
auto resultShape = resultType.getShape();
// Verify that all operands have the same vector type as result.
@@ -515,7 +515,7 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
getAffineConstantExpr(offsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
- it.value()->getLoc(), map, ArrayRef<Value>(it.value()));
+ it.value().getLoc(), map, ArrayRef<Value>(it.value()));
}
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
fn(i, sliceIndices);
@@ -536,8 +536,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
Value xferReadResult = xferReadOp.getResult();
auto extractSlicesOp =
- dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
- if (!xferReadResult->hasOneUse() || !extractSlicesOp)
+ dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
+ if (!xferReadResult.hasOneUse() || !extractSlicesOp)
return matchFailure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
@@ -587,14 +587,14 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
if (!xferWriteOp.permutation_map().isIdentity())
return matchFailure();
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
- auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp();
+ auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
if (!insertSlicesOp)
return matchFailure();
// Get TupleOp operand of 'insertSlicesOp'.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
- insertSlicesOp.vectors()->getDefiningOp());
+ insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
return matchFailure();
@@ -634,19 +634,19 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
PatternRewriter &rewriter) const override {
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
- tupleGetOp.vectors()->getDefiningOp());
+ tupleGetOp.vectors().getDefiningOp());
if (!extractSlicesOp)
return matchFailure();
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
- extractSlicesOp.vector()->getDefiningOp());
+ extractSlicesOp.vector().getDefiningOp());
if (!insertSlicesOp)
return matchFailure();
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
- insertSlicesOp.vectors()->getDefiningOp());
+ insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
return matchFailure();
OpenPOWER on IntegriCloud