diff options
| author | nmostafa <nagy.h.mostafa@intel.com> | 2019-12-05 13:12:50 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-05 13:13:20 -0800 |
| commit | daff60cd68ddf7a198925c9daa3da4320b5c3f25 (patch) | |
| tree | 08ec6407737e5c0bfea663e191d92d11c3f33974 /mlir/lib/Conversion/StandardToLLVM | |
| parent | e67acfa4684e4bee38d3b4c90eff1e78adc62cef (diff) | |
| download | bcm5719-llvm-daff60cd68ddf7a198925c9daa3da4320b5c3f25.tar.gz bcm5719-llvm-daff60cd68ddf7a198925c9daa3da4320b5c3f25.zip | |
Add UnrankedMemRef Type
Closes tensorflow/mlir#261
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/261 from nmostafa:nmostafa/unranked 96b6e918f6ed64496f7573b2db33c0b02658ca45
PiperOrigin-RevId: 284037040
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 195 |
1 files changed, 164 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 23c7be310a9..5a6282e8d4d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -193,6 +193,22 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); } +// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which +// contains: +// 1. int64_t rank, the dynamic rank of this MemRef +// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be +// stack allocated (alloca) copy of a MemRef descriptor that got casted to +// be unranked. + +static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; +static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; + +Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { + auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + return LLVM::LLVMType::getStructTy(rankTy, ptrTy); +} + // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and @@ -221,6 +237,8 @@ Type LLVMTypeConverter::convertStandardType(Type type) { return convertIndexType(indexType); if (auto memRefType = type.dyn_cast<MemRefType>()) return convertMemRefType(memRefType); + if (auto memRefType = type.dyn_cast<UnrankedMemRefType>()) + return convertUnrankedMemRefType(memRefType); if (auto vectorType = type.dyn_cast<VectorType>()) return convertVectorType(vectorType); if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) @@ -246,21 +264,41 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} /*============================================================================*/ +/* StructBuilder implementation */ +/*============================================================================*/ +StructBuilder::StructBuilder(Value *v) : value(v) { + assert(value != nullptr && "value cannot be null"); + structType = value->getType().cast<LLVM::LLVMType>(); +} + +Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { + Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); + return builder.create<LLVM::ExtractValueOp>(loc, type, value, + builder.getI64ArrayAttr(pos)); +} + +void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, + Value *ptr) { + value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); +} +/*============================================================================*/ /* MemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) { - if (value) { - structType = value->getType().cast<LLVM::LLVMType>(); - indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( - kOffsetPosInMemRefDescriptor); - } +MemRefDescriptor::MemRefDescriptor(Value *descriptor) + : StructBuilder(descriptor) { + assert(value != nullptr && "value cannot be null"); + indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( + kOffsetPosInMemRefDescriptor); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { + Value *descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); return MemRefDescriptor(descriptor); @@ -334,24 +372,42 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } -Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { - Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); - return builder.create<LLVM::ExtractValueOp>(loc, type, value, - builder.getI64ArrayAttr(pos)); -} - -void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos, - Value *ptr) { - value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, - builder.getI64ArrayAttr(pos)); -} - LLVM::LLVMType MemRefDescriptor::getElementType() { return value->getType().cast<LLVM::LLVMType>().getStructElementType( kAlignedPtrPosInMemRefDescriptor); } +/*============================================================================*/ +/* UnrankedMemRefDescriptor implementation */ +/*============================================================================*/ + +/// Construct a helper for the given descriptor value. +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor) + : StructBuilder(descriptor) {} + +/// Builds IR creating an `undef` value of the descriptor type. +UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, + Location loc, + Type descriptorType) { + Value *descriptor = + builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); + return UnrankedMemRefDescriptor(descriptor); +} +Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, + Value *v) { + setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); +} +Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { + return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, + Location loc, Value *v) { + setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); +} namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in @@ -432,7 +488,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>(); if (!converted) return matchFailure(); - if (t.isa<MemRefType>()) { + if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) { converted = converted.getPointerTo(); promotedArgIndices.push_back(en.index()); } @@ -983,6 +1039,14 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { Type packedResult; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + for (Type resType : resultTypes) { + assert(!resType.isa<UnrankedMemRefType>() && + "Returning unranked memref is not supported. Pass result as an" + "argument instead."); + (void)resType; + } + if (numResults != 0) { if (!(packedResult = this->lowering.packFunctionResults(resultTypes))) return this->matchFailure(); @@ -1076,11 +1140,26 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { PatternMatchResult match(Operation *op) const override { auto memRefCastOp = cast<MemRefCastOp>(op); - MemRefType sourceType = - memRefCastOp.getOperand()->getType().cast<MemRefType>(); - MemRefType targetType = memRefCastOp.getType(); - return (isSupportedMemRefType(targetType) && - isSupportedMemRefType(sourceType)) + Type srcType = memRefCastOp.getOperand()->getType(); + Type dstType = memRefCastOp.getType(); + + if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) { + MemRefType sourceType = + memRefCastOp.getOperand()->getType().cast<MemRefType>(); + MemRefType targetType = memRefCastOp.getType().cast<MemRefType>(); + return (isSupportedMemRefType(targetType) && + isSupportedMemRefType(sourceType)) + ? matchSuccess() + : matchFailure(); + } + + // At least one of the operands is unranked type + assert(srcType.isa<UnrankedMemRefType>() || + dstType.isa<UnrankedMemRefType>()); + + // Unranked to unranked cast is disallowed + return !(srcType.isa<UnrankedMemRefType>() && + dstType.isa<UnrankedMemRefType>()) ? matchSuccess() : matchFailure(); } @@ -1089,12 +1168,65 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast<MemRefCastOp>(op); OperandAdaptor<MemRefCastOp> transformed(operands); - // memref_cast is defined for source and destination memref types with the - // same element type, same mappings, same address space and same rank. - // Therefore a simple bitcast suffices. If not it is undefined behavior. + + auto srcType = memRefCastOp.getOperand()->getType(); + auto dstType = memRefCastOp.getType(); auto targetStructType = lowering.convertType(memRefCastOp.getType()); - rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType, - transformed.source()); + auto loc = op->getLoc(); + + if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) { + // memref_cast is defined for source and destination memref types with the + // same element type, same mappings, same address space and same rank. + // Therefore a simple bitcast suffices. If not it is undefined behavior. + rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType, + transformed.source()); + } else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { + // Casting ranked to unranked memref type + // Set the rank in the destination from the memref type + // Allocate space on the stack and copy the src memref decsriptor + // Set the ptr in the destination to the stack space + auto srcMemRefType = srcType.cast<MemRefType>(); + int64_t rank = srcMemRefType.getRank(); + // ptr = AllocaOp sizeof(MemRefDescriptor) + auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(), + rewriter); + // voidptr = BitCastOp srcType* to void* + auto voidPtr = + rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) + .getResult(); + // rank = ConstantOp srcRank + auto rankVal = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(rewriter.getIntegerType(64)), + rewriter.getI64IntegerAttr(rank)); + // undef = UndefOp + UnrankedMemRefDescriptor memRefDesc = + UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); + // d1 = InsertValueOp undef, rank, 0 + memRefDesc.setRank(rewriter, loc, rankVal); + // d2 = InsertValueOp d1, voidptr, 1 + memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); + rewriter.replaceOp(op, (Value *)memRefDesc); + + } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { + // Casting from unranked type to ranked. + // The operation is assumed to be doing a correct cast. If the destination + // type mismatches the unranked the type, it is undefined behavior. + UnrankedMemRefDescriptor memRefDesc(transformed.source()); + // ptr = ExtractValueOp src, 1 + auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); + // castPtr = BitCastOp i8* to structTy* + auto castPtr = + rewriter + .create<LLVM::BitcastOp>( + loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(), + ptr) + .getResult(); + // struct = LoadOp castPtr + auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); + rewriter.replaceOp(op, loadOp.getResult()); + } else { + llvm_unreachable("Unsuppored unranked memref to unranked memref cast"); + } } }; @@ -1896,7 +2028,8 @@ SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors( for (auto it : llvm::zip(opOperands, operands)) { auto *operand = std::get<0>(it); auto *llvmOperand = std::get<1>(it); - if (!operand->getType().isa<MemRefType>()) { + if (!operand->getType().isa<MemRefType>() && + !operand->getType().isa<UnrankedMemRefType>()) { promotedOperands.push_back(operand); continue; } |

