1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00

remove tls listener, move handshake to serve goroutine

This commit is contained in:
Sergey Cherepanov 2023-02-17 00:09:52 +03:00 committed by Mikhail Iudin
parent f796cc6c6d
commit 1a588d0a72
No known key found for this signature in database
GPG key ID: FAAAA8BAABDFF1C0
8 changed files with 70 additions and 162 deletions

View file

@ -109,7 +109,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc
}
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
sc, err = d.transport.TLSConn(ctx, timeoutConn)
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
if err != nil {
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
}

View file

@ -3,6 +3,7 @@ package server
import (
"context"
"github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs"
"go.uber.org/zap"
"io"
@ -18,19 +19,18 @@ import (
type BaseDrpcServer struct {
drpcServer *drpcserver.Server
transport secureservice.SecureService
listeners []secureservice.ContextListener
listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func()
*drpcmux.Mux
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener
type Params struct {
BufferSizeMb int
ListenAddrs []string
Wrapper DRPCHandlerWrapper
Converter ListenerConverter
TimeoutMillis int
}
@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
}})
ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs {
tcpList, err := net.Listen("tcp", addr)
list, err := net.Listen("tcp", addr)
if err != nil {
return err
}
tlsList := params.Converter(tcpList, params.TimeoutMillis)
s.listeners = append(s.listeners, tlsList)
go s.serve(ctx, tlsList)
s.listeners = append(s.listeners, list)
go s.serve(ctx, list)
}
return
}
func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) {
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started")
defer func() {
@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
return
default:
}
cctx, conn, err := lis.Accept(ctx)
conn, err := lis.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
l.Error("listener accept error", zap.Error(err))
return
}
go s.serveConn(cctx, conn)
go s.serveConn(conn)
}
}
func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) {
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
}
}
l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {

View file

@ -7,9 +7,11 @@ import (
"github.com/anytypeio/any-sync/metric"
anyNet "github.com/anytypeio/any-sync/net"
"github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/prometheus/client_golang/prometheus"
"net"
"storj.io/drpc"
"time"
)
const CName = "common.net.drpcserver"
@ -68,9 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) {
SummaryVec: histVec,
}
},
Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener {
return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake)
},
}
s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
}
return s.BaseDrpcServer.Run(ctx, params)
}

View file

@ -1,26 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/timeoutconn"
"net"
"time"
)
type basicListener struct {
net.Listener
timeoutMillis int
}
func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener {
return &basicListener{listener, timeoutMillis}
}
func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := b.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond)
return ctx, timeoutConn, err
}

View file

@ -6,6 +6,7 @@ import (
"github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto"
@ -37,20 +38,20 @@ func New() SecureService {
}
type SecureService interface {
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
app.Component
}
type secureService struct {
outboundTr *libp2ptls.Transport
account *accountdata.AccountData
key crypto.PrivKey
nodeconf nodeconf.Service
p2pTr *libp2ptls.Transport
account *accountdata.AccountData
key crypto.PrivKey
nodeconf nodeconf.Service
noVerifyChecker handshake.CredentialChecker
peerSignVerifier handshake.CredentialChecker
inboundChecker handshake.CredentialChecker
}
func (s *secureService) Init(a *app.App) (err error) {
@ -68,7 +69,14 @@ func (s *secureService) Init(a *app.App) (err error) {
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
s.inboundChecker = s.noVerifyChecker
confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId)
if len(confTypes) > 0 {
// require identity verification if we are node
s.inboundChecker = s.peerSignVerifier
}
if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
return
}
@ -80,20 +88,30 @@ func (s *secureService) Name() (name string) {
return CName
}
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
cc := s.noVerifyChecker
if identityHandshake {
cc = s.peerSignVerifier
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
return newTLSListener(cc, s.key, lis, timeoutMillis)
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return
}
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
return newBasicListener(lis, timeoutMillis)
}
func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.outboundTr.SecureOutbound(ctx, conn, "")
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
}

View file

@ -20,10 +20,8 @@ func TestHandshake(t *testing.T) {
fxS := newFixture(t, nc, nc.GetAccountService(0))
defer fxS.Finish(t)
tl := &testListener{conn: make(chan net.Conn, 1)}
defer tl.Close()
sc, cc := net.Pipe()
list := fxS.TLSListener(tl, 1000, true)
type acceptRes struct {
ctx context.Context
conn net.Conn
@ -32,16 +30,14 @@ func TestHandshake(t *testing.T) {
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = list.Accept(ctx)
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1))
defer fxC.Finish(t)
sc, cc := net.Pipe()
tl.conn <- sc
secConn, err := fxC.TLSConn(ctx, cc)
secConn, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
res := <-resCh
@ -61,7 +57,7 @@ func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Servic
a: new(app.App),
}
fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc)
fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService)
require.NoError(t, fx.a.Start(ctx))
return fx
}
@ -75,24 +71,3 @@ type fixture struct {
func (fx *fixture) Finish(t *testing.T) {
require.NoError(t, fx.a.Close(ctx))
}
type testListener struct {
conn chan net.Conn
}
func (t *testListener) Accept() (net.Conn, error) {
conn, ok := <-t.conn
if !ok {
return nil, net.ErrClosed
}
return conn, nil
}
func (t *testListener) Close() error {
close(t.conn)
return nil
}
func (t *testListener) Addr() net.Addr {
return nil
}

View file

@ -1,70 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/net/timeoutconn"
"github.com/libp2p/go-libp2p/core/crypto"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"net"
"time"
)
type ContextListener interface {
// Accept works like net.Listener accept but add context
Accept(ctx context.Context) (context.Context, net.Conn, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Addr returns the listener's network address.
Addr() net.Addr
}
func newTLSListener(cc handshake.CredentialChecker, key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener {
tr, _ := libp2ptls.New(libp2ptls.ID, key, nil)
return &tlsListener{
cc: cc,
tr: tr,
Listener: lis,
timeoutMillis: timeoutMillis,
}
}
type tlsListener struct {
net.Listener
tr *libp2ptls.Transport
timeoutMillis int
cc handshake.CredentialChecker
}
func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := p.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond)
return p.upgradeConn(ctx, timeoutConn)
}
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
secure, err := p.tr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
ctx = peer.CtxWithIdentity(ctx, identity)
return ctx, secure, nil
}

View file

@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) {
}
members = append(members, member)
}
if n.PeerId == s.accountId {
continue
}
if n.HasType(NodeTypeConsensus) {
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
}