mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-08 00:26:52 +00:00
add new file
This commit is contained in:
23
component/bridge/bridge.go
Normal file
23
component/bridge/bridge.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"net"
|
||||
)
|
||||
|
||||
func StartTcpBridge(ln net.Listener, config *tls.Config, serverCheck, clientCheck func(string) bool) error {
|
||||
h, err := NewTcpServer(ln, config, serverCheck, clientCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.run()
|
||||
}
|
||||
|
||||
func StartQUICBridge(ln net.PacketConn, config *tls.Config, quicConfig *quic.Config, clientCheck func(string) bool) error {
|
||||
h, err := NewQUICServer(ln, config, quicConfig, clientCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.run()
|
||||
}
|
102
component/bridge/client.go
Normal file
102
component/bridge/client.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
*manager
|
||||
tunnelId string
|
||||
clientId string
|
||||
control transport.Conn
|
||||
data transport.Conn
|
||||
httpClient *http.Client
|
||||
pingErrTimes int
|
||||
}
|
||||
|
||||
func NewClient(tunnelId string, clientId string, mg *manager) *client {
|
||||
c := &client{tunnelId: tunnelId, clientId: clientId, manager: mg}
|
||||
c.httpClient = &http.Client{Transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return c.NewControlConn()
|
||||
}}}
|
||||
go c.doPing()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *client) SetTunnel(conn transport.Conn, control bool) {
|
||||
if control {
|
||||
c.control = conn
|
||||
} else {
|
||||
c.data = conn
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) NewDataConn() (net.Conn, error) {
|
||||
if c.data == nil {
|
||||
return nil, errors.New("the data tunnel is not exist")
|
||||
}
|
||||
return c.data.Open()
|
||||
}
|
||||
|
||||
func (c *client) NewControlConn() (net.Conn, error) {
|
||||
if c.control == nil {
|
||||
return nil, errors.New("the data tunnel is not exist")
|
||||
}
|
||||
return c.control.Open()
|
||||
}
|
||||
|
||||
func (c *client) doPing() {
|
||||
for range time.NewTicker(time.Second * 5).C {
|
||||
if err := c.ping(); err != nil {
|
||||
c.pingErrTimes++
|
||||
logger.Error("do ping error", zap.Error(err))
|
||||
} else {
|
||||
logger.Debug("do ping success", zap.String("client id", c.clientId), zap.String("tunnel id", c.tunnelId))
|
||||
c.pingErrTimes = 0
|
||||
}
|
||||
if c.pingErrTimes > 3 {
|
||||
logger.Error("ping failed, close")
|
||||
c.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) close() {
|
||||
if c.data != nil {
|
||||
_ = c.data.Close()
|
||||
}
|
||||
if c.control != nil {
|
||||
_ = c.control.Close()
|
||||
}
|
||||
_ = c.RemoveClient(c)
|
||||
}
|
||||
|
||||
func (c *client) ping() error {
|
||||
conn, err := c.NewDataConn()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
_, err = pb.WriteMessage(conn, &pb.ClientRequest{ConnType: &pb.ClientRequest_Ping{Ping: &pb.Ping{Now: time.Now().String()}}})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
_, err = pb.ReadMessage(conn, &pb.Ping{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "data ping")
|
||||
}
|
||||
resp, err := c.httpClient.Get("http://nps.ehang.io/ping")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "control ping")
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
return nil
|
||||
}
|
60
component/bridge/manager.go
Normal file
60
component/bridge/manager.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/lb"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/transport"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
clients map[string]*client
|
||||
clientLb *lb.LoadBalancer
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager() *manager {
|
||||
return &manager{
|
||||
clients: make(map[string]*client),
|
||||
clientLb: lb.NewLoadBalancer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SetClient(clientId string, tunnelId string, isControl bool, conn transport.Conn) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
client, ok := m.clients[tunnelId]
|
||||
if !ok {
|
||||
client = NewClient(tunnelId, clientId, m)
|
||||
err := m.clientLb.SetClient(clientId, client)
|
||||
if err != nil {
|
||||
logger.Error("set client error", zap.Error(err), zap.String("clientId", clientId), zap.String("tunnelId", tunnelId))
|
||||
return err
|
||||
}
|
||||
m.clients[tunnelId] = client
|
||||
}
|
||||
client.SetTunnel(conn, isControl)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) GetDataConn(clientId string) (net.Conn, error) {
|
||||
c, err := m.clientLb.GetClient(clientId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.(*client).NewDataConn()
|
||||
}
|
||||
|
||||
func (m *manager) RemoveClient(client *client) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
err := m.clientLb.RemoveClient(client.clientId, client)
|
||||
if err != nil {
|
||||
logger.Error("remove client error", zap.Error(err), zap.String("clientId", client.clientId), zap.String("tunnelId", client.tunnelId))
|
||||
return err
|
||||
}
|
||||
delete(m.clients, client.tunnelId)
|
||||
return nil
|
||||
}
|
91
component/bridge/quic.go
Normal file
91
component/bridge/quic.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type QUICServer struct {
|
||||
packetConn net.PacketConn
|
||||
tlsConfig *tls.Config
|
||||
config *quic.Config
|
||||
listener quic.Listener
|
||||
gp *ants.PoolWithFunc
|
||||
clientCheck func(string) bool
|
||||
manager *manager
|
||||
}
|
||||
|
||||
func NewQUICServer(packetConn net.PacketConn, tlsConfig *tls.Config, config *quic.Config, clientCheck func(string) bool) (*QUICServer, error) {
|
||||
qs := &QUICServer{
|
||||
packetConn: packetConn,
|
||||
tlsConfig: tlsConfig,
|
||||
config: config,
|
||||
clientCheck: clientCheck,
|
||||
manager: NewManager(),
|
||||
}
|
||||
var err error
|
||||
if qs.listener, err = quic.Listen(packetConn, tlsConfig, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qs.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
session := i.(quic.Session)
|
||||
logger.Debug("accept a session", zap.String("remote addr", session.RemoteAddr().String()))
|
||||
stream, err := session.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
logger.Warn("accept stream error", zap.Error(err))
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
return
|
||||
}
|
||||
cr := &pb.ConnRequest{}
|
||||
_, err = pb.ReadMessage(stream, cr)
|
||||
if err != nil {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Warn("read message error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if !qs.clientCheck(cr.GetId()) {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
_ = qs.responseClient(stream, false, "id check failed")
|
||||
return
|
||||
}
|
||||
qc := transport.NewQUIC(session)
|
||||
_ = qc.Client()
|
||||
|
||||
_ = qs.responseClient(stream, true, "success")
|
||||
err = qs.manager.SetClient(cr.GetId(), cr.GetNpcInfo().GetTunnelId(), cr.GetNpcInfo().GetIsControlTunnel(), qc)
|
||||
if err != nil {
|
||||
_ = session.CloseWithError(0, "check auth failed")
|
||||
logger.Error("set client error", zap.Error(err), zap.String("info", cr.String()))
|
||||
}
|
||||
})
|
||||
return qs, err
|
||||
}
|
||||
|
||||
func (qs *QUICServer) responseClient(conn io.Writer, success bool, msg string) error {
|
||||
_, err := pb.WriteMessage(conn, &pb.NpcResponse{Success: success, Message: msg})
|
||||
return err
|
||||
}
|
||||
|
||||
func (qs *QUICServer) run() error {
|
||||
for {
|
||||
session, err := qs.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
logger.Error("accept connection failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
err = qs.gp.Invoke(session)
|
||||
if err != nil {
|
||||
logger.Error("Invoke session error", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
116
component/bridge/tcp.go
Normal file
116
component/bridge/tcp.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type tcpServer struct {
|
||||
config *tls.Config // config must contain root ca and server cert
|
||||
ln net.Listener
|
||||
gp *ants.PoolWithFunc
|
||||
manager *manager
|
||||
serverCheck func(id string) bool
|
||||
clientCheck func(id string) bool
|
||||
}
|
||||
|
||||
func NewTcpServer(ln net.Listener, config *tls.Config, serverCheck, clientCheck func(string) bool) (*tcpServer, error) {
|
||||
h := &tcpServer{
|
||||
config: config,
|
||||
ln: ln,
|
||||
serverCheck: serverCheck,
|
||||
clientCheck: clientCheck,
|
||||
manager: NewManager(),
|
||||
}
|
||||
var err error
|
||||
h.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
conn := i.(net.Conn)
|
||||
cr := &pb.ConnRequest{}
|
||||
_, err := pb.ReadMessage(conn, cr)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
logger.Warn("read message error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
switch cr.ConnType.(type) {
|
||||
case *pb.ConnRequest_AppInfo:
|
||||
h.handleApp(cr, conn)
|
||||
case *pb.ConnRequest_NpcInfo:
|
||||
h.handleClient(cr, conn)
|
||||
}
|
||||
})
|
||||
return h, err
|
||||
}
|
||||
|
||||
func (h *tcpServer) handleApp(cr *pb.ConnRequest, conn net.Conn) {
|
||||
if !h.serverCheck(cr.GetId()) {
|
||||
_ = conn.Close()
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
return
|
||||
}
|
||||
clientConn, err := h.manager.GetDataConn(cr.GetAppInfo().GetNpcId())
|
||||
if err != nil {
|
||||
logger.Error("get client error", zap.Error(err), zap.String("app_info", cr.String()))
|
||||
return
|
||||
}
|
||||
_, err = pb.WriteMessage(clientConn, &pb.ClientRequest{ConnType: &pb.ClientRequest_AppInfo{AppInfo: cr.GetAppInfo()}})
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = conn.Close()
|
||||
logger.Error("write app_info error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: conn, Reader: clientConn, Wg: &wg})
|
||||
_ = pool.CopyConnGoroutinePool.Invoke(pool.CopyConnGpParams{Writer: clientConn, Reader: conn, Wg: &wg})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *tcpServer) responseClient(conn io.Writer, success bool, msg string) error {
|
||||
_, err := pb.WriteMessage(conn, &pb.NpcResponse{Success: success, Message: msg})
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *tcpServer) handleClient(cr *pb.ConnRequest, conn net.Conn) {
|
||||
if !h.clientCheck(cr.GetId()) {
|
||||
_ = conn.Close()
|
||||
logger.Error("check server id error", zap.String("id", cr.GetId()))
|
||||
_ = h.responseClient(conn, false, "id check failed")
|
||||
return
|
||||
}
|
||||
yc := transport.NewYaMux(conn, nil)
|
||||
err := yc.Client()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
_ = h.responseClient(conn, false, "client failed")
|
||||
logger.Error("new yamux client error", zap.Error(err), zap.String("remote address", conn.RemoteAddr().String()))
|
||||
return
|
||||
}
|
||||
_ = h.responseClient(conn, true, "success")
|
||||
err = h.manager.SetClient(cr.GetId(), cr.GetNpcInfo().GetTunnelId(), cr.GetNpcInfo().GetIsControlTunnel(), yc)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
logger.Error("set client error", zap.Error(err), zap.String("info", cr.String()))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *tcpServer) run() error {
|
||||
h.ln = tls.NewListener(h.ln, h.config)
|
||||
for {
|
||||
conn, err := h.ln.Accept()
|
||||
if err != nil {
|
||||
logger.Error("Accept conn error", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
_ = h.gp.Invoke(conn)
|
||||
}
|
||||
}
|
129
component/client/client.go
Normal file
129
component/client/client.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/rule"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
controlLn net.Listener
|
||||
dataLn net.Listener
|
||||
lastPongTime time.Time
|
||||
mux *http.ServeMux
|
||||
ticker *time.Ticker
|
||||
closeCh chan struct{}
|
||||
closed int32
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewClient(controlLn, dataLn net.Listener) *Client {
|
||||
return &Client{
|
||||
controlLn: controlLn,
|
||||
dataLn: dataLn,
|
||||
mux: &http.ServeMux{},
|
||||
ticker: time.NewTicker(time.Second * 5),
|
||||
closeCh: make(chan struct{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ping(writer http.ResponseWriter, request *http.Request) {
|
||||
c.lastPongTime = time.Now()
|
||||
_, err := io.WriteString(writer, "pong")
|
||||
if err != nil {
|
||||
logger.Warn("write pong error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
logger.Debug("write pong success")
|
||||
}
|
||||
|
||||
func (c *Client) Run() {
|
||||
c.mux.HandleFunc("/ping", c.ping)
|
||||
c.wg.Add(3)
|
||||
go c.handleControlConn()
|
||||
go c.handleDataConn()
|
||||
go c.checkPing()
|
||||
c.wg.Wait()
|
||||
}
|
||||
|
||||
func (c *Client) HasPong() bool {
|
||||
return time.Now().Sub(c.lastPongTime).Seconds() < 10
|
||||
}
|
||||
|
||||
func (c *Client) checkPing() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ticker.C:
|
||||
if !c.lastPongTime.IsZero() && time.Now().Sub(c.lastPongTime).Seconds() > 15 && c.controlLn != nil {
|
||||
logger.Debug("close connection", zap.Time("lastPongTime", c.lastPongTime), zap.Time("now", time.Now()))
|
||||
_ = c.controlLn.Close()
|
||||
}
|
||||
case <-c.closeCh:
|
||||
c.wg.Done()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleDataConn() {
|
||||
h := &handler.DefaultHandler{}
|
||||
ac := &action.LocalAction{}
|
||||
err := ac.Init()
|
||||
if err != nil {
|
||||
logger.Warn("init action failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
appPr := &process.PbAppProcessor{}
|
||||
_ = appPr.Init(ac)
|
||||
h.AddRule(&rule.Rule{Handler: h, Process: appPr, Action: ac})
|
||||
|
||||
pingPr := &process.PbPingProcessor{}
|
||||
_ = appPr.Init(ac)
|
||||
h.AddRule(&rule.Rule{Handler: h, Process: pingPr, Action: ac})
|
||||
|
||||
var conn net.Conn
|
||||
for {
|
||||
conn, err = c.dataLn.Accept()
|
||||
if err != nil {
|
||||
logger.Error("accept connection failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
_, err = h.HandleConn(nil, enet.NewReaderConn(conn))
|
||||
if err != nil {
|
||||
logger.Warn("process failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
c.wg.Done()
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (c *Client) handleControlConn() {
|
||||
err := http.Serve(c.controlLn, c.mux)
|
||||
if err != nil {
|
||||
logger.Error("http error", zap.Error(err))
|
||||
}
|
||||
c.wg.Done()
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
|
||||
c.closeCh <- struct{}{}
|
||||
c.ticker.Stop()
|
||||
_ = c.controlLn.Close()
|
||||
_ = c.dataLn.Close()
|
||||
}
|
||||
}
|
69
component/client/conn.go
Normal file
69
component/client/conn.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"ehang.io/nps/transport"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TunnelCreator interface {
|
||||
NewMux(bridgeAddr string, message *pb.ConnRequest, config *tls.Config) (net.Listener, error)
|
||||
}
|
||||
|
||||
type BaseTunnelCreator struct{}
|
||||
|
||||
func (bc BaseTunnelCreator) handshake(npcInfo *pb.ConnRequest, rw io.ReadWriteCloser) error {
|
||||
_, err := pb.WriteMessage(rw, npcInfo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "write handshake message")
|
||||
}
|
||||
var resp pb.NpcResponse
|
||||
_, err = pb.ReadMessage(rw, &resp)
|
||||
if err != nil || !resp.Success {
|
||||
return errors.Wrap(err, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TcpTunnelCreator struct{ BaseTunnelCreator }
|
||||
|
||||
func (tc TcpTunnelCreator) NewMux(bridgeAddr string, message *pb.ConnRequest, config *tls.Config) (net.Listener, error) {
|
||||
conn, err := tls.Dial("tcp", bridgeAddr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tc.handshake(message, conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server := transport.NewYaMux(conn, nil)
|
||||
return server, server.Server()
|
||||
}
|
||||
|
||||
type QUICTunnelCreator struct{ BaseTunnelCreator }
|
||||
|
||||
func (tc QUICTunnelCreator) NewMux(bridgeAddr string, message *pb.ConnRequest, config *tls.Config) (net.Listener, error) {
|
||||
session, err := quic.DialAddr(bridgeAddr, config, &quic.Config{
|
||||
MaxIncomingStreams: 1000000,
|
||||
MaxIncomingUniStreams: 1000000,
|
||||
MaxIdleTimeout: time.Minute,
|
||||
KeepAlive: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.handshake(message, stream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server := transport.NewQUIC(session)
|
||||
return server, server.Server()
|
||||
}
|
44
component/client/npc.go
Normal file
44
component/client/npc.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pb"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// StartNpc is used to connect to bridge
|
||||
// proto is quic or tcp
|
||||
// tlsConfig must contain a npc cert
|
||||
func StartNpc(proto string, bridgeAddr string, tlsConfig *tls.Config) error {
|
||||
id, err := cert.GetCertSnFromConfig(tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var creator TunnelCreator
|
||||
if proto == "quic" {
|
||||
creator = QUICTunnelCreator{}
|
||||
} else {
|
||||
creator = TcpTunnelCreator{}
|
||||
}
|
||||
connId := uuid.NewV1().String()
|
||||
retry:
|
||||
logger.Info("start connecting to bridge")
|
||||
controlLn, err := creator.NewMux(bridgeAddr,
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, tlsConfig)
|
||||
if err != nil {
|
||||
logger.Error("new control connection error", zap.Error(err))
|
||||
goto retry
|
||||
}
|
||||
dataLn, err := creator.NewMux(bridgeAddr,
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, tlsConfig)
|
||||
if err != nil {
|
||||
logger.Error("new data connection error", zap.Error(err))
|
||||
goto retry
|
||||
}
|
||||
c := NewClient(controlLn, dataLn)
|
||||
c.Run()
|
||||
goto retry
|
||||
}
|
1
component/component.go
Normal file
1
component/component.go
Normal file
@@ -0,0 +1 @@
|
||||
package component
|
193
component/component_test.go
Normal file
193
component/component_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package component
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"ehang.io/nps/component/bridge"
|
||||
"ehang.io/nps/component/client"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var createCertOnce sync.Once
|
||||
|
||||
func createCertFile(t *testing.T) (string, string) {
|
||||
createCertOnce.Do(func() {
|
||||
g := cert.NewX509Generator(pkix.Name{
|
||||
Country: []string{"CN"},
|
||||
Organization: []string{"Ehang.io"},
|
||||
OrganizationalUnit: []string{"nps"},
|
||||
Province: []string{"Beijing"},
|
||||
CommonName: "nps",
|
||||
Locality: []string{"Beijing"},
|
||||
})
|
||||
cert, key, err := g.CreateRootCa()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_cert.pem"), cert, 0600))
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_key.pem"), key, 0600))
|
||||
assert.NoError(t, g.InitRootCa(cert, key))
|
||||
cert, key, err = g.CreateCert("bridge.nps.ehang.io")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_cert.pem"), cert, 0600))
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_key.pem"), key, 0600))
|
||||
cert, key, err = g.CreateCert("client.nps.ehang.io")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_cert.pem"), cert, 0600))
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_key.pem"), key, 0600))
|
||||
})
|
||||
return filepath.Join(os.TempDir(), "cert.pem"), filepath.Join(os.TempDir(), "key.pem")
|
||||
}
|
||||
|
||||
func TestTcpConnect(t *testing.T) {
|
||||
createCertFile(t)
|
||||
bridgeLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
|
||||
assert.NoError(t, err)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM(buf)
|
||||
|
||||
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
bridgeTlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{crt},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
}
|
||||
|
||||
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
Certificates: []tls.Certificate{crt},
|
||||
}
|
||||
go func() {
|
||||
assert.NoError(t, bridge.StartTcpBridge(bridgeLn, bridgeTlsConfig, func(s string) bool {
|
||||
return true
|
||||
}, func(s string) bool {
|
||||
sn, err := cert.GetCertSnFromConfig(clientConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sn, s)
|
||||
return true
|
||||
}))
|
||||
}()
|
||||
var c *client.Client
|
||||
go func() {
|
||||
id, err := cert.GetCertSnFromConfig(clientConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
creator := client.TcpTunnelCreator{}
|
||||
connId := uuid.NewV1().String()
|
||||
controlLn, err := creator.NewMux(bridgeLn.Addr().String(),
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
|
||||
assert.NoError(t, err)
|
||||
dataLn, err := creator.NewMux(bridgeLn.Addr().String(),
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
|
||||
assert.NoError(t, err)
|
||||
c = client.NewClient(controlLn, dataLn)
|
||||
c.Run()
|
||||
}()
|
||||
timeout := time.NewTimer(time.Second * 30)
|
||||
ticker := time.NewTicker(time.Millisecond * 100)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if c != nil && c.HasPong() {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-timeout.C:
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICConnect(t *testing.T) {
|
||||
createCertFile(t)
|
||||
bridgePacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
|
||||
assert.NoError(t, err)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM(buf)
|
||||
|
||||
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
bridgeTlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{crt},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
NextProtos: []string{"quic-nps"},
|
||||
}
|
||||
|
||||
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
Certificates: []tls.Certificate{crt},
|
||||
NextProtos: []string{"quic-nps"},
|
||||
}
|
||||
go func() {
|
||||
assert.NoError(t, bridge.StartQUICBridge(bridgePacketConn, bridgeTlsConfig, &quic.Config{
|
||||
MaxIncomingStreams: 1000000,
|
||||
MaxIncomingUniStreams: 1000000,
|
||||
MaxIdleTimeout: time.Minute,
|
||||
KeepAlive: true,
|
||||
}, func(s string) bool {
|
||||
sn, err := cert.GetCertSnFromConfig(clientConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sn, s)
|
||||
return true
|
||||
}))
|
||||
}()
|
||||
var c *client.Client
|
||||
go func() {
|
||||
id, err := cert.GetCertSnFromConfig(clientConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
creator := client.QUICTunnelCreator{}
|
||||
connId := uuid.NewV1().String()
|
||||
controlLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
|
||||
assert.NoError(t, err)
|
||||
dataLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
|
||||
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, dataLn)
|
||||
assert.NotEmpty(t, controlLn)
|
||||
c = client.NewClient(controlLn, dataLn)
|
||||
c.Run()
|
||||
}()
|
||||
timeout := time.NewTimer(time.Second * 30)
|
||||
ticker := time.NewTicker(time.Millisecond * 100)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if c != nil && c.HasPong() {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
case <-timeout.C:
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
1
component/controller/asset
Submodule
1
component/controller/asset
Submodule
Submodule component/controller/asset added at 07c2567751
92
component/controller/cert.go
Normal file
92
component/controller/cert.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"crypto/x509/pkix"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type certServe struct {
|
||||
baseController
|
||||
cg *cert.X509Generator
|
||||
}
|
||||
|
||||
func (cs *certServe) Init(rootCert []byte, rootKey []byte) error {
|
||||
cs.cg = cert.NewX509Generator(pkix.Name{
|
||||
Country: []string{"cn"},
|
||||
Organization: []string{"ehang"},
|
||||
OrganizationalUnit: []string{"nps"},
|
||||
Province: []string{"beijing"},
|
||||
CommonName: "nps",
|
||||
Locality: []string{"beijing"},
|
||||
})
|
||||
return cs.cg.InitRootCa(rootCert, rootKey)
|
||||
}
|
||||
|
||||
type certInfo struct {
|
||||
Name string `json:"name"`
|
||||
Uuid string `json:"uuid"`
|
||||
CertType string `json:"cert_type"`
|
||||
Cert string `json:"cert"`
|
||||
Key string `json:"key"`
|
||||
Sn string `json:"sn"`
|
||||
Remark string `json:"remark"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// Create
|
||||
// Cert type root|bridge|server|client|secret
|
||||
func (cs *certServe) Create(c *gin.Context) {
|
||||
var ci certInfo
|
||||
err := c.BindJSON(&ci)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
crt, key, err := cs.cg.CreateCert(fmt.Sprintf("%s.nps.ehang.io", ci.CertType))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
sn, err := cert.GetCertSnFromEncode(crt)
|
||||
ci.Cert, ci.Key, ci.Sn, ci.Uuid = string(crt), string(key), sn, uuid.NewV4().String()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
b, err := json.Marshal(ci)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
err = cs.db.Insert("cert", ci.Uuid, string(b))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
||||
}
|
||||
|
||||
func (cs *certServe) Update(c *gin.Context) {
|
||||
var ci certInfo
|
||||
err := c.BindJSON(&ci)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
b, err := json.Marshal(ci)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
err = cs.db.Update(cs.tableName, ci.Uuid, string(b))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
||||
}
|
50
component/controller/config.go
Normal file
50
component/controller/config.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"ehang.io/nps/db"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type configServer struct {
|
||||
db db.Db
|
||||
}
|
||||
|
||||
func (cs *configServer) ChangeSystemConfig(c *gin.Context) {
|
||||
type config struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
NewUsername string `json:"new_username"`
|
||||
}
|
||||
var cfg config
|
||||
err := c.BindJSON(&cfg)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
oldPassword, err := cs.db.GetConfig("admin_pass")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
fmt.Println(cfg, oldPassword)
|
||||
if cfg.OldPassword != oldPassword {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": "old password does not match"})
|
||||
return
|
||||
}
|
||||
if err := cs.db.SetConfig("admin_pass", cfg.NewPassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if cfg.NewUsername != "" {
|
||||
if err := cs.db.SetConfig("admin_pass", cfg.NewUsername); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
})
|
||||
}
|
101
component/controller/controller.go
Normal file
101
component/controller/controller.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"ehang.io/nps/db"
|
||||
"ehang.io/nps/lib/logger"
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func StartController(ln net.Listener, db db.Db, rootCert []byte, rootKey []byte, staticRootPath string, pagePath string) error {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
cfgServer := &configServer{db: db}
|
||||
crtServer := &certServe{baseController: baseController{db: db, tableName: "cert"}}
|
||||
err := crtServer.Init(rootCert, rootKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ruleServer := &ruleServer{baseController: baseController{db: db, tableName: "rule"}}
|
||||
|
||||
authMiddleware, err := newAuthMiddleware(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(CORSMiddleware(), gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
logger.Debug("http request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("proto", param.Request.Proto),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.String("error_message", param.ErrorMessage),
|
||||
zap.Int("response_code", param.StatusCode),
|
||||
)
|
||||
return ""
|
||||
}))
|
||||
router.POST("/login", authMiddleware.LoginHandler)
|
||||
|
||||
router.NoRoute(authMiddleware.MiddlewareFunc(), func(c *gin.Context) {
|
||||
claims := jwt.ExtractClaims(c)
|
||||
logger.Warn("NoRoute", zap.Any("claims", claims))
|
||||
c.JSON(404, gin.H{"code": "PAGE_NOT_FOUND", "message": "Page not found"})
|
||||
})
|
||||
|
||||
auth := router.Group("/auth")
|
||||
auth.Use(authMiddleware.MiddlewareFunc())
|
||||
auth.GET("/refresh_token", authMiddleware.RefreshHandler)
|
||||
auth.GET("/userinfo", userinfo)
|
||||
|
||||
v1 := router.Group("v1")
|
||||
v1.Use(authMiddleware.MiddlewareFunc())
|
||||
{
|
||||
v1.PUT("/config", cfgServer.ChangeSystemConfig)
|
||||
|
||||
v1.GET("/status", status)
|
||||
|
||||
v1.POST("/cert", crtServer.Create)
|
||||
v1.DELETE("/cert", crtServer.Delete)
|
||||
v1.GET("/cert/page", crtServer.Page)
|
||||
v1.PUT("/cert", crtServer.Update)
|
||||
|
||||
v1.POST("/rule", ruleServer.Create)
|
||||
v1.DELETE("/rule", ruleServer.Delete)
|
||||
v1.PUT("/rule", ruleServer.Update)
|
||||
v1.GET("/rule", ruleServer.One)
|
||||
v1.GET("/rule/page", ruleServer.Page)
|
||||
v1.GET("/rule/field", ruleServer.Field)
|
||||
v1.GET("/rule/limiter", ruleServer.Limiter)
|
||||
}
|
||||
router.Static("/static/", staticRootPath)
|
||||
router.Static("/page/", pagePath)
|
||||
|
||||
go storeSystemStatus()
|
||||
|
||||
return router.RunListener(ln)
|
||||
}
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8",
|
||||
[]byte(`{"code":0,"result":{"userId":"1","username":"vben","realName":"Vben Admin","avatar":"https://q1.qlogo.cn/g?b=qq&nk=190848757&s=640","desc":"manager","password":"123456","token":"fakeToken1","homePath":"/dashboard/analysis","roles":[{"roleName":"Super Admin","value":"super"}]},"message":"ok","type":"success"}`))
|
||||
}
|
149
component/controller/controller_test.go
Normal file
149
component/controller/controller_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509/pkix"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/rule"
|
||||
"ehang.io/nps/core/server"
|
||||
"ehang.io/nps/db"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestController(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:3500")
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(filepath.Join(os.TempDir(), "test_control.db"))
|
||||
d := db.NewSqliteDb(filepath.Join(os.TempDir(), "test_control.db"))
|
||||
err = d.Init()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, d.SetConfig("admin_user", "admin"))
|
||||
assert.NoError(t, d.SetConfig("admin_pass", "pass"))
|
||||
cg := cert.NewX509Generator(pkix.Name{
|
||||
Country: []string{"cn"},
|
||||
Organization: []string{"ehang"},
|
||||
OrganizationalUnit: []string{"nps"},
|
||||
Province: []string{"beijing"},
|
||||
CommonName: "nps",
|
||||
Locality: []string{"beijing"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
cert, key, err := cg.CreateRootCa()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
err = StartController(ln, d, cert, key, "./web/static/", "./web/views/")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
resp, err := doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/login"), "POST", `{"username": "admin","password": "pass"}`)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
|
||||
|
||||
for i := 0; i < 18; i++ {
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/cert"), "POST", fmt.Sprintf(`{"status":1,"name":"name_%d","cert_type": "client"}`, i))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
|
||||
}
|
||||
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s?page=%d&pageSize=%d", ln.Addr().String(), "/v1/cert/page", 4, 5), "GET", ``)
|
||||
assert.NoError(t, err)
|
||||
now := 2
|
||||
var lastUuid string
|
||||
assert.Equal(t, len(gjson.Parse(resp).Get("result.items").Array()), 3)
|
||||
for _, v := range gjson.Parse(resp).Get("result.items").Array() {
|
||||
assert.Equal(t, v.Get("name").String(), fmt.Sprintf(`name_%d`, now))
|
||||
lastUuid = v.Get("uuid").String()
|
||||
now--
|
||||
}
|
||||
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/cert"), "DELETE", fmt.Sprintf(`{"uuid":"%s"}`, lastUuid))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
|
||||
|
||||
s := &server.TcpServer{ServerAddr: "127.0.0.1:0"}
|
||||
h := &handler.DefaultHandler{}
|
||||
p := &process.DefaultProcess{}
|
||||
a := &action.LocalAction{}
|
||||
rj := &rule.JsonRule{
|
||||
Name: "test",
|
||||
Status: 1,
|
||||
Server: rule.JsonData{ObjType: s.GetName(), ObjData: getJson(t, s)},
|
||||
Handler: rule.JsonData{ObjType: h.GetName(), ObjData: getJson(t, h)},
|
||||
Process: rule.JsonData{ObjType: p.GetName(), ObjData: getJson(t, p)},
|
||||
Action: rule.JsonData{ObjType: a.GetName(), ObjData: getJson(t, a)},
|
||||
Limiters: []rule.JsonData{},
|
||||
}
|
||||
js := getJson(t, rj)
|
||||
for i := 0; i < 18; i++ {
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "POST", js)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
|
||||
}
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s?page=%d&pageSize=%d", ln.Addr().String(), "/v1/rule/page", 1, 10), "GET", ``)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("result.total").Int()), 18)
|
||||
|
||||
uuid := gjson.Parse(resp).Get("result.items").Array()[0].Get("uuid").String()
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "GET", fmt.Sprintf(`{"uuid": "%s"}`, uuid))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, gjson.Parse(resp).Get("result.uuid").String(), uuid)
|
||||
|
||||
rj.Uuid = uuid
|
||||
resp, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/v1/rule"), "PUT", getJson(t, rj))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int(gjson.Parse(resp).Get("code").Int()), 0)
|
||||
time.Sleep(time.Minute * 600)
|
||||
}
|
||||
|
||||
func getJson(t *testing.T, i interface{}) string {
|
||||
b, err := json.Marshal(i)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, string(b))
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
var once sync.Once
|
||||
var cookies []*http.Cookie
|
||||
|
||||
func doRequest(url string, method string, body string) (string, error) {
|
||||
once.Do(func() {
|
||||
client = &http.Client{}
|
||||
})
|
||||
payload := bytes.NewBufferString(body)
|
||||
req, err := http.NewRequest(method, url, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
if len(res.Cookies()) > 0 {
|
||||
cookies = res.Cookies()
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return string(b), errors.New("bad doRequest")
|
||||
}
|
||||
return string(b), err
|
||||
}
|
105
component/controller/jwt.go
Normal file
105
component/controller/jwt.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"ehang.io/nps/db"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type login struct {
|
||||
Username string `form:"username" json:"username" binding:"required"`
|
||||
Password string `form:"password" json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
var identityKey = "id"
|
||||
|
||||
type User struct {
|
||||
UserName string
|
||||
}
|
||||
|
||||
func newAuthMiddleware(db db.Db) (authMiddleware *jwt.GinJWTMiddleware, err error) {
|
||||
authMiddleware, err = jwt.New(&jwt.GinJWTMiddleware{
|
||||
Realm: "nps",
|
||||
Key: []byte("secret key"),
|
||||
Timeout: time.Hour * 24,
|
||||
MaxRefresh: time.Hour * 72,
|
||||
IdentityKey: identityKey,
|
||||
SendCookie: true,
|
||||
LoginResponse: func(c *gin.Context, code int, message string, time time.Time) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"result": gin.H{
|
||||
"token": message,
|
||||
},
|
||||
"message": "ok",
|
||||
})
|
||||
},
|
||||
PayloadFunc: func(data interface{}) jwt.MapClaims {
|
||||
if v, ok := data.(*User); ok {
|
||||
return jwt.MapClaims{
|
||||
identityKey: v.UserName,
|
||||
}
|
||||
}
|
||||
return jwt.MapClaims{}
|
||||
},
|
||||
IdentityHandler: func(c *gin.Context) interface{} {
|
||||
claims := jwt.ExtractClaims(c)
|
||||
return &User{
|
||||
UserName: claims[identityKey].(string),
|
||||
}
|
||||
},
|
||||
Authenticator: func(c *gin.Context) (interface{}, error) {
|
||||
var loginVals login
|
||||
if err := c.ShouldBind(&loginVals); err != nil {
|
||||
return "", jwt.ErrMissingLoginValues
|
||||
}
|
||||
userID := loginVals.Username
|
||||
password := loginVals.Password
|
||||
adminUser, err := db.GetConfig("admin_user")
|
||||
if err != nil {
|
||||
return "", jwt.ErrFailedAuthentication
|
||||
}
|
||||
adminPass, err := db.GetConfig("admin_pass")
|
||||
if err != nil {
|
||||
return "", jwt.ErrFailedAuthentication
|
||||
}
|
||||
if userID == adminUser && password == adminPass {
|
||||
return &User{
|
||||
UserName: userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
},
|
||||
Authorizator: func(data interface{}, c *gin.Context) bool {
|
||||
adminUser, err := db.GetConfig("admin_user")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := data.(*User); ok && v.UserName ==adminUser {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
Unauthorized: func(c *gin.Context, code int, message string) {
|
||||
c.JSON(code, gin.H{
|
||||
"code": code,
|
||||
"message": message,
|
||||
})
|
||||
},
|
||||
TokenLookup: "header: Authorization, query: token, cookie: jwt",
|
||||
TokenHeadName: "Bearer",
|
||||
TimeFunc: time.Now,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = authMiddleware.MiddlewareInit()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
151
component/controller/rule.go
Normal file
151
component/controller/rule.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/rule"
|
||||
"ehang.io/nps/db"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type baseController struct {
|
||||
db db.Db
|
||||
tableName string
|
||||
}
|
||||
|
||||
func (bc *baseController) Page(c *gin.Context) {
|
||||
page, err := strconv.Atoi(c.Query("page"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
pageSize, err := strconv.Atoi(c.Query("pageSize"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
dataArr, err := bc.db.QueryPage(bc.tableName, pageSize, (page-1)*pageSize, c.Query("key"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
list := make([]map[string]interface{}, 0)
|
||||
for _, s := range dataArr {
|
||||
dd := make(map[string]interface{}, 0)
|
||||
err = json.Unmarshal([]byte(s), &dd)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
list = append(list, dd)
|
||||
}
|
||||
n, err := bc.db.Count(bc.tableName, c.Query("key"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"result": gin.H{
|
||||
"items": list,
|
||||
"total": n,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (bc *baseController) Delete(c *gin.Context) {
|
||||
type uid struct {
|
||||
Uuid string
|
||||
}
|
||||
var js uid
|
||||
err := c.BindJSON(&js)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
err = bc.db.Delete(bc.tableName, js.Uuid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
||||
}
|
||||
|
||||
type ruleServer struct {
|
||||
baseController
|
||||
}
|
||||
|
||||
func (rs *ruleServer) Create(c *gin.Context) {
|
||||
var jr rule.JsonRule
|
||||
err := c.BindJSON(&jr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
jr.Uuid = uuid.NewV4().String()
|
||||
b, err := json.Marshal(jr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
err = rs.db.Insert(rs.tableName, jr.Uuid, string(b))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
||||
}
|
||||
|
||||
func (rs *ruleServer) Update(c *gin.Context) {
|
||||
var jr rule.JsonRule
|
||||
err := c.BindJSON(&jr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
b, err := json.Marshal(jr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
err = rs.db.Update(rs.tableName, jr.Uuid, string(b))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
||||
}
|
||||
|
||||
func (rs *ruleServer) One(c *gin.Context) {
|
||||
var js map[string]string
|
||||
err := c.BindJSON(&js)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
s, err := rs.db.QueryOne("rule", js["uuid"])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var r rule.JsonRule
|
||||
err = json.Unmarshal([]byte(s), &r)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 1, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "result": r, "message": "ok"})
|
||||
}
|
||||
|
||||
func (rs *ruleServer) Field(c *gin.Context) {
|
||||
chains := rule.GetChains()
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "result": chains, "message": "ok"})
|
||||
}
|
||||
|
||||
func (rs *ruleServer) Limiter(c *gin.Context) {
|
||||
chains := rule.GetLimiters()
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "result": chains, "message": "ok"})
|
||||
}
|
152
component/controller/status.go
Normal file
152
component/controller/status.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/disk"
|
||||
"github.com/shirou/gopsutil/load"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
"github.com/shirou/gopsutil/net"
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
tData = list.New()
|
||||
)
|
||||
|
||||
type timeData struct {
|
||||
now time.Time
|
||||
cpuData float64
|
||||
load1Data float64
|
||||
load5Data float64
|
||||
load15Data float64
|
||||
swapData float64
|
||||
virtualData float64
|
||||
bandwidthRecvData float64
|
||||
bandwidthSendData float64
|
||||
tcpConnNumData float64
|
||||
udpConnNumData float64
|
||||
diskData float64
|
||||
}
|
||||
|
||||
type dataAddr []float64
|
||||
|
||||
func status(c *gin.Context) {
|
||||
timeArr := make([]string, 0)
|
||||
dataMap := make(map[string][]float64, 0)
|
||||
dataMap["cpu"] = make([]float64, 0)
|
||||
dataMap["load1"] = make([]float64, 0)
|
||||
dataMap["load5"] = make([]float64, 0)
|
||||
dataMap["load15"] = make([]float64, 0)
|
||||
dataMap["swap"] = make([]float64, 0)
|
||||
dataMap["virtual"] = make([]float64, 0)
|
||||
dataMap["bandwidthRecvData"] = make([]float64, 0)
|
||||
dataMap["bandwidthSendData"] = make([]float64, 0)
|
||||
dataMap["tcpConnNumData"] = make([]float64, 0)
|
||||
dataMap["udpConnNumData"] = make([]float64, 0)
|
||||
dataMap["disk"] = make([]float64, 0)
|
||||
now := tData.Front()
|
||||
for {
|
||||
if now == nil {
|
||||
break
|
||||
}
|
||||
data := now.Value.(*timeData)
|
||||
timeArr = append(timeArr, data.now.Format("01-02 15:04"))
|
||||
dataMap["cpu"] = append(dataMap["cpu"], data.cpuData)
|
||||
dataMap["load1"] = append(dataMap["load1"], data.load15Data)
|
||||
dataMap["load5"] = append(dataMap["load5"], data.load5Data)
|
||||
dataMap["load15"] = append(dataMap["load15"], data.load15Data)
|
||||
dataMap["swap"] = append(dataMap["swap"], data.swapData)
|
||||
dataMap["virtual"] = append(dataMap["virtual"], data.virtualData)
|
||||
dataMap["bandwidthSend"] = append(dataMap["bandwidthSend"], data.bandwidthRecvData)
|
||||
dataMap["bandwidthRecv"] = append(dataMap["bandwidthRecv"], data.bandwidthSendData)
|
||||
dataMap["tcp"] = append(dataMap["v"], data.tcpConnNumData)
|
||||
dataMap["udp"] = append(dataMap["udp"], data.udpConnNumData)
|
||||
dataMap["disk"] = append(dataMap["disk"], data.diskData)
|
||||
now = now.Next()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"result": gin.H{
|
||||
"time": timeArr,
|
||||
"data": dataMap,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func storeSystemStatus() {
|
||||
path := "/"
|
||||
if runtime.GOOS == "windows" {
|
||||
path = "C:"
|
||||
}
|
||||
|
||||
for range time.NewTicker(time.Second).C {
|
||||
td := &timeData{now: time.Now()}
|
||||
checkListLen(tData)
|
||||
cpuPercent, err := cpu.Percent(0, true)
|
||||
if err == nil {
|
||||
var cpuAll float64
|
||||
for _, v := range cpuPercent {
|
||||
cpuAll += v
|
||||
}
|
||||
td.cpuData = float64(len(cpuPercent))
|
||||
}
|
||||
|
||||
loads, err := load.Avg()
|
||||
if err == nil {
|
||||
td.load1Data = loads.Load1
|
||||
td.load1Data = loads.Load5
|
||||
td.load15Data = loads.Load15
|
||||
}
|
||||
|
||||
swap, err := mem.SwapMemory()
|
||||
if err == nil {
|
||||
td.swapData = math.Round(swap.UsedPercent)
|
||||
}
|
||||
vir, err := mem.VirtualMemory()
|
||||
if err == nil {
|
||||
td.virtualData = math.Round(vir.UsedPercent)
|
||||
}
|
||||
io1, err := net.IOCounters(false)
|
||||
if err == nil {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
io2, err := net.IOCounters(false)
|
||||
if err == nil && len(io2) > 0 && len(io1) > 0 {
|
||||
td.bandwidthRecvData = float64((io2[0].BytesRecv-io1[0].BytesRecv)*2) / 1024 / 1024
|
||||
td.bandwidthSendData = float64((io2[0].BytesSent-io1[0].BytesSent)*2) / 1024 / 1024
|
||||
}
|
||||
}
|
||||
conn, err := net.ProtoCounters(nil)
|
||||
if err == nil {
|
||||
for _, v := range conn {
|
||||
if v.Protocol == "tcp" {
|
||||
td.tcpConnNumData = float64(v.Stats["CurrEstab"])
|
||||
}
|
||||
if v.Protocol == "udp" {
|
||||
td.udpConnNumData = float64(v.Stats["CurrEstab"])
|
||||
}
|
||||
}
|
||||
}
|
||||
usage, err := disk.Usage(path)
|
||||
if err == nil {
|
||||
td.diskData = math.Round(usage.UsedPercent)
|
||||
}
|
||||
tData.PushBack(td)
|
||||
}
|
||||
}
|
||||
|
||||
func checkListLen(lists ...*list.List) {
|
||||
for _, l := range lists {
|
||||
if l.Len() > 4320 {
|
||||
if first := l.Front(); first != nil {
|
||||
l.Remove(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
11
component/service/service.go
Normal file
11
component/service/service.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
import "net"
|
||||
|
||||
type HttpService struct {
|
||||
ln net.Listener
|
||||
}
|
||||
|
||||
func NewHttpService(ln net.Listener) *HttpService {
|
||||
return &HttpService{ln: ln}
|
||||
}
|
Reference in New Issue
Block a user