nps/lib/enet/paket.go
2022-01-23 17:30:38 +08:00

162 lines
4.0 KiB
Go

package enet
import (
"bytes"
"ehang.io/nps/lib/common"
"ehang.io/nps/lib/pool"
"github.com/pkg/errors"
"net"
"sync/atomic"
"time"
)
var (
_ net.PacketConn = (*TcpPacketConn)(nil)
_ PacketConn = (*ReaderPacketConn)(nil)
)
type PacketConn interface {
net.PacketConn
SendPacket([]byte, net.Addr) error
FirstPacket() ([]byte, net.Addr, error)
}
var udpBp = pool.NewBufferPool(1500)
// TcpPacketConn is an implement of net.PacketConn by net.Conn
type TcpPacketConn struct {
udpBp []byte
net.Conn
}
// NewTcpPacketConn return a *TcpPacketConn
func NewTcpPacketConn(conn net.Conn) *TcpPacketConn {
return &TcpPacketConn{Conn: conn}
}
// ReadFrom is a implement of net.PacketConn ReadFrom
func (tp *TcpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
b := udpBp.Get()
defer udpBp.Put(b)
n, err = common.ReadLenBytes(tp.Conn, b)
if err != nil {
return
}
rAddr, err := common.ReadAddr(bytes.NewReader(b[:n]))
if err != nil {
return
}
n = copy(p, b[len(rAddr):n])
addr, err = net.ResolveUDPAddr("udp", rAddr.String())
return
}
// WriteTo is a implement of net.PacketConn WriteTo
func (tp *TcpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
var pAddr common.Addr
pAddr, err = common.ParseAddr(addr.String())
if err != nil {
return
}
return common.WriteLenBytes(tp.Conn, append(pAddr, p...))
}
// ReaderPacketConn is an implementation of net.PacketConn
type ReaderPacketConn struct {
ch chan *packet
closeCh chan struct{}
closed int32
nowNum int32
addr net.Addr
writePacketConn net.PacketConn
readTimer *time.Timer
firstPacket []byte
}
type packet struct {
b []byte
addr net.Addr
}
// NewReaderPacketConn returns an initialized PacketConn
func NewReaderPacketConn(writePacketConn net.PacketConn, firstPacket []byte, addr net.Addr) *ReaderPacketConn {
return &ReaderPacketConn{
ch: make(chan *packet, 10),
closeCh: make(chan struct{}),
addr: addr,
writePacketConn: writePacketConn,
readTimer: time.NewTimer(time.Hour * 24 * 3650),
firstPacket: firstPacket,
}
}
func (pc *ReaderPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
var pt *packet
select {
case pt = <-pc.ch:
case <-pc.readTimer.C:
}
if pt == nil {
return 0, nil, errors.New("the PacketConn is already closed")
}
copy(p, pt.b)
return len(pt.b), pt.addr, nil
}
func (pc *ReaderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return pc.writePacketConn.WriteTo(p, addr)
}
// LocalAddr returns the listener's address
func (pc *ReaderPacketConn) LocalAddr() net.Addr {
return pc.addr
}
func (pc *ReaderPacketConn) SetDeadline(t time.Time) error {
pc.readTimer.Reset(t.Sub(time.Now()))
return pc.writePacketConn.SetWriteDeadline(t)
}
func (pc *ReaderPacketConn) SetReadDeadline(t time.Time) error {
pc.readTimer.Reset(t.Sub(time.Now()))
return nil
}
func (pc *ReaderPacketConn) SetWriteDeadline(t time.Time) error {
return pc.writePacketConn.SetWriteDeadline(t)
}
func (pc *ReaderPacketConn) FirstPacket() ([]byte, net.Addr, error) {
if pc.firstPacket == nil || pc.addr == nil {
return nil, nil, errors.New("not found first packet")
}
return pc.firstPacket, pc.addr, nil
}
// SendPacket is used to add connection to the listener
func (pc *ReaderPacketConn) SendPacket(b []byte, addr net.Addr) error {
if atomic.LoadInt32(&pc.closed) == 1 {
return errors.New("the listener is already closed")
}
atomic.AddInt32(&pc.nowNum, 1)
select {
case pc.ch <- &packet{b: b, addr: addr}:
return nil
case <-pc.closeCh:
case <-pc.readTimer.C:
_ = pc.Close()
}
if atomic.AddInt32(&pc.nowNum, -1) == 0 && atomic.LoadInt32(&pc.closed) == 1 {
close(pc.ch)
}
return errors.New("the packetConn is already closed")
}
// Close is used to close the listener, it will discard all existing connections
func (pc *ReaderPacketConn) Close() error {
if atomic.CompareAndSwapInt32(&pc.closed, 0, 1) {
close(pc.closeCh)
}
return nil
}