mirror of
https://github.com/anyproto/any-sync.git
synced 2025-06-08 14:07:02 +09:00
remove tls listener, move handshake to serve goroutine
This commit is contained in:
parent
f796cc6c6d
commit
1a588d0a72
8 changed files with 70 additions and 162 deletions
|
@ -109,7 +109,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc
|
||||||
}
|
}
|
||||||
|
|
||||||
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
|
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
|
||||||
sc, err = d.transport.TLSConn(ctx, timeoutConn)
|
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
|
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/anytypeio/any-sync/net/secureservice"
|
"github.com/anytypeio/any-sync/net/secureservice"
|
||||||
|
"github.com/libp2p/go-libp2p/core/sec"
|
||||||
"github.com/zeebo/errs"
|
"github.com/zeebo/errs"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"io"
|
"io"
|
||||||
|
@ -18,19 +19,18 @@ import (
|
||||||
type BaseDrpcServer struct {
|
type BaseDrpcServer struct {
|
||||||
drpcServer *drpcserver.Server
|
drpcServer *drpcserver.Server
|
||||||
transport secureservice.SecureService
|
transport secureservice.SecureService
|
||||||
listeners []secureservice.ContextListener
|
listeners []net.Listener
|
||||||
|
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||||
cancel func()
|
cancel func()
|
||||||
*drpcmux.Mux
|
*drpcmux.Mux
|
||||||
}
|
}
|
||||||
|
|
||||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||||
type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener
|
|
||||||
|
|
||||||
type Params struct {
|
type Params struct {
|
||||||
BufferSizeMb int
|
BufferSizeMb int
|
||||||
ListenAddrs []string
|
ListenAddrs []string
|
||||||
Wrapper DRPCHandlerWrapper
|
Wrapper DRPCHandlerWrapper
|
||||||
Converter ListenerConverter
|
|
||||||
TimeoutMillis int
|
TimeoutMillis int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
|
||||||
}})
|
}})
|
||||||
ctx, s.cancel = context.WithCancel(ctx)
|
ctx, s.cancel = context.WithCancel(ctx)
|
||||||
for _, addr := range params.ListenAddrs {
|
for _, addr := range params.ListenAddrs {
|
||||||
tcpList, err := net.Listen("tcp", addr)
|
list, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tlsList := params.Converter(tcpList, params.TimeoutMillis)
|
s.listeners = append(s.listeners, list)
|
||||||
s.listeners = append(s.listeners, tlsList)
|
go s.serve(ctx, list)
|
||||||
go s.serve(ctx, tlsList)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) {
|
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
|
||||||
l := log.With(zap.String("localAddr", lis.Addr().String()))
|
l := log.With(zap.String("localAddr", lis.Addr().String()))
|
||||||
l.Info("drpc listener started")
|
l.Info("drpc listener started")
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
cctx, conn, err := lis.Accept(ctx)
|
conn, err := lis.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isTemporary(err) {
|
if isTemporary(err) {
|
||||||
l.Debug("listener temporary accept error", zap.Error(err))
|
l.Debug("listener temporary accept error", zap.Error(err))
|
||||||
|
@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
|
||||||
l.Error("listener accept error", zap.Error(err))
|
l.Error("listener accept error", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go s.serveConn(cctx, conn)
|
go s.serveConn(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) {
|
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
|
||||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if s.handshake != nil {
|
||||||
|
ctx, conn, err = s.handshake(conn)
|
||||||
|
if err != nil {
|
||||||
|
l.Info("handshake error", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
l.Debug("connection opened")
|
l.Debug("connection opened")
|
||||||
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
|
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
|
||||||
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
|
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
|
||||||
|
|
|
@ -7,9 +7,11 @@ import (
|
||||||
"github.com/anytypeio/any-sync/metric"
|
"github.com/anytypeio/any-sync/metric"
|
||||||
anyNet "github.com/anytypeio/any-sync/net"
|
anyNet "github.com/anytypeio/any-sync/net"
|
||||||
"github.com/anytypeio/any-sync/net/secureservice"
|
"github.com/anytypeio/any-sync/net/secureservice"
|
||||||
|
"github.com/libp2p/go-libp2p/core/sec"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"net"
|
"net"
|
||||||
"storj.io/drpc"
|
"storj.io/drpc"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const CName = "common.net.drpcserver"
|
const CName = "common.net.drpcserver"
|
||||||
|
@ -68,9 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) {
|
||||||
SummaryVec: histVec,
|
SummaryVec: histVec,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener {
|
}
|
||||||
return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake)
|
s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
|
||||||
},
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
return s.transport.SecureInbound(ctx, conn)
|
||||||
}
|
}
|
||||||
return s.BaseDrpcServer.Run(ctx, params)
|
return s.BaseDrpcServer.Run(ctx, params)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package secureservice
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"github.com/anytypeio/any-sync/net/timeoutconn"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type basicListener struct {
|
|
||||||
net.Listener
|
|
||||||
timeoutMillis int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener {
|
|
||||||
return &basicListener{listener, timeoutMillis}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
|
|
||||||
conn, err := b.Listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond)
|
|
||||||
return ctx, timeoutConn, err
|
|
||||||
}
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/anytypeio/any-sync/app"
|
"github.com/anytypeio/any-sync/app"
|
||||||
"github.com/anytypeio/any-sync/app/logger"
|
"github.com/anytypeio/any-sync/app/logger"
|
||||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||||
|
"github.com/anytypeio/any-sync/net/peer"
|
||||||
"github.com/anytypeio/any-sync/net/secureservice/handshake"
|
"github.com/anytypeio/any-sync/net/secureservice/handshake"
|
||||||
"github.com/anytypeio/any-sync/nodeconf"
|
"github.com/anytypeio/any-sync/nodeconf"
|
||||||
"github.com/libp2p/go-libp2p/core/crypto"
|
"github.com/libp2p/go-libp2p/core/crypto"
|
||||||
|
@ -37,20 +38,20 @@ func New() SecureService {
|
||||||
}
|
}
|
||||||
|
|
||||||
type SecureService interface {
|
type SecureService interface {
|
||||||
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
|
SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||||
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
|
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||||
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
|
||||||
app.Component
|
app.Component
|
||||||
}
|
}
|
||||||
|
|
||||||
type secureService struct {
|
type secureService struct {
|
||||||
outboundTr *libp2ptls.Transport
|
p2pTr *libp2ptls.Transport
|
||||||
account *accountdata.AccountData
|
account *accountdata.AccountData
|
||||||
key crypto.PrivKey
|
key crypto.PrivKey
|
||||||
nodeconf nodeconf.Service
|
nodeconf nodeconf.Service
|
||||||
|
|
||||||
noVerifyChecker handshake.CredentialChecker
|
noVerifyChecker handshake.CredentialChecker
|
||||||
peerSignVerifier handshake.CredentialChecker
|
peerSignVerifier handshake.CredentialChecker
|
||||||
|
inboundChecker handshake.CredentialChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureService) Init(a *app.App) (err error) {
|
func (s *secureService) Init(a *app.App) (err error) {
|
||||||
|
@ -68,7 +69,14 @@ func (s *secureService) Init(a *app.App) (err error) {
|
||||||
|
|
||||||
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
|
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
|
||||||
|
|
||||||
if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
|
s.inboundChecker = s.noVerifyChecker
|
||||||
|
confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId)
|
||||||
|
if len(confTypes) > 0 {
|
||||||
|
// require identity verification if we are node
|
||||||
|
s.inboundChecker = s.peerSignVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,20 +88,30 @@ func (s *secureService) Name() (name string) {
|
||||||
return CName
|
return CName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
|
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
|
||||||
cc := s.noVerifyChecker
|
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
|
||||||
if identityHandshake {
|
if err != nil {
|
||||||
cc = s.peerSignVerifier
|
return nil, nil, HandshakeError{
|
||||||
|
remoteAddr: conn.RemoteAddr().String(),
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return newTLSListener(cc, s.key, lis, timeoutMillis)
|
|
||||||
|
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, HandshakeError{
|
||||||
|
remoteAddr: conn.RemoteAddr().String(),
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cctx = context.Background()
|
||||||
|
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
|
||||||
|
cctx = peer.CtxWithIdentity(cctx, identity)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
|
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
|
||||||
return newBasicListener(lis, timeoutMillis)
|
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||||
}
|
|
||||||
|
|
||||||
func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
|
|
||||||
sc, err := s.outboundTr.SecureOutbound(ctx, conn, "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,8 @@ func TestHandshake(t *testing.T) {
|
||||||
fxS := newFixture(t, nc, nc.GetAccountService(0))
|
fxS := newFixture(t, nc, nc.GetAccountService(0))
|
||||||
defer fxS.Finish(t)
|
defer fxS.Finish(t)
|
||||||
|
|
||||||
tl := &testListener{conn: make(chan net.Conn, 1)}
|
sc, cc := net.Pipe()
|
||||||
defer tl.Close()
|
|
||||||
|
|
||||||
list := fxS.TLSListener(tl, 1000, true)
|
|
||||||
type acceptRes struct {
|
type acceptRes struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
@ -32,16 +30,14 @@ func TestHandshake(t *testing.T) {
|
||||||
resCh := make(chan acceptRes)
|
resCh := make(chan acceptRes)
|
||||||
go func() {
|
go func() {
|
||||||
var ar acceptRes
|
var ar acceptRes
|
||||||
ar.ctx, ar.conn, ar.err = list.Accept(ctx)
|
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
|
||||||
resCh <- ar
|
resCh <- ar
|
||||||
}()
|
}()
|
||||||
|
|
||||||
fxC := newFixture(t, nc, nc.GetAccountService(1))
|
fxC := newFixture(t, nc, nc.GetAccountService(1))
|
||||||
defer fxC.Finish(t)
|
defer fxC.Finish(t)
|
||||||
|
|
||||||
sc, cc := net.Pipe()
|
secConn, err := fxC.SecureOutbound(ctx, cc)
|
||||||
tl.conn <- sc
|
|
||||||
secConn, err := fxC.TLSConn(ctx, cc)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
|
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
|
||||||
res := <-resCh
|
res := <-resCh
|
||||||
|
@ -61,7 +57,7 @@ func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Servic
|
||||||
a: new(app.App),
|
a: new(app.App),
|
||||||
}
|
}
|
||||||
|
|
||||||
fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc)
|
fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService)
|
||||||
require.NoError(t, fx.a.Start(ctx))
|
require.NoError(t, fx.a.Start(ctx))
|
||||||
return fx
|
return fx
|
||||||
}
|
}
|
||||||
|
@ -75,24 +71,3 @@ type fixture struct {
|
||||||
func (fx *fixture) Finish(t *testing.T) {
|
func (fx *fixture) Finish(t *testing.T) {
|
||||||
require.NoError(t, fx.a.Close(ctx))
|
require.NoError(t, fx.a.Close(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
type testListener struct {
|
|
||||||
conn chan net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testListener) Accept() (net.Conn, error) {
|
|
||||||
conn, ok := <-t.conn
|
|
||||||
if !ok {
|
|
||||||
return nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testListener) Close() error {
|
|
||||||
close(t.conn)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testListener) Addr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
package secureservice
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"github.com/anytypeio/any-sync/net/peer"
|
|
||||||
"github.com/anytypeio/any-sync/net/secureservice/handshake"
|
|
||||||
"github.com/anytypeio/any-sync/net/timeoutconn"
|
|
||||||
"github.com/libp2p/go-libp2p/core/crypto"
|
|
||||||
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ContextListener interface {
|
|
||||||
// Accept works like net.Listener accept but add context
|
|
||||||
Accept(ctx context.Context) (context.Context, net.Conn, error)
|
|
||||||
|
|
||||||
// Close closes the listener.
|
|
||||||
// Any blocked Accept operations will be unblocked and return errors.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// Addr returns the listener's network address.
|
|
||||||
Addr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTLSListener(cc handshake.CredentialChecker, key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener {
|
|
||||||
tr, _ := libp2ptls.New(libp2ptls.ID, key, nil)
|
|
||||||
return &tlsListener{
|
|
||||||
cc: cc,
|
|
||||||
tr: tr,
|
|
||||||
Listener: lis,
|
|
||||||
timeoutMillis: timeoutMillis,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type tlsListener struct {
|
|
||||||
net.Listener
|
|
||||||
tr *libp2ptls.Transport
|
|
||||||
timeoutMillis int
|
|
||||||
cc handshake.CredentialChecker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
|
|
||||||
conn, err := p.Listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond)
|
|
||||||
return p.upgradeConn(ctx, timeoutConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
|
|
||||||
secure, err := p.tr.SecureInbound(ctx, conn, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, HandshakeError{
|
|
||||||
remoteAddr: conn.RemoteAddr().String(),
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, HandshakeError{
|
|
||||||
remoteAddr: conn.RemoteAddr().String(),
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
|
|
||||||
ctx = peer.CtxWithIdentity(ctx, identity)
|
|
||||||
return ctx, secure, nil
|
|
||||||
}
|
|
|
@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) {
|
||||||
}
|
}
|
||||||
members = append(members, member)
|
members = append(members, member)
|
||||||
}
|
}
|
||||||
if n.PeerId == s.accountId {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if n.HasType(NodeTypeConsensus) {
|
if n.HasType(NodeTypeConsensus) {
|
||||||
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
|
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue