1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-10 01:51:11 +09:00

Add send request sync

This commit is contained in:
mcrakhman 2024-06-09 15:03:28 +02:00
parent febbdfa09b
commit 51f740cb90
No known key found for this signature in database
GPG key ID: DED12CFEF5B8396B
3 changed files with 33 additions and 1 deletions

View file

@ -14,9 +14,14 @@ import (
type RequestManager interface { type RequestManager interface {
QueueRequest(rq syncdeps.Request) error QueueRequest(rq syncdeps.Request) error
SendRequest(ctx context.Context, rq syncdeps.Request, collector ResponseCollector) error
HandleStreamRequest(ctx context.Context, rq syncdeps.Request, stream drpc.Stream) error HandleStreamRequest(ctx context.Context, rq syncdeps.Request, stream drpc.Stream) error
} }
type ResponseCollector interface {
CollectResponse(ctx context.Context, peerId, objectId string, resp syncdeps.Response) error
}
type StreamResponse struct { type StreamResponse struct {
Stream drpc.Stream Stream drpc.Stream
Connection drpc.Conn Connection drpc.Conn
@ -36,6 +41,22 @@ func NewRequestManager(handler syncdeps.SyncHandler) RequestManager {
} }
} }
func (r *requestManager) SendRequest(ctx context.Context, rq syncdeps.Request, collector ResponseCollector) error {
return r.handler.SendStreamRequest(ctx, rq, func(stream drpc.Stream) error {
for {
resp := r.handler.NewResponse()
err := stream.MsgRecv(resp, streampool.EncodingProto)
if err != nil {
return err
}
err = collector.CollectResponse(ctx, rq.PeerId(), rq.ObjectId(), resp)
if err != nil {
return err
}
}
})
}
func (r *requestManager) QueueRequest(rq syncdeps.Request) error { func (r *requestManager) QueueRequest(rq syncdeps.Request) error {
return r.requestPool.QueueRequestAction(rq.PeerId(), rq.ObjectId(), func(ctx context.Context) { return r.requestPool.QueueRequestAction(rq.PeerId(), rq.ObjectId(), func(ctx context.Context) {
err := r.handler.SendStreamRequest(ctx, rq, func(stream drpc.Stream) error { err := r.handler.SendStreamRequest(ctx, rq, func(stream drpc.Stream) error {

View file

@ -23,6 +23,7 @@ type SyncService interface {
app.Component app.Component
BroadcastMessage(ctx context.Context, msg drpc.Message) error BroadcastMessage(ctx context.Context, msg drpc.Message) error
HandleStreamRequest(ctx context.Context, req syncdeps.Request, stream drpc.Stream) error HandleStreamRequest(ctx context.Context, req syncdeps.Request, stream drpc.Stream) error
SendRequest(ctx context.Context, rq syncdeps.Request, collector ResponseCollector) error
QueueRequest(ctx context.Context, rq syncdeps.Request) error QueueRequest(ctx context.Context, rq syncdeps.Request) error
} }
@ -104,6 +105,10 @@ func (s *syncService) QueueRequest(ctx context.Context, rq syncdeps.Request) err
return s.manager.QueueRequest(rq) return s.manager.QueueRequest(rq)
} }
func (s *syncService) SendRequest(ctx context.Context, rq syncdeps.Request, collector ResponseCollector) error {
return s.manager.SendRequest(ctx, rq, collector)
}
func (s *syncService) HandleStreamRequest(ctx context.Context, req syncdeps.Request, stream drpc.Stream) error { func (s *syncService) HandleStreamRequest(ctx context.Context, req syncdeps.Request, stream drpc.Stream) error {
return s.manager.HandleStreamRequest(ctx, req, stream) return s.manager.HandleStreamRequest(ctx, req, stream)
} }

View file

@ -2,6 +2,8 @@ package synctest
import ( import (
"context" "context"
"errors"
"io"
"storj.io/drpc" "storj.io/drpc"
@ -25,6 +27,10 @@ func (c *CounterRequestSender) SendStreamRequest(ctx context.Context, rq syncdep
if err != nil { if err != nil {
return err return err
} }
return receive(stream) err = receive(stream)
if errors.Is(err, io.EOF) {
return nil
}
return err
}) })
} }