1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00

sync fixes + fix tests

This commit is contained in:
Sergey Cherepanov 2023-01-20 14:55:45 +03:00 committed by Mikhail Iudin
parent 34848254be
commit 5fb6ee5a7b
No known key found for this signature in database
GPG key ID: FAAAA8BAABDFF1C0
16 changed files with 197 additions and 392 deletions

View file

@ -1,37 +0,0 @@
package synctree
import (
"context"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anytypeio/any-sync/commonspace/objectsync"
)
type queuedClient struct {
SyncClient
queue objectsync.ActionQueue
}
func newQueuedClient(client SyncClient, queue objectsync.ActionQueue) SyncClient {
return &queuedClient{
SyncClient: client,
queue: queue,
}
}
func (q *queuedClient) Broadcast(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) {
return q.queue.Send(func() error {
return q.SyncClient.Broadcast(ctx, message)
})
}
func (q *queuedClient) SendWithReply(ctx context.Context, peerId string, message *treechangeproto.TreeSyncMessage, replyId string) (err error) {
return q.queue.Send(func() error {
return q.SyncClient.SendWithReply(ctx, peerId, message, replyId)
})
}
func (q *queuedClient) BroadcastAsyncOrSendResponsible(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) {
return q.queue.Send(func() error {
return q.SyncClient.BroadcastAsyncOrSendResponsible(ctx, message)
})
}

View file

@ -52,7 +52,7 @@ type syncTree struct {
var log = logger.NewNamed("commonspace.synctree").Sugar()
var buildObjectTree = objecttree.BuildObjectTree
var createSyncClient = newWrappedSyncClient
var createSyncClient = newSyncClient
type BuildDeps struct {
SpaceId string
@ -68,15 +68,6 @@ type BuildDeps struct {
WaitTreeRemoteSync bool
}
func newWrappedSyncClient(
spaceId string,
factory RequestFactory,
objectSync objectsync.ObjectSync,
configuration nodeconf.Configuration) SyncClient {
syncClient := newSyncClient(spaceId, objectSync.MessagePool(), factory, configuration)
return newQueuedClient(syncClient, objectSync.ActionQueue())
}
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t SyncTree, err error) {
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
peerId, err := peer.CtxPeerId(ctx)
@ -182,8 +173,8 @@ func buildSyncTree(ctx context.Context, isFirstBuild bool, deps BuildDeps) (t Sy
}
syncClient := createSyncClient(
deps.SpaceId,
deps.ObjectSync.MessagePool(),
sharedFactory,
deps.ObjectSync,
deps.Configuration)
syncTree := &syncTree{
ObjectTree: objTree,

View file

@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Update(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err)
require.Equal(t, expectedRes, res)
@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Rebuild(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err)
require.Equal(t, expectedRes, res)
@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) {
Return(expectedRes, nil)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
res, err := tr.AddContent(ctx, content)
require.NoError(t, err)
require.Equal(t, expectedRes, res)

View file

@ -38,7 +38,7 @@ type syncHandlerFixture struct {
ctrl *gomock.Controller
syncClientMock *mock_synctree.MockSyncClient
objectTreeMock *testObjTreeMock
receiveQueueMock *mock_synctree.MockReceiveQueue
receiveQueueMock ReceiveQueue
syncHandler *syncTreeHandler
}
@ -47,19 +47,19 @@ func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
ctrl := gomock.NewController(t)
syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
receiveQueueMock := mock_synctree.NewMockReceiveQueue(ctrl)
receiveQueue := newReceiveQueue(5)
syncHandler := &syncTreeHandler{
objTree: objectTreeMock,
syncClient: syncClientMock,
queue: receiveQueueMock,
queue: receiveQueue,
syncStatus: syncstatus.NewNoOpSyncStatus(),
}
return &syncHandlerFixture{
ctrl: ctrl,
syncClientMock: syncClientMock,
objectTreeMock: objectTreeMock,
receiveQueueMock: receiveQueueMock,
receiveQueueMock: receiveQueue,
syncHandler: syncHandler,
}
}
@ -84,10 +84,7 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2)
@ -101,7 +98,6 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2", "h1"})
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -118,10 +114,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fullRequest := &treechangeproto.TreeSyncMessage{}
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
@ -136,9 +130,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
fx.syncClientMock.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -155,14 +148,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -179,19 +169,16 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fullRequest := &treechangeproto.TreeSyncMessage{}
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.syncClientMock.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -208,14 +195,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -237,10 +221,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fullResponse := &treechangeproto.TreeSyncMessage{}
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
@ -255,9 +237,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -274,10 +255,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fullResponse := &treechangeproto.TreeSyncMessage{}
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
@ -288,9 +267,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -307,10 +285,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
fullResponse := &treechangeproto.TreeSyncMessage{}
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
@ -318,9 +294,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -337,9 +312,7 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
@ -356,9 +329,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, fmt.Errorf(""))
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Any(), gomock.Eq(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Any(), gomock.Eq(""))
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.Error(t, err)
})
@ -381,9 +353,7 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().
@ -399,7 +369,6 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
})).
Return(objecttree.AddResult{}, nil)
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
@ -417,16 +386,13 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})

View file

@ -1,78 +0,0 @@
package objectsync
import (
"context"
"github.com/cheggaaa/mb/v3"
"go.uber.org/zap"
)
type ActionFunc func() error
type ActionQueue interface {
Send(action ActionFunc) (err error)
Run()
Close()
}
type actionQueue struct {
batcher *mb.MB[ActionFunc]
maxReaders int
maxQueueLen int
readers chan struct{}
}
func NewDefaultActionQueue() ActionQueue {
return NewActionQueue(10, 200)
}
func NewActionQueue(maxReaders int, maxQueueLen int) ActionQueue {
return &actionQueue{
batcher: mb.New[ActionFunc](maxQueueLen),
maxReaders: maxReaders,
maxQueueLen: maxQueueLen,
}
}
func (q *actionQueue) Send(action ActionFunc) (err error) {
log.Debug("adding action to batcher")
err = q.batcher.TryAdd(action)
if err == nil {
return
}
log.With(zap.Error(err)).Debug("queue returned error")
actions := q.batcher.GetAll()
actions = append(actions[len(actions)/2:], action)
return q.batcher.Add(context.Background(), actions...)
}
func (q *actionQueue) Run() {
log.Debug("running the queue")
q.readers = make(chan struct{}, q.maxReaders)
for i := 0; i < q.maxReaders; i++ {
go q.startReading()
}
}
func (q *actionQueue) startReading() {
defer func() {
q.readers <- struct{}{}
}()
for {
action, err := q.batcher.WaitOne(context.Background())
if err != nil {
return
}
err = action()
if err != nil {
log.With(zap.Error(err)).Debug("action errored out")
}
}
}
func (q *actionQueue) Close() {
log.Debug("closing the queue")
q.batcher.Close()
for i := 0; i < q.maxReaders; i++ {
<-q.readers
}
}

View file

@ -1,54 +0,0 @@
package objectsync
import (
"fmt"
"github.com/stretchr/testify/require"
"sync/atomic"
"testing"
)
func TestActionQueue_Send(t *testing.T) {
maxReaders := 41
maxLen := 93
queue := NewActionQueue(maxReaders, maxLen).(*actionQueue)
counter := atomic.Int32{}
expectedCounter := int32(maxReaders + (maxLen+1)/2 + 1)
blocker := make(chan struct{}, expectedCounter)
waiter := make(chan struct{}, expectedCounter)
increase := func() error {
counter.Add(1)
waiter <- struct{}{}
<-blocker
return nil
}
queue.Run()
// sending maxReaders messages, so the goroutines will block on `blocker` channel
for i := 0; i < maxReaders; i++ {
queue.Send(increase)
}
// waiting until they all make progress
for i := 0; i < maxReaders; i++ {
<-waiter
}
fmt.Println(counter.Load())
// check that queue is empty
require.Equal(t, queue.batcher.Len(), 0)
// making queue to overflow while readers are blocked
for i := 0; i < maxLen+1; i++ {
queue.Send(increase)
}
// check that queue was halved after overflow
require.Equal(t, (maxLen+1)/2+1, queue.batcher.Len())
// unblocking maxReaders waiting + then we should also unblock the new readers to do a bit more readings
for i := 0; i < int(expectedCounter); i++ {
blocker <- struct{}{}
}
// waiting for all readers to finish adding
for i := 0; i < int(expectedCounter)-maxReaders; i++ {
<-waiter
}
queue.Close()
require.Equal(t, expectedCounter, counter.Load())
}

View file

@ -1,73 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/commonspace/objectsync (interfaces: ActionQueue)
// Package mock_objectsync is a generated GoMock package.
package mock_objectsync
import (
reflect "reflect"
objectsync "github.com/anytypeio/any-sync/commonspace/objectsync"
gomock "github.com/golang/mock/gomock"
)
// MockActionQueue is a mock of ActionQueue interface.
type MockActionQueue struct {
ctrl *gomock.Controller
recorder *MockActionQueueMockRecorder
}
// MockActionQueueMockRecorder is the mock recorder for MockActionQueue.
type MockActionQueueMockRecorder struct {
mock *MockActionQueue
}
// NewMockActionQueue creates a new mock instance.
func NewMockActionQueue(ctrl *gomock.Controller) *MockActionQueue {
mock := &MockActionQueue{ctrl: ctrl}
mock.recorder = &MockActionQueueMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockActionQueue) EXPECT() *MockActionQueueMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockActionQueue) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockActionQueueMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockActionQueue)(nil).Close))
}
// Run mocks base method.
func (m *MockActionQueue) Run() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Run")
}
// Run indicates an expected call of Run.
func (mr *MockActionQueueMockRecorder) Run() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockActionQueue)(nil).Run))
}
// Send mocks base method.
func (m *MockActionQueue) Send(arg0 objectsync.ActionFunc) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Send", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Send indicates an expected call of Send.
func (mr *MockActionQueueMockRecorder) Send(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockActionQueue)(nil).Send), arg0)
}

View file

@ -9,6 +9,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
)
type StreamManager interface {
@ -37,7 +38,6 @@ type messagePool struct {
waiters map[string]responseWaiter
waitersMx sync.Mutex
counter atomic.Uint64
queue ActionQueue
}
func newMessagePool(streamManager StreamManager, messageHandler MessageHandler) MessagePool {
@ -45,15 +45,17 @@ func newMessagePool(streamManager StreamManager, messageHandler MessageHandler)
StreamManager: streamManager,
messageHandler: messageHandler,
waiters: make(map[string]responseWaiter),
queue: NewDefaultActionQueue(),
}
return s
}
func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second*10)
defer cancel()
newCounter := s.counter.Add(1)
msg.ReplyId = genReplyKey(peerId, msg.ObjectId, newCounter)
log.Info("mpool sendSync", zap.String("replyId", msg.ReplyId))
s.waitersMx.Lock()
waiter := responseWaiter{
ch: make(chan *spacesyncproto.ObjectSyncMessage, 1),
@ -81,19 +83,14 @@ func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyn
func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
if msg.ReplyId != "" {
log.Info("mpool receive reply", zap.String("replyId", msg.ReplyId))
// we got reply, send it to waiter
if s.stopWaiter(msg) {
return
}
log.With(zap.String("replyId", msg.ReplyId)).Debug("reply id does not exist")
return
}
return s.queue.Send(func() error {
if e := s.messageHandler(ctx, senderId, msg); e != nil {
log.Info("handle message error", zap.Error(e))
}
return nil
})
return s.messageHandler(ctx, senderId, msg)
}
func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool {

View file

@ -1,4 +1,3 @@
//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anytypeio/any-sync/commonspace/objectsync ActionQueue
package objectsync
import (
@ -18,58 +17,57 @@ type ObjectSync interface {
ocache.ObjectLastUsage
synchandler.SyncHandler
MessagePool() MessagePool
ActionQueue() ActionQueue
Init(getter syncobjectgetter.SyncObjectGetter)
Init()
Close() (err error)
}
type objectSync struct {
spaceId string
streamPool MessagePool
messagePool MessagePool
objectGetter syncobjectgetter.SyncObjectGetter
actionQueue ActionQueue
syncCtx context.Context
cancelSync context.CancelFunc
}
func NewObjectSync(streamManager StreamManager, spaceId string) (objectSync ObjectSync) {
msgPool := newMessagePool(streamManager, func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
return objectSync.HandleMessage(ctx, senderId, message)
})
func NewObjectSync(
spaceId string,
streamManager StreamManager,
objectGetter syncobjectgetter.SyncObjectGetter) ObjectSync {
syncCtx, cancel := context.WithCancel(context.Background())
objectSync = newObjectSync(
os := newObjectSync(
spaceId,
msgPool,
objectGetter,
syncCtx,
cancel)
return
msgPool := newMessagePool(streamManager, os.handleMessage)
os.messagePool = msgPool
return os
}
func newObjectSync(
spaceId string,
streamPool MessagePool,
objectGetter syncobjectgetter.SyncObjectGetter,
syncCtx context.Context,
cancel context.CancelFunc,
) *objectSync {
return &objectSync{
streamPool: streamPool,
spaceId: spaceId,
syncCtx: syncCtx,
cancelSync: cancel,
actionQueue: NewDefaultActionQueue(),
objectGetter: objectGetter,
spaceId: spaceId,
syncCtx: syncCtx,
cancelSync: cancel,
//actionQueue: NewDefaultActionQueue(),
}
}
func (s *objectSync) Init(objectGetter syncobjectgetter.SyncObjectGetter) {
s.objectGetter = objectGetter
s.actionQueue.Run()
func (s *objectSync) Init() {
//s.actionQueue.Run()
}
func (s *objectSync) Close() (err error) {
s.actionQueue.Close()
//s.actionQueue.Close()
s.cancelSync()
return
}
@ -80,7 +78,11 @@ func (s *objectSync) LastUsage() time.Time {
}
func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId)).Debug("handling message")
return s.messagePool.HandleMessage(ctx, senderId, message)
}
func (s *objectSync) handleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId), zap.String("replyId", message.ReplyId)).Debug("handling message")
obj, err := s.objectGetter.GetObject(ctx, message.ObjectId)
if err != nil {
return
@ -89,9 +91,5 @@ func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message
}
func (s *objectSync) MessagePool() MessagePool {
return s.streamPool
}
func (s *objectSync) ActionQueue() ActionQueue {
return s.actionQueue
return s.messagePool
}

View file

@ -78,8 +78,8 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
}
continue
}
if _, ok := err.(secureservice.HandshakeError); ok {
l.Warn("listener handshake error", zap.Error(err))
if herr, ok := err.(secureservice.HandshakeError); ok {
l.Warn("listener handshake error", zap.Error(herr), zap.String("remoteAddr", herr.RemoteAddr()))
continue
}
l.Error("listener accept error", zap.Error(err))

View file

@ -12,7 +12,18 @@ import (
"net"
)
type HandshakeError error
type HandshakeError struct {
remoteAddr string
err error
}
func (he HandshakeError) RemoteAddr() string {
return he.remoteAddr
}
func (he HandshakeError) Error() string {
return he.err.Error()
}
const CName = "common.net.secure"

View file

@ -49,7 +49,10 @@ func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, er
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
secure, err := p.tr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, HandshakeError(err)
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
return ctx, secure, nil

View file

@ -1,7 +1,6 @@
package streampool
import (
"fmt"
"go.uber.org/zap"
"storj.io/drpc"
"sync/atomic"
@ -18,11 +17,7 @@ type stream struct {
}
func (sr *stream) write(msg drpc.Message) (err error) {
defer func() {
sr.l.Debug("write", zap.String("msg", msg.(fmt.Stringer).String()), zap.Error(err))
}()
if err = sr.stream.MsgSend(msg, EncodingProto); err != nil {
sr.l.Info("stream write error", zap.Error(err))
sr.streamClose()
}
return err
@ -38,8 +33,7 @@ func (sr *stream) readLoop() error {
sr.l.Info("msg receive error", zap.Error(err))
return err
}
sr.l.Debug("read msg", zap.String("msg", msg.(fmt.Stringer).String()))
if err := sr.pool.handler.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil {
if err := sr.pool.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil {
sr.l.Info("msg handle error", zap.Error(err))
return err
}

View file

@ -1,9 +1,9 @@
package streampool
import (
"fmt"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/pool"
"github.com/cheggaaa/mb/v3"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"golang.org/x/net/context"
@ -42,12 +42,30 @@ type streamPool struct {
streamIdsByPeer map[string][]uint32
streamIdsByTag map[string][]uint32
streams map[uint32]*stream
opening map[string]chan struct{}
opening map[string]*openingProcess
exec *sendPool
handleQueue *mb.MB[handleMessage]
mu sync.RWMutex
lastStreamId uint32
}
type openingProcess struct {
ch chan struct{}
err error
}
type handleMessage struct {
ctx context.Context
msg drpc.Message
peerId string
}
func (s *streamPool) init() {
// TODO: to config
for i := 0; i < 10; i++ {
go s.handleMessageLoop()
}
}
func (s *streamPool) ReadStream(peerId string, drpcStream drpc.Stream, tags ...string) error {
st := s.addStream(peerId, drpcStream, tags...)
return st.readLoop()
@ -78,7 +96,6 @@ func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...st
for _, tag := range tags {
s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId)
}
st.l.Debug("stream added", zap.Strings("tags", st.tags))
return st
}
@ -87,7 +104,7 @@ func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.P
for _, p := range peers {
funcs = append(funcs, func() {
if e := s.sendOne(ctx, p, msg); e != nil {
log.Info("send peer error", zap.Error(e))
log.Info("send peer error", zap.Error(e), zap.String("peerId", p.Id()))
}
})
}
@ -103,12 +120,11 @@ func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ...
}
}
s.mu.Unlock()
log.Debug("sendById", zap.String("msg", msg.(fmt.Stringer).String()), zap.Int("streams", len(streams)))
var funcs []func()
for _, st := range streams {
funcs = append(funcs, func() {
if e := st.write(msg); e != nil {
log.Debug("sendById write error", zap.Error(e))
st.l.Debug("sendById write error", zap.Error(e))
}
})
}
@ -126,7 +142,7 @@ func (s *streamPool) sendOne(ctx context.Context, p peer.Peer, msg drpc.Message)
}
for _, st := range streams {
if err = st.write(msg); err != nil {
log.Info("stream write error", zap.Error(err))
st.l.Info("sendOne write error", zap.Error(err))
// continue with next stream
continue
} else {
@ -144,18 +160,21 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st
for _, streamId := range streamIds {
streams = append(streams, s.streams[streamId])
}
var openingCh chan struct{}
var op *openingProcess
// no cached streams found
if len(streams) == 0 {
// start opening process
openingCh = s.openStream(ctx, p)
op = s.openStream(ctx, p)
}
s.mu.Unlock()
// not empty openingCh means we should wait for the stream opening and try again
if openingCh != nil {
if op != nil {
select {
case <-openingCh:
case <-op.ch:
if op.err != nil {
return nil, op.err
}
return s.getStreams(ctx, p)
case <-ctx.Done():
return nil, ctx.Err()
@ -164,30 +183,32 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st
return streams, nil
}
func (s *streamPool) openStream(ctx context.Context, p peer.Peer) chan struct{} {
if ch, ok := s.opening[p.Id()]; ok {
func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProcess {
if op, ok := s.opening[p.Id()]; ok {
// already have an opening process for this stream - return channel
return ch
return op
}
ch := make(chan struct{})
s.opening[p.Id()] = ch
op := &openingProcess{
ch: make(chan struct{}),
}
s.opening[p.Id()] = op
go func() {
// start stream opening in separate goroutine to avoid lock whole pool
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
close(ch)
close(op.ch)
delete(s.opening, p.Id())
}()
// open new stream and add to pool
st, tags, err := s.handler.OpenStream(ctx, p)
if err != nil {
log.Warn("stream open error", zap.Error(err))
op.err = err
return
}
s.AddStream(p.Id(), st, tags...)
}()
return ch
return op
}
func (s *streamPool) Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) {
@ -244,6 +265,28 @@ func (s *streamPool) removeStream(streamId uint32) {
st.l.Debug("stream removed", zap.Strings("tags", st.tags))
}
func (s *streamPool) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) {
return s.handleQueue.Add(ctx, handleMessage{
ctx: ctx,
msg: msg,
peerId: peerId,
})
}
func (s *streamPool) handleMessageLoop() {
for {
hm, err := s.handleQueue.WaitOne(context.Background())
if err != nil {
return
}
go func() {
if err = s.handler.HandleMessage(hm.ctx, hm.peerId, hm.msg); err != nil {
log.Warn("handle message error", zap.Error(err))
}
}()
}
}
func (s *streamPool) Close() (err error) {
return s.exec.Close()
}

View file

@ -18,22 +18,23 @@ import (
var ctx = context.Background()
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
p, err := fx.tp.Dial(ctx, peerId)
require.NoError(t, err)
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
require.NoError(t, err)
return s, p
}
func TestStreamPool_AddStream(t *testing.T) {
newClientStream := func(fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
p, err := fx.tp.Dial(ctx, peerId)
require.NoError(t, err)
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
require.NoError(t, err)
return s, p
}
t.Run("broadcast incoming", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, _ := newClientStream(fx, "p1")
s1, _ := newClientStream(t, fx, "p1")
fx.AddStream("p1", s1, "space1", "common")
s2, _ := newClientStream(fx, "p2")
s2, _ := newClientStream(t, fx, "p2")
fx.AddStream("p2", s2, "space2", "common")
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1"))
@ -61,7 +62,7 @@ func TestStreamPool_AddStream(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, p1 := newClientStream(fx, "p1")
s1, p1 := newClientStream(t, fx, "p1")
defer s1.Close()
fx.AddStream("p1", s1, "space1", "common")
@ -122,6 +123,46 @@ func TestStreamPool_Send(t *testing.T) {
// make sure that we have only one stream
assert.Equal(t, int32(1), fx.tsh.streamsCount.Load())
})
t.Run("parallel open stream error", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1")
require.NoError(t, err)
_ = p.Close()
fx.th.streamOpenDelay = time.Second / 3
var numMsgs = 5
var wg sync.WaitGroup
for i := 0; i < numMsgs; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"}))
}()
}
wg.Wait()
})
}
func TestStreamPool_SendById(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, _ := newClientStream(t, fx, "p1")
defer s1.Close()
fx.AddStream("p1", s1, "space1", "common")
require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1"))
var msg *testservice.StreamMessage
select {
case msg = <-fx.tsh.receiveCh:
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
assert.Equal(t, "test", msg.ReqData)
}
func newFixture(t *testing.T) *fixture {

View file

@ -3,6 +3,7 @@ package streampool
import (
"github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/app/logger"
"github.com/cheggaaa/mb/v3"
)
const CName = "common.net.streampool"
@ -22,15 +23,17 @@ type service struct {
}
func (s *service) NewStreamPool(h StreamHandler) StreamPool {
return &streamPool{
sp := &streamPool{
handler: h,
streamIdsByPeer: map[string][]uint32{},
streamIdsByTag: map[string][]uint32{},
streams: map[uint32]*stream{},
opening: map[string]chan struct{}{},
opening: map[string]*openingProcess{},
exec: newStreamSender(10, 100),
lastStreamId: 0,
handleQueue: mb.New[handleMessage](100),
}
sp.init()
return sp
}
func (s *service) Init(a *app.App) (err error) {