mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-02 11:56:53 +00:00
Bug修复+流量限制+带宽限制
This commit is contained in:
@@ -21,12 +21,14 @@ const cryptKey = "1234567812345678"
|
||||
type CryptConn struct {
|
||||
conn net.Conn
|
||||
crypt bool
|
||||
rate *Rate
|
||||
}
|
||||
|
||||
func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
|
||||
func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn {
|
||||
c := new(CryptConn)
|
||||
c.conn = conn
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -42,6 +44,9 @@ func (s *CryptConn) Write(b []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
_, err = s.conn.Write(b)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -72,6 +77,9 @@ func (s *CryptConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
copy(b, rb)
|
||||
n = len(rb)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -79,13 +87,15 @@ type SnappyConn struct {
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
crypt bool
|
||||
rate *Rate
|
||||
}
|
||||
|
||||
func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
|
||||
func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn {
|
||||
c := new(SnappyConn)
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
c.crypt = crypt
|
||||
c.rate = rate
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -101,7 +111,12 @@ func (s *SnappyConn) Write(b []byte) (n int, err error) {
|
||||
if _, err = s.w.Write(b); err != nil {
|
||||
return
|
||||
}
|
||||
err = s.w.Flush()
|
||||
if err = s.w.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -129,6 +144,9 @@ func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
n = len(bs)
|
||||
copy(b, bs)
|
||||
if s.rate != nil {
|
||||
s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,6 +251,10 @@ func (s *Conn) SetAlive() {
|
||||
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
||||
}
|
||||
|
||||
func (s *Conn) SetReadDeadline(t time.Duration) {
|
||||
s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
|
||||
}
|
||||
|
||||
//从tcp报文中解析出host,连接类型等 TODO 多种情况
|
||||
func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
|
||||
var b [32 * 1024]byte
|
||||
@@ -264,19 +286,19 @@ func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.
|
||||
}
|
||||
|
||||
//单独读(加密|压缩)
|
||||
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
|
||||
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
|
||||
if COMPRESS_SNAPY_DECODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt).Read(b)
|
||||
return NewSnappyConn(s.Conn, crypt, rate).Read(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt).Read(b)
|
||||
return NewCryptConn(s.Conn, crypt, rate).Read(b)
|
||||
}
|
||||
|
||||
//单独写(加密|压缩)
|
||||
func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
|
||||
func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) {
|
||||
if COMPRESS_SNAPY_ENCODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt).Write(b)
|
||||
return NewSnappyConn(s.Conn, crypt, rate).Write(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt).Write(b)
|
||||
return NewCryptConn(s.Conn, crypt, rate).Write(b)
|
||||
}
|
||||
|
||||
//写压缩方式,加密
|
||||
@@ -322,6 +344,11 @@ func (s *Conn) WriteSign() (int, error) {
|
||||
return s.Write([]byte(RES_SIGN))
|
||||
}
|
||||
|
||||
//write sign flag
|
||||
func (s *Conn) WriteClose() (int, error) {
|
||||
return s.Write([]byte(RES_CLOSE))
|
||||
}
|
||||
|
||||
//write main
|
||||
func (s *Conn) WriteMain() (int, error) {
|
||||
return s.Write([]byte(WORK_MAIN))
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,7 @@ var (
|
||||
type Flow struct {
|
||||
ExportFlow int64 //出口流量
|
||||
InletFlow int64 //入口流量
|
||||
FlowLimit int64 //流量限制,出口+入口 /M
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -29,7 +31,9 @@ type Client struct {
|
||||
Remark string //备注
|
||||
Status bool //是否开启
|
||||
IsConnect bool //是否连接
|
||||
Flow *Flow
|
||||
RateLimit int //速度限制 /kb
|
||||
Flow *Flow //流量
|
||||
Rate *Rate //速度控制
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
@@ -189,7 +193,9 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
|
||||
defer s.Unlock()
|
||||
for _, v := range s.Clients {
|
||||
if utils.Getverifyval(v.VerifyKey) == vKey && v.Status {
|
||||
v.Addr = addr
|
||||
if arr := strings.Split(addr, ":"); len(arr) > 0 {
|
||||
v.Addr = arr[0]
|
||||
}
|
||||
return v.Id, nil
|
||||
}
|
||||
}
|
||||
@@ -276,21 +282,26 @@ func (s *Csv) LoadClientFromCsv() {
|
||||
post := &Client{
|
||||
Id: GetIntNoErrByStr(item[0]),
|
||||
VerifyKey: item[1],
|
||||
Addr: item[2],
|
||||
Remark: item[3],
|
||||
Status: GetBoolByStr(item[4]),
|
||||
Remark: item[2],
|
||||
Status: GetBoolByStr(item[3]),
|
||||
RateLimit: GetIntNoErrByStr(item[9]),
|
||||
Cnf: &Config{
|
||||
U: item[5],
|
||||
P: item[6],
|
||||
Crypt: GetBoolByStr(item[7]),
|
||||
Mux: GetBoolByStr(item[8]),
|
||||
Compress: item[9],
|
||||
U: item[4],
|
||||
P: item[5],
|
||||
Crypt: GetBoolByStr(item[6]),
|
||||
Mux: GetBoolByStr(item[7]),
|
||||
Compress: item[8],
|
||||
},
|
||||
}
|
||||
if post.Id > s.ClientIncreaseId {
|
||||
s.ClientIncreaseId = post.Id
|
||||
}
|
||||
if post.RateLimit > 0 {
|
||||
post.Rate = NewRate(int64(post.RateLimit * 1024))
|
||||
post.Rate.Start()
|
||||
}
|
||||
post.Flow = new(Flow)
|
||||
post.Flow.FlowLimit = int64(utils.GetIntNoerrByStr(item[10]))
|
||||
clients = append(clients, post)
|
||||
}
|
||||
s.Clients = clients
|
||||
@@ -442,7 +453,6 @@ func (s *Csv) StoreClientsToCsv() {
|
||||
record := []string{
|
||||
strconv.Itoa(client.Id),
|
||||
client.VerifyKey,
|
||||
client.Addr,
|
||||
client.Remark,
|
||||
strconv.FormatBool(client.Status),
|
||||
client.Cnf.U,
|
||||
@@ -450,6 +460,8 @@ func (s *Csv) StoreClientsToCsv() {
|
||||
utils.GetStrByBool(client.Cnf.Crypt),
|
||||
utils.GetStrByBool(client.Cnf.Mux),
|
||||
client.Cnf.Compress,
|
||||
strconv.Itoa(client.RateLimit),
|
||||
strconv.Itoa(int(client.Flow.FlowLimit)),
|
||||
}
|
||||
err := writer.Write(record)
|
||||
if err != nil {
|
||||
|
@@ -12,6 +12,7 @@ var bufPool = sync.Pool{
|
||||
return make([]byte, poolSize)
|
||||
},
|
||||
}
|
||||
|
||||
var BufPoolUdp = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, poolSizeUdp)
|
||||
|
74
utils/rate.go
Normal file
74
utils/rate.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Rate struct {
|
||||
bucketSize int64 //木桶容量
|
||||
bucketSurplusSize int64 //当前桶中体积
|
||||
bucketAddSize int64 //每次加水大小
|
||||
stopChan chan bool //停止
|
||||
}
|
||||
|
||||
func NewRate(addSize int64) *Rate {
|
||||
return &Rate{
|
||||
bucketSize: addSize * 2,
|
||||
bucketSurplusSize: 0,
|
||||
bucketAddSize: addSize,
|
||||
stopChan: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Rate) Start() {
|
||||
go s.session()
|
||||
}
|
||||
|
||||
func (s *Rate) add(size int64) {
|
||||
if (s.bucketSize - s.bucketSurplusSize) < s.bucketAddSize {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&s.bucketSurplusSize, size)
|
||||
}
|
||||
|
||||
//回桶
|
||||
func (s *Rate) ReturnBucket(size int64) {
|
||||
s.add(size)
|
||||
}
|
||||
|
||||
//停止
|
||||
func (s *Rate) Stop() {
|
||||
s.stopChan <- true
|
||||
}
|
||||
|
||||
func (s *Rate) Get(size int64) {
|
||||
if s.bucketSurplusSize >= size {
|
||||
atomic.AddInt64(&s.bucketSurplusSize, -size)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(time.Millisecond * 100)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.bucketSurplusSize >= size {
|
||||
atomic.AddInt64(&s.bucketSurplusSize, -size)
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Rate) session() {
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.add(s.bucketAddSize)
|
||||
case <-s.stopChan:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
23
utils/rate_test.go
Normal file
23
utils/rate_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var rate = NewRate(100 * 1024)
|
||||
|
||||
func TestRate_Get(t *testing.T) {
|
||||
rate.Start()
|
||||
for i := 0; i < 5; i++ {
|
||||
go test(i)
|
||||
}
|
||||
test(5)
|
||||
}
|
||||
|
||||
func test(i int) {
|
||||
for {
|
||||
rate.Get(64 * 1024)
|
||||
log.Println("get ok", i)
|
||||
}
|
||||
}
|
@@ -25,6 +25,7 @@ const (
|
||||
WORK_CHAN = "chan"
|
||||
RES_SIGN = "sign"
|
||||
RES_MSG = "msg0"
|
||||
RES_CLOSE = "clse"
|
||||
CONN_SUCCESS = "sucs"
|
||||
CONN_ERROR = "fail"
|
||||
TEST_FLAG = "tst"
|
||||
@@ -42,24 +43,24 @@ WWW-Authenticate: Basic realm="easyProxy"
|
||||
)
|
||||
|
||||
//copy
|
||||
func Relay(in, out net.Conn, compressType int, crypt, mux bool) (n int64, err error) {
|
||||
func Relay(in, out net.Conn, compressType int, crypt, mux bool, rate *Rate) (n int64, err error) {
|
||||
switch compressType {
|
||||
case COMPRESS_SNAPY_ENCODE:
|
||||
n, err = copyBuffer(NewSnappyConn(in, crypt), out)
|
||||
n, err = copyBuffer(NewSnappyConn(in, crypt, rate), out)
|
||||
out.Close()
|
||||
NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
|
||||
NewSnappyConn(in, crypt, rate).Write([]byte(IO_EOF))
|
||||
case COMPRESS_SNAPY_DECODE:
|
||||
n, err = copyBuffer(in, NewSnappyConn(out, crypt))
|
||||
n, err = copyBuffer(in, NewSnappyConn(out, crypt, rate))
|
||||
in.Close()
|
||||
if !mux {
|
||||
out.Close()
|
||||
}
|
||||
case COMPRESS_NONE_ENCODE:
|
||||
n, err = copyBuffer(NewCryptConn(in, crypt), out)
|
||||
n, err = copyBuffer(NewCryptConn(in, crypt, rate), out)
|
||||
out.Close()
|
||||
NewCryptConn(in, crypt).Write([]byte(IO_EOF))
|
||||
NewCryptConn(in, crypt, rate).Write([]byte(IO_EOF))
|
||||
case COMPRESS_NONE_DECODE:
|
||||
n, err = copyBuffer(in, NewCryptConn(out, crypt))
|
||||
n, err = copyBuffer(in, NewCryptConn(out, crypt, rate))
|
||||
in.Close()
|
||||
if !mux {
|
||||
out.Close()
|
||||
@@ -205,14 +206,14 @@ func Getverifyval(vkey string) string {
|
||||
|
||||
//wait replay group
|
||||
//conn1 网桥 conn2
|
||||
func ReplayWaitGroup(conn1 net.Conn, conn2 net.Conn, compressEncode, compressDecode int, crypt, mux bool) (out int64, in int64) {
|
||||
func ReplayWaitGroup(conn1 net.Conn, conn2 net.Conn, compressEncode, compressDecode int, crypt, mux bool, rate *Rate) (out int64, in int64) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
in, _ = Relay(conn1, conn2, compressEncode, crypt, mux)
|
||||
in, _ = Relay(conn1, conn2, compressEncode, crypt, mux, rate)
|
||||
wg.Done()
|
||||
}()
|
||||
out, _ = Relay(conn2, conn1, compressDecode, crypt, mux)
|
||||
out, _ = Relay(conn2, conn1, compressDecode, crypt, mux, rate)
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user