diff --git a/commonspace/headsync/diffsyncer.go b/commonspace/headsync/diffsyncer.go index 2a03dd2c..793ea7d6 100644 --- a/commonspace/headsync/diffsyncer.go +++ b/commonspace/headsync/diffsyncer.go @@ -154,12 +154,8 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) if err != nil { return } - defer p.ReleaseDrpcConn(conn) - defer func() { - if ctx.Err() != nil { - _ = conn.Close() - } - }() + defer p.ReleaseDrpcConn(ctx, conn) + var ( cl = d.clientFactory.Client(conn) rdiff = NewRemoteDiff(d.spaceId, cl) diff --git a/commonspace/object/keyvalue/keyvalue.go b/commonspace/object/keyvalue/keyvalue.go index 23ca41db..ea353fd6 100644 --- a/commonspace/object/keyvalue/keyvalue.go +++ b/commonspace/object/keyvalue/keyvalue.go @@ -64,12 +64,7 @@ func (k *keyValueService) syncWithPeer(ctx context.Context, p peer.Peer) (err er if err != nil { return } - defer p.ReleaseDrpcConn(conn) - defer func() { - if ctx.Err() != nil { - _ = conn.Close() - } - }() + defer p.ReleaseDrpcConn(ctx, conn) var ( client = k.clientFactory.Client(conn) rdiff = NewRemoteDiff(k.spaceId, client) diff --git a/consensus/consensusclient/client.go b/consensus/consensusclient/client.go index 4bdb0a5e..2a73635e 100644 --- a/consensus/consensusclient/client.go +++ b/consensus/consensusclient/client.go @@ -86,12 +86,7 @@ func (s *service) doClient(ctx context.Context, fn func(cl consensusproto.DRPCCo if err != nil { return err } - defer peer.ReleaseDrpcConn(dc) - defer func() { - if ctx.Err() != nil { - _ = dc.Close() - } - }() + defer peer.ReleaseDrpcConn(ctx, dc) return fn(consensusproto.NewDRPCConsensusClient(dc)) } diff --git a/nameservice/nameserviceclient/nameserviceclient.go b/nameservice/nameserviceclient/nameserviceclient.go index 93fef40a..2cae941f 100644 --- a/nameservice/nameserviceclient/nameserviceclient.go +++ b/nameservice/nameserviceclient/nameserviceclient.go @@ -91,7 +91,7 @@ func (s *service) doClient(ctx context.Context, fn func(cl nsp.DRPCAnynsClient) log.Error("failed to acquire a DRPC connection to namingnode", zap.Error(err)) return err } - defer peer.ReleaseDrpcConn(dc) + defer peer.ReleaseDrpcConn(ctx, dc) return fn(nsp.NewDRPCAnynsClient(dc)) } @@ -110,12 +110,7 @@ func (s *service) doClientAA(ctx context.Context, fn func(cl nsp.DRPCAnynsAccoun if err != nil { return err } - defer peer.ReleaseDrpcConn(dc) - defer func() { - if ctx.Err() != nil { - _ = dc.Close() - } - }() + defer peer.ReleaseDrpcConn(ctx, dc) return fn(nsp.NewDRPCAnynsAccountAbstractionClient(dc)) } diff --git a/net/peer/mock_peer/mock_peer.go b/net/peer/mock_peer/mock_peer.go index 4ace14a8..e116b296 100644 --- a/net/peer/mock_peer/mock_peer.go +++ b/net/peer/mock_peer/mock_peer.go @@ -5,6 +5,7 @@ // // mockgen -destination mock_peer/mock_peer.go github.com/anyproto/any-sync/net/peer Peer // + // Package mock_peer is a generated GoMock package. package mock_peer @@ -21,6 +22,7 @@ import ( type MockPeer struct { ctrl *gomock.Controller recorder *MockPeerMockRecorder + isgomock struct{} } // MockPeerMockRecorder is the mock recorder for MockPeer. @@ -41,18 +43,18 @@ func (m *MockPeer) EXPECT() *MockPeerMockRecorder { } // AcquireDrpcConn mocks base method. -func (m *MockPeer) AcquireDrpcConn(arg0 context.Context) (drpc.Conn, error) { +func (m *MockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AcquireDrpcConn", arg0) + ret := m.ctrl.Call(m, "AcquireDrpcConn", ctx) ret0, _ := ret[0].(drpc.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // AcquireDrpcConn indicates an expected call of AcquireDrpcConn. -func (mr *MockPeerMockRecorder) AcquireDrpcConn(arg0 any) *gomock.Call { +func (mr *MockPeerMockRecorder) AcquireDrpcConn(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireDrpcConn", reflect.TypeOf((*MockPeer)(nil).AcquireDrpcConn), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireDrpcConn", reflect.TypeOf((*MockPeer)(nil).AcquireDrpcConn), ctx) } // Close mocks base method. @@ -98,17 +100,17 @@ func (mr *MockPeerMockRecorder) Context() *gomock.Call { } // DoDrpc mocks base method. -func (m *MockPeer) DoDrpc(arg0 context.Context, arg1 func(drpc.Conn) error) error { +func (m *MockPeer) DoDrpc(ctx context.Context, do func(drpc.Conn) error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoDrpc", arg0, arg1) + ret := m.ctrl.Call(m, "DoDrpc", ctx, do) ret0, _ := ret[0].(error) return ret0 } // DoDrpc indicates an expected call of DoDrpc. -func (mr *MockPeerMockRecorder) DoDrpc(arg0, arg1 any) *gomock.Call { +func (mr *MockPeerMockRecorder) DoDrpc(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoDrpc", reflect.TypeOf((*MockPeer)(nil).DoDrpc), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoDrpc", reflect.TypeOf((*MockPeer)(nil).DoDrpc), ctx, do) } // Id mocks base method. @@ -140,40 +142,40 @@ func (mr *MockPeerMockRecorder) IsClosed() *gomock.Call { } // ReleaseDrpcConn mocks base method. -func (m *MockPeer) ReleaseDrpcConn(arg0 drpc.Conn) { +func (m *MockPeer) ReleaseDrpcConn(ctx context.Context, conn drpc.Conn) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReleaseDrpcConn", arg0) + m.ctrl.Call(m, "ReleaseDrpcConn", ctx, conn) } // ReleaseDrpcConn indicates an expected call of ReleaseDrpcConn. -func (mr *MockPeerMockRecorder) ReleaseDrpcConn(arg0 any) *gomock.Call { +func (mr *MockPeerMockRecorder) ReleaseDrpcConn(ctx, conn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseDrpcConn", reflect.TypeOf((*MockPeer)(nil).ReleaseDrpcConn), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseDrpcConn", reflect.TypeOf((*MockPeer)(nil).ReleaseDrpcConn), ctx, conn) } // SetTTL mocks base method. -func (m *MockPeer) SetTTL(arg0 time.Duration) { +func (m *MockPeer) SetTTL(ttl time.Duration) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTTL", arg0) + m.ctrl.Call(m, "SetTTL", ttl) } // SetTTL indicates an expected call of SetTTL. -func (mr *MockPeerMockRecorder) SetTTL(arg0 any) *gomock.Call { +func (mr *MockPeerMockRecorder) SetTTL(ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTTL", reflect.TypeOf((*MockPeer)(nil).SetTTL), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTTL", reflect.TypeOf((*MockPeer)(nil).SetTTL), ttl) } // TryClose mocks base method. -func (m *MockPeer) TryClose(arg0 time.Duration) (bool, error) { +func (m *MockPeer) TryClose(objectTTL time.Duration) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TryClose", arg0) + ret := m.ctrl.Call(m, "TryClose", objectTTL) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // TryClose indicates an expected call of TryClose. -func (mr *MockPeerMockRecorder) TryClose(arg0 any) *gomock.Call { +func (mr *MockPeerMockRecorder) TryClose(objectTTL any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockPeer)(nil).TryClose), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockPeer)(nil).TryClose), objectTTL) } diff --git a/net/peer/peer.go b/net/peer/peer.go index dfa68b4d..1c73af0b 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -71,7 +71,7 @@ type Peer interface { Context() context.Context AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) - ReleaseDrpcConn(conn drpc.Conn) + ReleaseDrpcConn(ctx context.Context, conn drpc.Conn) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error IsClosed() bool @@ -169,34 +169,37 @@ func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { return res, nil } -func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { +// ReleaseDrpcConn releases the connection back to the pool. +// you should pass the same ctx you passed to AcquireDrpcConn +func (p *peer) ReleaseDrpcConn(ctx context.Context, conn drpc.Conn) { var closed bool select { + case <-ctx.Done(): + // in case ctx is closed the connection may be not yet closed because of the signal logic in the drpc manager + _ = conn.Close() + closed = true case <-conn.Closed(): closed = true default: - } - - // make sure this connection doesn't have an unfinished work - if connCasted, ok := conn.(connUnblocked); ok { - select { - case <-conn.Closed(): - closed = true - case <-connCasted.Unblocked(): - // semi-safe to reuse this connection - // it may be still a chance that connection will be closed in next milliseconds - default: - // means the connection has some unfinished work, - // e.g. not fully read stream - // we cannot reuse this connection so let's close it - err := conn.Close() - if err != nil { - log.Info("ReleaseDrpcConn failed to close connection", zap.String("peerId", p.id), zap.Error(err)) + // make sure this connection doesn't have an unfinished work + if connCasted, ok := conn.(connUnblocked); ok { + select { + case <-conn.Closed(): + closed = true + case <-connCasted.Unblocked(): + // semi-safe to reuse this connection + // it may be still a chance that connection will be closed in next milliseconds + // but this is a trade-off for performance + default: + // means the connection has some unfinished work, + // e.g. not fully read stream + // we cannot reuse this connection so let's close it + _ = conn.Close() + closed = true } - closed = true + } else { + panic("conn does not implement Unblocked()") } - } else { - panic("conn does not implement Unblocked()") } if !closed { @@ -229,7 +232,7 @@ func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { select { case p.subConnRelease <- nil: // wake up the waiting AcquireDrpcConn - // even in case it is closed, it will be discarded + // it will take the next one from the inactive pool return default: } @@ -243,12 +246,7 @@ func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error return err } err = do(conn) - defer p.ReleaseDrpcConn(conn) - defer func() { - if ctx.Err() != nil { - _ = conn.Close() - } - }() + defer p.ReleaseDrpcConn(ctx, conn) return err } diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go index 1dc40433..bf09b7bf 100644 --- a/net/peer/peer_test.go +++ b/net/peer/peer_test.go @@ -86,7 +86,7 @@ func TestPeer_AcquireDrpcConn(t *testing.T) { assert.Len(t, fx.active, 1) assert.Len(t, fx.inactive, 0) - fx.ReleaseDrpcConn(dc) + fx.ReleaseDrpcConn(ctx, dc) assert.Len(t, fx.active, 0) assert.Len(t, fx.inactive, 1) @@ -103,7 +103,7 @@ func TestPeer_AcquireDrpcConn(t *testing.T) { closedIn, _ := net.Pipe() dc := drpcconn.New(closedIn) - fx.ReleaseDrpcConn(&subConn{Conn: dc}) + fx.ReleaseDrpcConn(ctx, &subConn{Conn: dc}) dc.Close() in, out := net.Pipe() @@ -144,7 +144,7 @@ func TestPeer_DrpcConn_OpenThrottling(t *testing.T) { go func() { time.Sleep(fx.limiter.slowDownStep) - fx.ReleaseDrpcConn(conns[0]) + fx.ReleaseDrpcConn(ctx, conns[0]) conns = conns[1:] }() _, err := fx.AcquireDrpcConn(ctx) @@ -273,9 +273,9 @@ func TestPeer_TryClose(t *testing.T) { require.NoError(t, err) defer dc4.Close() - fx.ReleaseDrpcConn(dc3) + fx.ReleaseDrpcConn(ctx, dc3) _ = dc3.Close() - fx.ReleaseDrpcConn(dc) + fx.ReleaseDrpcConn(ctx, dc) time.Sleep(time.Millisecond * 100) diff --git a/net/rpc/rpctest/peer.go b/net/rpc/rpctest/peer.go index 202e07ba..66ba62f0 100644 --- a/net/rpc/rpctest/peer.go +++ b/net/rpc/rpctest/peer.go @@ -46,7 +46,7 @@ func (m MockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { return nil, nil } -func (m MockPeer) ReleaseDrpcConn(conn drpc.Conn) { +func (m MockPeer) ReleaseDrpcConn(ctx context.Context, conn drpc.Conn) { return } diff --git a/paymentservice/paymentserviceclient/paymentserviceclient.go b/paymentservice/paymentserviceclient/paymentserviceclient.go index 0fd39ede..989979e8 100644 --- a/paymentservice/paymentserviceclient/paymentserviceclient.go +++ b/paymentservice/paymentserviceclient/paymentserviceclient.go @@ -92,7 +92,7 @@ func (s *service) doClient(ctx context.Context, fn func(cl pp.DRPCAnyPaymentProc log.Error("failed to acquire a DRPC connection to paymentnode", zap.Error(err)) return err } - defer peer.ReleaseDrpcConn(dc) + defer peer.ReleaseDrpcConn(ctx, dc) return fn(pp.NewDRPCAnyPaymentProcessingClient(dc)) }