diff --git a/core/filestorage/rpcstore/clientmgr.go b/core/filestorage/rpcstore/clientmgr.go index edd3bce68..ba51aa430 100644 --- a/core/filestorage/rpcstore/clientmgr.go +++ b/core/filestorage/rpcstore/clientmgr.go @@ -15,6 +15,7 @@ import ( "golang.org/x/exp/slices" "github.com/anyproto/anytype-heart/space/spacecore/peerstore" + "github.com/anyproto/anytype-heart/util/contexthelper" ) const ( @@ -67,6 +68,8 @@ type clientManager struct { } func (m *clientManager) add(ctx context.Context, ts ...*task) (err error) { + ctx, cancel := contexthelper.ContextWithCloseChan(ctx, m.ctx.Done()) + defer cancel() select { case m.addLimiter <- struct{}{}: case <-ctx.Done(): diff --git a/core/payments/payments.go b/core/payments/payments.go index f8e2cc6b0..476f22f43 100644 --- a/core/payments/payments.go +++ b/core/payments/payments.go @@ -3,9 +3,6 @@ package payments import ( "context" "errors" - "fmt" - "sync" - "sync/atomic" "time" "unicode/utf8" @@ -26,6 +23,7 @@ import ( "github.com/anyproto/anytype-heart/pkg/lib/logging" "github.com/anyproto/anytype-heart/pkg/lib/pb/model" "github.com/anyproto/anytype-heart/space/deletioncontroller" + "github.com/anyproto/anytype-heart/util/contexthelper" ) const CName = "payments" @@ -125,16 +123,15 @@ func New() Service { } type service struct { - cache cache.CacheService - ppclient ppclient.AnyPpClientService - wallet wallet.Wallet - mx sync.Mutex - periodicGetStatus periodicsync.PeriodicSync - eventSender event.Sender - profileUpdater globalNamesUpdater - ns nameservice.Service - cancel context.CancelFunc - closed atomic.Bool + cache cache.CacheService + ppclient ppclient.AnyPpClientService + wallet wallet.Wallet + getSubscriptionLimiter chan struct{} + periodicGetStatus periodicsync.PeriodicSync + eventSender event.Sender + profileUpdater globalNamesUpdater + ns nameservice.Service + closing chan struct{} multiplayerLimitsUpdater deletioncontroller.DeletionController fileLimitsUpdater filesync.FileSync @@ -154,8 +151,7 @@ func (s *service) Init(a *app.App) (err error) { s.profileUpdater = app.MustComponent[globalNamesUpdater](a) s.multiplayerLimitsUpdater = app.MustComponent[deletioncontroller.DeletionController](a) s.fileLimitsUpdater = app.MustComponent[filesync.FileSync](a) - // setting empty cancel function, to not have nil function here - _, s.cancel = context.WithCancel(context.Background()) + s.getSubscriptionLimiter = make(chan struct{}, 1) return nil } @@ -170,8 +166,7 @@ func (s *service) Run(ctx context.Context) (err error) { } func (s *service) Close(_ context.Context) (err error) { - s.closed.Store(true) - s.cancel() + close(s.closing) s.periodicGetStatus.Close() return nil } @@ -197,7 +192,7 @@ func (s *service) sendMembershipUpdateEvent(status *pb.RpcMembershipGetStatusRes }) } -// Logic: +// GetSubscriptionStatus Logic: // // 1. Check in cache. if req.NoCache -> do not check in cache. // 2. If found in cache -> return it @@ -210,12 +205,17 @@ func (s *service) sendMembershipUpdateEvent(status *pb.RpcMembershipGetStatusRes // 8. UpdateLimits // 9. Enable cache again if status is active func (s *service) GetSubscriptionStatus(ctx context.Context, req *pb.RpcMembershipGetStatusRequest) (*pb.RpcMembershipGetStatusResponse, error) { - s.mx.Lock() - defer s.mx.Unlock() - if s.closed.Load() { - return nil, fmt.Errorf("service is closed") + // wrap context to stop in-flight request in case of component close + ctx, cancel := contexthelper.ContextWithCloseChan(ctx, s.closing) + defer cancel() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case s.getSubscriptionLimiter <- struct{}{}: + defer func() { + <-s.getSubscriptionLimiter + }() } - ctx, s.cancel = context.WithCancel(ctx) // 1 - check in cache first var ( cachedStatus *pb.RpcMembershipGetStatusResponse diff --git a/util/contexthelper/context.go b/util/contexthelper/context.go new file mode 100644 index 000000000..8778dfabb --- /dev/null +++ b/util/contexthelper/context.go @@ -0,0 +1,24 @@ +package contexthelper + +import "context" + +// ContextWithCloseChan returns a context that is canceled when either the parent context +// is canceled or when the provided close channel is closed. +func ContextWithCloseChan(ctx context.Context, closeChan <-chan struct{}) (context.Context, context.CancelFunc) { + // Create a new context that can be canceled + newCtx, cancel := context.WithCancel(ctx) + + // Start a goroutine that waits for either the closeChan to be closed or + // the new context to be canceled + go func() { + select { + case <-closeChan: + cancel() + case <-newCtx.Done(): + // newCtx is canceled, goroutine exits + } + }() + + // Return the cancel function + return newCtx, cancel +} diff --git a/util/contexthelper/context_test.go b/util/contexthelper/context_test.go new file mode 100644 index 000000000..614bc419d --- /dev/null +++ b/util/contexthelper/context_test.go @@ -0,0 +1,126 @@ +package contexthelper + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestContextWithCloseChan_CloseChanCancellation(t *testing.T) { + parentCtx := context.Background() + closeChan := make(chan struct{}) + ctx, cancelFunc := ContextWithCloseChan(parentCtx, closeChan) + defer cancelFunc() // Ensure resources are released + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // Expected to be canceled when closeChan is closed + case <-time.After(1 * time.Second): + t.Error("context was not canceled when closeChan was closed") + } + }() + + // Close the closeChan to trigger cancellation + close(closeChan) + + wg.Wait() + + // Verify that the context was canceled + if ctx.Err() == nil { + t.Error("context error is nil, expected cancellation error") + } +} + +func TestContextWithCloseChan_ParentContextCancellation(t *testing.T) { + parentCtx, parentCancel := context.WithCancel(context.Background()) + closeChan := make(chan struct{}) + ctx, cancelFunc := ContextWithCloseChan(parentCtx, closeChan) + defer cancelFunc() // Ensure resources are released + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // Expected to be canceled when parentCtx is canceled + case <-time.After(1 * time.Second): + t.Error("context was not canceled when parent context was canceled") + } + }() + + // Cancel the parent context + parentCancel() + + wg.Wait() + + // Verify that the context was canceled + if ctx.Err() == nil { + t.Error("context error is nil, expected cancellation error") + } +} + +func TestContextWithCloseChan_NoCancellation(t *testing.T) { + parentCtx := context.Background() + closeChan := make(chan struct{}) + ctx, cancelFunc := ContextWithCloseChan(parentCtx, closeChan) + defer cancelFunc() // Ensure resources are released + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + t.Error("context was canceled unexpectedly") + case <-time.After(50 * time.Millisecond): + // Expected to timeout here as neither context nor closeChan is canceled + } + }() + + wg.Wait() + + // Verify that the context is still active + if ctx.Err() != nil { + t.Errorf("context error is %v, expected nil", ctx.Err()) + } +} + +func TestContextWithCloseChan_BothCancellation(t *testing.T) { + parentCtx, parentCancel := context.WithCancel(context.Background()) + closeChan := make(chan struct{}) + ctx, cancelFunc := ContextWithCloseChan(parentCtx, closeChan) + defer cancelFunc() // Ensure resources are released + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // Expected to be canceled + case <-time.After(1 * time.Second): + t.Error("context was not canceled when both parent context and closeChan were canceled") + } + }() + + // Cancel both parent context and closeChan + parentCancel() + close(closeChan) + + wg.Wait() + + // Verify that the context was canceled + if ctx.Err() == nil { + t.Error("context error is nil, expected cancellation error") + } +}