diff --git a/export_test.go b/export_test.go index 5e6d74668..f259fca0f 100644 --- a/export_test.go +++ b/export_test.go @@ -93,3 +93,11 @@ func GetSlavesAddrByName(ctx context.Context, c *SentinelClient, name string) [] } return parseReplicaAddrs(addrs, false) } + +func (c *Ring) ShardByName(name string) *ringShard { + return c.sharding.ShardByName(name) +} + +func (c *ringSharding) ShardByName(name string) *ringShard { + return c.shards.m[name] +} diff --git a/ring.go b/ring.go index dede1e495..a8e08df08 100644 --- a/ring.go +++ b/ring.go @@ -48,8 +48,8 @@ type RingOptions struct { // Map of name => host:port addresses of ring shards. Addrs map[string]string - // NewClient creates a shard client with provided name and options. - NewClient func(name string, opt *Options) *Client + // NewClient creates a shard client with provided options. + NewClient func(opt *Options) *Client // Frequency of PING commands sent to check shards availability. // Shard is considered down after 3 subsequent failed checks. @@ -95,7 +95,7 @@ type RingOptions struct { func (opt *RingOptions) init() { if opt.NewClient == nil { - opt.NewClient = func(name string, opt *Options) *Client { + opt.NewClient = func(opt *Options) *Client { return NewClient(opt) } } @@ -160,14 +160,16 @@ func (opt *RingOptions) clientOptions() *Options { type ringShard struct { Client *Client down int32 + addr string } -func newRingShard(opt *RingOptions, name, addr string) *ringShard { +func newRingShard(opt *RingOptions, addr string) *ringShard { clopt := opt.clientOptions() clopt.Addr = addr return &ringShard{ - Client: opt.NewClient(name, clopt), + Client: opt.NewClient(clopt), + addr: addr, } } @@ -208,52 +210,102 @@ func (shard *ringShard) Vote(up bool) bool { //------------------------------------------------------------------------------ -type ringShards struct { +type ringSharding struct { opt *RingOptions mu sync.RWMutex + shards *ringShards + closed bool hash ConsistentHash - shards map[string]*ringShard // read only - list []*ringShard // read only numShard int - closed bool } -func newRingShards(opt *RingOptions) *ringShards { - shards := make(map[string]*ringShard, len(opt.Addrs)) - list := make([]*ringShard, 0, len(shards)) - - for name, addr := range opt.Addrs { - shard := newRingShard(opt, name, addr) - shards[name] = shard +type ringShards struct { + m map[string]*ringShard + list []*ringShard +} - list = append(list, shard) +func newRingSharding(opt *RingOptions) *ringSharding { + c := &ringSharding{ + opt: opt, } + c.SetAddrs(opt.Addrs) - c := &ringShards{ - opt: opt, + return c +} + +// SetAddrs replaces the shards in use, such that you can increase and +// decrease number of shards, that you use. It will reuse shards that +// existed before and close the ones that will not be used anymore. +func (c *ringSharding) SetAddrs(addrs map[string]string) { + c.mu.Lock() - shards: shards, - list: list, + if c.closed { + c.mu.Unlock() + return } + + shards, cleanup := newRingShards(c.opt, addrs, c.shards) + c.shards = shards + c.mu.Unlock() + c.rebalance() + cleanup() +} - return c +func newRingShards( + opt *RingOptions, addrs map[string]string, existingShards *ringShards, +) (*ringShards, func()) { + shardMap := make(map[string]*ringShard) // indexed by addr + unusedShards := make(map[string]*ringShard) // indexed by addr + + if existingShards != nil { + for _, shard := range existingShards.list { + addr := shard.Client.opt.Addr + shardMap[addr] = shard + unusedShards[addr] = shard + } + } + + shards := &ringShards{ + m: make(map[string]*ringShard), + } + + for name, addr := range addrs { + if shard, ok := shardMap[addr]; ok { + shards.m[name] = shard + delete(unusedShards, addr) + } else { + shards.m[name] = newRingShard(opt, addr) + } + } + + for _, shard := range shards.m { + shards.list = append(shards.list, shard) + } + + return shards, func() { + for addr, shard := range unusedShards { + if err := shard.Client.Close(); err != nil { + internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + } + } + } } -func (c *ringShards) List() []*ringShard { +func (c *ringSharding) List() []*ringShard { var list []*ringShard c.mu.RLock() if !c.closed { - list = c.list + list = c.shards.list } c.mu.RUnlock() return list } -func (c *ringShards) Hash(key string) string { +func (c *ringSharding) Hash(key string) string { key = hashtag.Key(key) var hash string @@ -268,7 +320,7 @@ func (c *ringShards) Hash(key string) string { return hash } -func (c *ringShards) GetByKey(key string) (*ringShard, error) { +func (c *ringSharding) GetByKey(key string) (*ringShard, error) { key = hashtag.Key(key) c.mu.RLock() @@ -282,15 +334,14 @@ func (c *ringShards) GetByKey(key string) (*ringShard, error) { return nil, errRingShardsDown } - hash := c.hash.Get(key) - if hash == "" { + shardName := c.hash.Get(key) + if shardName == "" { return nil, errRingShardsDown } - - return c.shards[hash], nil + return c.shards.m[shardName], nil } -func (c *ringShards) GetByName(shardName string) (*ringShard, error) { +func (c *ringSharding) GetByName(shardName string) (*ringShard, error) { if shardName == "" { return c.Random() } @@ -298,15 +349,15 @@ func (c *ringShards) GetByName(shardName string) (*ringShard, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.shards[shardName], nil + return c.shards.m[shardName], nil } -func (c *ringShards) Random() (*ringShard, error) { +func (c *ringSharding) Random() (*ringShard, error) { return c.GetByKey(strconv.Itoa(rand.Int())) } // Heartbeat monitors state of each shard in the ring. -func (c *ringShards) Heartbeat(ctx context.Context, frequency time.Duration) { +func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() @@ -334,14 +385,18 @@ func (c *ringShards) Heartbeat(ctx context.Context, frequency time.Duration) { } // rebalance removes dead shards from the Ring. -func (c *ringShards) rebalance() { +func (c *ringSharding) rebalance() { c.mu.RLock() shards := c.shards c.mu.RUnlock() - liveShards := make([]string, 0, len(shards)) + if shards == nil { + return + } + + liveShards := make([]string, 0, len(shards.m)) - for name, shard := range shards { + for name, shard := range shards.m { if shard.IsUp() { liveShards = append(liveShards, name) } @@ -350,19 +405,21 @@ func (c *ringShards) rebalance() { hash := c.opt.NewConsistentHash(liveShards) c.mu.Lock() - c.hash = hash - c.numShard = len(liveShards) + if !c.closed { + c.hash = hash + c.numShard = len(liveShards) + } c.mu.Unlock() } -func (c *ringShards) Len() int { +func (c *ringSharding) Len() int { c.mu.RLock() defer c.mu.RUnlock() return c.numShard } -func (c *ringShards) Close() error { +func (c *ringSharding) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -372,7 +429,8 @@ func (c *ringShards) Close() error { c.closed = true var firstErr error - for _, shard := range c.shards { + + for _, shard := range c.shards.list { if err := shard.Client.Close(); err != nil && firstErr == nil { firstErr = err } @@ -381,20 +439,12 @@ func (c *ringShards) Close() error { c.hash = nil c.shards = nil c.numShard = 0 - c.list = nil return firstErr } //------------------------------------------------------------------------------ -type ring struct { - opt *RingOptions - shards *ringShards - cmdsInfoCache *cmdsInfoCache //nolint:structcheck - heartbeatCancelFn context.CancelFunc -} - // Ring is a Redis client that uses consistent hashing to distribute // keys across multiple Redis servers (shards). It's safe for // concurrent use by multiple goroutines. @@ -410,7 +460,11 @@ type ring struct { // and can tolerate losing data when one of the servers dies. // Otherwise you should use Redis Cluster. type Ring struct { - *ring + opt *RingOptions + sharding *ringSharding + cmdsInfoCache *cmdsInfoCache + heartbeatCancelFn context.CancelFunc + cmdable hooks } @@ -421,21 +475,23 @@ func NewRing(opt *RingOptions) *Ring { hbCtx, hbCancel := context.WithCancel(context.Background()) ring := Ring{ - ring: &ring{ - opt: opt, - shards: newRingShards(opt), - heartbeatCancelFn: hbCancel, - }, + opt: opt, + sharding: newRingSharding(opt), + heartbeatCancelFn: hbCancel, } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - go ring.shards.Heartbeat(hbCtx, opt.HeartbeatFrequency) + go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) return &ring } +func (c *Ring) SetAddrs(addrs map[string]string) { + c.sharding.SetAddrs(addrs) +} + // Do creates a Cmd from the args and processes the cmd. func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(ctx, args...) @@ -458,7 +514,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration { // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { - shards := c.shards.List() + shards := c.sharding.List() var acc PoolStats for _, shard := range shards { s := shard.Client.connPool.Stats() @@ -473,7 +529,7 @@ func (c *Ring) PoolStats() *PoolStats { // Len returns the current number of shards in the ring. func (c *Ring) Len() int { - return c.shards.Len() + return c.sharding.Len() } // Subscribe subscribes the client to the specified channels. @@ -482,7 +538,7 @@ func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shards.GetByKey(channels[0]) + shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -496,7 +552,7 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shards.GetByKey(channels[0]) + shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -510,7 +566,7 @@ func (c *Ring) ForEachShard( ctx context.Context, fn func(ctx context.Context, client *Client) error, ) error { - shards := c.shards.List() + shards := c.sharding.List() var wg sync.WaitGroup errCh := make(chan error, 1) for _, shard := range shards { @@ -541,7 +597,7 @@ func (c *Ring) ForEachShard( } func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { - shards := c.shards.List() + shards := c.sharding.List() var firstErr error for _, shard := range shards { cmdsInfo, err := shard.Client.Command(ctx).Result() @@ -574,10 +630,10 @@ func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { cmdInfo := c.cmdInfo(ctx, cmd.Name()) pos := cmdFirstKeyPos(cmd, cmdInfo) if pos == 0 { - return c.shards.Random() + return c.sharding.Random() } firstKey := cmd.stringArg(pos) - return c.shards.GetByKey(firstKey) + return c.sharding.GetByKey(firstKey) } func (c *Ring) process(ctx context.Context, cmd Cmder) error { @@ -646,7 +702,7 @@ func (c *Ring) generalProcessPipeline( cmdInfo := c.cmdInfo(ctx, cmd.Name()) hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) if hash != "" { - hash = c.shards.Hash(hash) + hash = c.sharding.Hash(hash) } cmdsMap[hash] = append(cmdsMap[hash], cmd) } @@ -669,7 +725,7 @@ func (c *Ring) processShardPipeline( ctx context.Context, hash string, cmds []Cmder, tx bool, ) error { // TODO: retry? - shard, err := c.shards.GetByName(hash) + shard, err := c.sharding.GetByName(hash) if err != nil { setCmdsErr(cmds, err) return err @@ -689,7 +745,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er var shards []*ringShard for _, key := range keys { if key != "" { - shard, err := c.shards.GetByKey(hashtag.Key(key)) + shard, err := c.sharding.GetByKey(hashtag.Key(key)) if err != nil { return err } @@ -721,5 +777,5 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er func (c *Ring) Close() error { c.heartbeatCancelFn() - return c.shards.Close() + return c.sharding.Close() } diff --git a/ring_test.go b/ring_test.go index 1a6ec84b9..e0f58b237 100644 --- a/ring_test.go +++ b/ring_test.go @@ -113,6 +113,71 @@ var _ = Describe("Redis Ring", func() { Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100")) }) + Describe("[new] dynamic setting ring shards", func() { + It("downscale shard and check reuse shard, upscale shard and check reuse", func() { + Expect(ring.Len(), 2) + + wantShard := ring.ShardByName("ringShardOne") + ring.SetAddrs(map[string]string{ + "ringShardOne": ":" + ringShard1Port, + }) + Expect(ring.Len(), 1) + gotShard := ring.ShardByName("ringShardOne") + Expect(gotShard).To(Equal(wantShard)) + + ring.SetAddrs(map[string]string{ + "ringShardOne": ":" + ringShard1Port, + "ringShardTwo": ":" + ringShard2Port, + }) + Expect(ring.Len(), 2) + gotShard = ring.ShardByName("ringShardOne") + Expect(gotShard).To(Equal(wantShard)) + }) + + It("uses 3 shards after setting it to 3 shards", func() { + Expect(ring.Len(), 2) + + // Start ringShard3. + var err error + ringShard3, err = startRedis(ringShard3Port) + Expect(err).NotTo(HaveOccurred()) + defer ringShard3.Close() + + shardName1 := "ringShardOne" + shardAddr1 := ":" + ringShard1Port + wantShard1 := ring.ShardByName(shardName1) + shardName2 := "ringShardTwo" + shardAddr2 := ":" + ringShard2Port + wantShard2 := ring.ShardByName(shardName2) + shardName3 := "ringShardThree" + shardAddr3 := ":" + ringShard3Port + + ring.SetAddrs(map[string]string{ + shardName1: shardAddr1, + shardName2: shardAddr2, + shardName3: shardAddr3, + }) + Expect(ring.Len(), 3) + gotShard1 := ring.ShardByName(shardName1) + gotShard2 := ring.ShardByName(shardName2) + gotShard3 := ring.ShardByName(shardName3) + Expect(gotShard1).To(Equal(wantShard1)) + Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard3).ToNot(BeNil()) + + ring.SetAddrs(map[string]string{ + shardName1: shardAddr1, + shardName2: shardAddr2, + }) + Expect(ring.Len(), 2) + gotShard1 = ring.ShardByName(shardName1) + gotShard2 = ring.ShardByName(shardName2) + gotShard3 = ring.ShardByName(shardName3) + Expect(gotShard1).To(Equal(wantShard1)) + Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard3).To(BeNil()) + }) + }) Describe("pipeline", func() { It("doesn't panic closed ring, returns error", func() { pipe := ring.Pipeline() @@ -190,7 +255,7 @@ var _ = Describe("Redis Ring", func() { Describe("new client callback", func() { It("can be initialized with a new client callback", func() { opts := redisRingOptions() - opts.NewClient = func(name string, opt *redis.Options) *redis.Client { + opts.NewClient = func(opt *redis.Options) *redis.Client { opt.Username = "username1" opt.Password = "password1" return redis.NewClient(opt)