nps/core/limiter/rate.go
2022-01-23 17:30:38 +08:00

68 lines
1.4 KiB
Go

package limiter
import (
"ehang.io/nps/lib/enet"
"ehang.io/nps/lib/rate"
)
// RateLimiter is used to limit the speed of transport
type RateLimiter struct {
baseLimiter
RateLimit int64 `json:"rate_limit" required:"true" placeholder:"10(kb)" zh_name:"最大速度"`
rate *rate.Rate
}
func (rl *RateLimiter) GetName() string {
return "rate"
}
func (rl *RateLimiter) GetZhName() string {
return "带宽限制"
}
// Init the rate controller
func (rl *RateLimiter) Init() error {
if rl.RateLimit > 0 && rl.rate == nil {
rl.rate = rate.NewRate(rl.RateLimit)
rl.rate.Start()
}
return nil
}
// DoLimit return limited Conn
func (rl *RateLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
return NewRateConn(c, rl.rate), nil
}
// rateConn is used to limiter the rate fo connection
type rateConn struct {
enet.Conn
rate *rate.Rate
}
// NewRateConn return limited connection by rate interface
func NewRateConn(rc enet.Conn, rate *rate.Rate) enet.Conn {
return &rateConn{
Conn: rc,
rate: rate,
}
}
// Read data and remove capacity from rate pool
func (s *rateConn) Read(b []byte) (n int, err error) {
n, err = s.Conn.Read(b)
if s.rate != nil && err == nil {
err = s.rate.Get(int64(n))
}
return
}
// Write data and remove capacity from rate pool
func (s *rateConn) Write(b []byte) (n int, err error) {
n, err = s.Conn.Write(b)
if s.rate != nil && err == nil {
err = s.rate.Get(int64(n))
}
return
}