diff --git a/compiler/compiler.go b/compiler/compiler.go index 2a0246971a..11854c1736 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1416,7 +1416,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return c.parseMakeClosure(frame, expr) case *ssa.MakeInterface: val := c.getValue(frame, expr.X) - return c.parseMakeInterface(val, expr.X.Type(), expr.Pos()) + return c.parseMakeInterface(val, expr.X.Type(), expr.Pos()), nil case *ssa.MakeMap: mapType := expr.Type().Underlying().(*types.Map) llvmKeyType := c.getLLVMType(mapType.Key().Underlying()) @@ -1425,7 +1425,16 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { valueSize := c.targetData.TypeAllocSize(llvmValueType) llvmKeySize := llvm.ConstInt(c.ctx.Int8Type(), keySize, false) llvmValueSize := llvm.ConstInt(c.ctx.Int8Type(), valueSize, false) - hashmap := c.createRuntimeCall("hashmapMake", []llvm.Value{llvmKeySize, llvmValueSize}, "") + sizeHint := llvm.ConstInt(c.uintptrType, 8, false) + if expr.Reserve != nil { + sizeHint = c.getValue(frame, expr.Reserve) + var err error + sizeHint, err = c.parseConvert(expr.Reserve.Type(), types.Typ[types.Uintptr], sizeHint, expr.Pos()) + if err != nil { + return llvm.Value{}, err + } + } + hashmap := c.createRuntimeCall("hashmapMake", []llvm.Value{llvmKeySize, llvmValueSize, sizeHint}, "") return hashmap, nil case *ssa.MakeSlice: sliceLen := c.getValue(frame, expr.Len) diff --git a/compiler/interface.go b/compiler/interface.go index 6bc8580610..3d2014f317 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -22,13 +22,10 @@ import ( // value field. // // An interface value is a {typecode, value} tuple, or {i16, i8*} to be exact. -func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, pos token.Pos) (llvm.Value, error) { +func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, pos token.Pos) llvm.Value { itfValue := c.emitPointerPack([]llvm.Value{val}) itfTypeCodeGlobal := c.getTypeCode(typ) - itfMethodSetGlobal, err := c.getTypeMethodSet(typ) - if err != nil { - return llvm.Value{}, nil - } + itfMethodSetGlobal := c.getTypeMethodSet(typ) itfConcreteTypeGlobal := c.mod.NamedGlobal("typeInInterface:" + itfTypeCodeGlobal.Name()) if itfConcreteTypeGlobal.IsNil() { typeInInterface := c.mod.GetTypeByName("runtime.typeInInterface") @@ -41,7 +38,7 @@ func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, pos token. itf := llvm.Undef(c.mod.GetTypeByName("runtime._interface")) itf = c.builder.CreateInsertValue(itf, itfTypeCode, 0, "") itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") - return itf, nil + return itf } // getTypeCode returns a reference to a type code. @@ -155,18 +152,18 @@ func getTypeCodeName(t types.Type) string { // getTypeMethodSet returns a reference (GEP) to a global method set. This // method set should be unreferenced after the interface lowering pass. -func (c *Compiler) getTypeMethodSet(typ types.Type) (llvm.Value, error) { +func (c *Compiler) getTypeMethodSet(typ types.Type) llvm.Value { global := c.mod.NamedGlobal(typ.String() + "$methodset") zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) if !global.IsNil() { // the method set already exists - return llvm.ConstGEP(global, []llvm.Value{zero, zero}), nil + return llvm.ConstGEP(global, []llvm.Value{zero, zero}) } ms := c.ir.Program.MethodSets.MethodSet(typ) if ms.Len() == 0 { // no methods, so can leave that one out - return llvm.ConstPointerNull(llvm.PointerType(c.mod.GetTypeByName("runtime.interfaceMethodInfo"), 0)), nil + return llvm.ConstPointerNull(llvm.PointerType(c.mod.GetTypeByName("runtime.interfaceMethodInfo"), 0)) } methods := make([]llvm.Value, ms.Len()) @@ -179,10 +176,7 @@ func (c *Compiler) getTypeMethodSet(typ types.Type) (llvm.Value, error) { // compiler error, so panic panic("cannot find function: " + f.LinkName()) } - fn, err := c.getInterfaceInvokeWrapper(f) - if err != nil { - return llvm.Value{}, err - } + fn := c.getInterfaceInvokeWrapper(f) methodInfo := llvm.ConstNamedStruct(interfaceMethodInfoType, []llvm.Value{ signatureGlobal, llvm.ConstPtrToInt(fn, c.uintptrType), @@ -195,7 +189,7 @@ func (c *Compiler) getTypeMethodSet(typ types.Type) (llvm.Value, error) { global.SetInitializer(value) global.SetGlobalConstant(true) global.SetLinkage(llvm.PrivateLinkage) - return llvm.ConstGEP(global, []llvm.Value{zero, zero}), nil + return llvm.ConstGEP(global, []llvm.Value{zero, zero}) } // getInterfaceMethodSet returns a global variable with the method set of the @@ -365,12 +359,12 @@ type interfaceInvokeWrapper struct { // the underlying value, dereferences it, and calls the real method. This // wrapper is only needed when the interface value actually doesn't fit in a // pointer and a pointer to the value must be created. -func (c *Compiler) getInterfaceInvokeWrapper(f *ir.Function) (llvm.Value, error) { +func (c *Compiler) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value { wrapperName := f.LinkName() + "$invoke" wrapper := c.mod.NamedFunction(wrapperName) if !wrapper.IsNil() { // Wrapper already created. Return it directly. - return wrapper, nil + return wrapper } // Get the expanded receiver type. @@ -383,7 +377,7 @@ func (c *Compiler) getInterfaceInvokeWrapper(f *ir.Function) (llvm.Value, error) // Casting a function signature to a different signature and calling it // with a receiver pointer bitcasted to *i8 (as done in calls on an // interface) is hopefully a safe (defined) operation. - return f.LLVMFn, nil + return f.LLVMFn } // create wrapper function @@ -396,7 +390,7 @@ func (c *Compiler) getInterfaceInvokeWrapper(f *ir.Function) (llvm.Value, error) wrapper: wrapper, receiverType: receiverType, }) - return wrapper, nil + return wrapper } // createInterfaceInvokeWrapper finishes the work of getInterfaceInvokeWrapper, diff --git a/compiler/llvm.go b/compiler/llvm.go index 77fc9c44c7..a49573456d 100644 --- a/compiler/llvm.go +++ b/compiler/llvm.go @@ -22,6 +22,36 @@ func getUses(value llvm.Value) []llvm.Value { return uses } +// createEntryBlockAlloca creates a new alloca in the entry block, even though +// the IR builder is located elsewhere. It assumes that the insert point is +// after the last instruction in the current block. Also, it adds lifetime +// information to the IR signalling that the alloca won't be used before this +// point. +// +// This is useful for creating temporary allocas for intrinsics. Don't forget to +// end the lifetime after you're done with it. +func (c *Compiler) createEntryBlockAlloca(t llvm.Type, name string) (alloca, bitcast, size llvm.Value) { + currentBlock := c.builder.GetInsertBlock() + c.builder.SetInsertPointBefore(currentBlock.Parent().EntryBasicBlock().FirstInstruction()) + alloca = c.builder.CreateAlloca(t, name) + c.builder.SetInsertPointAtEnd(currentBlock) + bitcast = c.builder.CreateBitCast(alloca, c.i8ptrType, name+".bitcast") + size = llvm.ConstInt(c.ctx.Int64Type(), c.targetData.TypeAllocSize(t), false) + c.builder.CreateCall(c.getLifetimeStartFunc(), []llvm.Value{size, bitcast}, "") + return +} + +// getLifetimeStartFunc returns the llvm.lifetime.start intrinsic and creates it +// first if it doesn't exist yet. +func (c *Compiler) getLifetimeStartFunc() llvm.Value { + fn := c.mod.NamedFunction("llvm.lifetime.start.p0i8") + if fn.IsNil() { + fnType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.ctx.Int64Type(), c.i8ptrType}, false) + fn = llvm.AddFunction(c.mod, "llvm.lifetime.start.p0i8", fnType) + } + return fn +} + // getLifetimeEndFunc returns the llvm.lifetime.end intrinsic and creates it // first if it doesn't exist yet. func (c *Compiler) getLifetimeEndFunc() llvm.Value { diff --git a/compiler/map.go b/compiler/map.go index e69393c22e..da8e687431 100644 --- a/compiler/map.go +++ b/compiler/map.go @@ -11,8 +11,13 @@ import ( func (c *Compiler) emitMapLookup(keyType, valueType types.Type, m, key llvm.Value, commaOk bool, pos token.Pos) (llvm.Value, error) { llvmValueType := c.getLLVMType(valueType) - mapValueAlloca := c.builder.CreateAlloca(llvmValueType, "hashmap.value") - mapValuePtr := c.builder.CreateBitCast(mapValueAlloca, c.i8ptrType, "hashmap.valueptr") + + // Allocate the memory for the resulting type. Do not zero this memory: it + // will be zeroed by the hashmap get implementation if the key is not + // present in the map. + mapValueAlloca, mapValuePtr, mapValueSize := c.createEntryBlockAlloca(llvmValueType, "hashmap.value") + + // Do the lookup. How it is done depends on the key type. var commaOkValue llvm.Value if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string @@ -20,15 +25,24 @@ func (c *Compiler) emitMapLookup(keyType, valueType types.Type, m, key llvm.Valu commaOkValue = c.createRuntimeCall("hashmapStringGet", params, "") } else if hashmapIsBinaryKey(keyType) { // key can be compared with runtime.memequal - keyAlloca := c.builder.CreateAlloca(key.Type(), "hashmap.key") - c.builder.CreateStore(key, keyAlloca) - keyPtr := c.builder.CreateBitCast(keyAlloca, c.i8ptrType, "hashmap.keyptr") - params := []llvm.Value{m, keyPtr, mapValuePtr} + // Store the key in an alloca, in the entry block to avoid dynamic stack + // growth. + mapKeyAlloca, mapKeyPtr, mapKeySize := c.createEntryBlockAlloca(key.Type(), "hashmap.key") + c.builder.CreateStore(key, mapKeyAlloca) + // Fetch the value from the hashmap. + params := []llvm.Value{m, mapKeyPtr, mapValuePtr} commaOkValue = c.createRuntimeCall("hashmapBinaryGet", params, "") + c.builder.CreateCall(c.getLifetimeEndFunc(), []llvm.Value{mapKeySize, mapKeyPtr}, "") } else { + // Not trivially comparable using memcmp. return llvm.Value{}, c.makeError(pos, "only strings, bools, ints or structs of bools/ints are supported as map keys, but got: "+keyType.String()) } + + // Load the resulting value from the hashmap. The value is set to the zero + // value if the key doesn't exist in the hashmap. mapValue := c.builder.CreateLoad(mapValueAlloca, "") + c.builder.CreateCall(c.getLifetimeEndFunc(), []llvm.Value{mapValueSize, mapValuePtr}, "") + if commaOk { tuple := llvm.Undef(c.ctx.StructType([]llvm.Type{llvmValueType, c.ctx.Int1Type()}, false)) tuple = c.builder.CreateInsertValue(tuple, mapValue, 0, "") @@ -40,9 +54,8 @@ func (c *Compiler) emitMapLookup(keyType, valueType types.Type, m, key llvm.Valu } func (c *Compiler) emitMapUpdate(keyType types.Type, m, key, value llvm.Value, pos token.Pos) { - valueAlloca := c.builder.CreateAlloca(value.Type(), "hashmap.value") + valueAlloca, valuePtr, valueSize := c.createEntryBlockAlloca(value.Type(), "hashmap.value") c.builder.CreateStore(value, valueAlloca) - valuePtr := c.builder.CreateBitCast(valueAlloca, c.i8ptrType, "hashmap.valueptr") keyType = keyType.Underlying() if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string @@ -50,14 +63,15 @@ func (c *Compiler) emitMapUpdate(keyType types.Type, m, key, value llvm.Value, p c.createRuntimeCall("hashmapStringSet", params, "") } else if hashmapIsBinaryKey(keyType) { // key can be compared with runtime.memequal - keyAlloca := c.builder.CreateAlloca(key.Type(), "hashmap.key") + keyAlloca, keyPtr, keySize := c.createEntryBlockAlloca(key.Type(), "hashmap.key") c.builder.CreateStore(key, keyAlloca) - keyPtr := c.builder.CreateBitCast(keyAlloca, c.i8ptrType, "hashmap.keyptr") params := []llvm.Value{m, keyPtr, valuePtr} c.createRuntimeCall("hashmapBinarySet", params, "") + c.builder.CreateCall(c.getLifetimeEndFunc(), []llvm.Value{keySize, keyPtr}, "") } else { c.addError(pos, "only strings, bools, ints or structs of bools/ints are supported as map keys, but got: "+keyType.String()) } + c.builder.CreateCall(c.getLifetimeEndFunc(), []llvm.Value{valueSize, valuePtr}, "") } func (c *Compiler) emitMapDelete(keyType types.Type, m, key llvm.Value, pos token.Pos) error { @@ -68,11 +82,11 @@ func (c *Compiler) emitMapDelete(keyType types.Type, m, key llvm.Value, pos toke c.createRuntimeCall("hashmapStringDelete", params, "") return nil } else if hashmapIsBinaryKey(keyType) { - keyAlloca := c.builder.CreateAlloca(key.Type(), "hashmap.key") + keyAlloca, keyPtr, keySize := c.createEntryBlockAlloca(key.Type(), "hashmap.key") c.builder.CreateStore(key, keyAlloca) - keyPtr := c.builder.CreateBitCast(keyAlloca, c.i8ptrType, "hashmap.keyptr") params := []llvm.Value{m, keyPtr} c.createRuntimeCall("hashmapBinaryDelete", params, "") + c.builder.CreateCall(c.getLifetimeEndFunc(), []llvm.Value{keySize, keyPtr}, "") return nil } else { return c.makeError(pos, "only strings, bools, ints or structs of bools/ints are supported as map keys, but got: "+keyType.String()) diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go index 6afde9d9d6..2cccebcddc 100644 --- a/src/runtime/hashmap.go +++ b/src/runtime/hashmap.go @@ -60,14 +60,20 @@ func hashmapTopHash(hash uint32) uint8 { } // Create a new hashmap with the given keySize and valueSize. -func hashmapMake(keySize, valueSize uint8) *hashmap { +func hashmapMake(keySize, valueSize uint8, sizeHint uintptr) *hashmap { + numBuckets := sizeHint / 8 + bucketBits := uint8(0) + for numBuckets != 0 { + numBuckets /= 2 + bucketBits++ + } bucketBufSize := unsafe.Sizeof(hashmapBucket{}) + uintptr(keySize)*8 + uintptr(valueSize)*8 - bucket := alloc(bucketBufSize) + buckets := alloc(bucketBufSize * (1 << bucketBits)) return &hashmap{ - buckets: bucket, + buckets: buckets, keySize: keySize, valueSize: valueSize, - bucketBits: 0, + bucketBits: bucketBits, } } @@ -83,13 +89,20 @@ func hashmapLen(m *hashmap) int { // Set a specified key to a given value. Grow the map if necessary. //go:nobounds func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint32, keyEqual func(x, y unsafe.Pointer, n uintptr) bool) { + tophash := hashmapTopHash(hash) + + if m.buckets == nil { + // No bucket was allocated yet, do so now. + m.buckets = unsafe.Pointer(hashmapInsertIntoNewBucket(m, key, value, tophash)) + return + } + numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) bucketSize := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*8 bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) - - tophash := hashmapTopHash(hash) + var lastBucket *hashmapBucket // See whether the key already exists somewhere. var emptySlotKey unsafe.Pointer @@ -98,9 +111,9 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 for bucket != nil { for i := uintptr(0); i < 8; i++ { slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*uintptr(i) - slotKey := unsafe.Pointer(bucketAddr + slotKeyOffset) + slotKey := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotKeyOffset) slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*uintptr(i) - slotValue := unsafe.Pointer(bucketAddr + slotValueOffset) + slotValue := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotValueOffset) if bucket.tophash[i] == 0 && emptySlotKey == nil { // Found an empty slot, store it for if we couldn't find an // existing slot. @@ -109,7 +122,7 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 emptySlotTophash = &bucket.tophash[i] } if bucket.tophash[i] == tophash { - // Could be an existing value that's the same. + // Could be an existing key that's the same. if keyEqual(key, slotKey, uintptr(m.keySize)) { // found same key, replace it memcpy(slotValue, value, uintptr(m.valueSize)) @@ -117,16 +130,37 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 } } } + lastBucket = bucket bucket = bucket.next } - if emptySlotKey != nil { - m.count++ - memcpy(emptySlotKey, key, uintptr(m.keySize)) - memcpy(emptySlotValue, value, uintptr(m.valueSize)) - *emptySlotTophash = tophash + if emptySlotKey == nil { + // Add a new bucket to the bucket chain. + // TODO: rebalance if necessary to avoid O(n) insert and lookup time. + lastBucket.next = (*hashmapBucket)(hashmapInsertIntoNewBucket(m, key, value, tophash)) return } - panic("todo: hashmap: grow bucket") + m.count++ + memcpy(emptySlotKey, key, uintptr(m.keySize)) + memcpy(emptySlotValue, value, uintptr(m.valueSize)) + *emptySlotTophash = tophash +} + +// hashmapInsertIntoNewBucket creates a new bucket, inserts the given key and +// value into the bucket, and returns a pointer to this bucket. +func hashmapInsertIntoNewBucket(m *hashmap, key, value unsafe.Pointer, tophash uint8) *hashmapBucket { + bucketBufSize := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*8 + bucketBuf := alloc(bucketBufSize) + // Insert into the first slot, which is empty as it has just been allocated. + slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + slotKey := unsafe.Pointer(uintptr(bucketBuf) + slotKeyOffset) + slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + slotValue := unsafe.Pointer(uintptr(bucketBuf) + slotValueOffset) + m.count++ + memcpy(slotKey, key, uintptr(m.keySize)) + memcpy(slotValue, value, uintptr(m.valueSize)) + bucket := (*hashmapBucket)(bucketBuf) + bucket.tophash[0] = tophash + return bucket } // Get the value of a specified key, or zero the value if not found. diff --git a/testdata/map.go b/testdata/map.go index fdf615021e..c780c01341 100644 --- a/testdata/map.go +++ b/testdata/map.go @@ -47,6 +47,18 @@ func main() { println(testMapArrayKey[arrKey]) testMapArrayKey[arrKey] = 5555 println(testMapArrayKey[arrKey]) + + // test preallocated map + squares := make(map[int]int, 200) + testBigMap(squares, 100) + println("tested preallocated map") + + // test growing maps + squares = make(map[int]int, 0) + testBigMap(squares, 10) + squares = make(map[int]int, 20) + testBigMap(squares, 40) + println("tested growing of a map") } func readMap(m map[string]int, key string) { @@ -56,7 +68,27 @@ func readMap(m map[string]int, key string) { println(" ", k, "=", v) } } + func lookup(m map[string]int, key string) { value, ok := m[key] println("lookup with comma-ok:", key, value, ok) } + +func testBigMap(squares map[int]int, n int) { + for i := 0; i < n; i++ { + if len(squares) != i { + println("unexpected length:", len(squares), "at i =", i) + } + squares[i] = i*i + for j := 0; j <= i; j++ { + if v, ok := squares[j]; !ok || v != j*j { + if !ok { + println("key not found in squares map:", j) + } else { + println("unexpected value read back from squares map:", j, v) + } + return + } + } + } +} diff --git a/testdata/map.txt b/testdata/map.txt index a41522f782..66636d11f6 100644 --- a/testdata/map.txt +++ b/testdata/map.txt @@ -54,3 +54,5 @@ true false 0 42 4321 5555 +tested preallocated map +tested growing of a map