mirror of
https://github.com/anyproto/any-sync.git
synced 2025-06-11 10:18:08 +09:00
WIP add some sync stuff
This commit is contained in:
parent
869943723e
commit
ab758062df
12 changed files with 470 additions and 13 deletions
95
commonspace/sync/headupdate.go
Normal file
95
commonspace/sync/headupdate.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||||
|
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HeadUpdate struct {
|
||||||
|
peerId string
|
||||||
|
objectId string
|
||||||
|
spaceId string
|
||||||
|
heads []string
|
||||||
|
changes []*treechangeproto.RawTreeChangeWithId
|
||||||
|
snapshotPath []string
|
||||||
|
root *treechangeproto.RawTreeChangeWithId
|
||||||
|
opts BroadcastOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) SetPeerId(peerId string) {
|
||||||
|
h.peerId = peerId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) SetProtoMessage(message proto.Message) error {
|
||||||
|
var (
|
||||||
|
msg *spacesyncproto.ObjectSyncMessage
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
if msg, ok = message.(*spacesyncproto.ObjectSyncMessage); !ok {
|
||||||
|
return fmt.Errorf("unexpected message type: %T", message)
|
||||||
|
}
|
||||||
|
treeMsg := &treechangeproto.TreeSyncMessage{}
|
||||||
|
err := proto.Unmarshal(msg.Payload, treeMsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.root = treeMsg.RootChange
|
||||||
|
headMsg := treeMsg.GetContent().GetHeadUpdate()
|
||||||
|
if headMsg == nil {
|
||||||
|
return fmt.Errorf("unexpected message type: %T", treeMsg.GetContent())
|
||||||
|
}
|
||||||
|
h.heads = headMsg.Heads
|
||||||
|
h.changes = headMsg.Changes
|
||||||
|
h.snapshotPath = headMsg.SnapshotPath
|
||||||
|
h.spaceId = msg.SpaceId
|
||||||
|
h.objectId = msg.ObjectId
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) ProtoMessage() (proto.Message, error) {
|
||||||
|
if h.heads != nil {
|
||||||
|
return h.SyncMessage()
|
||||||
|
}
|
||||||
|
return &spacesyncproto.ObjectSyncMessage{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) PeerId() string {
|
||||||
|
return h.peerId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) ObjectId() string {
|
||||||
|
return h.objectId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) ShallowCopy() *HeadUpdate {
|
||||||
|
return &HeadUpdate{
|
||||||
|
peerId: h.peerId,
|
||||||
|
objectId: h.objectId,
|
||||||
|
heads: h.heads,
|
||||||
|
changes: h.changes,
|
||||||
|
snapshotPath: h.snapshotPath,
|
||||||
|
root: h.root,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) SyncMessage() (*spacesyncproto.ObjectSyncMessage, error) {
|
||||||
|
changes := h.changes
|
||||||
|
if slices.Contains(h.opts.EmptyPeers, h.peerId) {
|
||||||
|
changes = nil
|
||||||
|
}
|
||||||
|
treeMsg := treechangeproto.WrapHeadUpdate(&treechangeproto.TreeHeadUpdate{
|
||||||
|
Heads: h.heads,
|
||||||
|
SnapshotPath: h.snapshotPath,
|
||||||
|
Changes: changes,
|
||||||
|
}, h.root)
|
||||||
|
return spacesyncproto.MarshallSyncMessage(treeMsg, h.spaceId, h.objectId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HeadUpdate) RemoveChanges() {
|
||||||
|
h.changes = nil
|
||||||
|
}
|
11
commonspace/sync/headupdatehandler.go
Normal file
11
commonspace/sync/headupdatehandler.go
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"storj.io/drpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HeadUpdateHandler interface {
|
||||||
|
HandleHeadUpdate(ctx context.Context, headUpdate drpc.Message) (Request, error)
|
||||||
|
}
|
12
commonspace/sync/headupdatesender.go
Normal file
12
commonspace/sync/headupdatesender.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type BroadcastOptions struct {
|
||||||
|
EmptyPeers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeadUpdateSender interface {
|
||||||
|
SendHeadUpdate(ctx context.Context, peerId string, headUpdate *HeadUpdate) error
|
||||||
|
BroadcastHeadUpdate(ctx context.Context, opts BroadcastOptions, headUpdate *HeadUpdate) error
|
||||||
|
}
|
125
commonspace/sync/requestmanager.go
Normal file
125
commonspace/sync/requestmanager.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto"
|
||||||
|
"storj.io/drpc"
|
||||||
|
|
||||||
|
"github.com/anyproto/any-sync/net/streampool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request interface {
|
||||||
|
//heads []string
|
||||||
|
//changes []*treechangeproto.RawTreeChangeWithId
|
||||||
|
//root *treechangeproto.RawTreeChangeWithId
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response interface {
|
||||||
|
//heads []string
|
||||||
|
//changes []*treechangeproto.RawTreeChangeWithId
|
||||||
|
//root *treechangeproto.RawTreeChangeWithId
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestAccepter func(ctx context.Context, resp Response) error
|
||||||
|
|
||||||
|
type RequestManager interface {
|
||||||
|
QueueRequest(peerId, objectId string, rq Request) error
|
||||||
|
HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error
|
||||||
|
HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestHandler interface {
|
||||||
|
HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error
|
||||||
|
HandleStreamRequest(peerId, objectId string, rq Request, send func(resp proto.Message) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamResponse struct {
|
||||||
|
Stream drpc.Stream
|
||||||
|
Connection drpc.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestSender interface {
|
||||||
|
SendRequest(peerId, objectId string, rq Request) (resp Response, err error)
|
||||||
|
SendStreamRequest(peerId, objectId string, rq Request, receive func(stream drpc.Stream) error) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseHandler interface {
|
||||||
|
NewResponse() Response
|
||||||
|
HandleResponse(peerId, objectId string, resp Response) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestManager struct {
|
||||||
|
requestPool RequestPool
|
||||||
|
requestHandler RequestHandler
|
||||||
|
responseHandler ResponseHandler
|
||||||
|
requestSender RequestSender
|
||||||
|
currentRequests map[string]struct{}
|
||||||
|
mx sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wait chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *requestManager) QueueRequest(peerId, objectId string, rq Request) error {
|
||||||
|
return r.requestPool.QueueRequestAction(peerId, objectId, func() {
|
||||||
|
r.requestSender.SendStreamRequest(peerId, objectId, rq, func(stream drpc.Stream) error {
|
||||||
|
for {
|
||||||
|
resp := r.responseHandler.NewResponse()
|
||||||
|
err := stream.MsgRecv(resp, streampool.EncodingProto)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = r.responseHandler.HandleResponse(peerId, objectId, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *requestManager) HandleRequest(peerId, objectId string, rq Request, accept RequestAccepter) error {
|
||||||
|
id := fullId(peerId, objectId)
|
||||||
|
r.mx.Lock()
|
||||||
|
if _, ok := r.currentRequests[id]; ok {
|
||||||
|
r.mx.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.currentRequests[id] = struct{}{}
|
||||||
|
r.mx.Unlock()
|
||||||
|
defer func() {
|
||||||
|
r.mx.Lock()
|
||||||
|
delete(r.currentRequests, id)
|
||||||
|
r.mx.Unlock()
|
||||||
|
}()
|
||||||
|
return r.requestHandler.HandleRequest(peerId, objectId, rq, accept)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *requestManager) HandleStreamRequest(peerId, objectId string, rq Request, stream drpc.Stream) error {
|
||||||
|
id := fullId(peerId, objectId)
|
||||||
|
r.mx.Lock()
|
||||||
|
if _, ok := r.currentRequests[id]; ok {
|
||||||
|
r.mx.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.currentRequests[id] = struct{}{}
|
||||||
|
r.mx.Unlock()
|
||||||
|
defer func() {
|
||||||
|
r.mx.Lock()
|
||||||
|
delete(r.currentRequests, id)
|
||||||
|
r.mx.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := r.requestHandler.HandleStreamRequest(peerId, objectId, rq, func(resp proto.Message) error {
|
||||||
|
return stream.MsgSend(resp, streampool.EncodingProto)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func fullId(peerId, objectId string) string {
|
||||||
|
return strings.Join([]string{peerId, objectId}, "-")
|
||||||
|
}
|
5
commonspace/sync/requestpool.go
Normal file
5
commonspace/sync/requestpool.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
type RequestPool interface {
|
||||||
|
QueueRequestAction(peerId, objectId string, action func()) (err error)
|
||||||
|
}
|
84
commonspace/sync/sync.go
Normal file
84
commonspace/sync/sync.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cheggaaa/mb/v3"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"storj.io/drpc"
|
||||||
|
|
||||||
|
"github.com/anyproto/any-sync/app/logger"
|
||||||
|
"github.com/anyproto/any-sync/util/multiqueue"
|
||||||
|
)
|
||||||
|
|
||||||
|
const CName = "common.commonspace.sync"
|
||||||
|
|
||||||
|
var log = logger.NewNamed("sync")
|
||||||
|
|
||||||
|
type SyncService interface {
|
||||||
|
GetQueueProvider() multiqueue.QueueProvider[drpc.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type MergeFilterFunc func(ctx context.Context, msg drpc.Message, q *mb.MB[drpc.Message]) error
|
||||||
|
|
||||||
|
type syncService struct {
|
||||||
|
// sendQueue is a multiqueue: peerId -> queue
|
||||||
|
// this queue exists for sending head updates
|
||||||
|
sendQueueProvider multiqueue.QueueProvider[drpc.Message]
|
||||||
|
// receiveQueue is a multiqueue: objectId -> queue
|
||||||
|
// this queue exists for receiving head updates
|
||||||
|
receiveQueue multiqueue.MultiQueue[drpc.Message]
|
||||||
|
// manager is a Request manager which works with both incoming and outgoing requests
|
||||||
|
manager RequestManager
|
||||||
|
// handler checks if head update is relevant and then queues Request intent if necessary
|
||||||
|
handler HeadUpdateHandler
|
||||||
|
// sender sends head updates to peers
|
||||||
|
sender HeadUpdateSender
|
||||||
|
mergeFilter MergeFilterFunc
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSyncService() SyncService {
|
||||||
|
s := &syncService{}
|
||||||
|
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||||
|
s.sendQueueProvider = multiqueue.NewQueueProvider[drpc.Message](100, s.handleOutgoingMessage)
|
||||||
|
s.receiveQueue = multiqueue.New[drpc.Message](s.handleIncomingMessage, 100)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *syncService) handleOutgoingMessage(id string, msg drpc.Message, q *mb.MB[drpc.Message]) error {
|
||||||
|
//headUpdate := msg.(*HeadUpdate)
|
||||||
|
//cp := headUpdate.ShallowCopy()
|
||||||
|
//cp.SetPeerId(id)
|
||||||
|
//// TODO: add some merging/filtering logic if needed
|
||||||
|
//// for example we can filter empty messages for the same peer
|
||||||
|
//// or we can merge the messages together
|
||||||
|
return s.mergeFilter(s.ctx, msg, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *syncService) handleIncomingMessage(msg drpc.Message) {
|
||||||
|
req, err := s.handler.HandleHeadUpdate(s.ctx, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to handle head update", zap.Error(err))
|
||||||
|
}
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.manager.QueueRequest("", "", req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to queue request", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *syncService) GetQueueProvider() multiqueue.QueueProvider[drpc.Message] {
|
||||||
|
return s.sendQueueProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *syncService) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) error {
|
||||||
|
return s.receiveQueue.Add(ctx, peerId, msg.(*HeadUpdate))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *syncService) NewReadMessage() drpc.Message {
|
||||||
|
return &HeadUpdate{}
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package streampool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
"storj.io/drpc"
|
"storj.io/drpc"
|
||||||
)
|
)
|
||||||
|
@ -15,20 +16,51 @@ var (
|
||||||
errNotAProtoMsg = errors.New("encoding: not a proto message")
|
errNotAProtoMsg = errors.New("encoding: not a proto message")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ProtoMessageGettable interface {
|
||||||
|
ProtoMessage() (proto.Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtoMessageSettable interface {
|
||||||
|
ProtoMessageGettable
|
||||||
|
SetProtoMessage(proto.Message) error
|
||||||
|
}
|
||||||
|
|
||||||
type protoEncoding struct{}
|
type protoEncoding struct{}
|
||||||
|
|
||||||
func (p protoEncoding) Marshal(msg drpc.Message) ([]byte, error) {
|
func (p protoEncoding) Marshal(msg drpc.Message) (res []byte, err error) {
|
||||||
pmsg, ok := msg.(proto.Message)
|
pmsg, ok := msg.(proto.Message)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errNotAProtoMsg
|
if pmg, ok := msg.(ProtoMessageGettable); ok {
|
||||||
|
pmsg, err = pmg.ProtoMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, errNotAProtoMsg
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return proto.Marshal(pmsg)
|
return proto.Marshal(pmsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p protoEncoding) Unmarshal(buf []byte, msg drpc.Message) error {
|
func (p protoEncoding) Unmarshal(buf []byte, msg drpc.Message) (err error) {
|
||||||
|
var pms ProtoMessageSettable
|
||||||
pmsg, ok := msg.(proto.Message)
|
pmsg, ok := msg.(proto.Message)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errNotAProtoMsg
|
if pms, ok = msg.(ProtoMessageSettable); ok {
|
||||||
|
pmsg, err = pms.ProtoMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errNotAProtoMsg
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return proto.Unmarshal(buf, pmsg)
|
err = proto.Unmarshal(buf, pmsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if pms != nil {
|
||||||
|
err = pms.SetProtoMessage(pmsg)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"storj.io/drpc"
|
"storj.io/drpc"
|
||||||
|
|
||||||
"github.com/anyproto/any-sync/app/logger"
|
"github.com/anyproto/any-sync/app/logger"
|
||||||
|
"github.com/anyproto/any-sync/util/multiqueue"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stream struct {
|
type stream struct {
|
||||||
|
@ -19,7 +20,7 @@ type stream struct {
|
||||||
streamId uint32
|
streamId uint32
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
l logger.CtxLogger
|
l logger.CtxLogger
|
||||||
queue *mb.MB[drpc.Message]
|
queue *multiqueue.Queue[drpc.Message]
|
||||||
stats streamStat
|
stats streamStat
|
||||||
tags []string
|
tags []string
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cheggaaa/mb/v3"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
@ -13,6 +12,7 @@ import (
|
||||||
"github.com/anyproto/any-sync/app/debugstat"
|
"github.com/anyproto/any-sync/app/debugstat"
|
||||||
"github.com/anyproto/any-sync/net"
|
"github.com/anyproto/any-sync/net"
|
||||||
"github.com/anyproto/any-sync/net/peer"
|
"github.com/anyproto/any-sync/net/peer"
|
||||||
|
"github.com/anyproto/any-sync/util/multiqueue"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamHandler handles incoming messages from streams
|
// StreamHandler handles incoming messages from streams
|
||||||
|
@ -23,6 +23,8 @@ type StreamHandler interface {
|
||||||
HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error)
|
HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error)
|
||||||
// NewReadMessage creates new empty message for unmarshalling into it
|
// NewReadMessage creates new empty message for unmarshalling into it
|
||||||
NewReadMessage() drpc.Message
|
NewReadMessage() drpc.Message
|
||||||
|
// GetQueueProvider returns queue provider for outgoing messages
|
||||||
|
GetQueueProvider() multiqueue.QueueProvider[drpc.Message]
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerGetter should dial or return cached peers
|
// PeerGetter should dial or return cached peers
|
||||||
|
@ -154,10 +156,6 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream,
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.lastStreamId++
|
s.lastStreamId++
|
||||||
streamId := s.lastStreamId
|
streamId := s.lastStreamId
|
||||||
queueSize := s.writeQueueSize
|
|
||||||
if queueSize <= 0 {
|
|
||||||
queueSize = 100
|
|
||||||
}
|
|
||||||
st := &stream{
|
st := &stream{
|
||||||
peerId: peerId,
|
peerId: peerId,
|
||||||
peerCtx: ctx,
|
peerCtx: ctx,
|
||||||
|
@ -168,7 +166,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
stats: newStreamStat(peerId),
|
stats: newStreamStat(peerId),
|
||||||
}
|
}
|
||||||
st.queue = mb.New[drpc.Message](s.writeQueueSize)
|
st.queue = s.handler.GetQueueProvider().GetQueue(peerId)
|
||||||
s.streams[streamId] = st
|
s.streams[streamId] = st
|
||||||
s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId)
|
s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId)
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
|
|
|
@ -3,8 +3,9 @@ package multiqueue
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/cheggaaa/mb/v3"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cheggaaa/mb/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -25,6 +26,7 @@ type HandleFunc[T any] func(msg T)
|
||||||
type MultiQueue[T any] interface {
|
type MultiQueue[T any] interface {
|
||||||
Add(ctx context.Context, threadId string, msg T) (err error)
|
Add(ctx context.Context, threadId string, msg T) (err error)
|
||||||
CloseThread(threadId string) (err error)
|
CloseThread(threadId string) (err error)
|
||||||
|
ThreadIds() []string
|
||||||
Close() (err error)
|
Close() (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,6 +38,16 @@ type multiQueue[T any] struct {
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *multiQueue[T]) ThreadIds() []string {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
ids := make([]string, 0, len(m.threads))
|
||||||
|
for id := range m.threads {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
func (m *multiQueue[T]) Add(ctx context.Context, threadId string, msg T) (err error) {
|
func (m *multiQueue[T]) Add(ctx context.Context, threadId string, msg T) (err error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if m.closed {
|
if m.closed {
|
||||||
|
|
35
util/multiqueue/queue.go
Normal file
35
util/multiqueue/queue.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package multiqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cheggaaa/mb/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QueueHandler[T any] func(id string, msg T, q *mb.MB[T]) error
|
||||||
|
|
||||||
|
type Queue[T any] struct {
|
||||||
|
id string
|
||||||
|
q *mb.MB[T]
|
||||||
|
handler QueueHandler[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueue[T any](id string, size int, h QueueHandler[T]) *Queue[T] {
|
||||||
|
return &Queue[T]{
|
||||||
|
id: id,
|
||||||
|
q: mb.New[T](size),
|
||||||
|
handler: h,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queue[T]) TryAdd(msg T) error {
|
||||||
|
return q.handler(q.id, msg, q.q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queue[T]) WaitOne(ctx context.Context) (T, error) {
|
||||||
|
return q.q.WaitOne(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queue[T]) Close() error {
|
||||||
|
return q.q.Close()
|
||||||
|
}
|
47
util/multiqueue/queueprovider.go
Normal file
47
util/multiqueue/queueprovider.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package multiqueue
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QueueProvider[T any] interface {
|
||||||
|
GetQueue(id string) *Queue[T]
|
||||||
|
RemoveQueue(id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type queueProvider[T any] struct {
|
||||||
|
queues map[string]*Queue[T]
|
||||||
|
mx sync.Mutex
|
||||||
|
size int
|
||||||
|
handler QueueHandler[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueueProvider[T any](size int, handler QueueHandler[T]) QueueProvider[T] {
|
||||||
|
return &queueProvider[T]{
|
||||||
|
queues: make(map[string]*Queue[T]),
|
||||||
|
size: size,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *queueProvider[T]) GetQueue(id string) *Queue[T] {
|
||||||
|
p.mx.Lock()
|
||||||
|
defer p.mx.Unlock()
|
||||||
|
q, ok := p.queues[id]
|
||||||
|
if !ok {
|
||||||
|
q = NewQueue(id, p.size, p.handler)
|
||||||
|
p.queues[id] = q
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *queueProvider[T]) RemoveQueue(id string) error {
|
||||||
|
p.mx.Lock()
|
||||||
|
defer p.mx.Unlock()
|
||||||
|
q, ok := p.queues[id]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
delete(p.queues, id)
|
||||||
|
return q.Close()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue