diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go index bf8cf2c5..98ff9ce9 100644 --- a/net/dialer/dialer.go +++ b/net/dialer/dialer.go @@ -109,7 +109,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc } timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) - sc, err = d.transport.TLSConn(ctx, timeoutConn) + sc, err = d.transport.SecureOutbound(ctx, timeoutConn) if err != nil { return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go index e5a9c9df..7a3a598f 100644 --- a/net/rpc/server/baseserver.go +++ b/net/rpc/server/baseserver.go @@ -3,6 +3,7 @@ package server import ( "context" "github.com/anytypeio/any-sync/net/secureservice" + "github.com/libp2p/go-libp2p/core/sec" "github.com/zeebo/errs" "go.uber.org/zap" "io" @@ -18,19 +19,18 @@ import ( type BaseDrpcServer struct { drpcServer *drpcserver.Server transport secureservice.SecureService - listeners []secureservice.ContextListener + listeners []net.Listener + handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) cancel func() *drpcmux.Mux } type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler -type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener type Params struct { BufferSizeMb int ListenAddrs []string Wrapper DRPCHandlerWrapper - Converter ListenerConverter TimeoutMillis int } @@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) { }}) ctx, s.cancel = context.WithCancel(ctx) for _, addr := range params.ListenAddrs { - tcpList, err := net.Listen("tcp", addr) + list, err := net.Listen("tcp", addr) if err != nil { return err } - tlsList := params.Converter(tcpList, params.TimeoutMillis) - s.listeners = append(s.listeners, tlsList) - go s.serve(ctx, tlsList) + s.listeners = append(s.listeners, list) + go s.serve(ctx, list) } return } -func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) { +func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) { l := log.With(zap.String("localAddr", lis.Addr().String())) l.Info("drpc listener started") defer func() { @@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis return default: } - cctx, conn, err := lis.Accept(ctx) + conn, err := lis.Accept() if err != nil { if isTemporary(err) { l.Debug("listener temporary accept error", zap.Error(err)) @@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis l.Error("listener accept error", zap.Error(err)) return } - go s.serveConn(cctx, conn) + go s.serveConn(conn) } } -func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) { +func (s *BaseDrpcServer) serveConn(conn net.Conn) { l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) + var ( + ctx = context.Background() + err error + ) + if s.handshake != nil { + ctx, conn, err = s.handshake(conn) + if err != nil { + l.Info("handshake error", zap.Error(err)) + } + } + l.Debug("connection opened") if err := s.drpcServer.ServeOne(ctx, conn); err != nil { if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { diff --git a/net/rpc/server/drpcserver.go b/net/rpc/server/drpcserver.go index b71dcadc..088eb777 100644 --- a/net/rpc/server/drpcserver.go +++ b/net/rpc/server/drpcserver.go @@ -7,9 +7,11 @@ import ( "github.com/anytypeio/any-sync/metric" anyNet "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/secureservice" + "github.com/libp2p/go-libp2p/core/sec" "github.com/prometheus/client_golang/prometheus" "net" "storj.io/drpc" + "time" ) const CName = "common.net.drpcserver" @@ -68,9 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) { SummaryVec: histVec, } }, - Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener { - return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake) - }, + } + s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return s.transport.SecureInbound(ctx, conn) } return s.BaseDrpcServer.Run(ctx, params) } diff --git a/net/secureservice/basiclistener.go b/net/secureservice/basiclistener.go deleted file mode 100644 index b7df9caa..00000000 --- a/net/secureservice/basiclistener.go +++ /dev/null @@ -1,26 +0,0 @@ -package secureservice - -import ( - "context" - "github.com/anytypeio/any-sync/net/timeoutconn" - "net" - "time" -) - -type basicListener struct { - net.Listener - timeoutMillis int -} - -func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener { - return &basicListener{listener, timeoutMillis} -} - -func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) { - conn, err := b.Listener.Accept() - if err != nil { - return nil, nil, err - } - timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond) - return ctx, timeoutConn, err -} diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index 48336aec..2e66ab5f 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -6,6 +6,7 @@ import ( "github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/commonspace/object/accountdata" + "github.com/anytypeio/any-sync/net/peer" "github.com/anytypeio/any-sync/net/secureservice/handshake" "github.com/anytypeio/any-sync/nodeconf" "github.com/libp2p/go-libp2p/core/crypto" @@ -37,20 +38,20 @@ func New() SecureService { } type SecureService interface { - TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener - BasicListener(lis net.Listener, timeoutMillis int) ContextListener - TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) + SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) + SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) app.Component } type secureService struct { - outboundTr *libp2ptls.Transport - account *accountdata.AccountData - key crypto.PrivKey - nodeconf nodeconf.Service + p2pTr *libp2ptls.Transport + account *accountdata.AccountData + key crypto.PrivKey + nodeconf nodeconf.Service noVerifyChecker handshake.CredentialChecker peerSignVerifier handshake.CredentialChecker + inboundChecker handshake.CredentialChecker } func (s *secureService) Init(a *app.App) (err error) { @@ -68,7 +69,14 @@ func (s *secureService) Init(a *app.App) (err error) { s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service) - if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil { + s.inboundChecker = s.noVerifyChecker + confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId) + if len(confTypes) > 0 { + // require identity verification if we are node + s.inboundChecker = s.peerSignVerifier + } + + if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil { return } @@ -80,20 +88,30 @@ func (s *secureService) Name() (name string) { return CName } -func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener { - cc := s.noVerifyChecker - if identityHandshake { - cc = s.peerSignVerifier +func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { + sc, err = s.p2pTr.SecureInbound(ctx, conn, "") + if err != nil { + return nil, nil, HandshakeError{ + remoteAddr: conn.RemoteAddr().String(), + err: err, + } } - return newTLSListener(cc, s.key, lis, timeoutMillis) + + identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker) + if err != nil { + return nil, nil, HandshakeError{ + remoteAddr: conn.RemoteAddr().String(), + err: err, + } + } + cctx = context.Background() + cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) + cctx = peer.CtxWithIdentity(cctx, identity) + return } -func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener { - return newBasicListener(lis, timeoutMillis) -} - -func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) { - sc, err := s.outboundTr.SecureOutbound(ctx, conn, "") +func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) { + sc, err := s.p2pTr.SecureOutbound(ctx, conn, "") if err != nil { return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} } diff --git a/net/secureservice/secureservice_test.go b/net/secureservice/secureservice_test.go index d91e5050..2807ef84 100644 --- a/net/secureservice/secureservice_test.go +++ b/net/secureservice/secureservice_test.go @@ -20,10 +20,8 @@ func TestHandshake(t *testing.T) { fxS := newFixture(t, nc, nc.GetAccountService(0)) defer fxS.Finish(t) - tl := &testListener{conn: make(chan net.Conn, 1)} - defer tl.Close() + sc, cc := net.Pipe() - list := fxS.TLSListener(tl, 1000, true) type acceptRes struct { ctx context.Context conn net.Conn @@ -32,16 +30,14 @@ func TestHandshake(t *testing.T) { resCh := make(chan acceptRes) go func() { var ar acceptRes - ar.ctx, ar.conn, ar.err = list.Accept(ctx) + ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) resCh <- ar }() fxC := newFixture(t, nc, nc.GetAccountService(1)) defer fxC.Finish(t) - sc, cc := net.Pipe() - tl.conn <- sc - secConn, err := fxC.TLSConn(ctx, cc) + secConn, err := fxC.SecureOutbound(ctx, cc) require.NoError(t, err) assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) res := <-resCh @@ -61,7 +57,7 @@ func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Servic a: new(app.App), } - fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc) + fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService) require.NoError(t, fx.a.Start(ctx)) return fx } @@ -75,24 +71,3 @@ type fixture struct { func (fx *fixture) Finish(t *testing.T) { require.NoError(t, fx.a.Close(ctx)) } - -type testListener struct { - conn chan net.Conn -} - -func (t *testListener) Accept() (net.Conn, error) { - conn, ok := <-t.conn - if !ok { - return nil, net.ErrClosed - } - return conn, nil -} - -func (t *testListener) Close() error { - close(t.conn) - return nil -} - -func (t *testListener) Addr() net.Addr { - return nil -} diff --git a/net/secureservice/tlslistener.go b/net/secureservice/tlslistener.go deleted file mode 100644 index 65c13fe4..00000000 --- a/net/secureservice/tlslistener.go +++ /dev/null @@ -1,70 +0,0 @@ -package secureservice - -import ( - "context" - "github.com/anytypeio/any-sync/net/peer" - "github.com/anytypeio/any-sync/net/secureservice/handshake" - "github.com/anytypeio/any-sync/net/timeoutconn" - "github.com/libp2p/go-libp2p/core/crypto" - libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" - "net" - "time" -) - -type ContextListener interface { - // Accept works like net.Listener accept but add context - Accept(ctx context.Context) (context.Context, net.Conn, error) - - // Close closes the listener. - // Any blocked Accept operations will be unblocked and return errors. - Close() error - - // Addr returns the listener's network address. - Addr() net.Addr -} - -func newTLSListener(cc handshake.CredentialChecker, key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener { - tr, _ := libp2ptls.New(libp2ptls.ID, key, nil) - return &tlsListener{ - cc: cc, - tr: tr, - Listener: lis, - timeoutMillis: timeoutMillis, - } -} - -type tlsListener struct { - net.Listener - tr *libp2ptls.Transport - timeoutMillis int - cc handshake.CredentialChecker -} - -func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) { - conn, err := p.Listener.Accept() - if err != nil { - return nil, nil, err - } - timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond) - return p.upgradeConn(ctx, timeoutConn) -} - -func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) { - secure, err := p.tr.SecureInbound(ctx, conn, "") - if err != nil { - return nil, nil, HandshakeError{ - remoteAddr: conn.RemoteAddr().String(), - err: err, - } - } - identity, err := handshake.IncomingHandshake(nil, secure, p.cc) - if err != nil { - return nil, nil, HandshakeError{ - remoteAddr: conn.RemoteAddr().String(), - err: err, - } - } - ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String()) - ctx = peer.CtxWithIdentity(ctx, identity) - return ctx, secure, nil -} diff --git a/nodeconf/service.go b/nodeconf/service.go index 279f35bc..e78b4a93 100644 --- a/nodeconf/service.go +++ b/nodeconf/service.go @@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) { } members = append(members, member) } - if n.PeerId == s.accountId { - continue - } if n.HasType(NodeTypeConsensus) { fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId) }