diff options
Diffstat (limited to 'mlir/lib/Dialect')
26 files changed, 425 insertions, 443 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 5f4cc2e1060..0da539ea9f0 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -108,8 +108,8 @@ static bool isFunctionRegion(Region *region) { /// symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast<BlockArgument>()) - return isFunctionRegion(arg->getOwner()->getParent()); - return isFunctionRegion(value->getDefiningOp()->getParentRegion()); + return isFunctionRegion(arg.getOwner()->getParent()); + return isFunctionRegion(value.getDefiningOp()->getParentRegion()); } // Value can be used as a dimension id if it is valid as a symbol, or @@ -117,10 +117,10 @@ bool mlir::isTopLevelValue(Value value) { // with dimension id arguments. bool mlir::isValidDim(Value value) { // The value must be an index type. - if (!value->getType().isIndex()) + if (!value.getType().isIndex()) return false; - if (auto *op = value->getDefiningOp()) { + if (auto *op = value.getDefiningOp()) { // Top level operation or constant operation is ok. if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) return true; @@ -134,7 +134,7 @@ bool mlir::isValidDim(Value value) { return false; } // This value has to be a block argument for a FuncOp or an affine.for. - auto *parentOp = value.cast<BlockArgument>()->getOwner()->getParentOp(); + auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp(); return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp); } @@ -162,11 +162,11 @@ static bool isDimOpValidSymbol(DimOp dimOp) { // The dim op is also okay if its operand memref/tensor is a view/subview // whose corresponding size is a valid symbol. unsigned index = dimOp.getIndex(); - if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp())) + if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp())) return isMemRefSizeValidSymbol<ViewOp>(viewOp, index); - if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp())) + if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp())) return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index); - if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp())) + if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp())) return isMemRefSizeValidSymbol<AllocOp>(allocOp, index); return false; } @@ -177,10 +177,10 @@ static bool isDimOpValidSymbol(DimOp dimOp) { // constraints. bool mlir::isValidSymbol(Value value) { // The value must be an index type. - if (!value->getType().isIndex()) + if (!value.getType().isIndex()) return false; - if (auto *op = value->getDefiningOp()) { + if (auto *op = value.getDefiningOp()) { // Top level operation or constant operation is ok. if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) return true; @@ -283,7 +283,7 @@ LogicalResult AffineApplyOp::verify() { return emitOpError("operands must be of type 'index'"); } - if (!getResult()->getType().isIndex()) + if (!getResult().getType().isIndex()) return emitOpError("result must be of type 'index'"); // Verify that the map only produces one result. @@ -332,7 +332,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getContext()) + return getAffineDimExpr(iterPos->second, v.getContext()) .cast<AffineDimExpr>(); } @@ -365,7 +365,7 @@ static llvm::SetVector<unsigned> indicesFromAffineApplyOp(ArrayRef<Value> operands) { llvm::SetVector<unsigned> res; for (auto en : llvm::enumerate(operands)) - if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp())) + if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp())) res.insert(en.index()); return res; } @@ -487,7 +487,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, // 1. Only dispatch dims or symbols. for (auto en : llvm::enumerate(operands)) { auto t = en.value(); - assert(t->getType().isIndex()); + assert(t.getType().isIndex()); bool isDim = (en.index() < map.getNumDims()); if (isDim) { // a. The mathematical composition of AffineMap composes dims. @@ -503,7 +503,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto t = operands[i]; - auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp()); + auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp()); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply.getOperation()->print( @@ -588,7 +588,7 @@ static void composeAffineMapAndOperands(AffineMap *map, void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl<Value> *operands) { while (llvm::any_of(*operands, [](Value v) { - return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp()); + return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); } @@ -819,8 +819,8 @@ void AffineApplyOp::getCanonicalizationPatterns( static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp()); - if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) { + auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp()); + if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { operand.set(cast.getOperand()); folded = true; } @@ -856,16 +856,16 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result, } void AffineDmaStartOp::print(OpAsmPrinter &p) { - p << "affine.dma_start " << *getSrcMemRef() << '['; + p << "affine.dma_start " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); - p << "], " << *getDstMemRef() << '['; + p << "], " << getDstMemRef() << '['; p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); - p << "], " << *getTagMemRef() << '['; + p << "], " << getTagMemRef() << '['; p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); - p << "], " << *getNumElements(); + p << "], " << getNumElements(); if (isStrided()) { - p << ", " << *getStride(); - p << ", " << *getNumElementsPerStride(); + p << ", " << getStride(); + p << ", " << getNumElementsPerStride(); } p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " << getTagMemRefType(); @@ -951,11 +951,11 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, } LogicalResult AffineDmaStartOp::verify() { - if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>()) + if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>()) return emitOpError("expected DMA source to be of memref type"); - if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>()) + if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>()) return emitOpError("expected DMA destination to be of memref type"); - if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>()) + if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>()) return emitOpError("expected DMA tag to be of memref type"); // DMAs from different memory spaces supported. @@ -971,19 +971,19 @@ LogicalResult AffineDmaStartOp::verify() { } for (auto idx : getSrcIndices()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("tag index must be a dimension or symbol identifier"); @@ -1012,12 +1012,12 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result, } void AffineDmaWaitOp::print(OpAsmPrinter &p) { - p << "affine.dma_wait " << *getTagMemRef() << '['; + p << "affine.dma_wait " << getTagMemRef() << '['; SmallVector<Value, 2> operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; p.printOperand(getNumElements()); - p << " : " << getTagMemRef()->getType(); + p << " : " << getTagMemRef().getType(); } // Parse AffineDmaWaitOp. @@ -1056,10 +1056,10 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, } LogicalResult AffineDmaWaitOp::verify() { - if (!getOperand(0)->getType().isa<MemRefType>()) + if (!getOperand(0).getType().isa<MemRefType>()) return emitOpError("expected DMA tag to be of memref type"); for (auto idx : getTagIndices()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("index must be a dimension or symbol identifier"); @@ -1123,8 +1123,7 @@ static LogicalResult verify(AffineForOp op) { // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); - if (body->getNumArguments() != 1 || - !body->getArgument(0)->getType().isIndex()) + if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body to have a single index argument for the " "induction variable"); @@ -1553,7 +1552,7 @@ bool AffineForOp::matchingBoundOperandList() { Region &AffineForOp::getLoopBody() { return region(); } bool AffineForOp::isDefinedOutsideOfLoop(Value value) { - return !region().isAncestor(value->getParentRegion()); + return !region().isAncestor(value.getParentRegion()); } LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) { @@ -1571,9 +1570,9 @@ bool mlir::isForInductionVar(Value val) { /// not an induction variable, then return nullptr. AffineForOp mlir::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast<BlockArgument>(); - if (!ivArg || !ivArg->getOwner()) + if (!ivArg || !ivArg.getOwner()) return AffineForOp(); - auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); + auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); return dyn_cast<AffineForOp>(containingInst); } @@ -1744,7 +1743,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, result.addOperands(operands); if (map) result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); - auto memrefType = operands[0]->getType().cast<MemRefType>(); + auto memrefType = operands[0].getType().cast<MemRefType>(); result.types.push_back(memrefType.getElementType()); } @@ -1753,14 +1752,14 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref, assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); - auto memrefType = memref->getType().cast<MemRefType>(); + auto memrefType = memref.getType().cast<MemRefType>(); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref, ValueRange indices) { - auto memrefType = memref->getType().cast<MemRefType>(); + auto memrefType = memref.getType().cast<MemRefType>(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -1789,7 +1788,7 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) { } void AffineLoadOp::print(OpAsmPrinter &p) { - p << "affine.load " << *getMemRef() << '['; + p << "affine.load " << getMemRef() << '['; if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; @@ -1816,7 +1815,7 @@ LogicalResult AffineLoadOp::verify() { } for (auto idx : getMapOperands()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("index must be a dimension or symbol identifier"); @@ -1854,7 +1853,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result, void AffineStoreOp::build(Builder *builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { - auto memrefType = memref->getType().cast<MemRefType>(); + auto memrefType = memref.getType().cast<MemRefType>(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. @@ -1885,8 +1884,8 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) { } void AffineStoreOp::print(OpAsmPrinter &p) { - p << "affine.store " << *getValueToStore(); - p << ", " << *getMemRef() << '['; + p << "affine.store " << getValueToStore(); + p << ", " << getMemRef() << '['; if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; @@ -1896,7 +1895,7 @@ void AffineStoreOp::print(OpAsmPrinter &p) { LogicalResult AffineStoreOp::verify() { // First operand must have same type as memref element type. - if (getValueToStore()->getType() != getMemRefType().getElementType()) + if (getValueToStore().getType() != getMemRefType().getElementType()) return emitOpError("first operand must have same type memref element type"); auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); @@ -1914,7 +1913,7 @@ LogicalResult AffineStoreOp::verify() { } for (auto idx : getMapOperands()) { - if (!idx->getType().isIndex()) + if (!idx.getType().isIndex()) return emitOpError("index to store must have 'index' type"); if (!isValidAffineIndexOperand(idx)) return emitOpError("index must be a dimension or symbol identifier"); @@ -2059,7 +2058,7 @@ static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, } void print(OpAsmPrinter &p, AffinePrefetchOp op) { - p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '['; + p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '['; AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()); if (mapAttr) { SmallVector<Value, 2> operands(op.getMapOperands()); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index df6015de1b9..1dfa2460b61 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -47,8 +47,8 @@ static Value emitUniformPerLayerDequantize(Location loc, Value input, return nullptr; } - Type storageType = elementType.castToStorageType(input->getType()); - Type realType = elementType.castToExpressedType(input->getType()); + Type storageType = elementType.castToStorageType(input.getType()); + Type realType = elementType.castToExpressedType(input.getType()); Type intermediateType = castElementType(storageType, IntegerType::get(32, rewriter.getContext())); assert(storageType && "cannot cast to storage type"); @@ -90,7 +90,7 @@ emitUniformPerAxisDequantize(Location loc, Value input, static Value emitDequantize(Location loc, Value input, PatternRewriter &rewriter) { - Type inputType = input->getType(); + Type inputType = input.getType(); QuantizedType qElementType = QuantizedType::getQuantizedElementType(inputType); if (auto uperLayerElementType = @@ -113,8 +113,8 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> { PatternMatchResult matchAndRewrite(DequantizeCastOp op, PatternRewriter &rewriter) const override { - Type inputType = op.arg()->getType(); - Type outputType = op.getResult()->getType(); + Type inputType = op.arg().getType(); + Type outputType = op.getResult().getType(); QuantizedType inputElementType = QuantizedType::getQuantizedElementType(inputType); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 8cea97c693c..5eb7492c424 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -53,11 +53,11 @@ struct UniformBinaryOpInfo { UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs, Optional<APFloat> clampMin, Optional<APFloat> clampMax) : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), - lhsType(getUniformElementType(lhs->getType())), - rhsType(getUniformElementType(rhs->getType())), + lhsType(getUniformElementType(lhs.getType())), + rhsType(getUniformElementType(rhs.getType())), resultType(getUniformElementType(*op->result_type_begin())), - lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())), - rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())), + lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())), + rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())), resultStorageType( quant::QuantizedType::castToStorageType(*op->result_type_begin())) { } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 32d7fae65d9..e750d0fefff 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -110,7 +110,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, // to encode target module" has landed. // auto functionType = kernelFunc.getType(); // for (unsigned i = 0; i < numKernelFuncArgs; ++i) { - // if (getKernelOperand(i)->getType() != functionType.getInput(i)) { + // if (getKernelOperand(i).getType() != functionType.getInput(i)) { // return emitOpError("type of function argument ") // << i << " does not match"; // } @@ -137,7 +137,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { if (allReduce.body().front().getNumArguments() != 2) return allReduce.emitError("expected two region arguments"); for (auto argument : allReduce.body().front().getArguments()) { - if (argument->getType() != allReduce.getType()) + if (argument.getType() != allReduce.getType()) return allReduce.emitError("incorrect region argument type"); } unsigned yieldCount = 0; @@ -145,7 +145,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { if (yield.getNumOperands() != 1) return allReduce.emitError("expected one gpu.yield operand"); - if (yield.getOperand(0)->getType() != allReduce.getType()) + if (yield.getOperand(0).getType() != allReduce.getType()) return allReduce.emitError("incorrect gpu.yield type"); ++yieldCount; } @@ -157,8 +157,8 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) { } static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { - auto type = shuffleOp.value()->getType(); - if (shuffleOp.result()->getType() != type) { + auto type = shuffleOp.value().getType(); + if (shuffleOp.result().getType() != type) { return shuffleOp.emitOpError() << "requires the same type for value operand and result"; } @@ -170,10 +170,8 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) { } static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { - p << ShuffleOp::getOperationName() << ' '; - p.printOperands(op.getOperands()); - p << ' ' << op.mode() << " : "; - p.printType(op.value()->getType()); + p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' + << op.mode() << " : " << op.value().getType(); } static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { @@ -201,14 +199,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { // LaunchOp //===----------------------------------------------------------------------===// -static SmallVector<Type, 4> getValueTypes(ValueRange values) { - SmallVector<Type, 4> types; - types.reserve(values.size()); - for (Value v : values) - types.push_back(v->getType()); - return types; -} - void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX, Value gridSizeY, Value gridSizeZ, Value blockSizeX, Value blockSizeY, Value blockSizeZ, ValueRange operands) { @@ -224,7 +214,7 @@ void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX, Block *body = new Block(); body->addArguments( std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); - body->addArguments(getValueTypes(operands)); + body->addArguments(llvm::to_vector<4>(operands.getTypes())); kernelRegion->push_back(body); } @@ -309,10 +299,10 @@ LogicalResult verify(LaunchOp op) { // where %size-* and %iter-* will correspond to the body region arguments. static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, ValueRange operands, KernelDim3 ids) { - p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; - p << *size.x << " = " << *operands[0] << ", "; - p << *size.y << " = " << *operands[1] << ", "; - p << *size.z << " = " << *operands[2] << ')'; + p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; + p << size.x << " = " << operands[0] << ", "; + p << size.y << " = " << operands[1] << ", "; + p << size.z << " = " << operands[2] << ')'; } void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { @@ -335,8 +325,8 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { p << ' ' << op.getArgsKeyword() << '('; Block *entryBlock = &op.body().front(); interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) { - p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i) - << " = " << *operands[i]; + p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i) + << " = " << operands[i]; }); p << ") "; } @@ -486,14 +476,14 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { for (unsigned i = operands.size(); i > 0; --i) { unsigned index = i - 1; Value operand = operands[index]; - if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) + if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp())) continue; found = true; Value internalConstant = - rewriter.clone(*operand->getDefiningOp())->getResult(0); + rewriter.clone(*operand.getDefiningOp())->getResult(0); Value kernelArg = *std::next(kernelArgs.begin(), index); - kernelArg->replaceAllUsesWith(internalConstant); + kernelArg.replaceAllUsesWith(internalConstant); launchOp.eraseKernelArgument(index); } @@ -740,7 +730,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword, p << ' ' << keyword << '('; interleaveComma(values, p, - [&p](BlockArgument v) { p << *v << " : " << v->getType(); }); + [&p](BlockArgument v) { p << v << " : " << v.getType(); }); p << ')'; } @@ -790,7 +780,7 @@ static LogicalResult verifyAttributions(Operation *op, ArrayRef<BlockArgument> attributions, unsigned memorySpace) { for (Value v : attributions) { - auto type = v->getType().dyn_cast<MemRefType>(); + auto type = v.getType().dyn_cast<MemRefType>(); if (!type) return op->emitOpError() << "expected memref type in attribution"; @@ -814,7 +804,7 @@ LogicalResult GPUFuncOp::verifyBody() { ArrayRef<Type> funcArgTypes = getType().getInputs(); for (unsigned i = 0; i < numFuncArguments; ++i) { - Type blockArgType = front().getArgument(i)->getType(); + Type blockArgType = front().getArgument(i).getType(); if (funcArgTypes[i] != blockArgType) return emitOpError() << "expected body region argument #" << i << " to be of type " << funcArgTypes[i] << ", got " diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 2d00ac03d33..37f9c2e7b84 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -45,7 +45,7 @@ static void injectGpuIndexOperations(Location loc, Region &body) { // Replace the leading 12 function args with the respective thread/block index // operations. Iterate backwards since args are erased and indices change. for (int i = 11; i >= 0; --i) { - firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]); + firstBlock.getArgument(i).replaceAllUsesWith(indexOps[i]); firstBlock.eraseArgument(i); } } @@ -66,7 +66,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i)); } for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) { - auto operandOp = launch.getKernelOperand(i)->getDefiningOp(); + auto operandOp = launch.getKernelOperand(i).getDefiningOp(); if (!operandOp || !isInliningBeneficiary(operandOp)) { newLaunchArgs.push_back(launch.getKernelOperand(i)); continue; @@ -77,7 +77,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, continue; } auto clone = kernelBuilder.clone(*operandOp, map); - firstBlock.getArgument(i)->replaceAllUsesWith(clone->getResult(0)); + firstBlock.getArgument(i).replaceAllUsesWith(clone->getResult(0)); firstBlock.eraseArgument(i); } if (newLaunchArgs.size() == launch.getNumKernelOperands()) @@ -88,7 +88,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc, SmallVector<Type, 8> newArgumentTypes; newArgumentTypes.reserve(firstBlock.getNumArguments()); for (auto value : firstBlock.getArguments()) { - newArgumentTypes.push_back(value->getType()); + newArgumentTypes.push_back(value.getType()); } kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {})); auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 71b7064ac63..5c96581741b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -35,16 +35,16 @@ using namespace mlir::LLVM; //===----------------------------------------------------------------------===// static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) - << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); + << "\" " << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); - p << " : " << op.lhs()->getType(); + p << " : " << op.lhs().getType(); } static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) - << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); + << "\" " << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); - p << " : " << op.lhs()->getType(); + p << " : " << op.lhs().getType(); } // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use @@ -120,10 +120,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy(); - auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()}, + auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()}, op.getContext()); - p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy; + p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) p.printOptionalAttrDict(op.getAttrs()); else @@ -168,7 +168,7 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) { SmallVector<Type, 8> types(op.getOperandTypes()); auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); - p << op.getOperationName() << ' ' << *op.base() << '[' + p << op.getOperationName() << ' ' << op.base() << '[' << op.getOperands().drop_front() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << funcTy; @@ -212,9 +212,9 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { - p << op.getOperationName() << ' ' << *op.addr(); + p << op.getOperationName() << ' ' << op.addr(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.addr()->getType(); + p << " : " << op.addr().getType(); } // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return @@ -256,9 +256,9 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { - p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr(); + p << op.getOperationName() << ' ' << op.value() << ", " << op.addr(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.addr()->getType(); + p << " : " << op.addr().getType(); } // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type @@ -300,7 +300,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) { if (isDirect) p.printSymbolName(callee.getValue()); else - p << *op.getOperand(0); + p << op.getOperand(0); p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; p.printOptionalAttrDict(op.getAttrs(), {"callee"}); @@ -408,17 +408,17 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { void LLVM::ExtractElementOp::build(Builder *b, OperationState &result, Value vector, Value position, ArrayRef<NamedAttribute> attrs) { - auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>(); + auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>(); auto llvmType = wrappedVectorType.getVectorElementType(); build(b, result, llvmType, vector, position); result.addAttributes(attrs); } static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { - p << op.getOperationName() << ' ' << *op.vector() << "[" << *op.position() - << " : " << op.position()->getType() << "]"; + p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() + << " : " << op.position().getType() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector()->getType(); + p << " : " << op.vector().getType(); } // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use @@ -450,9 +450,9 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { - p << op.getOperationName() << ' ' << *op.container() << op.position(); + p << op.getOperationName() << ' ' << op.container() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); - p << " : " << op.container()->getType(); + p << " : " << op.container().getType(); } // Extract the type at `position` in the wrapped LLVM IR aggregate type @@ -542,10 +542,10 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { - p << op.getOperationName() << ' ' << *op.value() << ", " << *op.vector() - << "[" << *op.position() << " : " << op.position()->getType() << "]"; + p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" + << op.position() << " : " << op.position().getType() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector()->getType(); + p << " : " << op.vector().getType(); } // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use @@ -586,10 +586,10 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { - p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container() + p << op.getOperationName() << ' ' << op.value() << ", " << op.container() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); - p << " : " << op.container()->getType(); + p << " : " << op.container().getType(); } // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use @@ -629,10 +629,10 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static void printSelectOp(OpAsmPrinter &p, SelectOp &op) { - p << op.getOperationName() << ' ' << *op.condition() << ", " - << *op.trueValue() << ", " << *op.falseValue(); + p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue() + << ", " << op.falseValue(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType(); + p << " : " << op.condition().getType() << ", " << op.trueValue().getType(); } // <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use @@ -686,7 +686,7 @@ static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { - p << op.getOperationName() << ' ' << *op.getOperand(0) << ", "; + p << op.getOperationName() << ' ' << op.getOperand(0) << ", "; p.printSuccessorAndUseList(op.getOperation(), 0); p << ", "; p.printSuccessorAndUseList(op.getOperation(), 1); @@ -733,7 +733,7 @@ static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) { if (op.getNumOperands() == 0) return; - p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType(); + p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); } // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` @@ -761,7 +761,7 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { static void printUndefOp(OpAsmPrinter &p, UndefOp &op) { p << op.getOperationName(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.res()->getType(); + p << " : " << op.res().getType(); } // <operation> ::= `llvm.mlir.undef` attribute-dict? : type @@ -792,7 +792,7 @@ GlobalOp AddressOfOp::getGlobal() { static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) { p << op.getOperationName() << " @" << op.global_name(); p.printOptionalAttrDict(op.getAttrs(), {"global_name"}); - p << " : " << op.getResult()->getType(); + p << " : " << op.getResult().getType(); } static ParseResult parseAddressOfOp(OpAsmParser &parser, @@ -816,7 +816,7 @@ static LogicalResult verify(AddressOfOp op) { "must reference a global defined by 'llvm.mlir.global'"); if (global.getType().getPointerTo(global.addr_space().getZExtValue()) != - op.getResult()->getType()) + op.getResult().getType()) return op.emitOpError( "the type must be a pointer to the type of the referred global"); @@ -830,7 +830,7 @@ static LogicalResult verify(AddressOfOp op) { static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) { p << op.getOperationName() << '(' << op.value() << ')'; p.printOptionalAttrDict(op.getAttrs(), {"value"}); - p << " : " << op.res()->getType(); + p << " : " << op.res().getType(); } // <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type @@ -1060,7 +1060,7 @@ static LogicalResult verify(GlobalOp op) { void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1, Value v2, ArrayAttr mask, ArrayRef<NamedAttribute> attrs) { - auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>(); + auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>(); auto vType = LLVMType::getVectorTy( wrappedContainerType1.getVectorElementType(), mask.size()); build(b, result, vType, v1, v2, mask); @@ -1068,10 +1068,10 @@ void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1, } static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { - p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " " + p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " << op.mask(); p.printOptionalAttrDict(op.getAttrs(), {"mask"}); - p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); + p << " : " << op.v1().getType() << ", " << op.v2().getType(); } // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use @@ -1329,7 +1329,7 @@ static LogicalResult verify(LLVMFuncOp op) { unsigned numArguments = funcType->getNumParams(); Block &entryBlock = op.front(); for (unsigned i = 0; i < numArguments; ++i) { - Type argType = entryBlock.getArgument(i)->getType(); + Type argType = entryBlock.getArgument(i).getType(); auto argLLVMType = argType.dyn_cast<LLVMType>(); if (!argLLVMType) return op.emitOpError("entry block argument #") diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 7644cc69218..144afa4c5e1 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -48,30 +48,30 @@ Value Aliases::find(Value v) { auto it = aliases.find(v); if (it != aliases.end()) { - assert(it->getSecond()->getType().isa<MemRefType>() && "Memref expected"); + assert(it->getSecond().getType().isa<MemRefType>() && "Memref expected"); return it->getSecond(); } while (true) { if (v.isa<BlockArgument>()) return v; - if (auto alloc = dyn_cast_or_null<AllocOp>(v->getDefiningOp())) { + if (auto alloc = dyn_cast_or_null<AllocOp>(v.getDefiningOp())) { if (isStrided(alloc.getType())) return alloc.getResult(); } - if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) { + if (auto slice = dyn_cast_or_null<SliceOp>(v.getDefiningOp())) { auto it = aliases.insert(std::make_pair(v, find(slice.view()))); return it.first->second; } - if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) { + if (auto view = dyn_cast_or_null<ViewOp>(v.getDefiningOp())) { auto it = aliases.insert(std::make_pair(v, view.source())); return it.first->second; } - if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) { + if (auto view = dyn_cast_or_null<SubViewOp>(v.getDefiningOp())) { v = view.source(); continue; } - llvm::errs() << "View alias analysis reduces to: " << *v << "\n"; + llvm::errs() << "View alias analysis reduces to: " << v << "\n"; llvm_unreachable("unsupported view alias case"); } } @@ -224,7 +224,7 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences( auto *op = dependence.dependentOpView.op; LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << toStringRef(dt) << ": " << *src << " -> " << *op - << " on " << *dependence.indexingView); + << " on " << dependence.indexingView); res.push_back(op); } } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index b35a8ed0fd8..9b850431113 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -88,7 +88,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( for (auto it : llvm::enumerate(values)) blockTypes.push_back((it.index() < nViews) ? getElementTypeOrSelf(it.value()) - : it.value()->getType()); + : it.value().getType()); assert(op->getRegions().front().empty()); op->getRegions().front().push_front(new Block); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index efbb95e7319..cf27a817edb 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -120,7 +120,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { for (unsigned i = 0; i < nViews; ++i) { auto viewType = op.getShapedType(i); - if (viewType.getElementType() != block.getArgument(i)->getType()) + if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") << i << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") @@ -139,7 +139,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { "number of loops"); for (unsigned i = 0; i < nLoops; ++i) { - if (!block.getArgument(i)->getType().isIndex()) + if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") << i << " to be of IndexType"; } @@ -148,7 +148,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != - block.getArgument(memrefArgIndex)->getType()) + block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") << memrefArgIndex << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") @@ -314,10 +314,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RangeOp op) { - p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":" - << *op.step(); + p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":" + << op.step(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getResult()->getType(); + p << " : " << op.getResult().getType(); } static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) { @@ -541,7 +541,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, result.addOperands(base); result.addOperands(indexings); - auto memRefType = base->getType().cast<MemRefType>(); + auto memRefType = base.getType().cast<MemRefType>(); int64_t offset; SmallVector<int64_t, 4> strides; auto res = getStridesAndOffset(memRefType, strides, offset); @@ -560,7 +560,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, static void print(OpAsmPrinter &p, SliceOp op) { auto indexings = op.indexings(); - p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings + p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings << "] "; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getBaseViewType(); @@ -599,7 +599,7 @@ static LogicalResult verify(SliceOp op) { << rank << " indexings, got " << llvm::size(op.indexings()); unsigned index = 0; for (auto indexing : op.indexings()) { - if (indexing->getType().isa<IndexType>()) + if (indexing.getType().isa<IndexType>()) --rank; ++index; } @@ -618,7 +618,7 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, auto permutationMap = permutation.getValue(); assert(permutationMap); - auto memRefType = view->getType().cast<MemRefType>(); + auto memRefType = view.getType().cast<MemRefType>(); auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); // Compute permuted sizes. @@ -644,10 +644,10 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, } static void print(OpAsmPrinter &p, TransposeOp op) { - p << op.getOperationName() << " " << *op.view() << " " << op.permutation(); + p << op.getOperationName() << " " << op.view() << " " << op.permutation(); p.printOptionalAttrDict(op.getAttrs(), {TransposeOp::getPermutationAttrName()}); - p << " : " << op.view()->getType(); + p << " : " << op.view().getType(); } static ParseResult parseTransposeOp(OpAsmParser &parser, @@ -698,9 +698,9 @@ LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { for (unsigned i = 0; i != nOutputViews; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); - if (op.getOperand(i)->getType() != elementType) + if (op.getOperand(i).getType() != elementType) return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i)->getType() + << i << " (" << op.getOperand(i).getType() << ") doesn't match view element type (" << elementType << ")"; } return success(); @@ -765,7 +765,7 @@ static ParseResult parseLinalgStructuredOp(OpAsmParser &parser, static LogicalResult verify(FillOp op) { auto viewType = op.getOutputShapedType(0); - auto fillType = op.value()->getType(); + auto fillType = op.value().getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); return success(); @@ -816,9 +816,9 @@ verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) { } static LogicalResult verify(ConvOp op) { - auto oType = op.output()->getType().cast<MemRefType>(); - auto fType = op.filter()->getType().cast<MemRefType>(); - auto iType = op.input()->getType().cast<MemRefType>(); + auto oType = op.output().getType().cast<MemRefType>(); + auto fType = op.filter().getType().cast<MemRefType>(); + auto iType = op.input().getType().cast<MemRefType>(); if (oType.getElementType() != iType.getElementType() || oType.getElementType() != fType.getElementType()) return op.emitOpError("expects memref elemental types to match"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 9df7bce0879..043d9c0e7cd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -133,8 +133,7 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth << "\n"); - LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view - << "\n"); + LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); return ViewDimension{view, static_cast<unsigned>(en2.index())}; } } @@ -146,9 +145,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, unsigned producerIdx, OperationFolder *folder) { auto subView = dyn_cast_or_null<SubViewOp>( - consumer.getInput(consumerIdx)->getDefiningOp()); - auto slice = dyn_cast_or_null<SliceOp>( - consumer.getInput(consumerIdx)->getDefiningOp()); + consumer.getInput(consumerIdx).getDefiningOp()); + auto slice = + dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp()); assert(subView || slice); (void)subView; (void)slice; @@ -272,13 +271,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() - << " view: " << *producedView + << " view: " << producedView << " output index: " << producerIdx); // Must be a subview or a slice to guarantee there are loops we can fuse // into. - auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp()); - auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp()); + auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); + auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); if (!subView && !slice) { LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 4bc452afa36..9657daf9137 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -166,7 +166,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { // TODO(ntv): non-identity layout. auto isStaticMemRefWithIdentityLayout = [](Value v) { - auto m = v->getType().dyn_cast<MemRefType>(); + auto m = v.getType().dyn_cast<MemRefType>(); if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) return false; return true; @@ -281,7 +281,7 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, LinalgOp linOp = cast<LinalgOp>(op); SetVector<Value> subViews; for (auto it : linOp.getInputsAndOutputs()) - if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { promoteSubViewOperands(rewriter, linOp, subViews); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index b8b27958ff5..eb605699890 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -47,10 +47,10 @@ static llvm::cl::opt<bool> clPromoteDynamic( llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) { - auto *ctx = size->getContext(); + auto *ctx = size.getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) - if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp())) + if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp())) return alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); Value mul = muli(constant_index(width), size); @@ -116,7 +116,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, res.reserve(subViews.size()); DenseMap<Value, PromotionInfo> promotionInfoMap; for (auto v : subViews) { - SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + SubViewOp subView = cast<SubViewOp>(v.getDefiningOp()); auto viewType = subView.getType(); // TODO(ntv): support more cases than just float. if (!viewType.getElementType().isa<FloatType>()) @@ -128,7 +128,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, } for (auto v : subViews) { - SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + SubViewOp subView = cast<SubViewOp>(v.getDefiningOp()); auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; @@ -146,7 +146,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; - copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView); + copy(cast<SubViewOp>(v.getDefiningOp()), info->second.partialLocalView); } return res; } @@ -208,7 +208,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { SetVector<Value> subViews; OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) - if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 3841392dbdb..b77f658aa2f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -45,8 +45,8 @@ static llvm::cl::list<unsigned> llvm::cl::cat(clOptionsCategory)); static bool isZero(Value v) { - return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) && - cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0; + return isa_and_nonnull<ConstantIndexOp>(v.getDefiningOp()) && + cast<ConstantIndexOp>(v.getDefiningOp()).getValue() == 0; } using LoopIndexToRangeIndexMap = DenseMap<int, int>; @@ -201,8 +201,8 @@ void transformIndexedGenericOpIndices( // variable and replace all uses of the previous value. Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, pivs[rangeIndex->second]->getValue()); - for (auto &use : oldIndex->getUses()) { - if (use.getOwner() == newIndex->getDefiningOp()) + for (auto &use : oldIndex.getUses()) { + if (use.getOwner() == newIndex.getDefiningOp()) continue; use.set(newIndex); } @@ -258,7 +258,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value view = *(viewIteratorBegin + viewIndex); - unsigned rank = view->getType().cast<MemRefType>().getRank(); + unsigned rank = view.getType().cast<MemRefType>().getRank(); auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; // If the view is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { @@ -299,8 +299,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // defined. if (folder) for (auto v : llvm::concat<Value>(lbs, subViewSizes)) - if (v->use_empty()) - v->getDefiningOp()->erase(); + if (v.use_empty()) + v.getDefiningOp()->erase(); return res; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 66b82fd14e4..1b8a7be7a22 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -35,9 +35,9 @@ using namespace mlir::loop; mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, ValueHandle range) { assert(range.getType() && "expected !linalg.range type"); - assert(range.getValue()->getDefiningOp() && + assert(range.getValue().getDefiningOp() && "need operations to extract range parts"); - auto rangeOp = cast<RangeOp>(range.getValue()->getDefiningOp()); + auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp()); auto lb = rangeOp.min(); auto ub = rangeOp.max(); auto step = rangeOp.step(); @@ -168,7 +168,7 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) { res.reserve(nOperands); for (unsigned i = 0; i < nOperands; ++i) { res.push_back(op->getOperand(numViews + i)); - auto t = res.back()->getType(); + auto t = res.back().getType(); (void)t; assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) && "expected scalar or vector type"); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index acbab01df79..5452b3d4ab8 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -69,23 +69,22 @@ void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub, } LogicalResult verify(ForOp op) { - if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp())) + if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be positive"); // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); - if (body->getNumArguments() != 1 || - !body->getArgument(0)->getType().isIndex()) + if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex()) return op.emitOpError("expected body to have a single index argument for " "the induction variable"); return success(); } static void print(OpAsmPrinter &p, ForOp op) { - p << op.getOperationName() << " " << *op.getInductionVar() << " = " - << *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step(); + p << op.getOperationName() << " " << op.getInductionVar() << " = " + << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); @@ -126,11 +125,11 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) { Region &ForOp::getLoopBody() { return region(); } bool ForOp::isDefinedOutsideOfLoop(Value value) { - return !region().isAncestor(value->getParentRegion()); + return !region().isAncestor(value.getParentRegion()); } LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) { - for (auto *op : ops) + for (auto op : ops) op->moveBefore(this->getOperation()); return success(); } @@ -139,8 +138,8 @@ ForOp mlir::loop::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast<BlockArgument>(); if (!ivArg) return ForOp(); - assert(ivArg->getOwner() && "unlinked block argument"); - auto *containingInst = ivArg->getOwner()->getParentOp(); + assert(ivArg.getOwner() && "unlinked block argument"); + auto *containingInst = ivArg.getOwner()->getParentOp(); return dyn_cast_or_null<ForOp>(containingInst); } @@ -205,7 +204,7 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, IfOp op) { - p << IfOp::getOperationName() << " " << *op.condition(); + p << IfOp::getOperationName() << " " << op.condition(); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index faeff246bd2..8ff6fbed587 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -36,8 +36,8 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context) OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) { /// Matches x -> [scast -> scast] -> y, replacing the second scast with the /// value of x if the casts invert each other. - auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp()); - if (!srcScastOp || srcScastOp.arg()->getType() != getType()) + auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp()); + if (!srcScastOp || srcScastOp.arg().getType() != getType()) return OpFoldResult(); return srcScastOp.arg(); } diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 08a5ec59e8d..d62dd595985 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -52,7 +52,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, // Does the qbarrier convert to a quantized type. This will not be true // if a quantized type has not yet been chosen or if the cast to an equivalent // storage type is not supported. - Type qbarrierResultType = qbarrier.getResult()->getType(); + Type qbarrierResultType = qbarrier.getResult().getType(); QuantizedType quantizedElementType = QuantizedType::getQuantizedElementType(qbarrierResultType); if (!quantizedElementType) { @@ -66,7 +66,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, // type? This will not be true if the qbarrier is superfluous (converts // from and to a quantized type). if (!quantizedElementType.isCompatibleExpressedType( - qbarrier.arg()->getType())) { + qbarrier.arg().getType())) { return matchFailure(); } @@ -86,7 +86,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, // When creating the new const op, use a fused location that combines the // original const and the qbarrier that led to the quantization. auto fusedLoc = FusedLoc::get( - {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()}, + {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}, rewriter.getContext()); auto newConstOp = rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index d6fd35418b3..98c36fb9b08 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -104,7 +104,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { // Replace the values directly with the return operands. assert(valuesToRepl.size() == 1 && "spv.ReturnValue expected to only handle one result"); - valuesToRepl.front()->replaceAllUsesWith(retValOp.value()); + valuesToRepl.front().replaceAllUsesWith(retValOp.value()); } }; } // namespace diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ff20e091f91..a6f5d358d0c 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -167,8 +167,8 @@ printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer, static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth = true) { - Type operandType = op->getOperand(0)->getType(); - Type resultType = op->getResult(0)->getType(); + Type operandType = op->getOperand(0).getType(); + Type resultType = op->getResult(0).getType(); // ODS checks that result type and operand type have the same shape. if (auto vectorType = operandType.dyn_cast<VectorType>()) { @@ -271,8 +271,8 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, // // TODO(ravishankarm): Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations - if (val->getType() != - ptr->getType().cast<spirv::PointerType>().getPointeeType()) { + if (val.getType() != + ptr.getType().cast<spirv::PointerType>().getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -497,11 +497,11 @@ static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) { } static LogicalResult verifyBitFieldExtractOp(Operation *op) { - if (op->getOperand(0)->getType() != op->getResult(0)->getType()) { + if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " "result, but provided ") - << op->getOperand(0)->getType() << " and " - << op->getResult(0)->getType(); + << op->getOperand(0).getType() << " and " + << op->getResult(0).getType(); } return success(); } @@ -547,13 +547,12 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { printer << spirv::stringifyMemorySemantics( static_cast<spirv::MemorySemantics>( memorySemanticsAttr.getInt())) - << "\" " << op->getOperands() << " : " - << op->getOperand(0)->getType(); + << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); } // Verifies an atomic update op. static LogicalResult verifyAtomicUpdateOp(Operation *op) { - auto ptrType = op->getOperand(0)->getType().cast<spirv::PointerType>(); + auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>(); auto elementType = ptrType.getPointeeType(); if (!elementType.isa<IntegerType>()) return op->emitOpError( @@ -561,7 +560,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) { << elementType; if (op->getNumOperands() > 1) { - auto valueType = op->getOperand(1)->getType(); + auto valueType = op->getOperand(1).getType(); if (valueType != elementType) return op->emitOpError("expected value to have the same type as the " "pointer operand's pointee type ") @@ -595,8 +594,8 @@ static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) { } static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { - printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : " - << unaryOp->getOperand(0)->getType(); + printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : " + << unaryOp->getOperand(0).getType(); } /// Result of a logical op must be a scalar or vector of boolean type. @@ -634,7 +633,7 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : " - << logicalOp->getOperand(0)->getType(); + << logicalOp->getOperand(0).getType(); } static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { @@ -657,16 +656,16 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { static void printShiftOp(Operation *op, OpAsmPrinter &printer) { Value base = op->getOperand(0); Value shift = op->getOperand(1); - printer << op->getName() << ' ' << *base << ", " << *shift << " : " - << base->getType() << ", " << shift->getType(); + printer << op->getName() << ' ' << base << ", " << shift << " : " + << base.getType() << ", " << shift.getType(); } static LogicalResult verifyShiftOp(Operation *op) { - if (op->getOperand(0)->getType() != op->getResult(0)->getType()) { + if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " "result, but provided ") - << op->getOperand(0)->getType() << " and " - << op->getResult(0)->getType(); + << op->getOperand(0).getType() << " and " + << op->getResult(0).getType(); } return success(); } @@ -704,7 +703,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { } index = 0; if (resultType.isa<spirv::StructType>()) { - Operation *op = indexSSA->getDefiningOp(); + Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spv.AccessChain' op index must be an " "integer spv.constant to access " @@ -734,7 +733,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { void spirv::AccessChainOp::build(Builder *builder, OperationState &state, Value basePtr, ValueRange indices) { - auto type = getElementPtrType(basePtr->getType(), indices, state.location); + auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); } @@ -768,14 +767,14 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, } static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { - printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr() - << '[' << op.indices() << "] : " << op.base_ptr()->getType(); + printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr() + << '[' << op.indices() << "] : " << op.base_ptr().getType(); } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { SmallVector<Value, 4> indices(accessChainOp.indices().begin(), accessChainOp.indices().end()); - auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(), + auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(), indices, accessChainOp.getLoc()); if (!resultType) { return failure(); @@ -808,7 +807,7 @@ struct CombineChainedAccessChain PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp, PatternRewriter &rewriter) const override { auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>( - accessChainOp.base_ptr()->getDefiningOp()); + accessChainOp.base_ptr().getDefiningOp()); if (!parentAccessChainOp) { return matchFailure(); @@ -868,7 +867,7 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) { printer.printSymbolName(addressOfOp.variable()); // Print the type. - printer << " : " << addressOfOp.pointer()->getType(); + printer << " : " << addressOfOp.pointer().getType(); } static LogicalResult verify(spirv::AddressOfOp addressOfOp) { @@ -878,7 +877,7 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) { if (!varOp) { return addressOfOp.emitOpError("expected spv.globalVariable symbol"); } - if (addressOfOp.pointer()->getType() != varOp.type()) { + if (addressOfOp.pointer().getType() != varOp.type()) { return addressOfOp.emitOpError( "result type mismatch with the referenced global variable's type"); } @@ -926,7 +925,7 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp, << stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " - << atomOp.getOperands() << " : " << atomOp.pointer()->getType(); + << atomOp.getOperands() << " : " << atomOp.pointer().getType(); } static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { @@ -934,19 +933,19 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { // "The type of Value must be the same as Result Type. The type of the value // pointed to by Pointer must be the same as Result Type. This type must also // match the type of Comparator." - if (atomOp.getType() != atomOp.value()->getType()) + if (atomOp.getType() != atomOp.value().getType()) return atomOp.emitOpError("value operand must have the same type as the op " "result, but found ") - << atomOp.value()->getType() << " vs " << atomOp.getType(); + << atomOp.value().getType() << " vs " << atomOp.getType(); - if (atomOp.getType() != atomOp.comparator()->getType()) + if (atomOp.getType() != atomOp.comparator().getType()) return atomOp.emitOpError( "comparator operand must have the same type as the op " "result, but found ") - << atomOp.comparator()->getType() << " vs " << atomOp.getType(); + << atomOp.comparator().getType() << " vs " << atomOp.getType(); Type pointeeType = - atomOp.pointer()->getType().cast<spirv::PointerType>().getPointeeType(); + atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType(); if (atomOp.getType() != pointeeType) return atomOp.emitOpError( "pointer operand's pointee type must have the same " @@ -966,8 +965,8 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) { static LogicalResult verify(spirv::BitcastOp bitcastOp) { // TODO: The SPIR-V spec validation rules are different for different // versions. - auto operandType = bitcastOp.operand()->getType(); - auto resultType = bitcastOp.result()->getType(); + auto operandType = bitcastOp.operand().getType(); + auto resultType = bitcastOp.result().getType(); if (operandType == resultType) { return bitcastOp.emitError( "result type must be different from operand type"); @@ -1026,15 +1025,15 @@ static void print(spirv::BitFieldInsertOp bitFieldInsertOp, OpAsmPrinter &printer) { printer << spirv::BitFieldInsertOp::getOperationName() << ' ' << bitFieldInsertOp.getOperands() << " : " - << bitFieldInsertOp.base()->getType() << ", " - << bitFieldInsertOp.offset()->getType() << ", " - << bitFieldInsertOp.count()->getType(); + << bitFieldInsertOp.base().getType() << ", " + << bitFieldInsertOp.offset().getType() << ", " + << bitFieldInsertOp.count().getType(); } static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) { - auto baseType = bitFieldOp.base()->getType(); - auto insertType = bitFieldOp.insert()->getType(); - auto resultType = bitFieldOp.getResult()->getType(); + auto baseType = bitFieldOp.base().getType(); + auto insertType = bitFieldOp.insert().getType(); + auto resultType = bitFieldOp.getResult().getType(); if ((baseType != insertType) || (baseType != resultType)) { return bitFieldOp.emitError("expected the same type for the base operand, " @@ -1199,7 +1198,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, OpAsmPrinter &printer) { printer << spirv::CompositeConstructOp::getOperationName() << " " << compositeConstructOp.constituents() << " : " - << compositeConstructOp.getResult()->getType(); + << compositeConstructOp.getResult().getType(); } static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { @@ -1214,11 +1213,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { } for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { - if (constituents[index]->getType() != cType.getElementType(index)) { + if (constituents[index].getType() != cType.getElementType(index)) { return compositeConstructOp.emitError( "operand type mismatch: expected operand type ") << cType.getElementType(index) << ", but provided " - << constituents[index]->getType(); + << constituents[index].getType(); } } @@ -1234,7 +1233,7 @@ void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state, ArrayRef<int32_t> indices) { auto indexAttr = builder->getI32ArrayAttr(indices); auto elementType = - getElementType(composite->getType(), indexAttr, state.location); + getElementType(composite.getType(), indexAttr, state.location); if (!elementType) { return; } @@ -1268,13 +1267,13 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser, static void print(spirv::CompositeExtractOp compositeExtractOp, OpAsmPrinter &printer) { printer << spirv::CompositeExtractOp::getOperationName() << ' ' - << *compositeExtractOp.composite() << compositeExtractOp.indices() - << " : " << compositeExtractOp.composite()->getType(); + << compositeExtractOp.composite() << compositeExtractOp.indices() + << " : " << compositeExtractOp.composite().getType(); } static LogicalResult verify(spirv::CompositeExtractOp compExOp) { auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>(); - auto resultType = getElementType(compExOp.composite()->getType(), + auto resultType = getElementType(compExOp.composite().getType(), indicesArrayAttr, compExOp.getLoc()); if (!resultType) return failure(); @@ -1321,21 +1320,21 @@ static ParseResult parseCompositeInsertOp(OpAsmParser &parser, static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) { auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>(); auto objectType = - getElementType(compositeInsertOp.composite()->getType(), indicesArrayAttr, + getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr, compositeInsertOp.getLoc()); if (!objectType) return failure(); - if (objectType != compositeInsertOp.object()->getType()) { + if (objectType != compositeInsertOp.object().getType()) { return compositeInsertOp.emitOpError("object operand type should be ") << objectType << ", but found " - << compositeInsertOp.object()->getType(); + << compositeInsertOp.object().getType(); } - if (compositeInsertOp.composite()->getType() != compositeInsertOp.getType()) { + if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) { return compositeInsertOp.emitOpError("result type should be the same as " "the composite type, but found ") - << compositeInsertOp.composite()->getType() << " vs " + << compositeInsertOp.composite().getType() << " vs " << compositeInsertOp.getType(); } @@ -1345,10 +1344,10 @@ static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) { static void print(spirv::CompositeInsertOp compositeInsertOp, OpAsmPrinter &printer) { printer << spirv::CompositeInsertOp::getOperationName() << " " - << *compositeInsertOp.object() << ", " - << *compositeInsertOp.composite() << compositeInsertOp.indices() - << " : " << compositeInsertOp.object()->getType() << " into " - << compositeInsertOp.composite()->getType(); + << compositeInsertOp.object() << ", " << compositeInsertOp.composite() + << compositeInsertOp.indices() << " : " + << compositeInsertOp.object().getType() << " into " + << compositeInsertOp.composite().getType(); } //===----------------------------------------------------------------------===// @@ -1707,12 +1706,12 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { } for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { - if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) { + if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) { return functionCallOp.emitOpError( "operand type mismatch: expected operand type ") << functionType.getInput(i) << ", but provided " - << functionCallOp.getOperand(i)->getType() - << " for operand number " << i; + << functionCallOp.getOperand(i).getType() << " for operand number " + << i; } } @@ -1724,10 +1723,10 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { } if (functionCallOp.getNumResults() && - (functionCallOp.getResult(0)->getType() != functionType.getResult(0))) { + (functionCallOp.getResult(0).getType() != functionType.getResult(0))) { return functionCallOp.emitOpError("result type mismatch: expected ") << functionType.getResult(0) << ", but provided " - << functionCallOp.getResult(0)->getType(); + << functionCallOp.getResult(0).getType(); } return success(); @@ -1955,7 +1954,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) { void spirv::LoadOp::build(Builder *builder, OperationState &state, Value basePtr, IntegerAttr memory_access, IntegerAttr alignment) { - auto ptrType = basePtr->getType().cast<spirv::PointerType>(); + auto ptrType = basePtr.getType().cast<spirv::PointerType>(); build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, alignment); } @@ -1986,7 +1985,7 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) { auto *op = loadOp.getOperation(); SmallVector<StringRef, 4> elidedAttrs; StringRef sc = stringifyStorageClass( - loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); + loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass()); printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" " << loadOp.ptr(); @@ -2414,7 +2413,7 @@ static ParseResult parseReferenceOfOp(OpAsmParser &parser, static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) { printer << spirv::ReferenceOfOp::getOperationName() << ' '; printer.printSymbolName(referenceOfOp.spec_const()); - printer << " : " << referenceOfOp.reference()->getType(); + printer << " : " << referenceOfOp.reference().getType(); } static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { @@ -2424,7 +2423,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { if (!specConstOp) { return referenceOfOp.emitOpError("expected spv.specConstant symbol"); } - if (referenceOfOp.reference()->getType() != + if (referenceOfOp.reference().getType() != specConstOp.default_value().getType()) { return referenceOfOp.emitOpError("result type mismatch with the referenced " "specialization constant's type"); @@ -2461,7 +2460,7 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser, static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) { printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value() - << " : " << retValOp.value()->getType(); + << " : " << retValOp.value().getType(); } static LogicalResult verify(spirv::ReturnValueOp retValOp) { @@ -2472,7 +2471,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { "returns 1 value but enclosing function requires ") << numFnResults << " results"; - auto operandType = retValOp.value()->getType(); + auto operandType = retValOp.value().getType(); auto fnResultType = funcOp.getType().getResult(0); if (operandType != fnResultType) return retValOp.emitOpError(" return value's type (") @@ -2488,7 +2487,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond, Value trueValue, Value falseValue) { - build(builder, state, trueValue->getType(), cond, trueValue, falseValue); + build(builder, state, trueValue.getType(), cond, trueValue, falseValue); } static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) { @@ -2514,19 +2513,18 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) { static void print(spirv::SelectOp op, OpAsmPrinter &printer) { printer << spirv::SelectOp::getOperationName() << " " << op.getOperands() - << " : " << op.condition()->getType() << ", " - << op.result()->getType(); + << " : " << op.condition().getType() << ", " << op.result().getType(); } static LogicalResult verify(spirv::SelectOp op) { - auto resultTy = op.result()->getType(); - if (op.true_value()->getType() != resultTy) { + auto resultTy = op.result().getType(); + if (op.true_value().getType() != resultTy) { return op.emitOpError("result type and true value type must be the same"); } - if (op.false_value()->getType() != resultTy) { + if (op.false_value().getType() != resultTy) { return op.emitOpError("result type and false value type must be the same"); } - if (auto conditionTy = op.condition()->getType().dyn_cast<VectorType>()) { + if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) { auto resultVectorTy = resultTy.dyn_cast<VectorType>(); if (!resultVectorTy) { return op.emitOpError("result expected to be of vector type when " @@ -2695,7 +2693,7 @@ struct ConvertSelectionOpToSelect cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs(); auto selectOp = rewriter.create<spirv::SelectOp>( - selectionOp.getLoc(), trueValue->getType(), brConditionalOp.condition(), + selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(), trueValue, falseValue); rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue, selectOp.getResult(), storeOpAttributes); @@ -2773,7 +2771,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( // attributes and a valid type of the value. if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || - !isValidType(trueBrStoreOp.value()->getType())) { + !isValidType(trueBrStoreOp.value().getType())) { return matchFailure(); } @@ -2879,13 +2877,13 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) { auto *op = storeOp.getOperation(); SmallVector<StringRef, 4> elidedAttrs; StringRef sc = stringifyStorageClass( - storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); + storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass()); printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" " << storeOp.ptr() << ", " << storeOp.value(); printMemoryAccessAttribute(storeOp, printer, elidedAttrs); - printer << " : " << storeOp.value()->getType(); + printer << " : " << storeOp.value().getType(); printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -3025,7 +3023,7 @@ static LogicalResult verify(spirv::VariableOp varOp) { "spv.globalVariable for module-level variables."); } - auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>(); + auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>(); if (varOp.storage_class() != pointerType.getStorageClass()) return varOp.emitOpError( "storage class must match result pointer's storage class"); @@ -3033,7 +3031,7 @@ static LogicalResult verify(spirv::VariableOp varOp) { if (varOp.getNumOperands() != 0) { // SPIR-V spec: "Initializer must be an <id> from a constant instruction or // a global (module scope) OpVariable instruction". - auto *initOp = varOp.getOperand(0)->getDefiningOp(); + auto *initOp = varOp.getOperand(0).getDefiningOp(); if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant isa<spirv::ReferenceOfOp>(initOp) || // for spec constant isa<spirv::AddressOfOp>(initOp))) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 4e030217160..25d5763dc98 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1775,7 +1775,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { << " from block " << block << "\n"); if (!isFnEntryBlock(block)) { for (BlockArgument blockArg : block->getArguments()) { - auto newArg = newBlock->addArgument(blockArg->getType()); + auto newArg = newBlock->addArgument(blockArg.getType()); mapper.map(blockArg, newArg); LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg << " to " << newArg << '\n'); @@ -1816,7 +1816,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); for (BlockArgument blockArg : headerBlock->getArguments()) { - mergeBlock->addArgument(blockArg->getType()); + mergeBlock->addArgument(blockArg.getType()); } // If the loop header block has block arguments, make sure the spv.branch op @@ -2200,7 +2200,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) { "spirv::BitcastOp, only ") << wordIndex << " of " << words.size() << " processed"; } - if (resultTypes[0] == operands[0]->getType() && + if (resultTypes[0] == operands[0].getType() && resultTypes[0].isa<IntegerType>()) { // TODO(b/130356985): This check is added to ignore error in Op verification // due to both signed and unsigned integers mapping to the same diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index d7971eb5e35..74a959b0ea5 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -507,10 +507,10 @@ void Serializer::printValueIDMap(raw_ostream &os) { Value val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; - if (auto *op = val->getDefiningOp()) { + if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; } else if (auto arg = val.dyn_cast<BlockArgument>()) { - Block *block = arg->getOwner(); + Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; } @@ -714,7 +714,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { return failure(); } auto argValueID = getNextID(); @@ -1397,7 +1397,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Get the type <id> and result <id> for this OpPhi instruction. uint32_t phiTypeID = 0; - if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID))) + if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) return failure(); uint32_t phiID = getNextID(); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index 07621d6fa80..c94746ac22e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -100,7 +100,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { // Change the type for the direct users. target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) { - return VulkanLayoutUtils::isLegalType(op.pointer()->getType()); + return VulkanLayoutUtils::isLegalType(op.pointer().getType()); }); // TODO: Change the type for the indirect users such as spv.Load, spv.Store, diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index f9296196789..25e420ad9f8 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -79,7 +79,7 @@ struct StdInlinerInterface : public DialectInlinerInterface { // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } }; } // end anonymous namespace @@ -96,9 +96,9 @@ static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << *op->getOperand(0); + << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getOperand(0)->getType(); + p << " : " << op->getOperand(0).getType(); } /// A custom binary operation printer that omits the "std." prefix from the @@ -109,20 +109,20 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. - auto resultType = op->getResult(0)->getType(); - if (op->getOperand(0)->getType() != resultType || - op->getOperand(1)->getType() != resultType) { + auto resultType = op->getResult(0).getType(); + if (op->getOperand(0).getType() != resultType || + op->getOperand(1).getType() != resultType) { p.printGenericOp(op); return; } int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << *op->getOperand(0) << ", " << *op->getOperand(1); + << op->getOperand(0) << ", " << op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. - p << " : " << op->getResult(0)->getType(); + p << " : " << op->getResult(0).getType(); } /// A custom cast operation printer that omits the "std." prefix from the @@ -130,13 +130,13 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " - << op->getResult(0)->getType(); + << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to " + << op->getResult(0).getType(); } /// A custom cast operation verifier. template <typename T> static LogicalResult verifyCastOp(T op) { - auto opType = op.getOperand()->getType(); + auto opType = op.getOperand().getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) return op.emitError("operand type ") << opType << " and result type " @@ -209,8 +209,8 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp()); - if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) { + auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp()); + if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { operand.set(cast.getOperand()); folded = true; } @@ -281,7 +281,7 @@ static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) { } static LogicalResult verify(AllocOp op) { - auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>(); + auto memRefType = op.getResult().getType().dyn_cast<MemRefType>(); if (!memRefType) return op.emitOpError("result must be a memref"); @@ -338,7 +338,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); + auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); @@ -489,14 +489,14 @@ static LogicalResult verify(CallOp op) { return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i)->getType() != fnType.getInput(i)) + if (op.getOperand(i).getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i)->getType() != fnType.getResult(i)) + if (op.getResult(i).getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); @@ -553,12 +553,12 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser, static void print(OpAsmPrinter &p, CallIndirectOp op) { p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee()->getType(); + p << " : " << op.getCallee().getType(); } static LogicalResult verify(CallIndirectOp op) { // The callee must be a function. - auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>(); + auto fnType = op.getCallee().getType().dyn_cast<FunctionType>(); if (!fnType) return op.emitOpError("callee must have function type"); @@ -567,14 +567,14 @@ static LogicalResult verify(CallIndirectOp op) { return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) + if (op.getOperand(i + 1).getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i)->getType() != fnType.getResult(i)) + if (op.getResult(i).getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); @@ -616,7 +616,7 @@ static Type getI1SameShape(Builder *build, Type type) { static void buildCmpIOp(Builder *build, OperationState &result, CmpIPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(build, lhs->getType())); + result.types.push_back(getI1SameShape(build, lhs.getType())); result.addAttribute( CmpIOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast<int64_t>(predicate))); @@ -668,7 +668,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) { << '"' << ", " << op.lhs() << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); - p << " : " << op.lhs()->getType(); + p << " : " << op.lhs().getType(); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer @@ -769,7 +769,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { static void buildCmpFOp(Builder *build, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(build, lhs->getType())); + result.types.push_back(getI1SameShape(build, lhs.getType())); result.addAttribute( CmpFOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast<int64_t>(predicate))); @@ -824,7 +824,7 @@ static void print(OpAsmPrinter &p, CmpFOp op) { << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); - p << " : " << op.lhs()->getType(); + p << " : " << op.lhs().getType(); } static LogicalResult verify(CmpFOp op) { @@ -1123,14 +1123,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState &result, } bool ConstantFloatOp::classof(Operation *op) { - return ConstantOp::classof(op) && - op->getResult(0)->getType().isa<FloatType>(); + return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::classof(Operation *op) { return ConstantOp::classof(op) && - op->getResult(0)->getType().isa<IntegerType>(); + op->getResult(0).getType().isa<IntegerType>(); } void ConstantIntOp::build(Builder *builder, OperationState &result, @@ -1151,7 +1150,7 @@ void ConstantIntOp::build(Builder *builder, OperationState &result, /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::classof(Operation *op) { - return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); + return ConstantOp::classof(op) && op->getResult(0).getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState &result, @@ -1174,11 +1173,11 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value memref = dealloc.memref(); - if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp())) + if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto *user : memref->getUsers()) + for (auto *user : memref.getUsers()) if (!isa<DeallocOp>(user)) return matchFailure(); @@ -1190,7 +1189,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { } // end anonymous namespace. static void print(OpAsmPrinter &p, DeallocOp op) { - p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); + p << "dealloc " << op.memref() << " : " << op.memref().getType(); } static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { @@ -1203,7 +1202,7 @@ static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { } static LogicalResult verify(DeallocOp op) { - if (!op.memref()->getType().isa<MemRefType>()) + if (!op.memref().getType().isa<MemRefType>()) return op.emitOpError("operand must be a memref"); return success(); } @@ -1224,9 +1223,9 @@ LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, DimOp op) { - p << "dim " << *op.getOperand() << ", " << op.getIndex(); + p << "dim " << op.getOperand() << ", " << op.getIndex(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); - p << " : " << op.getOperand()->getType(); + p << " : " << op.getOperand().getType(); } static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { @@ -1251,7 +1250,7 @@ static LogicalResult verify(DimOp op) { return op.emitOpError("requires an integer attribute named 'index'"); int64_t index = indexAttr.getValue().getSExtValue(); - auto type = op.getOperand()->getType(); + auto type = op.getOperand().getType(); if (auto tensorType = type.dyn_cast<RankedTensorType>()) { if (index >= tensorType.getRank()) return op.emitOpError("index is out of range"); @@ -1270,7 +1269,7 @@ static LogicalResult verify(DimOp op) { OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { // Constant fold dim when the size along the index referred to is a constant. - auto opType = memrefOrTensor()->getType(); + auto opType = memrefOrTensor().getType(); int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast<RankedTensorType>()) indexSize = tensorType.getShape()[getIndex()]; @@ -1286,7 +1285,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { return {}; // The size at getIndex() is now a dynamic size of a memref. - auto memref = memrefOrTensor()->getDefiningOp(); + auto memref = memrefOrTensor().getDefiningOp(); if (auto alloc = dyn_cast_or_null<AllocOp>(memref)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(getIndex())); @@ -1367,16 +1366,15 @@ void DmaStartOp::build(Builder *builder, OperationState &result, } void DmaStartOp::print(OpAsmPrinter &p) { - p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], " - << *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements() - << ", " << *getTagMemRef() << '[' << getTagIndices() << ']'; + p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " + << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() + << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; if (isStrided()) - p << ", " << *getStride() << ", " << *getNumElementsPerStride(); + p << ", " << getStride() << ", " << getNumElementsPerStride(); p.printOptionalAttrDict(getAttrs()); - p << " : " << getSrcMemRef()->getType(); - p << ", " << getDstMemRef()->getType(); - p << ", " << getTagMemRef()->getType(); + p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() + << ", " << getTagMemRef().getType(); } // Parse DmaStartOp. @@ -1506,7 +1504,7 @@ void DmaWaitOp::print(OpAsmPrinter &p) { p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " << getNumElements(); p.printOptionalAttrDict(getAttrs()); - p << " : " << getTagMemRef()->getType(); + p << " : " << getTagMemRef().getType(); } // Parse DmaWaitOp. @@ -1553,10 +1551,10 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExtractElementOp op) { - p << "extract_element " << *op.getAggregate() << '[' << op.getIndices(); + p << "extract_element " << op.getAggregate() << '[' << op.getIndices(); p << ']'; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getAggregate()->getType(); + p << " : " << op.getAggregate().getType(); } static ParseResult parseExtractElementOp(OpAsmParser &parser, @@ -1577,7 +1575,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, } static LogicalResult verify(ExtractElementOp op) { - auto aggregateType = op.getAggregate()->getType().cast<ShapedType>(); + auto aggregateType = op.getAggregate().getType().cast<ShapedType>(); // This should be possible with tablegen type constraints if (op.getType() != aggregateType.getElementType()) @@ -1634,7 +1632,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, LoadOp op) { - p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']'; + p << "load " << op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -1781,7 +1779,7 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, PrefetchOp op) { - p << PrefetchOp::getOperationName() << " " << *op.memref() << '['; + p << PrefetchOp::getOperationName() << " " << op.memref() << '['; p.printOperands(op.indices()); p << ']' << ", " << (op.isWrite() ? "write" : "read"); p << ", locality<" << op.localityHint(); @@ -1851,7 +1849,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RankOp op) { - p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); + p << "rank " << op.getOperand() << " : " << op.getOperand().getType(); } static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { @@ -1866,7 +1864,7 @@ static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { // Constant fold rank when the rank of the tensor is known. - auto type = getOperand()->getType(); + auto type = getOperand().getType(); if (auto tensorType = type.dyn_cast<RankedTensorType>()) return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); return IntegerAttr(); @@ -1954,10 +1952,10 @@ static LogicalResult verify(ReturnOp op) { << " operands, but enclosing function returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i)->getType() != results[i]) + if (op.getOperand(i).getType() != results[i]) return op.emitError() << "type of return operand " << i << " (" - << op.getOperand(i)->getType() + << op.getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")"; return success(); @@ -1997,13 +1995,13 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SelectOp op) { - p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType(); + p << "select " << op.getOperands() << " : " << op.getTrueValue().getType(); p.printOptionalAttrDict(op.getAttrs()); } static LogicalResult verify(SelectOp op) { - auto trueType = op.getTrueValue()->getType(); - auto falseType = op.getFalseValue()->getType(); + auto trueType = op.getTrueValue().getType(); + auto falseType = op.getFalseValue().getType(); if (trueType != falseType) return op.emitOpError( @@ -2032,7 +2030,7 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { static LogicalResult verify(SignExtendIOp op) { // Get the scalar type (which is either directly the type of the operand // or the vector's/tensor's element type. - auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); // For now, index is forbidden for the source and the destination type. @@ -2054,7 +2052,7 @@ static LogicalResult verify(SignExtendIOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, SplatOp op) { - p << "splat " << *op.getOperand(); + p << "splat " << op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getType(); } @@ -2074,7 +2072,7 @@ static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) { static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. - if (op.getOperand()->getType() != + if (op.getOperand().getType() != op.getType().cast<ShapedType>().getElementType()) return op.emitError("operand should be of elemental type of result type"); @@ -2103,8 +2101,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, StoreOp op) { - p << "store " << *op.getValueToStore(); - p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']'; + p << "store " << op.getValueToStore(); + p << ", " << op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } @@ -2130,7 +2128,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { static LogicalResult verify(StoreOp op) { // First operand must have same type as memref element type. - if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) + if (op.getValueToStore().getType() != op.getMemRefType().getElementType()) return op.emitOpError( "first operand must have same type memref element type"); @@ -2251,9 +2249,9 @@ static Type getTensorTypeFromMemRefType(Builder &b, Type type) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorLoadOp op) { - p << "tensor_load " << *op.getOperand(); + p << "tensor_load " << op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getOperand()->getType(); + p << " : " << op.getOperand().getType(); } static ParseResult parseTensorLoadOp(OpAsmParser &parser, @@ -2274,9 +2272,9 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorStoreOp op) { - p << "tensor_store " << *op.tensor() << ", " << *op.memref(); + p << "tensor_store " << op.tensor() << ", " << op.memref(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.memref()->getType(); + p << " : " << op.memref().getType(); } static ParseResult parseTensorStoreOp(OpAsmParser &parser, @@ -2298,7 +2296,7 @@ static ParseResult parseTensorStoreOp(OpAsmParser &parser, //===----------------------------------------------------------------------===// static LogicalResult verify(TruncateIOp op) { - auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa<IndexType>()) @@ -2344,13 +2342,13 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, ViewOp op) { - p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; + p << op.getOperationName() << ' ' << op.getOperand(0) << '['; auto dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); p << "][" << op.getDynamicSizes() << ']'; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } Value ViewOp::getDynamicOffset() { @@ -2382,8 +2380,8 @@ static LogicalResult verifyDynamicStrides(MemRefType memrefType, } static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); - auto viewType = op.getResult()->getType().cast<MemRefType>(); + auto baseType = op.getOperand(0).getType().cast<MemRefType>(); + auto viewType = op.getResult().getType().cast<MemRefType>(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || @@ -2453,7 +2451,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { int64_t newOffset = oldOffset; unsigned dynamicOffsetOperandCount = 0; if (dynamicOffset != nullptr) { - auto *defOp = dynamicOffset->getDefiningOp(); + auto *defOp = dynamicOffset.getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic offset will be folded into the map. newOffset = constantIndexOp.getValue(); @@ -2478,7 +2476,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp(); + auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); @@ -2590,7 +2588,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, ValueRange strides, Type resultType, ArrayRef<NamedAttribute> attrs) { if (!resultType) - resultType = inferSubViewResultType(source->getType().cast<MemRefType>()); + resultType = inferSubViewResultType(source.getType().cast<MemRefType>()); auto segmentAttr = b->getI32VectorAttr( {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()), static_cast<int32_t>(strides.size())}); @@ -2637,13 +2635,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, SubViewOp op) { - p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets() + p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets() << "][" << op.sizes() << "][" << op.strides() << ']'; SmallVector<StringRef, 1> elidedAttrs = { SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); - p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } static LogicalResult verify(SubViewOp op) { @@ -2757,8 +2755,8 @@ static LogicalResult verify(SubViewOp op) { } raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { - return os << "range " << *range.offset << ":" << *range.size << ":" - << *range.stride; + return os << "range " << range.offset << ":" << range.size << ":" + << range.stride; } SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() { @@ -2827,7 +2825,7 @@ public: } SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes()); for (auto size : llvm::enumerate(subViewOp.sizes())) { - auto defOp = size.value()->getDefiningOp(); + auto defOp = size.value().getDefiningOp(); assert(defOp); staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue(); } @@ -2873,7 +2871,7 @@ public: SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides()); for (auto stride : llvm::enumerate(subViewOp.strides())) { - auto defOp = stride.value()->getDefiningOp(); + auto defOp = stride.value().getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); staticStrides[stride.index()] = @@ -2924,7 +2922,7 @@ public: auto staticOffset = baseOffset; for (auto offset : llvm::enumerate(subViewOp.offsets())) { - auto defOp = offset.value()->getDefiningOp(); + auto defOp = offset.value().getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); staticOffset += @@ -2959,7 +2957,7 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static LogicalResult verify(ZeroExtendIOp op) { - auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); + auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa<IndexType>()) diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 744500af663..11aea0936f2 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -163,9 +163,9 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { assert(op->getNumResults() == 1 && "only support broadcast check on one result"); - auto type1 = op->getOperand(0)->getType(); - auto type2 = op->getOperand(1)->getType(); - auto retType = op->getResult(0)->getType(); + auto type1 = op->getOperand(0).getType(); + auto type2 = op->getOperand(1).getType(); + auto retType = op->getResult(0).getType(); // We forbid broadcasting vector and tensor. if (hasBothVectorAndTensorType({type1, type2, retType})) diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 92a230eb5d1..8206d10962b 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -67,7 +67,7 @@ void vector::ContractionOp::build(Builder *builder, OperationState &result, ArrayAttr indexingMaps, ArrayAttr iteratorTypes) { result.addOperands({lhs, rhs, acc}); - result.addTypes(acc->getType()); + result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); } @@ -125,13 +125,13 @@ static void print(OpAsmPrinter &p, ContractionOp op) { attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); - p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", "; - p << *op.rhs() << ", " << *op.acc(); + p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; + p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) p << ", " << op.masks(); p.printOptionalAttrDict(op.getAttrs(), attrNames); - p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into " + p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into " << op.getResultType(); } @@ -211,7 +211,7 @@ static LogicalResult verify(ContractionOp op) { if (map.getNumDims() != numIterators) return op.emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; - auto operandType = op.getOperand(index)->getType().cast<VectorType>(); + auto operandType = op.getOperand(index).getType().cast<VectorType>(); unsigned rank = operandType.getShape().size(); if (map.getNumResults() != rank) return op.emitOpError("expected indexing map ") @@ -351,10 +351,10 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { - p << op.getOperationName() << " " << *op.vector() << "[" << *op.position() - << " : " << op.position()->getType() << "]"; + p << op.getOperationName() << " " << op.vector() << "[" << op.position() + << " : " << op.position().getType() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector()->getType(); + p << " : " << op.vector().getType(); } static ParseResult parseExtractElementOp(OpAsmParser &parser, @@ -398,15 +398,15 @@ void vector::ExtractOp::build(Builder *builder, OperationState &result, Value source, ArrayRef<int64_t> position) { result.addOperands(source); auto positionAttr = getVectorSubscriptAttr(*builder, position); - result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(), + result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(), positionAttr)); result.addAttribute(getPositionAttrName(), positionAttr); } static void print(OpAsmPrinter &p, vector::ExtractOp op) { - p << op.getOperationName() << " " << *op.vector() << op.position(); + p << op.getOperationName() << " " << op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); - p << " : " << op.vector()->getType(); + p << " : " << op.vector().getType(); } static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { @@ -495,13 +495,13 @@ static ParseResult parseExtractSlicesOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, ExtractSlicesOp op) { - p << op.getOperationName() << ' ' << *op.vector() << ", "; + p << op.getOperationName() << ' ' << op.vector() << ", "; p << op.sizes() << ", " << op.strides(); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(), ExtractSlicesOp::getStridesAttrName()}); - p << " : " << op.vector()->getType(); + p << " : " << op.vector().getType(); p << " into " << op.getResultTupleType(); } @@ -594,7 +594,7 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << *op.source() << " : " + p << op.getOperationName() << " " << op.source() << " : " << op.getSourceType() << " to " << op.getVectorType(); } @@ -642,15 +642,15 @@ void ShuffleOp::build(Builder *builder, OperationState &result, Value v1, Value v2, ArrayRef<int64_t> mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(*builder, mask); - result.addTypes(v1->getType()); + result.addTypes(v1.getType()); result.addAttribute(getMaskAttrName(), maskAttr); } static void print(OpAsmPrinter &p, ShuffleOp op) { - p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " " + p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " " << op.mask(); p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()}); - p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); + p << " : " << op.v1().getType() << ", " << op.v2().getType(); } static LogicalResult verify(ShuffleOp op) { @@ -725,10 +725,10 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, InsertElementOp op) { - p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "[" - << *op.position() << " : " << op.position()->getType() << "]"; + p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "[" + << op.position() << " : " << op.position().getType() << "]"; p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.dest()->getType(); + p << " : " << op.dest().getType(); } static ParseResult parseInsertElementOp(OpAsmParser &parser, @@ -766,12 +766,12 @@ void InsertOp::build(Builder *builder, OperationState &result, Value source, Value dest, ArrayRef<int64_t> position) { result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(*builder, position); - result.addTypes(dest->getType()); + result.addTypes(dest.getType()); result.addAttribute(getPositionAttrName(), positionAttr); } static void print(OpAsmPrinter &p, InsertOp op) { - p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() + p << op.getOperationName() << " " << op.source() << ", " << op.dest() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); p << " : " << op.getSourceType() << " into " << op.getDestVectorType(); @@ -851,13 +851,13 @@ static ParseResult parseInsertSlicesOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, InsertSlicesOp op) { - p << op.getOperationName() << ' ' << *op.vectors() << ", "; + p << op.getOperationName() << ' ' << op.vectors() << ", "; p << op.sizes() << ", " << op.strides(); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(), InsertSlicesOp::getStridesAttrName()}); - p << " : " << op.vectors()->getType(); + p << " : " << op.vectors().getType(); p << " into " << op.getResultVectorType(); } @@ -890,14 +890,13 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result, result.addOperands({source, dest}); auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets); auto stridesAttr = getVectorSubscriptAttr(*builder, strides); - result.addTypes(dest->getType()); + result.addTypes(dest.getType()); result.addAttribute(getOffsetsAttrName(), offsetsAttr); result.addAttribute(getStridesAttrName(), stridesAttr); } static void print(OpAsmPrinter &p, InsertStridedSliceOp op) { - p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() - << " "; + p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " "; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType(); } @@ -1049,10 +1048,10 @@ static LogicalResult verify(InsertStridedSliceOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, OuterProductOp op) { - p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); + p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); if (!op.acc().empty()) p << ", " << op.acc(); - p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); + p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); } static ParseResult parseOuterProductOp(OpAsmParser &parser, @@ -1103,7 +1102,7 @@ static LogicalResult verify(OuterProductOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ReshapeOp op) { - p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape() + p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape() << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes(); SmallVector<StringRef, 2> elidedAttrs = { ReshapeOp::getOperandSegmentSizeAttr(), @@ -1193,18 +1192,18 @@ static LogicalResult verify(ReshapeOp op) { // If all shape operands are produced by constant ops, verify that product // of dimensions for input/output shape match. auto isDefByConstant = [](Value operand) { - return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp()); + return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp()); }; if (llvm::all_of(op.input_shape(), isDefByConstant) && llvm::all_of(op.output_shape(), isDefByConstant)) { int64_t numInputElements = 1; for (auto operand : op.input_shape()) numInputElements *= - cast<ConstantIndexOp>(operand->getDefiningOp()).getValue(); + cast<ConstantIndexOp>(operand.getDefiningOp()).getValue(); int64_t numOutputElements = 1; for (auto operand : op.output_shape()) numOutputElements *= - cast<ConstantIndexOp>(operand->getDefiningOp()).getValue(); + cast<ConstantIndexOp>(operand.getDefiningOp()).getValue(); if (numInputElements != numOutputElements) return op.emitError("product of input and output shape sizes must match"); } @@ -1245,7 +1244,7 @@ void StridedSliceOp::build(Builder *builder, OperationState &result, auto sizesAttr = getVectorSubscriptAttr(*builder, sizes); auto stridesAttr = getVectorSubscriptAttr(*builder, strides); result.addTypes( - inferStridedSliceOpResultType(source->getType().cast<VectorType>(), + inferStridedSliceOpResultType(source.getType().cast<VectorType>(), offsetsAttr, sizesAttr, stridesAttr)); result.addAttribute(getOffsetsAttrName(), offsetsAttr); result.addAttribute(getSizesAttrName(), sizesAttr); @@ -1253,9 +1252,9 @@ void StridedSliceOp::build(Builder *builder, OperationState &result, } static void print(OpAsmPrinter &p, StridedSliceOp op) { - p << op.getOperationName() << " " << *op.vector(); + p << op.getOperationName() << " " << op.vector(); p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector()->getType() << " to " << op.getResult()->getType(); + p << " : " << op.vector().getType() << " to " << op.getResult().getType(); } static ParseResult parseStridedSliceOp(OpAsmParser &parser, @@ -1305,7 +1304,7 @@ static LogicalResult verify(StridedSliceOp op) { auto resultType = inferStridedSliceOpResultType( op.getVectorType(), op.offsets(), op.sizes(), op.strides()); - if (op.getResult()->getType() != resultType) { + if (op.getResult().getType() != resultType) { op.emitOpError("expected result type to be ") << resultType; return failure(); } @@ -1328,7 +1327,7 @@ public: PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp, PatternRewriter &rewriter) const override { // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp. - auto defOp = stridedSliceOp.vector()->getDefiningOp(); + auto defOp = stridedSliceOp.vector().getDefiningOp(); auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); if (!constantMaskOp) return matchFailure(); @@ -1365,7 +1364,7 @@ public: // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region. rewriter.replaceOpWithNewOp<ConstantMaskOp>( - stridedSliceOp, stridedSliceOp.getResult()->getType(), + stridedSliceOp, stridedSliceOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); return matchSuccess(); } @@ -1503,7 +1502,7 @@ static LogicalResult verify(TransferReadOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - auto paddingType = op.padding()->getType(); + auto paddingType = op.padding().getType(); auto permutationMap = op.permutation_map(); auto memrefElementType = memrefType.getElementType(); @@ -1540,8 +1539,8 @@ static LogicalResult verify(TransferReadOp op) { // TransferWriteOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref() - << "[" << op.indices() << "]"; + p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" + << op.indices() << "]"; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getVectorType() << ", " << op.getMemRefType(); } @@ -1596,12 +1595,12 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) { void TypeCastOp::build(Builder *builder, OperationState &result, Value source) { result.addOperands(source); result.addTypes( - inferVectorTypeCastResultType(source->getType().cast<MemRefType>())); + inferVectorTypeCastResultType(source.getType().cast<MemRefType>())); } static void print(OpAsmPrinter &p, TypeCastOp op) { - auto type = op.getOperand()->getType().cast<MemRefType>(); - p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to " + auto type = op.getOperand().getType().cast<MemRefType>(); + p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to " << inferVectorTypeCastResultType(type); } @@ -1665,14 +1664,14 @@ static ParseResult parseTupleGetOp(OpAsmParser &parser, } static void print(OpAsmPrinter &p, TupleGetOp op) { - p << op.getOperationName() << ' ' << *op.getOperand() << ", " << op.index(); + p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()}); - p << " : " << op.getOperand()->getType(); + p << " : " << op.getOperand().getType(); } static LogicalResult verify(TupleGetOp op) { - auto tupleType = op.getOperand()->getType().cast<TupleType>(); + auto tupleType = op.getOperand().getType().cast<TupleType>(); if (op.getIndex() < 0 || op.getIndex() >= static_cast<int64_t>(tupleType.size())) return op.emitOpError("tuple get index out of range"); @@ -1696,12 +1695,12 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, ConstantMaskOp op) { p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : " - << op.getResult()->getType(); + << op.getResult().getType(); } static LogicalResult verify(ConstantMaskOp &op) { // Verify that array attr size matches the rank of the vector result. - auto resultType = op.getResult()->getType().cast<VectorType>(); + auto resultType = op.getResult().getType().cast<VectorType>(); if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank()) return op.emitOpError( "must specify array attr of size equal vector result rank"); @@ -1749,7 +1748,7 @@ static void print(OpAsmPrinter &p, CreateMaskOp op) { static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != - op.getResult()->getType().cast<VectorType>().getRank()) + op.getResult().getType().cast<VectorType>().getRank()) return op.emitOpError( "must specify an operand for each result vector dimension"); return success(); @@ -1768,7 +1767,7 @@ ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) { } static void print(OpAsmPrinter &p, PrintOp op) { - p << op.getOperationName() << ' ' << *op.source() << " : " + p << op.getOperationName() << ' ' << op.source() << " : " << op.getPrintType(); } @@ -1783,19 +1782,19 @@ public: PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. auto is_not_def_by_constant = [](Value operand) { - return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp()); + return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp()); }; if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) return matchFailure(); // Gather constant mask dimension sizes. SmallVector<int64_t, 4> maskDimSizes; for (auto operand : createMaskOp.operands()) { - auto defOp = operand->getDefiningOp(); + auto defOp = operand.getDefiningOp(); maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue()); } // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp<ConstantMaskOp>( - createMaskOp, createMaskOp.getResult()->getType(), + createMaskOp, createMaskOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); return matchSuccess(); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 9fcbd0cb921..d98f41e3d63 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -195,7 +195,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue, auto tupleType = generateExtractSlicesOpResultType(vectorType, sizes, strides, builder); state.slicesTuple = builder.create<vector::ExtractSlicesOp>( - initValue->getLoc(), tupleType, initValue, sizes, strides); + initValue.getLoc(), tupleType, initValue, sizes, strides); } } @@ -232,7 +232,7 @@ static Value getOrCreateUnrolledVectorSlice( if (valueSlice == nullptr) { // Return tuple element at 'sliceLinearIndex'. auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); - auto initValueType = initValue->getType().cast<VectorType>(); + auto initValueType = initValue.getType().cast<VectorType>(); auto vectorType = VectorType::get(state.unrolledShape, initValueType.getElementType()); // Initialize 'cache' with slice from 'initValue'. @@ -311,7 +311,7 @@ static Value unrollSingleResultStructuredOp(Operation *op, unsigned resultIndex, ArrayRef<int64_t> targetShape, PatternRewriter &builder) { - auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>(); + auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -379,7 +379,7 @@ static Value unrollSingleResultStructuredOp(Operation *op, SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances); SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { - vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>(); + vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>(); vectorTupleValues[i] = caches[resultIndex][i]; } TupleType tupleType = builder.getTupleType(vectorTupleTypes); @@ -387,7 +387,7 @@ static Value unrollSingleResultStructuredOp(Operation *op, vectorTupleValues); // Create InsertSlicesOp(Tuple(result_vectors)). - auto resultVectorType = op->getResult(0)->getType().cast<VectorType>(); + auto resultVectorType = op->getResult(0).getType().cast<VectorType>(); SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape); SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1); @@ -411,7 +411,7 @@ static void getVectorContractionOpUnrollState( vectors.resize(numIterators); unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex(); for (unsigned i = 0; i < numIterators; ++i) { - vectors[i].type = contractionOp.getOperand(i)->getType().cast<VectorType>(); + vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>(); vectors[i].indexMap = iterationIndexMapList[i]; vectors[i].operandIndex = i; vectors[i].isAcc = i == accOperandIndex ? true : false; @@ -437,7 +437,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape, std::vector<VectorState> &vectors, unsigned &resultIndex) { // Verify that operation and operands all have the same vector shape. - auto resultType = op->getResult(0)->getType().dyn_cast_or_null<VectorType>(); + auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>(); assert(resultType && "Expected op with vector result type"); auto resultShape = resultType.getShape(); // Verify that all operands have the same vector type as result. @@ -515,7 +515,7 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); sliceIndices[it.index()] = rewriter.create<AffineApplyOp>( - it.value()->getLoc(), map, ArrayRef<Value>(it.value())); + it.value().getLoc(), map, ArrayRef<Value>(it.value())); } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); @@ -536,8 +536,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. Value xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = - dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin()); - if (!xferReadResult->hasOneUse() || !extractSlicesOp) + dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin()); + if (!xferReadResult.hasOneUse() || !extractSlicesOp) return matchFailure(); // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. @@ -587,14 +587,14 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> { if (!xferWriteOp.permutation_map().isIdentity()) return matchFailure(); // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. - auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp(); + auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp); if (!insertSlicesOp) return matchFailure(); // Get TupleOp operand of 'insertSlicesOp'. auto tupleOp = dyn_cast_or_null<vector::TupleOp>( - insertSlicesOp.vectors()->getDefiningOp()); + insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) return matchFailure(); @@ -634,19 +634,19 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> { PatternRewriter &rewriter) const override { // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>( - tupleGetOp.vectors()->getDefiningOp()); + tupleGetOp.vectors().getDefiningOp()); if (!extractSlicesOp) return matchFailure(); // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>( - extractSlicesOp.vector()->getDefiningOp()); + extractSlicesOp.vector().getDefiningOp()); if (!insertSlicesOp) return matchFailure(); // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. auto tupleOp = dyn_cast_or_null<vector::TupleOp>( - insertSlicesOp.vectors()->getDefiningOp()); + insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) return matchFailure(); |