MUX optimization

This commit is contained in:
刘河
2019-03-15 14:03:49 +08:00
parent f78e81b452
commit 97330bfbdc
33 changed files with 749 additions and 328 deletions

View File

@@ -32,7 +32,7 @@ type Csv struct {
ClientIncreaseId int //客户端id
TaskIncreaseId int //任务自增ID
HostIncreaseId int
sync.Mutex
sync.RWMutex
}
func (s *Csv) StoreTasksToCsv() {
@@ -43,6 +43,7 @@ func (s *Csv) StoreTasksToCsv() {
}
defer csvFile.Close()
writer := csv.NewWriter(csvFile)
s.Lock()
for _, task := range s.Tasks {
if task.NoStore {
continue
@@ -64,6 +65,7 @@ func (s *Csv) StoreTasksToCsv() {
logs.Error(err.Error())
}
}
s.Unlock()
writer.Flush()
}
@@ -147,6 +149,7 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
}
func (s *Csv) NewTask(t *Tunnel) error {
s.Lock()
for _, v := range s.Tasks {
if (v.Mode == "secret" || v.Mode == "p2p") && v.Password == t.Password {
return errors.New(fmt.Sprintf("Secret mode keys %s must be unique", t.Password))
@@ -154,33 +157,42 @@ func (s *Csv) NewTask(t *Tunnel) error {
}
t.Flow = new(Flow)
s.Tasks = append(s.Tasks, t)
s.Unlock()
s.StoreTasksToCsv()
return nil
}
func (s *Csv) UpdateTask(t *Tunnel) error {
s.Lock()
for _, v := range s.Tasks {
if v.Id == t.Id {
s.Unlock()
s.StoreTasksToCsv()
return nil
}
}
s.Unlock()
return errors.New("the task is not exist")
}
func (s *Csv) DelTask(id int) error {
s.Lock()
for k, v := range s.Tasks {
if v.Id == id {
s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
s.Unlock()
s.StoreTasksToCsv()
return nil
}
}
s.Unlock()
return errors.New("不存在")
}
//md5 password
func (s *Csv) GetTaskByMd5Password(p string) *Tunnel {
s.Lock()
defer s.Unlock()
for _, v := range s.Tasks {
if crypt.Md5(v.Password) == p {
return v
@@ -190,6 +202,8 @@ func (s *Csv) GetTaskByMd5Password(p string) *Tunnel {
}
func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
s.Lock()
defer s.Unlock()
for _, v = range s.Tasks {
if v.Id == id {
return
@@ -210,6 +224,8 @@ func (s *Csv) StoreHostToCsv() {
writer := csv.NewWriter(csvFile)
// 将map中的Post转换成slice因为csv的Write需要slice参数
// 并写入csv文件
s.Lock()
defer s.Unlock()
for _, host := range s.Hosts {
if host.NoStore {
continue
@@ -313,17 +329,22 @@ func (s *Csv) LoadHostFromCsv() {
}
func (s *Csv) DelHost(id int) error {
s.Lock()
for k, v := range s.Hosts {
if v.Id == id {
s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
s.Unlock()
s.StoreHostToCsv()
return nil
}
}
s.Unlock()
return errors.New("不存在")
}
func (s *Csv) IsHostExist(h *Host) bool {
s.Lock()
defer s.Unlock()
for _, v := range s.Hosts {
if v.Host == h.Host && h.Location == v.Location && (v.Scheme == "all" || v.Scheme == h.Scheme) {
return true
@@ -340,24 +361,31 @@ func (s *Csv) NewHost(t *Host) error {
t.Location = "/"
}
t.Flow = new(Flow)
s.Lock()
s.Hosts = append(s.Hosts, t)
s.Unlock()
s.StoreHostToCsv()
return nil
}
func (s *Csv) UpdateHost(t *Host) error {
s.Lock()
for _, v := range s.Hosts {
if v.Host == t.Host {
s.Unlock()
s.StoreHostToCsv()
return nil
}
}
s.Unlock()
return errors.New("不存在")
}
func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) {
list := make([]*Host, 0)
var cnt int
s.Lock()
defer s.Unlock()
for _, v := range s.Hosts {
if id == 0 || v.Client.Id == id {
cnt++
@@ -372,13 +400,16 @@ func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) {
}
func (s *Csv) DelClient(id int) error {
s.Lock()
for k, v := range s.Clients {
if v.Id == id {
s.Clients = append(s.Clients[:k], s.Clients[k+1:]...)
s.Unlock()
s.StoreClientsToCsv()
return nil
}
}
s.Unlock()
return errors.New("不存在")
}
@@ -402,13 +433,15 @@ reset:
c.Flow = new(Flow)
}
s.Lock()
defer s.Unlock()
s.Clients = append(s.Clients, c)
s.Unlock()
s.StoreClientsToCsv()
return nil
}
func (s *Csv) VerifyVkey(vkey string, id int) bool {
s.Lock()
defer s.Unlock()
for _, v := range s.Clients {
if v.VerifyKey == vkey && v.Id != id {
return false
@@ -426,7 +459,6 @@ func (s *Csv) GetClientId() int {
func (s *Csv) UpdateClient(t *Client) error {
s.Lock()
defer s.Unlock()
for _, v := range s.Clients {
if v.Id == t.Id {
v.Cnf = t.Cnf
@@ -435,16 +467,20 @@ func (s *Csv) UpdateClient(t *Client) error {
v.RateLimit = t.RateLimit
v.Flow = t.Flow
v.Rate = t.Rate
s.Unlock()
s.StoreClientsToCsv()
return nil
}
}
s.Unlock()
return errors.New("该客户端不存在")
}
func (s *Csv) GetClientList(start, length int) ([]*Client, int) {
list := make([]*Client, 0)
var cnt int
s.Lock()
defer s.Unlock()
for _, v := range s.Clients {
if v.NoDisplay {
continue
@@ -460,6 +496,8 @@ func (s *Csv) GetClientList(start, length int) ([]*Client, int) {
}
func (s *Csv) GetClient(id int) (v *Client, err error) {
s.Lock()
defer s.Unlock()
for _, v = range s.Clients {
if v.Id == id {
return
@@ -469,6 +507,8 @@ func (s *Csv) GetClient(id int) (v *Client, err error) {
return
}
func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) {
s.Lock()
defer s.Unlock()
for _, v := range s.Clients {
if crypt.Md5(v.VerifyKey) == vkey {
id = v.Id
@@ -480,6 +520,8 @@ func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) {
}
func (s *Csv) GetHostById(id int) (h *Host, err error) {
s.Lock()
defer s.Unlock()
for _, v := range s.Hosts {
if v.Id == id {
h = v
@@ -495,7 +537,12 @@ func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) {
var hosts []*Host
//Handling Ported Access
host = common.GetIpByAddr(host)
s.Lock()
defer s.Unlock()
for _, v := range s.Hosts {
if v.IsClose {
continue
}
//Remove http(s) http(s)://a.proxy.com
//*.proxy.com *.a.proxy.com Do some pan-parsing
tmp := strings.Replace(v.Host, "*", `\w+?`, -1)
@@ -533,6 +580,8 @@ func (s *Csv) StoreClientsToCsv() {
}
defer csvFile.Close()
writer := csv.NewWriter(csvFile)
s.Lock()
defer s.Unlock()
for _, client := range s.Clients {
if client.NoStore {
continue

View File

@@ -2,6 +2,7 @@ package file
import (
"github.com/cnlh/nps/lib/rate"
"github.com/pkg/errors"
"strings"
"sync"
"time"
@@ -78,7 +79,14 @@ func (s *Client) GetConn() bool {
return false
}
//modify the hosts and the tunnels by health information
func (s *Client) ModifyTarget() {
}
func (s *Client) HasTunnel(t *Tunnel) bool {
GetCsvDb().Lock()
defer GetCsvDb().Unlock()
for _, v := range GetCsvDb().Tasks {
if v.Client.Id == s.Id && v.Port == t.Port {
return true
@@ -88,6 +96,8 @@ func (s *Client) HasTunnel(t *Tunnel) bool {
}
func (s *Client) HasHost(h *Host) bool {
GetCsvDb().Lock()
defer GetCsvDb().Unlock()
for _, v := range GetCsvDb().Hosts {
if v.Client.Id == s.Id && v.Host == h.Host && h.Location == v.Location {
return true
@@ -126,14 +136,19 @@ type Health struct {
HealthMap map[string]int
HttpHealthUrl string
HealthRemoveArr []string
HealthCheckType string
HealthCheckTarget string
}
func (s *Tunnel) GetRandomTarget() string {
func (s *Tunnel) GetRandomTarget() (string, error) {
if s.TargetArr == nil {
s.TargetArr = strings.Split(s.Target, "\n")
}
if len(s.TargetArr) == 1 {
return s.TargetArr[0]
return s.TargetArr[0], nil
}
if len(s.TargetArr) == 0 {
return "", errors.New("all inward-bending targets are offline")
}
s.Lock()
defer s.Unlock()
@@ -141,7 +156,7 @@ func (s *Tunnel) GetRandomTarget() string {
s.NowIndex = -1
}
s.NowIndex++
return s.TargetArr[s.NowIndex]
return s.TargetArr[s.NowIndex], nil
}
type Config struct {
@@ -165,23 +180,26 @@ type Host struct {
TargetArr []string
NoStore bool
Scheme string //http https all
IsClose bool
Health
sync.RWMutex
}
func (s *Host) GetRandomTarget() string {
func (s *Host) GetRandomTarget() (string, error) {
if s.TargetArr == nil {
s.TargetArr = strings.Split(s.Target, "\n")
}
if len(s.TargetArr) == 1 {
return s.TargetArr[0]
return s.TargetArr[0], nil
}
if len(s.TargetArr) == 0 {
return "", errors.New("all inward-bending targets are offline")
}
s.Lock()
defer s.Unlock()
if s.NowIndex >= len(s.TargetArr)-1 {
s.NowIndex = -1
} else {
s.NowIndex++
}
return s.TargetArr[s.NowIndex]
s.NowIndex++
return s.TargetArr[s.NowIndex], nil
}