diff --git a/go.mod b/go.mod index 9bd856a..9d0d0f6 100644 --- a/go.mod +++ b/go.mod @@ -2,20 +2,19 @@ module github.com/coder/hnsw go 1.21.4 -require github.com/stretchr/testify v1.9.0 - -require github.com/google/renameio v1.0.1 - require ( - github.com/chewxy/math32 v1.10.1 // indirect - github.com/viterin/partial v1.1.0 // indirect - github.com/viterin/vek v0.4.2 // indirect - golang.org/x/sys v0.11.0 // indirect + github.com/google/renameio v1.0.1 + github.com/stretchr/testify v1.9.0 ) require ( + github.com/chewxy/math32 v1.10.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/viterin/partial v1.1.0 // indirect + github.com/viterin/vek v0.4.2 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/sys v0.11.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f571897..f74eb2b 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,14 @@ github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU= github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/renameio v1.0.1 h1:Lh/jXZmvZxb0BBeSY5VKEfidcbcbenKjZFzM/q0fSeU= github.com/google/renameio v1.0.1/go.mod h1:t/HQoYBZSsWSNK35C6CO/TpPLDVWvxOHboWUAweKUpk= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= @@ -16,7 +21,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/graph.go b/graph.go index 88fdeba..5dba454 100644 --- a/graph.go +++ b/graph.go @@ -402,13 +402,35 @@ func (g *Graph[K]) Add(nodes ...Node[K]) { // Invariant check: the node should have been added to the graph. if g.Len() != preLen+1 { - panic("node not added") + if len(g.layers) > 0 && g.layers[len(g.layers)-1].entry() == nil { + g.layers = g.layers[:len(g.layers)-1] + } } } } // Search finds the k nearest neighbors from the target node. func (h *Graph[K]) Search(near Vector, k int) []Node[K] { + sr := h.search(near, k) + out := make([]Node[K], len(sr)) + for i, node := range sr { + out[i] = node.Node + } + return out +} + +// SearchWithDistance finds the k nearest neighbors from the target node +// and returns the distance. +func (h *Graph[K]) SearchWithDistance(near Vector, k int) []SearchResult[K] { + return h.search(near, k) +} + +type SearchResult[T cmp.Ordered] struct { + Node[T] + Distance float32 +} + +func (h *Graph[K]) search(near Vector, k int) []SearchResult[K] { h.assertDims(near) if len(h.layers) == 0 { return nil @@ -434,10 +456,13 @@ func (h *Graph[K]) Search(near Vector, k int) []Node[K] { } nodes := searchPoint.search(k, efSearch, near, h.Distance) - out := make([]Node[K], 0, len(nodes)) + out := make([]SearchResult[K], 0, len(nodes)) for _, node := range nodes { - out = append(out, node.node.Node) + out = append(out, SearchResult[K]{ + Node: node.node.Node, + Distance: node.dist, + }) } return out @@ -462,17 +487,33 @@ func (h *Graph[K]) Delete(key K) bool { return false } + var deleteLayer = map[int]struct{}{} var deleted bool - for _, layer := range h.layers { + for i, layer := range h.layers { node, ok := layer.nodes[key] if !ok { continue } delete(layer.nodes, key) + if len(layer.nodes) == 0 { + deleteLayer[i] = struct{}{} + } node.isolate(h.M) deleted = true } + if len(deleteLayer) > 0 { + var newLayers = make([]*layer[K], 0, len(h.layers)-len(deleteLayer)) + for i, layer := range h.layers { + if _, ok := deleteLayer[i]; ok { + continue + } + newLayers = append(newLayers, layer) + } + + h.layers = newLayers + } + return deleted } diff --git a/graph_test.go b/graph_test.go index d2a9cab..df795e6 100644 --- a/graph_test.go +++ b/graph_test.go @@ -248,3 +248,14 @@ func TestGraph_DefaultCosine(t *testing.T) { neighbors, ) } + +func TestGraph_RemoveAllNodes(t *testing.T) { + var vec = []float32{1} + + for i := 0; i < 10; i++ { + g := NewGraph[int]() + g.Add(MakeNode(1, vec)) + g.Delete(1) + g.Add(MakeNode(1, vec)) + } +}