diff --git a/core/block/chats/chatsubscription/service.go b/core/block/chats/chatsubscription/service.go index dca5d3714..6f3dfe368 100644 --- a/core/block/chats/chatsubscription/service.go +++ b/core/block/chats/chatsubscription/service.go @@ -18,6 +18,7 @@ import ( "github.com/anyproto/anytype-heart/pkg/lib/localstore/objectstore" "github.com/anyproto/anytype-heart/pkg/lib/logging" "github.com/anyproto/anytype-heart/pkg/lib/pb/model" + "github.com/anyproto/anytype-heart/util/futures" ) const CName = "chatsubscription" @@ -64,12 +65,12 @@ type service struct { identityCache *expirable.LRU[string, *domain.Details] lock sync.Mutex - managers map[string]*subscriptionManager + managers map[string]*futures.Future[*subscriptionManager] } func New() Service { return &service{ - managers: make(map[string]*subscriptionManager), + managers: make(map[string]*futures.Future[*subscriptionManager]), } } @@ -104,7 +105,9 @@ func (s *service) GetManager(chatObjectId string) (Manager, error) { return s.getManager(chatObjectId) } -func (s *service) getManager(chatObjectId string) (*subscriptionManager, error) { +// getManagerFuture returns a future that should be resolved by the first who called this method. +// The idea behind using futures here is to initialize a manager once without blocking the whole service. +func (s *service) getManagerFuture(chatObjectId string) (*futures.Future[*subscriptionManager], error) { s.lock.Lock() mngr, ok := s.managers[chatObjectId] if ok { @@ -112,9 +115,7 @@ func (s *service) getManager(chatObjectId string) (*subscriptionManager, error) return mngr, nil } - mngr = &subscriptionManager{} - mngr.Lock() - defer mngr.Unlock() + mngr = futures.New[*subscriptionManager]() s.managers[chatObjectId] = mngr s.lock.Unlock() @@ -126,7 +127,15 @@ func (s *service) getManager(chatObjectId string) (*subscriptionManager, error) return mngr, nil } -func (s *service) initManager(chatObjectId string, mngr *subscriptionManager) error { +func (s *service) getManager(chatObjectId string) (*subscriptionManager, error) { + fut, err := s.getManagerFuture(chatObjectId) + if err != nil { + return nil, fmt.Errorf("get future: %w", err) + } + return fut.Wait() +} + +func (s *service) initManager(chatObjectId string, mngrFut *futures.Future[*subscriptionManager]) error { spaceId, err := s.spaceIdResolver.ResolveSpaceID(chatObjectId) if err != nil { return fmt.Errorf("resolve space id: %w", err) @@ -139,21 +148,26 @@ func (s *service) initManager(chatObjectId string, mngr *subscriptionManager) er if err != nil { return fmt.Errorf("get repository: %w", err) } - mngr.componentCtx = s.componentCtx - mngr.spaceId = spaceId - mngr.chatId = chatObjectId - mngr.myIdentity = currentIdentity - mngr.myParticipantId = currentParticipantId - mngr.identityCache = s.identityCache - mngr.subscriptions = make(map[string]*subscription) - mngr.spaceIndex = s.objectStore.SpaceIndex(spaceId) - mngr.eventSender = s.eventSender - mngr.repository = repository + mngr := &subscriptionManager{ + componentCtx: s.componentCtx, + spaceId: spaceId, + chatId: chatObjectId, + myIdentity: currentIdentity, + myParticipantId: currentParticipantId, + identityCache: s.identityCache, + subscriptions: make(map[string]*subscription), + spaceIndex: s.objectStore.SpaceIndex(spaceId), + eventSender: s.eventSender, + repository: repository, + } err = mngr.loadChatState(s.componentCtx) if err != nil { - return fmt.Errorf("init chat state: %w", err) + err = fmt.Errorf("init chat state: %w", err) + mngrFut.ResolveErr(err) + return err } + mngrFut.ResolveValue(mngr) return nil } @@ -183,6 +197,7 @@ func (s *service) SubscribeLastMessages(ctx context.Context, req SubscribeLastMe if err != nil { return nil, fmt.Errorf("get manager: %w", err) } + mngr.Lock() defer mngr.Unlock() diff --git a/util/futures/future.go b/util/futures/future.go new file mode 100644 index 000000000..7fbee9410 --- /dev/null +++ b/util/futures/future.go @@ -0,0 +1,59 @@ +package futures + +import ( + "sync" +) + +type Future[T any] struct { + cond *sync.Cond + + ok bool + value T + err error +} + +// New creates a value that should be resolved later. It's necessary to resolve a future eventually, otherwise there is +// a possibility of deadlock, when someone waits for never-resolving future. +func New[T any]() *Future[T] { + return &Future[T]{ + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} + +func (f *Future[T]) Wait() (T, error) { + f.cond.L.Lock() + for !f.ok { + f.cond.Wait() + } + f.cond.L.Unlock() + + return f.value, f.err +} + +// Resolve sets value or error for future only once, all consequent calls to Resolve have no effect +func (f *Future[T]) Resolve(val T, err error) { + f.cond.L.Lock() + defer f.cond.L.Unlock() + + // Resolve once + if f.ok { + return + } + + f.ok = true + f.value = val + f.err = err + + f.cond.Broadcast() +} + +func (f *Future[T]) ResolveValue(val T) { + f.Resolve(val, nil) +} + +func (f *Future[T]) ResolveErr(err error) { + var defaultValue T + f.Resolve(defaultValue, err) +} diff --git a/util/futures/future_test.go b/util/futures/future_test.go new file mode 100644 index 000000000..7d94ed8e0 --- /dev/null +++ b/util/futures/future_test.go @@ -0,0 +1,88 @@ +package futures + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFutures(t *testing.T) { + t.Run("synchronously in linear order: has value", func(t *testing.T) { + f := New[int]() + f.ResolveValue(42) + + got, err := f.Wait() + require.NoError(t, err) + assert.Equal(t, 42, got) + }) + + t.Run("synchronously in linear order: has error", func(t *testing.T) { + f := New[int]() + f.ResolveErr(fmt.Errorf("test error")) + + got, err := f.Wait() + require.Error(t, err) + assert.Equal(t, 0, got) + }) + + t.Run("one producer, multiple consumers: has value", func(t *testing.T) { + f := New[int]() + + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + got, err := f.Wait() + require.NoError(t, err) + assert.Equal(t, 42, got) + }() + } + + f.ResolveValue(42) + + wg.Wait() + }) + + t.Run("one producer, multiple consumers: has error", func(t *testing.T) { + f := New[int]() + + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + got, err := f.Wait() + require.Error(t, err) + assert.Equal(t, 0, got) + }() + } + + f.ResolveErr(fmt.Errorf("test error")) + + wg.Wait() + }) + + t.Run("multiple producers: has first resolved value", func(t *testing.T) { + f := New[int]() + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func() { + defer wg.Done() + + f.ResolveValue(i + 1) + }() + } + wg.Wait() + + got, err := f.Wait() + require.NoError(t, err) + + assert.True(t, got >= 1 && got <= 11) + }) +}