1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 14:07:02 +09:00

remove tls listener, move handshake to serve goroutine

This commit is contained in:
Sergey Cherepanov 2023-02-17 00:09:52 +03:00 committed by Mikhail Iudin
parent f796cc6c6d
commit 1a588d0a72
No known key found for this signature in database
GPG key ID: FAAAA8BAABDFF1C0
8 changed files with 70 additions and 162 deletions

View file

@ -109,7 +109,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc
} }
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
sc, err = d.transport.TLSConn(ctx, timeoutConn) sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
} }

View file

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/net/secureservice" "github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"io" "io"
@ -18,19 +19,18 @@ import (
type BaseDrpcServer struct { type BaseDrpcServer struct {
drpcServer *drpcserver.Server drpcServer *drpcserver.Server
transport secureservice.SecureService transport secureservice.SecureService
listeners []secureservice.ContextListener listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func() cancel func()
*drpcmux.Mux *drpcmux.Mux
} }
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener
type Params struct { type Params struct {
BufferSizeMb int BufferSizeMb int
ListenAddrs []string ListenAddrs []string
Wrapper DRPCHandlerWrapper Wrapper DRPCHandlerWrapper
Converter ListenerConverter
TimeoutMillis int TimeoutMillis int
} }
@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
}}) }})
ctx, s.cancel = context.WithCancel(ctx) ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs { for _, addr := range params.ListenAddrs {
tcpList, err := net.Listen("tcp", addr) list, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
tlsList := params.Converter(tcpList, params.TimeoutMillis) s.listeners = append(s.listeners, list)
s.listeners = append(s.listeners, tlsList) go s.serve(ctx, list)
go s.serve(ctx, tlsList)
} }
return return
} }
func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) { func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String())) l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started") l.Info("drpc listener started")
defer func() { defer func() {
@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
return return
default: default:
} }
cctx, conn, err := lis.Accept(ctx) conn, err := lis.Accept()
if err != nil { if err != nil {
if isTemporary(err) { if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err)) l.Debug("listener temporary accept error", zap.Error(err))
@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
l.Error("listener accept error", zap.Error(err)) l.Error("listener accept error", zap.Error(err))
return return
} }
go s.serveConn(cctx, conn) go s.serveConn(conn)
} }
} }
func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) { func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
}
}
l.Debug("connection opened") l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil { if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {

View file

@ -7,9 +7,11 @@ import (
"github.com/anytypeio/any-sync/metric" "github.com/anytypeio/any-sync/metric"
anyNet "github.com/anytypeio/any-sync/net" anyNet "github.com/anytypeio/any-sync/net"
"github.com/anytypeio/any-sync/net/secureservice" "github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"net" "net"
"storj.io/drpc" "storj.io/drpc"
"time"
) )
const CName = "common.net.drpcserver" const CName = "common.net.drpcserver"
@ -68,9 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) {
SummaryVec: histVec, SummaryVec: histVec,
} }
}, },
Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener { }
return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake) s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
}, ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
} }
return s.BaseDrpcServer.Run(ctx, params) return s.BaseDrpcServer.Run(ctx, params)
} }

View file

@ -1,26 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/timeoutconn"
"net"
"time"
)
type basicListener struct {
net.Listener
timeoutMillis int
}
func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener {
return &basicListener{listener, timeoutMillis}
}
func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := b.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond)
return ctx, timeoutConn, err
}

View file

@ -6,6 +6,7 @@ import (
"github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/accountdata" "github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/secureservice/handshake" "github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/nodeconf" "github.com/anytypeio/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
@ -37,20 +38,20 @@ func New() SecureService {
} }
type SecureService interface { type SecureService interface {
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
BasicListener(lis net.Listener, timeoutMillis int) ContextListener SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
app.Component app.Component
} }
type secureService struct { type secureService struct {
outboundTr *libp2ptls.Transport p2pTr *libp2ptls.Transport
account *accountdata.AccountData account *accountdata.AccountData
key crypto.PrivKey key crypto.PrivKey
nodeconf nodeconf.Service nodeconf nodeconf.Service
noVerifyChecker handshake.CredentialChecker noVerifyChecker handshake.CredentialChecker
peerSignVerifier handshake.CredentialChecker peerSignVerifier handshake.CredentialChecker
inboundChecker handshake.CredentialChecker
} }
func (s *secureService) Init(a *app.App) (err error) { func (s *secureService) Init(a *app.App) (err error) {
@ -68,7 +69,14 @@ func (s *secureService) Init(a *app.App) (err error) {
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service) s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil { s.inboundChecker = s.noVerifyChecker
confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId)
if len(confTypes) > 0 {
// require identity verification if we are node
s.inboundChecker = s.peerSignVerifier
}
if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
return return
} }
@ -80,20 +88,30 @@ func (s *secureService) Name() (name string) {
return CName return CName
} }
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener { func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
cc := s.noVerifyChecker sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
if identityHandshake { if err != nil {
cc = s.peerSignVerifier return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
} }
return newTLSListener(cc, s.key, lis, timeoutMillis)
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return
} }
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener { func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
return newBasicListener(lis, timeoutMillis) sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
}
func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.outboundTr.SecureOutbound(ctx, conn, "")
if err != nil { if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
} }

View file

@ -20,10 +20,8 @@ func TestHandshake(t *testing.T) {
fxS := newFixture(t, nc, nc.GetAccountService(0)) fxS := newFixture(t, nc, nc.GetAccountService(0))
defer fxS.Finish(t) defer fxS.Finish(t)
tl := &testListener{conn: make(chan net.Conn, 1)} sc, cc := net.Pipe()
defer tl.Close()
list := fxS.TLSListener(tl, 1000, true)
type acceptRes struct { type acceptRes struct {
ctx context.Context ctx context.Context
conn net.Conn conn net.Conn
@ -32,16 +30,14 @@ func TestHandshake(t *testing.T) {
resCh := make(chan acceptRes) resCh := make(chan acceptRes)
go func() { go func() {
var ar acceptRes var ar acceptRes
ar.ctx, ar.conn, ar.err = list.Accept(ctx) ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar resCh <- ar
}() }()
fxC := newFixture(t, nc, nc.GetAccountService(1)) fxC := newFixture(t, nc, nc.GetAccountService(1))
defer fxC.Finish(t) defer fxC.Finish(t)
sc, cc := net.Pipe() secConn, err := fxC.SecureOutbound(ctx, cc)
tl.conn <- sc
secConn, err := fxC.TLSConn(ctx, cc)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
res := <-resCh res := <-resCh
@ -61,7 +57,7 @@ func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Servic
a: new(app.App), a: new(app.App),
} }
fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc) fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService)
require.NoError(t, fx.a.Start(ctx)) require.NoError(t, fx.a.Start(ctx))
return fx return fx
} }
@ -75,24 +71,3 @@ type fixture struct {
func (fx *fixture) Finish(t *testing.T) { func (fx *fixture) Finish(t *testing.T) {
require.NoError(t, fx.a.Close(ctx)) require.NoError(t, fx.a.Close(ctx))
} }
type testListener struct {
conn chan net.Conn
}
func (t *testListener) Accept() (net.Conn, error) {
conn, ok := <-t.conn
if !ok {
return nil, net.ErrClosed
}
return conn, nil
}
func (t *testListener) Close() error {
close(t.conn)
return nil
}
func (t *testListener) Addr() net.Addr {
return nil
}

View file

@ -1,70 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/net/timeoutconn"
"github.com/libp2p/go-libp2p/core/crypto"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"net"
"time"
)
type ContextListener interface {
// Accept works like net.Listener accept but add context
Accept(ctx context.Context) (context.Context, net.Conn, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Addr returns the listener's network address.
Addr() net.Addr
}
func newTLSListener(cc handshake.CredentialChecker, key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener {
tr, _ := libp2ptls.New(libp2ptls.ID, key, nil)
return &tlsListener{
cc: cc,
tr: tr,
Listener: lis,
timeoutMillis: timeoutMillis,
}
}
type tlsListener struct {
net.Listener
tr *libp2ptls.Transport
timeoutMillis int
cc handshake.CredentialChecker
}
func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := p.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond)
return p.upgradeConn(ctx, timeoutConn)
}
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
secure, err := p.tr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
ctx = peer.CtxWithIdentity(ctx, identity)
return ctx, secure, nil
}

View file

@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) {
} }
members = append(members, member) members = append(members, member)
} }
if n.PeerId == s.accountId {
continue
}
if n.HasType(NodeTypeConsensus) { if n.HasType(NodeTypeConsensus) {
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId) fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
} }