diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 5dff0cb9..1d59af4b 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -3,7 +3,7 @@ package streampool import ( "context" "github.com/anyproto/any-sync/app/logger" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "storj.io/drpc" "sync/atomic" @@ -17,17 +17,12 @@ type stream struct { streamId uint32 closed atomic.Bool l logger.CtxLogger - queue multiqueue.MultiQueue[drpc.Message] + queue *mb.MB[drpc.Message] tags []string } func (sr *stream) write(msg drpc.Message) (err error) { - var queueId string - if qId, ok := msg.(MessageQueueId); ok { - queueId = qId.MessageQueueId() - msg = qId.DrpcMessage() - } - return sr.queue.Add(sr.stream.Context(), queueId, msg) + return sr.queue.Add(sr.stream.Context(), msg) } func (sr *stream) readLoop() error { @@ -50,13 +45,21 @@ func (sr *stream) readLoop() error { } } -func (sr *stream) writeToStream(msg drpc.Message) { - if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { - sr.l.Warn("msg send error", zap.Error(err)) - sr.streamClose() - return +func (sr *stream) writeLoop() { + for { + msg, err := sr.queue.WaitOne(sr.peerCtx) + if err != nil { + if err != mb.ErrClosed { + sr.streamClose() + } + return + } + if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { + sr.l.Warn("msg send error", zap.Error(err)) + sr.streamClose() + return + } } - return } func (sr *stream) streamClose() { diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 59ee6e4c..50a020b3 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" @@ -74,6 +74,9 @@ func (s *streamPool) ReadStream(drpcStream drpc.Stream, tags ...string) error { if err != nil { return err } + go func() { + st.writeLoop() + }() return st.readLoop() } @@ -85,6 +88,9 @@ func (s *streamPool) AddStream(drpcStream drpc.Stream, tags ...string) error { go func() { _ = st.readLoop() }() + go func() { + st.writeLoop() + }() return nil } @@ -122,7 +128,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)), tags: tags, } - st.queue = multiqueue.New[drpc.Message](st.writeToStream, s.writeQueueSize) + st.queue = mb.New[drpc.Message](s.writeQueueSize) s.streams[streamId] = st s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId) for _, tag := range tags { @@ -364,21 +370,3 @@ func removeStream(m map[string][]uint32, key string, streamId uint32) { m[key] = streamIds } } - -// WithQueueId wraps the message and adds queueId -func WithQueueId(msg drpc.Message, queueId string) drpc.Message { - return &messageWithQueueId{queueId: queueId, Message: msg} -} - -type messageWithQueueId struct { - drpc.Message - queueId string -} - -func (m messageWithQueueId) MessageQueueId() string { - return m.queueId -} - -func (m messageWithQueueId) DrpcMessage() drpc.Message { - return m.Message -} diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 6a053bd4..d4a05de3 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -47,7 +47,7 @@ func TestStreamPool_AddStream(t *testing.T) { require.NoError(t, fx.AddStream(s2, "space2", "common")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) - require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2")) + require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common")) var serverResults []string