mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-06 23:26:52 +00:00
add new file
This commit is contained in:
110
core/action/action.go
Normal file
110
core/action/action.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"errors"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bp = pool.NewBufferPool(MaxReadSize)
|
||||
|
||||
const MaxReadSize = 32 * 1024
|
||||
|
||||
var (
|
||||
_ Action = (*AdminAction)(nil)
|
||||
_ Action = (*BridgeAction)(nil)
|
||||
_ Action = (*LocalAction)(nil)
|
||||
_ Action = (*NpcAction)(nil)
|
||||
)
|
||||
|
||||
type Action interface {
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
Init() error
|
||||
RunConnWithAddr(net.Conn, string) error
|
||||
RunConn(net.Conn) error
|
||||
GetServeConnWithAddr(string) (net.Conn, error)
|
||||
GetServerConn() (net.Conn, error)
|
||||
CanServe() bool
|
||||
RunPacketConn(conn net.PacketConn) error
|
||||
}
|
||||
|
||||
type DefaultAction struct {
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetName() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetZhName() string {
|
||||
return "默认"
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) RunConn(clientConn net.Conn) error {
|
||||
return errors.New("not supported")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) CanServe() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) RunPacketConn(conn net.PacketConn) error {
|
||||
return errors.New("not supported")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetServerConn() (net.Conn, error) {
|
||||
return nil, errors.New("can not get component connection")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return nil, errors.New("can not get component connection")
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) startCopy(c1 net.Conn, c2 net.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
err := pool.CopyConnGoroutinePool.Invoke(&pool.CopyConnGpParams{
|
||||
Reader: c2,
|
||||
Writer: c1,
|
||||
Wg: &wg,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Invoke goroutine failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
buf := bp.Get()
|
||||
_, _ = common.CopyBuffer(c2, c1, buf)
|
||||
bp.Put(buf)
|
||||
if v, ok := c1.(*net.TCPConn); ok {
|
||||
_ = v.CloseRead()
|
||||
}
|
||||
if v, ok := c2.(*net.TCPConn); ok {
|
||||
_ = v.CloseWrite()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (ba *DefaultAction) startCopyPacketConn(p1 net.PacketConn, p2 net.PacketConn) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
_ = pool.CopyPacketGoroutinePool.Invoke(&pool.CopyPacketGpParams{
|
||||
RPacket: p1,
|
||||
WPacket: p2,
|
||||
Wg: &wg,
|
||||
})
|
||||
_ = pool.CopyPacketGoroutinePool.Invoke(&pool.CopyPacketGpParams{
|
||||
RPacket: p2,
|
||||
WPacket: p1,
|
||||
Wg: &wg,
|
||||
})
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
32
core/action/admin.go
Normal file
32
core/action/admin.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"net"
|
||||
)
|
||||
|
||||
var adminListener = enet.NewListener()
|
||||
|
||||
func GetAdminListener() net.Listener {
|
||||
return adminListener
|
||||
}
|
||||
|
||||
type AdminAction struct {
|
||||
DefaultAction
|
||||
}
|
||||
|
||||
func (la *AdminAction) GetName() string {
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func (la *AdminAction) GetZhName() string {
|
||||
return "转发到控制台"
|
||||
}
|
||||
|
||||
func (la *AdminAction) RunConn(clientConn net.Conn) error {
|
||||
return adminListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (la *AdminAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
return adminListener.SendConn(clientConn)
|
||||
}
|
29
core/action/admin_test.go
Normal file
29
core/action/admin_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdminRunConn(t *testing.T) {
|
||||
ac := &AdminAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
}
|
||||
finish := make(chan struct{}, 0)
|
||||
go func() {
|
||||
_, err := GetAdminListener().Accept()
|
||||
assert.NoError(t, err)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, ac.RunConn(conn))
|
||||
}()
|
||||
_, err = net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
<-finish
|
||||
}
|
61
core/action/bridge.go
Normal file
61
core/action/bridge.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"net"
|
||||
)
|
||||
|
||||
var bridgeListener = enet.NewListener()
|
||||
var bridgePacketConn enet.PacketConn
|
||||
var packetBp = pool.NewBufferPool(1500)
|
||||
|
||||
func GetBridgeListener() net.Listener {
|
||||
return bridgeListener
|
||||
}
|
||||
|
||||
func GetBridgePacketConn() net.PacketConn {
|
||||
return bridgePacketConn
|
||||
}
|
||||
|
||||
type BridgeAction struct {
|
||||
DefaultAction
|
||||
WritePacketConn net.PacketConn `json:"-"`
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) GetName() string {
|
||||
return "bridge"
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) GetZhName() string {
|
||||
return "转发到网桥"
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) Init() error {
|
||||
bridgePacketConn = enet.NewReaderPacketConn(ba.WritePacketConn, nil, ba.WritePacketConn.LocalAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunConn(clientConn net.Conn) error {
|
||||
return bridgeListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
return bridgeListener.SendConn(clientConn)
|
||||
}
|
||||
|
||||
func (ba *BridgeAction) RunPacketConn(pc net.PacketConn) error {
|
||||
b := packetBp.Get()
|
||||
defer packetBp.Put(b)
|
||||
for {
|
||||
n, addr, err := pc.ReadFrom(b)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = bridgePacketConn.SendPacket(b[:n], addr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
73
core/action/bridge_test.go
Normal file
73
core/action/bridge_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBridgeRunConn(t *testing.T) {
|
||||
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
ac := &BridgeAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
WritePacketConn: packetConn,
|
||||
}
|
||||
finish := make(chan struct{}, 0)
|
||||
go func() {
|
||||
_, err := GetBridgeListener().Accept()
|
||||
assert.NoError(t, err)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, ac.RunConn(conn))
|
||||
}()
|
||||
_, err = net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
<-finish
|
||||
}
|
||||
|
||||
func TestBridgeRunPacket(t *testing.T) {
|
||||
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
ac := &BridgeAction{
|
||||
DefaultAction: DefaultAction{},
|
||||
WritePacketConn: packetConn,
|
||||
}
|
||||
assert.NoError(t, ac.Init())
|
||||
go func() {
|
||||
p := make([]byte, 1024)
|
||||
pc := GetBridgePacketConn()
|
||||
n, addr, err := pc.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
_, err = pc.WriteTo(p[:n], addr)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := packetConn.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
bPacketConn := enet.NewReaderPacketConn(packetConn, p[:n], addr)
|
||||
go func() {
|
||||
err = ac.RunPacketConn(bPacketConn)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
err = bPacketConn.SendPacket(p[:n], addr)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
cPacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
b := []byte("12345")
|
||||
_, err = cPacketConn.WriteTo(b, packetConn.LocalAddr())
|
||||
assert.NoError(t, err)
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := cPacketConn.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, addr.String(), packetConn.LocalAddr().String())
|
||||
assert.Equal(t, p[:n], b)
|
||||
}
|
77
core/action/local.go
Normal file
77
core/action/local.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/lb"
|
||||
"net"
|
||||
)
|
||||
|
||||
type LocalAction struct {
|
||||
DefaultAction
|
||||
TargetAddr []string `json:"target_addr" placeholder:"1.1.1.1:80\n1.1.1.2:80" zh_name:"目标地址"`
|
||||
UnixSocket bool `json:"unix_sock" placeholder:"" zh_name:"转发到unix socket"`
|
||||
networkTcp string
|
||||
localLb lb.Algo
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetName() string {
|
||||
return "local"
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetZhName() string {
|
||||
return "转发到本地"
|
||||
}
|
||||
|
||||
func (la *LocalAction) Init() error {
|
||||
la.localLb = lb.GetLbAlgo("roundRobin")
|
||||
for _, v := range la.TargetAddr {
|
||||
_ = la.localLb.Append(v)
|
||||
}
|
||||
la.networkTcp = "tcp"
|
||||
if la.UnixSocket {
|
||||
// just support unix
|
||||
la.networkTcp = "unix"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunConn(clientConn net.Conn) error {
|
||||
serverConn, err := la.GetServerConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
la.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
serverConn, err := la.GetServeConnWithAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
la.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (la *LocalAction) CanServe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetServerConn() (net.Conn, error) {
|
||||
addr, err := la.localLb.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return la.GetServeConnWithAddr(addr.(string))
|
||||
}
|
||||
|
||||
func (la *LocalAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return net.Dial(la.networkTcp, addr)
|
||||
}
|
||||
|
||||
func (la *LocalAction) RunPacketConn(pc net.PacketConn) error {
|
||||
localPacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return la.startCopyPacketConn(pc, localPacketConn)
|
||||
}
|
1
core/action/local_test.go
Normal file
1
core/action/local_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package action
|
84
core/action/npc.go
Normal file
84
core/action/npc.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type NpcAction struct {
|
||||
NpcId string `json:"npc_id" required:"true" placeholder:"npc id" zh_name:"客户端"`
|
||||
BridgeAddr string `json:"bridge_addr" placeholder:"127.0.0.1:8080" zh_name:"网桥地址"`
|
||||
UnixSocket bool `json:"unix_sock" placeholder:"" zh_name:"转发到unix socket"`
|
||||
networkTcp pb.ConnType
|
||||
tlsConfig *tls.Config
|
||||
connRequest *pb.ConnRequest
|
||||
DefaultAction
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetName() string {
|
||||
return "npc"
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetZhName() string {
|
||||
return "转发到客户端"
|
||||
}
|
||||
|
||||
func (na *NpcAction) Init() error {
|
||||
if na.tlsConfig == nil {
|
||||
return errors.New("tls config is nil")
|
||||
}
|
||||
sn, err := cert.GetCertSnFromConfig(na.tlsConfig)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get serial number")
|
||||
}
|
||||
na.connRequest = &pb.ConnRequest{Id: sn}
|
||||
na.networkTcp = pb.ConnType_tcp
|
||||
if na.UnixSocket {
|
||||
// just support unix
|
||||
na.networkTcp = pb.ConnType_unix
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (na *NpcAction) RunConnWithAddr(clientConn net.Conn, addr string) error {
|
||||
serverConn, err := na.GetServeConnWithAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
na.startCopy(clientConn, serverConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (na *NpcAction) CanServe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (na *NpcAction) GetServeConnWithAddr(addr string) (net.Conn, error) {
|
||||
return dialBridge(na, na.networkTcp, addr)
|
||||
}
|
||||
|
||||
func (na *NpcAction) RunPacketConn(pc net.PacketConn) error {
|
||||
serverPacketConn, err := dialBridge(na, pb.ConnType_udp, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return na.startCopyPacketConn(pc, enet.NewTcpPacketConn(serverPacketConn))
|
||||
}
|
||||
|
||||
func dialBridge(npc *NpcAction, connType pb.ConnType, addr string) (net.Conn, error) {
|
||||
tlsConn, err := tls.Dial("tcp", npc.BridgeAddr, npc.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "dial bridge tls")
|
||||
}
|
||||
cr := proto.Clone(npc.connRequest).(*pb.ConnRequest)
|
||||
cr.ConnType = &pb.ConnRequest_AppInfo{AppInfo: &pb.AppInfo{ConnType: connType, AppAddr: addr, NpcId: npc.NpcId}}
|
||||
if _, err = pb.WriteMessage(tlsConn, cr); err != nil {
|
||||
return nil, errors.Wrap(err, "write enet request")
|
||||
}
|
||||
return tlsConn, err
|
||||
}
|
1
core/action/npc_test.go
Normal file
1
core/action/npc_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package action
|
71
core/handler/default.go
Normal file
71
core/handler/default.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Handler = (*HttpHandler)(nil)
|
||||
_ Handler = (*HttpsHandler)(nil)
|
||||
_ Handler = (*RdpHandler)(nil)
|
||||
_ Handler = (*RedisHandler)(nil)
|
||||
_ Handler = (*Socks5Handler)(nil)
|
||||
_ Handler = (*TransparentHandler)(nil)
|
||||
_ Handler = (*DefaultHandler)(nil)
|
||||
_ Handler = (*DnsHandler)(nil)
|
||||
_ Handler = (*P2PHandler)(nil)
|
||||
_ Handler = (*QUICHandler)(nil)
|
||||
_ Handler = (*DefaultHandler)(nil)
|
||||
_ Handler = (*Socks5UdpHandler)(nil)
|
||||
)
|
||||
|
||||
type RuleRun interface {
|
||||
RunConn(enet.Conn) (bool, error)
|
||||
RunPacketConn(enet.PacketConn) (bool, error)
|
||||
}
|
||||
|
||||
type DefaultHandler struct {
|
||||
ruleList []RuleRun
|
||||
}
|
||||
|
||||
func NewBaseTcpHandler() *DefaultHandler {
|
||||
return &DefaultHandler{ruleList: make([]RuleRun, 0)}
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) GetName() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) GetZhName() string {
|
||||
return "默认"
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) HandleConn(_ []byte, c enet.Conn) (bool, error) {
|
||||
return b.processConn(c)
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) AddRule(r RuleRun) {
|
||||
b.ruleList = append(b.ruleList, r)
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) HandlePacketConn(_ enet.PacketConn) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) processConn(c enet.Conn) (bool, error) {
|
||||
for _, r := range b.ruleList {
|
||||
if ok, err := r.RunConn(c); err != nil || ok {
|
||||
return ok, err
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (b *DefaultHandler) processPacketConn(pc enet.PacketConn) (bool, error) {
|
||||
for _, r := range b.ruleList {
|
||||
if ok, err := r.RunPacketConn(pc); err != nil || ok {
|
||||
return ok, err
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
35
core/handler/dns.go
Normal file
35
core/handler/dns.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DnsHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (dh *DnsHandler) GetName() string {
|
||||
return "dns"
|
||||
}
|
||||
|
||||
func (dh *DnsHandler) GetZhName() string {
|
||||
return "dns协议"
|
||||
}
|
||||
|
||||
func (dh *DnsHandler) HandlePacketConn(pc enet.PacketConn) (bool, error) {
|
||||
b, _, err := pc.FirstPacket()
|
||||
if err != nil {
|
||||
logger.Warn("firstPacket error", zap.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
err = m.Unpack(b)
|
||||
if err != nil {
|
||||
logger.Debug("parse dns request error", zap.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
return dh.processPacketConn(pc)
|
||||
}
|
46
core/handler/dns_test.go
Normal file
46
core/handler/dns_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testRule struct {
|
||||
run bool
|
||||
}
|
||||
|
||||
func (t *testRule) RunConn(c enet.Conn) (bool, error) {
|
||||
t.run = true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (t *testRule) RunPacketConn(_ enet.PacketConn) (bool, error) {
|
||||
t.run = true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func TestHandleDnsPacket(t *testing.T) {
|
||||
lPacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
h := DnsHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn("www.google.com"), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
|
||||
b, err := m.Pack()
|
||||
assert.NoError(t, err)
|
||||
pc := enet.NewReaderPacketConn(nil, b, lPacketConn.LocalAddr())
|
||||
|
||||
assert.NoError(t, pc.SendPacket(b, nil))
|
||||
res, err := h.HandlePacketConn(pc)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
}
|
11
core/handler/handler.go
Normal file
11
core/handler/handler.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package handler
|
||||
|
||||
import "ehang.io/nps/lib/enet"
|
||||
|
||||
type Handler interface {
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
AddRule(RuleRun)
|
||||
HandleConn([]byte, enet.Conn) (bool, error)
|
||||
HandlePacketConn(enet.PacketConn) (bool, error)
|
||||
}
|
30
core/handler/http.go
Normal file
30
core/handler/http.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type HttpHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func NewHttpHandler() *HttpHandler {
|
||||
return &HttpHandler{}
|
||||
}
|
||||
|
||||
func (h *HttpHandler) GetName() string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
func (h *HttpHandler) GetZhName() string {
|
||||
return "http协议"
|
||||
}
|
||||
|
||||
func (h *HttpHandler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
switch string(b[:3]) {
|
||||
case http.MethodGet[:3], http.MethodHead[:3], http.MethodPost[:3], http.MethodPut[:3], http.MethodPatch[:3], http.MethodDelete[:3], http.MethodConnect[:3], http.MethodOptions[:3], http.MethodTrace[:3]:
|
||||
return h.processConn(c)
|
||||
}
|
||||
return false, nil
|
||||
}
|
27
core/handler/http_test.go
Normal file
27
core/handler/http_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleHttpConn(t *testing.T) {
|
||||
|
||||
h := HttpHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
b, err := httputil.DumpRequest(r, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := h.HandleConn(b, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
}
|
33
core/handler/https.go
Normal file
33
core/handler/https.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeHandshake uint8 = 22
|
||||
typeClientHello uint8 = 1
|
||||
)
|
||||
|
||||
type HttpsHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func NewHttpsHandler() *HttpsHandler {
|
||||
return &HttpsHandler{}
|
||||
}
|
||||
|
||||
func (h *HttpsHandler) GetName() string {
|
||||
return "https"
|
||||
}
|
||||
|
||||
func (h *HttpsHandler) GetZhName() string {
|
||||
return "https协议"
|
||||
}
|
||||
|
||||
func (h *HttpsHandler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
if b[0] == recordTypeHandshake{
|
||||
return h.processConn(c)
|
||||
}
|
||||
return false, nil
|
||||
}
|
39
core/handler/https_test.go
Normal file
39
core/handler/https_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleHttpsConn(t *testing.T) {
|
||||
h := HttpsHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
finish := make(chan struct{}, 0)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
res, err := h.HandleConn(buf[:n], enet.NewReaderConn(conn))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err = tls.Dial("tcp", ln.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
<-finish
|
||||
}
|
32
core/handler/p2p.go
Normal file
32
core/handler/p2p.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type P2PHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (ph *P2PHandler) GetName() string {
|
||||
return "p2p"
|
||||
}
|
||||
|
||||
func (ph *P2PHandler) GetZhName() string {
|
||||
return "点对点协议"
|
||||
}
|
||||
|
||||
func (ph *P2PHandler) HandlePacketConn(pc enet.PacketConn) (bool, error) {
|
||||
b, _, err := pc.FirstPacket()
|
||||
if err != nil {
|
||||
logger.Warn("firstPacket error", zap.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
if bytes.HasPrefix(b, []byte("p2p")) {
|
||||
return ph.processPacketConn(pc)
|
||||
}
|
||||
return false, nil
|
||||
}
|
26
core/handler/p2p_test.go
Normal file
26
core/handler/p2p_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleP2PPacket(t *testing.T) {
|
||||
|
||||
h := P2PHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080")
|
||||
assert.NoError(t, err)
|
||||
pc := enet.NewReaderPacketConn(nil, []byte("p2p xxxx"), addr)
|
||||
|
||||
assert.NoError(t, pc.SendPacket([]byte("p2p xxxx"), nil))
|
||||
|
||||
res, err := h.HandlePacketConn(pc)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
}
|
32
core/handler/quic.go
Normal file
32
core/handler/quic.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type QUICHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (qh *QUICHandler) GetName() string {
|
||||
return "quic"
|
||||
}
|
||||
|
||||
func (qh *QUICHandler) GetZhName() string {
|
||||
return "quic协议"
|
||||
}
|
||||
|
||||
func (qh *QUICHandler) HandlePacketConn(pc enet.PacketConn) (bool, error) {
|
||||
b, _, err := pc.FirstPacket()
|
||||
if err != nil {
|
||||
logger.Warn("firstPacket error", zap.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
if len(b) >= 5 && bytes.HasPrefix(b[1:5], []byte{0, 0, 0, 1}) {
|
||||
return qh.processPacketConn(pc)
|
||||
}
|
||||
return false, nil
|
||||
}
|
35
core/handler/quic_test.go
Normal file
35
core/handler/quic_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleQUICPacket(t *testing.T) {
|
||||
h := QUICHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
finish := make(chan struct{}, 0)
|
||||
packetConn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
b := make([]byte, 1500)
|
||||
n, addr, err := packetConn.ReadFrom(b)
|
||||
assert.NoError(t, err)
|
||||
pc := enet.NewReaderPacketConn(nil, b[:n], packetConn.LocalAddr())
|
||||
assert.NoError(t, pc.SendPacket(b[:n], addr))
|
||||
res, err := h.HandlePacketConn(pc)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
go quic.DialAddr(packetConn.LocalAddr().String(), &tls.Config{}, nil)
|
||||
<-finish
|
||||
}
|
24
core/handler/rdp.go
Normal file
24
core/handler/rdp.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
type RdpHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (rh *RdpHandler) GetName() string {
|
||||
return "rdp"
|
||||
}
|
||||
|
||||
func (rh *RdpHandler) GetZhName() string {
|
||||
return "rdp协议"
|
||||
}
|
||||
|
||||
func (rh *RdpHandler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
if b[0] == 3 && b[1] == 0 {
|
||||
return rh.processConn(c)
|
||||
}
|
||||
return false, nil
|
||||
}
|
37
core/handler/rdp_test.go
Normal file
37
core/handler/rdp_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/icodeface/grdp"
|
||||
"github.com/icodeface/grdp/glog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleRdpConn(t *testing.T) {
|
||||
h := RdpHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
finish := make(chan struct{}, 0)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
res, err := h.HandleConn(buf[:n], enet.NewReaderConn(conn))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
grdp.NewClient(ln.Addr().String(), glog.DEBUG).Login("Administrator", "123456")
|
||||
}()
|
||||
<-finish
|
||||
}
|
22
core/handler/redis.go
Normal file
22
core/handler/redis.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package handler
|
||||
|
||||
import "ehang.io/nps/lib/enet"
|
||||
|
||||
type RedisHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (rds *RedisHandler) GetName() string {
|
||||
return "redis"
|
||||
}
|
||||
|
||||
func (rds *RedisHandler) GetZhName() string {
|
||||
return "redis协议"
|
||||
}
|
||||
|
||||
func (rds *RedisHandler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
if b[0] == 42 && b[1] == 49 && b[2] == 13 {
|
||||
return rds.processConn(c)
|
||||
}
|
||||
return false, nil
|
||||
}
|
40
core/handler/redis_test.go
Normal file
40
core/handler/redis_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleRedisConn(t *testing.T) {
|
||||
h := RedisHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
finish := make(chan struct{}, 0)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
res, err := h.HandleConn(buf[:n], enet.NewReaderConn(conn))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ln.Addr().String(),
|
||||
})
|
||||
rdb.Ping(context.Background())
|
||||
}()
|
||||
<-finish
|
||||
}
|
22
core/handler/socks5.go
Normal file
22
core/handler/socks5.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package handler
|
||||
|
||||
import "ehang.io/nps/lib/enet"
|
||||
|
||||
type Socks5Handler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (sh *Socks5Handler) GetName() string {
|
||||
return "socks5"
|
||||
}
|
||||
|
||||
func (sh *Socks5Handler) GetZhName() string {
|
||||
return "socks5协议"
|
||||
}
|
||||
|
||||
func (sh *Socks5Handler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
if b[0] == 5 {
|
||||
return sh.processConn(c)
|
||||
}
|
||||
return false, nil
|
||||
}
|
47
core/handler/socks5_test.go
Normal file
47
core/handler/socks5_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleSocks5Conn(t *testing.T) {
|
||||
h := Socks5Handler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
finish := make(chan struct{}, 0)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
res, err := h.HandleConn(buf[:n], enet.NewReaderConn(conn))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("socks5://%s", ln.Addr().String()))
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
_, _ = client.Get("https://google.com/")
|
||||
}()
|
||||
<-finish
|
||||
}
|
26
core/handler/socks5_udp.go
Normal file
26
core/handler/socks5_udp.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package handler
|
||||
|
||||
import "ehang.io/nps/lib/enet"
|
||||
|
||||
type Socks5UdpHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (sh *Socks5UdpHandler) GetName() string {
|
||||
return "socks5_udp"
|
||||
}
|
||||
|
||||
func (sh *Socks5UdpHandler) GetZhName() string {
|
||||
return "socks5 udp协议"
|
||||
}
|
||||
|
||||
func (sh *Socks5UdpHandler) HandlePacketConn(pc enet.PacketConn) (bool, error) {
|
||||
b, _, err := pc.FirstPacket()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
if b[0] == 0 {
|
||||
return sh.processPacketConn(pc)
|
||||
}
|
||||
return false, nil
|
||||
}
|
44
core/handler/socks5_udp_test.go
Normal file
44
core/handler/socks5_udp_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSocks5Handle(t *testing.T) {
|
||||
h := Socks5UdpHandler{}
|
||||
rule := &testRule{}
|
||||
h.AddRule(rule)
|
||||
|
||||
finish := make(chan struct{}, 0)
|
||||
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
buf := make([]byte, 1024)
|
||||
n, addr, err := pc.ReadFrom(buf)
|
||||
assert.NoError(t, err)
|
||||
rPc := enet.NewReaderPacketConn(nil, buf[:n], addr)
|
||||
res, err := h.HandlePacketConn(rPc)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, res)
|
||||
assert.Equal(t, true, rule.run)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
data := []byte("test")
|
||||
go func() {
|
||||
cPc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
pAddr, err := common.ParseAddr("8.8.8.8:53")
|
||||
assert.NoError(t, err)
|
||||
b := append([]byte{0, 0, 0}, pAddr...)
|
||||
b = append(b, data...)
|
||||
_, err = cPc.WriteTo(b, pc.LocalAddr())
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
<-finish
|
||||
}
|
21
core/handler/transparent.go
Normal file
21
core/handler/transparent.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
type TransparentHandler struct {
|
||||
DefaultHandler
|
||||
}
|
||||
|
||||
func (ts *TransparentHandler) GetName() string {
|
||||
return "transparent"
|
||||
}
|
||||
|
||||
func (ts *TransparentHandler) GetZhName() string {
|
||||
return "linux透明代理协议"
|
||||
}
|
||||
|
||||
func (ts *TransparentHandler) HandleConn(b []byte, c enet.Conn) (bool, error) {
|
||||
return ts.processConn(c)
|
||||
}
|
43
core/limiter/conn_num.go
Normal file
43
core/limiter/conn_num.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConnNumLimiter is used to limit the connection num of a service
|
||||
type ConnNumLimiter struct {
|
||||
baseLimiter
|
||||
nowNum int32
|
||||
MaxConnNum int32 `json:"max_conn_num" required:"true" placeholder:"10" zh_name:"最大连接数"` //0 means not limit
|
||||
}
|
||||
|
||||
func (cl *ConnNumLimiter) GetName() string {
|
||||
return "conn_num"
|
||||
}
|
||||
|
||||
func (cl *ConnNumLimiter) GetZhName() string {
|
||||
return "总连接数限制"
|
||||
}
|
||||
|
||||
// DoLimit return an error if the connection num exceed the maximum
|
||||
func (cl *ConnNumLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
if atomic.AddInt32(&cl.nowNum, 1) > cl.MaxConnNum && cl.MaxConnNum > 0 {
|
||||
atomic.AddInt32(&cl.nowNum, -1)
|
||||
return nil, errors.New("exceed maximum number of connections")
|
||||
}
|
||||
return &connNumConn{nowNum: &cl.nowNum}, nil
|
||||
}
|
||||
|
||||
// connNumConn is an implementation of enet.Conn
|
||||
type connNumConn struct {
|
||||
nowNum *int32
|
||||
enet.Conn
|
||||
}
|
||||
|
||||
// Close decrease the connection num
|
||||
func (cn *connNumConn) Close() error {
|
||||
atomic.AddInt32(cn.nowNum, -1)
|
||||
return cn.Conn.Close()
|
||||
}
|
35
core/limiter/conn_num_test.go
Normal file
35
core/limiter/conn_num_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnNumLimiter(t *testing.T) {
|
||||
cl := ConnNumLimiter{MaxConnNum: 5}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowNum := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
nowNum++
|
||||
_, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
if nowNum > 5 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 6; i > 0; i-- {
|
||||
go net.Dial("tcp", ln.Addr().String())
|
||||
}
|
||||
<-close
|
||||
}
|
87
core/limiter/flow.go
Normal file
87
core/limiter/flow.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// FlowStore is an interface to store or get the flow now
|
||||
type FlowStore interface {
|
||||
GetOutIn() (uint32, uint32)
|
||||
AddOut(out uint32) uint32
|
||||
AddIn(in uint32) uint32
|
||||
}
|
||||
|
||||
// memStore is an implement of FlowStore
|
||||
type memStore struct {
|
||||
nowOut uint32
|
||||
nowIn uint32
|
||||
}
|
||||
|
||||
// GetOutIn return out and in num 0
|
||||
func (m *memStore) GetOutIn() (uint32, uint32) {
|
||||
return m.nowOut, m.nowIn
|
||||
}
|
||||
|
||||
// AddOut is used to add out now
|
||||
func (m *memStore) AddOut(out uint32) uint32 {
|
||||
return atomic.AddUint32(&m.nowOut, out)
|
||||
}
|
||||
|
||||
// AddIn is used to add in now
|
||||
func (m *memStore) AddIn(in uint32) uint32 {
|
||||
return atomic.AddUint32(&m.nowIn, in)
|
||||
}
|
||||
|
||||
// FlowLimiter is used to limit the flow of a service
|
||||
type FlowLimiter struct {
|
||||
Store FlowStore
|
||||
OutLimit uint32 `json:"out_limit" required:"true" placeholder:"1024(kb)" zh_name:"出口最大流量"` //unit: kb, 0 means not limit
|
||||
InLimit uint32 `json:"in_limit" required:"true" placeholder:"1024(kb)" zh_name:"入口最大流量"` //unit: kb, 0 means not limit
|
||||
}
|
||||
|
||||
func (f *FlowLimiter) GetName() string {
|
||||
return "flow"
|
||||
}
|
||||
|
||||
func (f *FlowLimiter) GetZhName() string {
|
||||
return "流量限制"
|
||||
}
|
||||
|
||||
// DoLimit return a flow limited enet.Conn
|
||||
func (f *FlowLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
return &flowConn{fl: f, Conn: c}, nil
|
||||
}
|
||||
|
||||
// Init is used to set out or in num now
|
||||
func (f *FlowLimiter) Init() error {
|
||||
if f.Store == nil {
|
||||
f.Store = &memStore{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flowConn is an implement of
|
||||
type flowConn struct {
|
||||
enet.Conn
|
||||
fl *FlowLimiter
|
||||
}
|
||||
|
||||
// Read add the in flow num of the service
|
||||
func (fs *flowConn) Read(b []byte) (n int, err error) {
|
||||
n, err = fs.Conn.Read(b)
|
||||
if fs.fl.InLimit > 0 && fs.fl.Store.AddIn(uint32(n)) > fs.fl.InLimit {
|
||||
err = errors.New("exceed the in flow limit")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Write add the out flow num of the service
|
||||
func (fs *flowConn) Write(b []byte) (n int, err error) {
|
||||
n, err = fs.Conn.Write(b)
|
||||
if fs.fl.OutLimit > 0 && fs.fl.Store.AddOut(uint32(n)) > fs.fl.OutLimit {
|
||||
err = errors.New("exceed the out flow limit")
|
||||
}
|
||||
return
|
||||
}
|
59
core/limiter/flow_test.go
Normal file
59
core/limiter/flow_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlowLimiter(t *testing.T) {
|
||||
cl := FlowLimiter{
|
||||
OutLimit: 100,
|
||||
InLimit: 100,
|
||||
}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowBytes := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 10)
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
c, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
nowBytes += n
|
||||
if nowBytes > 100 {
|
||||
assert.Error(t, err)
|
||||
nowBytes = 0
|
||||
for i := 11; i > 0; i-- {
|
||||
n, err = c.Write(bytes.Repeat([]byte{0}, 10))
|
||||
nowBytes += n
|
||||
if nowBytes > 100 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Write(bytes.Repeat([]byte{0}, 10))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
<-close
|
||||
}
|
99
core/limiter/ip_conn_num.go
Normal file
99
core/limiter/ip_conn_num.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ipNumMap is used to store the connection num of a ip address
|
||||
type ipNumMap struct {
|
||||
m map[string]int32
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// AddOrSet is used to add connection num of a ip address
|
||||
func (i *ipNumMap) AddOrSet(key string) {
|
||||
i.Lock()
|
||||
if v, ok := i.m[key]; ok {
|
||||
i.m[key] = v + 1
|
||||
} else {
|
||||
i.m[key] = 1
|
||||
}
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
// SubOrDel is used to decrease connection of a ip address
|
||||
func (i *ipNumMap) SubOrDel(key string) {
|
||||
i.Lock()
|
||||
if v, ok := i.m[key]; ok {
|
||||
i.m[key] = v - 1
|
||||
if i.m[key] == 0 {
|
||||
delete(i.m, key)
|
||||
}
|
||||
}
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
// Get return the connection num of a ip
|
||||
func (i *ipNumMap) Get(key string) int32 {
|
||||
return i.m[key]
|
||||
}
|
||||
|
||||
// IpConnNumLimiter is used to limit the connection num of a service at the same time of same ip
|
||||
type IpConnNumLimiter struct {
|
||||
m *ipNumMap
|
||||
MaxNum int32 `json:"max_num" required:"true" placeholder:"10" zh_name:"单ip最大连接数"`
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (cl *IpConnNumLimiter) GetName() string {
|
||||
return "ip_conn_num"
|
||||
}
|
||||
|
||||
func (cl *IpConnNumLimiter) GetZhName() string {
|
||||
return "单ip连接数限制"
|
||||
}
|
||||
|
||||
// Init the ipNumMap
|
||||
func (cl *IpConnNumLimiter) Init() error {
|
||||
cl.m = &ipNumMap{m: make(map[string]int32)}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoLimit reports whether the connection num of the ip exceed the maximum number
|
||||
// If true, return error
|
||||
func (cl *IpConnNumLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return c, errors.Wrap(err, "split ip addr")
|
||||
}
|
||||
if cl.m.Get(ip) >= cl.MaxNum {
|
||||
return c, errors.Errorf("the ip(%s) exceed the maximum number(%d)", ip, cl.MaxNum)
|
||||
}
|
||||
return NewNumConn(c, ip, cl.m), nil
|
||||
}
|
||||
|
||||
// numConn is an implement of enet.Conn
|
||||
type numConn struct {
|
||||
key string
|
||||
m *ipNumMap
|
||||
enet.Conn
|
||||
}
|
||||
|
||||
// NewNumConn return a numConn
|
||||
func NewNumConn(c enet.Conn, key string, m *ipNumMap) *numConn {
|
||||
m.AddOrSet(key)
|
||||
return &numConn{
|
||||
m: m,
|
||||
key: key,
|
||||
Conn: c,
|
||||
}
|
||||
}
|
||||
|
||||
// Close is used to decrease the connection num of a ip when connection closing
|
||||
func (n *numConn) Close() error {
|
||||
n.m.SubOrDel(n.key)
|
||||
return n.Conn.Close()
|
||||
}
|
35
core/limiter/ip_conn_num_test.go
Normal file
35
core/limiter/ip_conn_num_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIpConnNumLimiter(t *testing.T) {
|
||||
cl := IpConnNumLimiter{MaxNum: 5}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowNum := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
nowNum++
|
||||
_, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
if nowNum > 5 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 6; i > 0; i-- {
|
||||
go net.Dial("tcp", ln.Addr().String())
|
||||
}
|
||||
<-close
|
||||
}
|
26
core/limiter/limiter.go
Normal file
26
core/limiter/limiter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Limiter = (*RateLimiter)(nil)
|
||||
_ Limiter = (*ConnNumLimiter)(nil)
|
||||
_ Limiter = (*IpConnNumLimiter)(nil)
|
||||
_ Limiter = (*FlowLimiter)(nil)
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
DoLimit(conn enet.Conn) (enet.Conn, error)
|
||||
Init() error
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
}
|
||||
|
||||
type baseLimiter struct {
|
||||
}
|
||||
|
||||
func (bl *baseLimiter) Init() error {
|
||||
return nil
|
||||
}
|
67
core/limiter/rate.go
Normal file
67
core/limiter/rate.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/rate"
|
||||
)
|
||||
|
||||
// RateLimiter is used to limit the speed of transport
|
||||
type RateLimiter struct {
|
||||
baseLimiter
|
||||
RateLimit int64 `json:"rate_limit" required:"true" placeholder:"10(kb)" zh_name:"最大速度"`
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetName() string {
|
||||
return "rate"
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetZhName() string {
|
||||
return "带宽限制"
|
||||
}
|
||||
|
||||
// Init the rate controller
|
||||
func (rl *RateLimiter) Init() error {
|
||||
if rl.RateLimit > 0 && rl.rate == nil {
|
||||
rl.rate = rate.NewRate(rl.RateLimit)
|
||||
rl.rate.Start()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoLimit return limited Conn
|
||||
func (rl *RateLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
return NewRateConn(c, rl.rate), nil
|
||||
}
|
||||
|
||||
// rateConn is used to limiter the rate fo connection
|
||||
type rateConn struct {
|
||||
enet.Conn
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
// NewRateConn return limited connection by rate interface
|
||||
func NewRateConn(rc enet.Conn, rate *rate.Rate) enet.Conn {
|
||||
return &rateConn{
|
||||
Conn: rc,
|
||||
rate: rate,
|
||||
}
|
||||
}
|
||||
|
||||
// Read data and remove capacity from rate pool
|
||||
func (s *rateConn) Read(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Read(b)
|
||||
if s.rate != nil && err == nil {
|
||||
err = s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Write data and remove capacity from rate pool
|
||||
func (s *rateConn) Write(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Write(b)
|
||||
if s.rate != nil && err == nil {
|
||||
err = s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
46
core/limiter/rate_test.go
Normal file
46
core/limiter/rate_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
cl := RateLimiter{
|
||||
RateLimit: 100,
|
||||
}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowBytes := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 10)
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
c, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
go func() {
|
||||
<-time.After(time.Second * 2)
|
||||
if nowBytes > 500 {
|
||||
t.Fail()
|
||||
}
|
||||
close <- struct{}{}
|
||||
}()
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
nowBytes += n
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Write(bytes.Repeat([]byte{0}, 10000))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
<-close
|
||||
}
|
94
core/process/http_proxy.go
Normal file
94
core/process/http_proxy.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"encoding/base64"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type HttpProxyProcess struct {
|
||||
DefaultProcess
|
||||
BasicAuth map[string]string `json:"basic_auth" placeholder:"username1 password1\nusername2 password2" zh_name:"basic认证"`
|
||||
}
|
||||
|
||||
func (hpp *HttpProxyProcess) GetName() string {
|
||||
return "http_proxy"
|
||||
}
|
||||
|
||||
func (hpp *HttpProxyProcess) GetZhName() string {
|
||||
return "http代理"
|
||||
}
|
||||
|
||||
func (hpp *HttpProxyProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
r, err := http.ReadRequest(bufio.NewReader(c))
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "read proxy request")
|
||||
}
|
||||
if len(hpp.BasicAuth) != 0 && !hpp.checkAuth(r) {
|
||||
return true, hpp.response(http.StatusProxyAuthRequired, map[string]string{"Proxy-Authenticate": "Basic realm=" + strconv.Quote("nps")}, c)
|
||||
}
|
||||
if r.Method == "CONNECT" {
|
||||
err = hpp.response(200, map[string]string{}, c)
|
||||
if err != nil {
|
||||
return true, errors.Wrap(err, "http proxy response")
|
||||
}
|
||||
} else if err = c.Reset(0); err != nil {
|
||||
logger.Warn("reset enet.Conn error", zap.Error(err))
|
||||
return true, err
|
||||
}
|
||||
address := r.Host
|
||||
if !strings.Contains(r.Host, ":") {
|
||||
if r.URL.Scheme == "https" {
|
||||
address = r.Host + ":443"
|
||||
} else {
|
||||
address = r.Host + ":80"
|
||||
}
|
||||
}
|
||||
return true, hpp.ac.RunConnWithAddr(c, address)
|
||||
}
|
||||
func (hpp *HttpProxyProcess) response(statusCode int, headers map[string]string, c enet.Conn) error {
|
||||
resp := &http.Response{
|
||||
Status: http.StatusText(statusCode),
|
||||
StatusCode: statusCode,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
}
|
||||
for k, v := range headers {
|
||||
resp.Header.Set(k, v)
|
||||
}
|
||||
return resp.Write(c)
|
||||
}
|
||||
|
||||
func (hpp *HttpProxyProcess) checkAuth(r *http.Request) bool {
|
||||
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
||||
if len(s) != 2 {
|
||||
s = strings.SplitN(r.Header.Get("Proxy-Authorization"), " ", 2)
|
||||
if len(s) != 2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return false
|
||||
}
|
||||
for u, p := range hpp.BasicAuth {
|
||||
if pair[0] == u && pair[1] == p {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
93
core/process/http_proxy_test.go
Normal file
93
core/process/http_proxy_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpProxyProcess(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hsAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
h := HttpProxyProcess{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("http://%s", ln.Addr().String()))
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = client.Get(fmt.Sprintf("http://%s/now", hsAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestHttpProxyProcessBasic(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
h := HttpProxyProcess{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
BasicAuth: map[string]string{"aaa": "bbb"},
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("http://%s", ln.Addr().String()))
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.Error(t, err)
|
||||
transport.Proxy = func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("http://%s:%s@%s", "aaa", "bbb", ln.Addr().String()))
|
||||
}
|
||||
|
||||
resp, err = client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
79
core/process/http_serve.go
Normal file
79
core/process/http_serve.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HttpServeProcess is proxy and modify http request
|
||||
type HttpServeProcess struct {
|
||||
DefaultProcess
|
||||
tls bool
|
||||
Host string `json:"host" required:"true" placeholder:"eg: www.nps.com or *.nps.com" zh_name:"域名"`
|
||||
RouteUrl string `json:"route_url" placeholder:"/api" zh_name:"匹配路径"`
|
||||
HeaderModify map[string]string `json:"header_modify" placeholder:"字段 修改值\nHost www.nps-change.com\nAccept */*" zh_name:"请求头修改"`
|
||||
HostModify string `json:"host_modify" placeholder:"www.nps-changed.com" zh_name:"请求域名"`
|
||||
AddOrigin bool `json:"add_origin" zh_name:"添加来源"`
|
||||
CacheTime int64 `json:"cache_time" placeholder:"600s" zh_name:"缓存时间"`
|
||||
CachePath []string `json:"cache_path" placeholder:".jd\n.css\n.png" zh_name:"缓存路径"`
|
||||
BasicAuth map[string]string `json:"basic_auth" placeholder:"username1 password1\nusername2 password2" zh_name:"basic认证"`
|
||||
httpServe *HttpServe
|
||||
ln *enet.Listener
|
||||
}
|
||||
|
||||
func (hp *HttpServeProcess) GetName() string {
|
||||
return "http_serve"
|
||||
}
|
||||
|
||||
func (hp *HttpServeProcess) GetZhName() string {
|
||||
return "http服务"
|
||||
}
|
||||
|
||||
// Init the action of process
|
||||
func (hp *HttpServeProcess) Init(ac action.Action) error {
|
||||
hp.ac = ac
|
||||
hp.ln = enet.NewListener()
|
||||
hp.httpServe = NewHttpServe(hp.ln, ac)
|
||||
hp.httpServe.SetModify(hp.HeaderModify, hp.HostModify, hp.AddOrigin)
|
||||
if hp.CacheTime > 0 {
|
||||
hp.httpServe.SetCache(hp.CachePath, time.Duration(hp.CacheTime)*time.Second)
|
||||
}
|
||||
if len(hp.BasicAuth) != 0 {
|
||||
hp.httpServe.SetBasicAuth(hp.BasicAuth)
|
||||
}
|
||||
if !hp.tls {
|
||||
go hp.httpServe.Serve()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessConn is used to determine whether to hit the rule
|
||||
// If true, send enet to httpServe
|
||||
func (hp *HttpServeProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
req, err := http.ReadRequest(bufio.NewReader(c))
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "read request")
|
||||
}
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "split host")
|
||||
}
|
||||
if !(common.HostContains(hp.Host, host) && (hp.RouteUrl == "" || strings.HasPrefix(req.URL.Path, hp.RouteUrl))) {
|
||||
logger.Debug("do http proxy failed", zap.String("host", host), zap.String("url", hp.RouteUrl))
|
||||
return false, nil
|
||||
}
|
||||
logger.Debug("do http proxy", zap.String("host", host), zap.String("url", hp.RouteUrl))
|
||||
if err := c.Reset(0); err != nil {
|
||||
return true, errors.Wrap(err, "reset connection data")
|
||||
}
|
||||
return true, hp.ln.SendConn(c)
|
||||
}
|
74
core/process/http_serve_test.go
Normal file
74
core/process/http_serve_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpServeProcess(t *testing.T) {
|
||||
sAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
h := &HttpServeProcess{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
Host: "127.0.0.1",
|
||||
RouteUrl: "",
|
||||
HeaderModify: map[string]string{"modify": "nps"},
|
||||
HostModify: "ehang.io",
|
||||
AddOrigin: true,
|
||||
}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{sAddr},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
rep, err := doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/header/modify"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "nps", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/host"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ehang.io", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/origin/xff"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/origin/xri"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
h.BasicAuth = map[string]string{"aaa": "bbb"}
|
||||
h.Init(ac)
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.Error(t, err)
|
||||
_, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/now"), "aaa", "bbb")
|
||||
assert.NoError(t, err)
|
||||
|
||||
h.BasicAuth = map[string]string{}
|
||||
h.CacheTime = 100
|
||||
h.CachePath = []string{"/now"}
|
||||
h.Init(ac)
|
||||
var time1, time2 string
|
||||
time1, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
time2, err = doRequest(fmt.Sprintf("http://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, time1)
|
||||
assert.Equal(t, time1, time2)
|
||||
|
||||
}
|
36
core/process/https_proxy.go
Normal file
36
core/process/https_proxy.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
type HttpsProxyProcess struct {
|
||||
CertFile string `json:"cert_file" required:"true" placeholder:"/var/cert/cert.pem" zh_name:"cert文件路径"`
|
||||
KeyFile string `json:"key_file" required:"true" placeholder:"/var/cert/key.pem" zh_name:"key文件路径"`
|
||||
config *tls.Config
|
||||
HttpProxyProcess
|
||||
}
|
||||
|
||||
func (hpp *HttpsProxyProcess) GetName() string {
|
||||
return "https_proxy"
|
||||
}
|
||||
|
||||
func (hpp *HttpsProxyProcess) GetZhName() string {
|
||||
return "https代理"
|
||||
}
|
||||
|
||||
func (hpp *HttpsProxyProcess) Init(ac action.Action) error {
|
||||
cer, err := tls.LoadX509KeyPair(hpp.CertFile, hpp.KeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hpp.config = &tls.Config{Certificates: []tls.Certificate{cer}}
|
||||
hpp.ac = ac
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hpp *HttpsProxyProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
return hpp.HttpProxyProcess.ProcessConn(enet.NewReaderConn(tls.Server(c, hpp.config)))
|
||||
}
|
120
core/process/https_proxy_test.go
Normal file
120
core/process/https_proxy_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509/pkix"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var createCertOnce sync.Once
|
||||
|
||||
func createCertFile(t *testing.T) (string, string) {
|
||||
createCertOnce.Do(func() {
|
||||
g := cert.NewX509Generator(pkix.Name{
|
||||
Country: []string{"CN"},
|
||||
Organization: []string{"Ehang.io"},
|
||||
OrganizationalUnit: []string{"nps"},
|
||||
Province: []string{"Beijing"},
|
||||
CommonName: "nps",
|
||||
Locality: []string{"Beijing"},
|
||||
})
|
||||
cert, key, err := g.CreateRootCa()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "cert.pem"), cert, 0600))
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "key.pem"), key, 0600))
|
||||
})
|
||||
return filepath.Join(os.TempDir(), "cert.pem"), filepath.Join(os.TempDir(), "key.pem")
|
||||
}
|
||||
|
||||
func TestHttpsProxyProcess(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
certFilePath, keyFilePath := createCertFile(t)
|
||||
h := HttpsProxyProcess{
|
||||
HttpProxyProcess: HttpProxyProcess{},
|
||||
CertFile: certFilePath,
|
||||
KeyFile: keyFilePath,
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("https://%s", ln.Addr().String()))
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestHttpsProxyProcessBasic(t *testing.T) {
|
||||
certFilePath, keyFilePath := createCertFile(t)
|
||||
sAddr, err := startHttps(t)
|
||||
h := HttpsProxyProcess{
|
||||
HttpProxyProcess: HttpProxyProcess{
|
||||
BasicAuth: map[string]string{"aaa": "bbb"},
|
||||
},
|
||||
CertFile: certFilePath,
|
||||
KeyFile: keyFilePath,
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("https://%s", ln.Addr().String()))
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.Error(t, err)
|
||||
transport.Proxy = func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("https://%s:%s@%s", "aaa", "bbb", ln.Addr().String()))
|
||||
}
|
||||
|
||||
resp, err = client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
41
core/process/https_redirect.go
Normal file
41
core/process/https_redirect.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// HttpsRedirectProcess is used to forward https request by ClientHelloMsg
|
||||
type HttpsRedirectProcess struct {
|
||||
DefaultProcess
|
||||
Host string `json:"host" required:"true" placeholder:"https.nps.com" zh_name:"域名"`
|
||||
}
|
||||
|
||||
func (hrp *HttpsRedirectProcess) GetName() string {
|
||||
return "https_redirect"
|
||||
}
|
||||
|
||||
func (hrp *HttpsRedirectProcess) GetZhName() string {
|
||||
return "https透传"
|
||||
}
|
||||
|
||||
// ProcessConn is used to determine whether to hit the host rule
|
||||
func (hrp *HttpsRedirectProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
clientMsg := cert.ClientHelloMsg{}
|
||||
buf, err := c.AllBytes()
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "get bytes")
|
||||
}
|
||||
if !clientMsg.Unmarshal(buf[5:]) {
|
||||
return false, errors.New("can not unmarshal client hello message")
|
||||
}
|
||||
if common.HostContains(hrp.Host, clientMsg.GetServerName()) {
|
||||
if err = c.Reset(0); err != nil {
|
||||
return false, errors.Wrap(err, "reset reader connection")
|
||||
}
|
||||
return true, errors.Wrap(hrp.ac.RunConn(c), "run action")
|
||||
}
|
||||
return false, nil
|
||||
}
|
43
core/process/https_redirect_test.go
Normal file
43
core/process/https_redirect_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpsRedirectProcess(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
assert.NoError(t, err)
|
||||
h := &HttpsRedirectProcess{
|
||||
Host: "ehang.io",
|
||||
}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{sAddr},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = h.ProcessConn(enet.NewReaderConn(c))
|
||||
_ = c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
_, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.Error(t, err)
|
||||
|
||||
h.Host = "*.github.com"
|
||||
_, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
}
|
51
core/process/https_serve.go
Normal file
51
core/process/https_serve.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/cert"
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type HttpsServeProcess struct {
|
||||
CertFile string `json:"cert_file" required:"true" placeholder:"/var/cert/cert.pem" zh_name:"cert文件路径"`
|
||||
KeyFile string `json:"key_file" required:"true" placeholder:"/var/cert/key.pem" zh_name:"key文件路径"`
|
||||
HttpServeProcess
|
||||
}
|
||||
|
||||
func (hsp *HttpsServeProcess) GetName() string {
|
||||
return "https_serve"
|
||||
}
|
||||
func (hsp *HttpsServeProcess) GetZhName() string {
|
||||
return "https服务"
|
||||
}
|
||||
|
||||
func (hsp *HttpsServeProcess) Init(ac action.Action) error {
|
||||
hsp.tls = true
|
||||
err := hsp.HttpServeProcess.Init(ac)
|
||||
go hsp.httpServe.ServeTLS(hsp.CertFile, hsp.KeyFile)
|
||||
return err
|
||||
}
|
||||
|
||||
func (hsp *HttpsServeProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
clientMsg := cert.ClientHelloMsg{}
|
||||
b, err := c.AllBytes()
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "get bytes")
|
||||
}
|
||||
if !clientMsg.Unmarshal(b[5:]) {
|
||||
return false, errors.New("can not unmarshal client hello message")
|
||||
}
|
||||
if common.HostContains(hsp.Host, clientMsg.GetServerName()) {
|
||||
logger.Debug("do https serve failed", zap.String("host", clientMsg.GetServerName()), zap.String("url", hsp.RouteUrl))
|
||||
if err := c.Reset(0); err != nil {
|
||||
return true, errors.Wrap(err, "reset reader connection")
|
||||
}
|
||||
return true, hsp.HttpServeProcess.ln.SendConn(c)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
78
core/process/https_serve_test.go
Normal file
78
core/process/https_serve_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpsProcess(t *testing.T) {
|
||||
certFile, keyFile := createCertFile(t)
|
||||
sAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
h := &HttpsServeProcess{
|
||||
HttpServeProcess: HttpServeProcess{
|
||||
Host: "www.github.com",
|
||||
RouteUrl: "",
|
||||
HeaderModify: map[string]string{"modify": "nps"},
|
||||
HostModify: "ehang.io",
|
||||
AddOrigin: true,
|
||||
},
|
||||
CertFile: certFile,
|
||||
KeyFile: keyFile,
|
||||
}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{sAddr},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
rep, err := doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/header/modify"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "nps", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/host"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ehang.io", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/origin/xff"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/origin/xri"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
h.BasicAuth = map[string]string{"aaa": "bbb"}
|
||||
assert.NoError(t, h.Init(ac))
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.Error(t, err)
|
||||
_, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"), "aaa", "bbb")
|
||||
assert.NoError(t, err)
|
||||
|
||||
h.BasicAuth = map[string]string{}
|
||||
h.CacheTime = 100
|
||||
h.CachePath = []string{"/now"}
|
||||
assert.NoError(t, h.Init(ac))
|
||||
var time1, time2 string
|
||||
time1, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
time2, err = doRequest(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, time1)
|
||||
assert.Equal(t, time1, time2)
|
||||
|
||||
}
|
45
core/process/pb_app.go
Normal file
45
core/process/pb_app.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type PbAppProcessor struct {
|
||||
DefaultProcess
|
||||
}
|
||||
|
||||
func (pp *PbAppProcessor) GetName() string {
|
||||
return "pb_app"
|
||||
}
|
||||
|
||||
func (pp *PbAppProcessor) ProcessConn(c enet.Conn) (bool, error) {
|
||||
m := &pb.ClientRequest{}
|
||||
n, err := pb.ReadMessage(c, m)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if _, ok := m.ConnType.(*pb.ClientRequest_AppInfo); !ok {
|
||||
return false, nil
|
||||
}
|
||||
if err := c.Reset(n + 4); err != nil {
|
||||
return true, errors.Wrap(err, "reset connection data")
|
||||
}
|
||||
switch m.GetAppInfo().GetConnType() {
|
||||
case pb.ConnType_udp:
|
||||
return true, pp.RunUdp(c)
|
||||
case pb.ConnType_tcp:
|
||||
return true, pp.ac.RunConnWithAddr(c, m.GetAppInfo().GetAppAddr())
|
||||
case pb.ConnType_unix:
|
||||
ac := &action.LocalAction{TargetAddr: []string{m.GetAppInfo().GetAppAddr()}, UnixSocket: true}
|
||||
_ = ac.Init()
|
||||
return true, ac.RunConn(c)
|
||||
}
|
||||
return true, errors.Errorf("can not support the conn type(%d)", m.GetAppInfo().GetConnType())
|
||||
}
|
||||
|
||||
func (pp *PbAppProcessor) RunUdp(c enet.Conn) error {
|
||||
return pp.ac.RunPacketConn(enet.NewTcpPacketConn(c))
|
||||
}
|
113
core/process/pb_app_test.go
Normal file
113
core/process/pb_app_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestProtobufProcess(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
h := &PbAppProcessor{}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{sAddr},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = h.ProcessConn(enet.NewReaderConn(c))
|
||||
_ = c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
client := http.Client{Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
MaxIdleConns: 10000,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
_, err = pb.WriteMessage(conn, &pb.AppInfo{AppAddr: sAddr})
|
||||
return conn, err
|
||||
},
|
||||
}}
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s%s", ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, resp)
|
||||
}
|
||||
|
||||
func TestProtobufUdpProcess(t *testing.T) {
|
||||
finish := make(chan struct{}, 0)
|
||||
lAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
udpServer, err := net.ListenUDP("udp", lAddr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
h := &PbAppProcessor{}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{udpServer.LocalAddr().String()},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = h.ProcessConn(enet.NewReaderConn(c))
|
||||
_ = c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
data := []byte{1, 2, 3, 4}
|
||||
dataReturn := []byte{4, 5, 6, 7}
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
_, err = pb.WriteMessage(conn, &pb.AppInfo{AppAddr: udpServer.LocalAddr().String(), ConnType: pb.ConnType_udp})
|
||||
|
||||
go func() {
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := udpServer.ReadFrom(b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, b[:n], data)
|
||||
|
||||
_, err = udpServer.WriteTo(dataReturn, addr)
|
||||
assert.NoError(t, err)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
c := enet.NewTcpPacketConn(conn)
|
||||
_, err = c.WriteTo(data, udpServer.LocalAddr())
|
||||
assert.NoError(t, err)
|
||||
|
||||
<-finish
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := c.ReadFrom(b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataReturn, b[:n])
|
||||
assert.Equal(t, addr.String(), udpServer.LocalAddr().String())
|
||||
}
|
29
core/process/pb_ping.go
Normal file
29
core/process/pb_ping.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PbPingProcessor struct {
|
||||
DefaultProcess
|
||||
}
|
||||
|
||||
func (pp *PbPingProcessor) GetName() string {
|
||||
return "pb_ping"
|
||||
}
|
||||
|
||||
func (pp *PbPingProcessor) ProcessConn(c enet.Conn) (bool, error) {
|
||||
m := &pb.ClientRequest{}
|
||||
_, err := pb.ReadMessage(c, m)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if _, ok := m.ConnType.(*pb.ClientRequest_Ping); !ok {
|
||||
return false, nil
|
||||
}
|
||||
m.GetPing().Now = time.Now().String()
|
||||
_, err = pb.WriteMessage(c, m)
|
||||
return true, err
|
||||
}
|
43
core/process/pb_ping_test.go
Normal file
43
core/process/pb_ping_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/pb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPbPingProcess(t *testing.T) {
|
||||
|
||||
h := &PbPingProcessor{}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{},
|
||||
}
|
||||
ac.Init()
|
||||
err = h.Init(ac)
|
||||
assert.NoError(t, err)
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = h.ProcessConn(enet.NewReaderConn(c))
|
||||
_ = c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
_, err = pb.WriteMessage(conn, &pb.Ping{Now: time.Now().String()})
|
||||
assert.NoError(t, err)
|
||||
m := &pb.Ping{}
|
||||
_, err = pb.ReadMessage(conn, m)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, m.Now)
|
||||
}
|
51
core/process/process.go
Normal file
51
core/process/process.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Process = (*DefaultProcess)(nil)
|
||||
_ Process = (*HttpServeProcess)(nil)
|
||||
_ Process = (*HttpsServeProcess)(nil)
|
||||
_ Process = (*HttpProxyProcess)(nil)
|
||||
_ Process = (*HttpsProxyProcess)(nil)
|
||||
_ Process = (*HttpsRedirectProcess)(nil)
|
||||
_ Process = (*Socks5Process)(nil)
|
||||
_ Process = (*TransparentProcess)(nil)
|
||||
)
|
||||
|
||||
type Process interface {
|
||||
Init(action action.Action) error
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
ProcessConn(enet.Conn) (bool, error)
|
||||
ProcessPacketConn(enet.PacketConn) (bool, error)
|
||||
}
|
||||
|
||||
type DefaultProcess struct {
|
||||
ac action.Action
|
||||
}
|
||||
|
||||
func (bp *DefaultProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
return true, bp.ac.RunConn(c)
|
||||
}
|
||||
|
||||
func (bp *DefaultProcess) GetName() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (bp *DefaultProcess) GetZhName() string {
|
||||
return "默认"
|
||||
}
|
||||
|
||||
// Init the action of process
|
||||
func (bp *DefaultProcess) Init(ac action.Action) error {
|
||||
bp.ac = ac
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bp *DefaultProcess) ProcessPacketConn(pc enet.PacketConn) (bool, error) {
|
||||
return true, bp.ac.RunPacketConn(pc)
|
||||
}
|
169
core/process/serve.go
Normal file
169
core/process/serve.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"github.com/gin-contrib/cache"
|
||||
"github.com/gin-contrib/cache/persistence"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HttpServe struct {
|
||||
engine *gin.Engine
|
||||
ln net.Listener
|
||||
ac action.Action
|
||||
httpServe *http.Server
|
||||
cacheStore *persistence.InMemoryStore
|
||||
cacheTime time.Duration
|
||||
cachePath []string
|
||||
headerModify map[string]string
|
||||
hostModify string
|
||||
addOrigin bool
|
||||
reverseProxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
func NewHttpServe(ln net.Listener, ac action.Action) *HttpServe {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
hs := &HttpServe{
|
||||
ln: ln,
|
||||
ac: ac,
|
||||
engine: gin.New(),
|
||||
}
|
||||
hs.httpServe = &http.Server{
|
||||
Handler: hs.engine,
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
|
||||
}
|
||||
|
||||
hs.reverseProxy = &httputil.ReverseProxy{
|
||||
Director: func(request *http.Request) {
|
||||
_ = hs.transport(request)
|
||||
hs.doModify(request)
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10000,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return ac.GetServerConn()
|
||||
},
|
||||
},
|
||||
}
|
||||
serverHttp := func(w http.ResponseWriter, r *http.Request) {
|
||||
hs.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
hs.engine.NoRoute(func(c *gin.Context) {
|
||||
cached := false
|
||||
for _, p := range hs.cachePath {
|
||||
if strings.Contains(c.Request.RequestURI, p) {
|
||||
cached = true
|
||||
cache.CachePage(hs.cacheStore, hs.cacheTime, func(c *gin.Context) {
|
||||
serverHttp(c.Writer, c.Request)
|
||||
})(c)
|
||||
}
|
||||
}
|
||||
if !cached {
|
||||
serverHttp(c.Writer, c.Request)
|
||||
}
|
||||
})
|
||||
hs.engine.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
logger.Debug("http request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("proto", param.Request.Proto),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.String("error_message", param.ErrorMessage),
|
||||
zap.Int("response_code", param.StatusCode),
|
||||
)
|
||||
return ""
|
||||
}))
|
||||
return hs
|
||||
}
|
||||
|
||||
func (hs *HttpServe) transport(req *http.Request) error {
|
||||
ruri := req.URL.RequestURI()
|
||||
req.URL.Scheme = "http"
|
||||
if req.URL.Scheme != "" && req.URL.Opaque == "" {
|
||||
ruri = req.URL.Scheme + "://" + req.Host + ruri
|
||||
} else if req.Method == "CONNECT" && req.URL.Path == "" {
|
||||
// CONNECT requests normally give just the host and port, not a full URL.
|
||||
ruri = req.Host
|
||||
if req.URL.Opaque != "" {
|
||||
ruri = req.URL.Opaque
|
||||
}
|
||||
}
|
||||
req.RequestURI = ""
|
||||
var err error
|
||||
req.URL, err = url.Parse(ruri)
|
||||
return err
|
||||
}
|
||||
|
||||
func (hs *HttpServe) SetBasicAuth(accounts map[string]string) {
|
||||
hs.engine.Use(gin.BasicAuth(accounts), gin.Recovery())
|
||||
}
|
||||
|
||||
func (hs *HttpServe) SetCache(cachePath []string, cacheTime time.Duration) {
|
||||
hs.cachePath = cachePath
|
||||
hs.cacheTime = cacheTime
|
||||
hs.cacheStore = persistence.NewInMemoryStore(cacheTime * time.Second)
|
||||
}
|
||||
|
||||
func (hs *HttpServe) SetModify(headerModify map[string]string, hostModify string, addOrigin bool) {
|
||||
hs.headerModify = headerModify
|
||||
hs.hostModify = hostModify
|
||||
hs.addOrigin = addOrigin
|
||||
return
|
||||
}
|
||||
|
||||
func (hs *HttpServe) Serve() error {
|
||||
return hs.httpServe.Serve(hs.ln)
|
||||
}
|
||||
|
||||
func (hs *HttpServe) ServeTLS(certFile string, keyFile string) error {
|
||||
return hs.httpServe.ServeTLS(hs.ln, certFile, keyFile)
|
||||
}
|
||||
|
||||
// doModify is used to modify http request
|
||||
func (hs *HttpServe) doModify(req *http.Request) {
|
||||
if hs.hostModify != "" {
|
||||
req.Host = hs.hostModify
|
||||
}
|
||||
for k, v := range hs.headerModify {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
addr := strings.Split(req.RemoteAddr, ":")[0]
|
||||
if hs.addOrigin {
|
||||
// XFF is setting in reverseProxy
|
||||
req.Header.Set("X-Real-IP", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func writeContentType(w http.ResponseWriter, value []string) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = value
|
||||
}
|
||||
}
|
||||
|
||||
type render struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (r *render) Render(writer http.ResponseWriter) error {
|
||||
_, err := io.Copy(writer, r.resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *render) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, []string{r.resp.Header.Get("Content-Type")})
|
||||
}
|
299
core/process/serve_test.go
Normal file
299
core/process/serve_test.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/websocket"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var startHttpOnce sync.Once
|
||||
var startHttpsOnce sync.Once
|
||||
var handleOnce sync.Once
|
||||
var ln net.Listener
|
||||
var lns net.Listener
|
||||
var err error
|
||||
|
||||
func registerHandle() {
|
||||
handleOnce.Do(func() {
|
||||
http.Handle("/ws", websocket.Handler(func(ws *websocket.Conn) {
|
||||
msg := make([]byte, 512)
|
||||
n, err := ws.Read(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ws.Write(msg[:n])
|
||||
}))
|
||||
http.HandleFunc("/now", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(time.Now().String()))
|
||||
})
|
||||
http.HandleFunc("/host", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.Host))
|
||||
})
|
||||
http.HandleFunc("/header/modify", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.Header.Get("modify")))
|
||||
})
|
||||
http.HandleFunc("/origin/xff", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.Header.Get("X-Forwarded-For")))
|
||||
})
|
||||
http.HandleFunc("/origin/xri", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.Header.Get("X-Real-IP")))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func startHttp(t *testing.T) (string, error) {
|
||||
startHttpOnce.Do(func() {
|
||||
ln, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
registerHandle()
|
||||
go http.Serve(ln, nil)
|
||||
|
||||
})
|
||||
|
||||
return ln.Addr().String(), err
|
||||
}
|
||||
|
||||
func startHttps(t *testing.T) (string, error) {
|
||||
startHttpsOnce.Do(func() {
|
||||
lns, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
registerHandle()
|
||||
certFilePath, keyFilePath := createCertFile(t)
|
||||
go http.ServeTLS(lns, nil, certFilePath, keyFilePath)
|
||||
})
|
||||
|
||||
return lns.Addr().String(), err
|
||||
}
|
||||
|
||||
func doRequest(params ...string) (string, error) {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest("GET", params[0], nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req.Header.Set("Connection", "close")
|
||||
if len(params) >= 3 && params[1] != "" {
|
||||
req.SetBasicAuth(params[1], params[2])
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
client.Transport = &http.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return tls.Dial(network, addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "www.github.com",
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return "0", errors.Errorf("respond error, code %d", resp.StatusCode)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func createHttpServe(serverAddr string) (*HttpServe, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ac := &action.LocalAction{
|
||||
DefaultAction: action.DefaultAction{},
|
||||
TargetAddr: []string{serverAddr},
|
||||
}
|
||||
ac.Init()
|
||||
return NewHttpServe(ln, ac), nil
|
||||
}
|
||||
|
||||
func TestHttpServeWebsocket(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
go hs.Serve()
|
||||
|
||||
ws, err := websocket.Dial(fmt.Sprintf("ws://%s/ws", hs.ln.Addr().String()), "", fmt.Sprintf("http://%s/ws", hs.ln.Addr().String()))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer ws.Close() //关闭连接
|
||||
|
||||
sendMsg := []byte("nps")
|
||||
_, err = ws.Write(sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := make([]byte, 512)
|
||||
m, err := ws.Read(msg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, sendMsg, msg[:m])
|
||||
}
|
||||
|
||||
func TestHttpsServeWebsocket(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
cert, key := createCertFile(t)
|
||||
go hs.ServeTLS(cert, key)
|
||||
|
||||
config, err := websocket.NewConfig(fmt.Sprintf("wss://%s/ws", hs.ln.Addr().String()), fmt.Sprintf("https://%s/ws", hs.ln.Addr().String()))
|
||||
assert.NoError(t, err)
|
||||
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
ws, err := websocket.DialConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer ws.Close() //关闭连接
|
||||
|
||||
sendMsg := []byte("nps")
|
||||
_, err = ws.Write(sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := make([]byte, 512)
|
||||
m, err := ws.Read(msg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, sendMsg, msg[:m])
|
||||
}
|
||||
|
||||
func TestHttpServeModify(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
go hs.Serve()
|
||||
|
||||
hs.SetModify(map[string]string{"modify": "test"}, "ehang.io", true)
|
||||
|
||||
rep, err := doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/header/modify"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/host"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ehang.io", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/origin/xff"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/origin/xri"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
}
|
||||
|
||||
func TestHttpsServeModify(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
cert, key := createCertFile(t)
|
||||
go hs.ServeTLS(cert, key)
|
||||
|
||||
hs.SetModify(map[string]string{"modify": "test"}, "ehang.io", true)
|
||||
|
||||
rep, err := doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/header/modify"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/host"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ehang.io", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/origin/xff"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
|
||||
rep, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/origin/xri"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1", rep)
|
||||
}
|
||||
|
||||
func TestHttpServeCache(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
go hs.Serve()
|
||||
hs.SetCache([]string{"now"}, time.Second*10)
|
||||
|
||||
var time1, time2 string
|
||||
time1, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
time2, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, time1)
|
||||
assert.Equal(t, time1, time2)
|
||||
}
|
||||
|
||||
func TestHttpsServeCache(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
cert, key := createCertFile(t)
|
||||
go hs.ServeTLS(cert, key)
|
||||
hs.SetCache([]string{"now"}, time.Second*10)
|
||||
|
||||
var time1, time2 string
|
||||
time1, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
time2, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/now"))
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, time1)
|
||||
assert.Equal(t, time1, time2)
|
||||
}
|
||||
|
||||
func TestHttpServeBasicAuth(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
go hs.Serve()
|
||||
hs.SetBasicAuth(map[string]string{"aaa": "bbb"})
|
||||
_, err = doRequest(fmt.Sprintf("http://%s%s", hs.ln.Addr().String(), "/now"), "aaa", "bbb")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHttpsServeBasicAuth(t *testing.T) {
|
||||
serverAddr, err := startHttp(t)
|
||||
assert.NoError(t, err)
|
||||
hs, err := createHttpServe(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
cert, key := createCertFile(t)
|
||||
go hs.ServeTLS(cert, key)
|
||||
|
||||
hs.SetBasicAuth(map[string]string{"aaa": "bbb"})
|
||||
_, err = doRequest(fmt.Sprintf("https://%s%s", hs.ln.Addr().String(), "/now"), "aaa", "bbb")
|
||||
assert.NoError(t, err)
|
||||
}
|
219
core/process/socks5.go
Normal file
219
core/process/socks5.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"encoding/binary"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/robfig/go-cache"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Socks5Process struct {
|
||||
DefaultProcess
|
||||
Accounts map[string]string `json:"accounts" placeholder:"username1 password1\nusername2 password2" zh_name:"授权账号密码"`
|
||||
ServerIp string `json:"server_ip" placeholder:"123.123.123.123" zh_name:"udp连接地址"`
|
||||
ipStore *cache.Cache
|
||||
}
|
||||
|
||||
const (
|
||||
ipV4 = 1
|
||||
domainName = 3
|
||||
ipV6 = 4
|
||||
connectMethod = 1
|
||||
bindMethod = 2
|
||||
associateMethod = 3
|
||||
// The maximum packet size of any udp Associate packet, based on ethernet's max size,
|
||||
// minus the IP and UDP headers5. IPv4 has a 20 byte header, UDP adds an
|
||||
// additional 4 bytes5. This is a total overhead of 24 bytes5. Ethernet's
|
||||
// max packet size is 1500 bytes, 1500 - 24 = 1476.
|
||||
maxUDPPacketSize = 1476
|
||||
)
|
||||
|
||||
const (
|
||||
succeeded uint8 = iota
|
||||
serverFailure
|
||||
notAllowed
|
||||
networkUnreachable
|
||||
hostUnreachable
|
||||
connectionRefused
|
||||
ttlExpired
|
||||
commandNotSupported
|
||||
addrTypeNotSupported
|
||||
)
|
||||
|
||||
const (
|
||||
UserPassAuth = uint8(2)
|
||||
userAuthVersion = uint8(1)
|
||||
authSuccess = uint8(0)
|
||||
authFailure = uint8(1)
|
||||
)
|
||||
|
||||
func (s5 *Socks5Process) GetName() string {
|
||||
return "socks5"
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) GetZhName() string {
|
||||
return "socks5代理"
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) Init(ac action.Action) error {
|
||||
s5.ipStore = cache.New(time.Minute, time.Minute*2)
|
||||
return s5.DefaultProcess.Init(ac)
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) ProcessConn(c enet.Conn) (bool, error) {
|
||||
return true, s5.handleConn(c)
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) ProcessPacketConn(pc enet.PacketConn) (bool, error) {
|
||||
ip, _, _ := net.SplitHostPort(pc.LocalAddr().String())
|
||||
if _, ok := s5.ipStore.Get(ip); !ok {
|
||||
return false, nil
|
||||
}
|
||||
_, addr, err := pc.FirstPacket()
|
||||
if err != nil {
|
||||
return false, errors.New("addr not found")
|
||||
}
|
||||
return true, s5.ac.RunPacketConn(enet.NewS5PacketConn(pc, addr))
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) handleConn(c enet.Conn) error {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if version := buf[0]; version != 5 {
|
||||
return errors.New("only support socks5")
|
||||
}
|
||||
nMethods := buf[1]
|
||||
|
||||
methods := make([]byte, nMethods)
|
||||
if l, err := c.Read(methods); l != int(nMethods) || err != nil {
|
||||
return errors.New("wrong method")
|
||||
}
|
||||
|
||||
if len(s5.Accounts) > 0 {
|
||||
buf[1] = UserPassAuth
|
||||
_, err := c.Write(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s5.Auth(c); err != nil {
|
||||
return errors.Wrap(err, "auth failed")
|
||||
}
|
||||
} else {
|
||||
buf[1] = 0
|
||||
_, _ = c.Write(buf)
|
||||
}
|
||||
return s5.handleRequest(c)
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) Auth(c enet.Conn) error {
|
||||
header := []byte{0, 0}
|
||||
if _, err := io.ReadAtLeast(c, header, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != userAuthVersion {
|
||||
return errors.New("auth type not support")
|
||||
}
|
||||
userLen := int(header[1])
|
||||
user := make([]byte, userLen)
|
||||
if _, err := io.ReadAtLeast(c, user, userLen); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.Read(header[:1]); err != nil {
|
||||
return errors.New("the length of password is incorrect")
|
||||
}
|
||||
passLen := int(header[0])
|
||||
pass := make([]byte, passLen)
|
||||
if _, err := io.ReadAtLeast(c, pass, passLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := s5.Accounts[string(user)]
|
||||
|
||||
if p == "" || string(pass) != p {
|
||||
_, _ = c.Write([]byte{userAuthVersion, authFailure})
|
||||
return errors.New("auth failure")
|
||||
}
|
||||
|
||||
if _, err := c.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
||||
return errors.Wrap(err, "write auth success")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) handleRequest(c enet.Conn) error {
|
||||
header := make([]byte, 3)
|
||||
|
||||
_, err := io.ReadFull(c, header)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch header[1] {
|
||||
case connectMethod:
|
||||
s5.handleConnect(c)
|
||||
case associateMethod:
|
||||
s5.handleUDP(c)
|
||||
default:
|
||||
s5.sendReply(c, commandNotSupported)
|
||||
c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//enet
|
||||
func (s5 *Socks5Process) handleConnect(c enet.Conn) {
|
||||
addr, err := common.ReadAddr(c)
|
||||
if err != nil {
|
||||
s5.sendReply(c, addrTypeNotSupported)
|
||||
logger.Warn("read socks addr error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
s5.sendReply(c, succeeded)
|
||||
_ = s5.ac.RunConnWithAddr(c, addr.String())
|
||||
return
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) handleUDP(c net.Conn) {
|
||||
_, err := common.ReadAddr(c)
|
||||
if err != nil {
|
||||
s5.sendReply(c, addrTypeNotSupported)
|
||||
logger.Warn("read socks addr error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
ip, _, _ := net.SplitHostPort(c.RemoteAddr().String())
|
||||
s5.ipStore.Set(ip, true, time.Minute)
|
||||
s5.sendReply(c, succeeded)
|
||||
}
|
||||
|
||||
func (s5 *Socks5Process) sendReply(c net.Conn, rep uint8) {
|
||||
reply := []byte{
|
||||
5,
|
||||
rep,
|
||||
0,
|
||||
1,
|
||||
}
|
||||
|
||||
localHost, localPort, _ := net.SplitHostPort(c.LocalAddr().String())
|
||||
if s5.ServerIp != "" {
|
||||
localHost = s5.ServerIp
|
||||
}
|
||||
ipBytes := net.ParseIP(localHost).To4()
|
||||
nPort, _ := strconv.Atoi(localPort)
|
||||
reply = append(reply, ipBytes...)
|
||||
portBytes := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(portBytes, uint16(nPort))
|
||||
reply = append(reply, portBytes...)
|
||||
_, _ = c.Write(reply)
|
||||
}
|
152
core/process/socks5_test.go
Normal file
152
core/process/socks5_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/lib/common"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/proxy"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSocks5ProxyProcess(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
assert.NoError(t, err)
|
||||
h := Socks5Process{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go h.ProcessConn(enet.NewReaderConn(c))
|
||||
}
|
||||
}()
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Proxy: func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(fmt.Sprintf("socks5://%s", ln.Addr().String()))
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestSocks5ProxyProcessAuth(t *testing.T) {
|
||||
sAddr, err := startHttps(t)
|
||||
h := Socks5Process{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
Accounts: map[string]string{"aaa": "bbb"},
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = h.ProcessConn(enet.NewReaderConn(c))
|
||||
_ = c.Close()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
auth := proxy.Auth{
|
||||
User: "aaa",
|
||||
Password: "bbb",
|
||||
}
|
||||
|
||||
dialer, err := proxy.SOCKS5("tcp", ln.Addr().String(), nil, proxy.Direct)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tr := &http.Transport{Dial: dialer.Dial, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
client := &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.Error(t, err)
|
||||
|
||||
dialer, err = proxy.SOCKS5("tcp", ln.Addr().String(), &auth, proxy.Direct)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tr = &http.Transport{Dial: dialer.Dial, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
client = &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
resp, err = client.Get(fmt.Sprintf("https://%s/now", sAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestSocks5ProxyProcessUdp(t *testing.T) {
|
||||
h := Socks5Process{
|
||||
DefaultProcess: DefaultProcess{},
|
||||
}
|
||||
ac := &action.LocalAction{}
|
||||
ac.Init()
|
||||
assert.NoError(t, h.Init(ac))
|
||||
h.ipStore.Set("127.0.0.1", true, time.Minute)
|
||||
|
||||
serverPc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
localPc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
appPc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
data := []byte("test")
|
||||
go func() {
|
||||
p := make([]byte, 1500)
|
||||
n, addr, err := appPc.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p[:n], data)
|
||||
_, err = appPc.WriteTo(data, addr)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
p := make([]byte, 1500)
|
||||
n, addr, err := serverPc.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
pc := enet.NewReaderPacketConn(serverPc, p[:n], addr)
|
||||
err = pc.SendPacket(p[:n], addr)
|
||||
assert.NoError(t, err)
|
||||
b, err := h.ProcessPacketConn(pc)
|
||||
assert.Equal(t, b, true)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
b := []byte{0, 0, 0}
|
||||
pAddr, err := common.ParseAddr(appPc.LocalAddr().String())
|
||||
assert.NoError(t, err)
|
||||
b = append(b, pAddr...)
|
||||
b = append(b, data...)
|
||||
_, err = localPc.WriteTo(b, serverPc.LocalAddr())
|
||||
assert.NoError(t, err)
|
||||
p := make([]byte, 1500)
|
||||
n, _, err := localPc.ReadFrom(p)
|
||||
assert.NoError(t, err)
|
||||
respAddr, err := common.SplitAddr(p[3:])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, respAddr.String(), appPc.LocalAddr().String())
|
||||
assert.Equal(t, p[3+len(respAddr):n], data)
|
||||
}
|
56
core/process/transparent_linux.go
Normal file
56
core/process/transparent_linux.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"strconv"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const SO_ORIGINAL_DST = 80
|
||||
|
||||
type TransparentProcess struct {
|
||||
DefaultProcess
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) GetName() string {
|
||||
return "transparent"
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) GetZhName() string {
|
||||
return "透明代理"
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
addr, err := tp.getAddress(c)
|
||||
if err != nil {
|
||||
logger.Debug("get syscall error", zap.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
return true, tp.ac.RunConnWithAddr(c, addr)
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) getAddress(conn net.Conn) (string, error) {
|
||||
// TODO: IPV6 support
|
||||
sysrawConn, f := conn.(syscall.Conn)
|
||||
if !f {
|
||||
return "", nil
|
||||
}
|
||||
rawConn, err := sysrawConn.SyscallConn()
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
var ip string
|
||||
var port uint16
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
addr, err := syscall.GetsockoptIPv6Mreq(int(fd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ip = net.IP(addr.Multiaddr[4:8]).String()
|
||||
port = uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3])
|
||||
})
|
||||
return net.JoinHostPort(ip, strconv.Itoa(int(port))), nil
|
||||
}
|
23
core/process/transparent_others.go
Normal file
23
core/process/transparent_others.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// +build !linux
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
type TransparentProcess struct {
|
||||
DefaultProcess
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) GetName() string {
|
||||
return "transparent"
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) GetZhName() string {
|
||||
return "透明代理"
|
||||
}
|
||||
|
||||
func (tp *TransparentProcess) ProcessConn(c enet.Conn) (bool, error) {
|
||||
return false, nil
|
||||
}
|
138
core/rule/list.go
Normal file
138
core/rule/list.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/limiter"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/server"
|
||||
"github.com/fatih/structtag"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var orderMap map[string]int
|
||||
var nowOrder = 2<<8 - 1
|
||||
|
||||
type children map[string]*List
|
||||
|
||||
var chains children
|
||||
var limiters children
|
||||
|
||||
func init() {
|
||||
orderMap = make(map[string]int, 0)
|
||||
chains = make(map[string]*List, 0)
|
||||
limiters = make(map[string]*List, 0)
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.HttpHandler{}).Append(&process.HttpServeProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{}, &action.AdminAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.HttpsHandler{}).Append(&process.HttpsServeProcess{HttpServeProcess: process.HttpServeProcess{}}).AppendMany(&action.NpcAction{}, &action.LocalAction{}, &action.AdminAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.HttpsHandler{}).Append(&process.HttpsRedirectProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.HttpHandler{}).Append(&process.HttpProxyProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.HttpsHandler{}).Append(&process.HttpsProxyProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.Socks5Handler{}).Append(&process.Socks5Process{}).AppendMany(&action.LocalAction{}, &action.NpcAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.TransparentHandler{}).Append(&process.TransparentProcess{}).AppendMany(&action.NpcAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.RdpHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.RedisHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
|
||||
chains.Append(&server.UdpServer{}).Append(&handler.DnsHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
// TODO p2p
|
||||
chains.Append(&server.UdpServer{}).Append(&handler.P2PHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.NpcAction{}, &action.LocalAction{})
|
||||
chains.Append(&server.UdpServer{}).Append(&handler.QUICHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.BridgeAction{})
|
||||
chains.Append(&server.UdpServer{}).Append(&handler.Socks5UdpHandler{}).Append(&process.Socks5Process{}).AppendMany(&action.LocalAction{}, &action.NpcAction{})
|
||||
|
||||
chains.Append(&server.TcpServer{}).Append(&handler.DefaultHandler{}).Append(&process.DefaultProcess{}).AppendMany(&action.BridgeAction{}, &action.AdminAction{}, &action.NpcAction{}, &action.LocalAction{})
|
||||
|
||||
limiters.AppendMany(&limiter.RateLimiter{}, &limiter.ConnNumLimiter{}, &limiter.FlowLimiter{}, &limiter.IpConnNumLimiter{})
|
||||
}
|
||||
|
||||
func GetLimiters() children {
|
||||
return limiters
|
||||
}
|
||||
|
||||
func GetChains() children {
|
||||
return chains
|
||||
}
|
||||
|
||||
type NameInterface interface {
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
}
|
||||
|
||||
type List struct {
|
||||
ZhName string `json:"zh_name"`
|
||||
Self interface{} `json:"-"`
|
||||
Field []field `json:"field"`
|
||||
Children children `json:"children"`
|
||||
}
|
||||
|
||||
func (c children) AppendMany(child ...NameInterface) {
|
||||
for _, cd := range child {
|
||||
c.Append(cd)
|
||||
}
|
||||
}
|
||||
|
||||
func (c children) Append(child NameInterface) children {
|
||||
if v, ok := c[child.GetName()]; ok {
|
||||
return v.Children
|
||||
}
|
||||
if _, ok := orderMap[child.GetName()]; !ok {
|
||||
orderMap[child.GetName()] = nowOrder
|
||||
nowOrder--
|
||||
}
|
||||
cd := &List{Self: child, Field: getFieldName(child), Children: make(map[string]*List, 0), ZhName: child.GetZhName()}
|
||||
c[child.GetName()] = cd
|
||||
return cd.Children
|
||||
}
|
||||
|
||||
type field struct {
|
||||
FiledType string `json:"field_type"`
|
||||
FieldName string `json:"field_name"`
|
||||
FieldZhName string `json:"field_zh_name"`
|
||||
FieldRequired bool `json:"field_required"`
|
||||
FieldExample string `json:"field_example"`
|
||||
}
|
||||
|
||||
func getFieldName(structName interface{}, child ...bool) []field {
|
||||
result := make([]field, 0)
|
||||
t := reflect.TypeOf(structName)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return result
|
||||
}
|
||||
fieldNum := t.NumField()
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
if len(child) == 0 && t.Field(i).Type.Kind() == reflect.Struct {
|
||||
value := reflect.ValueOf(structName)
|
||||
if value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
if value.Field(i).CanInterface() {
|
||||
result = append(result, getFieldName(value.Field(i).Interface(), true)...)
|
||||
}
|
||||
}
|
||||
tags, err := structtag.Parse(string(t.Field(i).Tag))
|
||||
if err == nil {
|
||||
tag, err := tags.Get("json")
|
||||
if err == nil {
|
||||
f := field{}
|
||||
f.FiledType = t.Field(i).Type.Kind().String()
|
||||
f.FieldName = tag.Name
|
||||
tag, err = tags.Get("required")
|
||||
if err == nil {
|
||||
f.FieldRequired, _ = strconv.ParseBool(tag.Name)
|
||||
}
|
||||
tag, err = tags.Get("placeholder")
|
||||
if err == nil {
|
||||
f.FieldExample = tag.Name
|
||||
}
|
||||
tag, err = tags.Get("zh_name")
|
||||
if err == nil {
|
||||
f.FieldZhName = tag.Name
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
13
core/rule/list_test.go
Normal file
13
core/rule/list_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/process"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetFields(t *testing.T) {
|
||||
h := process.HttpsServeProcess{HttpServeProcess: process.HttpServeProcess{}}
|
||||
if len(getFieldName(h)) < 3 {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
71
core/rule/rule.go
Normal file
71
core/rule/rule.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/limiter"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/server"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
Server server.Server `json:"server"`
|
||||
Handler handler.Handler `json:"handler"`
|
||||
Process process.Process `json:"process"`
|
||||
Action action.Action `json:"action"`
|
||||
Limiters []limiter.Limiter `json:"limiters"`
|
||||
}
|
||||
|
||||
var servers map[string]server.Server
|
||||
|
||||
func init() {
|
||||
servers = make(map[string]server.Server, 0)
|
||||
}
|
||||
|
||||
func (r *Rule) GetHandler() handler.Handler {
|
||||
return r.Handler
|
||||
}
|
||||
|
||||
func (r *Rule) Init() error {
|
||||
s := r.Server
|
||||
var ok bool
|
||||
if s, ok = servers[r.Server.GetName()+":"+r.Server.GetServerAddr()]; !ok {
|
||||
s = r.Server
|
||||
err := s.Init()
|
||||
servers[r.Server.GetName()+":"+r.Server.GetServerAddr()] = s
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.Serve()
|
||||
}
|
||||
s.RegisterHandle(r)
|
||||
r.Handler.AddRule(r)
|
||||
if err := r.Action.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, l := range r.Limiters {
|
||||
if err := l.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return r.Process.Init(r.Action)
|
||||
}
|
||||
|
||||
func (r *Rule) RunConn(c enet.Conn) (bool, error) {
|
||||
var err error
|
||||
for _, lm := range r.Limiters {
|
||||
if c, err = lm.DoLimit(c); err != nil {
|
||||
return true, errors.Wrap(err, "rule run")
|
||||
}
|
||||
}
|
||||
if err = c.Reset(0); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r.Process.ProcessConn(c)
|
||||
}
|
||||
|
||||
func (r *Rule) RunPacketConn(pc enet.PacketConn) (bool, error) {
|
||||
return r.Process.ProcessPacketConn(pc)
|
||||
}
|
92
core/rule/rule_json.go
Normal file
92
core/rule/rule_json.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/limiter"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/server"
|
||||
"encoding/json"
|
||||
"github.com/pkg/errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type JsonData struct {
|
||||
ObjType string `json:"obj_type"`
|
||||
ObjData string `json:"obj_data"`
|
||||
}
|
||||
|
||||
type JsonRule struct {
|
||||
Name string `json:"name"`
|
||||
Uuid string `json:"uuid"`
|
||||
Status int `json:"status"`
|
||||
Extend int `json:"extend"`
|
||||
Server JsonData `json:"server"`
|
||||
Handler JsonData `json:"handler"`
|
||||
Process JsonData `json:"process"`
|
||||
Action JsonData `json:"action"`
|
||||
Limiters []JsonData `json:"limiters"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
var NotFoundError = errors.New("not found")
|
||||
|
||||
func (jd *JsonRule) ToRule() (*Rule, error) {
|
||||
r := &Rule{Limiters: make([]limiter.Limiter, 0)}
|
||||
s, ok := chains[jd.Server.ObjType]
|
||||
if !ok {
|
||||
return nil, NotFoundError
|
||||
}
|
||||
r.Server = clone(s.Self).(server.Server)
|
||||
err := json.Unmarshal([]byte(jd.Server.ObjData), r.Server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h, ok := s.Children[jd.Handler.ObjType]
|
||||
if !ok {
|
||||
return nil, NotFoundError
|
||||
}
|
||||
r.Handler = clone(h.Self).(handler.Handler)
|
||||
err = json.Unmarshal([]byte(jd.Handler.ObjData), r.Handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, ok := h.Children[jd.Process.ObjType]
|
||||
if !ok {
|
||||
return nil, NotFoundError
|
||||
}
|
||||
r.Process = clone(p.Self).(process.Process)
|
||||
err = json.Unmarshal([]byte(jd.Process.ObjData), r.Process)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a, ok := p.Children[jd.Action.ObjType]
|
||||
if !ok {
|
||||
return nil, NotFoundError
|
||||
}
|
||||
r.Action = clone(a.Self).(action.Action)
|
||||
err = json.Unmarshal([]byte(jd.Action.ObjData), r.Action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range jd.Limiters {
|
||||
l, ok := limiters[v.ObjType]
|
||||
if !ok {
|
||||
return nil, NotFoundError
|
||||
}
|
||||
lm := clone(l.Self).(limiter.Limiter)
|
||||
err = json.Unmarshal([]byte(v.ObjData), lm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Limiters = append(r.Limiters, lm)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func clone(i interface{}) interface{} {
|
||||
v := reflect.ValueOf(i).Elem()
|
||||
vNew := reflect.New(v.Type())
|
||||
vNew.Elem().Set(v)
|
||||
return vNew.Interface()
|
||||
}
|
58
core/rule/rule_json_test.go
Normal file
58
core/rule/rule_json_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/server"
|
||||
"encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
type person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
a := &person{
|
||||
Name: "ALice",
|
||||
Age: 20,
|
||||
}
|
||||
b := clone(a).(*person)
|
||||
assert.Equal(t, a.Name, b.Name)
|
||||
assert.Equal(t, a.Age, b.Age)
|
||||
a.Name = "Bob"
|
||||
a.Age = 21
|
||||
assert.NotEqual(t, a.Name, b.Name)
|
||||
assert.NotEqual(t, a.Age, b.Age)
|
||||
assert.NotEqual(t, reflect.ValueOf(a).Pointer(), reflect.ValueOf(b).Pointer())
|
||||
}
|
||||
|
||||
func getJson(t *testing.T, i interface{}) string {
|
||||
b, err := json.Marshal(i)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, string(b))
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestJsonRule(t *testing.T) {
|
||||
s := &server.TcpServer{ ServerAddr: "127.0.0.1:0"}
|
||||
h := &handler.HttpHandler{}
|
||||
p := &process.HttpServeProcess{}
|
||||
a := &action.LocalAction{}
|
||||
js := JsonRule{
|
||||
Uuid: "",
|
||||
Server: JsonData{s.GetName(), getJson(t, s)},
|
||||
Handler: JsonData{h.GetName(), getJson(t, h)},
|
||||
Process: JsonData{p.GetName(), getJson(t, p)},
|
||||
Action: JsonData{a.GetName(), getJson(t, a)},
|
||||
Limiters: make([]JsonData, 0),
|
||||
}
|
||||
rl, err := js.ToRule()
|
||||
assert.NoError(t, err)
|
||||
err = rl.Init()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, rl.Server.(*server.TcpServer).ServerAddr, "127.0.0.1:0")
|
||||
}
|
45
core/rule/rule_test.go
Normal file
45
core/rule/rule_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/action"
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/limiter"
|
||||
"ehang.io/nps/core/process"
|
||||
"ehang.io/nps/core/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRule(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
r := &Rule{
|
||||
Server: &server.TcpServer{ServerAddr: "127.0.0.1:0"},
|
||||
Handler: &handler.DefaultHandler{},
|
||||
Process: &process.DefaultProcess{},
|
||||
Action: &action.LocalAction{TargetAddr: []string{ln.Addr().String()}},
|
||||
Limiters: make([]limiter.Limiter, 0),
|
||||
}
|
||||
err = r.Init()
|
||||
assert.NoError(t, err)
|
||||
data := []byte("test")
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
b := make([]byte, 1024)
|
||||
n, err := conn.Read(b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, b[:n])
|
||||
_, err = conn.Write(b[:n])
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
conn, err := net.Dial(r.Server.GetName(), r.Server.GetServerAddr())
|
||||
assert.NoError(t, err)
|
||||
_, err = conn.Write(data)
|
||||
assert.NoError(t, err)
|
||||
b := make([]byte, 1024)
|
||||
n, err := conn.Read(b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, b[:n], data)
|
||||
}
|
30
core/rule/sort.go
Normal file
30
core/rule/sort.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package rule
|
||||
|
||||
import "ehang.io/nps/core/process"
|
||||
|
||||
type Sort []*Rule
|
||||
|
||||
func (s Sort) Len() int { return len(s) }
|
||||
func (s Sort) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
// Less rule sort by
|
||||
func (s Sort) Less(i, j int) bool {
|
||||
iHandlerSort := orderMap[s[i].Handler.GetName()]
|
||||
iProcessSort := orderMap[s[i].Process.GetName()]
|
||||
jHandlerSort := orderMap[s[j].Handler.GetName()]
|
||||
jProcessSort := orderMap[s[j].Process.GetName()]
|
||||
iSort := iHandlerSort<<16 | iProcessSort<<8
|
||||
jSort := jHandlerSort<<16 | jProcessSort<<8
|
||||
if vi, ok := s[i].Process.(*process.HttpServeProcess); ok {
|
||||
if vj, ok := s[j].Process.(*process.HttpServeProcess); ok {
|
||||
iSort = iSort | (len(vj.RouteUrl) & (2 ^ 8 - 1))
|
||||
jSort = jSort | (len(vi.RouteUrl) & (2 ^ 8 - 1))
|
||||
}
|
||||
}
|
||||
if vi, ok := s[i].Process.(*process.HttpsServeProcess); ok {
|
||||
if vj, ok := s[j].Process.(*process.HttpsServeProcess); ok {
|
||||
iSort = iSort | (len(vj.RouteUrl) & (2 ^ 8 - 1))
|
||||
jSort = jSort | (len(vi.RouteUrl) & (2 ^ 8 - 1))
|
||||
}
|
||||
}
|
||||
return iSort > jSort
|
||||
}
|
27
core/rule/sort_test.go
Normal file
27
core/rule/sort_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/core/process"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSort_Len(t *testing.T) {
|
||||
r1 := &Rule{Handler: &handler.DefaultHandler{}, Process: &process.TransparentProcess{}}
|
||||
r2 := &Rule{Handler: &handler.DefaultHandler{}, Process: &process.DefaultProcess{}}
|
||||
r3 := &Rule{Handler: &handler.DefaultHandler{}, Process: &process.HttpServeProcess{RouteUrl: "/test/aaa"}}
|
||||
r4 := &Rule{Handler: &handler.DefaultHandler{}, Process: &process.Socks5Process{}}
|
||||
r5 := &Rule{Handler: &handler.DefaultHandler{}, Process: &process.HttpServeProcess{RouteUrl: "/test"}}
|
||||
r6 := &Rule{Handler: &handler.HttpsHandler{}, Process: &process.HttpsProxyProcess{}}
|
||||
s := make(Sort, 0)
|
||||
s = append(s, r1, r2, r3, r4, r5, r6)
|
||||
sort.Sort(s)
|
||||
expected := make(Sort, 0)
|
||||
expected = append(expected, r6, r5, r3, r4, r1, r2)
|
||||
for k, v := range expected {
|
||||
if v != s[k] {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
32
core/server/server.go
Normal file
32
core/server/server.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package server
|
||||
|
||||
import "ehang.io/nps/core/handler"
|
||||
|
||||
type rule interface {
|
||||
handler.RuleRun
|
||||
GetHandler() handler.Handler
|
||||
}
|
||||
|
||||
type Server interface {
|
||||
Init() error
|
||||
Serve()
|
||||
GetServerAddr() string
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
RegisterHandle(rl rule)
|
||||
}
|
||||
|
||||
type BaseServer struct {
|
||||
handlers map[string]handler.Handler
|
||||
}
|
||||
|
||||
func (bs *BaseServer) RegisterHandle(rl rule) {
|
||||
var h handler.Handler
|
||||
var ok bool
|
||||
if h, ok = bs.handlers[rl.GetHandler().GetName()]; !ok {
|
||||
h = rl.GetHandler()
|
||||
bs.handlers[h.GetName()] = h
|
||||
}
|
||||
h.AddRule(rl)
|
||||
return
|
||||
}
|
97
core/server/tcp.go
Normal file
97
core/server/tcp.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"ehang.io/nps/lib/pool"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var bp = pool.NewBufferPool(1500)
|
||||
|
||||
type TcpServer struct {
|
||||
BaseServer
|
||||
ServerAddr string `json:"server_addr" required:"true" placeholder:"0.0.0.0:8080 or :8080" zh_name:"监听地址"`
|
||||
listener net.Listener
|
||||
gp *ants.PoolWithFunc
|
||||
}
|
||||
|
||||
func (cm *TcpServer) GetServerAddr() string {
|
||||
if cm.listener == nil {
|
||||
return cm.ServerAddr
|
||||
}
|
||||
return cm.listener.Addr().String()
|
||||
}
|
||||
|
||||
func (cm *TcpServer) Init() error {
|
||||
var err error
|
||||
cm.handlers = make(map[string]handler.Handler, 0)
|
||||
if err = cm.listen(); err != nil {
|
||||
return err
|
||||
}
|
||||
cm.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
rc := enet.NewReaderConn(i.(net.Conn))
|
||||
buf := bp.Get()
|
||||
defer bp.Put(buf)
|
||||
|
||||
if _, err := io.ReadAtLeast(rc, buf, 3); err != nil {
|
||||
logger.Warn("read handle type fom connection failed", zap.String("remote addr", rc.RemoteAddr().String()))
|
||||
_ = rc.Close()
|
||||
return
|
||||
}
|
||||
logger.Debug("read handle type", zap.Uint8("type 1", buf[0]), zap.Uint8("type 2", buf[1]),
|
||||
zap.Uint8("type 3", buf[2]), zap.String("remote addr", rc.RemoteAddr().String()))
|
||||
|
||||
for _, h := range cm.handlers {
|
||||
err = rc.Reset(0)
|
||||
if err != nil {
|
||||
logger.Warn("reset connection error", zap.Error(err), zap.String("remote addr", rc.RemoteAddr().String()))
|
||||
_ = rc.Close()
|
||||
return
|
||||
}
|
||||
ok, err := h.HandleConn(buf, rc)
|
||||
if err != nil {
|
||||
logger.Warn("handle connection error", zap.Error(err), zap.String("remote addr", rc.RemoteAddr().String()))
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
logger.Debug("handle connection success", zap.String("remote addr", rc.RemoteAddr().String()))
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *TcpServer) GetName() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (cm *TcpServer) GetZhName() string {
|
||||
return "tcp服务"
|
||||
}
|
||||
|
||||
// create a listener accept user and npc
|
||||
func (cm *TcpServer) listen() error {
|
||||
var err error
|
||||
cm.listener, err = net.Listen("tcp", cm.ServerAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *TcpServer) Serve() {
|
||||
for {
|
||||
c, err := cm.listener.Accept()
|
||||
if err != nil {
|
||||
logger.Error("accept enet error", zap.Error(err))
|
||||
break
|
||||
}
|
||||
_ = cm.gp.Invoke(c)
|
||||
}
|
||||
}
|
80
core/server/udp.go
Normal file
80
core/server/udp.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"ehang.io/nps/core/handler"
|
||||
"ehang.io/nps/lib/logger"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
)
|
||||
|
||||
type UdpServer struct {
|
||||
ServerAddr string `json:"server_addr" required:"true" placeholder:"0.0.0.0:8080 or :8080" zh_name:"监听地址"`
|
||||
gp *ants.PoolWithFunc
|
||||
packetConn net.PacketConn
|
||||
handlers map[string]handler.Handler
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
n int
|
||||
buf []byte
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func (us *UdpServer) Init() error {
|
||||
us.handlers = make(map[string]handler.Handler, 0)
|
||||
if err := us.listen(); err != nil {
|
||||
return err
|
||||
}
|
||||
var err error
|
||||
us.gp, err = ants.NewPoolWithFunc(1000000, func(i interface{}) {
|
||||
p := i.(*udpPacket)
|
||||
defer bp.Put(p.buf)
|
||||
|
||||
logger.Debug("accept a now packet", zap.String("remote addr", p.addr.String()))
|
||||
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (us *UdpServer) GetServerAddr() string {
|
||||
if us.packetConn == nil {
|
||||
return us.ServerAddr
|
||||
}
|
||||
return us.packetConn.LocalAddr().String()
|
||||
}
|
||||
|
||||
func (us *UdpServer) GetName() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
func (us *UdpServer) GetZhName() string {
|
||||
return "udp服务"
|
||||
}
|
||||
|
||||
func (us *UdpServer) listen() error {
|
||||
addr, err := net.ResolveUDPAddr("udp", us.ServerAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
us.packetConn, err = net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (us *UdpServer) Serve() {
|
||||
for {
|
||||
buf := bp.Get()
|
||||
n, addr, err := us.packetConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
logger.Error("accept packet failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
err = us.gp.Invoke(udpPacket{n: n, buf: buf, addr: addr})
|
||||
if err != nil {
|
||||
logger.Error("Invoke error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user