From 2edd9c7b6dae742003c8bddd3bf253fb6868f65d Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Fri, 11 Aug 2023 15:34:05 +0200 Subject: [PATCH] addr scheme + fixes + peerservice test --- net/peerservice/peerservice.go | 72 +++++++-- net/peerservice/peerservice_test.go | 170 ++++++++++++++++++++++ net/transport/mock_transport/component.go | 26 ++++ net/transport/quic/config.go | 9 +- net/transport/quic/conn.go | 10 +- net/transport/quic/quic.go | 19 ++- net/transport/quic/quic_test.go | 2 +- net/transport/transport.go | 5 + net/transport/yamux/conn.go | 4 +- 9 files changed, 296 insertions(+), 21 deletions(-) create mode 100644 net/peerservice/peerservice_test.go create mode 100644 net/transport/mock_transport/component.go diff --git a/net/peerservice/peerservice.go b/net/peerservice/peerservice.go index dae7bf76..18e6d692 100644 --- a/net/peerservice/peerservice.go +++ b/net/peerservice/peerservice.go @@ -9,9 +9,11 @@ import ( "github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/rpc/server" "github.com/anyproto/any-sync/net/transport" + "github.com/anyproto/any-sync/net/transport/quic" "github.com/anyproto/any-sync/net/transport/yamux" "github.com/anyproto/any-sync/nodeconf" "go.uber.org/zap" + "strings" "sync" ) @@ -31,33 +33,49 @@ func New() PeerService { type PeerService interface { Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) SetPeerAddrs(peerId string, addrs []string) + PreferQuic(prefer bool) transport.Accepter app.Component } type peerService struct { - yamux transport.Transport - nodeConf nodeconf.NodeConf - peerAddrs map[string][]string - pool pool.Pool - server server.DRPCServer - mu sync.RWMutex + yamux transport.Transport + quic transport.Transport + nodeConf nodeconf.NodeConf + peerAddrs map[string][]string + pool pool.Pool + server server.DRPCServer + preferQuic bool + mu sync.RWMutex } func (p *peerService) Init(a *app.App) (err error) { p.yamux = a.MustComponent(yamux.CName).(transport.Transport) + p.quic = a.MustComponent(quic.CName).(transport.Transport) p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) p.pool = a.MustComponent(pool.CName).(pool.Pool) p.server = a.MustComponent(server.CName).(server.DRPCServer) p.peerAddrs = map[string][]string{} p.yamux.SetAccepter(p) + p.quic.SetAccepter(p) return nil } +var ( + yamuxPreferSchemes = []string{transport.Yamux, transport.Quic} + quicPreferSchemes = []string{transport.Quic, transport.Yamux} +) + func (p *peerService) Name() (name string) { return CName } +func (p *peerService) PreferQuic(prefer bool) { + p.mu.Lock() + p.preferQuic = prefer + p.mu.Unlock() +} + func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) { p.mu.RLock() defer p.mu.RUnlock() @@ -69,11 +87,29 @@ func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, er var mc transport.MultiConn log.DebugCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) - for _, addr := range addrs { - mc, err = p.yamux.Dial(ctx, addr) - if err != nil { - log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) - } else { + + var schemes = yamuxPreferSchemes + if p.preferQuic { + schemes = quicPreferSchemes + } + + for _, sch := range schemes { + for _, addr := range addrs { + if scheme(addr) != sch { + continue + } + if sch == transport.Quic { + mc, err = p.quic.Dial(ctx, stripScheme(addr)) + } else { + mc, err = p.yamux.Dial(ctx, stripScheme(addr)) + } + if err != nil { + log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) + } else { + break + } + } + if err == nil { break } } @@ -117,3 +153,17 @@ func (p *peerService) getPeerAddrs(peerId string) ([]string, error) { } return addrs, nil } + +func scheme(addr string) string { + if idx := strings.Index(addr, "://"); idx != -1 { + return addr[:idx] + } + return transport.Yamux +} + +func stripScheme(addr string) string { + if idx := strings.Index(addr, "://"); idx != -1 { + return addr[idx+3:] + } + return addr +} diff --git a/net/peerservice/peerservice_test.go b/net/peerservice/peerservice_test.go new file mode 100644 index 00000000..0856f5bc --- /dev/null +++ b/net/peerservice/peerservice_test.go @@ -0,0 +1,170 @@ +package peerservice + +import ( + "context" + "fmt" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/pool" + "github.com/anyproto/any-sync/net/rpc/rpctest" + "github.com/anyproto/any-sync/net/transport/mock_transport" + "github.com/anyproto/any-sync/net/transport/quic" + "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/nodeconf/mock_nodeconf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" +) + +var ctx = context.Background() + +func TestPeerService_Dial(t *testing.T) { + var addrs = []string{ + "yamux://127.0.0.1:1111", + "quic://127.0.0.1:1112", + } + t.Run("prefer yamux", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(false) + var peerId = "p1" + + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return(addrs, true) + + fx.yamux.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1111").Return(fx.mockMC(peerId), nil) + + p, err := fx.Dial(ctx, peerId) + require.NoError(t, err) + assert.NotNil(t, p) + }) + t.Run("prefer quic", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(true) + var peerId = "p1" + + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return(addrs, true) + + fx.quic.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1112").Return(fx.mockMC(peerId), nil) + + p, err := fx.Dial(ctx, peerId) + require.NoError(t, err) + assert.NotNil(t, p) + }) + t.Run("first failed", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(true) + var peerId = "p1" + + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return(addrs, true) + + fx.quic.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1112").Return(nil, fmt.Errorf("test")) + fx.yamux.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1111").Return(fx.mockMC(peerId), nil) + + p, err := fx.Dial(ctx, peerId) + require.NoError(t, err) + assert.NotNil(t, p) + }) + t.Run("peerId mismatched", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(false) + var peerId = "p1" + + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return(addrs, true) + + fx.yamux.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1111").Return(fx.mockMC(peerId+"not valid"), nil) + + p, err := fx.Dial(ctx, peerId) + assert.EqualError(t, err, ErrPeerIdMismatched.Error()) + assert.Nil(t, p) + }) + t.Run("custom addr", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(false) + var peerId = "p1" + + fx.SetPeerAddrs(peerId, addrs) + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return(nil, false) + + fx.yamux.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1111").Return(fx.mockMC(peerId), nil) + + p, err := fx.Dial(ctx, peerId) + require.NoError(t, err) + assert.NotNil(t, p) + }) + t.Run("addr without scheme", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.PreferQuic(false) + var peerId = "p1" + + fx.nodeConf.EXPECT().PeerAddresses(peerId).Return([]string{"127.0.0.1:1111"}, true) + + fx.yamux.MockTransport.EXPECT().Dial(ctx, "127.0.0.1:1111").Return(fx.mockMC(peerId), nil) + + p, err := fx.Dial(ctx, peerId) + require.NoError(t, err) + assert.NotNil(t, p) + }) +} + +func TestPeerService_Accept(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + + mc := fx.mockMC("p1") + require.NoError(t, fx.Accept(mc)) +} + +type fixture struct { + PeerService + a *app.App + ctrl *gomock.Controller + quic mock_transport.TransportComponent + yamux mock_transport.TransportComponent + nodeConf *mock_nodeconf.MockService +} + +func newFixture(t *testing.T) *fixture { + ctrl := gomock.NewController(t) + fx := &fixture{ + PeerService: New(), + ctrl: ctrl, + a: new(app.App), + quic: mock_transport.NewTransportComponent(ctrl, quic.CName), + yamux: mock_transport.NewTransportComponent(ctrl, yamux.CName), + nodeConf: mock_nodeconf.NewMockService(ctrl), + } + + fx.quic.EXPECT().SetAccepter(fx.PeerService) + fx.yamux.EXPECT().SetAccepter(fx.PeerService) + + fx.nodeConf.EXPECT().Name().Return(nodeconf.CName).AnyTimes() + fx.nodeConf.EXPECT().Init(gomock.Any()) + fx.nodeConf.EXPECT().Run(gomock.Any()) + fx.nodeConf.EXPECT().Close(gomock.Any()) + + fx.a.Register(fx.PeerService).Register(fx.quic).Register(fx.yamux).Register(fx.nodeConf).Register(pool.New()).Register(rpctest.NewTestServer()) + + require.NoError(t, fx.a.Start(ctx)) + return fx +} + +func (fx *fixture) mockMC(peerId string) *mock_transport.MockMultiConn { + mc := mock_transport.NewMockMultiConn(fx.ctrl) + cctx := peer.CtxWithPeerId(ctx, peerId) + mc.EXPECT().Context().Return(cctx).AnyTimes() + mc.EXPECT().Accept().Return(nil, fmt.Errorf("test")).AnyTimes() + mc.EXPECT().Close().AnyTimes() + return mc +} + +func (fx *fixture) finish(t *testing.T) { + require.NoError(t, fx.a.Close(ctx)) + fx.ctrl.Finish() +} diff --git a/net/transport/mock_transport/component.go b/net/transport/mock_transport/component.go new file mode 100644 index 00000000..2d572501 --- /dev/null +++ b/net/transport/mock_transport/component.go @@ -0,0 +1,26 @@ +package mock_transport + +import ( + "github.com/anyproto/any-sync/app" + "go.uber.org/mock/gomock" +) + +func NewTransportComponent(ctrl *gomock.Controller, name string) TransportComponent { + return TransportComponent{ + CName: name, + MockTransport: NewMockTransport(ctrl), + } +} + +type TransportComponent struct { + CName string + *MockTransport +} + +func (t TransportComponent) Init(a *app.App) (err error) { + return nil +} + +func (t TransportComponent) Name() (name string) { + return t.CName +} diff --git a/net/transport/quic/config.go b/net/transport/quic/config.go index a08dbbda..2292a260 100644 --- a/net/transport/quic/config.go +++ b/net/transport/quic/config.go @@ -5,8 +5,9 @@ type configGetter interface { } type Config struct { - ListenAddrs []string `yaml:"listenAddrs"` - WriteTimeoutSec int `yaml:"writeTimeoutSec"` - DialTimeoutSec int `yaml:"dialTimeoutSec"` - MaxStreams int64 `yaml:"maxStreams"` + ListenAddrs []string `yaml:"listenAddrs"` + WriteTimeoutSec int `yaml:"writeTimeoutSec"` + DialTimeoutSec int `yaml:"dialTimeoutSec"` + MaxStreams int64 `yaml:"maxStreams"` + KeepAlivePeriodSec int `yaml:"keepAlivePeriodSec"` } diff --git a/net/transport/quic/conn.go b/net/transport/quic/conn.go index 5ed704af..33df01d6 100644 --- a/net/transport/quic/conn.go +++ b/net/transport/quic/conn.go @@ -2,12 +2,15 @@ package quic import ( "context" + "errors" + "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/transport" "github.com/quic-go/quic-go" "net" ) func newConn(cctx context.Context, qconn quic.Connection) transport.MultiConn { + cctx = peer.CtxWithPeerAddr(cctx, transport.Quic+"://"+qconn.RemoteAddr().String()) return &quicMultiConn{ cctx: cctx, Connection: qconn, @@ -26,6 +29,11 @@ func (q *quicMultiConn) Context() context.Context { func (q *quicMultiConn) Accept() (conn net.Conn, err error) { stream, err := q.Connection.AcceptStream(context.Background()) if err != nil { + if errors.Is(err, quic.ErrServerClosed) { + err = transport.ErrConnClosed + } else if aerr, ok := err.(*quic.ApplicationError); ok && aerr.ErrorCode == 2 { + err = transport.ErrConnClosed + } return nil, err } return quicNetConn{ @@ -48,7 +56,7 @@ func (q *quicMultiConn) Open(ctx context.Context) (conn net.Conn, err error) { } func (q *quicMultiConn) Addr() string { - return q.RemoteAddr().String() + return transport.Quic + "://" + q.RemoteAddr().String() } func (q *quicMultiConn) IsClosed() bool { diff --git a/net/transport/quic/quic.go b/net/transport/quic/quic.go index 5024e38e..8213e9b2 100644 --- a/net/transport/quic/quic.go +++ b/net/transport/quic/quic.go @@ -2,6 +2,7 @@ package quic import ( "context" + "crypto/tls" "fmt" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" @@ -42,9 +43,16 @@ type quicTransport struct { func (q *quicTransport) Init(a *app.App) (err error) { q.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService) q.conf = a.MustComponent("config").(configGetter).GetQuic() + if q.conf.MaxStreams <= 0 { + q.conf.MaxStreams = 128 + } + if q.conf.KeepAlivePeriodSec <= 0 { + q.conf.KeepAlivePeriodSec = 25 + } q.quicConf = &quic.Config{ HandshakeIdleTimeout: time.Duration(q.conf.DialTimeoutSec) * time.Second, MaxIncomingStreams: q.conf.MaxStreams, + KeepAlivePeriod: time.Duration(q.conf.KeepAlivePeriodSec) * time.Second, } return } @@ -61,12 +69,19 @@ func (q *quicTransport) Run(ctx context.Context) (err error) { if q.accepter == nil { return fmt.Errorf("can't run service without accepter") } - tlConf, _, err := q.secure.TlsConfig() + + var tlsConf tls.Config + tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + conf, _, tlsErr := q.secure.TlsConfig() + return conf, tlsErr + } + tlsConf.NextProtos = []string{"anysync"} + if err != nil { return } for _, listAddr := range q.conf.ListenAddrs { - list, err := quic.ListenAddr(listAddr, tlConf, q.quicConf) + list, err := quic.ListenAddr(listAddr, &tlsConf, q.quicConf) if err != nil { return err } diff --git a/net/transport/quic/quic_test.go b/net/transport/quic/quic_test.go index 2d47aa8c..7ed01d75 100644 --- a/net/transport/quic/quic_test.go +++ b/net/transport/quic/quic_test.go @@ -74,7 +74,7 @@ func TestQuicTransport_Dial(t *testing.T) { // common write deadline - 66700 rps // subconn write deadline - 67100 rps func TestWriteBenchReuse(t *testing.T) { - t.Skip() + //t.Skip() var ( numSubConn = 10 numWrites = 10000 diff --git a/net/transport/transport.go b/net/transport/transport.go index 9d36435a..6697a65c 100644 --- a/net/transport/transport.go +++ b/net/transport/transport.go @@ -11,6 +11,11 @@ var ( ErrConnClosed = errors.New("connection closed") ) +const ( + Yamux = "yamux" + Quic = "quic" +) + // Transport is a common interface for a network transport type Transport interface { // SetAccepter sets accepter that will be called for new connections diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go index 6a473796..c34497b2 100644 --- a/net/transport/yamux/conn.go +++ b/net/transport/yamux/conn.go @@ -12,7 +12,7 @@ import ( ) func NewMultiConn(cctx context.Context, luConn *connutil.LastUsageConn, addr string, sess *yamux.Session) transport.MultiConn { - cctx = peer.CtxWithPeerAddr(cctx, sess.RemoteAddr().String()) + cctx = peer.CtxWithPeerAddr(cctx, transport.Yamux+"://"+sess.RemoteAddr().String()) return &yamuxConn{ ctx: cctx, luConn: luConn, @@ -44,7 +44,7 @@ func (y *yamuxConn) Context() context.Context { } func (y *yamuxConn) Addr() string { - return y.addr + return transport.Yamux + "://" + y.addr } func (y *yamuxConn) Accept() (conn net.Conn, err error) {