1
0
Fork 0
mirror of https://github.com/anyproto/any-sync.git synced 2025-06-09 17:45:03 +09:00
any-sync/net/transport/yamux/yamux.go
2023-06-07 14:49:44 +02:00

163 lines
4 KiB
Go

package yamux
import (
"context"
"fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/anyproto/any-sync/net/transport"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"net"
"time"
)
const CName = "net.transport.yamux"
var log = logger.NewNamed(CName)
func New() Yamux {
return new(yamuxTransport)
}
// Yamux implements transport.Transport with tcp+yamux
type Yamux interface {
transport.Transport
app.ComponentRunnable
}
type yamuxTransport struct {
secure secureservice.SecureService
accepter transport.Accepter
conf Config
listeners []net.Listener
listCtx context.Context
listCtxCancel context.CancelFunc
yamuxConf *yamux.Config
}
func (y *yamuxTransport) Init(a *app.App) (err error) {
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
y.conf = a.MustComponent("config").(configGetter).GetYamux()
if y.conf.DialTimeoutSec <= 0 {
y.conf.DialTimeoutSec = 10
}
if y.conf.WriteTimeoutSec <= 0 {
y.conf.WriteTimeoutSec = 10
}
y.yamuxConf = yamux.DefaultConfig()
y.yamuxConf.EnableKeepAlive = false
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
return
}
func (y *yamuxTransport) Name() string {
return CName
}
func (y *yamuxTransport) Run(ctx context.Context) (err error) {
if y.accepter == nil {
return fmt.Errorf("can't run service without accepter")
}
for _, listAddr := range y.conf.ListenAddrs {
list, err := net.Listen("tcp", listAddr)
if err != nil {
return err
}
y.listeners = append(y.listeners, list)
}
y.listCtx, y.listCtxCancel = context.WithCancel(context.Background())
for _, list := range y.listeners {
go y.acceptLoop(y.listCtx, list)
}
return
}
func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) {
y.accepter = accepter
}
func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) {
dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
cctx, err := y.secure.SecureOutbound(ctx, conn)
if err != nil {
_ = conn.Close()
return nil, err
}
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Client(luc, y.yamuxConf)
if err != nil {
return
}
mc = NewMultiConn(cctx, luc, addr, sess)
return
}
func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
l := log.With(zap.String("localAddr", list.Addr().String()))
l.Info("yamux listener started")
defer func() {
l.Debug("yamux listener stopped")
}()
for {
conn, err := list.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
select {
case <-time.After(time.Second):
case <-ctx.Done():
return
}
continue
}
if err != net.ErrClosed {
l.Error("listener closed with error", zap.Error(err))
} else {
l.Info("listener closed")
}
return
}
go y.accept(conn)
}
}
func (y *yamuxTransport) accept(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
defer cancel()
cctx, err := y.secure.SecureInbound(ctx, conn)
if err != nil {
log.Warn("incoming connection handshake error", zap.Error(err))
return
}
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Server(luc, y.yamuxConf)
if err != nil {
log.Warn("incoming connection yamux session error", zap.Error(err))
return
}
mc := NewMultiConn(cctx, luc, conn.RemoteAddr().String(), sess)
if err = y.accepter.Accept(mc); err != nil {
log.Warn("connection accept error", zap.Error(err))
}
}
func (y *yamuxTransport) Close(ctx context.Context) (err error) {
if y.listCtxCancel != nil {
y.listCtxCancel()
}
for _, l := range y.listeners {
_ = l.Close()
}
return
}