mirror of
https://github.com/anyproto/any-sync.git
synced 2025-06-08 05:57:03 +09:00
301 lines
7.8 KiB
Go
301 lines
7.8 KiB
Go
package streampool
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/context"
|
|
"storj.io/drpc"
|
|
|
|
"github.com/anyproto/any-sync/app"
|
|
"github.com/anyproto/any-sync/app/debugstat"
|
|
"github.com/anyproto/any-sync/net/peer"
|
|
"github.com/anyproto/any-sync/net/rpc/rpctest"
|
|
"github.com/anyproto/any-sync/net/streampool/streamhandler"
|
|
"github.com/anyproto/any-sync/net/streampool/testservice"
|
|
)
|
|
|
|
var ctx = context.Background()
|
|
|
|
func makePeerPair(t *testing.T, fx *fixture, peerId string) (pS, pC peer.Peer) {
|
|
mcS, mcC := rpctest.MultiConnPair(peerId+"server", peerId)
|
|
pS, err := peer.NewPeer(mcS, fx.ts)
|
|
require.NoError(t, err)
|
|
pC, err = peer.NewPeer(mcC, fx.ts)
|
|
require.NoError(t, err)
|
|
return
|
|
}
|
|
|
|
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
|
|
_, pC := makePeerPair(t, fx, peerId)
|
|
drpcConn, err := pC.AcquireDrpcConn(ctx)
|
|
require.NoError(t, err)
|
|
st, err = testservice.NewDRPCTestClient(drpcConn).TestStream(pC.Context())
|
|
require.NoError(t, err)
|
|
return st, pC
|
|
}
|
|
|
|
func TestStreamPool_AddStream(t *testing.T) {
|
|
t.Run("broadcast incoming", func(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
s1, _ := newClientStream(t, fx, "p1")
|
|
require.NoError(t, fx.AddStream(s1, 100, "space1", "common"))
|
|
s2, _ := newClientStream(t, fx, "p2")
|
|
require.NoError(t, fx.AddStream(s2, 100, "space2", "common"))
|
|
|
|
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1"))
|
|
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2"))
|
|
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common"))
|
|
|
|
var serverResults []string
|
|
for i := 0; i < 4; i++ {
|
|
select {
|
|
case msg := <-fx.tsh.receiveCh:
|
|
serverResults = append(serverResults, msg.ReqData)
|
|
case <-time.After(time.Second):
|
|
require.NoError(t, fmt.Errorf("timeout"))
|
|
}
|
|
}
|
|
|
|
sort.Strings(serverResults)
|
|
assert.Equal(t, []string{"common", "common", "space1", "space2"}, serverResults)
|
|
|
|
assert.NoError(t, s1.Close())
|
|
assert.NoError(t, s2.Close())
|
|
})
|
|
|
|
t.Run("send incoming", func(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
s1, p1 := newClientStream(t, fx, "p1")
|
|
defer s1.Close()
|
|
require.NoError(t, fx.AddStream(s1, 100, "space1", "common"))
|
|
|
|
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "test"}, func(ctx context.Context) (peers []peer.Peer, err error) {
|
|
return []peer.Peer{p1}, nil
|
|
}))
|
|
var msg *testservice.StreamMessage
|
|
select {
|
|
case msg = <-fx.tsh.receiveCh:
|
|
case <-time.After(time.Second):
|
|
require.NoError(t, fmt.Errorf("timeout"))
|
|
}
|
|
assert.Equal(t, "test", msg.ReqData)
|
|
})
|
|
}
|
|
|
|
func TestStreamPool_Send(t *testing.T) {
|
|
t.Run("open stream", func(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
pS, _ := makePeerPair(t, fx, "p1")
|
|
|
|
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
|
|
return []peer.Peer{pS}, nil
|
|
}))
|
|
|
|
var msg *testservice.StreamMessage
|
|
select {
|
|
case msg = <-fx.tsh.receiveCh:
|
|
case <-time.After(time.Second):
|
|
require.NoError(t, fmt.Errorf("timeout"))
|
|
}
|
|
assert.Equal(t, "should open stream", msg.ReqData)
|
|
})
|
|
|
|
t.Run("parallel open stream", func(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
pS, _ := makePeerPair(t, fx, "p1")
|
|
|
|
fx.th.streamOpenDelay = time.Second / 3
|
|
|
|
var numMsgs = 5
|
|
|
|
for i := 0; i < numMsgs; i++ {
|
|
go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
|
|
return []peer.Peer{pS}, nil
|
|
}))
|
|
}
|
|
|
|
var msgs []*testservice.StreamMessage
|
|
for i := 0; i < numMsgs; i++ {
|
|
select {
|
|
case msg := <-fx.tsh.receiveCh:
|
|
msgs = append(msgs, msg)
|
|
case <-time.After(time.Second):
|
|
require.NoError(t, fmt.Errorf("timeout"))
|
|
}
|
|
}
|
|
assert.Len(t, msgs, numMsgs)
|
|
// make sure that we have only one stream
|
|
assert.Equal(t, int32(1), fx.tsh.streamsCount.Load())
|
|
})
|
|
t.Run("parallel open stream error", func(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
pS, _ := makePeerPair(t, fx, "p1")
|
|
_ = pS.Close()
|
|
|
|
fx.th.streamOpenDelay = time.Second / 3
|
|
|
|
var numMsgs = 5
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numMsgs; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, pS, &testservice.StreamMessage{ReqData: "should open stream"}))
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
})
|
|
|
|
}
|
|
|
|
func TestStreamPool_SendById(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
s1, _ := newClientStream(t, fx, "p1")
|
|
defer s1.Close()
|
|
require.NoError(t, fx.AddStream(s1, 100, "space1", "common"))
|
|
|
|
require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1"))
|
|
var msg *testservice.StreamMessage
|
|
select {
|
|
case msg = <-fx.tsh.receiveCh:
|
|
case <-time.After(time.Second):
|
|
require.NoError(t, fmt.Errorf("timeout"))
|
|
}
|
|
assert.Equal(t, "test", msg.ReqData)
|
|
}
|
|
|
|
func TestStreamPool_Tags(t *testing.T) {
|
|
fx := newFixture(t)
|
|
defer fx.Finish(t)
|
|
|
|
s1, _ := newClientStream(t, fx, "p1")
|
|
defer s1.Close()
|
|
require.NoError(t, fx.AddStream(s1, 100, "t1"))
|
|
|
|
s2, _ := newClientStream(t, fx, "p2")
|
|
defer s1.Close()
|
|
require.NoError(t, fx.AddStream(s2, 100, "t2"))
|
|
|
|
err := fx.AddTagsCtx(streamCtx(ctx, 1, "p1"), "t3", "t3")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, []uint32{1}, fx.StreamPool.(*streamPool).streamIdsByTag["t3"])
|
|
|
|
err = fx.RemoveTagsCtx(streamCtx(ctx, 2, "p2"), "t2")
|
|
require.NoError(t, err)
|
|
assert.Len(t, fx.StreamPool.(*streamPool).streamIdsByTag["t2"], 0)
|
|
|
|
}
|
|
|
|
func newFixture(t *testing.T) *fixture {
|
|
fx := &fixture{}
|
|
fx.ts = rpctest.NewTestServer()
|
|
fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)}
|
|
require.NoError(t, testservice.DRPCRegisterTest(fx.ts, fx.tsh))
|
|
fx.th = &testHandler{}
|
|
s := New().(*streamPool)
|
|
s.handler = fx.th
|
|
s.statService = debugstat.NewNoOp()
|
|
s.streamConfig = StreamConfig{
|
|
SendQueueSize: 10,
|
|
DialQueueWorkers: 1,
|
|
DialQueueSize: 10,
|
|
}
|
|
fx.StreamPool = s
|
|
require.NoError(t, fx.StreamPool.Run(context.Background()))
|
|
return fx
|
|
}
|
|
|
|
type fixture struct {
|
|
StreamPool
|
|
th *testHandler
|
|
tsh *testServerHandler
|
|
ts *rpctest.TestServer
|
|
}
|
|
|
|
func (fx *fixture) Finish(t *testing.T) {
|
|
require.NoError(t, fx.Close(context.Background()))
|
|
}
|
|
|
|
type testHandler struct {
|
|
streamOpenDelay time.Duration
|
|
incomingMessages []drpc.Message
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (t *testHandler) Init(a *app.App) (err error) {
|
|
return nil
|
|
}
|
|
|
|
func (t *testHandler) Name() (name string) {
|
|
return streamhandler.CName
|
|
}
|
|
|
|
func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc.Stream, tags []string, queueSize int, err error) {
|
|
if t.streamOpenDelay > 0 {
|
|
time.Sleep(t.streamOpenDelay)
|
|
}
|
|
conn, err := p.AcquireDrpcConn(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
queueSize = 100
|
|
stream, err = testservice.NewDRPCTestClient(conn).TestStream(p.Context())
|
|
return
|
|
}
|
|
|
|
func (t *testHandler) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.incomingMessages = append(t.incomingMessages, msg)
|
|
return nil
|
|
}
|
|
|
|
func (t *testHandler) DRPCEncoding() drpc.Encoding {
|
|
return EncodingProto
|
|
}
|
|
|
|
func (t *testHandler) NewReadMessage() drpc.Message {
|
|
return new(testservice.StreamMessage)
|
|
}
|
|
|
|
type testServerHandler struct {
|
|
receiveCh chan *testservice.StreamMessage
|
|
streamsCount atomic.Int32
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (t *testServerHandler) TestStream(st testservice.DRPCTest_TestStreamStream) error {
|
|
t.streamsCount.Add(1)
|
|
defer t.streamsCount.Add(-1)
|
|
for {
|
|
msg, err := st.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t.receiveCh <- msg
|
|
if err = st.Send(msg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|