mirror of
https://github.com/ehang-io/nps.git
synced 2025-07-03 13:10:42 +00:00
112 lines
2.2 KiB
Go
112 lines
2.2 KiB
Go
package rate
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Rate is an implementation of the token bucket added regularly
|
|
type Rate struct {
|
|
bucketSize int64
|
|
bucketSurplusSize int64
|
|
bucketAddSize int64
|
|
stopChan chan bool
|
|
nowRate int64
|
|
cond *sync.Cond
|
|
hasStop bool
|
|
hasStart bool
|
|
}
|
|
|
|
// NewRate return token bucket with specified rate
|
|
func NewRate(addSize int64) *Rate {
|
|
r := &Rate{
|
|
bucketSize: addSize * 2,
|
|
bucketSurplusSize: 0,
|
|
bucketAddSize: addSize,
|
|
stopChan: make(chan bool),
|
|
cond: sync.NewCond(new(sync.Mutex)),
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Start is used to add token regularly
|
|
func (r *Rate) Start() {
|
|
if !r.hasStart {
|
|
r.hasStart = true
|
|
go r.session()
|
|
}
|
|
}
|
|
|
|
func (r *Rate) add(size int64) {
|
|
if res := r.bucketSize - r.bucketSurplusSize; res < r.bucketAddSize {
|
|
atomic.AddInt64(&r.bucketSurplusSize, res)
|
|
return
|
|
}
|
|
atomic.AddInt64(&r.bucketSurplusSize, size)
|
|
}
|
|
|
|
// Write is called when add token to bucket
|
|
func (r *Rate) Write(size int64) {
|
|
r.add(size)
|
|
}
|
|
|
|
// Stop is called when not use the rate bucket
|
|
func (r *Rate) Stop() {
|
|
if r.hasStart {
|
|
r.stopChan <- true
|
|
r.hasStop = true
|
|
r.cond.Broadcast()
|
|
}
|
|
}
|
|
|
|
// Get is called when get token from bucket
|
|
func (r *Rate) Get(size int64) error {
|
|
if r.hasStop {
|
|
return errors.New("the rate has closed")
|
|
}
|
|
if r.bucketSurplusSize >= size {
|
|
atomic.AddInt64(&r.bucketSurplusSize, -size)
|
|
return nil
|
|
}
|
|
for {
|
|
r.cond.L.Lock()
|
|
r.cond.Wait()
|
|
if r.bucketSurplusSize >= size {
|
|
r.cond.L.Unlock()
|
|
atomic.AddInt64(&r.bucketSurplusSize, -size)
|
|
return nil
|
|
}
|
|
if r.hasStop {
|
|
return errors.New("the rate has closed")
|
|
}
|
|
r.cond.L.Unlock()
|
|
}
|
|
}
|
|
|
|
// GetNowRate returns the current rate
|
|
// Just a rough number
|
|
func (r *Rate) GetNowRate() int64 {
|
|
return r.nowRate
|
|
}
|
|
|
|
func (r *Rate) session() {
|
|
ticker := time.NewTicker(time.Second * 1)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if rs := r.bucketAddSize - r.bucketSurplusSize; rs > 0 {
|
|
r.nowRate = rs
|
|
} else {
|
|
r.nowRate = r.bucketSize - r.bucketSurplusSize
|
|
}
|
|
r.add(r.bucketAddSize)
|
|
r.cond.Broadcast()
|
|
case <-r.stopChan:
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|