diff --git a/commonspace/sync/requestmanager.go b/commonspace/sync/requestmanager.go index a5006e79..1ff563ef 100644 --- a/commonspace/sync/requestmanager.go +++ b/commonspace/sync/requestmanager.go @@ -31,14 +31,14 @@ type StreamResponse struct { } type requestManager struct { - requestPool syncqueues.RequestPool + requestPool syncqueues.ActionPool incomingGuard *syncqueues.Guard limit *syncqueues.Limit handler syncdeps.SyncHandler metric syncdeps.QueueSizeUpdater } -func NewRequestManager(handler syncdeps.SyncHandler, metric syncdeps.QueueSizeUpdater, requestPool syncqueues.RequestPool, limit *syncqueues.Limit) RequestManager { +func NewRequestManager(handler syncdeps.SyncHandler, metric syncdeps.QueueSizeUpdater, requestPool syncqueues.ActionPool, limit *syncqueues.Limit) RequestManager { return &requestManager{ requestPool: requestPool, limit: limit, @@ -75,7 +75,7 @@ func (r *requestManager) SendRequest(ctx context.Context, rq syncdeps.Request, c func (r *requestManager) QueueRequest(rq syncdeps.Request) error { size := rq.MsgSize() r.metric.UpdateQueueSize(size, syncdeps.MsgTypeOutgoingRequest, true) - return r.requestPool.QueueRequestAction(rq.PeerId(), rq.ObjectId(), func(ctx context.Context) { + r.requestPool.Add(rq.PeerId(), rq.ObjectId(), func(ctx context.Context) { err := r.handler.ApplyRequest(ctx, rq, r) if err != nil { log.Error("failed to apply request", zap.Error(err), zap.String("limit stats", r.limit.Stats(rq.PeerId()))) @@ -83,6 +83,7 @@ func (r *requestManager) QueueRequest(rq syncdeps.Request) error { }, func() { r.metric.UpdateQueueSize(size, syncdeps.MsgTypeOutgoingRequest, false) }) + return nil } func (r *requestManager) HandleDeprecatedObjectSync(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) { diff --git a/commonspace/sync/sync.go b/commonspace/sync/sync.go index 7bbe3c28..bda37c1c 100644 --- a/commonspace/sync/sync.go +++ b/commonspace/sync/sync.go @@ -78,7 +78,7 @@ func (s *syncService) Init(a *app.App) (err error) { s.streamPool = a.MustComponent(streampool.CName).(streampool.StreamPool) s.commonMetric, _ = a.Component(metric.CName).(metric.Metric) syncQueues := a.MustComponent(syncqueues.CName).(syncqueues.SyncQueues) - s.manager = NewRequestManager(s.handler, s.metric, syncQueues.RequestPool(s.spaceId), syncQueues.Limit(s.spaceId)) + s.manager = NewRequestManager(s.handler, s.metric, syncQueues.ActionPool(s.spaceId), syncQueues.Limit(s.spaceId)) s.ctx, s.cancel = context.WithCancel(context.Background()) return nil } diff --git a/util/syncqueues/tryaddqueue.go b/util/syncqueues/replaceablequeue.go similarity index 63% rename from util/syncqueues/tryaddqueue.go rename to util/syncqueues/replaceablequeue.go index c8332421..5d748fa4 100644 --- a/util/syncqueues/tryaddqueue.go +++ b/util/syncqueues/replaceablequeue.go @@ -16,9 +16,9 @@ type entry struct { cnt uint64 } -func newTryAddQueue(workers, maxSize int) *tryAddQueue { +func newReplaceableQueue(workers, maxSize int) *replaceableQueue { ctx, cancel := context.WithCancel(context.Background()) - ss := &tryAddQueue{ + ss := &replaceableQueue{ ctx: ctx, cancel: cancel, workers: workers, @@ -28,7 +28,7 @@ func newTryAddQueue(workers, maxSize int) *tryAddQueue { return ss } -type tryAddQueue struct { +type replaceableQueue struct { ctx context.Context cancel context.CancelFunc workers int @@ -39,7 +39,7 @@ type tryAddQueue struct { mx sync.Mutex } -func (rp *tryAddQueue) Replace(id string, call, remove func()) { +func (rp *replaceableQueue) Replace(id string, call, remove func()) { curCnt := rp.cnt.Load() rp.cnt.Add(1) rp.mx.Lock() @@ -74,47 +74,13 @@ func (rp *tryAddQueue) Replace(id string, call, remove func()) { } } -func (rp *tryAddQueue) TryAdd(id string, call, remove func()) bool { - curCnt := rp.cnt.Load() - rp.cnt.Add(1) - rp.mx.Lock() - if _, ok := rp.entries[id]; ok { - rp.mx.Unlock() - if remove != nil { - remove() - } - return false - } - ent := entry{ - call: call, - onRemove: remove, - cnt: curCnt, - } - rp.entries[id] = ent - rp.mx.Unlock() - err := rp.batch.TryAdd(id) - if err != nil { - rp.mx.Lock() - curEntry := rp.entries[id] - if curEntry.cnt == curCnt { - delete(rp.entries, id) - } - rp.mx.Unlock() - if ent.onRemove != nil { - ent.onRemove() - } - return false - } - return true -} - -func (rp *tryAddQueue) Run() { +func (rp *replaceableQueue) Run() { for i := 0; i < rp.workers; i++ { - go rp.sendLoop() + go rp.callLoop() } } -func (rp *tryAddQueue) sendLoop() { +func (rp *replaceableQueue) callLoop() { for { id, err := rp.batch.WaitOne(rp.ctx) if err != nil { @@ -135,13 +101,13 @@ func (rp *tryAddQueue) sendLoop() { } } -func (rp *tryAddQueue) ShouldClose(curTime time.Time, timeout time.Duration) bool { +func (rp *replaceableQueue) ShouldClose(curTime time.Time, timeout time.Duration) bool { rp.mx.Lock() defer rp.mx.Unlock() return curTime.Sub(rp.lastServed) > timeout && rp.batch.Len() == 0 } -func (rp *tryAddQueue) Close() (err error) { +func (rp *replaceableQueue) Close() (err error) { rp.cancel() return rp.batch.Close() } diff --git a/util/syncqueues/requestpool.go b/util/syncqueues/requestpool.go index dc905632..e80a204a 100644 --- a/util/syncqueues/requestpool.go +++ b/util/syncqueues/requestpool.go @@ -10,41 +10,39 @@ import ( "github.com/anyproto/any-sync/util/periodicsync" ) -type RequestPool interface { - TryTake(peerId, objectId string) bool - Release(peerId, objectId string) +type ActionPool interface { Run() - QueueRequestAction(peerId, objectId string, action func(ctx context.Context), remove func()) (err error) + Add(peerId, objectId string, action func(ctx context.Context), remove func()) Close() } -type requestPool struct { +type actionPool struct { mu sync.Mutex peerGuard *Guard - pools map[string]*tryAddQueue + queues map[string]*replaceableQueue periodicLoop periodicsync.PeriodicSync closePeriod time.Duration gcPeriod time.Duration ctx context.Context cancel context.CancelFunc - openFunc func(peerId string) *tryAddQueue + openFunc func(peerId string) *replaceableQueue isClosed bool } -func NewRequestPool(closePeriod, gcPeriod time.Duration, openFunc func(peerId string) *tryAddQueue) RequestPool { +func NewActionPool(closePeriod, gcPeriod time.Duration, openFunc func(peerId string) *replaceableQueue) ActionPool { ctx, cancel := context.WithCancel(context.Background()) - return &requestPool{ + return &actionPool{ ctx: ctx, cancel: cancel, closePeriod: closePeriod, gcPeriod: gcPeriod, openFunc: openFunc, - pools: make(map[string]*tryAddQueue), + queues: make(map[string]*replaceableQueue), peerGuard: NewGuard(), } } -func (rp *requestPool) TryTake(peerId, objectId string) bool { +func (rp *actionPool) tryTake(peerId, objectId string) bool { rp.mu.Lock() defer rp.mu.Unlock() if rp.isClosed { @@ -54,68 +52,68 @@ func (rp *requestPool) TryTake(peerId, objectId string) bool { return rp.peerGuard.TryTake(fullId(peerId, objectId)) } -func (rp *requestPool) Release(peerId, objectId string) { +func (rp *actionPool) release(peerId, objectId string) { rp.peerGuard.Release(fullId(peerId, objectId)) } -func (rp *requestPool) Run() { +func (rp *actionPool) Run() { rp.periodicLoop = periodicsync.NewPeriodicSyncDuration(rp.gcPeriod, time.Minute, rp.gc, log) rp.periodicLoop.Run() } -func (rp *requestPool) gc(ctx context.Context) error { +func (rp *actionPool) gc(ctx context.Context) error { rp.mu.Lock() - var poolsToClose []*tryAddQueue + var queuesToClose []*replaceableQueue tm := time.Now() - for id, pool := range rp.pools { - if pool.ShouldClose(tm, rp.closePeriod) { - delete(rp.pools, id) - log.Debug("closing pool", zap.String("peerId", id)) - poolsToClose = append(poolsToClose, pool) + for id, queue := range rp.queues { + if queue.ShouldClose(tm, rp.closePeriod) { + delete(rp.queues, id) + log.Debug("closing queue", zap.String("peerId", id)) + queuesToClose = append(queuesToClose, queue) } } rp.mu.Unlock() - for _, pool := range poolsToClose { - _ = pool.Close() + for _, queue := range queuesToClose { + _ = queue.Close() } return nil } -func (rp *requestPool) QueueRequestAction(peerId, objectId string, action func(ctx context.Context), remove func()) (err error) { +func (rp *actionPool) Add(peerId, objectId string, action func(ctx context.Context), remove func()) { rp.mu.Lock() if rp.isClosed { rp.mu.Unlock() - return nil + return } var ( - pool *tryAddQueue + queue *replaceableQueue exists bool ) - pool, exists = rp.pools[peerId] + queue, exists = rp.queues[peerId] if !exists { - pool = rp.openFunc(peerId) - rp.pools[peerId] = pool - pool.Run() + queue = rp.openFunc(peerId) + rp.queues[peerId] = queue + queue.Run() } rp.mu.Unlock() var wrappedAction func() wrappedAction = func() { - if !rp.TryTake(peerId, objectId) { + // this prevents cases when two simultaneous requests are sent at the same time + if !rp.tryTake(peerId, objectId) { return } action(rp.ctx) - rp.Release(peerId, objectId) + rp.release(peerId, objectId) } - pool.Replace(objectId, wrappedAction, remove) - return nil + queue.Replace(objectId, wrappedAction, remove) } -func (rp *requestPool) Close() { +func (rp *actionPool) Close() { rp.periodicLoop.Close() rp.mu.Lock() defer rp.mu.Unlock() rp.isClosed = true - for _, pool := range rp.pools { - _ = pool.Close() + for _, queue := range rp.queues { + _ = queue.Close() } } diff --git a/util/syncqueues/requestpool_test.go b/util/syncqueues/requestpool_test.go new file mode 100644 index 00000000..7762aed4 --- /dev/null +++ b/util/syncqueues/requestpool_test.go @@ -0,0 +1,146 @@ +package syncqueues + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func TestRequestPool(t *testing.T) { + t.Run("parallel, different peer, same object", func(t *testing.T) { + rp := NewActionPool(time.Minute, time.Minute, func(peerId string) *replaceableQueue { + return newReplaceableQueue(1, 1) + }) + rp.Run() + // we use wait channel to make sure that blocking does not prevent action from being called + wait := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(2) + rp.Add("peerId", "objectId", func(ctx context.Context) { + wg.Done() + <-wait + }, func() {}) + rp.Add("peerId1", "objectId", func(ctx context.Context) { + wg.Done() + <-wait + }, func() {}) + wg.Wait() + rp.Close() + }) + t.Run("parallel, same peer, different object", func(t *testing.T) { + rp := NewActionPool(time.Minute, time.Minute, func(peerId string) *replaceableQueue { + return newReplaceableQueue(2, 2) + }) + rp.Run() + // we use wait channel to make sure that blocking does not prevent action from being called + wait := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(2) + rp.Add("peerId", "objectId", func(ctx context.Context) { + wg.Done() + <-wait + }, func() {}) + rp.Add("peerId", "objectId1", func(ctx context.Context) { + wg.Done() + <-wait + }, func() {}) + wg.Wait() + rp.Close() + }) + t.Run("parallel, same peer, same object", func(t *testing.T) { + rp := NewActionPool(time.Minute, time.Minute, func(peerId string) *replaceableQueue { + return newReplaceableQueue(2, 2) + }) + rp.Run() + // here we are checking that the second action is not called in parallel, + // when the first action is not finished + wait := make(chan struct{}) + cnt := atomic.NewBool(false) + wg := &sync.WaitGroup{} + wg.Add(1) + rp.Add("peerId", "objectId", func(ctx context.Context) { + cnt.Store(true) + wg.Done() + <-wait + }, func() {}) + time.Sleep(100 * time.Millisecond) + rp.Add("peerId", "objectId", func(ctx context.Context) { + require.Fail(t, "should not be called") + wg.Done() + <-wait + }, func() {}) + wg.Wait() + time.Sleep(100 * time.Millisecond) + require.True(t, cnt.Load()) + rp.Close() + }) + t.Run("parallel, same peer, different object, replace", func(t *testing.T) { + rp := NewActionPool(time.Minute, time.Minute, func(peerId string) *replaceableQueue { + return newReplaceableQueue(1, 3) + }) + rp.Run() + // we expect the second action to be replaced + wait := make(chan struct{}) + cnt := atomic.NewBool(false) + wg := &sync.WaitGroup{} + wg.Add(1) + rp.Add("peerId", "objectId", func(ctx context.Context) { + <-wait + }, func() {}) + rp.Add("peerId", "objectId1", func(ctx context.Context) { + require.Fail(t, "should not be called") + }, func() {}) + rp.Add("peerId", "objectId1", func(ctx context.Context) { + cnt.Store(true) + wg.Done() + }, func() {}) + close(wait) + wg.Wait() + time.Sleep(100 * time.Millisecond) + require.True(t, cnt.Load()) + rp.Close() + }) + t.Run("parallel, same peer, different object, try add failed", func(t *testing.T) { + rp := NewActionPool(time.Minute, time.Minute, func(peerId string) *replaceableQueue { + return newReplaceableQueue(1, 1) + }) + rp.Run() + // we expect try add to fail and call remove action + wait := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(1) + rp.Add("peerId", "objectId", func(ctx context.Context) { + <-wait + }, func() {}) + rp.Add("peerId", "objectId1", func(ctx context.Context) { + require.Fail(t, "should not be called") + }, func() { + wg.Done() + }) + close(wait) + wg.Wait() + rp.Close() + }) + t.Run("gc", func(t *testing.T) { + rp := NewActionPool(time.Millisecond*20, time.Millisecond*20, func(peerId string) *replaceableQueue { + return newReplaceableQueue(2, 2) + }) + rp.Run() + wg := &sync.WaitGroup{} + wg.Add(2) + rp.Add("peerId1", "objectId1", func(ctx context.Context) { + wg.Done() + }, func() {}) + rp.Add("peerId2", "objectId2", func(ctx context.Context) { + wg.Done() + }, func() {}) + wg.Wait() + time.Sleep(200 * time.Millisecond) + require.Empty(t, rp.(*actionPool).queues) + rp.Close() + }) +} diff --git a/util/syncqueues/sync.go b/util/syncqueues/sync.go index 1355b8bb..1f1cbe99 100644 --- a/util/syncqueues/sync.go +++ b/util/syncqueues/sync.go @@ -18,7 +18,7 @@ var log = logger.NewNamed(CName) type SyncQueues interface { app.ComponentRunnable - RequestPool(spaceId string) RequestPool + ActionPool(spaceId string) ActionPool Limit(spaceId string) *Limit } @@ -28,7 +28,7 @@ func New() SyncQueues { type syncQueues struct { limit *Limit - rp RequestPool + pool ActionPool nodeConf nodeconf.Service accountService accountService.Service } @@ -47,12 +47,12 @@ func (g *syncQueues) Init(a *app.App) (err error) { iAmResponsible = true } } - g.rp = NewRequestPool(time.Second*30, time.Minute, func(peerId string) *tryAddQueue { + g.pool = NewActionPool(time.Second*30, time.Minute, func(peerId string) *replaceableQueue { // increase limits between responsible nodes if slices.Contains(nodeIds, peerId) && iAmResponsible { - return newTryAddQueue(30, 400) + return newReplaceableQueue(30, 400) } else { - return newTryAddQueue(10, 100) + return newReplaceableQueue(10, 100) } }) g.limit = NewLimit([]int{20, 15, 10, 5}, []int{200, 400, 800}, nodeIds, 100) @@ -60,17 +60,17 @@ func (g *syncQueues) Init(a *app.App) (err error) { } func (g *syncQueues) Run(ctx context.Context) (err error) { - g.rp.Run() + g.pool.Run() return } func (g *syncQueues) Close(ctx context.Context) (err error) { - g.rp.Close() + g.pool.Close() return } -func (g *syncQueues) RequestPool(spaceId string) RequestPool { - return g.rp +func (g *syncQueues) ActionPool(spaceId string) ActionPool { + return g.pool } func (g *syncQueues) Limit(spaceId string) *Limit {