1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-08 05:57:03 +09:00

ref: use peerId instead libp2p sec connection + quic ready interfaces

This commit is contained in:
Sergey Cherepanov 2023-06-06 18:36:09 +02:00
parent 0b4f08fbef
commit 156f3387eb
No known key found for this signature in database
GPG key ID: 87F8EDE8FBDF637C
13 changed files with 209 additions and 152 deletions

12
go.mod
View file

@ -55,9 +55,10 @@ require (
github.com/fogleman/gg v1.3.0 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/ipfs/bbloom v0.0.4 // indirect
@ -90,7 +91,7 @@ require (
github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multistream v0.4.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@ -98,7 +99,9 @@ require (
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.0 // indirect
github.com/quic-go/quic-go v0.34.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/quic-go/quic-go v0.35.1 // indirect
github.com/quic-go/webtransport-go v0.5.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect
@ -107,9 +110,10 @@ require (
go.opentelemetry.io/otel/trace v1.7.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.6.0 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/tools v0.9.1 // indirect
golang.org/x/tools v0.9.3 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.30.0 // indirect
lukechampine.com/blake3 v1.2.1 // indirect

14
go.sum
View file

@ -47,6 +47,7 @@ github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbV
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-graphviz v0.1.1 h1:MGrsnzBxTyt7KG8FhHsFPDTGvF7UaQMmSa6A610DqPg=
@ -69,6 +70,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -244,8 +247,11 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -267,9 +273,13 @@ github.com/prometheus/procfs v0.10.0 h1:UkG7GPYkO4UZyLnyXjaWYcgOSONqwdBqFUT95ugm
github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU=
github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -285,6 +295,7 @@ github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
@ -353,6 +364,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -417,6 +429,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View file

@ -17,7 +17,7 @@ type TimeoutConn struct {
timeout time.Duration
}
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn {
return &TimeoutConn{conn, timeout}
}

View file

@ -5,7 +5,6 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
)
@ -19,11 +18,11 @@ type noVerifyChecker struct {
cred *handshakeproto.Credentials
}
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
return n.cred
}
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != n.cred.Version {
return nil, handshake.ErrIncompatibleVersion
}
@ -42,8 +41,8 @@ type peerSignVerifier struct {
account *accountdata.AccountKeys
}
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String()))
func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId))
if err != nil {
log.Warn("can't sign identity credentials", zap.Error(err))
}
@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr
}
}
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != p.protoVersion {
return nil, handshake.ErrIncompatibleVersion
}
@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro
if err != nil {
return nil, handshake.ErrInvalidCredentials
}
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign)
ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
if err != nil {
return nil, err
}

View file

@ -4,13 +4,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/testutil/accounttest"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"testing"
)
@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(0, a2)
c1 := newTestSC(a2.PeerId)
c2 := newTestSC(a1.PeerId)
c1 := a2.PeerId
c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2)
@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(1, a2)
c1 := newTestSC(a2.PeerId)
c2 := newTestSC(a1.PeerId)
c1 := a2.PeerId
c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2)
@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys {
require.NoError(t, as.Init(nil))
return as.Account()
}
func newTestSC(peerId string) sec.SecureConn {
pid, _ := peer.Decode(peerId)
return &testSc{
ID: pid,
}
}
type testSc struct {
net.Conn
peer.ID
}
func (t *testSc) LocalPeer() peer.ID {
return ""
}
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
return nil
}
func (t *testSc) RemotePeer() peer.ID {
return t.ID
}
func (t *testSc) RemotePublicKey() crypto.PubKey {
return nil
}
func (t *testSc) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View file

@ -3,10 +3,10 @@ package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"io"
)
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
@ -14,21 +14,21 @@ func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChec
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
identity, err = outgoingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
h.conn = conn
localCred := cc.MakeCredentials(peerId)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
@ -45,7 +45,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
@ -68,7 +68,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
@ -76,32 +76,32 @@ func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChec
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
identity, err = incomingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
h.conn = conn
msg, err := h.readMsg(msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}

View file

@ -7,7 +7,6 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
@ -17,7 +16,7 @@ import (
var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil
},
}
@ -32,7 +31,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -40,10 +39,10 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
_, err = noVerifyChecker.CheckCredential("p1", msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -58,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -69,7 +68,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -85,7 +84,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -101,7 +100,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -109,7 +108,7 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
@ -120,7 +119,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -129,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
@ -138,7 +137,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -147,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack and close conn
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -159,7 +158,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -168,12 +167,12 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
res := <-handshakeResCh
@ -183,7 +182,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -192,10 +191,10 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
_, err = noVerifyChecker.CheckCredential("", msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -211,7 +210,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -234,13 +233,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -260,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -271,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
@ -286,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -300,13 +299,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -320,13 +319,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -340,18 +339,18 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// expect EOF
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
@ -362,13 +361,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred and close conn
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -381,13 +380,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -403,13 +402,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -425,13 +424,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -448,13 +447,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -472,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -491,11 +490,11 @@ func TestEndToEnd(t *testing.T) {
)
st := time.Now()
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}()
go func() {
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}()
@ -519,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done)
go func() {
for {
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
_, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker)
select {
case outRes <- struct{}{}:
case <-done:
@ -529,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) {
}()
go func() {
for {
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
_, _ = IncomingHandshake(nil, c2, "", noVerifyChecker)
select {
case inRes <- struct{}{}:
case <-done:
@ -549,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct {
makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error
}
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials {
return t.makeCred
}
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil {
return nil, t.checkErr
}
if t.checkCred != nil {
return t.checkCred(sc, cred)
return t.checkCred(peerId, cred)
}
return nil, nil
}

View file

@ -4,10 +4,8 @@ import (
"encoding/binary"
"errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"net"
"sync"
)
@ -65,8 +63,8 @@ var handshakePool = &sync.Pool{New: func() any {
}}
type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
MakeCredentials(remotePeerId string) *handshakeproto.Credentials
CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func newHandshake() *handshake {
@ -74,7 +72,7 @@ func newHandshake() *handshake {
}
type handshake struct {
conn net.Conn
conn io.ReadWriteCloser
remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack

View file

@ -2,6 +2,7 @@ package secureservice
import (
"context"
"crypto/tls"
commonaccount "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
@ -10,9 +11,9 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/sec"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"go.uber.org/zap"
"io"
"net"
)
@ -25,8 +26,10 @@ func New() SecureService {
}
type SecureService interface {
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error)
ServerTlsConfig() (*tls.Config, error)
app.Component
}
@ -75,28 +78,31 @@ func (s *secureService) Name() (name string) {
return CName
}
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err := s.p2pTr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, handshake.HandshakeError{
return nil, handshake.HandshakeError{
Err: err,
}
}
return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String())
}
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) {
identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker)
if err != nil {
return nil, nil, err
return nil, err
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithPeerId(cctx, peerId)
cctx = peer.CtxWithIdentity(cctx, identity)
return
}
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "")
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil {
return nil, nil, handshake.HandshakeError{Err: err}
return nil, handshake.HandshakeError{Err: err}
}
peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.NodeTypes(peerId)
@ -106,12 +112,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx
} else {
checker = s.noVerifyChecker
}
identity, err := handshake.OutgoingHandshake(ctx, sc, checker)
identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker)
if err != nil {
return nil, nil, err
return nil, err
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return cctx, sc, nil
return cctx, nil
}
func (s *secureService) ServerTlsConfig() (*tls.Config, error) {
p2pIdn, err := libp2ptls.NewIdentity(s.key)
if err != nil {
return nil, err
}
conf, _ := p2pIdn.ConfigForPeer("")
conf.NextProtos = []string{"anysync"}
return conf, nil
}

View file

@ -32,18 +32,17 @@ func TestHandshake(t *testing.T) {
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
defer fxC.Finish(t)
cctx, secConn, err := fxC.SecureOutbound(ctx, cc)
cctx, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err)
ctxPeerId, err := peer.CtxPeerId(cctx)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
res := <-resCh
require.NoError(t, res.err)
@ -70,12 +69,12 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
defer fxC.Finish(t)
_, _, err := fxC.SecureOutbound(ctx, cc)
_, err := fxC.SecureOutbound(ctx, cc)
require.Equal(t, handshake.ErrIncompatibleVersion, err)
res := <-resCh
require.Equal(t, handshake.ErrIncompatibleVersion, res.err)

View file

@ -26,7 +26,10 @@ type yamuxConn struct {
}
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
return y.Session.Open()
if conn, err = y.Session.Open(); err != nil {
return
}
return connutil.NewTimeout(conn, time.Second*10), nil
}
func (y *yamuxConn) LastUsage() time.Time {
@ -46,6 +49,7 @@ func (y *yamuxConn) Accept() (conn net.Conn, err error) {
if err == yamux.ErrSessionShutdown {
err = transport.ErrConnClosed
}
}
return
}
return connutil.NewTimeout(conn, time.Second*10), nil
}

View file

@ -86,12 +86,12 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
cctx, sc, err := y.secure.SecureOutbound(ctx, conn)
cctx, err := y.secure.SecureOutbound(ctx, conn)
if err != nil {
_ = conn.Close()
return nil, err
}
luc := connutil.NewLastUsageConn(sc)
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Client(luc, y.yamuxConf)
if err != nil {
return
@ -132,12 +132,12 @@ func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
func (y *yamuxTransport) accept(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
defer cancel()
cctx, sc, err := y.secure.SecureInbound(ctx, conn)
cctx, err := y.secure.SecureInbound(ctx, conn)
if err != nil {
log.Warn("incoming connection handshake error", zap.Error(err))
return
}
luc := connutil.NewLastUsageConn(sc)
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Server(luc, y.yamuxConf)
if err != nil {
log.Warn("incoming connection yamux session error", zap.Error(err))

View file

@ -14,7 +14,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net"
"sync"
"testing"
"time"
)
var ctx = context.Background()
@ -63,6 +66,64 @@ func TestYamuxTransport_Dial(t *testing.T) {
assert.NoError(t, copyErr)
}
// no deadline - 69100 rps
// common write deadline - 66700 rps
// subconn write deadline - 67100 rps
func TestWriteBench(t *testing.T) {
t.Skip()
var (
numSubConn = 10
numWrites = 100000
)
fxS := newFixture(t)
defer fxS.finish(t)
fxC := newFixture(t)
defer fxC.finish(t)
mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err)
mcS := fxS.accepter.mcs[0]
go func() {
for i := 0; i < numSubConn; i++ {
conn, err := mcS.Accept()
require.NoError(t, err)
go func(sc net.Conn) {
var b = make([]byte, 1024)
for {
n, _ := sc.Read(b)
if n > 0 {
sc.Write(b[:n])
} else {
break
}
}
}(conn)
}
}()
var wg sync.WaitGroup
wg.Add(numSubConn)
st := time.Now()
for i := 0; i < numSubConn; i++ {
conn, err := mcC.Open(ctx)
require.NoError(t, err)
go func(sc net.Conn) {
defer sc.Close()
defer wg.Done()
for j := 0; j < numWrites; j++ {
var b = []byte("some data some data some data some data some data some data some data some data some data")
sc.Write(b)
sc.Read(b)
}
}(conn)
}
wg.Wait()
dur := time.Since(st)
t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds())
}
type fixture struct {
*yamuxTransport
a *app.App