From 43383853e13f38c122e3ffc174a05e16a0d2a396 Mon Sep 17 00:00:00 2001 From: mcrakhman Date: Thu, 6 Jun 2024 00:37:01 +0200 Subject: [PATCH] Add global pool for peers --- commonspace/sync/sync_test.go | 15 ++-- commonspace/sync/synctest/connprovider.go | 34 +------ commonspace/sync/synctest/countergenerator.go | 3 +- commonspace/sync/synctest/peerglobalpool.go | 89 +++++++++++++++++++ commonspace/sync/synctest/peerprovider.go | 32 ++++--- 5 files changed, 117 insertions(+), 56 deletions(-) create mode 100644 commonspace/sync/synctest/peerglobalpool.go diff --git a/commonspace/sync/sync_test.go b/commonspace/sync/sync_test.go index 0cd8974c..08751ee4 100644 --- a/commonspace/sync/sync_test.go +++ b/commonspace/sync/sync_test.go @@ -16,10 +16,11 @@ import ( var ctx = context.Background() func TestNewSyncService(t *testing.T) { - connProvider := synctest.NewConnProvider([]string{"first", "second"}) + peerPool := synctest.NewPeerGlobalPool([]string{"first", "second"}) + peerPool.MakePeers() var ( - firstApp = newFixture(t, "first", counterFixtureParams{connProvider: connProvider, start: 0, delta: 2}) - secondApp = newFixture(t, "second", counterFixtureParams{connProvider: connProvider, start: 1, delta: 2}) + firstApp = newFixture(t, "first", counterFixtureParams{peerPool: peerPool, start: 0, delta: 2}) + secondApp = newFixture(t, "second", counterFixtureParams{peerPool: peerPool, start: 1, delta: 2}) ) require.NoError(t, firstApp.a.Start(ctx)) require.NoError(t, secondApp.a.Start(ctx)) @@ -33,14 +34,14 @@ type counterFixture struct { } type counterFixtureParams struct { - connProvider *synctest.ConnProvider - start int32 - delta int32 + peerPool *synctest.PeerGlobalPool + start int32 + delta int32 } func newFixture(t *testing.T, peerId string, params counterFixtureParams) *counterFixture { a := &app.App{} - a.Register(params.connProvider). + a.Register(params.peerPool). Register(synctest.NewConfig()). Register(rpctest.NewTestServer()). Register(synctest.NewCounterStreamOpener()). diff --git a/commonspace/sync/synctest/connprovider.go b/commonspace/sync/synctest/connprovider.go index c684a2f4..b4db7738 100644 --- a/commonspace/sync/synctest/connprovider.go +++ b/commonspace/sync/synctest/connprovider.go @@ -3,44 +3,18 @@ package synctest import ( "sync" - "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/net/rpc/rpctest" "github.com/anyproto/any-sync/net/transport" ) -const ConnName = "connprovider" - type ConnProvider struct { sync.Mutex multiConns map[string]transport.MultiConn - providers map[string]*PeerProvider - peerIds []string -} - -func (c *ConnProvider) Init(a *app.App) (err error) { - return -} - -func (c *ConnProvider) Name() (name string) { - return ConnName -} - -func (c *ConnProvider) Observe(provider *PeerProvider, peerId string) { - c.Lock() - defer c.Unlock() - c.providers[peerId] = provider -} - -func (c *ConnProvider) GetPeerIds() []string { - return c.peerIds } func (c *ConnProvider) GetConn(firstId, secondId string) (conn transport.MultiConn) { c.Lock() defer c.Unlock() - if firstId == secondId { - panic("cannot connect to self") - } id := mapId(firstId, secondId) if conn, ok := c.multiConns[id]; ok { return conn @@ -48,18 +22,12 @@ func (c *ConnProvider) GetConn(firstId, secondId string) (conn transport.MultiCo first, second := rpctest.MultiConnPair(firstId, secondId) c.multiConns[id] = second c.multiConns[mapId(secondId, firstId)] = first - err := c.providers[secondId].StartPeer(secondId, second) - if err != nil { - panic(err) - } return second } -func NewConnProvider(peerIds []string) *ConnProvider { +func NewConnProvider() *ConnProvider { return &ConnProvider{ - peerIds: peerIds, multiConns: make(map[string]transport.MultiConn), - providers: make(map[string]*PeerProvider), } } diff --git a/commonspace/sync/synctest/countergenerator.go b/commonspace/sync/synctest/countergenerator.go index 17ec19ba..88d05ae0 100644 --- a/commonspace/sync/synctest/countergenerator.go +++ b/commonspace/sync/synctest/countergenerator.go @@ -33,7 +33,6 @@ func NewCounterGenerator() *CounterGenerator { func (c *CounterGenerator) Init(a *app.App) (err error) { c.counter = a.MustComponent(CounterName).(*Counter) - c.connProvider = a.MustComponent(ConnName).(*ConnProvider) c.peerProvider = a.MustComponent(PeerName).(*PeerProvider) c.ownId = c.peerProvider.myPeer c.streamPool = a.MustComponent(streampool.CName).(streampool.StreamPool) @@ -53,7 +52,7 @@ func (c *CounterGenerator) update(ctx context.Context) error { Value: res, ObjectId: "counter", }, func(ctx context.Context) (peers []peer.Peer, err error) { - for _, peerId := range c.connProvider.GetPeerIds() { + for _, peerId := range c.peerProvider.GetPeerIds() { if peerId == c.ownId { continue } diff --git a/commonspace/sync/synctest/peerglobalpool.go b/commonspace/sync/synctest/peerglobalpool.go new file mode 100644 index 00000000..9c460faa --- /dev/null +++ b/commonspace/sync/synctest/peerglobalpool.go @@ -0,0 +1,89 @@ +package synctest + +import ( + "context" + "fmt" + "net" + "strings" + "sync" + + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/rpc" +) + +const PeerGlobalName = "peerglobalpool" + +type connCtrl interface { + ServeConn(ctx context.Context, conn net.Conn) (err error) + DrpcConfig() rpc.Config +} + +type connCtrlWrapper struct { + connCtrl +} + +type PeerGlobalPool struct { + ctrls map[string]*connCtrlWrapper + peers map[string]peer.Peer + peerIds []string + connProvider *ConnProvider + sync.Mutex +} + +func NewPeerGlobalPool(peerIds []string) *PeerGlobalPool { + return &PeerGlobalPool{ + peerIds: peerIds, + ctrls: make(map[string]*connCtrlWrapper), + peers: make(map[string]peer.Peer), + connProvider: NewConnProvider(), + } +} + +func (p *PeerGlobalPool) Init(a *app.App) (err error) { + return nil +} + +func (p *PeerGlobalPool) Name() (name string) { + return PeerGlobalName +} + +func (p *PeerGlobalPool) MakePeers() { + p.Lock() + defer p.Unlock() + for _, first := range p.peerIds { + for _, second := range p.peerIds { + if first == second { + continue + } + id := mapId(first, second) + p.ctrls[id] = &connCtrlWrapper{} + conn := p.connProvider.GetConn(first, second) + p.peers[id], _ = peer.NewPeer(conn, p.ctrls[id]) + } + } +} + +func (c *PeerGlobalPool) GetPeerIds() (peerIds []string) { + return c.peerIds +} + +func (p *PeerGlobalPool) AddCtrl(peerId string, addCtrl connCtrl) { + p.Lock() + defer p.Unlock() + for id, ctrl := range p.ctrls { + splitId := strings.Split(id, "-") + if splitId[0] == peerId { + ctrl.connCtrl = addCtrl + } + } +} + +func (p *PeerGlobalPool) GetPeer(id string) (peer.Peer, error) { + p.Lock() + defer p.Unlock() + if pr, ok := p.peers[id]; ok { + return pr, nil + } + return nil, fmt.Errorf("peer not found") +} diff --git a/commonspace/sync/synctest/peerprovider.go b/commonspace/sync/synctest/peerprovider.go index 42f4c0d7..bdfe2b30 100644 --- a/commonspace/sync/synctest/peerprovider.go +++ b/commonspace/sync/synctest/peerprovider.go @@ -1,29 +1,37 @@ package synctest import ( + "context" "sync" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/rpc/rpctest" "github.com/anyproto/any-sync/net/rpc/server" - "github.com/anyproto/any-sync/net/transport" ) const PeerName = "peerprovider" type PeerProvider struct { sync.Mutex - myPeer string - peers map[string]peer.Peer - connProvider *ConnProvider - server *rpctest.TestServer + myPeer string + peers map[string]peer.Peer + pool *PeerGlobalPool + server *rpctest.TestServer +} + +func (c *PeerProvider) Run(ctx context.Context) (err error) { + c.pool.AddCtrl(c.myPeer, c.server) + return nil +} + +func (c *PeerProvider) Close(ctx context.Context) (err error) { + return nil } func (c *PeerProvider) Init(a *app.App) (err error) { - c.connProvider = a.MustComponent(ConnName).(*ConnProvider) + c.pool = a.MustComponent(PeerGlobalName).(*PeerGlobalPool) c.server = a.MustComponent(server.CName).(*rpctest.TestServer) - c.connProvider.Observe(c, c.myPeer) return } @@ -31,11 +39,8 @@ func (c *PeerProvider) Name() (name string) { return PeerName } -func (c *PeerProvider) StartPeer(peerId string, conn transport.MultiConn) (err error) { - c.Lock() - defer c.Unlock() - c.peers[peerId], err = peer.NewPeer(conn, c.server) - return err +func (c *PeerProvider) GetPeerIds() (peerIds []string) { + return c.pool.GetPeerIds() } func (c *PeerProvider) GetPeer(peerId string) (pr peer.Peer, err error) { @@ -44,8 +49,7 @@ func (c *PeerProvider) GetPeer(peerId string) (pr peer.Peer, err error) { if pr, ok := c.peers[peerId]; ok { return pr, nil } - conn := c.connProvider.GetConn(c.myPeer, peerId) - c.peers[peerId], err = peer.NewPeer(conn, c.server) + c.peers[peerId], err = c.pool.GetPeer(mapId(c.myPeer, peerId)) if err != nil { return nil, err }