mirror of
https://github.com/anyproto/any-sync.git
synced 2025-06-08 05:57:03 +09:00
sync fixes + fix tests
This commit is contained in:
parent
34848254be
commit
5fb6ee5a7b
16 changed files with 197 additions and 392 deletions
|
@ -1,37 +0,0 @@
|
|||
package synctree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/objectsync"
|
||||
)
|
||||
|
||||
type queuedClient struct {
|
||||
SyncClient
|
||||
queue objectsync.ActionQueue
|
||||
}
|
||||
|
||||
func newQueuedClient(client SyncClient, queue objectsync.ActionQueue) SyncClient {
|
||||
return &queuedClient{
|
||||
SyncClient: client,
|
||||
queue: queue,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queuedClient) Broadcast(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) {
|
||||
return q.queue.Send(func() error {
|
||||
return q.SyncClient.Broadcast(ctx, message)
|
||||
})
|
||||
}
|
||||
|
||||
func (q *queuedClient) SendWithReply(ctx context.Context, peerId string, message *treechangeproto.TreeSyncMessage, replyId string) (err error) {
|
||||
return q.queue.Send(func() error {
|
||||
return q.SyncClient.SendWithReply(ctx, peerId, message, replyId)
|
||||
})
|
||||
}
|
||||
|
||||
func (q *queuedClient) BroadcastAsyncOrSendResponsible(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) {
|
||||
return q.queue.Send(func() error {
|
||||
return q.SyncClient.BroadcastAsyncOrSendResponsible(ctx, message)
|
||||
})
|
||||
}
|
|
@ -52,7 +52,7 @@ type syncTree struct {
|
|||
var log = logger.NewNamed("commonspace.synctree").Sugar()
|
||||
|
||||
var buildObjectTree = objecttree.BuildObjectTree
|
||||
var createSyncClient = newWrappedSyncClient
|
||||
var createSyncClient = newSyncClient
|
||||
|
||||
type BuildDeps struct {
|
||||
SpaceId string
|
||||
|
@ -68,15 +68,6 @@ type BuildDeps struct {
|
|||
WaitTreeRemoteSync bool
|
||||
}
|
||||
|
||||
func newWrappedSyncClient(
|
||||
spaceId string,
|
||||
factory RequestFactory,
|
||||
objectSync objectsync.ObjectSync,
|
||||
configuration nodeconf.Configuration) SyncClient {
|
||||
syncClient := newSyncClient(spaceId, objectSync.MessagePool(), factory, configuration)
|
||||
return newQueuedClient(syncClient, objectSync.ActionQueue())
|
||||
}
|
||||
|
||||
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t SyncTree, err error) {
|
||||
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
|
||||
peerId, err := peer.CtxPeerId(ctx)
|
||||
|
@ -182,8 +173,8 @@ func buildSyncTree(ctx context.Context, isFirstBuild bool, deps BuildDeps) (t Sy
|
|||
}
|
||||
syncClient := createSyncClient(
|
||||
deps.SpaceId,
|
||||
deps.ObjectSync.MessagePool(),
|
||||
sharedFactory,
|
||||
deps.ObjectSync,
|
||||
deps.Configuration)
|
||||
syncTree := &syncTree{
|
||||
ObjectTree: objTree,
|
||||
|
|
|
@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) {
|
|||
updateListenerMock.EXPECT().Update(tr)
|
||||
|
||||
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
|
||||
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
|
||||
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
|
||||
res, err := tr.AddRawChanges(ctx, payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedRes, res)
|
||||
|
@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) {
|
|||
updateListenerMock.EXPECT().Rebuild(tr)
|
||||
|
||||
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
|
||||
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
|
||||
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
|
||||
res, err := tr.AddRawChanges(ctx, payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedRes, res)
|
||||
|
@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) {
|
|||
Return(expectedRes, nil)
|
||||
|
||||
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
|
||||
syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil)
|
||||
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil)
|
||||
res, err := tr.AddContent(ctx, content)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedRes, res)
|
||||
|
|
|
@ -38,7 +38,7 @@ type syncHandlerFixture struct {
|
|||
ctrl *gomock.Controller
|
||||
syncClientMock *mock_synctree.MockSyncClient
|
||||
objectTreeMock *testObjTreeMock
|
||||
receiveQueueMock *mock_synctree.MockReceiveQueue
|
||||
receiveQueueMock ReceiveQueue
|
||||
|
||||
syncHandler *syncTreeHandler
|
||||
}
|
||||
|
@ -47,19 +47,19 @@ func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
|
|||
ctrl := gomock.NewController(t)
|
||||
syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
|
||||
objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
|
||||
receiveQueueMock := mock_synctree.NewMockReceiveQueue(ctrl)
|
||||
receiveQueue := newReceiveQueue(5)
|
||||
|
||||
syncHandler := &syncTreeHandler{
|
||||
objTree: objectTreeMock,
|
||||
syncClient: syncClientMock,
|
||||
queue: receiveQueueMock,
|
||||
queue: receiveQueue,
|
||||
syncStatus: syncstatus.NewNoOpSyncStatus(),
|
||||
}
|
||||
return &syncHandlerFixture{
|
||||
ctrl: ctrl,
|
||||
syncClientMock: syncClientMock,
|
||||
objectTreeMock: objectTreeMock,
|
||||
receiveQueueMock: receiveQueueMock,
|
||||
receiveQueueMock: receiveQueue,
|
||||
syncHandler: syncHandler,
|
||||
}
|
||||
}
|
||||
|
@ -84,10 +84,7 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2)
|
||||
|
@ -101,7 +98,6 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2", "h1"})
|
||||
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -118,10 +114,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
fullRequest := &treechangeproto.TreeSyncMessage{}
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
|
||||
|
@ -136,9 +130,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
fx.syncClientMock.EXPECT().
|
||||
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
|
||||
Return(fullRequest, nil)
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -155,14 +148,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -179,19 +169,16 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
fullRequest := &treechangeproto.TreeSyncMessage{}
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
|
||||
fx.syncClientMock.EXPECT().
|
||||
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
|
||||
Return(fullRequest, nil)
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq(""))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -208,14 +195,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -237,10 +221,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
fullResponse := &treechangeproto.TreeSyncMessage{}
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().Header().Return(nil)
|
||||
|
@ -255,9 +237,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
fx.syncClientMock.EXPECT().
|
||||
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
|
||||
Return(fullResponse, nil)
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -274,10 +255,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
fullResponse := &treechangeproto.TreeSyncMessage{}
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
|
||||
fx.objectTreeMock.EXPECT().
|
||||
Id().AnyTimes().Return(treeId)
|
||||
|
@ -288,9 +267,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
fx.syncClientMock.EXPECT().
|
||||
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
|
||||
Return(fullResponse, nil)
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(""))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -307,10 +285,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
|
||||
fullResponse := &treechangeproto.TreeSyncMessage{}
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
|
||||
|
||||
fx.objectTreeMock.EXPECT().
|
||||
Id().AnyTimes().Return(treeId)
|
||||
|
@ -318,9 +294,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
fx.syncClientMock.EXPECT().
|
||||
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
|
||||
Return(fullResponse, nil)
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -337,9 +312,7 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "")
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "")
|
||||
|
||||
fx.objectTreeMock.EXPECT().
|
||||
Id().AnyTimes().Return(treeId)
|
||||
|
@ -356,9 +329,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
|
|||
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
|
||||
})).
|
||||
Return(objecttree.AddResult{}, fmt.Errorf(""))
|
||||
fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Any(), gomock.Eq(""))
|
||||
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Any(), gomock.Eq(""))
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
@ -381,9 +353,7 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().
|
||||
|
@ -399,7 +369,6 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
|
|||
})).
|
||||
Return(objecttree.AddResult{}, nil)
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
@ -417,16 +386,13 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
|
|||
SnapshotPath: []string{"h1"},
|
||||
}
|
||||
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId)
|
||||
fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false)
|
||||
fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil)
|
||||
objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
|
||||
|
||||
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
|
||||
fx.objectTreeMock.EXPECT().
|
||||
Heads().
|
||||
Return([]string{"h1"}).AnyTimes()
|
||||
|
||||
fx.receiveQueueMock.EXPECT().ClearQueue(senderId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
package objectsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cheggaaa/mb/v3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ActionFunc func() error
|
||||
|
||||
type ActionQueue interface {
|
||||
Send(action ActionFunc) (err error)
|
||||
Run()
|
||||
Close()
|
||||
}
|
||||
|
||||
type actionQueue struct {
|
||||
batcher *mb.MB[ActionFunc]
|
||||
maxReaders int
|
||||
maxQueueLen int
|
||||
readers chan struct{}
|
||||
}
|
||||
|
||||
func NewDefaultActionQueue() ActionQueue {
|
||||
return NewActionQueue(10, 200)
|
||||
}
|
||||
|
||||
func NewActionQueue(maxReaders int, maxQueueLen int) ActionQueue {
|
||||
return &actionQueue{
|
||||
batcher: mb.New[ActionFunc](maxQueueLen),
|
||||
maxReaders: maxReaders,
|
||||
maxQueueLen: maxQueueLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *actionQueue) Send(action ActionFunc) (err error) {
|
||||
log.Debug("adding action to batcher")
|
||||
err = q.batcher.TryAdd(action)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.With(zap.Error(err)).Debug("queue returned error")
|
||||
actions := q.batcher.GetAll()
|
||||
actions = append(actions[len(actions)/2:], action)
|
||||
return q.batcher.Add(context.Background(), actions...)
|
||||
}
|
||||
|
||||
func (q *actionQueue) Run() {
|
||||
log.Debug("running the queue")
|
||||
q.readers = make(chan struct{}, q.maxReaders)
|
||||
for i := 0; i < q.maxReaders; i++ {
|
||||
go q.startReading()
|
||||
}
|
||||
}
|
||||
|
||||
func (q *actionQueue) startReading() {
|
||||
defer func() {
|
||||
q.readers <- struct{}{}
|
||||
}()
|
||||
for {
|
||||
action, err := q.batcher.WaitOne(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = action()
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Debug("action errored out")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *actionQueue) Close() {
|
||||
log.Debug("closing the queue")
|
||||
q.batcher.Close()
|
||||
for i := 0; i < q.maxReaders; i++ {
|
||||
<-q.readers
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
package objectsync
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestActionQueue_Send(t *testing.T) {
|
||||
maxReaders := 41
|
||||
maxLen := 93
|
||||
|
||||
queue := NewActionQueue(maxReaders, maxLen).(*actionQueue)
|
||||
counter := atomic.Int32{}
|
||||
expectedCounter := int32(maxReaders + (maxLen+1)/2 + 1)
|
||||
blocker := make(chan struct{}, expectedCounter)
|
||||
waiter := make(chan struct{}, expectedCounter)
|
||||
increase := func() error {
|
||||
counter.Add(1)
|
||||
waiter <- struct{}{}
|
||||
<-blocker
|
||||
return nil
|
||||
}
|
||||
|
||||
queue.Run()
|
||||
// sending maxReaders messages, so the goroutines will block on `blocker` channel
|
||||
for i := 0; i < maxReaders; i++ {
|
||||
queue.Send(increase)
|
||||
}
|
||||
// waiting until they all make progress
|
||||
for i := 0; i < maxReaders; i++ {
|
||||
<-waiter
|
||||
}
|
||||
fmt.Println(counter.Load())
|
||||
// check that queue is empty
|
||||
require.Equal(t, queue.batcher.Len(), 0)
|
||||
// making queue to overflow while readers are blocked
|
||||
for i := 0; i < maxLen+1; i++ {
|
||||
queue.Send(increase)
|
||||
}
|
||||
// check that queue was halved after overflow
|
||||
require.Equal(t, (maxLen+1)/2+1, queue.batcher.Len())
|
||||
// unblocking maxReaders waiting + then we should also unblock the new readers to do a bit more readings
|
||||
for i := 0; i < int(expectedCounter); i++ {
|
||||
blocker <- struct{}{}
|
||||
}
|
||||
// waiting for all readers to finish adding
|
||||
for i := 0; i < int(expectedCounter)-maxReaders; i++ {
|
||||
<-waiter
|
||||
}
|
||||
queue.Close()
|
||||
require.Equal(t, expectedCounter, counter.Load())
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/objectsync (interfaces: ActionQueue)
|
||||
|
||||
// Package mock_objectsync is a generated GoMock package.
|
||||
package mock_objectsync
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
objectsync "github.com/anytypeio/any-sync/commonspace/objectsync"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockActionQueue is a mock of ActionQueue interface.
|
||||
type MockActionQueue struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockActionQueueMockRecorder
|
||||
}
|
||||
|
||||
// MockActionQueueMockRecorder is the mock recorder for MockActionQueue.
|
||||
type MockActionQueueMockRecorder struct {
|
||||
mock *MockActionQueue
|
||||
}
|
||||
|
||||
// NewMockActionQueue creates a new mock instance.
|
||||
func NewMockActionQueue(ctrl *gomock.Controller) *MockActionQueue {
|
||||
mock := &MockActionQueue{ctrl: ctrl}
|
||||
mock.recorder = &MockActionQueueMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockActionQueue) EXPECT() *MockActionQueueMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockActionQueue) Close() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close")
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockActionQueueMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockActionQueue)(nil).Close))
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockActionQueue) Run() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Run")
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockActionQueueMockRecorder) Run() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockActionQueue)(nil).Run))
|
||||
}
|
||||
|
||||
// Send mocks base method.
|
||||
func (m *MockActionQueue) Send(arg0 objectsync.ActionFunc) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Send", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Send indicates an expected call of Send.
|
||||
func (mr *MockActionQueueMockRecorder) Send(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockActionQueue)(nil).Send), arg0)
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StreamManager interface {
|
||||
|
@ -37,7 +38,6 @@ type messagePool struct {
|
|||
waiters map[string]responseWaiter
|
||||
waitersMx sync.Mutex
|
||||
counter atomic.Uint64
|
||||
queue ActionQueue
|
||||
}
|
||||
|
||||
func newMessagePool(streamManager StreamManager, messageHandler MessageHandler) MessagePool {
|
||||
|
@ -45,15 +45,17 @@ func newMessagePool(streamManager StreamManager, messageHandler MessageHandler)
|
|||
StreamManager: streamManager,
|
||||
messageHandler: messageHandler,
|
||||
waiters: make(map[string]responseWaiter),
|
||||
queue: NewDefaultActionQueue(),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
newCounter := s.counter.Add(1)
|
||||
msg.ReplyId = genReplyKey(peerId, msg.ObjectId, newCounter)
|
||||
|
||||
log.Info("mpool sendSync", zap.String("replyId", msg.ReplyId))
|
||||
s.waitersMx.Lock()
|
||||
waiter := responseWaiter{
|
||||
ch: make(chan *spacesyncproto.ObjectSyncMessage, 1),
|
||||
|
@ -81,19 +83,14 @@ func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyn
|
|||
|
||||
func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
if msg.ReplyId != "" {
|
||||
log.Info("mpool receive reply", zap.String("replyId", msg.ReplyId))
|
||||
// we got reply, send it to waiter
|
||||
if s.stopWaiter(msg) {
|
||||
return
|
||||
}
|
||||
log.With(zap.String("replyId", msg.ReplyId)).Debug("reply id does not exist")
|
||||
return
|
||||
}
|
||||
return s.queue.Send(func() error {
|
||||
if e := s.messageHandler(ctx, senderId, msg); e != nil {
|
||||
log.Info("handle message error", zap.Error(e))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return s.messageHandler(ctx, senderId, msg)
|
||||
}
|
||||
|
||||
func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anytypeio/any-sync/commonspace/objectsync ActionQueue
|
||||
package objectsync
|
||||
|
||||
import (
|
||||
|
@ -18,58 +17,57 @@ type ObjectSync interface {
|
|||
ocache.ObjectLastUsage
|
||||
synchandler.SyncHandler
|
||||
MessagePool() MessagePool
|
||||
ActionQueue() ActionQueue
|
||||
|
||||
Init(getter syncobjectgetter.SyncObjectGetter)
|
||||
Init()
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type objectSync struct {
|
||||
spaceId string
|
||||
|
||||
streamPool MessagePool
|
||||
messagePool MessagePool
|
||||
objectGetter syncobjectgetter.SyncObjectGetter
|
||||
actionQueue ActionQueue
|
||||
|
||||
syncCtx context.Context
|
||||
cancelSync context.CancelFunc
|
||||
}
|
||||
|
||||
func NewObjectSync(streamManager StreamManager, spaceId string) (objectSync ObjectSync) {
|
||||
msgPool := newMessagePool(streamManager, func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
return objectSync.HandleMessage(ctx, senderId, message)
|
||||
})
|
||||
func NewObjectSync(
|
||||
spaceId string,
|
||||
streamManager StreamManager,
|
||||
objectGetter syncobjectgetter.SyncObjectGetter) ObjectSync {
|
||||
syncCtx, cancel := context.WithCancel(context.Background())
|
||||
objectSync = newObjectSync(
|
||||
os := newObjectSync(
|
||||
spaceId,
|
||||
msgPool,
|
||||
objectGetter,
|
||||
syncCtx,
|
||||
cancel)
|
||||
return
|
||||
msgPool := newMessagePool(streamManager, os.handleMessage)
|
||||
os.messagePool = msgPool
|
||||
return os
|
||||
}
|
||||
|
||||
func newObjectSync(
|
||||
spaceId string,
|
||||
streamPool MessagePool,
|
||||
objectGetter syncobjectgetter.SyncObjectGetter,
|
||||
syncCtx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
) *objectSync {
|
||||
return &objectSync{
|
||||
streamPool: streamPool,
|
||||
spaceId: spaceId,
|
||||
syncCtx: syncCtx,
|
||||
cancelSync: cancel,
|
||||
actionQueue: NewDefaultActionQueue(),
|
||||
objectGetter: objectGetter,
|
||||
spaceId: spaceId,
|
||||
syncCtx: syncCtx,
|
||||
cancelSync: cancel,
|
||||
//actionQueue: NewDefaultActionQueue(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *objectSync) Init(objectGetter syncobjectgetter.SyncObjectGetter) {
|
||||
s.objectGetter = objectGetter
|
||||
s.actionQueue.Run()
|
||||
func (s *objectSync) Init() {
|
||||
//s.actionQueue.Run()
|
||||
}
|
||||
|
||||
func (s *objectSync) Close() (err error) {
|
||||
s.actionQueue.Close()
|
||||
//s.actionQueue.Close()
|
||||
s.cancelSync()
|
||||
return
|
||||
}
|
||||
|
@ -80,7 +78,11 @@ func (s *objectSync) LastUsage() time.Time {
|
|||
}
|
||||
|
||||
func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId)).Debug("handling message")
|
||||
return s.messagePool.HandleMessage(ctx, senderId, message)
|
||||
}
|
||||
|
||||
func (s *objectSync) handleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId), zap.String("replyId", message.ReplyId)).Debug("handling message")
|
||||
obj, err := s.objectGetter.GetObject(ctx, message.ObjectId)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -89,9 +91,5 @@ func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message
|
|||
}
|
||||
|
||||
func (s *objectSync) MessagePool() MessagePool {
|
||||
return s.streamPool
|
||||
}
|
||||
|
||||
func (s *objectSync) ActionQueue() ActionQueue {
|
||||
return s.actionQueue
|
||||
return s.messagePool
|
||||
}
|
||||
|
|
|
@ -78,8 +78,8 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
|
|||
}
|
||||
continue
|
||||
}
|
||||
if _, ok := err.(secureservice.HandshakeError); ok {
|
||||
l.Warn("listener handshake error", zap.Error(err))
|
||||
if herr, ok := err.(secureservice.HandshakeError); ok {
|
||||
l.Warn("listener handshake error", zap.Error(herr), zap.String("remoteAddr", herr.RemoteAddr()))
|
||||
continue
|
||||
}
|
||||
l.Error("listener accept error", zap.Error(err))
|
||||
|
|
|
@ -12,7 +12,18 @@ import (
|
|||
"net"
|
||||
)
|
||||
|
||||
type HandshakeError error
|
||||
type HandshakeError struct {
|
||||
remoteAddr string
|
||||
err error
|
||||
}
|
||||
|
||||
func (he HandshakeError) RemoteAddr() string {
|
||||
return he.remoteAddr
|
||||
}
|
||||
|
||||
func (he HandshakeError) Error() string {
|
||||
return he.err.Error()
|
||||
}
|
||||
|
||||
const CName = "common.net.secure"
|
||||
|
||||
|
|
|
@ -49,7 +49,10 @@ func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, er
|
|||
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
|
||||
secure, err := p.tr.SecureInbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError(err)
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
|
||||
return ctx, secure, nil
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package streampool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"storj.io/drpc"
|
||||
"sync/atomic"
|
||||
|
@ -18,11 +17,7 @@ type stream struct {
|
|||
}
|
||||
|
||||
func (sr *stream) write(msg drpc.Message) (err error) {
|
||||
defer func() {
|
||||
sr.l.Debug("write", zap.String("msg", msg.(fmt.Stringer).String()), zap.Error(err))
|
||||
}()
|
||||
if err = sr.stream.MsgSend(msg, EncodingProto); err != nil {
|
||||
sr.l.Info("stream write error", zap.Error(err))
|
||||
sr.streamClose()
|
||||
}
|
||||
return err
|
||||
|
@ -38,8 +33,7 @@ func (sr *stream) readLoop() error {
|
|||
sr.l.Info("msg receive error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
sr.l.Debug("read msg", zap.String("msg", msg.(fmt.Stringer).String()))
|
||||
if err := sr.pool.handler.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil {
|
||||
if err := sr.pool.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil {
|
||||
sr.l.Info("msg handle error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package streampool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/anytypeio/any-sync/net/pool"
|
||||
"github.com/cheggaaa/mb/v3"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/net/context"
|
||||
|
@ -42,12 +42,30 @@ type streamPool struct {
|
|||
streamIdsByPeer map[string][]uint32
|
||||
streamIdsByTag map[string][]uint32
|
||||
streams map[uint32]*stream
|
||||
opening map[string]chan struct{}
|
||||
opening map[string]*openingProcess
|
||||
exec *sendPool
|
||||
handleQueue *mb.MB[handleMessage]
|
||||
mu sync.RWMutex
|
||||
lastStreamId uint32
|
||||
}
|
||||
|
||||
type openingProcess struct {
|
||||
ch chan struct{}
|
||||
err error
|
||||
}
|
||||
type handleMessage struct {
|
||||
ctx context.Context
|
||||
msg drpc.Message
|
||||
peerId string
|
||||
}
|
||||
|
||||
func (s *streamPool) init() {
|
||||
// TODO: to config
|
||||
for i := 0; i < 10; i++ {
|
||||
go s.handleMessageLoop()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamPool) ReadStream(peerId string, drpcStream drpc.Stream, tags ...string) error {
|
||||
st := s.addStream(peerId, drpcStream, tags...)
|
||||
return st.readLoop()
|
||||
|
@ -78,7 +96,6 @@ func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...st
|
|||
for _, tag := range tags {
|
||||
s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId)
|
||||
}
|
||||
st.l.Debug("stream added", zap.Strings("tags", st.tags))
|
||||
return st
|
||||
}
|
||||
|
||||
|
@ -87,7 +104,7 @@ func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.P
|
|||
for _, p := range peers {
|
||||
funcs = append(funcs, func() {
|
||||
if e := s.sendOne(ctx, p, msg); e != nil {
|
||||
log.Info("send peer error", zap.Error(e))
|
||||
log.Info("send peer error", zap.Error(e), zap.String("peerId", p.Id()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -103,12 +120,11 @@ func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ...
|
|||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
log.Debug("sendById", zap.String("msg", msg.(fmt.Stringer).String()), zap.Int("streams", len(streams)))
|
||||
var funcs []func()
|
||||
for _, st := range streams {
|
||||
funcs = append(funcs, func() {
|
||||
if e := st.write(msg); e != nil {
|
||||
log.Debug("sendById write error", zap.Error(e))
|
||||
st.l.Debug("sendById write error", zap.Error(e))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -126,7 +142,7 @@ func (s *streamPool) sendOne(ctx context.Context, p peer.Peer, msg drpc.Message)
|
|||
}
|
||||
for _, st := range streams {
|
||||
if err = st.write(msg); err != nil {
|
||||
log.Info("stream write error", zap.Error(err))
|
||||
st.l.Info("sendOne write error", zap.Error(err))
|
||||
// continue with next stream
|
||||
continue
|
||||
} else {
|
||||
|
@ -144,18 +160,21 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st
|
|||
for _, streamId := range streamIds {
|
||||
streams = append(streams, s.streams[streamId])
|
||||
}
|
||||
var openingCh chan struct{}
|
||||
var op *openingProcess
|
||||
// no cached streams found
|
||||
if len(streams) == 0 {
|
||||
// start opening process
|
||||
openingCh = s.openStream(ctx, p)
|
||||
op = s.openStream(ctx, p)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// not empty openingCh means we should wait for the stream opening and try again
|
||||
if openingCh != nil {
|
||||
if op != nil {
|
||||
select {
|
||||
case <-openingCh:
|
||||
case <-op.ch:
|
||||
if op.err != nil {
|
||||
return nil, op.err
|
||||
}
|
||||
return s.getStreams(ctx, p)
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
@ -164,30 +183,32 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st
|
|||
return streams, nil
|
||||
}
|
||||
|
||||
func (s *streamPool) openStream(ctx context.Context, p peer.Peer) chan struct{} {
|
||||
if ch, ok := s.opening[p.Id()]; ok {
|
||||
func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProcess {
|
||||
if op, ok := s.opening[p.Id()]; ok {
|
||||
// already have an opening process for this stream - return channel
|
||||
return ch
|
||||
return op
|
||||
}
|
||||
ch := make(chan struct{})
|
||||
s.opening[p.Id()] = ch
|
||||
op := &openingProcess{
|
||||
ch: make(chan struct{}),
|
||||
}
|
||||
s.opening[p.Id()] = op
|
||||
go func() {
|
||||
// start stream opening in separate goroutine to avoid lock whole pool
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
close(ch)
|
||||
close(op.ch)
|
||||
delete(s.opening, p.Id())
|
||||
}()
|
||||
// open new stream and add to pool
|
||||
st, tags, err := s.handler.OpenStream(ctx, p)
|
||||
if err != nil {
|
||||
log.Warn("stream open error", zap.Error(err))
|
||||
op.err = err
|
||||
return
|
||||
}
|
||||
s.AddStream(p.Id(), st, tags...)
|
||||
}()
|
||||
return ch
|
||||
return op
|
||||
}
|
||||
|
||||
func (s *streamPool) Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) {
|
||||
|
@ -244,6 +265,28 @@ func (s *streamPool) removeStream(streamId uint32) {
|
|||
st.l.Debug("stream removed", zap.Strings("tags", st.tags))
|
||||
}
|
||||
|
||||
func (s *streamPool) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) {
|
||||
return s.handleQueue.Add(ctx, handleMessage{
|
||||
ctx: ctx,
|
||||
msg: msg,
|
||||
peerId: peerId,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *streamPool) handleMessageLoop() {
|
||||
for {
|
||||
hm, err := s.handleQueue.WaitOne(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if err = s.handler.HandleMessage(hm.ctx, hm.peerId, hm.msg); err != nil {
|
||||
log.Warn("handle message error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamPool) Close() (err error) {
|
||||
return s.exec.Close()
|
||||
}
|
||||
|
|
|
@ -18,22 +18,23 @@ import (
|
|||
|
||||
var ctx = context.Background()
|
||||
|
||||
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
|
||||
p, err := fx.tp.Dial(ctx, peerId)
|
||||
require.NoError(t, err)
|
||||
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
|
||||
require.NoError(t, err)
|
||||
return s, p
|
||||
}
|
||||
|
||||
func TestStreamPool_AddStream(t *testing.T) {
|
||||
newClientStream := func(fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
|
||||
p, err := fx.tp.Dial(ctx, peerId)
|
||||
require.NoError(t, err)
|
||||
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
|
||||
require.NoError(t, err)
|
||||
return s, p
|
||||
}
|
||||
|
||||
t.Run("broadcast incoming", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
s1, _ := newClientStream(fx, "p1")
|
||||
s1, _ := newClientStream(t, fx, "p1")
|
||||
fx.AddStream("p1", s1, "space1", "common")
|
||||
s2, _ := newClientStream(fx, "p2")
|
||||
s2, _ := newClientStream(t, fx, "p2")
|
||||
fx.AddStream("p2", s2, "space2", "common")
|
||||
|
||||
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1"))
|
||||
|
@ -61,7 +62,7 @@ func TestStreamPool_AddStream(t *testing.T) {
|
|||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
s1, p1 := newClientStream(fx, "p1")
|
||||
s1, p1 := newClientStream(t, fx, "p1")
|
||||
defer s1.Close()
|
||||
fx.AddStream("p1", s1, "space1", "common")
|
||||
|
||||
|
@ -122,6 +123,46 @@ func TestStreamPool_Send(t *testing.T) {
|
|||
// make sure that we have only one stream
|
||||
assert.Equal(t, int32(1), fx.tsh.streamsCount.Load())
|
||||
})
|
||||
t.Run("parallel open stream error", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
p, err := fx.tp.Dial(ctx, "p1")
|
||||
require.NoError(t, err)
|
||||
_ = p.Close()
|
||||
|
||||
fx.th.streamOpenDelay = time.Second / 3
|
||||
|
||||
var numMsgs = 5
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"}))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_SendById(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
s1, _ := newClientStream(t, fx, "p1")
|
||||
defer s1.Close()
|
||||
fx.AddStream("p1", s1, "space1", "common")
|
||||
|
||||
require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1"))
|
||||
var msg *testservice.StreamMessage
|
||||
select {
|
||||
case msg = <-fx.tsh.receiveCh:
|
||||
case <-time.After(time.Second):
|
||||
require.NoError(t, fmt.Errorf("timeout"))
|
||||
}
|
||||
assert.Equal(t, "test", msg.ReqData)
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) *fixture {
|
||||
|
|
|
@ -3,6 +3,7 @@ package streampool
|
|||
import (
|
||||
"github.com/anytypeio/any-sync/app"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/cheggaaa/mb/v3"
|
||||
)
|
||||
|
||||
const CName = "common.net.streampool"
|
||||
|
@ -22,15 +23,17 @@ type service struct {
|
|||
}
|
||||
|
||||
func (s *service) NewStreamPool(h StreamHandler) StreamPool {
|
||||
return &streamPool{
|
||||
sp := &streamPool{
|
||||
handler: h,
|
||||
streamIdsByPeer: map[string][]uint32{},
|
||||
streamIdsByTag: map[string][]uint32{},
|
||||
streams: map[uint32]*stream{},
|
||||
opening: map[string]chan struct{}{},
|
||||
opening: map[string]*openingProcess{},
|
||||
exec: newStreamSender(10, 100),
|
||||
lastStreamId: 0,
|
||||
handleQueue: mb.New[handleMessage](100),
|
||||
}
|
||||
sp.init()
|
||||
return sp
|
||||
}
|
||||
|
||||
func (s *service) Init(a *app.App) (err error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue