mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-08 18:39:01 +00:00
add new file
This commit is contained in:
43
core/limiter/conn_num.go
Normal file
43
core/limiter/conn_num.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConnNumLimiter is used to limit the connection num of a service
|
||||
type ConnNumLimiter struct {
|
||||
baseLimiter
|
||||
nowNum int32
|
||||
MaxConnNum int32 `json:"max_conn_num" required:"true" placeholder:"10" zh_name:"最大连接数"` //0 means not limit
|
||||
}
|
||||
|
||||
func (cl *ConnNumLimiter) GetName() string {
|
||||
return "conn_num"
|
||||
}
|
||||
|
||||
func (cl *ConnNumLimiter) GetZhName() string {
|
||||
return "总连接数限制"
|
||||
}
|
||||
|
||||
// DoLimit return an error if the connection num exceed the maximum
|
||||
func (cl *ConnNumLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
if atomic.AddInt32(&cl.nowNum, 1) > cl.MaxConnNum && cl.MaxConnNum > 0 {
|
||||
atomic.AddInt32(&cl.nowNum, -1)
|
||||
return nil, errors.New("exceed maximum number of connections")
|
||||
}
|
||||
return &connNumConn{nowNum: &cl.nowNum}, nil
|
||||
}
|
||||
|
||||
// connNumConn is an implementation of enet.Conn
|
||||
type connNumConn struct {
|
||||
nowNum *int32
|
||||
enet.Conn
|
||||
}
|
||||
|
||||
// Close decrease the connection num
|
||||
func (cn *connNumConn) Close() error {
|
||||
atomic.AddInt32(cn.nowNum, -1)
|
||||
return cn.Conn.Close()
|
||||
}
|
35
core/limiter/conn_num_test.go
Normal file
35
core/limiter/conn_num_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnNumLimiter(t *testing.T) {
|
||||
cl := ConnNumLimiter{MaxConnNum: 5}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowNum := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
nowNum++
|
||||
_, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
if nowNum > 5 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 6; i > 0; i-- {
|
||||
go net.Dial("tcp", ln.Addr().String())
|
||||
}
|
||||
<-close
|
||||
}
|
87
core/limiter/flow.go
Normal file
87
core/limiter/flow.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// FlowStore is an interface to store or get the flow now
|
||||
type FlowStore interface {
|
||||
GetOutIn() (uint32, uint32)
|
||||
AddOut(out uint32) uint32
|
||||
AddIn(in uint32) uint32
|
||||
}
|
||||
|
||||
// memStore is an implement of FlowStore
|
||||
type memStore struct {
|
||||
nowOut uint32
|
||||
nowIn uint32
|
||||
}
|
||||
|
||||
// GetOutIn return out and in num 0
|
||||
func (m *memStore) GetOutIn() (uint32, uint32) {
|
||||
return m.nowOut, m.nowIn
|
||||
}
|
||||
|
||||
// AddOut is used to add out now
|
||||
func (m *memStore) AddOut(out uint32) uint32 {
|
||||
return atomic.AddUint32(&m.nowOut, out)
|
||||
}
|
||||
|
||||
// AddIn is used to add in now
|
||||
func (m *memStore) AddIn(in uint32) uint32 {
|
||||
return atomic.AddUint32(&m.nowIn, in)
|
||||
}
|
||||
|
||||
// FlowLimiter is used to limit the flow of a service
|
||||
type FlowLimiter struct {
|
||||
Store FlowStore
|
||||
OutLimit uint32 `json:"out_limit" required:"true" placeholder:"1024(kb)" zh_name:"出口最大流量"` //unit: kb, 0 means not limit
|
||||
InLimit uint32 `json:"in_limit" required:"true" placeholder:"1024(kb)" zh_name:"入口最大流量"` //unit: kb, 0 means not limit
|
||||
}
|
||||
|
||||
func (f *FlowLimiter) GetName() string {
|
||||
return "flow"
|
||||
}
|
||||
|
||||
func (f *FlowLimiter) GetZhName() string {
|
||||
return "流量限制"
|
||||
}
|
||||
|
||||
// DoLimit return a flow limited enet.Conn
|
||||
func (f *FlowLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
return &flowConn{fl: f, Conn: c}, nil
|
||||
}
|
||||
|
||||
// Init is used to set out or in num now
|
||||
func (f *FlowLimiter) Init() error {
|
||||
if f.Store == nil {
|
||||
f.Store = &memStore{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// flowConn is an implement of
|
||||
type flowConn struct {
|
||||
enet.Conn
|
||||
fl *FlowLimiter
|
||||
}
|
||||
|
||||
// Read add the in flow num of the service
|
||||
func (fs *flowConn) Read(b []byte) (n int, err error) {
|
||||
n, err = fs.Conn.Read(b)
|
||||
if fs.fl.InLimit > 0 && fs.fl.Store.AddIn(uint32(n)) > fs.fl.InLimit {
|
||||
err = errors.New("exceed the in flow limit")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Write add the out flow num of the service
|
||||
func (fs *flowConn) Write(b []byte) (n int, err error) {
|
||||
n, err = fs.Conn.Write(b)
|
||||
if fs.fl.OutLimit > 0 && fs.fl.Store.AddOut(uint32(n)) > fs.fl.OutLimit {
|
||||
err = errors.New("exceed the out flow limit")
|
||||
}
|
||||
return
|
||||
}
|
59
core/limiter/flow_test.go
Normal file
59
core/limiter/flow_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlowLimiter(t *testing.T) {
|
||||
cl := FlowLimiter{
|
||||
OutLimit: 100,
|
||||
InLimit: 100,
|
||||
}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowBytes := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 10)
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
c, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
nowBytes += n
|
||||
if nowBytes > 100 {
|
||||
assert.Error(t, err)
|
||||
nowBytes = 0
|
||||
for i := 11; i > 0; i-- {
|
||||
n, err = c.Write(bytes.Repeat([]byte{0}, 10))
|
||||
nowBytes += n
|
||||
if nowBytes > 100 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Write(bytes.Repeat([]byte{0}, 10))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
<-close
|
||||
}
|
99
core/limiter/ip_conn_num.go
Normal file
99
core/limiter/ip_conn_num.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ipNumMap is used to store the connection num of a ip address
|
||||
type ipNumMap struct {
|
||||
m map[string]int32
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// AddOrSet is used to add connection num of a ip address
|
||||
func (i *ipNumMap) AddOrSet(key string) {
|
||||
i.Lock()
|
||||
if v, ok := i.m[key]; ok {
|
||||
i.m[key] = v + 1
|
||||
} else {
|
||||
i.m[key] = 1
|
||||
}
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
// SubOrDel is used to decrease connection of a ip address
|
||||
func (i *ipNumMap) SubOrDel(key string) {
|
||||
i.Lock()
|
||||
if v, ok := i.m[key]; ok {
|
||||
i.m[key] = v - 1
|
||||
if i.m[key] == 0 {
|
||||
delete(i.m, key)
|
||||
}
|
||||
}
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
// Get return the connection num of a ip
|
||||
func (i *ipNumMap) Get(key string) int32 {
|
||||
return i.m[key]
|
||||
}
|
||||
|
||||
// IpConnNumLimiter is used to limit the connection num of a service at the same time of same ip
|
||||
type IpConnNumLimiter struct {
|
||||
m *ipNumMap
|
||||
MaxNum int32 `json:"max_num" required:"true" placeholder:"10" zh_name:"单ip最大连接数"`
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (cl *IpConnNumLimiter) GetName() string {
|
||||
return "ip_conn_num"
|
||||
}
|
||||
|
||||
func (cl *IpConnNumLimiter) GetZhName() string {
|
||||
return "单ip连接数限制"
|
||||
}
|
||||
|
||||
// Init the ipNumMap
|
||||
func (cl *IpConnNumLimiter) Init() error {
|
||||
cl.m = &ipNumMap{m: make(map[string]int32)}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoLimit reports whether the connection num of the ip exceed the maximum number
|
||||
// If true, return error
|
||||
func (cl *IpConnNumLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return c, errors.Wrap(err, "split ip addr")
|
||||
}
|
||||
if cl.m.Get(ip) >= cl.MaxNum {
|
||||
return c, errors.Errorf("the ip(%s) exceed the maximum number(%d)", ip, cl.MaxNum)
|
||||
}
|
||||
return NewNumConn(c, ip, cl.m), nil
|
||||
}
|
||||
|
||||
// numConn is an implement of enet.Conn
|
||||
type numConn struct {
|
||||
key string
|
||||
m *ipNumMap
|
||||
enet.Conn
|
||||
}
|
||||
|
||||
// NewNumConn return a numConn
|
||||
func NewNumConn(c enet.Conn, key string, m *ipNumMap) *numConn {
|
||||
m.AddOrSet(key)
|
||||
return &numConn{
|
||||
m: m,
|
||||
key: key,
|
||||
Conn: c,
|
||||
}
|
||||
}
|
||||
|
||||
// Close is used to decrease the connection num of a ip when connection closing
|
||||
func (n *numConn) Close() error {
|
||||
n.m.SubOrDel(n.key)
|
||||
return n.Conn.Close()
|
||||
}
|
35
core/limiter/ip_conn_num_test.go
Normal file
35
core/limiter/ip_conn_num_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIpConnNumLimiter(t *testing.T) {
|
||||
cl := IpConnNumLimiter{MaxNum: 5}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowNum := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
nowNum++
|
||||
_, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
if nowNum > 5 {
|
||||
assert.Error(t, err)
|
||||
close <- struct{}{}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 6; i > 0; i-- {
|
||||
go net.Dial("tcp", ln.Addr().String())
|
||||
}
|
||||
<-close
|
||||
}
|
26
core/limiter/limiter.go
Normal file
26
core/limiter/limiter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Limiter = (*RateLimiter)(nil)
|
||||
_ Limiter = (*ConnNumLimiter)(nil)
|
||||
_ Limiter = (*IpConnNumLimiter)(nil)
|
||||
_ Limiter = (*FlowLimiter)(nil)
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
DoLimit(conn enet.Conn) (enet.Conn, error)
|
||||
Init() error
|
||||
GetName() string
|
||||
GetZhName() string
|
||||
}
|
||||
|
||||
type baseLimiter struct {
|
||||
}
|
||||
|
||||
func (bl *baseLimiter) Init() error {
|
||||
return nil
|
||||
}
|
67
core/limiter/rate.go
Normal file
67
core/limiter/rate.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"ehang.io/nps/lib/enet"
|
||||
"ehang.io/nps/lib/rate"
|
||||
)
|
||||
|
||||
// RateLimiter is used to limit the speed of transport
|
||||
type RateLimiter struct {
|
||||
baseLimiter
|
||||
RateLimit int64 `json:"rate_limit" required:"true" placeholder:"10(kb)" zh_name:"最大速度"`
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetName() string {
|
||||
return "rate"
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetZhName() string {
|
||||
return "带宽限制"
|
||||
}
|
||||
|
||||
// Init the rate controller
|
||||
func (rl *RateLimiter) Init() error {
|
||||
if rl.RateLimit > 0 && rl.rate == nil {
|
||||
rl.rate = rate.NewRate(rl.RateLimit)
|
||||
rl.rate.Start()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoLimit return limited Conn
|
||||
func (rl *RateLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
|
||||
return NewRateConn(c, rl.rate), nil
|
||||
}
|
||||
|
||||
// rateConn is used to limiter the rate fo connection
|
||||
type rateConn struct {
|
||||
enet.Conn
|
||||
rate *rate.Rate
|
||||
}
|
||||
|
||||
// NewRateConn return limited connection by rate interface
|
||||
func NewRateConn(rc enet.Conn, rate *rate.Rate) enet.Conn {
|
||||
return &rateConn{
|
||||
Conn: rc,
|
||||
rate: rate,
|
||||
}
|
||||
}
|
||||
|
||||
// Read data and remove capacity from rate pool
|
||||
func (s *rateConn) Read(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Read(b)
|
||||
if s.rate != nil && err == nil {
|
||||
err = s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Write data and remove capacity from rate pool
|
||||
func (s *rateConn) Write(b []byte) (n int, err error) {
|
||||
n, err = s.Conn.Write(b)
|
||||
if s.rate != nil && err == nil {
|
||||
err = s.rate.Get(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
46
core/limiter/rate_test.go
Normal file
46
core/limiter/rate_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ehang.io/nps/lib/enet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
cl := RateLimiter{
|
||||
RateLimit: 100,
|
||||
}
|
||||
assert.NoError(t, cl.Init())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
nowBytes := 0
|
||||
close := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 10)
|
||||
c, err := ln.Accept()
|
||||
assert.NoError(t, err)
|
||||
c, err = cl.DoLimit(enet.NewReaderConn(c))
|
||||
go func() {
|
||||
<-time.After(time.Second * 2)
|
||||
if nowBytes > 500 {
|
||||
t.Fail()
|
||||
}
|
||||
close <- struct{}{}
|
||||
}()
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
nowBytes += n
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(t, err)
|
||||
for i := 11; i > 0; i-- {
|
||||
_, err := c.Write(bytes.Repeat([]byte{0}, 10000))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
<-close
|
||||
}
|
Reference in New Issue
Block a user