diff --git a/net/transport/quic/conn.go b/net/transport/quic/conn.go index a160dbf3..8563c74c 100644 --- a/net/transport/quic/conn.go +++ b/net/transport/quic/conn.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "time" "github.com/quic-go/quic-go" @@ -11,16 +12,18 @@ import ( "github.com/anyproto/any-sync/net/transport" ) -func newConn(cctx context.Context, qconn quic.Connection) transport.MultiConn { +func newConn(cctx context.Context, qconn quic.Connection, writeTimeout time.Duration) transport.MultiConn { cctx = peer.CtxWithPeerAddr(cctx, transport.Quic+"://"+qconn.RemoteAddr().String()) return &quicMultiConn{ - cctx: cctx, - Connection: qconn, + cctx: cctx, + Connection: qconn, + writeTimeout: writeTimeout, } } type quicMultiConn struct { - cctx context.Context + cctx context.Context + writeTimeout time.Duration quic.Connection } @@ -39,9 +42,10 @@ func (q *quicMultiConn) Accept() (conn net.Conn, err error) { return nil, err } return quicNetConn{ - Stream: stream, - localAddr: q.LocalAddr(), - remoteAddr: q.RemoteAddr(), + Stream: stream, + localAddr: q.LocalAddr(), + remoteAddr: q.RemoteAddr(), + writeTimeout: q.writeTimeout, }, nil } @@ -84,6 +88,7 @@ const ( type quicNetConn struct { quic.Stream + writeTimeout time.Duration localAddr, remoteAddr net.Addr } @@ -98,6 +103,15 @@ func (q quicNetConn) Close() error { return q.Stream.Close() } +func (q quicNetConn) Write(b []byte) (n int, err error) { + if q.writeTimeout > 0 { + if err = q.Stream.SetWriteDeadline(time.Now().Add(q.writeTimeout)); err != nil { + return + } + } + return q.Stream.Write(b) +} + func (q quicNetConn) LocalAddr() net.Addr { return q.localAddr } diff --git a/net/transport/quic/quic.go b/net/transport/quic/quic.go index dbd712ca..27c7b786 100644 --- a/net/transport/quic/quic.go +++ b/net/transport/quic/quic.go @@ -147,7 +147,7 @@ func (q *quicTransport) Dial(ctx context.Context, addr string) (mc transport.Mul return nil, err } - return newConn(cctx, qConn), nil + return newConn(cctx, qConn, time.Second*time.Duration(q.conf.WriteTimeoutSec)), nil } func (q *quicTransport) acceptLoop(ctx context.Context, list *quic.Listener) { @@ -199,7 +199,7 @@ func (q *quicTransport) accept(conn quic.Connection) (err error) { }() return } - mc := newConn(cctx, conn) + mc := newConn(cctx, conn, time.Second*time.Duration(q.conf.WriteTimeoutSec)) return q.accepter.Accept(mc) }