diff options
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 91 |
1 files changed, 45 insertions, 46 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 454a28a6558..cb5e96f0086 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -122,7 +122,7 @@ private: void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); void visitOperationStmt(const OperationStmt *opStmt); - void visitType(const Type *type); + void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const Operation *op); @@ -135,16 +135,16 @@ private: } // end anonymous namespace // TODO Support visiting other types/instructions when implemented. -void ModuleState::visitType(const Type *type) { - if (auto *funcType = dyn_cast<FunctionType>(type)) { +void ModuleState::visitType(Type type) { + if (auto funcType = type.dyn_cast<FunctionType>()) { // Visit input and result types for functions. - for (auto *input : funcType->getInputs()) + for (auto input : funcType.getInputs()) visitType(input); - for (auto *result : funcType->getResults()) + for (auto result : funcType.getResults()) visitType(result); - } else if (auto *memref = dyn_cast<MemRefType>(type)) { + } else if (auto memref = type.dyn_cast<MemRefType>()) { // Visit affine maps in memref type. - for (auto map : memref->getAffineMaps()) { + for (auto map : memref.getAffineMaps()) { recordAffineMapReference(map); } } @@ -271,7 +271,7 @@ public: void print(const Module *module); void printFunctionReference(const Function *func); void printAttribute(Attribute attr); - void printType(const Type *type); + void printType(Type type); void print(const Function *fn); void print(const ExtFunction *fn); void print(const CFGFunction *fn); @@ -290,7 +290,7 @@ protected: void printFunctionAttributes(const Function *fn); void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs = {}); - void printFunctionResultType(const FunctionType *type); + void printFunctionResultType(FunctionType type); void printAffineMapId(int affineMapId) const; void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; @@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) { } void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { - auto *type = attr.getType(); - auto shape = type->getShape(); - auto rank = type->getRank(); + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); SmallVector<Attribute, 16> elements; attr.getValues(elements); @@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { os << ']'; } -void ModulePrinter::printType(const Type *type) { - switch (type->getKind()) { +void ModulePrinter::printType(Type type) { + switch (type.getKind()) { case Type::Kind::Index: os << "index"; return; @@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) { return; case Type::Kind::Integer: { - auto *integer = cast<IntegerType>(type); - os << 'i' << integer->getWidth(); + auto integer = type.cast<IntegerType>(); + os << 'i' << integer.getWidth(); return; } case Type::Kind::Function: { - auto *func = cast<FunctionType>(type); + auto func = type.cast<FunctionType>(); os << '('; - interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); + interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); os << ") -> "; - auto results = func->getResults(); + auto results = func.getResults(); if (results.size() == 1) - os << *results[0]; + os << results[0]; else { os << '('; - interleaveComma(results, [&](Type *type) { printType(type); }); + interleaveComma(results, [&](Type type) { printType(type); }); os << ')'; } return; } case Type::Kind::Vector: { - auto *v = cast<VectorType>(type); + auto v = type.cast<VectorType>(); os << "vector<"; - for (auto dim : v->getShape()) + for (auto dim : v.getShape()) os << dim << 'x'; - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::RankedTensor: { - auto *v = cast<RankedTensorType>(type); + auto v = type.cast<RankedTensorType>(); os << "tensor<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - os << *v->getElementType() << '>'; + os << v.getElementType() << '>'; return; } case Type::Kind::UnrankedTensor: { - auto *v = cast<UnrankedTensorType>(type); + auto v = type.cast<UnrankedTensorType>(); os << "tensor<*x"; - printType(v->getElementType()); + printType(v.getElementType()); os << '>'; return; } case Type::Kind::MemRef: { - auto *v = cast<MemRefType>(type); + auto v = type.cast<MemRefType>(); os << "memref<"; - for (auto dim : v->getShape()) { + for (auto dim : v.getShape()) { if (dim < 0) os << '?'; else os << dim; os << 'x'; } - printType(v->getElementType()); - for (auto map : v->getAffineMaps()) { + printType(v.getElementType()); + for (auto map : v.getAffineMaps()) { os << ", "; printAffineMapReference(map); } // Only print the memory space if it is the non-default one. - if (v->getMemorySpace()) - os << ", " << v->getMemorySpace(); + if (v.getMemorySpace()) + os << ", " << v.getMemorySpace(); os << '>'; return; } @@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) { // Function printing //===----------------------------------------------------------------------===// -void ModulePrinter::printFunctionResultType(const FunctionType *type) { - switch (type->getResults().size()) { +void ModulePrinter::printFunctionResultType(FunctionType type) { + switch (type.getResults().size()) { case 0: break; case 1: os << " -> "; - printType(type->getResults()[0]); + printType(type.getResults()[0]); break; default: os << " -> ("; - interleaveComma(type->getResults(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getResults(), + [&](Type eltType) { printType(eltType); }); os << ')'; break; } @@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) { auto type = fn->getType(); os << "@" << fn->getName() << '('; - interleaveComma(type->getInputs(), - [&](Type *eltType) { printType(eltType); }); + interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); }); os << ')'; printFunctionResultType(type); @@ -937,7 +936,7 @@ public: // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } - void printType(const Type *type) { ModulePrinter::printType(type); } + void printType(Type type) { ModulePrinter::printType(type); } void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); } void printAffineMap(AffineMap map) { return ModulePrinter::printAffineMapReference(map); @@ -974,10 +973,10 @@ protected: if (auto *op = value->getDefiningOperation()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) { // i1 constants get special names. - if (intOp->getType()->isInteger(1)) { + if (intOp->getType().isInteger(1)) { specialName << (intOp->getValue() ? "true" : "false"); } else { - specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); + specialName << 'c' << intOp->getValue() << '_' << intOp->getType(); } } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { specialName << 'c' << intOp->getValue(); @@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); } void Type::print(raw_ostream &os) const { ModuleState state(getContext()); - ModulePrinter(os, state).printType(this); + ModulePrinter(os, state).printType(*this); } void Type::dump() const { print(llvm::errs()); } |

