diff --git a/commonspace/sync/requestmanager.go b/commonspace/sync/requestmanager.go index f7fda85d..8816cea8 100644 --- a/commonspace/sync/requestmanager.go +++ b/commonspace/sync/requestmanager.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/gogo/protobuf/proto" + "go.uber.org/zap" "storj.io/drpc" "github.com/anyproto/any-sync/net/streampool" @@ -23,17 +24,15 @@ type Response interface { //root *treechangeproto.RawTreeChangeWithId } -type RequestAccepter func(ctx context.Context, resp Response) error - type RequestManager interface { QueueRequest(peerId, objectId string, rq Request) error - HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error + HandleRequest(peerId, objectId string, rq Request) error HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error } type RequestHandler interface { - HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error - HandleStreamRequest(peerId, objectId string, rq Request, send func(resp proto.Message) error) error + HandleRequest(peerId, objectId string, rq Request) (Request, error) + HandleStreamRequest(peerId, objectId string, rq Request, send func(resp proto.Message) error) (Request, error) } type StreamResponse struct { @@ -65,7 +64,7 @@ type requestManager struct { func (r *requestManager) QueueRequest(peerId, objectId string, rq Request) error { return r.requestPool.QueueRequestAction(peerId, objectId, func() { - r.requestSender.SendStreamRequest(peerId, objectId, rq, func(stream drpc.Stream) error { + err := r.requestSender.SendStreamRequest(peerId, objectId, rq, func(stream drpc.Stream) error { for { resp := r.responseHandler.NewResponse() err := stream.MsgRecv(resp, streampool.EncodingProto) @@ -77,12 +76,14 @@ func (r *requestManager) QueueRequest(peerId, objectId string, rq Request) error return err } } - return nil }) + if err != nil { + log.Warn("failed to receive request", zap.Error(err)) + } }) } -func (r *requestManager) HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error { +func (r *requestManager) HandleRequest(peerId, objectId string, rq Request) error { id := fullId(peerId, objectId) r.mx.Lock() if _, ok := r.currentRequests[id]; ok { @@ -96,7 +97,14 @@ func (r *requestManager) HandleRequest(peerId, objectId string, rq Request, acce delete(r.currentRequests, id) r.mx.Unlock() }() - return r.requestHandler.HandleRequest(peerId, objectId, rq, accept) + newRq, err := r.requestHandler.HandleRequest(peerId, objectId, rq) + if err != nil { + return err + } + if newRq != nil { + return r.QueueRequest(peerId, objectId, newRq) + } + return nil } func (r *requestManager) HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error { @@ -113,11 +121,16 @@ func (r *requestManager) HandleStreamRequest(peerId, objectId string, rq Request delete(r.currentRequests, id) r.mx.Unlock() }() - - err := r.requestHandler.HandleStreamRequest(peerId, objectId, rq, func(resp proto.Message) error { + newRq, err := r.requestHandler.HandleStreamRequest(peerId, objectId, rq, func(resp proto.Message) error { return stream.MsgSend(resp, streampool.EncodingProto) }) - return err + if err != nil { + return err + } + if newRq != nil { + return r.QueueRequest(peerId, objectId, newRq) + } + return nil } func fullId(peerId, objectId string) string {