summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/AsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/AsmPrinter.cpp')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp91
1 files changed, 45 insertions, 46 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 454a28a6558..cb5e96f0086 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -122,7 +122,7 @@ private:
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
- void visitType(const Type *type);
+ void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const Operation *op);
@@ -135,16 +135,16 @@ private:
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
-void ModuleState::visitType(const Type *type) {
- if (auto *funcType = dyn_cast<FunctionType>(type)) {
+void ModuleState::visitType(Type type) {
+ if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
- for (auto *input : funcType->getInputs())
+ for (auto input : funcType.getInputs())
visitType(input);
- for (auto *result : funcType->getResults())
+ for (auto result : funcType.getResults())
visitType(result);
- } else if (auto *memref = dyn_cast<MemRefType>(type)) {
+ } else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
- for (auto map : memref->getAffineMaps()) {
+ for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map);
}
}
@@ -271,7 +271,7 @@ public:
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(Attribute attr);
- void printType(const Type *type);
+ void printType(Type type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
@@ -290,7 +290,7 @@ protected:
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
- void printFunctionResultType(const FunctionType *type);
+ void printFunctionResultType(FunctionType type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
@@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
- auto *type = attr.getType();
- auto shape = type->getShape();
- auto rank = type->getRank();
+ auto type = attr.getType();
+ auto shape = type.getShape();
+ auto rank = type.getRank();
SmallVector<Attribute, 16> elements;
attr.getValues(elements);
@@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
os << ']';
}
-void ModulePrinter::printType(const Type *type) {
- switch (type->getKind()) {
+void ModulePrinter::printType(Type type) {
+ switch (type.getKind()) {
case Type::Kind::Index:
os << "index";
return;
@@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
return;
case Type::Kind::Integer: {
- auto *integer = cast<IntegerType>(type);
- os << 'i' << integer->getWidth();
+ auto integer = type.cast<IntegerType>();
+ os << 'i' << integer.getWidth();
return;
}
case Type::Kind::Function: {
- auto *func = cast<FunctionType>(type);
+ auto func = type.cast<FunctionType>();
os << '(';
- interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
+ interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> ";
- auto results = func->getResults();
+ auto results = func.getResults();
if (results.size() == 1)
- os << *results[0];
+ os << results[0];
else {
os << '(';
- interleaveComma(results, [&](Type *type) { printType(type); });
+ interleaveComma(results, [&](Type type) { printType(type); });
os << ')';
}
return;
}
case Type::Kind::Vector: {
- auto *v = cast<VectorType>(type);
+ auto v = type.cast<VectorType>();
os << "vector<";
- for (auto dim : v->getShape())
+ for (auto dim : v.getShape())
os << dim << 'x';
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
- auto *v = cast<RankedTensorType>(type);
+ auto v = type.cast<RankedTensorType>();
os << "tensor<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
- auto *v = cast<UnrankedTensorType>(type);
+ auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x";
- printType(v->getElementType());
+ printType(v.getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
- auto *v = cast<MemRefType>(type);
+ auto v = type.cast<MemRefType>();
os << "memref<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- printType(v->getElementType());
- for (auto map : v->getAffineMaps()) {
+ printType(v.getElementType());
+ for (auto map : v.getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
// Only print the memory space if it is the non-default one.
- if (v->getMemorySpace())
- os << ", " << v->getMemorySpace();
+ if (v.getMemorySpace())
+ os << ", " << v.getMemorySpace();
os << '>';
return;
}
@@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
// Function printing
//===----------------------------------------------------------------------===//
-void ModulePrinter::printFunctionResultType(const FunctionType *type) {
- switch (type->getResults().size()) {
+void ModulePrinter::printFunctionResultType(FunctionType type) {
+ switch (type.getResults().size()) {
case 0:
break;
case 1:
os << " -> ";
- printType(type->getResults()[0]);
+ printType(type.getResults()[0]);
break;
default:
os << " -> (";
- interleaveComma(type->getResults(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getResults(),
+ [&](Type eltType) { printType(eltType); });
os << ')';
break;
}
@@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleaveComma(type->getInputs(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
@@ -937,7 +936,7 @@ public:
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
- void printType(const Type *type) { ModulePrinter::printType(type); }
+ void printType(Type type) { ModulePrinter::printType(type); }
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
@@ -974,10 +973,10 @@ protected:
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
- if (intOp->getType()->isInteger(1)) {
+ if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
- specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
+ specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
@@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
- ModulePrinter(os, state).printType(this);
+ ModulePrinter(os, state).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }
OpenPOWER on IntegriCloud