diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 91 | ||||
| -rw-r--r-- | mlir/lib/IR/AttributeDetail.h | 5 | ||||
| -rw-r--r-- | mlir/lib/IR/Attributes.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/IR/BasicBlock.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 55 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/IR/Function.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/IR/Instructions.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/MLIRContext.cpp | 207 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 26 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/TypeDetail.h | 126 | ||||
| -rw-r--r-- | mlir/lib/IR/Types.cpp | 101 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 178 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/StandardOps.cpp | 226 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConstantFold.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 14 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 21 |
22 files changed, 670 insertions, 508 deletions
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 1b3c24fd9f9..1904a636647 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { return tripCountExpr.getLargestKnownDivisor(); } -bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType, +bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType, ArrayRef<MLValue *> indices, unsigned dim) { - assert(indices.size() == memRefType->getRank()); + assert(indices.size() == memRefType.getRank()); assert(dim < indices.size()); - auto layoutMap = memRefType->getAffineMaps(); - assert(memRefType->getAffineMaps().size() <= 1); + auto layoutMap = memRefType.getAffineMaps(); + assert(memRefType.getAffineMaps().size() <= 1); // TODO(ntv): remove dependency on Builder once we support non-identity // layout map. - Builder b(memRefType->getContext()); + Builder b(memRefType.getContext()); assert(layoutMap.empty() || layoutMap[0] == b.getMultiDimIdentityMap(indices.size())); (void)layoutMap; @@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input, using namespace functional; auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); }, memoryOp->getIndices()); - auto *memRefType = memoryOp->getMemRefType(); + auto memRefType = memoryOp->getMemRefType(); for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) { if (fastestVaryingDim == (numIndices - 1) - d) { continue; @@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input, template <typename LoadOrStoreOpPointer> static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { - auto *memRefType = memoryOp->getMemRefType(); - return isa<VectorType>(memRefType->getElementType()); + auto memRefType = memoryOp->getMemRefType(); + return memRefType.getElementType().template isa<VectorType>(); } bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index bfbcb169cfe..0dd030d5b45 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() { // Verify that the argument list of the function and the arg list of the first // block line up. - auto fnInputTypes = fn.getType()->getInputs(); + auto fnInputTypes = fn.getType().getInputs(); if (fnInputTypes.size() != firstBB->getNumArguments()) return failure("first block of cfgfunc must have " + Twine(fnInputTypes.size()) + @@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands, bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) { // Verify that the return operands match the results of the function. - auto results = fn.getType()->getResults(); + auto results = fn.getType().getResults(); if (inst.getNumOperands() != results.size()) return failure("return has " + Twine(inst.getNumOperands()) + " operands, but enclosing function returns " + diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 454a28a6558..cb5e96f0086 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -122,7 +122,7 @@ private: void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); void visitOperationStmt(const OperationStmt *opStmt); - void visitType(const Type *type); + void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const Operation *op); @@ -135,16 +135,16 @@ private: } // end anonymous namespace // TODO Support visiting other types/instructions when implemented. -void ModuleState::visitType(const Type *type) { - if (auto *funcType = dyn_cast<FunctionType>(type)) { +void ModuleState::visitType(Type type) { + if (auto funcType = type.dyn_cast<FunctionType>()) { // Visit input and result types for functions. - for (auto *input : funcType->getInputs()) + for (auto input : funcType.getInputs()) visitType(input); - for (auto *result : funcType->getResults()) + for (auto result : funcType.getResults()) visitType(result); - } else if (auto *memref = dyn_cast<MemRefType>(type)) { + } else if (auto memref = type.dyn_cast<MemRefType>()) { // Visit affine maps in memref type. - for (auto map : memref->getAffineMaps()) { + for (auto map : memref.getAffineMaps()) { recordAffineMapReference(map); } } @@ -271,7 +271,7 @@ public: void print(const Module *module); void printFunctionReference(const Function *func); void printAttribute(Attribute attr); - void printType(const Type *type); + void printType(Type type); void print(const Function *fn); void print(const ExtFunction *fn); void print(const CFGFunction *fn); @@ -290,7 +290,7 @@ protected: void printFunctionAttributes(const Function *fn); void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs = {}); - void printFunctionResultType(const FunctionType *type); + void printFunctionResultType(FunctionType type); void printAffineMapId(int affineMapId) const; void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; @@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) { } void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { - auto *type = attr.getType(); - auto shape = type->getShape(); - auto rank = type->getRank(); + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); SmallVector<Attribute, 16> elements; attr.getValues(elements); @@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { os << ']'; } -void ModulePrinter::printType(const Type *type) { - switch (type->getKind()) { +void ModulePrinter::printType(Type type) { + switch (type.getKind()) { case Type::Kind::Index: os << "index"; return; @@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) { return; case Type::Kind::Integer: { - auto *integer = cast<IntegerType>(type); - os << 'i' << integer->getWidth(); + auto integer = type.cast<IntegerType>(); + os << 'i' << integer.getWidth(); return; } case Type::Kind::Function: { - auto *func = cast<FunctionType>(type); + auto func = type.cast<FunctionType>(); os << '('; - interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); + interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; - auto results = func->getResults(); + auto results = func.getResults(); if (results.size() == 1) - os << *results[0]; + os << results[0]; else { os << '('; - interleaveComma(results, [&](Type *type) { printType(type); }); + interleaveComma(results, [&](Type type) { printType(type); }); os << ')'; } return; } case Type::Kind::Vector: { - auto *v = cast<VectorType>(type); + auto v = type.cast<VectorType>(); os << "vector<"; - for (auto dim : v->getShape()) + for (auto dim : v.getShape()) os << dim << 'x'; - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::RankedTensor: { - auto *v = cast<RankedTensorType>(type); + auto v = type.cast<RankedTensorType>(); os << "tensor<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::UnrankedTensor: { - auto *v = cast<UnrankedTensorType>(type); + auto v = type.cast<UnrankedTensorType>(); os << "tensor<*x"; - printType(v->getElementType()); + printType(v.getElementType()); os << '>'; return; } case Type::Kind::MemRef: { - auto *v = cast<MemRefType>(type); + auto v = type.cast<MemRefType>(); os << "memref<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - printType(v->getElementType()); - for (auto map : v->getAffineMaps()) { + printType(v.getElementType()); + for (auto map : v.getAffineMaps()) { os << ", "; printAffineMapReference(map); } // Only print the memory space if it is the non-default one. - if (v->getMemorySpace()) - os << ", " << v->getMemorySpace(); + if (v.getMemorySpace()) + os << ", " << v.getMemorySpace(); os << '>'; return; } @@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { // Function printing //===----------------------------------------------------------------------===// -void ModulePrinter::printFunctionResultType(const FunctionType *type) { - switch (type->getResults().size()) { +void ModulePrinter::printFunctionResultType(FunctionType type) { + switch (type.getResults().size()) { case 0: break; case 1: os << " -> "; - printType(type->getResults()[0]); + printType(type.getResults()[0]); break; default: os << " -> ("; - interleaveComma(type->getResults(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getResults(), + [&](Type eltType) { printType(eltType); }); os << ')'; break; } @@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) { auto type = fn->getType(); os << "@" << fn->getName() << '('; - interleaveComma(type->getInputs(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); }); os << ')'; printFunctionResultType(type); @@ -937,7 +936,7 @@ public: // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } - void printType(const Type *type) { ModulePrinter::printType(type); } + void printType(Type type) { ModulePrinter::printType(type); } void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); } void printAffineMap(AffineMap map) { return ModulePrinter::printAffineMapReference(map); @@ -974,10 +973,10 @@ protected: if (auto *op = value->getDefiningOperation()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) { // i1 constants get special names. - if (intOp->getType()->isInteger(1)) { + if (intOp->getType().isInteger(1)) { specialName << (intOp->getValue() ? "true" : "false"); } else { - specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); + specialName << 'c' << intOp->getValue() << '_' << intOp->getType(); } } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { specialName << 'c' << intOp->getValue(); @@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); } void Type::print(raw_ostream &os) const { ModuleState state(getContext()); - ModulePrinter(os, state).printType(this); + ModulePrinter(os, state).printType(*this); } void Type::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index a0e9afb4fd3..63ad544fa48 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage { /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { - Type *value; + Type value; }; /// An attribute representing a reference to a function. @@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage { /// A base attribute representing a reference to a vector or tensor constant. struct ElementsAttributeStorage : public AttributeStorage { - VectorOrTensorType *type; + VectorOrTensorType type; }; /// An attribute representing a reference to a vector or tensor constant, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 34312b84a0b..58b5b90d43d 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const { TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -Type *TypeAttr::getValue() const { - return static_cast<ImplType *>(attr)->value; -} +Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} @@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const { return static_cast<ImplType *>(attr)->value; } -FunctionType *FunctionAttr::getType() const { return getValue()->getType(); } +FunctionType FunctionAttr::getType() const { return getValue()->getType(); } ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -VectorOrTensorType *ElementsAttr::getType() const { +VectorOrTensorType ElementsAttr::getType() const { return static_cast<ImplType *>(attr)->type; } @@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos, void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth; - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); values.reserve(elementNum); if (bitsWidth == 64) { ArrayRef<int64_t> vs( @@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr) : DenseElementsAttr(ptr) {} void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { - auto elementNum = getType()->getNumElements(); - auto context = getType()->getContext(); + auto elementNum = getType().getNumElements(); + auto context = getType().getContext(); ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()), getRawData().size() / 8}); values.reserve(elementNum); diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp index bb8ac75d91a..29a5ce12e4a 100644 --- a/mlir/lib/IR/BasicBlock.cpp +++ b/mlir/lib/IR/BasicBlock.cpp @@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() { // Argument list management. //===----------------------------------------------------------------------===// -BBArgument *BasicBlock::addArgument(Type *type) { +BBArgument *BasicBlock::addArgument(Type type) { auto *arg = new BBArgument(type, this); arguments.push_back(arg); return arg; } /// Add one argument to the argument list for each type specified in the list. -auto BasicBlock::addArguments(ArrayRef<Type *> types) +auto BasicBlock::addArguments(ArrayRef<Type> types) -> llvm::iterator_range<args_iterator> { arguments.reserve(arguments.size() + types.size()); auto initialSize = arguments.size(); - for (auto *type : types) { + for (auto type : types) { addArgument(type); } return {arguments.data() + initialSize, arguments.data() + arguments.size()}; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 22d749a6c8c..906b580d9af 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename, // Types. //===----------------------------------------------------------------------===// -FloatType *Builder::getBF16Type() { return Type::getBF16(context); } +FloatType Builder::getBF16Type() { return Type::getBF16(context); } -FloatType *Builder::getF16Type() { return Type::getF16(context); } +FloatType Builder::getF16Type() { return Type::getF16(context); } -FloatType *Builder::getF32Type() { return Type::getF32(context); } +FloatType Builder::getF32Type() { return Type::getF32(context); } -FloatType *Builder::getF64Type() { return Type::getF64(context); } +FloatType Builder::getF64Type() { return Type::getF64(context); } -OtherType *Builder::getIndexType() { return Type::getIndex(context); } +OtherType Builder::getIndexType() { return Type::getIndex(context); } -OtherType *Builder::getTFControlType() { return Type::getTFControl(context); } +OtherType Builder::getTFControlType() { return Type::getTFControl(context); } -OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } +OtherType Builder::getTFResourceType() { return Type::getTFResource(context); } -OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); } +OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); } -OtherType *Builder::getTFComplex64Type() { +OtherType Builder::getTFComplex64Type() { return Type::getTFComplex64(context); } -OtherType *Builder::getTFComplex128Type() { +OtherType Builder::getTFComplex128Type() { return Type::getTFComplex128(context); } -OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); } +OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); } -OtherType *Builder::getTFStringType() { return Type::getTFString(context); } +OtherType Builder::getTFStringType() { return Type::getTFString(context); } -IntegerType *Builder::getIntegerType(unsigned width) { +IntegerType Builder::getIntegerType(unsigned width) { return Type::getInteger(width, context); } -FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs, - ArrayRef<Type *> results) { +FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, + ArrayRef<Type> results) { return FunctionType::get(inputs, results, context); } -MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { +MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } -VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) { +VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) { return VectorType::get(shape, elementType); } -RankedTensorType *Builder::getTensorType(ArrayRef<int> shape, - Type *elementType) { +RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) { return RankedTensorType::get(shape, elementType); } -UnrankedTensorType *Builder::getTensorType(Type *elementType) { +UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } @@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { return IntegerSetAttr::get(set); } -TypeAttr Builder::getTypeAttr(Type *type) { +TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type, context); } @@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) { return FunctionAttr::get(value, context); } -ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, ArrayRef<char> data) { return DenseElementsAttr::get(type, data); } -ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type, +ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes) { return OpaqueElementsAttr::get(type, bytes); } @@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { OperationStmt *MLFuncBuilder::createOperation(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> types, + ArrayRef<Type> types, ArrayRef<NamedAttribute> attrs) { auto *op = OperationStmt::create(location, name, operands, types, attrs, getContext()); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 542e67eaefd..e4bca037c4e 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, numDims = opInfos.size(); // Parse the optional symbol operands. - auto *affineIntTy = parser->getBuilder().getIndexType(); + auto affineIntTy = parser->getBuilder().getIndexType(); if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::OptionalSquare) || parser->resolveOperands(opInfos, affineIntTy, operands)) @@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result, bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); - auto *affineIntTy = builder.getIndexType(); + auto affineIntTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; @@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants, /// Builds a constant op with the specified attribute value and result type. void ConstantOp::build(Builder *builder, OperationState *result, - Attribute value, Type *type) { + Attribute value, Type type) { result->addAttribute("value", value); result->types.push_back(type); } @@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const { p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); if (!getValue().isa<FunctionAttr>()) - *p << " : " << *getType(); + *p << " : " << getType(); } bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { Attribute valueAttr; - Type *type; + Type type; if (parser->parseAttribute(valueAttr, "value", result->attributes) || parser->parseOptionalAttributeDict(result->attributes)) @@ -208,33 +208,33 @@ bool ConstantOp::verify() const { if (!value) return emitOpError("requires a 'value' attribute"); - auto *type = this->getType(); - if (isa<IntegerType>(type) || type->isIndex()) { + auto type = this->getType(); + if (type.isa<IntegerType>() || type.isIndex()) { if (!value.isa<IntegerAttr>()) return emitOpError( "requires 'value' to be an integer for an integer result type"); return false; } - if (isa<FloatType>(type)) { + if (type.isa<FloatType>()) { if (!value.isa<FloatAttr>()) return emitOpError("requires 'value' to be a floating point constant"); return false; } - if (isa<VectorOrTensorType>(type)) { + if (type.isa<VectorOrTensorType>()) { if (!value.isa<ElementsAttr>()) return emitOpError("requires 'value' to be a vector/tensor constant"); return false; } - if (type->isTFString()) { + if (type.isTFString()) { if (!value.isa<StringAttr>()) return emitOpError("requires 'value' to be a string constant"); return false; } - if (isa<FunctionType>(type)) { + if (type.isa<FunctionType>()) { if (!value.isa<FunctionAttr>()) return emitOpError("requires 'value' to be a function reference"); return false; @@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands, } void ConstantFloatOp::build(Builder *builder, OperationState *result, - const APFloat &value, FloatType *type) { + const APFloat &value, FloatType type) { ConstantOp::build(builder, result, builder->getFloatAttr(value), type); } bool ConstantFloatOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<FloatType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::isClassFor(const Operation *op) { return ConstantOp::isClassFor(op) && - isa<IntegerType>(op->getResult(0)->getType()); + op->getResult(0)->getType().isa<IntegerType>(); } void ConstantIntOp::build(Builder *builder, OperationState *result, @@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result, /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, Type *type) { - assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type"); + int64_t value, Type type) { + assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); ConstantOp::build(builder, result, builder->getIntegerAttr(value), type); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); + return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState *result, @@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result, bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type *, 2> types; + SmallVector<Type, 2> types; llvm::SMLoc loc; return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || (!opInfo.empty() && parser->parseColonTypeList(types)) || @@ -330,7 +330,7 @@ bool ReturnOp::verify() const { // The operand number and types must match the function signature. MLFunction *function = cast<MLFunction>(block); - const auto &results = function->getType()->getResults(); + const auto &results = function->getType().getResults(); if (stmt->getNumOperands() != results.size()) return emitOpError("has " + Twine(stmt->getNumOperands()) + " operands, but enclosing function returns " + diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index efeb16b61db..70c0e1259b3 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -28,8 +28,8 @@ using namespace mlir; Function::Function(Kind kind, Location *location, StringRef name, - FunctionType *type, ArrayRef<NamedAttribute> attrs) - : nameAndKind(Identifier::get(name, type->getContext()), kind), + FunctionType type, ArrayRef<NamedAttribute> attrs) + : nameAndKind(Identifier::get(name, type.getContext()), kind), location(location), type(type) { this->attrs = AttributeListStorage::get(attrs, getContext()); } @@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const { return {}; } -MLIRContext *Function::getContext() const { return getType()->getContext(); } +MLIRContext *Function::getContext() const { return getType().getContext(); } /// Delete this object. void Function::destroy() { @@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const { // ExtFunction implementation. //===----------------------------------------------------------------------===// -ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, +ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::ExtFunc, location, name, type, attrs) {} @@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, // CFGFunction implementation. //===----------------------------------------------------------------------===// -CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type, +CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::CFGFunc, location, name, type, attrs) {} @@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() { /// Create a new MLFunction with the specific fields. MLFunction *MLFunction::create(Location *location, StringRef name, - FunctionType *type, + FunctionType type, ArrayRef<NamedAttribute> attrs) { - const auto &argTypes = type->getInputs(); + const auto &argTypes = type.getInputs(); auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size()); void *rawMem = malloc(byteSize); @@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name, return function; } -MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type, +MLFunction::MLFunction(Location *location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) : Function(Kind::MLFunc, location, name, type, attrs), StmtBlock(StmtBlockKind::MLFunc) {} diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index 422636bf2e3..d2f49ddfc6e 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const { /// Create a new OperationInst with the specified fields. OperationInst *OperationInst::create(Location *location, OperationName name, ArrayRef<CFGValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(), @@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name, OperationInst *OperationInst::clone() const { SmallVector<CFGValue *, 8> operands; - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; // Put together the operands and results. for (auto *operand : getOperands()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0a2e9416842..8811f7b9f78 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -21,6 +21,7 @@ #include "AttributeDetail.h" #include "AttributeListStorage.h" #include "IntegerSetDetail.h" +#include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -44,11 +45,11 @@ using namespace mlir::detail; using namespace llvm; namespace { -struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { +struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> { // Functions are uniqued based on their inputs and results. - using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>; - using DenseMapInfo<FunctionType *>::getHashValue; - using DenseMapInfo<FunctionType *>::isEqual; + using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>; + using DenseMapInfo<FunctionTypeStorage *>::getHashValue; + using DenseMapInfo<FunctionTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( @@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) { + static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; return lhs == KeyTy(rhs->getInputs(), rhs->getResults()); @@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { } }; -struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> { +struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> { // Vectors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<VectorType *>::getHashValue; - using DenseMapInfo<VectorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<VectorTypeStorage *>::getHashValue; + using DenseMapInfo<VectorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const VectorType *rhs) { + static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> { +struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> { // Ranked tensors are uniqued based on their element type and shape. - using KeyTy = std::pair<Type *, ArrayRef<int>>; - using DenseMapInfo<RankedTensorType *>::getHashValue; - using DenseMapInfo<RankedTensorType *>::isEqual; + using KeyTy = std::pair<Type, ArrayRef<int>>; + using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue; + using DenseMapInfo<RankedTensorTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(key.first), + DenseMapInfo<Type>::getHashValue(key.first), hash_combine_range(key.second.begin(), key.second.end())); } - static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) { + static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); + return lhs == KeyTy(rhs->elementType, rhs->getShape()); } }; -struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> { +struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> { // MemRefs are uniqued based on their element type, shape, affine map // composition, and memory space. - using KeyTy = - std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; - using DenseMapInfo<MemRefType *>::getHashValue; - using DenseMapInfo<MemRefType *>::isEqual; + using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; + using DenseMapInfo<MemRefTypeStorage *>::getHashValue; + using DenseMapInfo<MemRefTypeStorage *>::isEqual; static unsigned getHashValue(KeyTy key) { return hash_combine( - DenseMapInfo<Type *>::getHashValue(std::get<0>(key)), + DenseMapInfo<Type>::getHashValue(std::get<0>(key)), hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()), hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), std::get<3>(key)); } - static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) { + static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(), - rhs->getAffineMaps(), rhs->getMemorySpace()); + return lhs == std::make_tuple(rhs->elementType, rhs->getShape(), + rhs->getAffineMaps(), rhs->memorySpace); } }; @@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> { }; struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>; + using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>; using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue; using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual; @@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { }; struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> { - using KeyTy = std::pair<VectorOrTensorType *, StringRef>; + using KeyTy = std::pair<VectorOrTensorType, StringRef>; using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue; using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual; @@ -295,13 +295,14 @@ public: llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers; // Uniquing table for 'other' types. - OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - - int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr}; + OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - + int(Type::Kind::FIRST_OTHER_TYPE) + 1] = { + nullptr}; // Uniquing table for 'float' types. - FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - - int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = { - nullptr}; + FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - + int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = + {nullptr}; // Affine map uniquing. using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; @@ -324,26 +325,26 @@ public: DenseMap<int64_t, AffineConstantExprStorage *> constExprs; /// Integer type uniquing. - DenseMap<unsigned, IntegerType *> integers; + DenseMap<unsigned, IntegerTypeStorage *> integers; /// Function type uniquing. - using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>; + using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>; FunctionTypeSet functions; /// Vector type uniquing. - using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>; + using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>; VectorTypeSet vectors; /// Ranked tensor type uniquing. using RankedTensorTypeSet = - DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>; + DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>; RankedTensorTypeSet rankedTensors; /// Unranked tensor type uniquing. - DenseMap<Type *, UnrankedTensorType *> unrankedTensors; + DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors; /// MemRef type uniquing. - using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>; + using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>; MemRefTypeSet memrefs; // Attribute uniquing. @@ -355,13 +356,12 @@ public: ArrayAttrSet arrayAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs; DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs; - DenseMap<Type *, TypeAttributeStorage *> typeAttrs; + DenseMap<Type, TypeAttributeStorage *> typeAttrs; using AttributeListSet = DenseSet<AttributeListStorage *, AttributeListKeyInfo>; AttributeListSet attributeLists; DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs; - DenseMap<std::pair<VectorOrTensorType *, Attribute>, - SplatElementsAttributeStorage *> + DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *> splatElementsAttrs; using DenseElementsAttrSet = DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>; @@ -369,7 +369,7 @@ public: using OpaqueElementsAttrSet = DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>; OpaqueElementsAttrSet opaqueElementsAttrs; - DenseMap<std::tuple<Type *, Attribute, Attribute>, + DenseMap<std::tuple<Type, Attribute, Attribute>, SparseElementsAttributeStorage *> sparseElementsAttrs; @@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line, // Type uniquing //===----------------------------------------------------------------------===// -IntegerType *IntegerType::get(unsigned width, MLIRContext *context) { +IntegerType IntegerType::get(unsigned width, MLIRContext *context) { + assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); auto &impl = context->getImpl(); auto *&result = impl.integers[width]; if (!result) { - result = impl.allocator.Allocate<IntegerType>(); - new (result) IntegerType(width, context); + result = impl.allocator.Allocate<IntegerTypeStorage>(); + new (result) IntegerTypeStorage{{Kind::Integer, context}, width}; } return result; } -FloatType *FloatType::get(Kind kind, MLIRContext *context) { +FloatType FloatType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE && kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind"); auto &impl = context->getImpl(); @@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<FloatType>(); + auto *ptr = impl.allocator.Allocate<FloatTypeStorage>(); // Initialize the memory using placement new. - new (ptr) FloatType(kind, context); + new (ptr) FloatTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -OtherType *OtherType::get(Kind kind, MLIRContext *context) { +OtherType OtherType::get(Kind kind, MLIRContext *context) { assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE && "Not an 'other' type kind"); auto &impl = context->getImpl(); @@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) { return entry; // On the first use, we allocate them into the bump pointer. - auto *ptr = impl.allocator.Allocate<OtherType>(); + auto *ptr = impl.allocator.Allocate<OtherTypeStorage>(); // Initialize the memory using placement new. - new (ptr) OtherType(kind, context); + new (ptr) OtherTypeStorage{{kind, context}}; // Cache and return it. return entry = ptr; } -FunctionType *FunctionType::get(ArrayRef<Type *> inputs, - ArrayRef<Type *> results, - MLIRContext *context) { +FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results, + MLIRContext *context) { auto &impl = context->getImpl(); // Look to see if we already have this function type. @@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<FunctionType>(); + auto *result = impl.allocator.Allocate<FunctionTypeStorage>(); // Copy the inputs and results into the bump pointer. - SmallVector<Type *, 16> types; + SmallVector<Type, 16> types; types.reserve(inputs.size() + results.size()); types.append(inputs.begin(), inputs.end()); types.append(results.begin(), results.end()); - auto typesList = impl.copyInto(ArrayRef<Type *>(types)); + auto typesList = impl.copyInto(ArrayRef<Type>(types)); // Initialize the memory using placement new. - new (result) - FunctionType(typesList.data(), inputs.size(), results.size(), context); + new (result) FunctionTypeStorage{ + {Kind::Function, context, static_cast<unsigned int>(inputs.size())}, + static_cast<unsigned int>(results.size()), + typesList.data()}; // Cache and return it. return *existing.first = result; } -VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { +VectorType VectorType::get(ArrayRef<int> shape, Type elementType) { assert(!shape.empty() && "vector types must have at least one dimension"); - assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) && + assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) && "vectors elements must be primitives"); assert(!std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; }) && "vector types must have static shape"); - auto *context = elementType->getContext(); + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this vector type. @@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<VectorType>(); + auto *result = impl.allocator.Allocate<VectorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) VectorType(shape, elementType, context); + new (result) VectorTypeStorage{ + {{Kind::Vector, context, static_cast<unsigned int>(shape.size())}, + elementType}, + shape.data()}; // Cache and return it. return *existing.first = result; } -RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, - Type *elementType) { - auto *context = elementType->getContext(); +RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this ranked tensor type. @@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<RankedTensorType>(); + auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); // Initialize the memory using placement new. - new (result) RankedTensorType(shape, elementType, context); + new (result) RankedTensorTypeStorage{ + {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())}, + elementType}}, + shape.data()}; // Cache and return it. return *existing.first = result; } -UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { - auto *context = elementType->getContext(); +UnrankedTensorType UnrankedTensorType::get(Type elementType) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Look to see if we already have this unranked tensor type. @@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { return result; // On the first use, we allocate them into the bump pointer. - result = impl.allocator.Allocate<UnrankedTensorType>(); + result = impl.allocator.Allocate<UnrankedTensorTypeStorage>(); // Initialize the memory using placement new. - new (result) UnrankedTensorType(elementType, context); + new (result) UnrankedTensorTypeStorage{ + {{{Kind::UnrankedTensor, context}, elementType}}}; return result; } -MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapComposition, - unsigned memorySpace) { - auto *context = elementType->getContext(); +MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType, + ArrayRef<AffineMap> affineMapComposition, + unsigned memorySpace) { + auto *context = elementType.getContext(); auto &impl = context->getImpl(); // Drop the unbounded identity maps from the composition. @@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, return *existing.first; // On the first use, we allocate them into the bump pointer. - auto *result = impl.allocator.Allocate<MemRefType>(); + auto *result = impl.allocator.Allocate<MemRefTypeStorage>(); // Copy the shape into the bump pointer. shape = impl.copyInto(shape); @@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, impl.copyInto(ArrayRef<AffineMap>(affineMapComposition)); // Initialize the memory using placement new. - new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace, - context); + new (result) MemRefTypeStorage{ + {Kind::MemRef, context, static_cast<unsigned int>(shape.size())}, + elementType, + shape.data(), + static_cast<unsigned int>(affineMapComposition.size()), + affineMapComposition.data(), + memorySpace}; // Cache and return it. return *existing.first = result; } @@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { return result; } -TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { +TypeAttr TypeAttr::get(Type type, MLIRContext *context) { auto *&result = context->getImpl().typeAttrs[type]; if (result) return result; @@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs, return *existing.first = result; } -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, +SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, Attribute elt) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto *&result = impl.splatElementsAttrs[{type, elt}]; @@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, return result; } -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, +DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef<char> data) { - auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); + auto bitsRequired = (long)type.getBitWidth() * type.getNumElements(); (void)bitsRequired; assert((bitsRequired <= data.size() * 8L) && "Input data bit size should be larger than that type requires"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. DenseElementsAttrInfo::KeyTy key({type, data}); @@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first; // Otherwise, allocate a new one, unique it and return it. - auto *eltType = type->getElementType(); - switch (eltType->getKind()) { + auto eltType = type.getElementType(); + switch (eltType.getKind()) { case Type::Kind::BF16: case Type::Kind::F16: case Type::Kind::F32: @@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } case Type::Kind::Integer: { - auto width = ::cast<IntegerType>(eltType)->getWidth(); + auto width = eltType.cast<IntegerType>().getWidth(); auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>(); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); std::uninitialized_copy(data.begin(), data.end(), copy); @@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, } } -OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, +OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, StringRef bytes) { - assert(isValidTensorElementType(type->getElementType()) && + assert(isValidTensorElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. OpaqueElementsAttrInfo::KeyTy key({type, bytes}); @@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, return *existing.first = result; } -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type, +SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { - auto &impl = type->getContext()->getImpl(); + auto &impl = type.getContext()->getImpl(); // Look to see if we already have this. auto key = std::make_tuple(type, indices, values); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2ed09b83b53..0722421c8ba 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, } bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { - auto *type = op->getResult(0)->getType(); + auto type = op->getResult(0)->getType(); for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { if (op->getResult(i)->getType() != type) return op->emitOpError( @@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { /// If this is a vector type, or a tensor type, return the scalar element type /// that it is built around, otherwise return the type unmodified. -static Type *getTensorOrVectorElementType(Type *type) { - if (auto *vec = dyn_cast<VectorType>(type)) - return vec->getElementType(); +static Type getTensorOrVectorElementType(Type type) { + if (auto vec = type.dyn_cast<VectorType>()) + return vec.getElementType(); // Look through tensor<vector<...>> to find the underlying element type. - if (auto *tensor = dyn_cast<TensorType>(type)) - return getTensorOrVectorElementType(tensor->getElementType()); + if (auto tensor = type.dyn_cast<TensorType>()) + return getTensorOrVectorElementType(tensor.getElementType()); return type; } bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<FloatType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) return op->emitOpError("requires a floating point type"); } @@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { for (auto *result : op->getResults()) { - if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType()))) + if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>()) return op->emitOpError("requires an integer type"); } return false; @@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { SmallVector<OpAsmParser::OperandType, 2> ops; - Type *type; + Type type; return parser->parseOperandList(ops, 2) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type) || @@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p->printOptionalAttrDict(op->getAttrs()); - *p << " : " << *op->getResult(0)->getType(); + *p << " : " << op->getResult(0)->getType(); } //===----------------------------------------------------------------------===// @@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { //===----------------------------------------------------------------------===// void impl::buildCastOp(Builder *builder, OperationState *result, - SSAValue *source, Type *destType) { + SSAValue *source, Type destType) { result->addOperands(source); result->addTypes(destType); } bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType srcInfo; - Type *srcType, *dstType; + Type srcType, dstType; return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || parser->resolveOperand(srcInfo, srcType, result->operands) || parser->parseKeywordType("to", dstType) || @@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " - << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType(); + << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index e9c46d6ec5e..698089a1c67 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block, /// Create a new OperationStmt with the specific fields. OperationStmt *OperationStmt::create(Location *location, OperationName name, ArrayRef<MLValue *> operands, - ArrayRef<Type *> resultTypes, + ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, MLIRContext *context) { auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(), @@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) - return getResult(0)->getType()->getContext(); + return getResult(0)->getType().getContext(); if (getNumOperands()) - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); // In the very odd case where we have no operands or results, fall back to // doing a find. @@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const { if (operands.empty()) return findFunction()->getContext(); - return getOperand(0)->getType()->getContext(); + return getOperand(0)->getType().getContext(); } //===----------------------------------------------------------------------===// @@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, operands.push_back(remapOperand(opValue)); if (auto *opStmt = dyn_cast<OperationStmt>(this)) { - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; resultTypes.reserve(opStmt->getNumResults()); for (auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h new file mode 100644 index 00000000000..c22e87a283e --- /dev/null +++ b/mlir/lib/IR/TypeDetail.h @@ -0,0 +1,126 @@ +//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of Type. +// +//===----------------------------------------------------------------------===// +#ifndef TYPEDETAIL_H_ +#define TYPEDETAIL_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +class AffineMap; +class MLIRContext; + +namespace detail { + +/// Base storage class appearing in a Type. +struct alignas(8) TypeStorage { + TypeStorage(Type::Kind kind, MLIRContext *context) + : context(context), kind(kind), subclassData(0) {} + TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData) + : context(context), kind(kind), subclassData(subclassData) {} + + unsigned getSubclassData() const { return subclassData; } + + void setSubclassData(unsigned val) { + subclassData = val; + // Ensure we don't have any accidental truncation. + assert(getSubclassData() == val && "Subclass data too large for field"); + } + + /// This refers to the MLIRContext in which this type was uniqued. + MLIRContext *const context; + + /// Classification of the subclass, used for type checking. + Type::Kind kind : 8; + + /// Space for subclasses to store data. + unsigned subclassData : 24; +}; + +struct IntegerTypeStorage : public TypeStorage { + unsigned width; +}; + +struct FloatTypeStorage : public TypeStorage {}; + +struct OtherTypeStorage : public TypeStorage {}; + +struct FunctionTypeStorage : public TypeStorage { + ArrayRef<Type> getInputs() const { + return ArrayRef<Type>(inputsAndResults, subclassData); + } + ArrayRef<Type> getResults() const { + return ArrayRef<Type>(inputsAndResults + subclassData, numResults); + } + + unsigned numResults; + Type const *inputsAndResults; +}; + +struct VectorOrTensorTypeStorage : public TypeStorage { + Type elementType; +}; + +struct VectorTypeStorage : public VectorOrTensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct TensorTypeStorage : public VectorOrTensorTypeStorage {}; + +struct RankedTensorTypeStorage : public TensorTypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + const int *shapeElements; +}; + +struct UnrankedTensorTypeStorage : public TensorTypeStorage {}; + +struct MemRefTypeStorage : public TypeStorage { + ArrayRef<int> getShape() const { + return ArrayRef<int>(shapeElements, getSubclassData()); + } + + ArrayRef<AffineMap> getAffineMaps() const { + return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + } + + /// The type of each scalar element of the memref. + Type elementType; + /// An array of integers which stores the shape dimension sizes. + const int *shapeElements; + /// The number of affine maps in the 'affineMapList' array. + const unsigned numAffineMaps; + /// List of affine maps in the memref's layout/index map composition. + AffineMap const *affineMapList; + /// Memory space in which data referenced by memref resides. + const unsigned memorySpace; +}; + +} // namespace detail +} // namespace mlir +#endif // TYPEDETAIL_H_ diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 0ad3f4728fe..1a716956608 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -16,10 +16,17 @@ // ============================================================================= #include "mlir/IR/Types.h" +#include "TypeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/raw_ostream.h" + using namespace mlir; +using namespace mlir::detail; + +Type::Kind Type::getKind() const { return type->kind; } + +MLIRContext *Type::getContext() const { return type->context; } unsigned Type::getBitWidth() const { switch (getKind()) { @@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const { case Type::Kind::F64: return 64; case Type::Kind::Integer: - return cast<IntegerType>(this)->getWidth(); + return cast<IntegerType>().getWidth(); case Type::Kind::Vector: case Type::Kind::RankedTensor: case Type::Kind::UnrankedTensor: - return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth(); + return cast<VectorOrTensorType>().getElementType().getBitWidth(); // TODO: Handle more types. default: llvm_unreachable("unexpected type"); } } -IntegerType::IntegerType(unsigned width, MLIRContext *context) - : Type(Kind::Integer, context), width(width) { - assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); +unsigned Type::getSubclassData() const { return type->getSubclassData(); } +void Type::setSubclassData(unsigned val) { type->setSubclassData(val); } + +IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {} + +unsigned IntegerType::getWidth() const { + return static_cast<ImplType *>(type)->width; } -FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {} + +OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {} -OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {} +FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {} -FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, - unsigned numResults, MLIRContext *context) - : Type(Kind::Function, context, numInputs), numResults(numResults), - inputsAndResults(inputsAndResults) {} +ArrayRef<Type> FunctionType::getInputs() const { + return static_cast<ImplType *>(type)->getInputs(); +} + +unsigned FunctionType::getNumResults() const { + return static_cast<ImplType *>(type)->numResults; +} + +ArrayRef<Type> FunctionType::getResults() const { + return static_cast<ImplType *>(type)->getResults(); +} -VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, - Type *elementType, unsigned subClassData) - : Type(kind, context, subClassData), elementType(elementType) {} +VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {} + +Type VectorOrTensorType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} unsigned VectorOrTensorType::getNumElements() const { switch (getKind()) { @@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const { ArrayRef<int> VectorOrTensorType::getShape() const { switch (getKind()) { case Kind::Vector: - return cast<VectorType>(this)->getShape(); + return cast<VectorType>().getShape(); case Kind::RankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); case Kind::UnrankedTensor: - return cast<RankedTensorType>(this)->getShape(); + return cast<RankedTensorType>().getShape(); default: llvm_unreachable("not a VectorOrTensorType"); } @@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const { return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); } -VectorType::VectorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()), - shapeElements(shape.data()) {} +VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} -TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) - : VectorOrTensorType(kind, context, elementType) { - assert(isValidTensorElementType(elementType)); +ArrayRef<int> VectorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType, - MLIRContext *context) - : TensorType(Kind::RankedTensor, elementType, context), - shapeElements(shape.data()) { - setSubclassData(shape.size()); +TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {} + +RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +ArrayRef<int> RankedTensorType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); } -UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) - : TensorType(Kind::UnrankedTensor, elementType, context) {} +UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {} + +MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {} -MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType, - ArrayRef<AffineMap> affineMapList, unsigned memorySpace, - MLIRContext *context) - : Type(Kind::MemRef, context, shape.size()), elementType(elementType), - shapeElements(shape.data()), numAffineMaps(affineMapList.size()), - affineMapList(affineMapList.data()), memorySpace(memorySpace) {} +ArrayRef<int> MemRefType::getShape() const { + return static_cast<ImplType *>(type)->getShape(); +} + +Type MemRefType::getElementType() const { + return static_cast<ImplType *>(type)->elementType; +} ArrayRef<AffineMap> MemRefType::getAffineMaps() const { - return ArrayRef<AffineMap>(affineMapList, numAffineMaps); + return static_cast<ImplType *>(type)->getAffineMaps(); +} + +unsigned MemRefType::getMemorySpace() const { + return static_cast<ImplType *>(type)->memorySpace; } unsigned MemRefType::getNumDynamicDims() const { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7974c7c71a4..ceb893165f0 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -182,19 +182,19 @@ public: // as the results of their action. // Type parsing. - VectorType *parseVectorType(); + VectorType parseVectorType(); ParseResult parseXInDimensionList(); ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions); - Type *parseTensorType(); - Type *parseMemRefType(); - Type *parseFunctionType(); - Type *parseType(); - ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements); - ParseResult parseTypeList(SmallVectorImpl<Type *> &elements); + Type parseTensorType(); + Type parseMemRefType(); + Type parseFunctionType(); + Type parseType(); + ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); + ParseResult parseTypeList(SmallVectorImpl<Type> &elements); // Attribute parsing. Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, - FunctionType *type); + FunctionType type); Attribute parseAttribute(); ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); @@ -206,9 +206,9 @@ public: AffineMap parseAffineMapReference(); IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetReference(); - DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type); - DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector); - VectorOrTensorType *parseVectorOrTensorType(); + DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type); + DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector); + VectorOrTensorType parseVectorOrTensorType(); private: // The Parser is subclassed and reinstantiated. Do not add additional @@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil( /// float-type ::= `f16` | `bf16` | `f32` | `f64` /// other-type ::= `index` | `tf_control` /// -Type *Parser::parseType() { +Type Parser::parseType() { switch (getToken().getKind()) { default: return (emitError("expected type"), nullptr); @@ -368,7 +368,7 @@ Type *Parser::parseType() { /// vector-type ::= `vector` `<` const-dimension-list primitive-type `>` /// const-dimension-list ::= (integer-literal `x`)+ /// -VectorType *Parser::parseVectorType() { +VectorType Parser::parseVectorType() { consumeToken(Token::kw_vector); if (parseToken(Token::less, "expected '<' in vector type")) @@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; - if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) + if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) return (emitError(typeLoc, "invalid vector element type"), nullptr); return VectorType::get(dimensions, elementType); @@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) { /// tensor-type ::= `tensor` `<` dimension-list element-type `>` /// dimension-list ::= dimension-list-ranked | `*x` /// -Type *Parser::parseTensorType() { +Type Parser::parseTensorType() { consumeToken(Token::kw_tensor); if (parseToken(Token::less, "expected '<' in tensor type")) @@ -485,7 +485,7 @@ Type *Parser::parseTensorType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) return nullptr; @@ -505,7 +505,7 @@ Type *Parser::parseTensorType() { /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map /// memory-space ::= integer-literal /* | TODO: address-space-id */ /// -Type *Parser::parseMemRefType() { +Type Parser::parseMemRefType() { consumeToken(Token::kw_memref); if (parseToken(Token::less, "expected '<' in memref type")) @@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() { // Parse the element type. auto typeLoc = getToken().getLoc(); - auto *elementType = parseType(); + auto elementType = parseType(); if (!elementType) return nullptr; - if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) && - !isa<VectorType>(elementType)) + if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() && + !elementType.isa<VectorType>()) return (emitError(typeLoc, "invalid memref element type"), nullptr); // Parse semi-affine-map-composition. @@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() { /// /// function-type ::= type-list-parens `->` type-list /// -Type *Parser::parseFunctionType() { +Type Parser::parseFunctionType() { assert(getToken().is(Token::l_paren)); - SmallVector<Type *, 4> arguments, results; + SmallVector<Type, 4> arguments, results; if (parseTypeList(arguments) || parseToken(Token::arrow, "expected '->' in function type") || parseTypeList(results)) @@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() { /// /// type-list-no-parens ::= type (`,` type)* /// -ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) { +ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { auto parseElt = [&]() -> ParseResult { auto elt = parseType(); elements.push_back(elt); @@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) { /// type-list-parens ::= `(` `)` /// | `(` type-list-no-parens `)` /// -ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) { +ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) { auto parseElt = [&]() -> ParseResult { auto elt = parseType(); elements.push_back(elt); @@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) { namespace { class TensorLiteralParser { public: - TensorLiteralParser(Parser &p, Type *eltTy) - : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {} + TensorLiteralParser(Parser &p, Type eltTy) + : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {} ParseResult parse() { return parseList(shape); } @@ -676,7 +676,7 @@ private: } Parser &p; - Type *eltTy; + Type eltTy; size_t currBitPos; size_t bitsWidth; SmallVector<int, 4> shape; @@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) { if (!result) return p.emitError("expected tensor element"); // check result matches the element type. - switch (eltTy->getKind()) { + switch (eltTy.getKind()) { case Type::Kind::BF16: case Type::Kind::F16: case Type::Kind::F32: @@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) { /// synthesizing a forward reference) or emit an error and return null on /// failure. Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, - FunctionType *type) { + FunctionType type) { Identifier name = builder.getIdentifier(nameStr.drop_front()); // See if the function has already been defined in the module. @@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::colon, "expected ':' and function type")) return nullptr; auto typeLoc = getToken().getLoc(); - Type *type = parseType(); + Type type = parseType(); if (!type) return nullptr; - auto *fnType = dyn_cast<FunctionType>(type); + auto fnType = type.dyn_cast<FunctionType>(); if (!fnType) return (emitError(typeLoc, "expected function type"), nullptr); @@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() { consumeToken(Token::kw_opaque); if (parseToken(Token::less, "expected '<' after 'opaque'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; auto val = getToken().getStringValue(); @@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "expected '<' after 'splat'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; switch (getToken().getKind()) { @@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; @@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() { if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; - auto *type = parseVectorOrTensorType(); + auto type = parseVectorOrTensorType(); if (!type) return nullptr; switch (getToken().getKind()) { case Token::l_square: { /// Parse indices - auto *indicesEltType = builder.getIntegerType(32); + auto indicesEltType = builder.getIntegerType(32); auto indices = - parseDenseElementsAttr(indicesEltType, isa<VectorType>(type)); + parseDenseElementsAttr(indicesEltType, type.isa<VectorType>()); if (parseToken(Token::comma, "expected ','")) return nullptr; /// Parse values. - auto *valuesEltType = type->getElementType(); + auto valuesEltType = type.getElementType(); auto values = - parseDenseElementsAttr(valuesEltType, isa<VectorType>(type)); + parseDenseElementsAttr(valuesEltType, type.isa<VectorType>()); /// Sanity check. - auto *indicesType = indices.getType(); - auto *valuesType = values.getType(); - auto sameShape = (indicesType->getRank() == 1) || - (type->getRank() == indicesType->getDimSize(1)); + auto indicesType = indices.getType(); + auto valuesType = values.getType(); + auto sameShape = (indicesType.getRank() == 1) || + (type.getRank() == indicesType.getDimSize(1)); auto sameElementNum = - indicesType->getDimSize(0) == valuesType->getDimSize(0); + indicesType.getDimSize(0) == valuesType.getDimSize(0); if (!sameShape || !sameElementNum) { std::string str; llvm::raw_string_ostream s(str); s << "expected shape (["; - interleaveComma(type->getShape(), s); + interleaveComma(type.getShape(), s); s << "]); inferred shape of indices literal (["; - interleaveComma(indicesType->getShape(), s); + interleaveComma(indicesType.getShape(), s); s << "]); inferred shape of values literal (["; - interleaveComma(valuesType->getShape(), s); + interleaveComma(valuesType.getShape(), s); s << "])"; return (emitError(s.str()), nullptr); } @@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() { nullptr); } default: { - if (Type *type = parseType()) + if (Type type = parseType()) return builder.getTypeAttr(type); return nullptr; } @@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() { /// /// This method returns a constructed dense elements attribute with the shape /// from the parsing result. -DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) { +DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) { TensorLiteralParser literalParser(*this, eltType); if (literalParser.parse()) return nullptr; - VectorOrTensorType *type; + VectorOrTensorType type; if (isVector) { type = builder.getVectorType(literalParser.getShape(), eltType); } else { @@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) { /// This method compares the shapes from the parsing result and that from the /// input argument. It returns a constructed dense elements attribute if both /// match. -DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) { - auto *eltTy = type->getElementType(); +DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { + auto eltTy = type.getElementType(); TensorLiteralParser literalParser(*this, eltTy); if (literalParser.parse()) return nullptr; - if (literalParser.getShape() != type->getShape()) { + if (literalParser.getShape() != type.getShape()) { std::string str; llvm::raw_string_ostream s(str); s << "inferred shape of elements literal (["; interleaveComma(literalParser.getShape(), s); s << "]) does not match type (["; - interleaveComma(type->getShape(), s); + interleaveComma(type.getShape(), s); s << "])"; return (emitError(s.str()), nullptr); } @@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) { /// vector-or-tensor-type ::= vector-type | tensor-type /// /// This method also checks the type has static shape and ranked. -VectorOrTensorType *Parser::parseVectorOrTensorType() { - auto *type = dyn_cast<VectorOrTensorType>(parseType()); +VectorOrTensorType Parser::parseVectorOrTensorType() { + auto type = parseType().dyn_cast<VectorOrTensorType>(); if (!type) { return (emitError("expected elements literal has a tensor or vector type"), nullptr); @@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() { if (parseToken(Token::comma, "expected ','")) return nullptr; - if (!type->hasStaticShape() || type->getRank() == -1) { + if (!type.hasStaticShape() || type.getRank() == -1) { return (emitError("tensor literals must be ranked and have static shape"), nullptr); } @@ -1834,7 +1834,7 @@ public: /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. - SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type); + SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type); /// Register a definition of a value with the symbol table. ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); @@ -1845,11 +1845,11 @@ public: template <typename ResultType> ResultType parseSSADefOrUseAndType( - const std::function<ResultType(SSAUseInfo, Type *)> &action); + const std::function<ResultType(SSAUseInfo, Type)> &action); SSAValue *parseSSAUseAndType() { return parseSSADefOrUseAndType<SSAValue *>( - [&](SSAUseInfo useInfo, Type *type) -> SSAValue * { + [&](SSAUseInfo useInfo, Type type) -> SSAValue * { return resolveSSAUse(useInfo, type); }); } @@ -1880,7 +1880,7 @@ private: /// their first reference, to allow checking for use of undefined values. DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders; - SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type); + SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type); /// Return true if this is a forward reference. bool isForwardReferencePlaceholder(SSAValue *value) { @@ -1891,7 +1891,7 @@ private: /// Create and remember a new placeholder for a forward reference. SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, - Type *type) { + Type type) { // Forward references are always created as instructions, even in ML // functions, because we just need something with a def/use chain. // @@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, /// Given an unbound reference to an SSA value and its type, return the value /// it specifies. This returns null on failure. -SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) { +SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = values[useInfo.name]; // If we have already seen a value of this name, return it. @@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { /// ssa-use-and-type ::= ssa-use `:` type template <typename ResultType> ResultType FunctionParser::parseSSADefOrUseAndType( - const std::function<ResultType(SSAUseInfo, Type *)> &action) { + const std::function<ResultType(SSAUseInfo, Type)> &action) { SSAUseInfo useInfo; if (parseSSAUse(useInfo) || parseToken(Token::colon, "expected ':' and type for SSA operand")) return nullptr; - auto *type = parseType(); + auto type = parseType(); if (!type) return nullptr; @@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList( if (valueIDs.empty()) return ParseSuccess; - SmallVector<Type *, 4> types; + SmallVector<Type, 4> types; if (parseToken(Token::colon, "expected ':' in operand list") || parseTypeListNoParens(types)) return ParseFailure; @@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation( auto type = parseType(); if (!type) return nullptr; - auto fnType = dyn_cast<FunctionType>(type); + auto fnType = type.dyn_cast<FunctionType>(); if (!fnType) return (emitError(typeLoc, "expected function type"), nullptr); - result.addTypes(fnType->getResults()); + result.addTypes(fnType.getResults()); // Check that we have the right number of types for the operands. - auto operandTypes = fnType->getInputs(); + auto operandTypes = fnType.getInputs(); if (operandTypes.size() != operandInfos.size()) { auto plural = "s"[operandInfos.size() == 1]; return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) + @@ -2253,17 +2253,17 @@ public: return parser.parseToken(Token::comma, "expected ','"); } - bool parseColonType(Type *&result) override { + bool parseColonType(Type &result) override { return parser.parseToken(Token::colon, "expected ':'") || !(result = parser.parseType()); } - bool parseColonTypeList(SmallVectorImpl<Type *> &result) override { + bool parseColonTypeList(SmallVectorImpl<Type> &result) override { if (parser.parseToken(Token::colon, "expected ':'")) return true; do { - if (auto *type = parser.parseType()) + if (auto type = parser.parseType()) result.push_back(type); else return true; @@ -2273,7 +2273,7 @@ public: } /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type *&result) override { + bool parseKeywordType(const char *keyword, Type &result) override { if (parser.getTokenSpelling() != keyword) return parser.emitError("expected '" + Twine(keyword) + "'"); parser.consumeToken(); @@ -2396,7 +2396,7 @@ public: } /// Resolve a parse function name and a type into a function reference. - virtual bool resolveFunctionName(StringRef name, FunctionType *type, + virtual bool resolveFunctionName(StringRef name, FunctionType type, llvm::SMLoc loc, Function *&result) { result = parser.resolveFunctionReference(name, loc, type); return result == nullptr; @@ -2410,7 +2410,7 @@ public: llvm::SMLoc getNameLoc() const override { return nameLoc; } - bool resolveOperand(const OperandType &operand, Type *type, + bool resolveOperand(const OperandType &operand, Type type, SmallVectorImpl<SSAValue *> &result) override { FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; @@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList( return ParseSuccess; return parseCommaSeparatedList([&]() -> ParseResult { - auto type = parseSSADefOrUseAndType<Type *>( - [&](SSAUseInfo useInfo, Type *type) -> Type * { + auto type = parseSSADefOrUseAndType<Type>( + [&](SSAUseInfo useInfo, Type type) -> Type { BBArgument *arg = owner->addArgument(type); if (addDefinition(useInfo, arg)) - return nullptr; + return {}; return type; }); return type ? ParseSuccess : ParseFailure; @@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, " symbol count must match"); // Resolve SSA uses. - Type *indexType = builder.getIndexType(); + Type indexType = builder.getIndexType(); for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { SSAValue *sval = resolveSSAUse(opInfo[i], indexType); if (!sval) @@ -3187,9 +3187,9 @@ private: ParseResult parseAffineStructureDef(); // Functions. - ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, + ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames); - ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type, + ParseResult parseFunctionSignature(StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> *argNames); ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs); ParseResult parseExtFunc(); @@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() { /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ /// ParseResult -ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, +ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames) { consumeToken(Token::l_paren); @@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, /// type-list)? /// ParseResult -ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, +ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> *argNames) { if (getToken().isNot(Token::at_identifier)) return emitError("expected a function identifier like '@foo'"); @@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, if (getToken().isNot(Token::l_paren)) return emitError("expected '(' in function signature"); - SmallVector<Type *, 4> argTypes; + SmallVector<Type, 4> argTypes; ParseResult parseResult; if (argNames) @@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, return ParseFailure; // Parse the return type if present. - SmallVector<Type *, 4> results; + SmallVector<Type, 4> results; if (consumeIf(Token::arrow)) { if (parseTypeList(results)) return ParseFailure; @@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() { auto loc = getToken().getLoc(); StringRef name; - FunctionType *type = nullptr; + FunctionType type; if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; @@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() { auto loc = getToken().getLoc(); StringRef name; - FunctionType *type = nullptr; + FunctionType type; if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; @@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() { consumeToken(Token::kw_mlfunc); StringRef name; - FunctionType *type = nullptr; + FunctionType type; SmallVector<StringRef, 4> argNames; auto loc = getToken().getLoc(); diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index b60d209e1f5..e2bdfd7a18b 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// void AllocOp::build(Builder *builder, OperationState *result, - MemRefType *memrefType, ArrayRef<SSAValue *> operands) { + MemRefType memrefType, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->types.push_back(memrefType); } void AllocOp::print(OpAsmPrinter *p) const { - MemRefType *type = getType(); + MemRefType type = getType(); *p << "alloc"; // Print dynamic dimension operands. printDimAndSymbolList(operand_begin(), operand_end(), - type->getNumDynamicDims(), p); + type.getNumDynamicDims(), p); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); - *p << " : " << *type; + *p << " : " << type; } bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { - MemRefType *type; + MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a // memref type. @@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. - if (numDimOperands != type->getNumDynamicDims()) { + if (numDimOperands != type.getNumDynamicDims()) { return parser->emitError(parser->getNameLoc(), "dimension operand count does not equal memref " "dynamic dimension count"); @@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { } bool AllocOp::verify() const { - auto *memRefType = dyn_cast<MemRefType>(getResult()->getType()); + auto memRefType = getResult()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("result must be a memref"); unsigned numSymbols = 0; - if (!memRefType->getAffineMaps().empty()) { - AffineMap affineMap = memRefType->getAffineMaps()[0]; + if (!memRefType.getAffineMaps().empty()) { + AffineMap affineMap = memRefType.getAffineMaps()[0]; // Store number of symbols used in affine map (used in subsequent check). numSymbols = affineMap.getNumSymbols(); // TODO(zinenko): this check does not belong to AllocOp, or any other op but @@ -195,10 +195,10 @@ bool AllocOp::verify() const { // Remove when we can emit errors directly from *Type::get(...) functions. // // Verify that the layout affine map matches the rank of the memref. - if (affineMap.getNumDims() != memRefType->getRank()) + if (affineMap.getNumDims() != memRefType.getRank()) return emitOpError("affine map dimension count must equal memref rank"); } - unsigned numDynamicDims = memRefType->getNumDynamicDims(); + unsigned numDynamicDims = memRefType.getNumDynamicDims(); // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. @@ -208,7 +208,7 @@ bool AllocOp::verify() const { } // Verify that all operands are of type Index. for (auto *operand : getOperands()) { - if (!operand->getType()->isIndex()) + if (!operand->getType().isIndex()) return emitOpError("requires operands to be of type Index"); } return false; @@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern { // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector<int, 4> newShapeConstants; - newShapeConstants.reserve(memrefType->getRank()); + newShapeConstants.reserve(memrefType.getRank()); SmallVector<SSAValue *, 4> newOperands; SmallVector<SSAValue *, 4> droppedOperands; unsigned dynamicDimPos = 0; - for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) { - int dimSize = memrefType->getDimSize(dim); + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); @@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern { } // Create new memref type (which will have fewer dynamic dimensions). - auto *newMemRefType = MemRefType::get( - newShapeConstants, memrefType->getElementType(), - memrefType->getAffineMaps(), memrefType->getMemorySpace()); - assert(newOperands.size() == newMemRefType->getNumDynamicDims()); + auto newMemRefType = MemRefType::get( + newShapeConstants, memrefType.getElementType(), + memrefType.getAffineMaps(), memrefType.getMemorySpace()); + assert(newOperands.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = @@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee, ArrayRef<SSAValue *> operands) { result->addOperands(operands); result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType()->getResults()); + result->addTypes(callee->getType().getResults()); } bool CallOp::parse(OpAsmParser *parser, OperationState *result) { StringRef calleeName; llvm::SMLoc calleeLoc; - FunctionType *calleeType = nullptr; + FunctionType calleeType; SmallVector<OpAsmParser::OperandType, 4> operands; Function *callee = nullptr; if (parser->parseFunctionName(calleeName, calleeLoc) || @@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || - parser->addTypesToList(calleeType->getResults(), result->types) || - parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, + parser->addTypesToList(calleeType.getResults(), result->types) || + parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, result->operands)) return true; @@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const { p->printOperands(getOperands()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallOp::verify() const { @@ -338,20 +338,20 @@ bool CallOp::verify() const { return emitOpError("requires a 'callee' function attribute"); // Verify that the operand and result types match the callee. - auto *fnType = fnAttr.getValue()->getType(); - if (fnType->getNumInputs() != getNumOperands()) + auto fnType = fnAttr.getValue()->getType(); + if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -364,14 +364,14 @@ bool CallOp::verify() const { void CallIndirectOp::build(Builder *builder, OperationState *result, SSAValue *callee, ArrayRef<SSAValue *> operands) { - auto *fnType = cast<FunctionType>(callee->getType()); + auto fnType = callee->getType().cast<FunctionType>(); result->operands.push_back(callee); result->addOperands(operands); - result->addTypes(fnType->getResults()); + result->addTypes(fnType.getResults()); } bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { - FunctionType *calleeType = nullptr; + FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector<OpAsmParser::OperandType, 4> operands; @@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(calleeType) || parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, result->operands) || - parser->addTypesToList(calleeType->getResults(), result->types); + parser->addTypesToList(calleeType.getResults(), result->types); } void CallIndirectOp::print(OpAsmPrinter *p) const { @@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const { p->printOperands(++operandRange.begin(), operandRange.end()); *p << ')'; p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); + *p << " : " << getCallee()->getType(); } bool CallIndirectOp::verify() const { // The callee must be a function. - auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); + auto fnType = getCallee()->getType().dyn_cast<FunctionType>(); if (!fnType) return emitOpError("callee must have function type"); // Verify that the operand and result types match the callee. - if (fnType->getNumInputs() != getNumOperands() - 1) + if (fnType.getNumInputs() != getNumOperands() - 1) return emitOpError("incorrect number of operands for callee"); - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i + 1)->getType() != fnType->getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (getOperand(i + 1)->getType() != fnType.getInput(i)) return emitOpError("operand type mismatch"); } - if (fnType->getNumResults() != getNumResults()) + if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (getResult(i)->getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); } @@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result, } void DeallocOp::print(OpAsmPrinter *p) const { - *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); + *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); } bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; - MemRefType *type; + MemRefType type; return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands); } bool DeallocOp::verify() const { - if (!isa<MemRefType>(getMemRef()->getType())) + if (!getMemRef()->getType().isa<MemRefType>()) return emitOpError("operand must be a memref"); return false; } @@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result, void DimOp::print(OpAsmPrinter *p) const { *p << "dim " << *getOperand() << ", " << getIndex(); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); - *p << " : " << *getOperand()->getType(); + *p << " : " << getOperand()->getType(); } bool DimOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; - Type *type; + Type type; return parser->parseOperand(operandInfo) || parser->parseComma() || parser->parseAttribute(indexAttr, "index", result->attributes) || @@ -496,15 +496,15 @@ bool DimOp::verify() const { return emitOpError("requires an integer attribute named 'index'"); uint64_t index = (uint64_t)indexAttr.getValue(); - auto *type = getOperand()->getType(); - if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { - if (index >= tensorType->getRank()) + auto type = getOperand()->getType(); + if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + if (index >= tensorType.getRank()) return emitOpError("index is out of range"); - } else if (auto *memrefType = dyn_cast<MemRefType>(type)) { - if (index >= memrefType->getRank()) + } else if (auto memrefType = type.dyn_cast<MemRefType>()) { + if (index >= memrefType.getRank()) return emitOpError("index is out of range"); - } else if (isa<UnrankedTensorType>(type)) { + } else if (type.isa<UnrankedTensorType>()) { // ok, assumed to be in-range. } else { return emitOpError("requires an operand with tensor or memref type"); @@ -516,12 +516,12 @@ bool DimOp::verify() const { Attribute DimOp::constantFold(ArrayRef<Attribute> operands, MLIRContext *context) const { // Constant fold dim when the size along the index referred to is a constant. - auto *opType = getOperand()->getType(); + auto opType = getOperand()->getType(); int indexSize = -1; - if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) { - indexSize = tensorType->getShape()[getIndex()]; - } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) { - indexSize = memrefType->getShape()[getIndex()]; + if (auto tensorType = opType.dyn_cast<RankedTensorType>()) { + indexSize = tensorType.getShape()[getIndex()]; + } else if (auto memrefType = opType.dyn_cast<MemRefType>()) { + indexSize = memrefType.getShape()[getIndex()]; } if (indexSize >= 0) @@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getSrcMemRef()->getType(); - *p << ", " << *getDstMemRef()->getType(); - *p << ", " << *getTagMemRef()->getType(); + *p << " : " << getSrcMemRef()->getType(); + *p << ", " << getDstMemRef()->getType(); + *p << ", " << getTagMemRef()->getType(); } // Parse DmaStartOp. @@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; - SmallVector<Type *, 3> types; - auto *indexType = parser->getBuilder().getIndexType(); + SmallVector<Type, 3> types; + auto indexType = parser->getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). @@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { return true; // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || - dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) + if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() || + dstIndexInfos.size() != types[1].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "memref rank not equal to indices count"); - if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) + if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { p->printOperands(getTagIndices()); *p << "], "; p->printOperand(getNumElements()); - *p << " : " << *getTagMemRef()->getType(); + *p << " : " << getTagMemRef()->getType(); } // Parse DmaWaitOp. @@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const { bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; - Type *type; - auto *indexType = parser->getBuilder().getIndexType(); + Type type; + auto indexType = parser->getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. @@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperand(numElementsInfo, indexType, result->operands)) return true; - if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) + if (tagIndexInfos.size() != type.cast<MemRefType>().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results, void ExtractElementOp::build(Builder *builder, OperationState *result, SSAValue *aggregate, ArrayRef<SSAValue *> indices) { - auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); + auto aggregateType = aggregate->getType().cast<VectorOrTensorType>(); result->addOperands(aggregate); result->addOperands(indices); - result->types.push_back(aggregateType->getElementType()); + result->types.push_back(aggregateType.getElementType()); } void ExtractElementOp::print(OpAsmPrinter *p) const { @@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getAggregate()->getType(); + *p << " : " << getAggregate()->getType(); } bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - VectorOrTensorType *type; + VectorOrTensorType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(aggregateInfo) || @@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(aggregateInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool ExtractElementOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected an aggregate to index into"); - auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); + auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>(); if (!aggregateType) return emitOpError("first operand must be a vector or tensor"); - if (getType() != aggregateType->getElementType()) + if (getType() != aggregateType.getElementType()) return emitOpError("result type must match element type of aggregate"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to extract_element must have 'index' type"); // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType->getRank(); + auto aggregateRank = aggregateType.getRank(); if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) return emitOpError("incorrect number of indices for extract_element"); @@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const { void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, ArrayRef<SSAValue *> indices) { - auto *memrefType = cast<MemRefType>(memref->getType()); + auto memrefType = memref->getType().cast<MemRefType>(); result->addOperands(memref); result->addOperands(indices); - result->types.push_back(memrefType->getElementType()); + result->types.push_back(memrefType.getElementType()); } void LoadOp::print(OpAsmPrinter *p) const { @@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *type; + MemRefType type; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(memrefInfo) || @@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); + parser->addTypeToList(type.getElementType(), result->types); } bool LoadOp::verify() const { if (getNumOperands() == 0) return emitOpError("expected a memref to load from"); - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("first operand must be a memref"); - if (getType() != memRefType->getElementType()) + if (getType() != memRefType.getElementType()) return emitOpError("result type must match element type of memref"); - if (memRefType->getRank() != getNumOperands() - 1) + if (memRefType.getRank() != getNumOperands() - 1) return emitOpError("incorrect number of indices for load"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool MemRefCastOp::verify() const { - auto *opType = dyn_cast<MemRefType>(getOperand()->getType()); - auto *resType = dyn_cast<MemRefType>(getType()); + auto opType = getOperand()->getType().dyn_cast<MemRefType>(); + auto resType = getType().dyn_cast<MemRefType>(); if (!opType || !resType) return emitOpError("requires input and result types to be memrefs"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); - if (opType->getAffineMaps() != resType->getAffineMaps()) + if (opType.getAffineMaps() != resType.getAffineMaps()) return emitOpError("requires input and result mappings to be the same"); - if (opType->getMemorySpace() != resType->getMemorySpace()) + if (opType.getMemorySpace() != resType.getMemorySpace()) return emitOpError( "requires input and result memory spaces to be the same"); // They must have the same rank, and any specified dimensions must match. - if (opType->getRank() != resType->getRank()) + if (opType.getRank() != resType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opType->getRank(); i != e; ++i) { - int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i); + for (unsigned i = 0, e = opType.getRank(); i != e; ++i) { + int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } @@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const { p->printOperands(getIndices()); *p << ']'; p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRefType(); + *p << " : " << getMemRefType(); } bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *memrefType; + MemRefType memrefType; auto affineIntTy = parser->getBuilder().getIndexType(); return parser->parseOperand(storeValueInfo) || parser->parseComma() || @@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType->getElementType(), + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands); @@ -950,19 +950,19 @@ bool StoreOp::verify() const { return emitOpError("expected a value to store and a memref"); // Second operand is a memref type. - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); + auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>(); if (!memRefType) return emitOpError("second operand must be a memref"); // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType->getElementType()) + if (getValueToStore()->getType() != memRefType.getElementType()) return emitOpError("first operand must have same type memref element type"); - if (getNumOperands() != 2 + memRefType->getRank()) + if (getNumOperands() != 2 + memRefType.getRank()) return emitOpError("store index operand count not equal to memref rank"); for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) + if (!idx->getType().isIndex()) return emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results, //===----------------------------------------------------------------------===// bool TensorCastOp::verify() const { - auto *opType = dyn_cast<TensorType>(getOperand()->getType()); - auto *resType = dyn_cast<TensorType>(getType()); + auto opType = getOperand()->getType().dyn_cast<TensorType>(); + auto resType = getType().dyn_cast<TensorType>(); if (!opType || !resType) return emitOpError("requires input and result types to be tensors"); if (opType == resType) return emitOpError("requires the input and result type to be different"); - if (opType->getElementType() != resType->getElementType()) + if (opType.getElementType() != resType.getElementType()) return emitOpError( "requires input and result element types to be the same"); // If the source or destination are unranked, then the cast is valid. - auto *opRType = dyn_cast<RankedTensorType>(opType); - auto *resRType = dyn_cast<RankedTensorType>(resType); + auto opRType = opType.dyn_cast<RankedTensorType>(); + auto resRType = resType.dyn_cast<RankedTensorType>(); if (!opRType || !resRType) return false; // If they are both ranked, they have to have the same rank, and any specified // dimensions must match. - if (opRType->getRank() != resRType->getRank()) + if (opRType.getRank() != resRType.getRank()) return emitOpError("requires input and result ranks to match"); - for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { - int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); + for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) { + int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i); if (opDim != -1 && resultDim != -1 && opDim != resultDim) return emitOpError("requires static dimensions to match"); } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 81994ddfab4..15dd89bb758 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> { SmallVector<SSAValue *, 8> existingConstants; // Operation statements that were folded and that need to be erased. std::vector<OperationStmt *> opStmtsToErase; - using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>; + using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>; bool foldOperation(Operation *op, SmallVectorImpl<SSAValue *> &existingConstants, @@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) { auto &inst = *instIt++; - auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { builder.setInsertionPoint(&inst); return builder.create<ConstantOp>(inst.getLoc(), value, type); }; @@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { // Override the walker's operation statement visit for constant folding. void ConstantFold::visitOperationStmt(OperationStmt *stmt) { - auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { + auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { MLFuncBuilder builder(stmt); return builder.create<ConstantOp>(stmt->getLoc(), value, type); }; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index d96d65b5fb7..90421819d82 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { bInner.setInsertionPoint(forStmt, forStmt->begin()); // Doubles the shape with a leading dimension extent of 2. - auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * { + auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { // Add the leading dimension in the shape for the double buffer. - ArrayRef<int> shape = oldMemRefType->getShape(); + ArrayRef<int> shape = oldMemRefType.getShape(); SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); shapeSizes.insert(shapeSizes.begin(), 2); - auto *newMemRefType = - bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {}, - oldMemRefType->getMemorySpace()); + auto newMemRefType = + bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {}, + oldMemRefType.getMemorySpace()); return newMemRefType; }; - auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType())); + auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>()); // Create and place the alloc at the top level. MLFuncBuilder topBuilder(forStmt->getFunction()); - auto *newMemRef = cast<MLValue>( + auto newMemRef = cast<MLValue>( topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType) ->getResult()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index cdf5b7166a0..4ec89425189 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -78,7 +78,7 @@ private: /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants; + DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants; }; }; // end anonymous namespace diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index edd8ce85317..ad9d6dcb769 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, ArrayRef<MLValue *> extraIndices, AffineMap indexRemap) { - unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); + unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank(); + unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; if (indexRemap) { assert(indexRemap.getNumInputs() == oldMemRefRank); @@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, } // Assert same elemental type. - assert(cast<MemRefType>(oldMemRef->getType())->getElementType() == - cast<MemRefType>(newMemRef->getType())->getElementType()); + assert(oldMemRef->getType().cast<MemRefType>().getElementType() == + newMemRef->getType().cast<MemRefType>().getElementType()); // Check if memref was used in a non-deferencing context. for (const StmtOperand &use : oldMemRef->getUses()) { @@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, opStmt->operand_end()); // Result types don't change. Both memref's are of the same elemental type. - SmallVector<Type *, 8> resultTypes; + SmallVector<Type, 8> resultTypes; resultTypes.reserve(opStmt->getNumResults()); for (const auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d7a1f531cef..511afa95993 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches, /// sizes specified by vectorSize. The MemRef lives in the same memory space as /// tmpl. The MemRef should be promoted to a closer memory address space in a /// later pass. -static MemRefType *getVectorizedMemRefType(MemRefType *tmpl, - ArrayRef<int> vectorSizes) { - auto *elementType = tmpl->getElementType(); - assert(!dyn_cast<VectorType>(elementType) && +static MemRefType getVectorizedMemRefType(MemRefType tmpl, + ArrayRef<int> vectorSizes) { + auto elementType = tmpl.getElementType(); + assert(!elementType.dyn_cast<VectorType>() && "Can't vectorize an already vector type"); - assert(tmpl->getAffineMaps().empty() && + assert(tmpl.getAffineMaps().empty() && "Unsupported non-implicit identity map"); return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {}, - tmpl->getMemorySpace()); + tmpl.getMemorySpace()); } /// Creates an unaligned load with the following semantics: @@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() { template <typename LoadOrStoreOpPointer> static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp, ArrayRef<int> vectorSize) { - auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType()); - auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); + auto memRefType = + memoryOp->getMemRef()->getType().template cast<MemRefType>(); + auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); // Materialize a MemRef with 1 vector. auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); |

