diff --git a/net/config.go b/net/config.go index b0cdf564..dde5db96 100644 --- a/net/config.go +++ b/net/config.go @@ -10,7 +10,8 @@ type Config struct { } type ServerConfig struct { - ListenAddrs []string `yaml:"listenAddrs"` + IdentityHandshake bool `yaml:"identityHandshake"` + ListenAddrs []string `yaml:"listenAddrs"` } type StreamConfig struct { diff --git a/net/rpc/server/drpcserver.go b/net/rpc/server/drpcserver.go index 0bc0cd71..b71dcadc 100644 --- a/net/rpc/server/drpcserver.go +++ b/net/rpc/server/drpcserver.go @@ -5,9 +5,10 @@ import ( "github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/metric" - "github.com/anytypeio/any-sync/net" + anyNet "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/secureservice" "github.com/prometheus/client_golang/prometheus" + "net" "storj.io/drpc" ) @@ -25,14 +26,14 @@ type DRPCServer interface { } type drpcServer struct { - config net.Config + config anyNet.Config metric metric.Metric transport secureservice.SecureService *BaseDrpcServer } func (s *drpcServer) Init(a *app.App) (err error) { - s.config = a.MustComponent("config").(net.ConfigGetter).GetNet() + s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() s.metric = a.MustComponent(metric.CName).(metric.Metric) s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) return nil @@ -67,7 +68,9 @@ func (s *drpcServer) Run(ctx context.Context) (err error) { SummaryVec: histVec, } }, - Converter: s.transport.TLSListener, + Converter: func(listener net.Listener, timeoutMillis int) secureservice.ContextListener { + return s.transport.TLSListener(listener, timeoutMillis, s.config.Server.IdentityHandshake) + }, } return s.BaseDrpcServer.Run(ctx, params) } diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index 8f1f4a4d..48336aec 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -68,6 +68,10 @@ func (s *secureService) Init(a *app.App) (err error) { s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service) + if s.outboundTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil { + return + } + log.Info("secure service init", zap.String("peerId", account.Account().PeerId)) return nil } diff --git a/net/secureservice/secureservice_test.go b/net/secureservice/secureservice_test.go index 686b3cac..d91e5050 100644 --- a/net/secureservice/secureservice_test.go +++ b/net/secureservice/secureservice_test.go @@ -2,34 +2,97 @@ package secureservice import ( "context" + "github.com/anytypeio/any-sync/accountservice" "github.com/anytypeio/any-sync/app" - "github.com/anytypeio/any-sync/testutil/accounttest" + "github.com/anytypeio/any-sync/net/peer" + "github.com/anytypeio/any-sync/nodeconf" + "github.com/anytypeio/any-sync/testutil/testnodeconf" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "net" "testing" ) var ctx = context.Background() func TestHandshake(t *testing.T) { - fx := newFixture(t) - defer fx.Finish(t) + nc := testnodeconf.GenNodeConfig(2) + fxS := newFixture(t, nc, nc.GetAccountService(0)) + defer fxS.Finish(t) + + tl := &testListener{conn: make(chan net.Conn, 1)} + defer tl.Close() + + list := fxS.TLSListener(tl, 1000, true) + type acceptRes struct { + ctx context.Context + conn net.Conn + err error + } + resCh := make(chan acceptRes) + go func() { + var ar acceptRes + ar.ctx, ar.conn, ar.err = list.Accept(ctx) + resCh <- ar + }() + + fxC := newFixture(t, nc, nc.GetAccountService(1)) + defer fxC.Finish(t) + + sc, cc := net.Pipe() + tl.conn <- sc + secConn, err := fxC.TLSConn(ctx, cc) + require.NoError(t, err) + assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) + res := <-resCh + require.NoError(t, res.err) + peerId, err := peer.CtxPeerId(res.ctx) + require.NoError(t, err) + accId, err := peer.CtxIdentity(res.ctx) + require.NoError(t, err) + assert.Equal(t, nc.GetAccountService(1).Account().PeerId, peerId) + assert.Equal(t, nc.GetAccountService(1).Account().Identity, accId) } -func newFixture(t *testing.T) *fixture { +func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Service) *fixture { fx := &fixture{ secureService: New().(*secureService), + acc: acc, a: new(app.App), } - fx.a.Register(&accounttest.AccountTestService{}).Register(fx.secureService) + + fx.a.Register(fx.acc).Register(fx.secureService).Register(nodeconf.New()).Register(nc) require.NoError(t, fx.a.Start(ctx)) return fx } type fixture struct { *secureService - a *app.App + a *app.App + acc accountservice.Service } func (fx *fixture) Finish(t *testing.T) { require.NoError(t, fx.a.Close(ctx)) } + +type testListener struct { + conn chan net.Conn +} + +func (t *testListener) Accept() (net.Conn, error) { + conn, ok := <-t.conn + if !ok { + return nil, net.ErrClosed + } + return conn, nil +} + +func (t *testListener) Close() error { + close(t.conn) + return nil +} + +func (t *testListener) Addr() net.Addr { + return nil +} diff --git a/testutil/testnodeconf/testnodeconf.go b/testutil/testnodeconf/testnodeconf.go new file mode 100644 index 00000000..2d6b195e --- /dev/null +++ b/testutil/testnodeconf/testnodeconf.go @@ -0,0 +1,40 @@ +package testnodeconf + +import ( + "github.com/anytypeio/any-sync/accountservice" + "github.com/anytypeio/any-sync/app" + "github.com/anytypeio/any-sync/nodeconf" + "github.com/anytypeio/any-sync/testutil/accounttest" +) + +func GenNodeConfig(num int) (conf *Config) { + conf = &Config{} + if num <= 0 { + num = 1 + } + for i := 0; i < num; i++ { + ac := &accounttest.AccountTestService{} + if err := ac.Init(nil); err != nil { + panic(err) + } + conf.nodes = append(conf.nodes, ac.NodeConf(nil)) + conf.configs = append(conf.configs, ac) + } + return conf +} + +type Config struct { + nodes []nodeconf.NodeConfig + configs []*accounttest.AccountTestService +} + +func (c *Config) Init(a *app.App) (err error) { return } +func (c *Config) Name() string { return "config" } + +func (c *Config) GetNodes() []nodeconf.NodeConfig { + return c.nodes +} + +func (c *Config) GetAccountService(idx int) accountservice.Service { + return c.configs[idx] +}