diff options
Diffstat (limited to 'llgo/irgen/typemap.go')
| -rw-r--r-- | llgo/irgen/typemap.go | 217 |
1 files changed, 116 insertions, 101 deletions
diff --git a/llgo/irgen/typemap.go b/llgo/irgen/typemap.go index 486ee6f8958..2ae1a4e2b93 100644 --- a/llgo/irgen/typemap.go +++ b/llgo/irgen/typemap.go @@ -71,13 +71,20 @@ type TypeMap struct { typeSliceType, methodSliceType, imethodSliceType, structFieldSliceType llvm.Type + funcValType llvm.Type hashFnType, equalFnType llvm.Type - hashFnEmptyInterface, hashFnInterface, hashFnFloat, hashFnComplex, hashFnString, hashFnIdentity, hashFnError llvm.Value - equalFnEmptyInterface, equalFnInterface, equalFnFloat, equalFnComplex, equalFnString, equalFnIdentity, equalFnError llvm.Value + algsEmptyInterface, + algsInterface, + algsFloat, + algsComplex, + algsString, + algsIdentity, + algsError algorithms +} - zeroType llvm.Type - zeroValue llvm.Value +type algorithms struct { + hash, hashDescriptor, equal, equalDescriptor llvm.Value } func NewLLVMTypeMap(ctx llvm.Context, target llvm.TargetData) *llvmTypeMap { @@ -117,35 +124,38 @@ func NewTypeMap(pkg *ssa.Package, llvmtm *llvmTypeMap, module llvm.Module, r *ru boolType := llvm.Int8Type() stringPtrType := llvm.PointerType(tm.stringType, 0) - // Create runtime algorithm function types. + tm.funcValType = tm.ctx.StructCreateNamed("funcVal") + tm.funcValType.StructSetBody([]llvm.Type{ + llvm.PointerType(llvm.FunctionType(llvm.VoidType(), []llvm.Type{}, false), 0), + }, false) + params := []llvm.Type{voidPtrType, uintptrType} tm.hashFnType = llvm.FunctionType(uintptrType, params, false) params = []llvm.Type{voidPtrType, voidPtrType, uintptrType} tm.equalFnType = llvm.FunctionType(boolType, params, false) - tm.hashFnEmptyInterface = llvm.AddFunction(tm.module, "__go_type_hash_empty_interface", tm.hashFnType) - tm.hashFnInterface = llvm.AddFunction(tm.module, "__go_type_hash_interface", tm.hashFnType) - tm.hashFnFloat = llvm.AddFunction(tm.module, "__go_type_hash_float", tm.hashFnType) - tm.hashFnComplex = llvm.AddFunction(tm.module, "__go_type_hash_complex", tm.hashFnType) - tm.hashFnString = llvm.AddFunction(tm.module, "__go_type_hash_string", tm.hashFnType) - tm.hashFnIdentity = llvm.AddFunction(tm.module, "__go_type_hash_identity", tm.hashFnType) - tm.hashFnError = llvm.AddFunction(tm.module, "__go_type_hash_error", tm.hashFnType) - - tm.equalFnEmptyInterface = llvm.AddFunction(tm.module, "__go_type_equal_empty_interface", tm.equalFnType) - tm.equalFnInterface = llvm.AddFunction(tm.module, "__go_type_equal_interface", tm.equalFnType) - tm.equalFnFloat = llvm.AddFunction(tm.module, "__go_type_equal_float", tm.equalFnType) - tm.equalFnComplex = llvm.AddFunction(tm.module, "__go_type_equal_complex", tm.equalFnType) - tm.equalFnString = llvm.AddFunction(tm.module, "__go_type_equal_string", tm.equalFnType) - tm.equalFnIdentity = llvm.AddFunction(tm.module, "__go_type_equal_identity", tm.equalFnType) - tm.equalFnError = llvm.AddFunction(tm.module, "__go_type_equal_error", tm.equalFnType) - - // The body of this type is set in emitTypeDescInitializers once we have scanned - // every type, as it needs to be as large and well aligned as the - // largest/most aligned type. - tm.zeroType = tm.ctx.StructCreateNamed("zero") - tm.zeroValue = llvm.AddGlobal(tm.module, tm.zeroType, "go$zerovalue") - tm.zeroValue.SetLinkage(llvm.CommonLinkage) - tm.zeroValue.SetInitializer(llvm.ConstNull(tm.zeroType)) + typeAlgorithms := [...]struct { + Name string + *algorithms + }{ + {"empty_interface", &tm.algsEmptyInterface}, + {"interface", &tm.algsInterface}, + {"float", &tm.algsFloat}, + {"complex", &tm.algsComplex}, + {"string", &tm.algsString}, + {"identity", &tm.algsIdentity}, + {"error", &tm.algsError}, + } + for _, typeAlgs := range typeAlgorithms { + hashFnName := "__go_type_hash_" + typeAlgs.Name + hashDescriptorName := hashFnName + "_descriptor" + equalFnName := "__go_type_equal_" + typeAlgs.Name + equalDescriptorName := equalFnName + "_descriptor" + typeAlgs.hash = llvm.AddGlobal(tm.module, tm.hashFnType, hashFnName) + typeAlgs.hashDescriptor = llvm.AddGlobal(tm.module, tm.funcValType, hashDescriptorName) + typeAlgs.equal = llvm.AddGlobal(tm.module, tm.equalFnType, equalFnName) + typeAlgs.equalDescriptor = llvm.AddGlobal(tm.module, tm.funcValType, equalDescriptorName) + } tm.commonTypeType = tm.ctx.StructCreateNamed("commonType") commonTypeTypePtr := llvm.PointerType(tm.commonTypeType, 0) @@ -174,13 +184,12 @@ func NewTypeMap(pkg *ssa.Package, llvmtm *llvmTypeMap, module llvm.Module, r *ru tm.ctx.Int8Type(), // fieldAlign uintptrType, // size tm.ctx.Int32Type(), // hash - llvm.PointerType(tm.hashFnType, 0), // hashfn - llvm.PointerType(tm.equalFnType, 0), // equalfn + llvm.PointerType(tm.funcValType, 0), // hashfn + llvm.PointerType(tm.funcValType, 0), // equalfn voidPtrType, // gc stringPtrType, // string llvm.PointerType(tm.uncommonTypeType, 0), // uncommonType commonTypeTypePtr, // ptrToThis - llvm.PointerType(tm.zeroType, 0), // zero }, false) tm.typeSliceType = tm.makeNamedSliceType("typeSlice", commonTypeTypePtr) @@ -1096,9 +1105,6 @@ func (tm *TypeMap) emitTypeDescInitializers() { } } } - - tm.zeroType.StructSetBody([]llvm.Type{llvm.ArrayType(tm.ctx.Int8Type(), int(maxSize))}, false) - tm.zeroValue.SetAlignment(int(maxAlign)) } const ( @@ -1234,24 +1240,20 @@ func (tm *TypeMap) makeTypeDescInitializer(t types.Type) llvm.Value { } } -type algorithmFns struct { - hash, equal llvm.Value -} - -func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal llvm.Value) { - if algs, ok := tm.algs.At(st).(algorithmFns); ok { - return algs.hash, algs.equal +func (tm *TypeMap) getStructAlgorithms(st *types.Struct) algorithms { + if algs, ok := tm.algs.At(st).(algorithms); ok { + return algs } hashes := make([]llvm.Value, st.NumFields()) equals := make([]llvm.Value, st.NumFields()) for i := range hashes { - fhash, fequal := tm.getAlgorithmFunctions(st.Field(i).Type()) - if fhash == tm.hashFnError { - return fhash, fequal + algs := tm.getAlgorithms(st.Field(i).Type()) + if algs.hashDescriptor == tm.algsError.hashDescriptor { + return algs } - hashes[i], equals[i] = fhash, fequal + hashes[i], equals[i] = algs.hash, algs.equal } i8ptr := llvm.PointerType(tm.ctx.Int8Type(), 0) @@ -1260,8 +1262,11 @@ func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal ll builder := tm.ctx.NewBuilder() defer builder.Dispose() - hash = llvm.AddFunction(tm.module, tm.mc.mangleHashFunctionName(st), tm.hashFnType) + hashFunctionName := tm.mc.mangleHashFunctionName(st) + hash := llvm.AddFunction(tm.module, hashFunctionName, tm.hashFnType) hash.SetLinkage(llvm.LinkOnceODRLinkage) + hashDescriptor := tm.createAlgorithmDescriptor(hashFunctionName+"_descriptor", hash) + builder.SetInsertPointAtEnd(llvm.AddBasicBlock(hash, "entry")) sptr := builder.CreateBitCast(hash.Param(0), llsptrty, "") @@ -1271,9 +1276,7 @@ func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal ll for i, fhash := range hashes { fptr := builder.CreateStructGEP(sptr, i, "") fptr = builder.CreateBitCast(fptr, i8ptr, "") - fsize := llvm.ConstInt(tm.inttype, uint64(tm.sizes.Sizeof(st.Field(i).Type())), false) - hashcall := builder.CreateCall(fhash, []llvm.Value{fptr, fsize}, "") hashval = builder.CreateMul(hashval, i33, "") hashval = builder.CreateAdd(hashval, hashcall, "") @@ -1281,11 +1284,13 @@ func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal ll builder.CreateRet(hashval) - equal = llvm.AddFunction(tm.module, tm.mc.mangleEqualFunctionName(st), tm.equalFnType) + equalFunctionName := tm.mc.mangleEqualFunctionName(st) + equal := llvm.AddFunction(tm.module, equalFunctionName, tm.equalFnType) equal.SetLinkage(llvm.LinkOnceODRLinkage) + equalDescriptor := tm.createAlgorithmDescriptor(equalFunctionName+"_descriptor", equal) + eqentrybb := llvm.AddBasicBlock(equal, "entry") eqretzerobb := llvm.AddBasicBlock(equal, "retzero") - builder.SetInsertPointAtEnd(eqentrybb) s1ptr := builder.CreateBitCast(equal.Param(0), llsptrty, "") s2ptr := builder.CreateBitCast(equal.Param(1), llsptrty, "") @@ -1298,15 +1303,11 @@ func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal ll f1ptr = builder.CreateBitCast(f1ptr, i8ptr, "") f2ptr := builder.CreateStructGEP(s2ptr, i, "") f2ptr = builder.CreateBitCast(f2ptr, i8ptr, "") - fsize := llvm.ConstInt(tm.inttype, uint64(tm.sizes.Sizeof(st.Field(i).Type())), false) - equalcall := builder.CreateCall(fequal, []llvm.Value{f1ptr, f2ptr, fsize}, "") equaleqzero := builder.CreateICmp(llvm.IntEQ, equalcall, zerobool, "") - contbb := llvm.AddBasicBlock(equal, "cont") builder.CreateCondBr(equaleqzero, eqretzerobb, contbb) - builder.SetInsertPointAtEnd(contbb) } @@ -1315,18 +1316,24 @@ func (tm *TypeMap) getStructAlgorithmFunctions(st *types.Struct) (hash, equal ll builder.SetInsertPointAtEnd(eqretzerobb) builder.CreateRet(zerobool) - tm.algs.Set(st, algorithmFns{hash, equal}) - return + algs := algorithms{ + hash: hash, + hashDescriptor: hashDescriptor, + equal: equal, + equalDescriptor: equalDescriptor, + } + tm.algs.Set(st, algs) + return algs } -func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm.Value) { - if algs, ok := tm.algs.At(at).(algorithmFns); ok { - return algs.hash, algs.equal +func (tm *TypeMap) getArrayAlgorithms(at *types.Array) algorithms { + if algs, ok := tm.algs.At(at).(algorithms); ok { + return algs } - ehash, eequal := tm.getAlgorithmFunctions(at.Elem()) - if ehash == tm.hashFnError { - return ehash, eequal + elemAlgs := tm.getAlgorithms(at.Elem()) + if elemAlgs.hashDescriptor == tm.algsError.hashDescriptor { + return elemAlgs } i8ptr := llvm.PointerType(tm.ctx.Int8Type(), 0) @@ -1339,8 +1346,22 @@ func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm builder := tm.ctx.NewBuilder() defer builder.Dispose() - hash = llvm.AddFunction(tm.module, tm.mc.mangleHashFunctionName(at), tm.hashFnType) + hashFunctionName := tm.mc.mangleHashFunctionName(at) + hash := llvm.AddFunction(tm.module, hashFunctionName, tm.hashFnType) hash.SetLinkage(llvm.LinkOnceODRLinkage) + hashDescriptor := tm.createAlgorithmDescriptor(hashFunctionName+"_descriptor", hash) + equalFunctionName := tm.mc.mangleHashFunctionName(at) + equal := llvm.AddFunction(tm.module, equalFunctionName, tm.equalFnType) + equal.SetLinkage(llvm.LinkOnceODRLinkage) + equalDescriptor := tm.createAlgorithmDescriptor(equalFunctionName+"_descriptor", equal) + algs := algorithms{ + hash: hash, + hashDescriptor: hashDescriptor, + equal: equal, + equalDescriptor: equalDescriptor, + } + tm.algs.Set(at, algs) + hashentrybb := llvm.AddBasicBlock(hash, "entry") builder.SetInsertPointAtEnd(hashentrybb) if at.Len() == 0 { @@ -1363,7 +1384,7 @@ func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm eptr := builder.CreateGEP(aptr, []llvm.Value{index}, "") eptr = builder.CreateBitCast(eptr, i8ptr, "") - hashcall := builder.CreateCall(ehash, []llvm.Value{eptr, esize}, "") + hashcall := builder.CreateCall(elemAlgs.hash, []llvm.Value{eptr, esize}, "") hashval = builder.CreateMul(hashval, i33, "") hashval = builder.CreateAdd(hashval, hashcall, "") @@ -1388,8 +1409,6 @@ func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm zerobool := llvm.ConstNull(tm.ctx.Int8Type()) onebool := llvm.ConstInt(tm.ctx.Int8Type(), 1, false) - equal = llvm.AddFunction(tm.module, tm.mc.mangleEqualFunctionName(at), tm.equalFnType) - equal.SetLinkage(llvm.LinkOnceODRLinkage) eqentrybb := llvm.AddBasicBlock(equal, "entry") builder.SetInsertPointAtEnd(eqentrybb) if at.Len() == 0 { @@ -1412,7 +1431,7 @@ func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm e2ptr := builder.CreateGEP(a2ptr, []llvm.Value{index}, "") e2ptr = builder.CreateBitCast(e2ptr, i8ptr, "") - equalcall := builder.CreateCall(eequal, []llvm.Value{e1ptr, e2ptr, esize}, "") + equalcall := builder.CreateCall(elemAlgs.equal, []llvm.Value{e1ptr, e2ptr, esize}, "") equaleqzero := builder.CreateICmp(llvm.IntEQ, equalcall, zerobool, "") contbb := llvm.AddBasicBlock(equal, "cont") @@ -1437,48 +1456,45 @@ func (tm *TypeMap) getArrayAlgorithmFunctions(at *types.Array) (hash, equal llvm builder.CreateRet(zerobool) } - tm.algs.Set(at, algorithmFns{hash, equal}) - return + return algs +} + +func (tm *TypeMap) createAlgorithmDescriptor(name string, fn llvm.Value) llvm.Value { + d := llvm.AddGlobal(tm.module, tm.funcValType, name) + d.SetLinkage(llvm.LinkOnceODRLinkage) + d.SetGlobalConstant(true) + fn = llvm.ConstBitCast(fn, tm.funcValType.StructElementTypes()[0]) + init := llvm.ConstNull(tm.funcValType) + init = llvm.ConstInsertValue(init, fn, []uint32{0}) + d.SetInitializer(init) + return d } -func (tm *TypeMap) getAlgorithmFunctions(t types.Type) (hash, equal llvm.Value) { +func (tm *TypeMap) getAlgorithms(t types.Type) algorithms { switch t := t.Underlying().(type) { case *types.Interface: if t.NumMethods() == 0 { - hash = tm.hashFnEmptyInterface - equal = tm.equalFnEmptyInterface - } else { - hash = tm.hashFnInterface - equal = tm.equalFnInterface + return tm.algsEmptyInterface } + return tm.algsInterface case *types.Basic: switch t.Kind() { case types.Float32, types.Float64: - hash = tm.hashFnFloat - equal = tm.equalFnFloat + return tm.algsFloat case types.Complex64, types.Complex128: - hash = tm.hashFnComplex - equal = tm.equalFnComplex + return tm.algsComplex case types.String: - hash = tm.hashFnString - equal = tm.equalFnString - default: - hash = tm.hashFnIdentity - equal = tm.equalFnIdentity + return tm.algsString } + return tm.algsIdentity case *types.Signature, *types.Map, *types.Slice: - hash = tm.hashFnError - equal = tm.equalFnError + return tm.algsError case *types.Struct: - hash, equal = tm.getStructAlgorithmFunctions(t) + return tm.getStructAlgorithms(t) case *types.Array: - hash, equal = tm.getArrayAlgorithmFunctions(t) - default: - hash = tm.hashFnIdentity - equal = tm.equalFnIdentity + return tm.getArrayAlgorithms(t) } - - return + return tm.algsIdentity } func (tm *TypeMap) getTypeDescInfo(t types.Type) *typeDescInfo { @@ -1704,27 +1720,26 @@ func runtimeTypeKind(t types.Type) (k uint8) { } func (tm *TypeMap) makeCommonType(t types.Type) llvm.Value { - var vals [12]llvm.Value + var vals [11]llvm.Value vals[0] = llvm.ConstInt(tm.ctx.Int8Type(), uint64(runtimeTypeKind(t)), false) vals[1] = llvm.ConstInt(tm.ctx.Int8Type(), uint64(tm.Alignof(t)), false) vals[2] = vals[1] vals[3] = llvm.ConstInt(tm.inttype, uint64(tm.Sizeof(t)), false) vals[4] = llvm.ConstInt(tm.ctx.Int32Type(), uint64(tm.getTypeHash(t)), false) - hash, equal := tm.getAlgorithmFunctions(t) - vals[5] = hash - vals[6] = equal + algs := tm.getAlgorithms(t) + vals[5] = algs.hashDescriptor + vals[6] = algs.equalDescriptor vals[7] = tm.getGcPointer(t) var b bytes.Buffer tm.writeType(t, &b) vals[8] = tm.globalStringPtr(b.String()) vals[9] = tm.makeUncommonTypePtr(t) - if _, ok := t.(*types.Named); ok { + switch t.(type) { + case *types.Named, *types.Struct: vals[10] = tm.getTypeDescriptorPointer(types.NewPointer(t)) - } else { + default: vals[10] = llvm.ConstPointerNull(llvm.PointerType(tm.commonTypeType, 0)) } - vals[11] = tm.zeroValue - return llvm.ConstNamedStruct(tm.commonTypeType, vals[:]) } |

