mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-02 11:56:53 +00:00
MUX optimization
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/cnlh/nps/vender/github.com/xtaci/kcp"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -42,6 +43,7 @@ type Bridge struct {
|
||||
Client map[int]*Client
|
||||
tunnelType string //bridge type kcp or tcp
|
||||
OpenTask chan *file.Tunnel
|
||||
CloseTask chan *file.Tunnel
|
||||
CloseClient chan int
|
||||
SecretChan chan *conn.Secret
|
||||
clientLock sync.RWMutex
|
||||
@@ -57,6 +59,7 @@ func NewTunnel(tunnelPort int, tunnelType string, ipVerify bool, runList map[int
|
||||
t.Client = make(map[int]*Client)
|
||||
t.tunnelType = tunnelType
|
||||
t.OpenTask = make(chan *file.Tunnel)
|
||||
t.CloseTask = make(chan *file.Tunnel)
|
||||
t.CloseClient = make(chan int)
|
||||
t.Register = make(map[string]time.Time)
|
||||
t.ipVerify = ipVerify
|
||||
@@ -106,6 +109,62 @@ func (s *Bridge) StartTunnel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//get health information form client
|
||||
func (s *Bridge) GetHealthFromClient(id int, c *conn.Conn) {
|
||||
for {
|
||||
if info, status, err := c.GetHealthInfo(); err != nil {
|
||||
logs.Error(err)
|
||||
break
|
||||
} else if !status { //the status is true , return target to the targetArr
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if v.Client.Id == id && v.Mode == "tcp" && strings.Contains(v.Target, info) {
|
||||
v.Lock()
|
||||
if v.TargetArr == nil || (len(v.TargetArr) == 0 && len(v.HealthRemoveArr) == 0) {
|
||||
v.TargetArr = common.TrimArr(strings.Split(v.Target, "\n"))
|
||||
}
|
||||
v.TargetArr = common.RemoveArrVal(v.TargetArr, info)
|
||||
if v.HealthRemoveArr == nil {
|
||||
v.HealthRemoveArr = make([]string, 0)
|
||||
}
|
||||
v.HealthRemoveArr = append(v.HealthRemoveArr, info)
|
||||
v.Unlock()
|
||||
}
|
||||
}
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
if v.Client.Id == id && strings.Contains(v.Target, info) {
|
||||
v.Lock()
|
||||
if v.TargetArr == nil || (len(v.TargetArr) == 0 && len(v.HealthRemoveArr) == 0) {
|
||||
v.TargetArr = common.TrimArr(strings.Split(v.Target, "\n"))
|
||||
}
|
||||
v.TargetArr = common.RemoveArrVal(v.TargetArr, info)
|
||||
if v.HealthRemoveArr == nil {
|
||||
v.HealthRemoveArr = make([]string, 0)
|
||||
}
|
||||
v.HealthRemoveArr = append(v.HealthRemoveArr, info)
|
||||
v.Unlock()
|
||||
}
|
||||
}
|
||||
} else { //the status is false,remove target from the targetArr
|
||||
for _, v := range file.GetCsvDb().Tasks {
|
||||
if v.Client.Id == id && v.Mode == "tcp" && common.IsArrContains(v.HealthRemoveArr, info) && !common.IsArrContains(v.TargetArr, info) {
|
||||
v.Lock()
|
||||
v.TargetArr = append(v.TargetArr, info)
|
||||
v.HealthRemoveArr = common.RemoveArrVal(v.HealthRemoveArr, info)
|
||||
v.Unlock()
|
||||
}
|
||||
}
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
if v.Client.Id == id && common.IsArrContains(v.HealthRemoveArr, info) && !common.IsArrContains(v.TargetArr, info) {
|
||||
v.Lock()
|
||||
v.TargetArr = append(v.TargetArr, info)
|
||||
v.HealthRemoveArr = common.RemoveArrVal(v.HealthRemoveArr, info)
|
||||
v.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//验证失败,返回错误验证flag,并且关闭连接
|
||||
func (s *Bridge) verifyError(c *conn.Conn) {
|
||||
c.Write([]byte(common.VERIFY_EER))
|
||||
@@ -187,6 +246,7 @@ func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) {
|
||||
s.Client[id] = NewClient(nil, nil, c)
|
||||
s.clientLock.Unlock()
|
||||
}
|
||||
go s.GetHealthFromClient(id, c)
|
||||
logs.Info("clientId %d connection succeeded, address:%s ", id, c.Conn.RemoteAddr())
|
||||
case common.WORK_CHAN:
|
||||
s.clientLock.Lock()
|
||||
@@ -264,7 +324,7 @@ func (s *Bridge) register(c *conn.Conn) {
|
||||
var hour int32
|
||||
if err := binary.Read(c, binary.LittleEndian, &hour); err == nil {
|
||||
s.registerLock.Lock()
|
||||
s.Register[common.GetIpByAddr(c.Conn.RemoteAddr().String())] = time.Now().Add(time.Hour * time.Duration(hour))
|
||||
s.Register[common.GetIpByAddr(c.Conn.RemoteAddr().String())] = time.Now().Add(time.Minute * time.Duration(hour))
|
||||
s.registerLock.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -282,11 +342,11 @@ func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link, linkAddr string, t
|
||||
s.registerLock.Unlock()
|
||||
return nil, errors.New(fmt.Sprintf("The ip %s is not in the validation list", ip))
|
||||
} else {
|
||||
s.registerLock.Unlock()
|
||||
if !v.After(time.Now()) {
|
||||
return nil, errors.New(fmt.Sprintf("The validity of the ip %s has expired", ip))
|
||||
}
|
||||
}
|
||||
s.registerLock.Unlock()
|
||||
}
|
||||
var tunnel *mux.Mux
|
||||
if t != nil && t.Mode == "file" {
|
||||
@@ -311,7 +371,6 @@ func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link, linkAddr string, t
|
||||
logs.Info("new connect error ,the target %s refuse to connect", link.Host)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
s.clientLock.Unlock()
|
||||
err = errors.New(fmt.Sprintf("the client %d is not connect", clientId))
|
||||
@@ -366,6 +425,7 @@ loop:
|
||||
if err != nil {
|
||||
break loop
|
||||
}
|
||||
file.GetCsvDb().Lock()
|
||||
for _, v := range file.GetCsvDb().Hosts {
|
||||
if v.Client.Id == id {
|
||||
str += v.Remark + common.CONN_DATA_SEQ
|
||||
@@ -376,6 +436,7 @@ loop:
|
||||
str += v.Remark + common.CONN_DATA_SEQ
|
||||
}
|
||||
}
|
||||
file.GetCsvDb().Unlock()
|
||||
binary.Write(c, binary.LittleEndian, int32(len([]byte(str))))
|
||||
binary.Write(c, binary.LittleEndian, []byte(str))
|
||||
}
|
||||
|
Reference in New Issue
Block a user