diff --git a/go.mod b/go.mod index 5dd9ad82..b05925ce 100644 --- a/go.mod +++ b/go.mod @@ -55,9 +55,10 @@ require ( github.com/fogleman/gg v1.3.0 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect + github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/ipfs/bbloom v0.0.4 // indirect @@ -90,7 +91,7 @@ require ( github.com/multiformats/go-multicodec v0.9.0 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-varint v0.0.7 // indirect - github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/onsi/ginkgo/v2 v2.9.7 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -98,7 +99,9 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.0 // indirect - github.com/quic-go/quic-go v0.34.0 // indirect + github.com/quic-go/qtls-go1-19 v0.3.2 // indirect + github.com/quic-go/qtls-go1-20 v0.2.2 // indirect + github.com/quic-go/quic-go v0.35.1 // indirect github.com/quic-go/webtransport-go v0.5.3 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect @@ -107,9 +110,10 @@ require ( go.opentelemetry.io/otel/trace v1.7.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.6.0 // indirect + golang.org/x/mod v0.10.0 // indirect golang.org/x/sync v0.2.0 // indirect golang.org/x/sys v0.8.0 // indirect - golang.org/x/tools v0.9.1 // indirect + golang.org/x/tools v0.9.3 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/protobuf v1.30.0 // indirect lukechampine.com/blake3 v1.2.1 // indirect diff --git a/go.sum b/go.sum index ed9b778b..8fab2e97 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,7 @@ github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbV github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-graphviz v0.1.1 h1:MGrsnzBxTyt7KG8FhHsFPDTGvF7UaQMmSa6A610DqPg= @@ -69,6 +70,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -244,8 +247,11 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= +github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -267,9 +273,13 @@ github.com/prometheus/procfs v0.10.0 h1:UkG7GPYkO4UZyLnyXjaWYcgOSONqwdBqFUT95ugm github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= +github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= +github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= +github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU= github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -285,6 +295,7 @@ github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= @@ -353,6 +364,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -417,6 +429,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= +golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/net/connutil/timeout.go b/net/connutil/timeout.go index 381998f9..057c7b8c 100644 --- a/net/connutil/timeout.go +++ b/net/connutil/timeout.go @@ -17,7 +17,7 @@ type TimeoutConn struct { timeout time.Duration } -func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn { +func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn { return &TimeoutConn{conn, timeout} } diff --git a/net/secureservice/credential.go b/net/secureservice/credential.go index 5e97e8fc..9d992430 100644 --- a/net/secureservice/credential.go +++ b/net/secureservice/credential.go @@ -5,7 +5,6 @@ import ( "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/util/crypto" - "github.com/libp2p/go-libp2p/core/sec" "go.uber.org/zap" ) @@ -19,11 +18,11 @@ type noVerifyChecker struct { cred *handshakeproto.Credentials } -func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { +func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials { return n.cred } -func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if cred.Version != n.cred.Version { return nil, handshake.ErrIncompatibleVersion } @@ -42,8 +41,8 @@ type peerSignVerifier struct { account *accountdata.AccountKeys } -func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { - sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String())) +func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials { + sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId)) if err != nil { log.Warn("can't sign identity credentials", zap.Error(err)) } @@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr } } -func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if cred.Version != p.protoVersion { return nil, handshake.ErrIncompatibleVersion } @@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro if err != nil { return nil, handshake.ErrInvalidCredentials } - ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign) + ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign) if err != nil { return nil, err } diff --git a/net/secureservice/credential_test.go b/net/secureservice/credential_test.go index e64e173a..50b24ece 100644 --- a/net/secureservice/credential_test.go +++ b/net/secureservice/credential_test.go @@ -4,13 +4,8 @@ import ( "github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/testutil/accounttest" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "net" "testing" ) @@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) { cc1 := newPeerSignVerifier(0, a1) cc2 := newPeerSignVerifier(0, a2) - c1 := newTestSC(a2.PeerId) - c2 := newTestSC(a1.PeerId) + c1 := a2.PeerId + c2 := a1.PeerId cr1 := cc1.MakeCredentials(c1) cr2 := cc2.MakeCredentials(c2) @@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) { cc1 := newPeerSignVerifier(0, a1) cc2 := newPeerSignVerifier(1, a2) - c1 := newTestSC(a2.PeerId) - c2 := newTestSC(a1.PeerId) + c1 := a2.PeerId + c2 := a1.PeerId cr1 := cc1.MakeCredentials(c1) cr2 := cc2.MakeCredentials(c2) @@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys { require.NoError(t, as.Init(nil)) return as.Account() } - -func newTestSC(peerId string) sec.SecureConn { - pid, _ := peer.Decode(peerId) - return &testSc{ - ID: pid, - } -} - -type testSc struct { - net.Conn - peer.ID -} - -func (t *testSc) LocalPeer() peer.ID { - return "" -} - -func (t *testSc) LocalPrivateKey() crypto.PrivKey { - return nil -} - -func (t *testSc) RemotePeer() peer.ID { - return t.ID -} - -func (t *testSc) RemotePublicKey() crypto.PubKey { - return nil -} - -func (t *testSc) ConnState() network.ConnectionState { - return network.ConnectionState{} -} diff --git a/net/secureservice/handshake/credential.go b/net/secureservice/handshake/credential.go index 06108928..e9baeea5 100644 --- a/net/secureservice/handshake/credential.go +++ b/net/secureservice/handshake/credential.go @@ -3,10 +3,10 @@ package handshake import ( "context" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" - "github.com/libp2p/go-libp2p/core/sec" + "io" ) -func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { if ctx == nil { ctx = context.Background() } @@ -14,21 +14,21 @@ func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChec done := make(chan struct{}) go func() { defer close(done) - identity, err = outgoingHandshake(h, sc, cc) + identity, err = outgoingHandshake(h, conn, peerId, cc) }() select { case <-done: return case <-ctx.Done(): - _ = sc.Close() + _ = conn.Close() return nil, ctx.Err() } } -func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { defer h.release() - h.conn = sc - localCred := cc.MakeCredentials(sc) + h.conn = conn + localCred := cc.MakeCredentials(peerId) if err = h.writeCredentials(localCred); err != nil { h.tryWriteErrAndClose(err) return @@ -45,7 +45,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i return nil, HandshakeError{e: msg.ack.Error} } - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { h.tryWriteErrAndClose(err) return } @@ -68,7 +68,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i } } -func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { if ctx == nil { ctx = context.Background() } @@ -76,32 +76,32 @@ func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChec done := make(chan struct{}) go func() { defer close(done) - identity, err = incomingHandshake(h, sc, cc) + identity, err = incomingHandshake(h, conn, peerId, cc) }() select { case <-done: return case <-ctx.Done(): - _ = sc.Close() + _ = conn.Close() return nil, ctx.Err() } } -func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { defer h.release() - h.conn = sc + h.conn = conn msg, err := h.readMsg(msgTypeCred) if err != nil { h.tryWriteErrAndClose(err) return } - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { h.tryWriteErrAndClose(err) return } - if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { + if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil { h.tryWriteErrAndClose(err) return nil, err } diff --git a/net/secureservice/handshake/credential_test.go b/net/secureservice/handshake/credential_test.go index 6a34f9cb..49865848 100644 --- a/net/secureservice/handshake/credential_test.go +++ b/net/secureservice/handshake/credential_test.go @@ -7,7 +7,6 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net" @@ -17,7 +16,7 @@ import ( var noVerifyChecker = &testCredChecker{ makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, - checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { + checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { return []byte("identity"), nil }, } @@ -32,7 +31,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -40,10 +39,10 @@ func TestOutgoingHandshake(t *testing.T) { // receive credential message msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto) require.NoError(t, err) - _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + _, err = noVerifyChecker.CheckCredential("p1", msg.cred) require.NoError(t, err) // send credential message - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // receive ack msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) @@ -58,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -69,7 +68,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -85,7 +84,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -101,7 +100,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -109,7 +108,7 @@ func TestOutgoingHandshake(t *testing.T) { // receive credential message _, err := h.readMsg(msgTypeCred) require.NoError(t, err) - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) @@ -120,7 +119,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -129,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) { _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials and close conn - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = c2.Close() res := <-handshakeResCh require.Error(t, res.err) @@ -138,7 +137,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -147,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) { _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read ack and close conn _, err = h.readMsg(msgTypeAck) require.NoError(t, err) @@ -159,7 +158,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -168,12 +167,12 @@ func TestOutgoingHandshake(t *testing.T) { _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read ack _, err = h.readMsg(msgTypeAck) require.NoError(t, err) // write cred instead ack - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _, err = h.readMsg(msgTypeAck) require.Error(t, err) res := <-handshakeResCh @@ -183,7 +182,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -192,10 +191,10 @@ func TestOutgoingHandshake(t *testing.T) { msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) - _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + _, err = noVerifyChecker.CheckCredential("", msg.cred) require.NoError(t, err) // send credential message - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // receive ack msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) @@ -211,7 +210,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker) + identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -234,13 +233,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -260,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -271,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials and close conn - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = c2.Close() res := <-handshakeResCh require.Error(t, res.err) @@ -286,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -300,13 +299,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // except ack with error msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) @@ -320,13 +319,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) + identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // except ack with error msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) @@ -340,18 +339,18 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read cred _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write cred instead ack - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // expect EOF _, err = h.readMsg(msgTypeAck) require.Error(t, err) @@ -362,13 +361,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read cred and close conn _, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -381,13 +380,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -403,13 +402,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -425,13 +424,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -448,13 +447,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(ctx, c1, noVerifyChecker) + identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials _, err := h.readMsg(msgTypeCred) require.NoError(t, err) @@ -472,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -491,11 +490,11 @@ func TestEndToEnd(t *testing.T) { ) st := time.Now() go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) outResCh <- handshakeRes{identity: identity, err: err} }() go func() { - identity, err := IncomingHandshake(nil, c2, noVerifyChecker) + identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker) inResCh <- handshakeRes{identity: identity, err: err} }() @@ -519,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) { defer close(done) go func() { for { - _, _ = OutgoingHandshake(nil, c1, noVerifyChecker) + _, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker) select { case outRes <- struct{}{}: case <-done: @@ -529,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) { }() go func() { for { - _, _ = IncomingHandshake(nil, c2, noVerifyChecker) + _, _ = IncomingHandshake(nil, c2, "", noVerifyChecker) select { case inRes <- struct{}{}: case <-done: @@ -549,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) { type testCredChecker struct { makeCred *handshakeproto.Credentials - checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) + checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) checkErr error } -func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { +func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials { return t.makeCred } -func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if t.checkErr != nil { return nil, t.checkErr } if t.checkCred != nil { - return t.checkCred(sc, cred) + return t.checkCred(peerId, cred) } return nil, nil } diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index abbafeb5..42aa2a8c 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -4,10 +4,8 @@ import ( "encoding/binary" "errors" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" - "github.com/libp2p/go-libp2p/core/sec" "golang.org/x/exp/slices" "io" - "net" "sync" ) @@ -65,8 +63,8 @@ var handshakePool = &sync.Pool{New: func() any { }} type CredentialChecker interface { - MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials - CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) + MakeCredentials(remotePeerId string) *handshakeproto.Credentials + CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) } func newHandshake() *handshake { @@ -74,7 +72,7 @@ func newHandshake() *handshake { } type handshake struct { - conn net.Conn + conn io.ReadWriteCloser remoteCred *handshakeproto.Credentials remoteProto *handshakeproto.Proto remoteAck *handshakeproto.Ack diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index 4faca6c0..1e5d3231 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -2,6 +2,7 @@ package secureservice import ( "context" + "crypto/tls" commonaccount "github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" @@ -10,9 +11,9 @@ import ( "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/nodeconf" "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/sec" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" "go.uber.org/zap" + "io" "net" ) @@ -25,8 +26,10 @@ func New() SecureService { } type SecureService interface { - SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) - SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) + SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) + SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) + HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error) + ServerTlsConfig() (*tls.Config, error) app.Component } @@ -75,28 +78,31 @@ func (s *secureService) Name() (name string) { return CName } -func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { - sc, err = s.p2pTr.SecureInbound(ctx, conn, "") +func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) { + sc, err := s.p2pTr.SecureInbound(ctx, conn, "") if err != nil { - return nil, nil, handshake.HandshakeError{ + return nil, handshake.HandshakeError{ Err: err, } } + return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String()) +} - identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker) +func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) { + identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker) if err != nil { - return nil, nil, err + return nil, err } cctx = context.Background() - cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) + cctx = peer.CtxWithPeerId(cctx, peerId) cctx = peer.CtxWithIdentity(cctx, identity) return } -func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { - sc, err = s.p2pTr.SecureOutbound(ctx, conn, "") +func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) { + sc, err := s.p2pTr.SecureOutbound(ctx, conn, "") if err != nil { - return nil, nil, handshake.HandshakeError{Err: err} + return nil, handshake.HandshakeError{Err: err} } peerId := sc.RemotePeer().String() confTypes := s.nodeconf.NodeTypes(peerId) @@ -106,12 +112,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx } else { checker = s.noVerifyChecker } - identity, err := handshake.OutgoingHandshake(ctx, sc, checker) + identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker) if err != nil { - return nil, nil, err + return nil, err } cctx = context.Background() cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) cctx = peer.CtxWithIdentity(cctx, identity) - return cctx, sc, nil + return cctx, nil +} + +func (s *secureService) ServerTlsConfig() (*tls.Config, error) { + p2pIdn, err := libp2ptls.NewIdentity(s.key) + if err != nil { + return nil, err + } + conf, _ := p2pIdn.ConfigForPeer("") + conf.NextProtos = []string{"anysync"} + return conf, nil } diff --git a/net/secureservice/secureservice_test.go b/net/secureservice/secureservice_test.go index e03b92a4..6aee985d 100644 --- a/net/secureservice/secureservice_test.go +++ b/net/secureservice/secureservice_test.go @@ -32,18 +32,17 @@ func TestHandshake(t *testing.T) { resCh := make(chan acceptRes) go func() { var ar acceptRes - ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) + ar.ctx, ar.err = fxS.SecureInbound(ctx, sc) resCh <- ar }() fxC := newFixture(t, nc, nc.GetAccountService(1), 0) defer fxC.Finish(t) - cctx, secConn, err := fxC.SecureOutbound(ctx, cc) + cctx, err := fxC.SecureOutbound(ctx, cc) require.NoError(t, err) ctxPeerId, err := peer.CtxPeerId(cctx) require.NoError(t, err) - assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId) res := <-resCh require.NoError(t, res.err) @@ -70,12 +69,12 @@ func TestHandshakeIncompatibleVersion(t *testing.T) { resCh := make(chan acceptRes) go func() { var ar acceptRes - ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) + ar.ctx, ar.err = fxS.SecureInbound(ctx, sc) resCh <- ar }() fxC := newFixture(t, nc, nc.GetAccountService(1), 1) defer fxC.Finish(t) - _, _, err := fxC.SecureOutbound(ctx, cc) + _, err := fxC.SecureOutbound(ctx, cc) require.Equal(t, handshake.ErrIncompatibleVersion, err) res := <-resCh require.Equal(t, handshake.ErrIncompatibleVersion, res.err) diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go index 5aa66162..26a6f3f2 100644 --- a/net/transport/yamux/conn.go +++ b/net/transport/yamux/conn.go @@ -26,7 +26,10 @@ type yamuxConn struct { } func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) { - return y.Session.Open() + if conn, err = y.Session.Open(); err != nil { + return + } + return connutil.NewTimeout(conn, time.Second*10), nil } func (y *yamuxConn) LastUsage() time.Time { @@ -46,6 +49,7 @@ func (y *yamuxConn) Accept() (conn net.Conn, err error) { if err == yamux.ErrSessionShutdown { err = transport.ErrConnClosed } + return } - return + return connutil.NewTimeout(conn, time.Second*10), nil } diff --git a/net/transport/yamux/yamux.go b/net/transport/yamux/yamux.go index 305fcb92..c0f85777 100644 --- a/net/transport/yamux/yamux.go +++ b/net/transport/yamux/yamux.go @@ -86,12 +86,12 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu } ctx, cancel := context.WithTimeout(ctx, dialTimeout) defer cancel() - cctx, sc, err := y.secure.SecureOutbound(ctx, conn) + cctx, err := y.secure.SecureOutbound(ctx, conn) if err != nil { _ = conn.Close() return nil, err } - luc := connutil.NewLastUsageConn(sc) + luc := connutil.NewLastUsageConn(conn) sess, err := yamux.Client(luc, y.yamuxConf) if err != nil { return @@ -132,12 +132,12 @@ func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) { func (y *yamuxTransport) accept(conn net.Conn) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second) defer cancel() - cctx, sc, err := y.secure.SecureInbound(ctx, conn) + cctx, err := y.secure.SecureInbound(ctx, conn) if err != nil { log.Warn("incoming connection handshake error", zap.Error(err)) return } - luc := connutil.NewLastUsageConn(sc) + luc := connutil.NewLastUsageConn(conn) sess, err := yamux.Server(luc, y.yamuxConf) if err != nil { log.Warn("incoming connection yamux session error", zap.Error(err)) diff --git a/net/transport/yamux/yamux_test.go b/net/transport/yamux/yamux_test.go index 20efdfce..af2850a4 100644 --- a/net/transport/yamux/yamux_test.go +++ b/net/transport/yamux/yamux_test.go @@ -14,7 +14,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io" + "net" + "sync" "testing" + "time" ) var ctx = context.Background() @@ -63,6 +66,64 @@ func TestYamuxTransport_Dial(t *testing.T) { assert.NoError(t, copyErr) } +// no deadline - 69100 rps +// common write deadline - 66700 rps +// subconn write deadline - 67100 rps +func TestWriteBench(t *testing.T) { + t.Skip() + var ( + numSubConn = 10 + numWrites = 100000 + ) + + fxS := newFixture(t) + defer fxS.finish(t) + fxC := newFixture(t) + defer fxC.finish(t) + + mcC, err := fxC.Dial(ctx, fxS.addr) + require.NoError(t, err) + mcS := fxS.accepter.mcs[0] + + go func() { + for i := 0; i < numSubConn; i++ { + conn, err := mcS.Accept() + require.NoError(t, err) + go func(sc net.Conn) { + var b = make([]byte, 1024) + for { + n, _ := sc.Read(b) + if n > 0 { + sc.Write(b[:n]) + } else { + break + } + } + }(conn) + } + }() + + var wg sync.WaitGroup + wg.Add(numSubConn) + st := time.Now() + for i := 0; i < numSubConn; i++ { + conn, err := mcC.Open(ctx) + require.NoError(t, err) + go func(sc net.Conn) { + defer sc.Close() + defer wg.Done() + for j := 0; j < numWrites; j++ { + var b = []byte("some data some data some data some data some data some data some data some data some data") + sc.Write(b) + sc.Read(b) + } + }(conn) + } + wg.Wait() + dur := time.Since(st) + t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds()) +} + type fixture struct { *yamuxTransport a *app.App