mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-08 00:26:52 +00:00
add new file
This commit is contained in:
110
core/action/action.go
Normal file
110
core/action/action.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"errors"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bp = pool.NewBufferPool(MaxReadSize)
|
||||
|
||||
const MaxReadSize = 32 * 1024
|
||||
|
||||
var (
|
||||
_ Action = (*AdminAction)(nil)
|
||||
_ Action = (*BridgeAction)(nil)
|
||||
_ Action = (*LocalAction)(nil)
|
||||
_ Action = (*NpcAction)(nil)
|
||||
)
|
||||
|
||||
type Action interface {
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
Init() error
|
||||
RunConnWithAddr(net.Conn, string) error
|
||||
RunConn(net.Conn) error
|
||||
GetServeConnWithAddr(string) (net.Conn, error)
|
||||
GetServerConn() (net.Conn, error)
|
||||
CanServe() bool
|
||||
RunPacketConn(conn net.PacketConn) error
|
||||
}
|
||||
|
||||
type DefaultAction struct {
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetName() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetZhName() string {
|
||||
return "默认"
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) RunConn(clientConn net.Conn) error {
|
||||
return errors.New("not supported")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) CanServe() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) RunPacketConn(conn net.PacketConn) error {
|
||||
return errors.New("not supported")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetServerConn() (net.Conn, error) {
|
||||
return nil, errors.New("can not get component connection")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return nil, errors.New("can not get component connection")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) startCopy(c1 net.Conn, c2 net.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
err := pool.CopyConnGoroutinePool.Invoke(&pool.CopyConnGpParams{
|
||||
Reader: c2,
|
||||
Writer: c1,
|
||||
Wg: &wg,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Invoke goroutine failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
buf := bp.Get()
|
||||
_, _ = common.CopyBuffer(c2, c1, buf)
|
||||
bp.Put(buf)
|
||||
if v, ok := c1.(*net.TCPConn); ok {
|
||||
_ = v.CloseRead()
|
||||
}
|
||||
if v, ok := c2.(*net.TCPConn); ok {
|
||||
_ = v.CloseWrite()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) startCopyPacketConn(p1 net.PacketConn, p2 net.PacketConn) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
_ = pool.CopyPacketGoroutinePool.Invoke(&pool.CopyPacketGpParams{
|
||||
RPacket: p1,
|
||||
WPacket: p2,
|
||||
Wg: &wg,
|
||||
})
|
||||
_ = pool.CopyPacketGoroutinePool.Invoke(&pool.CopyPacketGpParams{
|
||||
RPacket: p2,
|
||||
WPacket: p1,
|
||||
Wg: &wg,
|
||||
})
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
32
core/action/admin.go
Normal file
32
core/action/admin.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"net"
|
||||
)
|
||||
|
||||
var adminListener = enet.NewListener()
|
||||
|
||||
func GetAdminListener() net.Listener {
|
||||
return adminListener
|
||||
}
|
||||
|
||||
type AdminAction struct {
|
||||
DefaultAction
|
||||
}
|
||||
|
||||
func (la *AdminAction) GetName() string {
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func (la *AdminAction) GetZhName() string {
|
||||
return "转发到控制台"
|
||||
}
|
||||
|
||||
func (la *AdminAction) RunConn(clientConn net.Conn) error {
|
||||
return adminListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (la *AdminAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
return adminListener.SendConn(clientConn)
|
||||
}
|
29
core/action/admin_test.go
Normal file
29
core/action/admin_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdminRunConn(t *testing.T) {
|
||||
ac := &AdminAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
}
|
||||
finish := make(chan struct{}, 0)
|
||||
go func() {
|
||||
_, err := GetAdminListener().Accept()
|
||||
assert.NoError(t, err)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, ac.RunConn(conn))
|
||||
}()
|
||||
_, err = net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
<-finish
|
||||
}
|
61
core/action/bridge.go
Normal file
61
core/action/bridge.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"net"
|
||||
)
|
||||
|
||||
var bridgeListener = enet.NewListener()
|
||||
var bridgePacketConn enet.PacketConn
|
||||
var packetBp = pool.NewBufferPool(1500)
|
||||
|
||||
func GetBridgeListener() net.Listener {
|
||||
return bridgeListener
|
||||
}
|
||||
|
||||
func GetBridgePacketConn() net.PacketConn {
|
||||
return bridgePacketConn
|
||||
}
|
||||
|
||||
type BridgeAction struct {
|
||||
DefaultAction
|
||||
WritePacketConn net.PacketConn `json:"-"`
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) GetName() string {
|
||||
return "bridge"
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) GetZhName() string {
|
||||
return "转发到网桥"
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) Init() error {
|
||||
bridgePacketConn = enet.NewReaderPacketConn(ba.WritePacketConn, nil, ba.WritePacketConn.LocalAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunConn(clientConn net.Conn) error {
|
||||
return bridgeListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
return bridgeListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunPacketConn(pc net.PacketConn) error {
|
||||
b := packetBp.Get()
|
||||
defer packetBp.Put(b)
|
||||
for {
|
||||
n, addr, err := pc.ReadFrom(b)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = bridgePacketConn.SendPacket(b[:n], addr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
73
core/action/bridge_test.go
Normal file
73
core/action/bridge_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBridgeRunConn(t *testing.T) {
|
||||
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
ac := &BridgeAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
WritePacketConn: packetConn,
|
||||
}
|
||||
finish := make(chan struct{}, 0)
|
||||
go func() {
|
||||
_, err := GetBridgeListener().Accept()
|
||||
assert.NoError(t, err)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, ac.RunConn(conn))
|
||||
}()
|
||||
_, err = net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
<-finish
|
||||
}
|
||||
|
||||
func TestBridgeRunPacket(t *testing.T) {
|
||||
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
ac := &BridgeAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
WritePacketConn: packetConn,
|
||||
}
|
||||
assert.NoError(t, ac.Init())
|
||||
go func() {
|
||||
p := make([]byte, 1024)
|
||||
pc := GetBridgePacketConn()
|
||||
n, addr, err := pc.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
_, err = pc.WriteTo(p[:n], addr)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := packetConn.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
bPacketConn := enet.NewReaderPacketConn(packetConn, p[:n], addr)
|
||||
go func() {
|
||||
err = ac.RunPacketConn(bPacketConn)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
err = bPacketConn.SendPacket(p[:n], addr)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
cPacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
b := []byte("12345")
|
||||
_, err = cPacketConn.WriteTo(b, packetConn.LocalAddr())
|
||||
assert.NoError(t, err)
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := cPacketConn.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, addr.String(), packetConn.LocalAddr().String())
|
||||
assert.Equal(t, p[:n], b)
|
||||
}
|
77
core/action/local.go
Normal file
77
core/action/local.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/lb"
|
||||
"net"
|
||||
)
|
||||
|
||||
type LocalAction struct {
|
||||
DefaultAction
|
||||
TargetAddr []string `json:"target_addr" placeholder:"1.1.1.1:80\n1.1.1.2:80" zh_name:"目标地址"`
|
||||
UnixSocket bool `json:"unix_sock" placeholder:"" zh_name:"转发到unix socket"`
|
||||
networkTcp string
|
||||
localLb lb.Algo
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetName() string {
|
||||
return "local"
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetZhName() string {
|
||||
return "转发到本地"
|
||||
}
|
||||
|
||||
func (la *LocalAction) Init() error {
|
||||
la.localLb = lb.GetLbAlgo("roundRobin")
|
||||
for _, v := range la.TargetAddr {
|
||||
_ = la.localLb.Append(v)
|
||||
}
|
||||
la.networkTcp = "tcp"
|
||||
if la.UnixSocket {
|
||||
// just support unix
|
||||
la.networkTcp = "unix"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunConn(clientConn net.Conn) error {
|
||||
serverConn, err := la.GetServerConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
la.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
serverConn, err := la.GetServeConnWithAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
la.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) CanServe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetServerConn() (net.Conn, error) {
|
||||
addr, err := la.localLb.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return la.GetServeConnWithAddr(addr.(string))
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return net.Dial(la.networkTcp, addr)
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunPacketConn(pc net.PacketConn) error {
|
||||
localPacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return la.startCopyPacketConn(pc, localPacketConn)
|
||||
}
|
1
core/action/local_test.go
Normal file
1
core/action/local_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package action
|
84
core/action/npc.go
Normal file
84
core/action/npc.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type NpcAction struct {
|
||||
NpcId string `json:"npc_id" required:"true" placeholder:"npc id" zh_name:"客户端"`
|
||||
BridgeAddr string `json:"bridge_addr" placeholder:"127.0.0.1:8080" zh_name:"网桥地址"`
|
||||
UnixSocket bool `json:"unix_sock" placeholder:"" zh_name:"转发到unix socket"`
|
||||
networkTcp pb.ConnType
|
||||
tlsConfig *tls.Config
|
||||
connRequest *pb.ConnRequest
|
||||
DefaultAction
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetName() string {
|
||||
return "npc"
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetZhName() string {
|
||||
return "转发到客户端"
|
||||
}
|
||||
|
||||
func (na *NpcAction) Init() error {
|
||||
if na.tlsConfig == nil {
|
||||
return errors.New("tls config is nil")
|
||||
}
|
||||
sn, err := cert.GetCertSnFromConfig(na.tlsConfig)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get serial number")
|
||||
}
|
||||
na.connRequest = &pb.ConnRequest{Id: sn}
|
||||
na.networkTcp = pb.ConnType_tcp
|
||||
if na.UnixSocket {
|
||||
// just support unix
|
||||
na.networkTcp = pb.ConnType_unix
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (na *NpcAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
serverConn, err := na.GetServeConnWithAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
na.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (na *NpcAction) CanServe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return dialBridge(na, na.networkTcp, addr)
|
||||
}
|
||||
|
||||
func (na *NpcAction) RunPacketConn(pc net.PacketConn) error {
|
||||
serverPacketConn, err := dialBridge(na, pb.ConnType_udp, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return na.startCopyPacketConn(pc, enet.NewTcpPacketConn(serverPacketConn))
|
||||
}
|
||||
|
||||
func dialBridge(npc *NpcAction, connType pb.ConnType, addr string) (net.Conn, error) {
|
||||
tlsConn, err := tls.Dial("tcp", npc.BridgeAddr, npc.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "dial bridge tls")
|
||||
}
|
||||
cr := proto.Clone(npc.connRequest).(*pb.ConnRequest)
|
||||
cr.ConnType = &pb.ConnRequest_AppInfo{AppInfo: &pb.AppInfo{ConnType: connType, AppAddr: addr, NpcId: npc.NpcId}}
|
||||
if _, err = pb.WriteMessage(tlsConn, cr); err != nil {
|
||||
return nil, errors.Wrap(err, "write enet request")
|
||||
}
|
||||
return tlsConn, err
|
||||
}
|
1
core/action/npc_test.go
Normal file
1
core/action/npc_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package action
|
Reference in New Issue
Block a user