mirror of
https://github.com/anyproto/any-sync.git
synced 2025-06-08 05:57:03 +09:00
remove tls listener, move handshake to serve goroutine
This commit is contained in:
parent
f796cc6c6d
commit
1a588d0a72
8 changed files with 70 additions and 162 deletions
|
@ -109,7 +109,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc
|
|||
}
|
||||
|
||||
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
|
||||
sc, err = d.transport.TLSConn(ctx, timeoutConn)
|
||||
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
|
@ -18,19 +19,18 @@ import (
|
|||
type BaseDrpcServer struct {
|
||||
drpcServer *drpcserver.Server
|
||||
transport secureservice.SecureService
|
||||
listeners []secureservice.ContextListener
|
||||
listeners []net.Listener
|
||||
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||
cancel func()
|
||||
*drpcmux.Mux
|
||||
}
|
||||
|
||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||
type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener
|
||||
|
||||
type Params struct {
|
||||
BufferSizeMb int
|
||||
ListenAddrs []string
|
||||
Wrapper DRPCHandlerWrapper
|
||||
Converter ListenerConverter
|
||||
TimeoutMillis int
|
||||
}
|
||||
|
||||
|
@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
|
|||
}})
|
||||
ctx, s.cancel = context.WithCancel(ctx)
|
||||
for _, addr := range params.ListenAddrs {
|
||||
tcpList, err := net.Listen("tcp", addr)
|
||||
list, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsList := params.Converter(tcpList, params.TimeoutMillis)
|
||||
s.listeners = append(s.listeners, tlsList)
|
||||
go s.serve(ctx, tlsList)
|
||||
s.listeners = append(s.listeners, list)
|
||||
go s.serve(ctx, list)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) {
|
||||
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
|
||||
l := log.With(zap.String("localAddr", lis.Addr().String()))
|
||||
l.Info("drpc listener started")
|
||||
defer func() {
|
||||
|
@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
|
|||
return
|
||||
default:
|
||||
}
|
||||
cctx, conn, err := lis.Accept(ctx)
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
if isTemporary(err) {
|
||||
l.Debug("listener temporary accept error", zap.Error(err))
|
||||
|
@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
|
|||
l.Error("listener accept error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
go s.serveConn(cctx, conn)
|
||||
go s.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) {
|
||||
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
|
||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||
var (
|
||||
ctx = context.Background()
|
||||
err error
|
||||
)
|
||||
if s.handshake != nil {
|
||||
ctx, conn, err = s.handshake(conn)
|
||||
if err != nil {
|
||||
l.Info("handshake error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
l.Debug("connection opened")
|
||||
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
|
||||
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
|
||||
|
|
|
@ -7,9 +7,11 @@ import (
|
|||
"github.com/anytypeio/any-sync/metric"
|
||||
anyNet "github.com/anytypeio/any-sync/net"
|
||||
"github.com/anytypeio/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"time"
|
||||
)
|
||||
|
||||
const CName = "common.net.drpcserver"
|
||||
|
@ -68,9 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) {
|
|||
SummaryVec: histVec,
|
||||
}
|
||||
},
|
||||
Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener {
|
||||
return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake)
|
||||
},
|
||||
}
|
||||
s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
return s.transport.SecureInbound(ctx, conn)
|
||||
}
|
||||
return s.BaseDrpcServer.Run(ctx, params)
|
||||
}
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package secureservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/net/timeoutconn"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type basicListener struct {
|
||||
net.Listener
|
||||
timeoutMillis int
|
||||
}
|
||||
|
||||
func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener {
|
||||
return &basicListener{listener, timeoutMillis}
|
||||
}
|
||||
|
||||
func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
|
||||
conn, err := b.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond)
|
||||
return ctx, timeoutConn, err
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/anytypeio/any-sync/app"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/anytypeio/any-sync/net/secureservice/handshake"
|
||||
"github.com/anytypeio/any-sync/nodeconf"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
|
@ -37,20 +38,20 @@ func New() SecureService {
|
|||
}
|
||||
|
||||
type SecureService interface {
|
||||
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
|
||||
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
|
||||
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||
SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||
app.Component
|
||||
}
|
||||
|
||||
type secureService struct {
|
||||
outboundTr *libp2ptls.Transport
|
||||
account *accountdata.AccountData
|
||||
key crypto.PrivKey
|
||||
nodeconf nodeconf.Service
|
||||
p2pTr *libp2ptls.Transport
|
||||
account *accountdata.AccountData
|
||||
key crypto.PrivKey
|
||||
nodeconf nodeconf.Service
|
||||
|
||||
noVerifyChecker handshake.CredentialChecker
|
||||
peerSignVerifier handshake.CredentialChecker
|
||||
inboundChecker handshake.CredentialChecker
|
||||
}
|
||||
|
||||
func (s *secureService) Init(a *app.App) (err error) {
|
||||
|
@ -68,7 +69,14 @@ func (s *secureService) Init(a *app.App) (err error) {
|
|||
|
||||
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
|
||||
|
||||
if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
|
||||
s.inboundChecker = s.noVerifyChecker
|
||||
confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId)
|
||||
if len(confTypes) > 0 {
|
||||
// require identity verification if we are node
|
||||
s.inboundChecker = s.peerSignVerifier
|
||||
}
|
||||
|
||||
if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -80,20 +88,30 @@ func (s *secureService) Name() (name string) {
|
|||
return CName
|
||||
}
|
||||
|
||||
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
|
||||
cc := s.noVerifyChecker
|
||||
if identityHandshake {
|
||||
cc = s.peerSignVerifier
|
||||
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
|
||||
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
return newTLSListener(cc, s.key, lis, timeoutMillis)
|
||||
|
||||
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
cctx = context.Background()
|
||||
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
|
||||
cctx = peer.CtxWithIdentity(cctx, identity)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
|
||||
return newBasicListener(lis, timeoutMillis)
|
||||
}
|
||||
|
||||
func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
|
||||
sc, err := s.outboundTr.SecureOutbound(ctx, conn, "")
|
||||
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
|
||||
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
||||
}
|
||||
|
|
|
@ -20,10 +20,8 @@ func TestHandshake(t *testing.T) {
|
|||
fxS := newFixture(t, nc, nc.GetAccountService(0))
|
||||
defer fxS.Finish(t)
|
||||
|
||||
tl := &testListener{conn: make(chan net.Conn, 1)}
|
||||
defer tl.Close()
|
||||
sc, cc := net.Pipe()
|
||||
|
||||
list := fxS.TLSListener(tl, 1000, true)
|
||||
type acceptRes struct {
|
||||
ctx context.Context
|
||||
conn net.Conn
|
||||
|
@ -32,16 +30,14 @@ func TestHandshake(t *testing.T) {
|
|||
resCh := make(chan acceptRes)
|
||||
go func() {
|
||||
var ar acceptRes
|
||||
ar.ctx, ar.conn, ar.err = list.Accept(ctx)
|
||||
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
|
||||
resCh <- ar
|
||||
}()
|
||||
|
||||
fxC := newFixture(t, nc, nc.GetAccountService(1))
|
||||
defer fxC.Finish(t)
|
||||
|
||||
sc, cc := net.Pipe()
|
||||
tl.conn <- sc
|
||||
secConn, err := fxC.TLSConn(ctx, cc)
|
||||
secConn, err := fxC.SecureOutbound(ctx, cc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
|
||||
res := <-resCh
|
||||
|
@ -61,7 +57,7 @@ func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Servic
|
|||
a: new(app.App),
|
||||
}
|
||||
|
||||
fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc)
|
||||
fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService)
|
||||
require.NoError(t, fx.a.Start(ctx))
|
||||
return fx
|
||||
}
|
||||
|
@ -75,24 +71,3 @@ type fixture struct {
|
|||
func (fx *fixture) Finish(t *testing.T) {
|
||||
require.NoError(t, fx.a.Close(ctx))
|
||||
}
|
||||
|
||||
type testListener struct {
|
||||
conn chan net.Conn
|
||||
}
|
||||
|
||||
func (t *testListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-t.conn
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (t *testListener) Close() error {
|
||||
close(t.conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testListener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
package secureservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/anytypeio/any-sync/net/secureservice/handshake"
|
||||
"github.com/anytypeio/any-sync/net/timeoutconn"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ContextListener interface {
|
||||
// Accept works like net.Listener accept but add context
|
||||
Accept(ctx context.Context) (context.Context, net.Conn, error)
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
Close() error
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
Addr() net.Addr
|
||||
}
|
||||
|
||||
func newTLSListener(cc handshake.CredentialChecker, key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener {
|
||||
tr, _ := libp2ptls.New(libp2ptls.ID, key, nil)
|
||||
return &tlsListener{
|
||||
cc: cc,
|
||||
tr: tr,
|
||||
Listener: lis,
|
||||
timeoutMillis: timeoutMillis,
|
||||
}
|
||||
}
|
||||
|
||||
type tlsListener struct {
|
||||
net.Listener
|
||||
tr *libp2ptls.Transport
|
||||
timeoutMillis int
|
||||
cc handshake.CredentialChecker
|
||||
}
|
||||
|
||||
func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
|
||||
conn, err := p.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond)
|
||||
return p.upgradeConn(ctx, timeoutConn)
|
||||
}
|
||||
|
||||
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
|
||||
secure, err := p.tr.SecureInbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
|
||||
ctx = peer.CtxWithIdentity(ctx, identity)
|
||||
return ctx, secure, nil
|
||||
}
|
|
@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) {
|
|||
}
|
||||
members = append(members, member)
|
||||
}
|
||||
if n.PeerId == s.accountId {
|
||||
continue
|
||||
}
|
||||
if n.HasType(NodeTypeConsensus) {
|
||||
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue