mirror of
https://github.com/ehang-io/nps.git
synced 2025-07-02 04:00:42 +00:00
117 lines
3.2 KiB
Go
117 lines
3.2 KiB
Go
package bridge
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"ehang.io/nps/lib/logger"
|
|
"ehang.io/nps/lib/pb"
|
|
"ehang.io/nps/lib/pool"
|
|
"ehang.io/nps/transport"
|
|
"github.com/panjf2000/ants/v2"
|
|
"go.uber.org/zap"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
)
|
|
|
|
type tcpServer struct {
|
|
config *tls.Config // config must contain root ca and server cert
|
|
ln net.Listener
|
|
gp *ants.PoolWithFunc
|
|
manager *manager
|
|
serverCheck func(id string) bool
|
|
clientCheck func(id string) bool
|
|
}
|
|
|
|
func NewTcpServer(ln net.Listener, config *tls.Config, serverCheck, clientCheck func(string) bool) (*tcpServer, error) {
|
|
h := &tcpServer{
|
|
config: config,
|
|
ln: ln,
|
|
serverCheck: serverCheck,
|
|
clientCheck: clientCheck,
|
|
manager: NewManager(),
|
|
}
|
|
var err error
|
|
h.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
|
conn := i.(net.Conn)
|
|
cr := &pb.ConnRequest{}
|
|
_, err := pb.ReadMessage(conn, cr)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
logger.Warn("read message error", zap.Error(err))
|
|
return
|
|
}
|
|
switch cr.ConnType.(type) {
|
|
case *pb.ConnRequest_AppInfo:
|
|
h.handleApp(cr, conn)
|
|
case *pb.ConnRequest_NpcInfo:
|
|
h.handleClient(cr, conn)
|
|
}
|
|
})
|
|
return h, err
|
|
}
|
|
|
|
func (h *tcpServer) handleApp(cr *pb.ConnRequest, conn net.Conn) {
|
|
if !h.serverCheck(cr.GetId()) {
|
|
_ = conn.Close()
|
|
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
|
return
|
|
}
|
|
clientConn, err := h.manager.GetDataConn(cr.GetAppInfo().GetNpcId())
|
|
if err != nil {
|
|
logger.Error("get client error", zap.Error(err), zap.String("app_info", cr.String()))
|
|
return
|
|
}
|
|
_, err = pb.WriteMessage(clientConn, &pb.ClientRequest{ConnType: &pb.ClientRequest_AppInfo{AppInfo: cr.GetAppInfo()}})
|
|
if err != nil {
|
|
_ = clientConn.Close()
|
|
_ = conn.Close()
|
|
logger.Error("write app_info error", zap.Error(err))
|
|
return
|
|
}
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: conn, Reader: clientConn, Wg: &wg})
|
|
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: clientConn, Reader: conn, Wg: &wg})
|
|
wg.Wait()
|
|
}
|
|
|
|
func (h *tcpServer) responseClient(conn io.Writer, success bool, msg string) error {
|
|
_, err := pb.WriteMessage(conn, &pb.NpcResponse{Success: success, Message: msg})
|
|
return err
|
|
}
|
|
|
|
func (h *tcpServer) handleClient(cr *pb.ConnRequest, conn net.Conn) {
|
|
if !h.clientCheck(cr.GetId()) {
|
|
_ = conn.Close()
|
|
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
|
_ = h.responseClient(conn, false, "id check failed")
|
|
return
|
|
}
|
|
yc := transport.NewYaMux(conn, nil)
|
|
err := yc.Client()
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
_ = h.responseClient(conn, false, "client failed")
|
|
logger.Error("new yamux client error", zap.Error(err), zap.String("remote address", conn.RemoteAddr().String()))
|
|
return
|
|
}
|
|
_ = h.responseClient(conn, true, "success")
|
|
err = h.manager.SetClient(cr.GetId(), cr.GetNpcInfo().GetTunnelId(), cr.GetNpcInfo().GetIsControlTunnel(), yc)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
logger.Error("set client error", zap.Error(err), zap.String("info", cr.String()))
|
|
}
|
|
}
|
|
|
|
func (h *tcpServer) run() error {
|
|
h.ln = tls.NewListener(h.ln, h.config)
|
|
for {
|
|
conn, err := h.ln.Accept()
|
|
if err != nil {
|
|
logger.Error("Accept conn error", zap.Error(err))
|
|
return err
|
|
}
|
|
_ = h.gp.Invoke(conn)
|
|
}
|
|
}
|