summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToLLVM
diff options
context:
space:
mode:
authornmostafa <nagy.h.mostafa@intel.com>2019-12-05 13:12:50 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-05 13:13:20 -0800
commitdaff60cd68ddf7a198925c9daa3da4320b5c3f25 (patch)
tree08ec6407737e5c0bfea663e191d92d11c3f33974 /mlir/lib/Conversion/StandardToLLVM
parente67acfa4684e4bee38d3b4c90eff1e78adc62cef (diff)
downloadbcm5719-llvm-daff60cd68ddf7a198925c9daa3da4320b5c3f25.tar.gz
bcm5719-llvm-daff60cd68ddf7a198925c9daa3da4320b5c3f25.zip
Add UnrankedMemRef Type
Closes tensorflow/mlir#261 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/261 from nmostafa:nmostafa/unranked 96b6e918f6ed64496f7573b2db33c0b02658ca45 PiperOrigin-RevId: 284037040
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp195
1 files changed, 164 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 23c7be310a9..5a6282e8d4d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -193,6 +193,22 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
}
+// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
+// contains:
+// 1. int64_t rank, the dynamic rank of this MemRef
+// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
+// stack allocated (alloca) copy of a MemRef descriptor that got casted to
+// be unranked.
+
+static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
+static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
+
+Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
+ auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
+}
+
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
@@ -221,6 +237,8 @@ Type LLVMTypeConverter::convertStandardType(Type type) {
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
+ if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
+ return convertUnrankedMemRefType(memRefType);
if (auto vectorType = type.dyn_cast<VectorType>())
return convertVectorType(vectorType);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
@@ -246,21 +264,41 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
/*============================================================================*/
+/* StructBuilder implementation */
+/*============================================================================*/
+StructBuilder::StructBuilder(Value *v) : value(v) {
+ assert(value != nullptr && "value cannot be null");
+ structType = value->getType().cast<LLVM::LLVMType>();
+}
+
+Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc,
+ unsigned pos) {
+ Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+ return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+ builder.getI64ArrayAttr(pos));
+}
+
+void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
+ Value *ptr) {
+ value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+ builder.getI64ArrayAttr(pos));
+}
+/*============================================================================*/
/* MemRefDescriptor implementation */
/*============================================================================*/
/// Construct a helper for the given descriptor value.
-MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
- if (value) {
- structType = value->getType().cast<LLVM::LLVMType>();
- indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
- kOffsetPosInMemRefDescriptor);
- }
+MemRefDescriptor::MemRefDescriptor(Value *descriptor)
+ : StructBuilder(descriptor) {
+ assert(value != nullptr && "value cannot be null");
+ indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ kOffsetPosInMemRefDescriptor);
}
/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
+
Value *descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
@@ -334,24 +372,42 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
-Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
- unsigned pos) {
- Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
- return builder.create<LLVM::ExtractValueOp>(loc, type, value,
- builder.getI64ArrayAttr(pos));
-}
-
-void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
- Value *ptr) {
- value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
- builder.getI64ArrayAttr(pos));
-}
-
LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor);
}
+/*============================================================================*/
+/* UnrankedMemRefDescriptor implementation */
+/*============================================================================*/
+
+/// Construct a helper for the given descriptor value.
+UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor)
+ : StructBuilder(descriptor) {}
+
+/// Builds IR creating an `undef` value of the descriptor type.
+UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
+ Location loc,
+ Type descriptorType) {
+ Value *descriptor =
+ builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ return UnrankedMemRefDescriptor(descriptor);
+}
+Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
+}
+void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
+ Value *v) {
+ setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
+}
+Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
+ Location loc) {
+ return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
+}
+void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
+ Location loc, Value *v) {
+ setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
+}
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
@@ -432,7 +488,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return matchFailure();
- if (t.isa<MemRefType>()) {
+ if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>()) {
converted = converted.getPointerTo();
promotedArgIndices.push_back(en.index());
}
@@ -983,6 +1039,14 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
Type packedResult;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
+
+ for (Type resType : resultTypes) {
+ assert(!resType.isa<UnrankedMemRefType>() &&
+ "Returning unranked memref is not supported. Pass result as an"
+ "argument instead.");
+ (void)resType;
+ }
+
if (numResults != 0) {
if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
return this->matchFailure();
@@ -1076,11 +1140,26 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
PatternMatchResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
- MemRefType sourceType =
- memRefCastOp.getOperand()->getType().cast<MemRefType>();
- MemRefType targetType = memRefCastOp.getType();
- return (isSupportedMemRefType(targetType) &&
- isSupportedMemRefType(sourceType))
+ Type srcType = memRefCastOp.getOperand()->getType();
+ Type dstType = memRefCastOp.getType();
+
+ if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
+ MemRefType sourceType =
+ memRefCastOp.getOperand()->getType().cast<MemRefType>();
+ MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
+ return (isSupportedMemRefType(targetType) &&
+ isSupportedMemRefType(sourceType))
+ ? matchSuccess()
+ : matchFailure();
+ }
+
+ // At least one of the operands is unranked type
+ assert(srcType.isa<UnrankedMemRefType>() ||
+ dstType.isa<UnrankedMemRefType>());
+
+ // Unranked to unranked cast is disallowed
+ return !(srcType.isa<UnrankedMemRefType>() &&
+ dstType.isa<UnrankedMemRefType>())
? matchSuccess()
: matchFailure();
}
@@ -1089,12 +1168,65 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
- // memref_cast is defined for source and destination memref types with the
- // same element type, same mappings, same address space and same rank.
- // Therefore a simple bitcast suffices. If not it is undefined behavior.
+
+ auto srcType = memRefCastOp.getOperand()->getType();
+ auto dstType = memRefCastOp.getType();
auto targetStructType = lowering.convertType(memRefCastOp.getType());
- rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
- transformed.source());
+ auto loc = op->getLoc();
+
+ if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
+ // memref_cast is defined for source and destination memref types with the
+ // same element type, same mappings, same address space and same rank.
+ // Therefore a simple bitcast suffices. If not it is undefined behavior.
+ rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
+ transformed.source());
+ } else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
+ // Casting ranked to unranked memref type
+ // Set the rank in the destination from the memref type
+ // Allocate space on the stack and copy the src memref decsriptor
+ // Set the ptr in the destination to the stack space
+ auto srcMemRefType = srcType.cast<MemRefType>();
+ int64_t rank = srcMemRefType.getRank();
+ // ptr = AllocaOp sizeof(MemRefDescriptor)
+ auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
+ rewriter);
+ // voidptr = BitCastOp srcType* to void*
+ auto voidPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
+ .getResult();
+ // rank = ConstantOp srcRank
+ auto rankVal = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(rewriter.getIntegerType(64)),
+ rewriter.getI64IntegerAttr(rank));
+ // undef = UndefOp
+ UnrankedMemRefDescriptor memRefDesc =
+ UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
+ // d1 = InsertValueOp undef, rank, 0
+ memRefDesc.setRank(rewriter, loc, rankVal);
+ // d2 = InsertValueOp d1, voidptr, 1
+ memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
+ rewriter.replaceOp(op, (Value *)memRefDesc);
+
+ } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
+ // Casting from unranked type to ranked.
+ // The operation is assumed to be doing a correct cast. If the destination
+ // type mismatches the unranked the type, it is undefined behavior.
+ UnrankedMemRefDescriptor memRefDesc(transformed.source());
+ // ptr = ExtractValueOp src, 1
+ auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
+ // castPtr = BitCastOp i8* to structTy*
+ auto castPtr =
+ rewriter
+ .create<LLVM::BitcastOp>(
+ loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
+ ptr)
+ .getResult();
+ // struct = LoadOp castPtr
+ auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
+ rewriter.replaceOp(op, loadOp.getResult());
+ } else {
+ llvm_unreachable("Unsuppored unranked memref to unranked memref cast");
+ }
}
};
@@ -1896,7 +2028,8 @@ SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
for (auto it : llvm::zip(opOperands, operands)) {
auto *operand = std::get<0>(it);
auto *llvmOperand = std::get<1>(it);
- if (!operand->getType().isa<MemRefType>()) {
+ if (!operand->getType().isa<MemRefType>() &&
+ !operand->getType().isa<UnrankedMemRefType>()) {
promotedOperands.push_back(operand);
continue;
}
OpenPOWER on IntegriCloud