diff --git a/commonspace/sync/headupdate.go b/commonspace/sync/headupdate.go new file mode 100644 index 00000000..d1bbdbe5 --- /dev/null +++ b/commonspace/sync/headupdate.go @@ -0,0 +1,95 @@ +package sync + +import ( + "fmt" + "slices" + + "github.com/gogo/protobuf/proto" + + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" +) + +type HeadUpdate struct { + peerId string + objectId string + spaceId string + heads []string + changes []*treechangeproto.RawTreeChangeWithId + snapshotPath []string + root *treechangeproto.RawTreeChangeWithId + opts BroadcastOptions +} + +func (h *HeadUpdate) SetPeerId(peerId string) { + h.peerId = peerId +} + +func (h *HeadUpdate) SetProtoMessage(message proto.Message) error { + var ( + msg *spacesyncproto.ObjectSyncMessage + ok bool + ) + if msg, ok = message.(*spacesyncproto.ObjectSyncMessage); !ok { + return fmt.Errorf("unexpected message type: %T", message) + } + treeMsg := &treechangeproto.TreeSyncMessage{} + err := proto.Unmarshal(msg.Payload, treeMsg) + if err != nil { + return err + } + h.root = treeMsg.RootChange + headMsg := treeMsg.GetContent().GetHeadUpdate() + if headMsg == nil { + return fmt.Errorf("unexpected message type: %T", treeMsg.GetContent()) + } + h.heads = headMsg.Heads + h.changes = headMsg.Changes + h.snapshotPath = headMsg.SnapshotPath + h.spaceId = msg.SpaceId + h.objectId = msg.ObjectId + return nil +} + +func (h *HeadUpdate) ProtoMessage() (proto.Message, error) { + if h.heads != nil { + return h.SyncMessage() + } + return &spacesyncproto.ObjectSyncMessage{}, nil +} + +func (h *HeadUpdate) PeerId() string { + return h.peerId +} + +func (h *HeadUpdate) ObjectId() string { + return h.objectId +} + +func (h *HeadUpdate) ShallowCopy() *HeadUpdate { + return &HeadUpdate{ + peerId: h.peerId, + objectId: h.objectId, + heads: h.heads, + changes: h.changes, + snapshotPath: h.snapshotPath, + root: h.root, + } +} + +func (h *HeadUpdate) SyncMessage() (*spacesyncproto.ObjectSyncMessage, error) { + changes := h.changes + if slices.Contains(h.opts.EmptyPeers, h.peerId) { + changes = nil + } + treeMsg := treechangeproto.WrapHeadUpdate(&treechangeproto.TreeHeadUpdate{ + Heads: h.heads, + SnapshotPath: h.snapshotPath, + Changes: changes, + }, h.root) + return spacesyncproto.MarshallSyncMessage(treeMsg, h.spaceId, h.objectId) +} + +func (h *HeadUpdate) RemoveChanges() { + h.changes = nil +} diff --git a/commonspace/sync/headupdatehandler.go b/commonspace/sync/headupdatehandler.go new file mode 100644 index 00000000..9064cb36 --- /dev/null +++ b/commonspace/sync/headupdatehandler.go @@ -0,0 +1,11 @@ +package sync + +import ( + "context" + + "storj.io/drpc" +) + +type HeadUpdateHandler interface { + HandleHeadUpdate(ctx context.Context, headUpdate drpc.Message) (Request, error) +} diff --git a/commonspace/sync/headupdatesender.go b/commonspace/sync/headupdatesender.go new file mode 100644 index 00000000..252b463b --- /dev/null +++ b/commonspace/sync/headupdatesender.go @@ -0,0 +1,12 @@ +package sync + +import "context" + +type BroadcastOptions struct { + EmptyPeers []string +} + +type HeadUpdateSender interface { + SendHeadUpdate(ctx context.Context, peerId string, headUpdate *HeadUpdate) error + BroadcastHeadUpdate(ctx context.Context, opts BroadcastOptions, headUpdate *HeadUpdate) error +} diff --git a/commonspace/sync/requestmanager.go b/commonspace/sync/requestmanager.go new file mode 100644 index 00000000..f7fda85d --- /dev/null +++ b/commonspace/sync/requestmanager.go @@ -0,0 +1,125 @@ +package sync + +import ( + "context" + "strings" + "sync" + + "github.com/gogo/protobuf/proto" + "storj.io/drpc" + + "github.com/anyproto/any-sync/net/streampool" +) + +type Request interface { + //heads []string + //changes []*treechangeproto.RawTreeChangeWithId + //root *treechangeproto.RawTreeChangeWithId +} + +type Response interface { + //heads []string + //changes []*treechangeproto.RawTreeChangeWithId + //root *treechangeproto.RawTreeChangeWithId +} + +type RequestAccepter func(ctx context.Context, resp Response) error + +type RequestManager interface { + QueueRequest(peerId, objectId string, rq Request) error + HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error + HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error +} + +type RequestHandler interface { + HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error + HandleStreamRequest(peerId, objectId string, rq Request, send func(resp proto.Message) error) error +} + +type StreamResponse struct { + Stream drpc.Stream + Connection drpc.Conn +} + +type RequestSender interface { + SendRequest(peerId, objectId string, rq Request) (resp Response, err error) + SendStreamRequest(peerId, objectId string, rq Request, receive func(stream drpc.Stream) error) (err error) +} + +type ResponseHandler interface { + NewResponse() Response + HandleResponse(peerId, objectId string, resp Response) error +} + +type requestManager struct { + requestPool RequestPool + requestHandler RequestHandler + responseHandler ResponseHandler + requestSender RequestSender + currentRequests map[string]struct{} + mx sync.Mutex + ctx context.Context + cancel context.CancelFunc + wait chan struct{} +} + +func (r *requestManager) QueueRequest(peerId, objectId string, rq Request) error { + return r.requestPool.QueueRequestAction(peerId, objectId, func() { + r.requestSender.SendStreamRequest(peerId, objectId, rq, func(stream drpc.Stream) error { + for { + resp := r.responseHandler.NewResponse() + err := stream.MsgRecv(resp, streampool.EncodingProto) + if err != nil { + return err + } + err = r.responseHandler.HandleResponse(peerId, objectId, resp) + if err != nil { + return err + } + } + return nil + }) + }) +} + +func (r *requestManager) HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error { + id := fullId(peerId, objectId) + r.mx.Lock() + if _, ok := r.currentRequests[id]; ok { + r.mx.Unlock() + return nil + } + r.currentRequests[id] = struct{}{} + r.mx.Unlock() + defer func() { + r.mx.Lock() + delete(r.currentRequests, id) + r.mx.Unlock() + }() + return r.requestHandler.HandleRequest(peerId, objectId, rq, accept) +} + +func (r *requestManager) HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error { + id := fullId(peerId, objectId) + r.mx.Lock() + if _, ok := r.currentRequests[id]; ok { + r.mx.Unlock() + return nil + } + r.currentRequests[id] = struct{}{} + r.mx.Unlock() + defer func() { + r.mx.Lock() + delete(r.currentRequests, id) + r.mx.Unlock() + }() + + err := r.requestHandler.HandleStreamRequest(peerId, objectId, rq, func(resp proto.Message) error { + return stream.MsgSend(resp, streampool.EncodingProto) + }) + return err +} + +func fullId(peerId, objectId string) string { + return strings.Join([]string{peerId, objectId}, "-") +} diff --git a/commonspace/sync/requestpool.go b/commonspace/sync/requestpool.go new file mode 100644 index 00000000..c34944ba --- /dev/null +++ b/commonspace/sync/requestpool.go @@ -0,0 +1,5 @@ +package sync + +type RequestPool interface { + QueueRequestAction(peerId, objectId string, action func()) (err error) +} diff --git a/commonspace/sync/sync.go b/commonspace/sync/sync.go new file mode 100644 index 00000000..ee2a588b --- /dev/null +++ b/commonspace/sync/sync.go @@ -0,0 +1,84 @@ +package sync + +import ( + "context" + + "github.com/cheggaaa/mb/v3" + "go.uber.org/zap" + "storj.io/drpc" + + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/util/multiqueue" +) + +const CName = "common.commonspace.sync" + +var log = logger.NewNamed("sync") + +type SyncService interface { + GetQueueProvider() multiqueue.QueueProvider[drpc.Message] +} + +type MergeFilterFunc func(ctx context.Context, msg drpc.Message, q *mb.MB[drpc.Message]) error + +type syncService struct { + // sendQueue is a multiqueue: peerId -> queue + // this queue exists for sending head updates + sendQueueProvider multiqueue.QueueProvider[drpc.Message] + // receiveQueue is a multiqueue: objectId -> queue + // this queue exists for receiving head updates + receiveQueue multiqueue.MultiQueue[drpc.Message] + // manager is a Request manager which works with both incoming and outgoing requests + manager RequestManager + // handler checks if head update is relevant and then queues Request intent if necessary + handler HeadUpdateHandler + // sender sends head updates to peers + sender HeadUpdateSender + mergeFilter MergeFilterFunc + ctx context.Context + cancel context.CancelFunc +} + +func NewSyncService() SyncService { + s := &syncService{} + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.sendQueueProvider = multiqueue.NewQueueProvider[drpc.Message](100, s.handleOutgoingMessage) + s.receiveQueue = multiqueue.New[drpc.Message](s.handleIncomingMessage, 100) + return s +} + +func (s *syncService) handleOutgoingMessage(id string, msg drpc.Message, q *mb.MB[drpc.Message]) error { + //headUpdate := msg.(*HeadUpdate) + //cp := headUpdate.ShallowCopy() + //cp.SetPeerId(id) + //// TODO: add some merging/filtering logic if needed + //// for example we can filter empty messages for the same peer + //// or we can merge the messages together + return s.mergeFilter(s.ctx, msg, q) +} + +func (s *syncService) handleIncomingMessage(msg drpc.Message) { + req, err := s.handler.HandleHeadUpdate(s.ctx, msg) + if err != nil { + log.Error("failed to handle head update", zap.Error(err)) + } + if req == nil { + return + } + err = s.manager.QueueRequest("", "", req) + if err != nil { + log.Error("failed to queue request", zap.Error(err)) + } +} + +func (s *syncService) GetQueueProvider() multiqueue.QueueProvider[drpc.Message] { + return s.sendQueueProvider +} + +func (s *syncService) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) error { + return s.receiveQueue.Add(ctx, peerId, msg.(*HeadUpdate)) +} + +func (s *syncService) NewReadMessage() drpc.Message { + return &HeadUpdate{} +} diff --git a/net/streampool/encoding.go b/net/streampool/encoding.go index d724bf90..b494c4eb 100644 --- a/net/streampool/encoding.go +++ b/net/streampool/encoding.go @@ -2,6 +2,7 @@ package streampool import ( "errors" + "github.com/gogo/protobuf/proto" "storj.io/drpc" ) @@ -15,20 +16,51 @@ var ( errNotAProtoMsg = errors.New("encoding: not a proto message") ) +type ProtoMessageGettable interface { + ProtoMessage() (proto.Message, error) +} + +type ProtoMessageSettable interface { + ProtoMessageGettable + SetProtoMessage(proto.Message) error +} + type protoEncoding struct{} -func (p protoEncoding) Marshal(msg drpc.Message) ([]byte, error) { +func (p protoEncoding) Marshal(msg drpc.Message) (res []byte, err error) { pmsg, ok := msg.(proto.Message) if !ok { - return nil, errNotAProtoMsg + if pmg, ok := msg.(ProtoMessageGettable); ok { + pmsg, err = pmg.ProtoMessage() + if err != nil { + return nil, err + } + } else { + return nil, errNotAProtoMsg + } } return proto.Marshal(pmsg) } -func (p protoEncoding) Unmarshal(buf []byte, msg drpc.Message) error { +func (p protoEncoding) Unmarshal(buf []byte, msg drpc.Message) (err error) { + var pms ProtoMessageSettable pmsg, ok := msg.(proto.Message) if !ok { - return errNotAProtoMsg + if pms, ok = msg.(ProtoMessageSettable); ok { + pmsg, err = pms.ProtoMessage() + if err != nil { + return err + } + } else { + return errNotAProtoMsg + } } - return proto.Unmarshal(buf, pmsg) + err = proto.Unmarshal(buf, pmsg) + if err != nil { + return err + } + if pms != nil { + err = pms.SetProtoMessage(pmsg) + } + return } diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 28f60d76..8aae8161 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -9,6 +9,7 @@ import ( "storj.io/drpc" "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/util/multiqueue" ) type stream struct { @@ -19,7 +20,7 @@ type stream struct { streamId uint32 closed atomic.Bool l logger.CtxLogger - queue *mb.MB[drpc.Message] + queue *multiqueue.Queue[drpc.Message] stats streamStat tags []string } diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 14be0b38..4cae48e0 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -4,7 +4,6 @@ import ( "fmt" "sync" - "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" @@ -13,6 +12,7 @@ import ( "github.com/anyproto/any-sync/app/debugstat" "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/util/multiqueue" ) // StreamHandler handles incoming messages from streams @@ -23,6 +23,8 @@ type StreamHandler interface { HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) // NewReadMessage creates new empty message for unmarshalling into it NewReadMessage() drpc.Message + // GetQueueProvider returns queue provider for outgoing messages + GetQueueProvider() multiqueue.QueueProvider[drpc.Message] } // PeerGetter should dial or return cached peers @@ -154,10 +156,6 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, defer s.mu.Unlock() s.lastStreamId++ streamId := s.lastStreamId - queueSize := s.writeQueueSize - if queueSize <= 0 { - queueSize = 100 - } st := &stream{ peerId: peerId, peerCtx: ctx, @@ -168,7 +166,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, tags: tags, stats: newStreamStat(peerId), } - st.queue = mb.New[drpc.Message](s.writeQueueSize) + st.queue = s.handler.GetQueueProvider().GetQueue(peerId) s.streams[streamId] = st s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId) for _, tag := range tags { diff --git a/util/multiqueue/multiqueue.go b/util/multiqueue/multiqueue.go index ee1cce81..90974f66 100644 --- a/util/multiqueue/multiqueue.go +++ b/util/multiqueue/multiqueue.go @@ -3,8 +3,9 @@ package multiqueue import ( "context" "errors" - "github.com/cheggaaa/mb/v3" "sync" + + "github.com/cheggaaa/mb/v3" ) var ( @@ -25,6 +26,7 @@ type HandleFunc[T any] func(msg T) type MultiQueue[T any] interface { Add(ctx context.Context, threadId string, msg T) (err error) CloseThread(threadId string) (err error) + ThreadIds() []string Close() (err error) } @@ -36,6 +38,16 @@ type multiQueue[T any] struct { closed bool } +func (m *multiQueue[T]) ThreadIds() []string { + m.mu.Lock() + defer m.mu.Unlock() + ids := make([]string, 0, len(m.threads)) + for id := range m.threads { + ids = append(ids, id) + } + return ids +} + func (m *multiQueue[T]) Add(ctx context.Context, threadId string, msg T) (err error) { m.mu.Lock() if m.closed { diff --git a/util/multiqueue/queue.go b/util/multiqueue/queue.go new file mode 100644 index 00000000..96c50816 --- /dev/null +++ b/util/multiqueue/queue.go @@ -0,0 +1,35 @@ +package multiqueue + +import ( + "context" + + "github.com/cheggaaa/mb/v3" +) + +type QueueHandler[T any] func(id string, msg T, q *mb.MB[T]) error + +type Queue[T any] struct { + id string + q *mb.MB[T] + handler QueueHandler[T] +} + +func NewQueue[T any](id string, size int, h QueueHandler[T]) *Queue[T] { + return &Queue[T]{ + id: id, + q: mb.New[T](size), + handler: h, + } +} + +func (q *Queue[T]) TryAdd(msg T) error { + return q.handler(q.id, msg, q.q) +} + +func (q *Queue[T]) WaitOne(ctx context.Context) (T, error) { + return q.q.WaitOne(ctx) +} + +func (q *Queue[T]) Close() error { + return q.q.Close() +} diff --git a/util/multiqueue/queueprovider.go b/util/multiqueue/queueprovider.go new file mode 100644 index 00000000..e16c24fc --- /dev/null +++ b/util/multiqueue/queueprovider.go @@ -0,0 +1,47 @@ +package multiqueue + +import ( + "sync" +) + +type QueueProvider[T any] interface { + GetQueue(id string) *Queue[T] + RemoveQueue(id string) error +} + +type queueProvider[T any] struct { + queues map[string]*Queue[T] + mx sync.Mutex + size int + handler QueueHandler[T] +} + +func NewQueueProvider[T any](size int, handler QueueHandler[T]) QueueProvider[T] { + return &queueProvider[T]{ + queues: make(map[string]*Queue[T]), + size: size, + handler: handler, + } +} + +func (p *queueProvider[T]) GetQueue(id string) *Queue[T] { + p.mx.Lock() + defer p.mx.Unlock() + q, ok := p.queues[id] + if !ok { + q = NewQueue(id, p.size, p.handler) + p.queues[id] = q + } + return q +} + +func (p *queueProvider[T]) RemoveQueue(id string) error { + p.mx.Lock() + defer p.mx.Unlock() + q, ok := p.queues[id] + if !ok { + return nil + } + delete(p.queues, id) + return q.Close() +}