diff --git a/coordinator/coordinatorclient/coordinatorclient.go b/coordinator/coordinatorclient/coordinatorclient.go index f63f8b55..82dfd4b2 100644 --- a/coordinator/coordinatorclient/coordinatorclient.go +++ b/coordinator/coordinatorclient/coordinatorclient.go @@ -361,7 +361,7 @@ func (c *coordinatorClient) IsNetworkNeedsUpdate(ctx context.Context) (bool, err if err != nil { return false, err } - return version != secureservice.ProtoVersion, nil + return secureservice.ProtoVersion < version, nil } func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error { diff --git a/coordinator/nodeconfsource/nodeconfsource.go b/coordinator/nodeconfsource/nodeconfsource.go index 2e02edb9..b2130b4f 100644 --- a/coordinator/nodeconfsource/nodeconfsource.go +++ b/coordinator/nodeconfsource/nodeconfsource.go @@ -67,17 +67,10 @@ func (n *nodeConfSource) GetLast(ctx context.Context, currentId string) (c nodec } } - needsUpdate, err := n.cl.IsNetworkNeedsUpdate(ctx) - if err != nil { - return - } - if needsUpdate { - err = nodeconf.ErrNetworkNeedsUpdate - } return nodeconf.Configuration{ Id: res.ConfigurationId, NetworkId: res.NetworkId, Nodes: nodes, CreationTime: time.Unix(int64(res.CreationTimeUnix), 0), - }, err + }, nil } diff --git a/nodeconf/service.go b/nodeconf/service.go index 85376703..4c88ebba 100644 --- a/nodeconf/service.go +++ b/nodeconf/service.go @@ -3,6 +3,8 @@ package nodeconf import ( "context" + "errors" + "fmt" "sync" commonaccount "github.com/anyproto/any-sync/accountservice" @@ -44,6 +46,10 @@ type Service interface { app.ComponentRunnable } +type NetworkProtoVersionChecker interface { + IsNetworkNeedsUpdate(ctx context.Context) (bool, error) +} + type service struct { accountId string config Configuration @@ -53,7 +59,8 @@ type service struct { mu sync.RWMutex sync periodicsync.PeriodicSync - compatibilityStatus NetworkCompatibilityStatus + compatibilityStatus NetworkCompatibilityStatus + networkProtoVersionChecker NetworkProtoVersionChecker } func (s *service) Init(a *app.App) (err error) { @@ -80,6 +87,7 @@ func (s *service) Init(a *app.App) (err error) { } return }, log) + s.networkProtoVersionChecker = app.MustComponent[NetworkProtoVersionChecker](a) return s.setLastConfiguration(lastStored) } @@ -99,17 +107,78 @@ func (s *service) NetworkCompatibilityStatus() NetworkCompatibilityStatus { } func (s *service) updateConfiguration(ctx context.Context) (err error) { - last, err := s.source.GetLast(ctx, s.Configuration().Id) + last, err := s.fetchLastConfiguration(ctx) if err != nil { + return err + } + + if err = s.updateCompatibilityStatus(ctx, err); err != nil { + return err + } + + if err = s.saveAndSetLastConfiguration(ctx, last); err != nil { + return err + } + + return nil +} + +func (s *service) fetchLastConfiguration(ctx context.Context) (Configuration, error) { + last, err := s.source.GetLast(ctx, s.Configuration().Id) + if err != nil && !errors.Is(err, ErrConfigurationNotChanged) { s.setCompatibilityStatusByErr(err) - return - } else { - s.setCompatibilityStatusByErr(nil) } - if err = s.store.SaveLast(ctx, last); err != nil { - return + return last, err +} + +func (s *service) updateCompatibilityStatus(ctx context.Context, err error) error { + needsUpdate, checkErr := s.networkProtoVersionChecker.IsNetworkNeedsUpdate(ctx) + if checkErr != nil { + return fmt.Errorf("network protocol version check failed: %w", checkErr) } - return s.setLastConfiguration(last) + + if needsUpdate { + s.setCompatibilityStatus(NetworkCompatibilityStatusNeedsUpdate) + return nil + } + + s.setCompatibilityStatusByErr(err) + return nil +} + +func (s *service) saveAndSetLastConfiguration(ctx context.Context, last Configuration) error { + if err := s.store.SaveLast(ctx, last); err != nil { + return fmt.Errorf("failed to save last configuration: %w", err) + } + + if err := s.setLastConfiguration(last); err != nil { + return fmt.Errorf("failed to set last configuration: %w", err) + } + + return nil +} + +func (s *service) setCompatibilityStatus(status NetworkCompatibilityStatus) { + s.mu.Lock() + defer s.mu.Unlock() + s.compatibilityStatus = status +} + +func (s *service) setCompatibilityStatusByErr(err error) { + var status NetworkCompatibilityStatus + + switch err { + case nil: + status = NetworkCompatibilityStatusOk + case handshake.ErrIncompatibleVersion: + status = NetworkCompatibilityStatusIncompatible + case net.ErrUnableToConnect: + status = NetworkCompatibilityStatusUnknown + default: + status = NetworkCompatibilityStatusError + } + + s.setCompatibilityStatus(status) } func (s *service) setLastConfiguration(c Configuration) (err error) { @@ -138,23 +207,6 @@ func (s *service) setLastConfiguration(c Configuration) (err error) { return } -func (s *service) setCompatibilityStatusByErr(err error) { - s.mu.Lock() - defer s.mu.Unlock() - switch err { - case nil: - s.compatibilityStatus = NetworkCompatibilityStatusOk - case handshake.ErrIncompatibleVersion: - s.compatibilityStatus = NetworkCompatibilityStatusIncompatible - case net.ErrUnableToConnect: - s.compatibilityStatus = NetworkCompatibilityStatusUnknown - case ErrNetworkNeedsUpdate: - s.compatibilityStatus = NetworkCompatibilityStatusNeedsUpdate - default: - s.compatibilityStatus = NetworkCompatibilityStatusError - } -} - func (s *service) Id() string { s.mu.RLock() defer s.mu.RUnlock() diff --git a/nodeconf/service_test.go b/nodeconf/service_test.go index 925e4c7d..fafd56d5 100644 --- a/nodeconf/service_test.go +++ b/nodeconf/service_test.go @@ -56,35 +56,42 @@ func TestService_NetworkCompatibilityStatus(t *testing.T) { }) t.Run("needs update", func(t *testing.T) { fx := newFixture(t) + fx.testCoordinator.needsUpdate = true defer fx.finish(t) - fx.testSource.call = func() (c Configuration, e error) { - e = ErrNetworkNeedsUpdate - return - } fx.run(t) time.Sleep(time.Millisecond * 10) assert.Equal(t, NetworkCompatibilityStatusNeedsUpdate, fx.NetworkCompatibilityStatus()) }) + t.Run("network not changed update", func(t *testing.T) { + fx := newFixture(t) + fx.testSource.err = ErrConfigurationNotChanged + defer fx.finish(t) + fx.run(t) + time.Sleep(time.Millisecond * 10) + assert.Equal(t, NetworkCompatibilityStatusError, fx.NetworkCompatibilityStatus()) + }) } func newFixture(t *testing.T) *fixture { fx := &fixture{ - Service: New(), - a: new(app.App), - testStore: &testStore{}, - testSource: &testSource{}, - testConf: newTestConf(), + Service: New(), + testCoordinator: &testCoordinator{}, + a: new(app.App), + testStore: &testStore{}, + testSource: &testSource{}, + testConf: newTestConf(), } - fx.a.Register(fx.testConf).Register(&accounttest.AccountTestService{}).Register(fx.Service).Register(fx.testSource).Register(fx.testStore) + fx.a.Register(fx.testConf).Register(&accounttest.AccountTestService{}).Register(fx.Service).Register(fx.testSource).Register(fx.testStore).Register(fx.testCoordinator) return fx } type fixture struct { Service - a *app.App - testStore *testStore - testSource *testSource - testConf *testConf + a *app.App + testStore *testStore + testSource *testSource + testConf *testConf + testCoordinator *testCoordinator } func (fx *fixture) run(t *testing.T) { @@ -95,6 +102,17 @@ func (fx *fixture) finish(t *testing.T) { require.NoError(t, fx.a.Close(ctx)) } +type testCoordinator struct { + needsUpdate bool +} + +func (t *testCoordinator) IsNetworkNeedsUpdate(ctx context.Context) (bool, error) { + return t.needsUpdate, nil +} + +func (t *testCoordinator) Init(a *app.App) error { return nil } +func (t *testCoordinator) Name() string { return "testCoordinator" } + type testSource struct { conf Configuration err error diff --git a/nodeconf/source.go b/nodeconf/source.go index 6dafe5af..56ea2c54 100644 --- a/nodeconf/source.go +++ b/nodeconf/source.go @@ -9,7 +9,6 @@ const CNameSource = "common.nodeconf.source" var ( ErrConfigurationNotChanged = errors.New("configuration not changed") - ErrNetworkNeedsUpdate = errors.New("network needs update") ) type Source interface {