mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-08 18:09:03 +00:00
add new file
This commit is contained in:
23
component/bridge/bridge.go
Normal file
23
component/bridge/bridge.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"net"
|
||||
)
|
||||
|
||||
func StartTcpBridge(ln net.Listener, config *tls.Config, serverCheck, clientCheck func(string) bool) error {
|
||||
h, err := NewTcpServer(ln, config, serverCheck, clientCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.run()
|
||||
}
|
||||
|
||||
func StartQUICBridge(ln net.PacketConn, config *tls.Config, quicConfig *quic.Config, clientCheck func(string) bool) error {
|
||||
h, err := NewQUICServer(ln, config, quicConfig, clientCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.run()
|
||||
}
|
102
component/bridge/client.go
Normal file
102
component/bridge/client.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
*manager
|
||||
tunnelId string
|
||||
clientId string
|
||||
control transport.Conn
|
||||
data transport.Conn
|
||||
httpClient *http.Client
|
||||
pingErrTimes int
|
||||
}
|
||||
|
||||
func NewClient(tunnelId string, clientId string, mg *manager) *client {
|
||||
c := &client{tunnelId: tunnelId, clientId: clientId, manager: mg}
|
||||
c.httpClient = &http.Client{Transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return c.NewControlConn()
|
||||
}}}
|
||||
go c.doPing()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *client) SetTunnel(conn transport.Conn, control bool) {
|
||||
if control {
|
||||
c.control = conn
|
||||
} else {
|
||||
c.data = conn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) NewDataConn() (net.Conn, error) {
|
||||
if c.data == nil {
|
||||
return nil, errors.New("the data tunnel is not exist")
|
||||
}
|
||||
return c.data.Open()
|
||||
}
|
||||
|
||||
func (c *client) NewControlConn() (net.Conn, error) {
|
||||
if c.control == nil {
|
||||
return nil, errors.New("the data tunnel is not exist")
|
||||
}
|
||||
return c.control.Open()
|
||||
}
|
||||
|
||||
func (c *client) doPing() {
|
||||
for range time.NewTicker(time.Second * 5).C {
|
||||
if err := c.ping(); err != nil {
|
||||
c.pingErrTimes++
|
||||
logger.Error("do ping error", zap.Error(err))
|
||||
} else {
|
||||
logger.Debug("do ping success", zap.String("client id", c.clientId), zap.String("tunnel id", c.tunnelId))
|
||||
c.pingErrTimes = 0
|
||||
}
|
||||
if c.pingErrTimes > 3 {
|
||||
logger.Error("ping failed, close")
|
||||
c.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) close() {
|
||||
if c.data != nil {
|
||||
_ = c.data.Close()
|
||||
}
|
||||
if c.control != nil {
|
||||
_ = c.control.Close()
|
||||
}
|
||||
_ = c.RemoveClient(c)
|
||||
}
|
||||
|
||||
func (c *client) ping() error {
|
||||
conn, err := c.NewDataConn()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
_, err = pb.WriteMessage(conn, &pb.ClientRequest{ConnType: &pb.ClientRequest_Ping{Ping: &pb.Ping{Now: time.Now().String()}}})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
_, err = pb.ReadMessage(conn, &pb.Ping{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
resp, err := c.httpClient.Get("http://nps.ehang.io/ping")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "control ping")
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
return nil
|
||||
}
|
60
component/bridge/manager.go
Normal file
60
component/bridge/manager.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/lb"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/transport"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
clients map[string]*client
|
||||
clientLb *lb.LoadBalancer
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager() *manager {
|
||||
return &manager{
|
||||
clients: make(map[string]*client),
|
||||
clientLb: lb.NewLoadBalancer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SetClient(clientId string, tunnelId string, isControl bool, conn transport.Conn) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
client, ok := m.clients[tunnelId]
|
||||
if !ok {
|
||||
client = NewClient(tunnelId, clientId, m)
|
||||
err := m.clientLb.SetClient(clientId, client)
|
||||
if err != nil {
|
||||
logger.Error("set client error", zap.Error(err), zap.String("clientId", clientId), zap.String("tunnelId", tunnelId))
|
||||
return err
|
||||
}
|
||||
m.clients[tunnelId] = client
|
||||
}
|
||||
client.SetTunnel(conn, isControl)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) GetDataConn(clientId string) (net.Conn, error) {
|
||||
c, err := m.clientLb.GetClient(clientId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.(*client).NewDataConn()
|
||||
}
|
||||
|
||||
func (m *manager) RemoveClient(client *client) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
err := m.clientLb.RemoveClient(client.clientId, client)
|
||||
if err != nil {
|
||||
logger.Error("remove client error", zap.Error(err), zap.String("clientId", client.clientId), zap.String("tunnelId", client.tunnelId))
|
||||
return err
|
||||
}
|
||||
delete(m.clients, client.tunnelId)
|
||||
return nil
|
||||
}
|
91
component/bridge/quic.go
Normal file
91
component/bridge/quic.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type QUICServer struct {
|
||||
packetConn net.PacketConn
|
||||
tlsConfig *tls.Config
|
||||
config *quic.Config
|
||||
listener quic.Listener
|
||||
gp *ants.PoolWithFunc
|
||||
clientCheck func(string) bool
|
||||
manager *manager
|
||||
}
|
||||
|
||||
func NewQUICServer(packetConn net.PacketConn, tlsConfig *tls.Config, config *quic.Config, clientCheck func(string) bool) (*QUICServer, error) {
|
||||
qs := &QUICServer{
|
||||
packetConn: packetConn,
|
||||
tlsConfig: tlsConfig,
|
||||
config: config,
|
||||
clientCheck: clientCheck,
|
||||
manager: NewManager(),
|
||||
}
|
||||
var err error
|
||||
if qs.listener, err = quic.Listen(packetConn, tlsConfig, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qs.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
session := i.(quic.Session)
|
||||
logger.Debug("accept a session", zap.String("remote addr", session.RemoteAddr().String()))
|
||||
stream, err := session.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
logger.Warn("accept stream error", zap.Error(err))
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
return
|
||||
}
|
||||
cr := &pb.ConnRequest{}
|
||||
_, err = pb.ReadMessage(stream, cr)
|
||||
if err != nil {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Warn("read message error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if !qs.clientCheck(cr.GetId()) {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
_ = qs.responseClient(stream, false, "id check failed")
|
||||
return
|
||||
}
|
||||
qc := transport.NewQUIC(session)
|
||||
_ = qc.Client()
|
||||
|
||||
_ = qs.responseClient(stream, true, "success")
|
||||
err = qs.manager.SetClient(cr.GetId(), cr.GetNpcInfo().GetTunnelId(), cr.GetNpcInfo().GetIsControlTunnel(), qc)
|
||||
if err != nil {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Error("set client error", zap.Error(err), zap.String("info", cr.String()))
|
||||
}
|
||||
})
|
||||
return qs, err
|
||||
}
|
||||
|
||||
func (qs *QUICServer) responseClient(conn io.Writer, success bool, msg string) error {
|
||||
_, err := pb.WriteMessage(conn, &pb.NpcResponse{Success: success, Message: msg})
|
||||
return err
|
||||
}
|
||||
|
||||
func (qs *QUICServer) run() error {
|
||||
for {
|
||||
session, err := qs.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
logger.Error("accept connection failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
err = qs.gp.Invoke(session)
|
||||
if err != nil {
|
||||
logger.Error("Invoke session error", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
116
component/bridge/tcp.go
Normal file
116
component/bridge/tcp.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type tcpServer struct {
|
||||
config *tls.Config // config must contain root ca and server cert
|
||||
ln net.Listener
|
||||
gp *ants.PoolWithFunc
|
||||
manager *manager
|
||||
serverCheck func(id string) bool
|
||||
clientCheck func(id string) bool
|
||||
}
|
||||
|
||||
func NewTcpServer(ln net.Listener, config *tls.Config, serverCheck, clientCheck func(string) bool) (*tcpServer, error) {
|
||||
h := &tcpServer{
|
||||
config: config,
|
||||
ln: ln,
|
||||
serverCheck: serverCheck,
|
||||
clientCheck: clientCheck,
|
||||
manager: NewManager(),
|
||||
}
|
||||
var err error
|
||||
h.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
conn := i.(net.Conn)
|
||||
cr := &pb.ConnRequest{}
|
||||
_, err := pb.ReadMessage(conn, cr)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
logger.Warn("read message error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
switch cr.ConnType.(type) {
|
||||
case *pb.ConnRequest_AppInfo:
|
||||
h.handleApp(cr, conn)
|
||||
case *pb.ConnRequest_NpcInfo:
|
||||
h.handleClient(cr, conn)
|
||||
}
|
||||
})
|
||||
return h, err
|
||||
}
|
||||
|
||||
func (h *tcpServer) handleApp(cr *pb.ConnRequest, conn net.Conn) {
|
||||
if !h.serverCheck(cr.GetId()) {
|
||||
_ = conn.Close()
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
return
|
||||
}
|
||||
clientConn, err := h.manager.GetDataConn(cr.GetAppInfo().GetNpcId())
|
||||
if err != nil {
|
||||
logger.Error("get client error", zap.Error(err), zap.String("app_info", cr.String()))
|
||||
return
|
||||
}
|
||||
_, err = pb.WriteMessage(clientConn, &pb.ClientRequest{ConnType: &pb.ClientRequest_AppInfo{AppInfo: cr.GetAppInfo()}})
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = conn.Close()
|
||||
logger.Error("write app_info error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: conn, Reader: clientConn, Wg: &wg})
|
||||
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: clientConn, Reader: conn, Wg: &wg})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *tcpServer) responseClient(conn io.Writer, success bool, msg string) error {
|
||||
_, err := pb.WriteMessage(conn, &pb.NpcResponse{Success: success, Message: msg})
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *tcpServer) handleClient(cr *pb.ConnRequest, conn net.Conn) {
|
||||
if !h.clientCheck(cr.GetId()) {
|
||||
_ = conn.Close()
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
_ = h.responseClient(conn, false, "id check failed")
|
||||
return
|
||||
}
|
||||
yc := transport.NewYaMux(conn, nil)
|
||||
err := yc.Client()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
_ = h.responseClient(conn, false, "client failed")
|
||||
logger.Error("new yamux client error", zap.Error(err), zap.String("remote address", conn.RemoteAddr().String()))
|
||||
return
|
||||
}
|
||||
_ = h.responseClient(conn, true, "success")
|
||||
err = h.manager.SetClient(cr.GetId(), cr.GetNpcInfo().GetTunnelId(), cr.GetNpcInfo().GetIsControlTunnel(), yc)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
logger.Error("set client error", zap.Error(err), zap.String("info", cr.String()))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tcpServer) run() error {
|
||||
h.ln = tls.NewListener(h.ln, h.config)
|
||||
for {
|
||||
conn, err := h.ln.Accept()
|
||||
if err != nil {
|
||||
logger.Error("Accept conn error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
_ = h.gp.Invoke(conn)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user