diff --git a/README.md b/README.md index 5c65b71..2adf649 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中 * [与nginx配合](#与nginx配合) * [关闭http|https代理](#关闭代理) * [将nps安装到系统](#将nps安装到系统) -* 单隧道模式及介绍 +* 单隧道模式及介绍(即将移除) * [tcp隧道模式](#tcp隧道模式) * [udp隧道模式](#udp隧道模式) * [socks5代理模式](#socks5代理模式) @@ -62,6 +62,7 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中 * [带宽限制](#带宽限制) * [负载均衡](#负载均衡) * [守护进程](#守护进程) + * [KCP协议支持](#KCP协议支持) * [相关说明](#相关说明) * [流量统计](#流量统计) * [热更新支持](#热更新支持) @@ -138,12 +139,13 @@ go语言编写,无第三方依赖,各个平台都已经编译在release中 ---|--- httpport | web管理端口 password | web界面管理密码 -tcpport | 服务端客户端通信端口 +bridePort | 服务端客户端通信端口 pemPath | ssl certFile绝对路径 keyPath | ssl keyFile绝对路径 httpsProxyPort | 域名代理https代理监听端口 httpProxyPort | 域名代理http代理监听端口 authip|web api免验证IP地址 +bridgeType|客户端与服务端连接方式kcp或tcp ### 详细说明 @@ -539,12 +541,23 @@ authip | 免验证ip,适用于web api ### 守护进程 本代理支持守护进程,使用示例如下,服务端客户端所有模式通用,支持linux,darwin,windows。 ``` -./(nps|npc) start|stop|restart|status xxxxxx +./(nps|npc) start|stop|restart|status 若有其他参数可加其他参数 ``` ``` -(nps|npc).exe start|stop|restart|status xxxxxx +(nps|npc).exe start|stop|restart|status 若有其他参数可加其他参数 ``` +### KCP协议支持 +KCP 是一个快速可靠协议,能以比 TCP浪费10%-20%的带宽的代价,换取平均延迟降低 30%-40%,在弱网环境下对性能能有一定的提升。可在app.conf中修改bridgeType为kcp +,设置后本代理将开启udp端口(bridgePort) + +注意:当服务端为kcp时,客户端连接时也需要加上参数 + +``` +-type=kcp +``` + + ## 相关说明 ### 获取用户真实ip diff --git a/bridge/bridge.go b/bridge/bridge.go index c18370e..973a2b1 100755 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -2,71 +2,103 @@ package bridge import ( "errors" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/kcp" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/pool" + "github.com/cnlh/nps/lib/common" "net" + "strconv" "sync" "time" ) type Client struct { - tunnel *lib.Conn - signal *lib.Conn - linkMap map[int]*lib.Link + tunnel *conn.Conn + signal *conn.Conn + linkMap map[int]*conn.Link linkStatusMap map[int]bool stop chan bool sync.RWMutex } -type Bridge struct { - TunnelPort int //通信隧道端口 - listener *net.TCPListener //server端监听 - Client map[int]*Client - RunList map[int]interface{} //运行中的任务 - lock sync.Mutex - tunnelLock sync.Mutex - clientLock sync.Mutex +func NewClient(t *conn.Conn, s *conn.Conn) *Client { + return &Client{ + linkMap: make(map[int]*conn.Link), + stop: make(chan bool), + linkStatusMap: make(map[int]bool), + signal: s, + tunnel: t, + } } -func NewTunnel(tunnelPort int, runList map[int]interface{}) *Bridge { +type Bridge struct { + TunnelPort int //通信隧道端口 + tcpListener *net.TCPListener //server端监听 + kcpListener *kcp.Listener //server端监听 + Client map[int]*Client + RunList map[int]interface{} //运行中的任务 + tunnelType string //bridge type kcp or tcp + lock sync.Mutex + tunnelLock sync.Mutex + clientLock sync.RWMutex +} + +func NewTunnel(tunnelPort int, runList map[int]interface{}, tunnelType string) *Bridge { t := new(Bridge) t.TunnelPort = tunnelPort t.Client = make(map[int]*Client) t.RunList = runList + t.tunnelType = tunnelType return t } func (s *Bridge) StartTunnel() error { var err error - s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""}) - if err != nil { - return err + if s.tunnelType == "kcp" { + s.kcpListener, err = kcp.ListenWithOptions(":"+strconv.Itoa(s.TunnelPort), nil, 150, 3) + if err != nil { + return err + } + go func() { + for { + c, err := s.kcpListener.AcceptKCP() + conn.SetUdpSession(c) + if err != nil { + lg.Println(err) + continue + } + go s.cliProcess(conn.NewConn(c)) + } + }() + } else { + s.tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.TunnelPort, ""}) + if err != nil { + return err + } + go func() { + for { + c, err := s.tcpListener.Accept() + if err != nil { + lg.Println(err) + continue + } + go s.cliProcess(conn.NewConn(c)) + } + }() } - go s.tunnelProcess() return nil } -//tcp server -func (s *Bridge) tunnelProcess() error { - var err error - for { - conn, err := s.listener.Accept() - if err != nil { - lib.Println(err) - continue - } - go s.cliProcess(lib.NewConn(conn)) - } - return err -} - //验证失败,返回错误验证flag,并且关闭连接 -func (s *Bridge) verifyError(c *lib.Conn) { - c.Write([]byte(lib.VERIFY_EER)) +func (s *Bridge) verifyError(c *conn.Conn) { + c.Write([]byte(common.VERIFY_EER)) c.Conn.Close() } -func (s *Bridge) cliProcess(c *lib.Conn) { - c.SetReadDeadline(5) +func (s *Bridge) cliProcess(c *conn.Conn) { + c.SetReadDeadline(5, s.tunnelType) var buf []byte var err error if buf, err = c.ReadLen(32); err != nil { @@ -74,9 +106,9 @@ func (s *Bridge) cliProcess(c *lib.Conn) { return } //验证 - id, err := lib.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String()) + id, err := file.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String()) if err != nil { - lib.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr()) + lg.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr()) s.verifyError(c) return } @@ -97,40 +129,39 @@ func (s *Bridge) closeClient(id int) { } //tcp连接类型区分 -func (s *Bridge) typeDeal(typeVal string, c *lib.Conn, id int) { +func (s *Bridge) typeDeal(typeVal string, c *conn.Conn, id int) { switch typeVal { - case lib.WORK_MAIN: + case common.WORK_MAIN: //客户端已经存在,下线 - s.clientLock.Lock() - if _, ok := s.Client[id]; ok { - s.clientLock.Unlock() - s.closeClient(id) - } else { - s.clientLock.Unlock() - } - s.clientLock.Lock() - - s.Client[id] = &Client{ - linkMap: make(map[int]*lib.Link), - stop: make(chan bool), - linkStatusMap: make(map[int]bool), - } - lib.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr()) - s.Client[id].signal = c - s.clientLock.Unlock() - go s.GetStatus(id) - case lib.WORK_CHAN: s.clientLock.Lock() if v, ok := s.Client[id]; ok { s.clientLock.Unlock() - v.tunnel = c + if v.signal != nil { + v.signal.WriteClose() + } + v.Lock() + v.signal = c + v.Unlock() } else { + s.Client[id] = NewClient(nil, c) + s.clientLock.Unlock() + } + lg.Printf("客户端%d连接成功,地址为:%s", id, c.Conn.RemoteAddr()) + go s.GetStatus(id) + case common.WORK_CHAN: + s.clientLock.Lock() + if v, ok := s.Client[id]; ok { + s.clientLock.Unlock() + v.Lock() + v.tunnel = c + v.Unlock() + } else { + s.Client[id] = NewClient(c, nil) s.clientLock.Unlock() - return } go s.clientCopy(id) } - c.SetAlive() + c.SetAlive(s.tunnelType) return } @@ -161,13 +192,13 @@ func (s *Bridge) waitStatus(clientId, id int) (bool) { return false } -func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, err error) { +func (s *Bridge) SendLinkInfo(clientId int, link *conn.Link) (tunnel *conn.Conn, err error) { s.clientLock.Lock() if v, ok := s.Client[clientId]; ok { s.clientLock.Unlock() v.signal.SendLinkInfo(link) if err != nil { - lib.Println("send error:", err, link.Id) + lg.Println("send error:", err, link.Id) s.DelClient(clientId) return } @@ -192,7 +223,7 @@ func (s *Bridge) SendLinkInfo(clientId int, link *lib.Link) (tunnel *lib.Conn, e } //得到一个tcp隧道 -func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn, err error) { +func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *conn.Conn, err error) { s.clientLock.Lock() defer s.clientLock.Unlock() if v, ok := s.Client[id]; !ok { @@ -204,7 +235,7 @@ func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (conn *lib.Conn, } //得到一个通信通道 -func (s *Bridge) GetSignal(id int) (conn *lib.Conn, err error) { +func (s *Bridge) GetSignal(id int) (conn *conn.Conn, err error) { s.clientLock.Lock() defer s.clientLock.Unlock() if v, ok := s.Client[id]; !ok { @@ -257,19 +288,19 @@ func (s *Bridge) clientCopy(clientId int) { for { if id, err := client.tunnel.GetLen(); err != nil { s.closeClient(clientId) - lib.Println("读取msg id 错误", err, id) + lg.Println("读取msg id 错误", err, id) break } else { client.Lock() if link, ok := client.linkMap[id]; ok { client.Unlock() if content, err := client.tunnel.GetMsgContent(link); err != nil { - lib.PutBufPoolCopy(content) + pool.PutBufPoolCopy(content) s.closeClient(clientId) - lib.Println("read msg content error", err, "close client") + lg.Println("read msg content error", err, "close client") break } else { - if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF { + if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF { if link.Conn != nil { link.Conn.Close() } @@ -281,7 +312,7 @@ func (s *Bridge) clientCopy(clientId int) { } link.Flow.Add(0, len(content)) } - lib.PutBufPoolCopy(content) + pool.PutBufPoolCopy(content) } } else { client.Unlock() @@ -289,5 +320,4 @@ func (s *Bridge) clientCopy(clientId int) { } } } - } diff --git a/client/client.go b/client/client.go index 996fdd1..00dc391 100755 --- a/client/client.go +++ b/client/client.go @@ -1,30 +1,35 @@ package client import ( - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/kcp" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/pool" "net" "sync" "time" ) type TRPClient struct { - svrAddr string - linkMap map[int]*lib.Link - stop chan bool - tunnel *lib.Conn + svrAddr string + linkMap map[int]*conn.Link + stop chan bool + tunnel *conn.Conn + bridgeConnType string sync.Mutex vKey string } //new client -func NewRPClient(svraddr string, vKey string) *TRPClient { +func NewRPClient(svraddr string, vKey string, bridgeConnType string) *TRPClient { return &TRPClient{ - svrAddr: svraddr, - linkMap: make(map[int]*lib.Link), - stop: make(chan bool), - tunnel: nil, - Mutex: sync.Mutex{}, - vKey: vKey, + svrAddr: svraddr, + linkMap: make(map[int]*conn.Link), + stop: make(chan bool), + Mutex: sync.Mutex{}, + vKey: vKey, + bridgeConnType: bridgeConnType, } } @@ -36,37 +41,44 @@ func (s *TRPClient) Start() error { //新建 func (s *TRPClient) NewConn() { + var err error + var c net.Conn retry: - conn, err := net.Dial("tcp", s.svrAddr) + if s.bridgeConnType == "tcp" { + c, err = net.Dial("tcp", s.svrAddr) + } else { + var sess *kcp.UDPSession + sess, err = kcp.DialWithOptions(s.svrAddr, nil, 150, 3) + conn.SetUdpSession(sess) + c = sess + } if err != nil { - lib.Println("连接服务端失败,五秒后将重连") + lg.Println("连接服务端失败,五秒后将重连") time.Sleep(time.Second * 5) goto retry return } - s.processor(lib.NewConn(conn)) + s.processor(conn.NewConn(c)) } //处理 -func (s *TRPClient) processor(c *lib.Conn) { - c.SetAlive() - if _, err := c.Write([]byte(lib.Getverifyval(s.vKey))); err != nil { +func (s *TRPClient) processor(c *conn.Conn) { + c.SetAlive(s.bridgeConnType) + if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil { return } c.WriteMain() - go s.dealChan() - for { flags, err := c.ReadFlag() if err != nil { - lib.Println("服务端断开,正在重新连接") + lg.Println("服务端断开,正在重新连接") break } switch flags { - case lib.VERIFY_EER: - lib.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey) - case lib.NEW_CONN: + case common.VERIFY_EER: + lg.Fatalf("vKey:%s不正确,服务端拒绝连接,请检查", s.vKey) + case common.NEW_CONN: if link, err := c.GetLinkInfo(); err != nil { break } else { @@ -75,54 +87,46 @@ func (s *TRPClient) processor(c *lib.Conn) { s.Unlock() go s.linkProcess(link, c) } - case lib.RES_CLOSE: - lib.Fatalln("该vkey被另一客户连接") - case lib.RES_MSG: - lib.Println("服务端返回错误,重新连接") + case common.RES_CLOSE: + lg.Fatalln("该vkey被另一客户连接") + case common.RES_MSG: + lg.Println("服务端返回错误,重新连接") break default: - lib.Println("无法解析该错误,重新连接") + lg.Println("无法解析该错误,重新连接") break } } s.stop <- true - s.linkMap = make(map[int]*lib.Link) + s.linkMap = make(map[int]*conn.Link) go s.NewConn() } -func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) { +func (s *TRPClient) linkProcess(link *conn.Link, c *conn.Conn) { //与目标建立连接 server, err := net.DialTimeout(link.ConnType, link.Host, time.Second*3) if err != nil { c.WriteFail(link.Id) - lib.Println("connect to ", link.Host, "error:", err) + lg.Println("connect to ", link.Host, "error:", err) return } c.WriteSuccess(link.Id) - link.Conn = lib.NewConn(server) - + link.Conn = conn.NewConn(server) + buf := pool.BufPoolCopy.Get().([]byte) for { - buf := lib.BufPoolCopy.Get().([]byte) if n, err := server.Read(buf); err != nil { - lib.PutBufPoolCopy(buf) - s.tunnel.SendMsg([]byte(lib.IO_EOF), link) + s.tunnel.SendMsg([]byte(common.IO_EOF), link) break } else { if _, err := s.tunnel.SendMsg(buf[:n], link); err != nil { - lib.PutBufPoolCopy(buf) c.Close() break } - lib.PutBufPoolCopy(buf) - //if link.ConnType == utils.CONN_UDP { - // c.Close() - // break - //} } } - + pool.PutBufPoolCopy(buf) s.Lock() delete(s.linkMap, link.Id) s.Unlock() @@ -131,41 +135,50 @@ func (s *TRPClient) linkProcess(link *lib.Link, c *lib.Conn) { //隧道模式处理 func (s *TRPClient) dealChan() { var err error - //创建一个tcp连接 - conn, err := net.Dial("tcp", s.svrAddr) + var c net.Conn + var sess *kcp.UDPSession + if s.bridgeConnType == "tcp" { + c, err = net.Dial("tcp", s.svrAddr) + } else { + sess, err = kcp.DialWithOptions(s.svrAddr, nil, 10, 3) + conn.SetUdpSession(sess) + c = sess + } if err != nil { - lib.Println("connect to ", s.svrAddr, "error:", err) + lg.Println("connect to ", s.svrAddr, "error:", err) return } //验证 - if _, err := conn.Write([]byte(lib.Getverifyval(s.vKey))); err != nil { - lib.Println("connect to ", s.svrAddr, "error:", err) + if _, err := c.Write([]byte(common.Getverifyval(s.vKey))); err != nil { + lg.Println("connect to ", s.svrAddr, "error:", err) return } //默认长连接保持 - s.tunnel = lib.NewConn(conn) - s.tunnel.SetAlive() + s.tunnel = conn.NewConn(c) + s.tunnel.SetAlive(s.bridgeConnType) //写标志 s.tunnel.WriteChan() go func() { for { if id, err := s.tunnel.GetLen(); err != nil { - lib.Println("get msg id error") + lg.Println("get msg id error") break } else { s.Lock() if v, ok := s.linkMap[id]; ok { s.Unlock() if content, err := s.tunnel.GetMsgContent(v); err != nil { - lib.Println("get msg content error:", err, id) + lg.Println("get msg content error:", err, id) + pool.PutBufPoolCopy(content) break } else { - if len(content) == len(lib.IO_EOF) && string(content) == lib.IO_EOF { + if len(content) == len(common.IO_EOF) && string(content) == common.IO_EOF { v.Conn.Close() } else if v.Conn != nil { v.Conn.Write(content) } + pool.PutBufPoolCopy(content) } } else { s.Unlock() @@ -175,5 +188,6 @@ func (s *TRPClient) dealChan() { }() select { case <-s.stop: + break } } diff --git a/cmd/npc/npc.go b/cmd/npc/npc.go index ac418f4..a72cbba 100644 --- a/cmd/npc/npc.go +++ b/cmd/npc/npc.go @@ -3,8 +3,9 @@ package main import ( "flag" "github.com/cnlh/nps/client" - "github.com/cnlh/nps/lib" - _ "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/daemon" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/common" "strings" ) @@ -14,20 +15,21 @@ var ( serverAddr = flag.String("server", "", "服务器地址ip:端口") verifyKey = flag.String("vkey", "", "验证密钥") logType = flag.String("log", "stdout", "日志输出方式(stdout|file)") + connType = flag.String("type", "tcp", "与服务端建立连接方式(kcp|tcp)") ) func main() { flag.Parse() - lib.InitDaemon("npc") + daemon.InitDaemon("npc", common.GetRunPath(), common.GetPidPath()) if *logType == "stdout" { - lib.InitLogFile("npc", true) + lg.InitLogFile("npc", true, common.GetLogPath()) } else { - lib.InitLogFile("npc", false) + lg.InitLogFile("npc", false, common.GetLogPath()) } stop := make(chan int) for _, v := range strings.Split(*verifyKey, ",") { - lib.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v) - go client.NewRPClient(*serverAddr, v).Start() + lg.Println("客户端启动,连接:", *serverAddr, " 验证令牌:", v) + go client.NewRPClient(*serverAddr, v, *connType).Start() } <-stop } diff --git a/cmd/nps/nps.go b/cmd/nps/nps.go index e966a61..03307ee 100644 --- a/cmd/nps/nps.go +++ b/cmd/nps/nps.go @@ -2,8 +2,12 @@ package main import ( "flag" - "github.com/astaxie/beego" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/beego" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/daemon" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/install" + "github.com/cnlh/nps/lib/lg" "github.com/cnlh/nps/server" _ "github.com/cnlh/nps/web/routers" "log" @@ -28,58 +32,65 @@ var ( func main() { flag.Parse() - if len(os.Args) > 1 && os.Args[1] == "test" { - server.TestServerConfig() - log.Println("test ok, no error") - return + if len(os.Args) > 1 { + switch os.Args[1] { + case "test": + server.TestServerConfig() + log.Println("test ok, no error") + return + case "start", "restart", "stop", "status": + daemon.InitDaemon("nps", common.GetRunPath(), common.GetPidPath()) + case "install": + install.InstallNps() + return + } } - lib.InitDaemon("nps") if *logType == "stdout" { - lib.InitLogFile("nps", true) + lg.InitLogFile("nps", true, common.GetLogPath()) } else { - lib.InitLogFile("nps", false) + lg.InitLogFile("nps", false, common.GetLogPath()) } - task := &lib.Tunnel{ + task := &file.Tunnel{ TcpPort: *httpPort, Mode: *rpMode, Target: *tunnelTarget, - Config: &lib.Config{ + Config: &file.Config{ U: *u, P: *p, Compress: *compress, - Crypt: lib.GetBoolByStr(*crypt), + Crypt: common.GetBoolByStr(*crypt), }, - Flow: &lib.Flow{}, + Flow: &file.Flow{}, UseClientCnf: false, } if *VerifyKey != "" { - c := &lib.Client{ + c := &file.Client{ Id: 0, VerifyKey: *VerifyKey, Addr: "", Remark: "", Status: true, IsConnect: false, - Cnf: &lib.Config{}, - Flow: &lib.Flow{}, + Cnf: &file.Config{}, + Flow: &file.Flow{}, } - c.Cnf.CompressDecode, c.Cnf.CompressEncode = lib.GetCompressType(c.Cnf.Compress) - lib.GetCsvDb().Clients[0] = c + c.Cnf.CompressDecode, c.Cnf.CompressEncode = common.GetCompressType(c.Cnf.Compress) + file.GetCsvDb().Clients[0] = c task.Client = c } if *TcpPort == 0 { - p, err := beego.AppConfig.Int("tcpport") + p, err := beego.AppConfig.Int("bridgePort") if err == nil && *rpMode == "webServer" { *TcpPort = p } else { *TcpPort = 8284 } } - lib.Println("服务端启动,监听tcp服务端端口:", *TcpPort) - task.Config.CompressDecode, task.Config.CompressEncode = lib.GetCompressType(task.Config.Compress) + lg.Printf("服务端启动,监听%s服务端口:%d", beego.AppConfig.String("bridgeType"), *TcpPort) + task.Config.CompressDecode, task.Config.CompressEncode = common.GetCompressType(task.Config.Compress) if *rpMode != "webServer" { - lib.GetCsvDb().Tasks[0] = task + file.GetCsvDb().Tasks[0] = task } - beego.LoadAppConfig("ini", filepath.Join(lib.GetRunPath(), "conf", "app.conf")) - server.StartNewServer(*TcpPort, task) + beego.LoadAppConfig("ini", filepath.Join(common.GetRunPath(), "conf", "app.conf")) + server.StartNewServer(*TcpPort, task, beego.AppConfig.String("bridgeType")) } diff --git a/conf/app.conf b/conf/app.conf index d032faa..f8588a6 100755 --- a/conf/app.conf +++ b/conf/app.conf @@ -1,28 +1,33 @@ appname = nps -#web管理端口 -httpport = 8081 +#Web Management Port +httpport = 8080 -#启动模式dev|pro +#Boot mode(dev|pro) runmode = dev -#web管理密码 +#Web Management Password password=123 -##客户端与服务端通信端口 -tcpport=8284 +##Communication Port between Client and Server +##If the data transfer mode is tcp, it is TCP port +##If the data transfer mode is kcp, it is UDP port +bridgePort=8284 -#web api免验证IP地址 +#Web API unauthenticated IP address authip=127.0.0.1 -##http代理端口,为空则不启动 +##HTTP proxy port, no startup if empty httpProxyPort=80 -##https代理端口,为空则不启动 +##HTTPS proxy port, no startup if empty httpsProxyPort= -##certFile绝对路径 +##certFile absolute path pemPath=/etc/nginx/certificate.crt -##keyFile绝对路径 -keyPath=/etc/nginx/private.key \ No newline at end of file +##KeyFile absolute path +keyPath=/etc/nginx/private.key + +##Data transmission mode(kcp or tcp) +bridgeType=tcp \ No newline at end of file diff --git a/conf/clients.csv b/conf/clients.csv index 6dd9ca2..99927d6 100644 --- a/conf/clients.csv +++ b/conf/clients.csv @@ -1 +1 @@ -1,ydiigrm4ghu7mym1,,true,,,0,,0,0 +1,ydiigrm4ghu7mym1,测试,true,,,0,,0,0 diff --git a/conf/hosts.csv b/conf/hosts.csv index eccf90a..fd44434 100644 --- a/conf/hosts.csv +++ b/conf/hosts.csv @@ -1 +1,2 @@ -a.o.com,127.0.0.1:8081,1,,,测试 +a.o.com,127.0.0.1:8080,1,,,测试 +b.o.com,127.0.0.1:8082,1,,, diff --git a/conf/tasks.csv b/conf/tasks.csv index ef6b307..4bea5c8 100644 --- a/conf/tasks.csv +++ b/conf/tasks.csv @@ -1,4 +1,4 @@ -9001,tunnelServer,123.206.77.88:22,,,,1,0,0,0,1,1,true,测试tcp 53,udpServer,114.114.114.114:53,,,,1,0,0,0,2,1,true,udp -0,socks5Server,,,,,1,0,0,0,3,1,true,socks5 9005,httpProxyServer,,,,,1,0,0,0,4,1,true, +9002,socks5Server,,,,,1,0,0,0,3,1,true,socks5 +9001,tunnelServer,127.0.0.1:8082,,,,1,0,0,0,1,1,true,测试tcp diff --git a/lib/common/const.go b/lib/common/const.go new file mode 100644 index 0000000..3cebd55 --- /dev/null +++ b/lib/common/const.go @@ -0,0 +1,28 @@ +package common + +const ( + COMPRESS_NONE_ENCODE = iota + COMPRESS_NONE_DECODE + COMPRESS_SNAPY_ENCODE + COMPRESS_SNAPY_DECODE + VERIFY_EER = "vkey" + WORK_MAIN = "main" + WORK_CHAN = "chan" + RES_SIGN = "sign" + RES_MSG = "msg0" + RES_CLOSE = "clse" + NEW_CONN = "conn" //新连接标志 + NEW_TASK = "task" //新连接标志 + CONN_SUCCESS = "sucs" + CONN_TCP = "tcp" + CONN_UDP = "udp" + UnauthorizedBytes = `HTTP/1.1 401 Unauthorized +Content-Type: text/plain; charset=utf-8 +WWW-Authenticate: Basic realm="easyProxy" + +401 Unauthorized` + IO_EOF = "PROXYEOF" + ConnectionFailBytes = `HTTP/1.1 404 Not Found + +` +) diff --git a/lib/common/run.go b/lib/common/run.go new file mode 100644 index 0000000..74ed1f3 --- /dev/null +++ b/lib/common/run.go @@ -0,0 +1,67 @@ +package common + +import ( + "os" + "path/filepath" + "runtime" +) + +//Get the currently selected configuration file directory +//For non-Windows systems, select the /etc/nps as config directory if exist, or select ./ +//windows system, select the C:\Program Files\nps as config directory if exist, or select ./ +func GetRunPath() string { + var path string + if path = GetInstallPath(); !FileExists(path) { + return "./" + } + return path +} + +//Different systems get different installation paths +func GetInstallPath() string { + var path string + if IsWindows() { + path = `C:\Program Files\nps` + } else { + path = "/etc/nps" + } + return path +} + +//Get the absolute path to the running directory +func GetAppPath() string { + if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil { + return path + } + return os.Args[0] +} + +//Determine whether the current system is a Windows system? +func IsWindows() bool { + if runtime.GOOS == "windows" { + return true + } + return false +} + +//interface log file path +func GetLogPath() string { + var path string + if IsWindows() { + path = "./" + } else { + path = "/tmp" + } + return path +} + +//interface pid file path +func GetPidPath() string { + var path string + if IsWindows() { + path = "./" + } else { + path = "/tmp" + } + return path +} diff --git a/lib/util.go b/lib/common/util.go similarity index 62% rename from lib/util.go rename to lib/common/util.go index 6fbf9fc..52381c3 100755 --- a/lib/util.go +++ b/lib/common/util.go @@ -1,45 +1,21 @@ -package lib +package common import ( + "bytes" "encoding/base64" + "encoding/binary" + "github.com/cnlh/nps/lib/crypt" + "github.com/cnlh/nps/lib/lg" "io/ioutil" "net" "net/http" "os" - "path/filepath" "regexp" - "runtime" "strconv" "strings" ) -const ( - COMPRESS_NONE_ENCODE = iota - COMPRESS_NONE_DECODE - COMPRESS_SNAPY_ENCODE - COMPRESS_SNAPY_DECODE - VERIFY_EER = "vkey" - WORK_MAIN = "main" - WORK_CHAN = "chan" - RES_SIGN = "sign" - RES_MSG = "msg0" - RES_CLOSE = "clse" - NEW_CONN = "conn" //新连接标志 - CONN_SUCCESS = "sucs" - CONN_TCP = "tcp" - CONN_UDP = "udp" - UnauthorizedBytes = `HTTP/1.1 401 Unauthorized -Content-Type: text/plain; charset=utf-8 -WWW-Authenticate: Basic realm="easyProxy" - -401 Unauthorized` - IO_EOF = "PROXYEOF" - ConnectionFailBytes = `HTTP/1.1 404 Not Found - -` -) - -//判断压缩方式 +//Judging Compression Mode func GetCompressType(compress string) (int, int) { switch compress { case "": @@ -47,12 +23,12 @@ func GetCompressType(compress string) (int, int) { case "snappy": return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE default: - Fatalln("数据压缩格式错误") + lg.Fatalln("数据压缩格式错误") } return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE } -//通过host获取对应的ip地址 +//Get the corresponding IP address through domain name func GetHostByName(hostname string) string { if !DomainCheck(hostname) { return hostname @@ -68,7 +44,7 @@ func GetHostByName(hostname string) string { return "" } -//检查是否是域名 +//Check the legality of domain func DomainCheck(domain string) bool { var match bool IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)" @@ -80,7 +56,7 @@ func DomainCheck(domain string) bool { return match } -//检查basic认证 +//Check if the Request request is validated func CheckAuth(r *http.Request, user, passwd string) bool { s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) if len(s) != 2 { @@ -122,11 +98,12 @@ func GetIntNoErrByStr(str string) int { return i } -//简单的一个校验值 +//Get verify value func Getverifyval(vkey string) string { - return Md5(vkey) + return crypt.Md5(vkey) } +//Change headers and host of request func ChangeHostAndHeader(r *http.Request, host string, header string, addr string) { if host != "" { r.Host = host @@ -145,6 +122,7 @@ func ChangeHostAndHeader(r *http.Request, host string, header string, addr strin r.Header.Set("X-Real-IP", addr) } +//Read file content by file path func ReadAllFromFile(filePath string) ([]byte, error) { f, err := os.Open(filePath) if err != nil { @@ -163,53 +141,7 @@ func FileExists(name string) bool { return true } -func GetRunPath() string { - var path string - if path = GetInstallPath(); !FileExists(path) { - return "./" - } - return path -} -func GetInstallPath() string { - var path string - if IsWindows() { - path = `C:\Program Files\nps` - } else { - path = "/etc/nps" - } - return path -} -func GetAppPath() string { - if path, err := filepath.Abs(filepath.Dir(os.Args[0])); err == nil { - return path - } - return os.Args[0] -} -func IsWindows() bool { - if runtime.GOOS == "windows" { - return true - } - return false -} -func GetLogPath() string { - var path string - if IsWindows() { - path = "./" - } else { - path = "/tmp" - } - return path -} -func GetPidPath() string { - var path string - if IsWindows() { - path = "./" - } else { - path = "/tmp" - } - return path -} - +//Judge whether the TCP port can open normally func TestTcpPort(port int) bool { l, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), port, ""}) defer l.Close() @@ -218,3 +150,27 @@ func TestTcpPort(port int) bool { } return true } + +//Judge whether the UDP port can open normally +func TestUdpPort(port int) bool { + l, err := net.ListenUDP("udp", &net.UDPAddr{net.ParseIP("0.0.0.0"), port, ""}) + defer l.Close() + if err != nil { + return false + } + return true +} + +//Write length and individual byte data +//Length prevents sticking +//# Characters are used to separate data +func BinaryWrite(raw *bytes.Buffer, v ...string) { + buffer := new(bytes.Buffer) + var l int32 + for _, v := range v { + l += int32(len([]byte(v))) + int32(len([]byte("#"))) + binary.Write(buffer, binary.LittleEndian, []byte(v)) + binary.Write(buffer, binary.LittleEndian, []byte("#")) + } + binary.Write(raw, binary.LittleEndian, buffer.Bytes()) +} diff --git a/lib/conn.go b/lib/conn/conn.go similarity index 64% rename from lib/conn.go rename to lib/conn/conn.go index ff69c14..24eb92b 100755 --- a/lib/conn.go +++ b/lib/conn/conn.go @@ -1,11 +1,15 @@ -package lib +package conn import ( "bufio" "bytes" "encoding/binary" "errors" - "github.com/golang/snappy" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/kcp" + "github.com/cnlh/nps/lib/pool" + "github.com/cnlh/nps/lib/rate" "io" "net" "net/http" @@ -18,126 +22,6 @@ import ( const cryptKey = "1234567812345678" -type CryptConn struct { - conn net.Conn - crypt bool - rate *Rate -} - -func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn { - c := new(CryptConn) - c.conn = conn - c.crypt = crypt - c.rate = rate - return c -} - -//加密写 -func (s *CryptConn) Write(b []byte) (n int, err error) { - n = len(b) - if s.crypt { - if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil { - return - } - } - if b, err = GetLenBytes(b); err != nil { - return - } - _, err = s.conn.Write(b) - if s.rate != nil { - s.rate.Get(int64(n)) - } - return -} - -//解密读 -func (s *CryptConn) Read(b []byte) (n int, err error) { - var lens int - var buf []byte - var rb []byte - c := NewConn(s.conn) - if lens, err = c.GetLen(); err != nil { - return - } - if buf, err = c.ReadLen(lens); err != nil { - return - } - if s.crypt { - if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil { - return - } - } else { - rb = buf - } - copy(b, rb) - n = len(rb) - if s.rate != nil { - s.rate.Get(int64(n)) - } - return -} - -type SnappyConn struct { - w *snappy.Writer - r *snappy.Reader - crypt bool - rate *Rate -} - -func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn { - c := new(SnappyConn) - c.w = snappy.NewBufferedWriter(conn) - c.r = snappy.NewReader(conn) - c.crypt = crypt - c.rate = rate - return c -} - -//snappy压缩写 包含加密 -func (s *SnappyConn) Write(b []byte) (n int, err error) { - n = len(b) - if s.crypt { - if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil { - Println("encode crypt error:", err) - return - } - } - if _, err = s.w.Write(b); err != nil { - return - } - if err = s.w.Flush(); err != nil { - return - } - if s.rate != nil { - s.rate.Get(int64(n)) - } - return -} - -//snappy压缩读 包含解密 -func (s *SnappyConn) Read(b []byte) (n int, err error) { - buf := BufPool.Get().([]byte) - defer BufPool.Put(buf) - if n, err = s.r.Read(buf); err != nil { - return - } - var bs []byte - if s.crypt { - if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil { - Println("decode crypt error:", err) - return - } - } else { - bs = buf[:n] - } - n = len(bs) - copy(b, bs) - if s.rate != nil { - s.rate.Get(int64(n)) - } - return -} - type Conn struct { Conn net.Conn sync.Mutex @@ -186,16 +70,16 @@ func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http. //读取指定长度内容 func (s *Conn) ReadLen(cLen int) ([]byte, error) { - if cLen > poolSize { + if cLen > pool.PoolSize { return nil, errors.New("长度错误" + strconv.Itoa(cLen)) } var buf []byte - if cLen <= poolSizeSmall { - buf = BufPoolSmall.Get().([]byte)[:cLen] - defer BufPoolSmall.Put(buf) + if cLen <= pool.PoolSizeSmall { + buf = pool.BufPoolSmall.Get().([]byte)[:cLen] + defer pool.BufPoolSmall.Put(buf) } else { - buf = BufPoolMax.Get().([]byte)[:cLen] - defer BufPoolMax.Put(buf) + buf = pool.BufPoolMax.Get().([]byte)[:cLen] + defer pool.BufPoolMax.Put(buf) } if n, err := io.ReadFull(s, buf); err != nil || n != cLen { return buf, errors.New("读取指定长度错误" + err.Error()) @@ -231,35 +115,64 @@ func (s *Conn) GetConnStatus() (id int, status bool, err error) { if b, err = s.ReadLen(1); err != nil { return } else { - status = GetBoolByStr(string(b[0])) + status = common.GetBoolByStr(string(b[0])) } return } //设置连接为长连接 -func (s *Conn) SetAlive() { +func (s *Conn) SetAlive(tp string) { + if tp == "kcp" { + s.setKcpAlive() + } else { + s.setTcpAlive() + } +} + +//设置连接为长连接 +func (s *Conn) setTcpAlive() { conn := s.Conn.(*net.TCPConn) conn.SetReadDeadline(time.Time{}) conn.SetKeepAlive(true) conn.SetKeepAlivePeriod(time.Duration(2 * time.Second)) } +//设置连接为长连接 +func (s *Conn) setKcpAlive() { + conn := s.Conn.(*kcp.UDPSession) + conn.SetReadDeadline(time.Time{}) +} + +//设置连接为长连接 +func (s *Conn) SetReadDeadline(t time.Duration, tp string) { + if tp == "kcp" { + s.SetKcpReadDeadline(t) + } else { + s.SetTcpReadDeadline(t) + } +} + //set read dead time -func (s *Conn) SetReadDeadline(t time.Duration) { +func (s *Conn) SetTcpReadDeadline(t time.Duration) { s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second)) } +//set read dead time +func (s *Conn) SetKcpReadDeadline(t time.Duration) { + s.Conn.(*kcp.UDPSession).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second)) +} + //单独读(加密|压缩) -func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) { - if COMPRESS_SNAPY_DECODE == compress { +func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *rate.Rate) (int, error) { + if common.COMPRESS_SNAPY_DECODE == compress { return NewSnappyConn(s.Conn, crypt, rate).Read(b) } return NewCryptConn(s.Conn, crypt, rate).Read(b) } //单独写(加密|压缩) -func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) { - if COMPRESS_SNAPY_ENCODE == compress { +func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *rate.Rate) (n int, err error) { + if common.COMPRESS_SNAPY_ENCODE == compress { return NewSnappyConn(s.Conn, crypt, rate).Write(b) } return NewCryptConn(s.Conn, crypt, rate).Write(b) @@ -292,7 +205,7 @@ func (s *Conn) SendMsg(content []byte, link *Link) (n int, err error) { func (s *Conn) GetMsgContent(link *Link) (content []byte, err error) { s.Lock() defer s.Unlock() - buf := BufPoolCopy.Get().([]byte) + buf := pool.BufPoolCopy.Get().([]byte) if n, err := s.ReadFrom(buf, link.De, link.Crypt, link.Rate); err == nil && n > 4 { content = buf[:n] } @@ -310,7 +223,7 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) { +----------+------+----------+------+----+----+------+ */ raw := bytes.NewBuffer([]byte{}) - binary.Write(raw, binary.LittleEndian, []byte(NEW_CONN)) + binary.Write(raw, binary.LittleEndian, []byte(common.NEW_CONN)) binary.Write(raw, binary.LittleEndian, int32(14+len(link.Host))) binary.Write(raw, binary.LittleEndian, int32(link.Id)) binary.Write(raw, binary.LittleEndian, []byte(link.ConnType)) @@ -318,13 +231,13 @@ func (s *Conn) SendLinkInfo(link *Link) (int, error) { binary.Write(raw, binary.LittleEndian, []byte(link.Host)) binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.En))) binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.De))) - binary.Write(raw, binary.LittleEndian, []byte(GetStrByBool(link.Crypt))) + binary.Write(raw, binary.LittleEndian, []byte(common.GetStrByBool(link.Crypt))) s.Lock() defer s.Unlock() return s.Write(raw.Bytes()) } -func (s *Conn) GetLinkInfo() (link *Link, err error) { +func (s *Conn) GetLinkInfo() (lk *Link, err error) { s.Lock() defer s.Unlock() var hostLen, n int @@ -332,21 +245,69 @@ func (s *Conn) GetLinkInfo() (link *Link, err error) { if n, err = s.GetLen(); err != nil { return } - link = new(Link) + lk = new(Link) if buf, err = s.ReadLen(n); err != nil { return } - if link.Id, err = GetLenByBytes(buf[:4]); err != nil { + if lk.Id, err = GetLenByBytes(buf[:4]); err != nil { return } - link.ConnType = string(buf[4:7]) + lk.ConnType = string(buf[4:7]) if hostLen, err = GetLenByBytes(buf[7:11]); err != nil { return } else { - link.Host = string(buf[11 : 11+hostLen]) - link.En = GetIntNoErrByStr(string(buf[11+hostLen])) - link.De = GetIntNoErrByStr(string(buf[12+hostLen])) - link.Crypt = GetBoolByStr(string(buf[13+hostLen])) + lk.Host = string(buf[11 : 11+hostLen]) + lk.En = common.GetIntNoErrByStr(string(buf[11+hostLen])) + lk.De = common.GetIntNoErrByStr(string(buf[12+hostLen])) + lk.Crypt = common.GetBoolByStr(string(buf[13+hostLen])) + } + return +} + +//send task info +func (s *Conn) SendTaskInfo(t *file.Tunnel) (int, error) { + /* + The task info is formed as follows: + +----+-----+---------+ + |type| len | content | + +----+---------------+ + | 4 | 4 | ... | + +----+---------------+ +*/ + raw := bytes.NewBuffer([]byte{}) + binary.Write(raw, binary.LittleEndian, common.NEW_TASK) + common.BinaryWrite(raw, t.Mode, string(t.TcpPort), string(t.Target), string(t.Config.U), string(t.Config.P), common.GetStrByBool(t.Config.Crypt), t.Config.Compress, t.Remark) + s.Lock() + defer s.Unlock() + return s.Write(raw.Bytes()) +} + +//get task info +func (s *Conn) GetTaskInfo() (t *file.Tunnel, err error) { + var l int + var b []byte + if l, err = s.GetLen(); err != nil { + return + } else if b, err = s.ReadLen(l); err != nil { + return + } else { + arr := strings.Split(string(b), "#") + t.Mode = arr[0] + t.TcpPort, _ = strconv.Atoi(arr[1]) + t.Target = arr[2] + t.Config = new(file.Config) + t.Config.U = arr[3] + t.Config.P = arr[4] + t.Config.Compress = arr[5] + t.Config.CompressDecode, t.Config.CompressDecode = common.GetCompressType(arr[5]) + t.Id = file.GetCsvDb().GetTaskId() + t.Status = true + if t.Client, err = file.GetCsvDb().GetClient(0); err != nil { + return + } + t.Flow = new(file.Flow) + t.Remark = arr[6] + t.UseClientCnf = false } return } @@ -388,31 +349,31 @@ func (s *Conn) Read(b []byte) (int, error) { //write error func (s *Conn) WriteError() (int, error) { - return s.Write([]byte(RES_MSG)) + return s.Write([]byte(common.RES_MSG)) } //write sign flag func (s *Conn) WriteSign() (int, error) { - return s.Write([]byte(RES_SIGN)) + return s.Write([]byte(common.RES_SIGN)) } //write sign flag func (s *Conn) WriteClose() (int, error) { - return s.Write([]byte(RES_CLOSE)) + return s.Write([]byte(common.RES_CLOSE)) } //write main func (s *Conn) WriteMain() (int, error) { s.Lock() defer s.Unlock() - return s.Write([]byte(WORK_MAIN)) + return s.Write([]byte(common.WORK_MAIN)) } //write chan func (s *Conn) WriteChan() (int, error) { s.Lock() defer s.Unlock() - return s.Write([]byte(WORK_CHAN)) + return s.Write([]byte(common.WORK_CHAN)) } //获取长度+内容 @@ -436,3 +397,13 @@ func GetLenByBytes(buf []byte) (int, error) { } return int(nlen), nil } + +func SetUdpSession(sess *kcp.UDPSession) { + sess.SetStreamMode(true) + sess.SetWindowSize(1024, 1024) + sess.SetReadBuffer(64 * 1024) + sess.SetWriteBuffer(64 * 1024) + sess.SetNoDelay(1, 10, 2, 1) + sess.SetMtu(1600) + sess.SetACKNoDelay(true) +} diff --git a/lib/conn/link.go b/lib/conn/link.go new file mode 100644 index 0000000..9f50681 --- /dev/null +++ b/lib/conn/link.go @@ -0,0 +1,37 @@ +package conn + +import ( + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/rate" + "net" +) + +type Link struct { + Id int //id + ConnType string //连接类型 + Host string //目标 + En int //加密 + De int //解密 + Crypt bool //加密 + Conn *Conn + Flow *file.Flow + UdpListener *net.UDPConn + Rate *rate.Rate + UdpRemoteAddr *net.UDPAddr +} + +func NewLink(id int, connType string, host string, en, de int, crypt bool, c *Conn, flow *file.Flow, udpListener *net.UDPConn, rate *rate.Rate, UdpRemoteAddr *net.UDPAddr) *Link { + return &Link{ + Id: id, + ConnType: connType, + Host: host, + En: en, + De: de, + Crypt: crypt, + Conn: c, + Flow: flow, + UdpListener: udpListener, + Rate: rate, + UdpRemoteAddr: UdpRemoteAddr, + } +} diff --git a/lib/conn/normal.go b/lib/conn/normal.go new file mode 100644 index 0000000..9ce9777 --- /dev/null +++ b/lib/conn/normal.go @@ -0,0 +1,66 @@ +package conn + +import ( + "github.com/cnlh/nps/lib/crypt" + "github.com/cnlh/nps/lib/rate" + "net" +) + +type CryptConn struct { + conn net.Conn + crypt bool + rate *rate.Rate +} + +func NewCryptConn(conn net.Conn, crypt bool, rate *rate.Rate) *CryptConn { + c := new(CryptConn) + c.conn = conn + c.crypt = crypt + c.rate = rate + return c +} + +//加密写 +func (s *CryptConn) Write(b []byte) (n int, err error) { + n = len(b) + if s.crypt { + if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil { + return + } + } + if b, err = GetLenBytes(b); err != nil { + return + } + _, err = s.conn.Write(b) + if s.rate != nil { + s.rate.Get(int64(n)) + } + return +} + +//解密读 +func (s *CryptConn) Read(b []byte) (n int, err error) { + var lens int + var buf []byte + var rb []byte + c := NewConn(s.conn) + if lens, err = c.GetLen(); err != nil { + return + } + if buf, err = c.ReadLen(lens); err != nil { + return + } + if s.crypt { + if rb, err = crypt.AesDecrypt(buf, []byte(cryptKey)); err != nil { + return + } + } else { + rb = buf + } + copy(b, rb) + n = len(rb) + if s.rate != nil { + s.rate.Get(int64(n)) + } + return +} diff --git a/lib/conn/snappy.go b/lib/conn/snappy.go new file mode 100644 index 0000000..7750c1b --- /dev/null +++ b/lib/conn/snappy.go @@ -0,0 +1,72 @@ +package conn + +import ( + "github.com/cnlh/nps/lib/crypt" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/pool" + "github.com/cnlh/nps/lib/rate" + "github.com/cnlh/nps/lib/snappy" + "log" + "net" +) + +type SnappyConn struct { + w *snappy.Writer + r *snappy.Reader + crypt bool + rate *rate.Rate +} + +func NewSnappyConn(conn net.Conn, crypt bool, rate *rate.Rate) *SnappyConn { + c := new(SnappyConn) + c.w = snappy.NewBufferedWriter(conn) + c.r = snappy.NewReader(conn) + c.crypt = crypt + c.rate = rate + return c +} + +//snappy压缩写 包含加密 +func (s *SnappyConn) Write(b []byte) (n int, err error) { + n = len(b) + if s.crypt { + if b, err = crypt.AesEncrypt(b, []byte(cryptKey)); err != nil { + lg.Println("encode crypt error:", err) + return + } + } + if _, err = s.w.Write(b); err != nil { + return + } + if err = s.w.Flush(); err != nil { + return + } + if s.rate != nil { + s.rate.Get(int64(n)) + } + return +} + +//snappy压缩读 包含解密 +func (s *SnappyConn) Read(b []byte) (n int, err error) { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf) + if n, err = s.r.Read(buf); err != nil { + return + } + var bs []byte + if s.crypt { + if bs, err = crypt.AesDecrypt(buf[:n], []byte(cryptKey)); err != nil { + log.Println("decode crypt error:", err) + return + } + } else { + bs = buf[:n] + } + n = len(bs) + copy(b, bs) + if s.rate != nil { + s.rate.Get(int64(n)) + } + return +} diff --git a/lib/crypt.go b/lib/crypt/crypt.go similarity index 93% rename from lib/crypt.go rename to lib/crypt/crypt.go index 952ecfe..17e9947 100644 --- a/lib/crypt.go +++ b/lib/crypt/crypt.go @@ -1,4 +1,4 @@ -package lib +package crypt import ( "bytes" @@ -37,21 +37,19 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) { blockSize := block.BlockSize() blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) origData := make([]byte, len(crypted)) - // origData := crypted blockMode.CryptBlocks(origData, crypted) err, origData = PKCS5UnPadding(origData) - // origData = ZeroUnPadding(origData) return origData, err } -//补全 +//Completion when the length is insufficient func PKCS5Padding(ciphertext []byte, blockSize int) []byte { padding := blockSize - len(ciphertext)%blockSize padtext := bytes.Repeat([]byte{byte(padding)}, padding) return append(ciphertext, padtext...) } -//去补 +//Remove excess func PKCS5UnPadding(origData []byte) (error, []byte) { length := len(origData) // 去掉最后一个字节 unpadding 次 @@ -62,14 +60,14 @@ func PKCS5UnPadding(origData []byte) (error, []byte) { return nil, origData[:(length - unpadding)] } -//生成32位md5字串 +//Generate 32-bit MD5 strings func Md5(s string) string { h := md5.New() h.Write([]byte(s)) return hex.EncodeToString(h.Sum(nil)) } -//生成随机验证密钥 +//Generating Random Verification Key func GetRandomString(l int) string { str := "0123456789abcdefghijklmnopqrstuvwxyz" bytes := []byte(str) diff --git a/lib/daemon.go b/lib/daemon/daemon.go similarity index 65% rename from lib/daemon.go rename to lib/daemon/daemon.go index cef26b1..7d15e73 100644 --- a/lib/daemon.go +++ b/lib/daemon/daemon.go @@ -1,6 +1,7 @@ -package lib +package daemon import ( + "github.com/cnlh/nps/lib/common" "io/ioutil" "log" "os" @@ -10,7 +11,7 @@ import ( "strings" ) -func InitDaemon(f string) { +func InitDaemon(f string, runPath string, pidPath string) { if len(os.Args) < 2 { return } @@ -22,22 +23,17 @@ func InitDaemon(f string) { args = append(args, "-log=file") switch os.Args[1] { case "start": - start(args, f) + start(args, f, pidPath, runPath) os.Exit(0) case "stop": - stop(f, args[0]) + stop(f, args[0], pidPath) os.Exit(0) case "restart": - stop(f, args[0]) - start(args, f) - os.Exit(0) - case "install": - if f == "nps" { - InstallNps() - } + stop(f, args[0], pidPath) + start(args, f, pidPath, runPath) os.Exit(0) case "status": - if status(f) { + if status(f, pidPath) { log.Printf("%s is running", f) } else { log.Printf("%s is not running", f) @@ -46,11 +42,11 @@ func InitDaemon(f string) { } } -func status(f string) bool { +func status(f string, pidPath string) bool { var cmd *exec.Cmd - b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid")) + b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid")) if err == nil { - if !IsWindows() { + if !common.IsWindows() { cmd = exec.Command("/bin/sh", "-c", "ps -ax | awk '{ print $1 }' | grep "+string(b)) } else { cmd = exec.Command("tasklist", ) @@ -63,38 +59,38 @@ func status(f string) bool { return false } -func start(osArgs []string, f string) { - if status(f) { +func start(osArgs []string, f string, pidPath, runPath string) { + if status(f, pidPath) { log.Printf(" %s is running", f) return } cmd := exec.Command(osArgs[0], osArgs[1:]...) cmd.Start() if cmd.Process.Pid > 0 { - log.Println("start ok , pid:", cmd.Process.Pid, "config path:", GetRunPath()) + log.Println("start ok , pid:", cmd.Process.Pid, "config path:", runPath) d1 := []byte(strconv.Itoa(cmd.Process.Pid)) - ioutil.WriteFile(filepath.Join(GetPidPath(), f+".pid"), d1, 0600) + ioutil.WriteFile(filepath.Join(pidPath, f+".pid"), d1, 0600) } else { log.Println("start error") } } -func stop(f string, p string) { - if !status(f) { +func stop(f string, p string, pidPath string) { + if !status(f, pidPath) { log.Printf(" %s is not running", f) return } var c *exec.Cmd var err error - if IsWindows() { + if common.IsWindows() { p := strings.Split(p, `\`) c = exec.Command("taskkill", "/F", "/IM", p[len(p)-1]) } else { - b, err := ioutil.ReadFile(filepath.Join(GetPidPath(), f+".pid")) + b, err := ioutil.ReadFile(filepath.Join(pidPath, f+".pid")) if err == nil { c = exec.Command("/bin/bash", "-c", `kill -9 `+string(b)) } else { - log.Fatalln("stop error,PID file does not exist") + log.Fatalln("stop error,pid file does not exist") } } err = c.Run() diff --git a/lib/file/csv.go b/lib/file/csv.go new file mode 100644 index 0000000..e5347b5 --- /dev/null +++ b/lib/file/csv.go @@ -0,0 +1,19 @@ +package file + +import ( + "github.com/cnlh/nps/lib/common" + "sync" +) + +var ( + CsvDb *Csv + once sync.Once +) +//init csv from file +func GetCsvDb() *Csv { + once.Do(func() { + CsvDb = NewCsv(common.GetRunPath()) + CsvDb.Init() + }) + return CsvDb +} diff --git a/lib/file.go b/lib/file/file.go similarity index 79% rename from lib/file.go rename to lib/file/file.go index 08c7d45..d955b9a 100644 --- a/lib/file.go +++ b/lib/file/file.go @@ -1,8 +1,11 @@ -package lib +package file import ( "encoding/csv" "errors" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/rate" "os" "path/filepath" "strconv" @@ -10,13 +13,10 @@ import ( "sync" ) -var ( - CsvDb *Csv - once sync.Once -) - -func NewCsv() *Csv { - return new(Csv) +func NewCsv(runPath string) *Csv { + return &Csv{ + RunPath: runPath, + } } type Csv struct { @@ -24,6 +24,7 @@ type Csv struct { Path string Hosts []*Host //域名列表 Clients []*Client //客户端 + RunPath string //存储根目录 ClientIncreaseId int //客户端id TaskIncreaseId int //任务自增ID sync.Mutex @@ -37,9 +38,9 @@ func (s *Csv) Init() { func (s *Csv) StoreTasksToCsv() { // 创建文件 - csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "tasks.csv")) + csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv")) if err != nil { - Fatalf(err.Error()) + lg.Fatalf(err.Error()) } defer csvFile.Close() writer := csv.NewWriter(csvFile) @@ -51,8 +52,8 @@ func (s *Csv) StoreTasksToCsv() { task.Config.U, task.Config.P, task.Config.Compress, - GetStrByBool(task.Status), - GetStrByBool(task.Config.Crypt), + common.GetStrByBool(task.Status), + common.GetStrByBool(task.Config.Crypt), strconv.Itoa(task.Config.CompressEncode), strconv.Itoa(task.Config.CompressDecode), strconv.Itoa(task.Id), @@ -62,7 +63,7 @@ func (s *Csv) StoreTasksToCsv() { } err := writer.Write(record) if err != nil { - Fatalf(err.Error()) + lg.Fatalf(err.Error()) } } writer.Flush() @@ -87,33 +88,33 @@ func (s *Csv) openFile(path string) ([][]string, error) { } func (s *Csv) LoadTaskFromCsv() { - path := filepath.Join(GetRunPath(), "conf", "tasks.csv") + path := filepath.Join(s.RunPath, "conf", "tasks.csv") records, err := s.openFile(path) if err != nil { - Fatalln("配置文件打开错误:", path) + lg.Fatalln("配置文件打开错误:", path) } var tasks []*Tunnel // 将每一行数据保存到内存slice中 for _, item := range records { post := &Tunnel{ - TcpPort: GetIntNoErrByStr(item[0]), + TcpPort: common.GetIntNoErrByStr(item[0]), Mode: item[1], Target: item[2], Config: &Config{ U: item[3], P: item[4], Compress: item[5], - Crypt: GetBoolByStr(item[7]), - CompressEncode: GetIntNoErrByStr(item[8]), - CompressDecode: GetIntNoErrByStr(item[9]), + Crypt: common.GetBoolByStr(item[7]), + CompressEncode: common.GetIntNoErrByStr(item[8]), + CompressDecode: common.GetIntNoErrByStr(item[9]), }, - Status: GetBoolByStr(item[6]), - Id: GetIntNoErrByStr(item[10]), - UseClientCnf: GetBoolByStr(item[12]), + Status: common.GetBoolByStr(item[6]), + Id: common.GetIntNoErrByStr(item[10]), + UseClientCnf: common.GetBoolByStr(item[12]), Remark: item[13], } post.Flow = new(Flow) - if post.Client, err = s.GetClient(GetIntNoErrByStr(item[11])); err != nil { + if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[11])); err != nil { continue } tasks = append(tasks, post) @@ -135,7 +136,7 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) { s.Lock() defer s.Unlock() for _, v := range s.Clients { - if Getverifyval(v.VerifyKey) == vKey && v.Status { + if common.Getverifyval(v.VerifyKey) == vKey && v.Status { if arr := strings.Split(addr, ":"); len(arr) > 0 { v.Addr = arr[0] } @@ -186,7 +187,7 @@ func (s *Csv) GetTask(id int) (v *Tunnel, err error) { func (s *Csv) StoreHostToCsv() { // 创建文件 - csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "hosts.csv")) + csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv")) if err != nil { panic(err) } @@ -214,24 +215,24 @@ func (s *Csv) StoreHostToCsv() { } func (s *Csv) LoadClientFromCsv() { - path := filepath.Join(GetRunPath(), "conf", "clients.csv") + path := filepath.Join(s.RunPath, "conf", "clients.csv") records, err := s.openFile(path) if err != nil { - Fatalln("配置文件打开错误:", path) + lg.Fatalln("配置文件打开错误:", path) } var clients []*Client // 将每一行数据保存到内存slice中 for _, item := range records { post := &Client{ - Id: GetIntNoErrByStr(item[0]), + Id: common.GetIntNoErrByStr(item[0]), VerifyKey: item[1], Remark: item[2], - Status: GetBoolByStr(item[3]), - RateLimit: GetIntNoErrByStr(item[8]), + Status: common.GetBoolByStr(item[3]), + RateLimit: common.GetIntNoErrByStr(item[8]), Cnf: &Config{ U: item[4], P: item[5], - Crypt: GetBoolByStr(item[6]), + Crypt: common.GetBoolByStr(item[6]), Compress: item[7], }, } @@ -239,21 +240,21 @@ func (s *Csv) LoadClientFromCsv() { s.ClientIncreaseId = post.Id } if post.RateLimit > 0 { - post.Rate = NewRate(int64(post.RateLimit * 1024)) + post.Rate = rate.NewRate(int64(post.RateLimit * 1024)) post.Rate.Start() } post.Flow = new(Flow) - post.Flow.FlowLimit = int64(GetIntNoErrByStr(item[9])) + post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9])) clients = append(clients, post) } s.Clients = clients } func (s *Csv) LoadHostFromCsv() { - path := filepath.Join(GetRunPath(), "conf", "hosts.csv") + path := filepath.Join(s.RunPath, "conf", "hosts.csv") records, err := s.openFile(path) if err != nil { - Fatalln("配置文件打开错误:", path) + lg.Fatalln("配置文件打开错误:", path) } var hosts []*Host // 将每一行数据保存到内存slice中 @@ -265,7 +266,7 @@ func (s *Csv) LoadHostFromCsv() { HostChange: item[4], Remark: item[5], } - if post.Client, err = s.GetClient(GetIntNoErrByStr(item[2])); err != nil { + if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil { continue } post.Flow = new(Flow) @@ -387,11 +388,12 @@ func (s *Csv) GetClient(id int) (v *Client, err error) { err = errors.New("未找到") return } + func (s *Csv) StoreClientsToCsv() { // 创建文件 - csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "clients.csv")) + csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv")) if err != nil { - Fatalln(err.Error()) + lg.Fatalln(err.Error()) } defer csvFile.Close() writer := csv.NewWriter(csvFile) @@ -403,24 +405,15 @@ func (s *Csv) StoreClientsToCsv() { strconv.FormatBool(client.Status), client.Cnf.U, client.Cnf.P, - GetStrByBool(client.Cnf.Crypt), + common.GetStrByBool(client.Cnf.Crypt), client.Cnf.Compress, strconv.Itoa(client.RateLimit), strconv.Itoa(int(client.Flow.FlowLimit)), } err := writer.Write(record) if err != nil { - Fatalln(err.Error()) + lg.Fatalln(err.Error()) } } writer.Flush() } - -//init csv from file -func GetCsvDb() *Csv { - once.Do(func() { - CsvDb = NewCsv() - CsvDb.Init() - }) - return CsvDb -} diff --git a/lib/link.go b/lib/file/obj.go similarity index 61% rename from lib/link.go rename to lib/file/obj.go index 03d85bc..7dc2897 100644 --- a/lib/link.go +++ b/lib/file/obj.go @@ -1,41 +1,11 @@ -package lib +package file import ( - "net" + "github.com/cnlh/nps/lib/rate" "strings" "sync" ) -type Link struct { - Id int //id - ConnType string //连接类型 - Host string //目标 - En int //加密 - De int //解密 - Crypt bool //加密 - Conn *Conn - Flow *Flow - UdpListener *net.UDPConn - Rate *Rate - UdpRemoteAddr *net.UDPAddr -} - -func NewLink(id int, connType string, host string, en, de int, crypt bool, conn *Conn, flow *Flow, udpListener *net.UDPConn, rate *Rate, UdpRemoteAddr *net.UDPAddr) *Link { - return &Link{ - Id: id, - ConnType: connType, - Host: host, - En: en, - De: de, - Crypt: crypt, - Conn: conn, - Flow: flow, - UdpListener: udpListener, - Rate: rate, - UdpRemoteAddr: UdpRemoteAddr, - } -} - type Flow struct { ExportFlow int64 //出口流量 InletFlow int64 //入口流量 @@ -52,15 +22,15 @@ func (s *Flow) Add(in, out int) { type Client struct { Cnf *Config - Id int //id - VerifyKey string //验证密钥 - Addr string //客户端ip地址 - Remark string //备注 - Status bool //是否开启 - IsConnect bool //是否连接 - RateLimit int //速度限制 /kb - Flow *Flow //流量 - Rate *Rate //速度控制 + Id int //id + VerifyKey string //验证密钥 + Addr string //客户端ip地址 + Remark string //备注 + Status bool //是否开启 + IsConnect bool //是否连接 + RateLimit int //速度限制 /kb + Flow *Flow //流量 + Rate *rate.Rate //速度控制 id int sync.RWMutex } @@ -74,7 +44,7 @@ func (s *Client) GetId() int { type Tunnel struct { Id int //Id - TcpPort int //服务端与客户端通信端口 + TcpPort int //服务端监听端口 Mode string //启动方式 Target string //目标 Status bool //是否开启 diff --git a/lib/install.go b/lib/install/install.go similarity index 81% rename from lib/install.go rename to lib/install/install.go index 64fcc9e..63d2e46 100644 --- a/lib/install.go +++ b/lib/install/install.go @@ -1,8 +1,9 @@ -package lib +package install import ( "errors" "fmt" + "github.com/cnlh/nps/lib/common" "io" "log" "os" @@ -11,22 +12,22 @@ import ( ) func InstallNps() { - path := GetInstallPath() + path := common.GetInstallPath() MkidrDirAll(path, "conf", "web/static", "web/views") //复制文件到对应目录 - if err := CopyDir(filepath.Join(GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil { + if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "views"), filepath.Join(path, "web", "views")); err != nil { log.Fatalln(err) } - if err := CopyDir(filepath.Join(GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil { + if err := CopyDir(filepath.Join(common.GetAppPath(), "web", "static"), filepath.Join(path, "web", "static")); err != nil { log.Fatalln(err) } - if err := CopyDir(filepath.Join(GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil { + if err := CopyDir(filepath.Join(common.GetAppPath(), "conf"), filepath.Join(path, "conf")); err != nil { log.Fatalln(err) } - if !IsWindows() { - if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/bin/nps"); err != nil { - if _, err := copyFile(filepath.Join(GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil { + if !common.IsWindows() { + if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/bin/nps"); err != nil { + if _, err := copyFile(filepath.Join(common.GetAppPath(), "nps"), "/usr/local/bin/nps"); err != nil { log.Fatalln(err) } else { os.Chmod("/usr/local/bin/nps", 0777) @@ -41,7 +42,7 @@ func InstallNps() { log.Println("install ok!") log.Println("Static files and configuration files in the current directory will be useless") log.Println("The new configuration file is located in", path, "you can edit them") - if !IsWindows() { + if !common.IsWindows() { log.Println("You can start with nps test|start|stop|restart|status anywhere") } else { log.Println("You can copy executable files to any directory and start working with nps.exe test|start|stop|restart|status") diff --git a/lib/kcp/crypt.go b/lib/kcp/crypt.go new file mode 100644 index 0000000..958fdea --- /dev/null +++ b/lib/kcp/crypt.go @@ -0,0 +1,785 @@ +package kcp + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/sha1" + + "github.com/templexxx/xor" + "github.com/tjfoc/gmsm/sm4" + + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/crypto/salsa20" + "golang.org/x/crypto/tea" + "golang.org/x/crypto/twofish" + "golang.org/x/crypto/xtea" +) + +var ( + initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107} + saltxor = `sH3CIVoF#rWLtJo6` +) + +// BlockCrypt defines encryption/decryption methods for a given byte slice. +// Notes on implementing: the data to be encrypted contains a builtin +// nonce at the first 16 bytes +type BlockCrypt interface { + // Encrypt encrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Encrypt(dst, src []byte) + + // Decrypt decrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Decrypt(dst, src []byte) +} + +type salsa20BlockCrypt struct { + key [32]byte +} + +// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20 +func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(salsa20BlockCrypt) + copy(c.key[:], key) + return c, nil +} + +func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} +func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} + +type sm4BlockCrypt struct { + encbuf [sm4.BlockSize]byte + decbuf [2 * sm4.BlockSize]byte + block cipher.Block +} + +// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4 +func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(sm4BlockCrypt) + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type twofishBlockCrypt struct { + encbuf [twofish.BlockSize]byte + decbuf [2 * twofish.BlockSize]byte + block cipher.Block +} + +// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish +func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(twofishBlockCrypt) + block, err := twofish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type tripleDESBlockCrypt struct { + encbuf [des.BlockSize]byte + decbuf [2 * des.BlockSize]byte + block cipher.Block +} + +// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES +func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(tripleDESBlockCrypt) + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type cast5BlockCrypt struct { + encbuf [cast5.BlockSize]byte + decbuf [2 * cast5.BlockSize]byte + block cipher.Block +} + +// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128 +func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(cast5BlockCrypt) + block, err := cast5.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type blowfishBlockCrypt struct { + encbuf [blowfish.BlockSize]byte + decbuf [2 * blowfish.BlockSize]byte + block cipher.Block +} + +// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher) +func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(blowfishBlockCrypt) + block, err := blowfish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type aesBlockCrypt struct { + encbuf [aes.BlockSize]byte + decbuf [2 * aes.BlockSize]byte + block cipher.Block +} + +// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard +func NewAESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(aesBlockCrypt) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type teaBlockCrypt struct { + encbuf [tea.BlockSize]byte + decbuf [2 * tea.BlockSize]byte + block cipher.Block +} + +// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm +func NewTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(teaBlockCrypt) + block, err := tea.NewCipherWithRounds(key, 16) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type xteaBlockCrypt struct { + encbuf [xtea.BlockSize]byte + decbuf [2 * xtea.BlockSize]byte + block cipher.Block +} + +// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA +func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(xteaBlockCrypt) + block, err := xtea.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type simpleXORBlockCrypt struct { + xortbl []byte +} + +// NewSimpleXORBlockCrypt simple xor with key expanding +func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(simpleXORBlockCrypt) + c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New) + return c, nil +} + +func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } +func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } + +type noneBlockCrypt struct{} + +// NewNoneBlockCrypt does nothing but copying +func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) { + return new(noneBlockCrypt), nil +} + +func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) } +func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) } + +// packet encryption with local CFB mode +func encrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + encrypt8(block, dst, src, buf) + case 16: + encrypt16(block, dst, src, buf) + default: + encryptVariant(block, dst, src, buf) + } +} + +// optimized encryption for the ciphers which works in 8-bytes +func encrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:8] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + xor.BytesSrc1(d[0:8], s[0:8], tbl) + block.Encrypt(tbl, d[0:8]) + // 2 + xor.BytesSrc1(d[8:16], s[8:16], tbl) + block.Encrypt(tbl, d[8:16]) + // 3 + xor.BytesSrc1(d[16:24], s[16:24], tbl) + block.Encrypt(tbl, d[16:24]) + // 4 + xor.BytesSrc1(d[24:32], s[24:32], tbl) + block.Encrypt(tbl, d[24:32]) + // 5 + xor.BytesSrc1(d[32:40], s[32:40], tbl) + block.Encrypt(tbl, d[32:40]) + // 6 + xor.BytesSrc1(d[40:48], s[40:48], tbl) + block.Encrypt(tbl, d[40:48]) + // 7 + xor.BytesSrc1(d[48:56], s[48:56], tbl) + block.Encrypt(tbl, d[48:56]) + // 8 + xor.BytesSrc1(d[56:64], s[56:64], tbl) + block.Encrypt(tbl, d[56:64]) + base += 64 + } + + switch left { + case 7: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 6: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 5: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 4: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 3: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 2: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 1: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} + +// optimized encryption for the ciphers which works in 16-bytes +func encrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + xor.BytesSrc1(d[0:16], s[0:16], tbl) + block.Encrypt(tbl, d[0:16]) + // 2 + xor.BytesSrc1(d[16:32], s[16:32], tbl) + block.Encrypt(tbl, d[16:32]) + // 3 + xor.BytesSrc1(d[32:48], s[32:48], tbl) + block.Encrypt(tbl, d[32:48]) + // 4 + xor.BytesSrc1(d[48:64], s[48:64], tbl) + block.Encrypt(tbl, d[48:64]) + // 5 + xor.BytesSrc1(d[64:80], s[64:80], tbl) + block.Encrypt(tbl, d[64:80]) + // 6 + xor.BytesSrc1(d[80:96], s[80:96], tbl) + block.Encrypt(tbl, d[80:96]) + // 7 + xor.BytesSrc1(d[96:112], s[96:112], tbl) + block.Encrypt(tbl, d[96:112]) + // 8 + xor.BytesSrc1(d[112:128], s[112:128], tbl) + block.Encrypt(tbl, d[112:128]) + base += 128 + } + + switch left { + case 7: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 6: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 5: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 4: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 3: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 2: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 1: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} + +func encryptVariant(block cipher.Block, dst, src, buf []byte) { + blocksize := block.BlockSize() + tbl := buf[:blocksize] + block.Encrypt(tbl, initialVector) + n := len(src) / blocksize + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + // 1 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 2 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 3 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 4 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 5 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 6 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 7 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + + // 8 + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + } + + switch left { + case 7: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 6: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 5: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 4: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 3: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 2: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 1: + xor.BytesSrc1(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += blocksize + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} + +// decryption +func decrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + decrypt8(block, dst, src, buf) + case 16: + decrypt16(block, dst, src, buf) + default: + decryptVariant(block, dst, src, buf) + } +} + +func decrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:8] + next := buf[8:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + block.Encrypt(next, s[0:8]) + xor.BytesSrc1(d[0:8], s[0:8], tbl) + // 2 + block.Encrypt(tbl, s[8:16]) + xor.BytesSrc1(d[8:16], s[8:16], next) + // 3 + block.Encrypt(next, s[16:24]) + xor.BytesSrc1(d[16:24], s[16:24], tbl) + // 4 + block.Encrypt(tbl, s[24:32]) + xor.BytesSrc1(d[24:32], s[24:32], next) + // 5 + block.Encrypt(next, s[32:40]) + xor.BytesSrc1(d[32:40], s[32:40], tbl) + // 6 + block.Encrypt(tbl, s[40:48]) + xor.BytesSrc1(d[40:48], s[40:48], next) + // 7 + block.Encrypt(next, s[48:56]) + xor.BytesSrc1(d[48:56], s[48:56], tbl) + // 8 + block.Encrypt(tbl, s[56:64]) + xor.BytesSrc1(d[56:64], s[56:64], next) + base += 64 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 8 + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} + +func decrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:16] + next := buf[16:32] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + block.Encrypt(next, s[0:16]) + xor.BytesSrc1(d[0:16], s[0:16], tbl) + // 2 + block.Encrypt(tbl, s[16:32]) + xor.BytesSrc1(d[16:32], s[16:32], next) + // 3 + block.Encrypt(next, s[32:48]) + xor.BytesSrc1(d[32:48], s[32:48], tbl) + // 4 + block.Encrypt(tbl, s[48:64]) + xor.BytesSrc1(d[48:64], s[48:64], next) + // 5 + block.Encrypt(next, s[64:80]) + xor.BytesSrc1(d[64:80], s[64:80], tbl) + // 6 + block.Encrypt(tbl, s[80:96]) + xor.BytesSrc1(d[80:96], s[80:96], next) + // 7 + block.Encrypt(next, s[96:112]) + xor.BytesSrc1(d[96:112], s[96:112], tbl) + // 8 + block.Encrypt(tbl, s[112:128]) + xor.BytesSrc1(d[112:128], s[112:128], next) + base += 128 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} + +func decryptVariant(block cipher.Block, dst, src, buf []byte) { + blocksize := block.BlockSize() + tbl := buf[:blocksize] + next := buf[blocksize:] + block.Encrypt(tbl, initialVector) + n := len(src) / blocksize + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + // 1 + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + base += blocksize + + // 2 + block.Encrypt(tbl, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], next) + base += blocksize + + // 3 + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + base += blocksize + + // 4 + block.Encrypt(tbl, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], next) + base += blocksize + + // 5 + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + base += blocksize + + // 6 + block.Encrypt(tbl, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], next) + base += blocksize + + // 7 + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + base += blocksize + + // 8 + block.Encrypt(tbl, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], next) + base += blocksize + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 6: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 5: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 4: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 3: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 2: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 1: + block.Encrypt(next, src[base:]) + xor.BytesSrc1(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += blocksize + fallthrough + case 0: + xor.BytesSrc0(dst[base:], src[base:], tbl) + } +} diff --git a/lib/kcp/crypt_test.go b/lib/kcp/crypt_test.go new file mode 100644 index 0000000..2ef4dc8 --- /dev/null +++ b/lib/kcp/crypt_test.go @@ -0,0 +1,289 @@ +package kcp + +import ( + "bytes" + "crypto/aes" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "hash/crc32" + "io" + "testing" +) + +func TestSM4(t *testing.T) { + bc, err := NewSM4BlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestAES(t *testing.T) { + bc, err := NewAESBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestTEA(t *testing.T) { + bc, err := NewTEABlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestXOR(t *testing.T) { + bc, err := NewSimpleXORBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestBlowfish(t *testing.T) { + bc, err := NewBlowfishBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestNone(t *testing.T) { + bc, err := NewNoneBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestCast5(t *testing.T) { + bc, err := NewCast5BlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func Test3DES(t *testing.T) { + bc, err := NewTripleDESBlockCrypt(pass[:24]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestTwofish(t *testing.T) { + bc, err := NewTwofishBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestXTEA(t *testing.T) { + bc, err := NewXTEABlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestSalsa20(t *testing.T) { + bc, err := NewSalsa20BlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func cryptTest(t *testing.T, bc BlockCrypt) { + data := make([]byte, mtuLimit) + io.ReadFull(rand.Reader, data) + dec := make([]byte, mtuLimit) + enc := make([]byte, mtuLimit) + bc.Encrypt(enc, data) + bc.Decrypt(dec, enc) + if !bytes.Equal(data, dec) { + t.Fail() + } +} + +func BenchmarkSM4(b *testing.B) { + bc, err := NewSM4BlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkAES128(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkAES192(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:24]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkAES256(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkTEA(b *testing.B) { + bc, err := NewTEABlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkXOR(b *testing.B) { + bc, err := NewSimpleXORBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkBlowfish(b *testing.B) { + bc, err := NewBlowfishBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkNone(b *testing.B) { + bc, err := NewNoneBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkCast5(b *testing.B) { + bc, err := NewCast5BlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func Benchmark3DES(b *testing.B) { + bc, err := NewTripleDESBlockCrypt(pass[:24]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkTwofish(b *testing.B) { + bc, err := NewTwofishBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkXTEA(b *testing.B) { + bc, err := NewXTEABlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkSalsa20(b *testing.B) { + bc, err := NewSalsa20BlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func benchCrypt(b *testing.B, bc BlockCrypt) { + data := make([]byte, mtuLimit) + io.ReadFull(rand.Reader, data) + dec := make([]byte, mtuLimit) + enc := make([]byte, mtuLimit) + + b.ReportAllocs() + b.SetBytes(int64(len(enc) * 2)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bc.Encrypt(enc, data) + bc.Decrypt(dec, enc) + } +} + +func BenchmarkCRC32(b *testing.B) { + content := make([]byte, 1024) + b.SetBytes(int64(len(content))) + for i := 0; i < b.N; i++ { + crc32.ChecksumIEEE(content) + } +} + +func BenchmarkCsprngSystem(b *testing.B) { + data := make([]byte, md5.Size) + b.SetBytes(int64(len(data))) + + for i := 0; i < b.N; i++ { + io.ReadFull(rand.Reader, data) + } +} + +func BenchmarkCsprngMD5(b *testing.B) { + var data [md5.Size]byte + b.SetBytes(md5.Size) + + for i := 0; i < b.N; i++ { + data = md5.Sum(data[:]) + } +} +func BenchmarkCsprngSHA1(b *testing.B) { + var data [sha1.Size]byte + b.SetBytes(sha1.Size) + + for i := 0; i < b.N; i++ { + data = sha1.Sum(data[:]) + } +} + +func BenchmarkCsprngNonceMD5(b *testing.B) { + var ng nonceMD5 + ng.Init() + b.SetBytes(md5.Size) + data := make([]byte, md5.Size) + for i := 0; i < b.N; i++ { + ng.Fill(data) + } +} + +func BenchmarkCsprngNonceAES128(b *testing.B) { + var ng nonceAES128 + ng.Init() + + b.SetBytes(aes.BlockSize) + data := make([]byte, aes.BlockSize) + for i := 0; i < b.N; i++ { + ng.Fill(data) + } +} diff --git a/lib/kcp/entropy.go b/lib/kcp/entropy.go new file mode 100644 index 0000000..156c1cd --- /dev/null +++ b/lib/kcp/entropy.go @@ -0,0 +1,52 @@ +package kcp + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "io" +) + +// Entropy defines a entropy source +type Entropy interface { + Init() + Fill(nonce []byte) +} + +// nonceMD5 nonce generator for packet header +type nonceMD5 struct { + seed [md5.Size]byte +} + +func (n *nonceMD5) Init() { /*nothing required*/ } + +func (n *nonceMD5) Fill(nonce []byte) { + if n.seed[0] == 0 { // entropy update + io.ReadFull(rand.Reader, n.seed[:]) + } + n.seed = md5.Sum(n.seed[:]) + copy(nonce, n.seed[:]) +} + +// nonceAES128 nonce generator for packet headers +type nonceAES128 struct { + seed [aes.BlockSize]byte + block cipher.Block +} + +func (n *nonceAES128) Init() { + var key [16]byte //aes-128 + io.ReadFull(rand.Reader, key[:]) + io.ReadFull(rand.Reader, n.seed[:]) + block, _ := aes.NewCipher(key[:]) + n.block = block +} + +func (n *nonceAES128) Fill(nonce []byte) { + if n.seed[0] == 0 { // entropy update + io.ReadFull(rand.Reader, n.seed[:]) + } + n.block.Encrypt(n.seed[:], n.seed[:]) + copy(nonce, n.seed[:]) +} diff --git a/lib/kcp/fec.go b/lib/kcp/fec.go new file mode 100644 index 0000000..366637b --- /dev/null +++ b/lib/kcp/fec.go @@ -0,0 +1,311 @@ +package kcp + +import ( + "encoding/binary" + "sync/atomic" + + "github.com/klauspost/reedsolomon" +) + +const ( + fecHeaderSize = 6 + fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size + typeData = 0xf1 + typeFEC = 0xf2 +) + +type ( + // fecPacket is a decoded FEC packet + fecPacket struct { + seqid uint32 + flag uint16 + data []byte + } + + // fecDecoder for decoding incoming packets + fecDecoder struct { + rxlimit int // queue size limit + dataShards int + parityShards int + shardSize int + rx []fecPacket // ordered receive queue + + // caches + decodeCache [][]byte + flagCache []bool + + // zeros + zeros []byte + + // RS decoder + codec reedsolomon.Encoder + } +) + +func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder { + if dataShards <= 0 || parityShards <= 0 { + return nil + } + if rxlimit < dataShards+parityShards { + return nil + } + + dec := new(fecDecoder) + dec.rxlimit = rxlimit + dec.dataShards = dataShards + dec.parityShards = parityShards + dec.shardSize = dataShards + parityShards + codec, err := reedsolomon.New(dataShards, parityShards) + if err != nil { + return nil + } + dec.codec = codec + dec.decodeCache = make([][]byte, dec.shardSize) + dec.flagCache = make([]bool, dec.shardSize) + dec.zeros = make([]byte, mtuLimit) + return dec +} + +// decodeBytes a fec packet +func (dec *fecDecoder) decodeBytes(data []byte) fecPacket { + var pkt fecPacket + pkt.seqid = binary.LittleEndian.Uint32(data) + pkt.flag = binary.LittleEndian.Uint16(data[4:]) + // allocate memory & copy + buf := xmitBuf.Get().([]byte)[:len(data)-6] + copy(buf, data[6:]) + pkt.data = buf + return pkt +} + +// decode a fec packet +func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) { + // insertion + n := len(dec.rx) - 1 + insertIdx := 0 + for i := n; i >= 0; i-- { + if pkt.seqid == dec.rx[i].seqid { // de-duplicate + xmitBuf.Put(pkt.data) + return nil + } else if _itimediff(pkt.seqid, dec.rx[i].seqid) > 0 { // insertion + insertIdx = i + 1 + break + } + } + + // insert into ordered rx queue + if insertIdx == n+1 { + dec.rx = append(dec.rx, pkt) + } else { + dec.rx = append(dec.rx, fecPacket{}) + copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right + dec.rx[insertIdx] = pkt + } + + // shard range for current packet + shardBegin := pkt.seqid - pkt.seqid%uint32(dec.shardSize) + shardEnd := shardBegin + uint32(dec.shardSize) - 1 + + // max search range in ordered queue for current shard + searchBegin := insertIdx - int(pkt.seqid%uint32(dec.shardSize)) + if searchBegin < 0 { + searchBegin = 0 + } + searchEnd := searchBegin + dec.shardSize - 1 + if searchEnd >= len(dec.rx) { + searchEnd = len(dec.rx) - 1 + } + + // re-construct datashards + if searchEnd-searchBegin+1 >= dec.dataShards { + var numshard, numDataShard, first, maxlen int + + // zero caches + shards := dec.decodeCache + shardsflag := dec.flagCache + for k := range dec.decodeCache { + shards[k] = nil + shardsflag[k] = false + } + + // shard assembly + for i := searchBegin; i <= searchEnd; i++ { + seqid := dec.rx[i].seqid + if _itimediff(seqid, shardEnd) > 0 { + break + } else if _itimediff(seqid, shardBegin) >= 0 { + shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data + shardsflag[seqid%uint32(dec.shardSize)] = true + numshard++ + if dec.rx[i].flag == typeData { + numDataShard++ + } + if numshard == 1 { + first = i + } + if len(dec.rx[i].data) > maxlen { + maxlen = len(dec.rx[i].data) + } + } + } + + if numDataShard == dec.dataShards { + // case 1: no loss on data shards + dec.rx = dec.freeRange(first, numshard, dec.rx) + } else if numshard >= dec.dataShards { + // case 2: loss on data shards, but it's recoverable from parity shards + for k := range shards { + if shards[k] != nil { + dlen := len(shards[k]) + shards[k] = shards[k][:maxlen] + copy(shards[k][dlen:], dec.zeros) + } + } + if err := dec.codec.ReconstructData(shards); err == nil { + for k := range shards[:dec.dataShards] { + if !shardsflag[k] { + recovered = append(recovered, shards[k]) + } + } + } + dec.rx = dec.freeRange(first, numshard, dec.rx) + } + } + + // keep rxlimit + if len(dec.rx) > dec.rxlimit { + if dec.rx[0].flag == typeData { // track the unrecoverable data + atomic.AddUint64(&DefaultSnmp.FECShortShards, 1) + } + dec.rx = dec.freeRange(0, 1, dec.rx) + } + return +} + +// free a range of fecPacket, and zero for GC recycling +func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket { + for i := first; i < first+n; i++ { // recycle buffer + xmitBuf.Put(q[i].data) + } + copy(q[first:], q[first+n:]) + for i := 0; i < n; i++ { // dereference data + q[len(q)-1-i].data = nil + } + return q[:len(q)-n] +} + +type ( + // fecEncoder for encoding outgoing packets + fecEncoder struct { + dataShards int + parityShards int + shardSize int + paws uint32 // Protect Against Wrapped Sequence numbers + next uint32 // next seqid + + shardCount int // count the number of datashards collected + maxSize int // track maximum data length in datashard + + headerOffset int // FEC header offset + payloadOffset int // FEC payload offset + + // caches + shardCache [][]byte + encodeCache [][]byte + + // zeros + zeros []byte + + // RS encoder + codec reedsolomon.Encoder + } +) + +func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder { + if dataShards <= 0 || parityShards <= 0 { + return nil + } + enc := new(fecEncoder) + enc.dataShards = dataShards + enc.parityShards = parityShards + enc.shardSize = dataShards + parityShards + enc.paws = (0xffffffff/uint32(enc.shardSize) - 1) * uint32(enc.shardSize) + enc.headerOffset = offset + enc.payloadOffset = enc.headerOffset + fecHeaderSize + + codec, err := reedsolomon.New(dataShards, parityShards) + if err != nil { + return nil + } + enc.codec = codec + + // caches + enc.encodeCache = make([][]byte, enc.shardSize) + enc.shardCache = make([][]byte, enc.shardSize) + for k := range enc.shardCache { + enc.shardCache[k] = make([]byte, mtuLimit) + } + enc.zeros = make([]byte, mtuLimit) + return enc +} + +// encodes the packet, outputs parity shards if we have collected quorum datashards +// notice: the contents of 'ps' will be re-written in successive calling +func (enc *fecEncoder) encode(b []byte) (ps [][]byte) { + enc.markData(b[enc.headerOffset:]) + binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:]))) + + // copy data to fec datashards + sz := len(b) + enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz] + copy(enc.shardCache[enc.shardCount], b) + enc.shardCount++ + + // track max datashard length + if sz > enc.maxSize { + enc.maxSize = sz + } + + // Generation of Reed-Solomon Erasure Code + if enc.shardCount == enc.dataShards { + // fill '0' into the tail of each datashard + for i := 0; i < enc.dataShards; i++ { + shard := enc.shardCache[i] + slen := len(shard) + copy(shard[slen:enc.maxSize], enc.zeros) + } + + // construct equal-sized slice with stripped header + cache := enc.encodeCache + for k := range cache { + cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize] + } + + // encoding + if err := enc.codec.Encode(cache); err == nil { + ps = enc.shardCache[enc.dataShards:] + for k := range ps { + enc.markFEC(ps[k][enc.headerOffset:]) + ps[k] = ps[k][:enc.maxSize] + } + } + + // counters resetting + enc.shardCount = 0 + enc.maxSize = 0 + } + + return +} + +func (enc *fecEncoder) markData(data []byte) { + binary.LittleEndian.PutUint32(data, enc.next) + binary.LittleEndian.PutUint16(data[4:], typeData) + enc.next++ +} + +func (enc *fecEncoder) markFEC(data []byte) { + binary.LittleEndian.PutUint32(data, enc.next) + binary.LittleEndian.PutUint16(data[4:], typeFEC) + enc.next = (enc.next + 1) % enc.paws +} diff --git a/lib/kcp/fec_test.go b/lib/kcp/fec_test.go new file mode 100644 index 0000000..9279f00 --- /dev/null +++ b/lib/kcp/fec_test.go @@ -0,0 +1,43 @@ +package kcp + +import ( + "math/rand" + "testing" +) + +func BenchmarkFECDecode(b *testing.B) { + const dataSize = 10 + const paritySize = 3 + const payLoad = 1500 + decoder := newFECDecoder(1024, dataSize, paritySize) + b.ReportAllocs() + b.SetBytes(payLoad) + for i := 0; i < b.N; i++ { + if rand.Int()%(dataSize+paritySize) == 0 { // random loss + continue + } + var pkt fecPacket + pkt.seqid = uint32(i) + if i%(dataSize+paritySize) >= dataSize { + pkt.flag = typeFEC + } else { + pkt.flag = typeData + } + pkt.data = make([]byte, payLoad) + decoder.decode(pkt) + } +} + +func BenchmarkFECEncode(b *testing.B) { + const dataSize = 10 + const paritySize = 3 + const payLoad = 1500 + + b.ReportAllocs() + b.SetBytes(payLoad) + encoder := newFECEncoder(dataSize, paritySize, 0) + for i := 0; i < b.N; i++ { + data := make([]byte, payLoad) + encoder.encode(data) + } +} diff --git a/lib/kcp/kcp.go b/lib/kcp/kcp.go new file mode 100644 index 0000000..6bfb04e --- /dev/null +++ b/lib/kcp/kcp.go @@ -0,0 +1,1012 @@ +// Package kcp - A Fast and Reliable ARQ Protocol +package kcp + +import ( + "encoding/binary" + "sync/atomic" +) + +const ( + IKCP_RTO_NDL = 30 // no delay min rto + IKCP_RTO_MIN = 100 // normal min rto + IKCP_RTO_DEF = 200 + IKCP_RTO_MAX = 60000 + IKCP_CMD_PUSH = 81 // cmd: push data + IKCP_CMD_ACK = 82 // cmd: ack + IKCP_CMD_WASK = 83 // cmd: window probe (ask) + IKCP_CMD_WINS = 84 // cmd: window size (tell) + IKCP_ASK_SEND = 1 // need to send IKCP_CMD_WASK + IKCP_ASK_TELL = 2 // need to send IKCP_CMD_WINS + IKCP_WND_SND = 32 + IKCP_WND_RCV = 32 + IKCP_MTU_DEF = 1400 + IKCP_ACK_FAST = 3 + IKCP_INTERVAL = 100 + IKCP_OVERHEAD = 24 + IKCP_DEADLINK = 20 + IKCP_THRESH_INIT = 2 + IKCP_THRESH_MIN = 2 + IKCP_PROBE_INIT = 7000 // 7 secs to probe window size + IKCP_PROBE_LIMIT = 120000 // up to 120 secs to probe window +) + +// output_callback is a prototype which ought capture conn and call conn.Write +type output_callback func(buf []byte, size int) + +/* encode 8 bits unsigned int */ +func ikcp_encode8u(p []byte, c byte) []byte { + p[0] = c + return p[1:] +} + +/* decode 8 bits unsigned int */ +func ikcp_decode8u(p []byte, c *byte) []byte { + *c = p[0] + return p[1:] +} + +/* encode 16 bits unsigned int (lsb) */ +func ikcp_encode16u(p []byte, w uint16) []byte { + binary.LittleEndian.PutUint16(p, w) + return p[2:] +} + +/* decode 16 bits unsigned int (lsb) */ +func ikcp_decode16u(p []byte, w *uint16) []byte { + *w = binary.LittleEndian.Uint16(p) + return p[2:] +} + +/* encode 32 bits unsigned int (lsb) */ +func ikcp_encode32u(p []byte, l uint32) []byte { + binary.LittleEndian.PutUint32(p, l) + return p[4:] +} + +/* decode 32 bits unsigned int (lsb) */ +func ikcp_decode32u(p []byte, l *uint32) []byte { + *l = binary.LittleEndian.Uint32(p) + return p[4:] +} + +func _imin_(a, b uint32) uint32 { + if a <= b { + return a + } + return b +} + +func _imax_(a, b uint32) uint32 { + if a >= b { + return a + } + return b +} + +func _ibound_(lower, middle, upper uint32) uint32 { + return _imin_(_imax_(lower, middle), upper) +} + +func _itimediff(later, earlier uint32) int32 { + return (int32)(later - earlier) +} + +// segment defines a KCP segment +type segment struct { + conv uint32 + cmd uint8 + frg uint8 + wnd uint16 + ts uint32 + sn uint32 + una uint32 + rto uint32 + xmit uint32 + resendts uint32 + fastack uint32 + acked uint32 // mark if the seg has acked + data []byte +} + +// encode a segment into buffer +func (seg *segment) encode(ptr []byte) []byte { + ptr = ikcp_encode32u(ptr, seg.conv) + ptr = ikcp_encode8u(ptr, seg.cmd) + ptr = ikcp_encode8u(ptr, seg.frg) + ptr = ikcp_encode16u(ptr, seg.wnd) + ptr = ikcp_encode32u(ptr, seg.ts) + ptr = ikcp_encode32u(ptr, seg.sn) + ptr = ikcp_encode32u(ptr, seg.una) + ptr = ikcp_encode32u(ptr, uint32(len(seg.data))) + atomic.AddUint64(&DefaultSnmp.OutSegs, 1) + return ptr +} + +// KCP defines a single KCP connection +type KCP struct { + conv, mtu, mss, state uint32 + snd_una, snd_nxt, rcv_nxt uint32 + ssthresh uint32 + rx_rttvar, rx_srtt int32 + rx_rto, rx_minrto uint32 + snd_wnd, rcv_wnd, rmt_wnd, cwnd, probe uint32 + interval, ts_flush uint32 + nodelay, updated uint32 + ts_probe, probe_wait uint32 + dead_link, incr uint32 + + fastresend int32 + nocwnd, stream int32 + + snd_queue []segment + rcv_queue []segment + snd_buf []segment + rcv_buf []segment + + acklist []ackItem + + buffer []byte + output output_callback +} + +type ackItem struct { + sn uint32 + ts uint32 +} + +// NewKCP create a new kcp control object, 'conv' must equal in two endpoint +// from the same connection. +func NewKCP(conv uint32, output output_callback) *KCP { + kcp := new(KCP) + kcp.conv = conv + kcp.snd_wnd = IKCP_WND_SND + kcp.rcv_wnd = IKCP_WND_RCV + kcp.rmt_wnd = IKCP_WND_RCV + kcp.mtu = IKCP_MTU_DEF + kcp.mss = kcp.mtu - IKCP_OVERHEAD + kcp.buffer = make([]byte, (kcp.mtu+IKCP_OVERHEAD)*3) + kcp.rx_rto = IKCP_RTO_DEF + kcp.rx_minrto = IKCP_RTO_MIN + kcp.interval = IKCP_INTERVAL + kcp.ts_flush = IKCP_INTERVAL + kcp.ssthresh = IKCP_THRESH_INIT + kcp.dead_link = IKCP_DEADLINK + kcp.output = output + return kcp +} + +// newSegment creates a KCP segment +func (kcp *KCP) newSegment(size int) (seg segment) { + seg.data = xmitBuf.Get().([]byte)[:size] + return +} + +// delSegment recycles a KCP segment +func (kcp *KCP) delSegment(seg *segment) { + if seg.data != nil { + xmitBuf.Put(seg.data) + seg.data = nil + } +} + +// PeekSize checks the size of next message in the recv queue +func (kcp *KCP) PeekSize() (length int) { + if len(kcp.rcv_queue) == 0 { + return -1 + } + + seg := &kcp.rcv_queue[0] + if seg.frg == 0 { + return len(seg.data) + } + + if len(kcp.rcv_queue) < int(seg.frg+1) { + return -1 + } + + for k := range kcp.rcv_queue { + seg := &kcp.rcv_queue[k] + length += len(seg.data) + if seg.frg == 0 { + break + } + } + return +} + +// Recv is user/upper level recv: returns size, returns below zero for EAGAIN +func (kcp *KCP) Recv(buffer []byte) (n int) { + if len(kcp.rcv_queue) == 0 { + return -1 + } + + peeksize := kcp.PeekSize() + if peeksize < 0 { + return -2 + } + + if peeksize > len(buffer) { + return -3 + } + + var fast_recover bool + if len(kcp.rcv_queue) >= int(kcp.rcv_wnd) { + fast_recover = true + } + + // merge fragment + count := 0 + for k := range kcp.rcv_queue { + seg := &kcp.rcv_queue[k] + copy(buffer, seg.data) + buffer = buffer[len(seg.data):] + n += len(seg.data) + count++ + kcp.delSegment(seg) + if seg.frg == 0 { + break + } + } + if count > 0 { + kcp.rcv_queue = kcp.remove_front(kcp.rcv_queue, count) + } + + // move available data from rcv_buf -> rcv_queue + count = 0 + for k := range kcp.rcv_buf { + seg := &kcp.rcv_buf[k] + if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { + kcp.rcv_nxt++ + count++ + } else { + break + } + } + + if count > 0 { + kcp.rcv_queue = append(kcp.rcv_queue, kcp.rcv_buf[:count]...) + kcp.rcv_buf = kcp.remove_front(kcp.rcv_buf, count) + } + + // fast recover + if len(kcp.rcv_queue) < int(kcp.rcv_wnd) && fast_recover { + // ready to send back IKCP_CMD_WINS in ikcp_flush + // tell remote my window size + kcp.probe |= IKCP_ASK_TELL + } + return +} + +// Send is user/upper level send, returns below zero for error +func (kcp *KCP) Send(buffer []byte) int { + var count int + if len(buffer) == 0 { + return -1 + } + + // append to previous segment in streaming mode (if possible) + if kcp.stream != 0 { + n := len(kcp.snd_queue) + if n > 0 { + seg := &kcp.snd_queue[n-1] + if len(seg.data) < int(kcp.mss) { + capacity := int(kcp.mss) - len(seg.data) + extend := capacity + if len(buffer) < capacity { + extend = len(buffer) + } + + // grow slice, the underlying cap is guaranteed to + // be larger than kcp.mss + oldlen := len(seg.data) + seg.data = seg.data[:oldlen+extend] + copy(seg.data[oldlen:], buffer) + buffer = buffer[extend:] + } + } + + if len(buffer) == 0 { + return 0 + } + } + + if len(buffer) <= int(kcp.mss) { + count = 1 + } else { + count = (len(buffer) + int(kcp.mss) - 1) / int(kcp.mss) + } + + if count > 255 { + return -2 + } + + if count == 0 { + count = 1 + } + + for i := 0; i < count; i++ { + var size int + if len(buffer) > int(kcp.mss) { + size = int(kcp.mss) + } else { + size = len(buffer) + } + seg := kcp.newSegment(size) + copy(seg.data, buffer[:size]) + if kcp.stream == 0 { // message mode + seg.frg = uint8(count - i - 1) + } else { // stream mode + seg.frg = 0 + } + kcp.snd_queue = append(kcp.snd_queue, seg) + buffer = buffer[size:] + } + return 0 +} + +func (kcp *KCP) update_ack(rtt int32) { + // https://tools.ietf.org/html/rfc6298 + var rto uint32 + if kcp.rx_srtt == 0 { + kcp.rx_srtt = rtt + kcp.rx_rttvar = rtt >> 1 + } else { + delta := rtt - kcp.rx_srtt + kcp.rx_srtt += delta >> 3 + if delta < 0 { + delta = -delta + } + if rtt < kcp.rx_srtt-kcp.rx_rttvar { + // if the new RTT sample is below the bottom of the range of + // what an RTT measurement is expected to be. + // give an 8x reduced weight versus its normal weighting + kcp.rx_rttvar += (delta - kcp.rx_rttvar) >> 5 + } else { + kcp.rx_rttvar += (delta - kcp.rx_rttvar) >> 2 + } + } + rto = uint32(kcp.rx_srtt) + _imax_(kcp.interval, uint32(kcp.rx_rttvar)<<2) + kcp.rx_rto = _ibound_(kcp.rx_minrto, rto, IKCP_RTO_MAX) +} + +func (kcp *KCP) shrink_buf() { + if len(kcp.snd_buf) > 0 { + seg := &kcp.snd_buf[0] + kcp.snd_una = seg.sn + } else { + kcp.snd_una = kcp.snd_nxt + } +} + +func (kcp *KCP) parse_ack(sn uint32) { + if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { + return + } + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if sn == seg.sn { + seg.acked = 1 + kcp.delSegment(seg) + break + } + if _itimediff(sn, seg.sn) < 0 { + break + } + } +} + +func (kcp *KCP) parse_fastack(sn, ts uint32) { + if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { + return + } + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if _itimediff(sn, seg.sn) < 0 { + break + } else if sn != seg.sn && _itimediff(seg.ts, ts) <= 0 { + seg.fastack++ + } + } +} + +func (kcp *KCP) parse_una(una uint32) { + count := 0 + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if _itimediff(una, seg.sn) > 0 { + kcp.delSegment(seg) + count++ + } else { + break + } + } + if count > 0 { + kcp.snd_buf = kcp.remove_front(kcp.snd_buf, count) + } +} + +// ack append +func (kcp *KCP) ack_push(sn, ts uint32) { + kcp.acklist = append(kcp.acklist, ackItem{sn, ts}) +} + +// returns true if data has repeated +func (kcp *KCP) parse_data(newseg segment) bool { + sn := newseg.sn + if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) >= 0 || + _itimediff(sn, kcp.rcv_nxt) < 0 { + return true + } + + n := len(kcp.rcv_buf) - 1 + insert_idx := 0 + repeat := false + for i := n; i >= 0; i-- { + seg := &kcp.rcv_buf[i] + if seg.sn == sn { + repeat = true + break + } + if _itimediff(sn, seg.sn) > 0 { + insert_idx = i + 1 + break + } + } + + if !repeat { + // replicate the content if it's new + dataCopy := xmitBuf.Get().([]byte)[:len(newseg.data)] + copy(dataCopy, newseg.data) + newseg.data = dataCopy + + if insert_idx == n+1 { + kcp.rcv_buf = append(kcp.rcv_buf, newseg) + } else { + kcp.rcv_buf = append(kcp.rcv_buf, segment{}) + copy(kcp.rcv_buf[insert_idx+1:], kcp.rcv_buf[insert_idx:]) + kcp.rcv_buf[insert_idx] = newseg + } + } + + // move available data from rcv_buf -> rcv_queue + count := 0 + for k := range kcp.rcv_buf { + seg := &kcp.rcv_buf[k] + if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { + kcp.rcv_nxt++ + count++ + } else { + break + } + } + if count > 0 { + kcp.rcv_queue = append(kcp.rcv_queue, kcp.rcv_buf[:count]...) + kcp.rcv_buf = kcp.remove_front(kcp.rcv_buf, count) + } + + return repeat +} + +// Input when you received a low level packet (eg. UDP packet), call it +// regular indicates a regular packet has received(not from FEC) +func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int { + snd_una := kcp.snd_una + if len(data) < IKCP_OVERHEAD { + return -1 + } + + var latest uint32 // the latest ack packet + var flag int + var inSegs uint64 + + for { + var ts, sn, length, una, conv uint32 + var wnd uint16 + var cmd, frg uint8 + + if len(data) < int(IKCP_OVERHEAD) { + break + } + + data = ikcp_decode32u(data, &conv) + if conv != kcp.conv { + return -1 + } + + data = ikcp_decode8u(data, &cmd) + data = ikcp_decode8u(data, &frg) + data = ikcp_decode16u(data, &wnd) + data = ikcp_decode32u(data, &ts) + data = ikcp_decode32u(data, &sn) + data = ikcp_decode32u(data, &una) + data = ikcp_decode32u(data, &length) + if len(data) < int(length) { + return -2 + } + + if cmd != IKCP_CMD_PUSH && cmd != IKCP_CMD_ACK && + cmd != IKCP_CMD_WASK && cmd != IKCP_CMD_WINS { + return -3 + } + + // only trust window updates from regular packets. i.e: latest update + if regular { + kcp.rmt_wnd = uint32(wnd) + } + kcp.parse_una(una) + kcp.shrink_buf() + + if cmd == IKCP_CMD_ACK { + kcp.parse_ack(sn) + kcp.parse_fastack(sn, ts) + flag |= 1 + latest = ts + } else if cmd == IKCP_CMD_PUSH { + repeat := true + if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) < 0 { + kcp.ack_push(sn, ts) + if _itimediff(sn, kcp.rcv_nxt) >= 0 { + var seg segment + seg.conv = conv + seg.cmd = cmd + seg.frg = frg + seg.wnd = wnd + seg.ts = ts + seg.sn = sn + seg.una = una + seg.data = data[:length] // delayed data copying + repeat = kcp.parse_data(seg) + } + } + if regular && repeat { + atomic.AddUint64(&DefaultSnmp.RepeatSegs, 1) + } + } else if cmd == IKCP_CMD_WASK { + // ready to send back IKCP_CMD_WINS in Ikcp_flush + // tell remote my window size + kcp.probe |= IKCP_ASK_TELL + } else if cmd == IKCP_CMD_WINS { + // do nothing + } else { + return -3 + } + + inSegs++ + data = data[length:] + } + atomic.AddUint64(&DefaultSnmp.InSegs, inSegs) + + // update rtt with the latest ts + // ignore the FEC packet + if flag != 0 && regular { + current := currentMs() + if _itimediff(current, latest) >= 0 { + kcp.update_ack(_itimediff(current, latest)) + } + } + + // cwnd update when packet arrived + if kcp.nocwnd == 0 { + if _itimediff(kcp.snd_una, snd_una) > 0 { + if kcp.cwnd < kcp.rmt_wnd { + mss := kcp.mss + if kcp.cwnd < kcp.ssthresh { + kcp.cwnd++ + kcp.incr += mss + } else { + if kcp.incr < mss { + kcp.incr = mss + } + kcp.incr += (mss*mss)/kcp.incr + (mss / 16) + if (kcp.cwnd+1)*mss <= kcp.incr { + kcp.cwnd++ + } + } + if kcp.cwnd > kcp.rmt_wnd { + kcp.cwnd = kcp.rmt_wnd + kcp.incr = kcp.rmt_wnd * mss + } + } + } + } + + if ackNoDelay && len(kcp.acklist) > 0 { // ack immediately + kcp.flush(true) + } + return 0 +} + +func (kcp *KCP) wnd_unused() uint16 { + if len(kcp.rcv_queue) < int(kcp.rcv_wnd) { + return uint16(int(kcp.rcv_wnd) - len(kcp.rcv_queue)) + } + return 0 +} + +// flush pending data +func (kcp *KCP) flush(ackOnly bool) uint32 { + var seg segment + seg.conv = kcp.conv + seg.cmd = IKCP_CMD_ACK + seg.wnd = kcp.wnd_unused() + seg.una = kcp.rcv_nxt + + buffer := kcp.buffer + // flush acknowledges + ptr := buffer + for i, ack := range kcp.acklist { + size := len(buffer) - len(ptr) + if size+IKCP_OVERHEAD > int(kcp.mtu) { + kcp.output(buffer, size) + ptr = buffer + } + // filter jitters caused by bufferbloat + if ack.sn >= kcp.rcv_nxt || len(kcp.acklist)-1 == i { + seg.sn, seg.ts = ack.sn, ack.ts + ptr = seg.encode(ptr) + } + } + kcp.acklist = kcp.acklist[0:0] + + if ackOnly { // flash remain ack segments + size := len(buffer) - len(ptr) + if size > 0 { + kcp.output(buffer, size) + } + return kcp.interval + } + + // probe window size (if remote window size equals zero) + if kcp.rmt_wnd == 0 { + current := currentMs() + if kcp.probe_wait == 0 { + kcp.probe_wait = IKCP_PROBE_INIT + kcp.ts_probe = current + kcp.probe_wait + } else { + if _itimediff(current, kcp.ts_probe) >= 0 { + if kcp.probe_wait < IKCP_PROBE_INIT { + kcp.probe_wait = IKCP_PROBE_INIT + } + kcp.probe_wait += kcp.probe_wait / 2 + if kcp.probe_wait > IKCP_PROBE_LIMIT { + kcp.probe_wait = IKCP_PROBE_LIMIT + } + kcp.ts_probe = current + kcp.probe_wait + kcp.probe |= IKCP_ASK_SEND + } + } + } else { + kcp.ts_probe = 0 + kcp.probe_wait = 0 + } + + // flush window probing commands + if (kcp.probe & IKCP_ASK_SEND) != 0 { + seg.cmd = IKCP_CMD_WASK + size := len(buffer) - len(ptr) + if size+IKCP_OVERHEAD > int(kcp.mtu) { + kcp.output(buffer, size) + ptr = buffer + } + ptr = seg.encode(ptr) + } + + // flush window probing commands + if (kcp.probe & IKCP_ASK_TELL) != 0 { + seg.cmd = IKCP_CMD_WINS + size := len(buffer) - len(ptr) + if size+IKCP_OVERHEAD > int(kcp.mtu) { + kcp.output(buffer, size) + ptr = buffer + } + ptr = seg.encode(ptr) + } + + kcp.probe = 0 + + // calculate window size + cwnd := _imin_(kcp.snd_wnd, kcp.rmt_wnd) + if kcp.nocwnd == 0 { + cwnd = _imin_(kcp.cwnd, cwnd) + } + + // sliding window, controlled by snd_nxt && sna_una+cwnd + newSegsCount := 0 + for k := range kcp.snd_queue { + if _itimediff(kcp.snd_nxt, kcp.snd_una+cwnd) >= 0 { + break + } + newseg := kcp.snd_queue[k] + newseg.conv = kcp.conv + newseg.cmd = IKCP_CMD_PUSH + newseg.sn = kcp.snd_nxt + kcp.snd_buf = append(kcp.snd_buf, newseg) + kcp.snd_nxt++ + newSegsCount++ + } + if newSegsCount > 0 { + kcp.snd_queue = kcp.remove_front(kcp.snd_queue, newSegsCount) + } + + // calculate resent + resent := uint32(kcp.fastresend) + if kcp.fastresend <= 0 { + resent = 0xffffffff + } + + // check for retransmissions + current := currentMs() + var change, lost, lostSegs, fastRetransSegs, earlyRetransSegs uint64 + minrto := int32(kcp.interval) + + ref := kcp.snd_buf[:len(kcp.snd_buf)] // for bounds check elimination + for k := range ref { + segment := &ref[k] + needsend := false + if segment.acked == 1 { + continue + } + if segment.xmit == 0 { // initial transmit + needsend = true + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + } else if _itimediff(current, segment.resendts) >= 0 { // RTO + needsend = true + if kcp.nodelay == 0 { + segment.rto += kcp.rx_rto + } else { + segment.rto += kcp.rx_rto / 2 + } + segment.resendts = current + segment.rto + lost++ + lostSegs++ + } else if segment.fastack >= resent { // fast retransmit + needsend = true + segment.fastack = 0 + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + change++ + fastRetransSegs++ + } else if segment.fastack > 0 && newSegsCount == 0 { // early retransmit + needsend = true + segment.fastack = 0 + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + change++ + earlyRetransSegs++ + } + + if needsend { + current = currentMs() // time update for a blocking call + segment.xmit++ + segment.ts = current + segment.wnd = seg.wnd + segment.una = seg.una + + size := len(buffer) - len(ptr) + need := IKCP_OVERHEAD + len(segment.data) + + if size+need > int(kcp.mtu) { + kcp.output(buffer, size) + ptr = buffer + } + + ptr = segment.encode(ptr) + copy(ptr, segment.data) + ptr = ptr[len(segment.data):] + + if segment.xmit >= kcp.dead_link { + kcp.state = 0xFFFFFFFF + } + } + + // get the nearest rto + if rto := _itimediff(segment.resendts, current); rto > 0 && rto < minrto { + minrto = rto + } + } + + // flash remain segments + size := len(buffer) - len(ptr) + if size > 0 { + kcp.output(buffer, size) + } + + // counter updates + sum := lostSegs + if lostSegs > 0 { + atomic.AddUint64(&DefaultSnmp.LostSegs, lostSegs) + } + if fastRetransSegs > 0 { + atomic.AddUint64(&DefaultSnmp.FastRetransSegs, fastRetransSegs) + sum += fastRetransSegs + } + if earlyRetransSegs > 0 { + atomic.AddUint64(&DefaultSnmp.EarlyRetransSegs, earlyRetransSegs) + sum += earlyRetransSegs + } + if sum > 0 { + atomic.AddUint64(&DefaultSnmp.RetransSegs, sum) + } + + // cwnd update + if kcp.nocwnd == 0 { + // update ssthresh + // rate halving, https://tools.ietf.org/html/rfc6937 + if change > 0 { + inflight := kcp.snd_nxt - kcp.snd_una + kcp.ssthresh = inflight / 2 + if kcp.ssthresh < IKCP_THRESH_MIN { + kcp.ssthresh = IKCP_THRESH_MIN + } + kcp.cwnd = kcp.ssthresh + resent + kcp.incr = kcp.cwnd * kcp.mss + } + + // congestion control, https://tools.ietf.org/html/rfc5681 + if lost > 0 { + kcp.ssthresh = cwnd / 2 + if kcp.ssthresh < IKCP_THRESH_MIN { + kcp.ssthresh = IKCP_THRESH_MIN + } + kcp.cwnd = 1 + kcp.incr = kcp.mss + } + + if kcp.cwnd < 1 { + kcp.cwnd = 1 + kcp.incr = kcp.mss + } + } + + return uint32(minrto) +} + +// Update updates state (call it repeatedly, every 10ms-100ms), or you can ask +// ikcp_check when to call it again (without ikcp_input/_send calling). +// 'current' - current timestamp in millisec. +func (kcp *KCP) Update() { + var slap int32 + + current := currentMs() + if kcp.updated == 0 { + kcp.updated = 1 + kcp.ts_flush = current + } + + slap = _itimediff(current, kcp.ts_flush) + + if slap >= 10000 || slap < -10000 { + kcp.ts_flush = current + slap = 0 + } + + if slap >= 0 { + kcp.ts_flush += kcp.interval + if _itimediff(current, kcp.ts_flush) >= 0 { + kcp.ts_flush = current + kcp.interval + } + kcp.flush(false) + } +} + +// Check determines when should you invoke ikcp_update: +// returns when you should invoke ikcp_update in millisec, if there +// is no ikcp_input/_send calling. you can call ikcp_update in that +// time, instead of call update repeatly. +// Important to reduce unnacessary ikcp_update invoking. use it to +// schedule ikcp_update (eg. implementing an epoll-like mechanism, +// or optimize ikcp_update when handling massive kcp connections) +func (kcp *KCP) Check() uint32 { + current := currentMs() + ts_flush := kcp.ts_flush + tm_flush := int32(0x7fffffff) + tm_packet := int32(0x7fffffff) + minimal := uint32(0) + if kcp.updated == 0 { + return current + } + + if _itimediff(current, ts_flush) >= 10000 || + _itimediff(current, ts_flush) < -10000 { + ts_flush = current + } + + if _itimediff(current, ts_flush) >= 0 { + return current + } + + tm_flush = _itimediff(ts_flush, current) + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + diff := _itimediff(seg.resendts, current) + if diff <= 0 { + return current + } + if diff < tm_packet { + tm_packet = diff + } + } + + minimal = uint32(tm_packet) + if tm_packet >= tm_flush { + minimal = uint32(tm_flush) + } + if minimal >= kcp.interval { + minimal = kcp.interval + } + + return current + minimal +} + +// SetMtu changes MTU size, default is 1400 +func (kcp *KCP) SetMtu(mtu int) int { + if mtu < 50 || mtu < IKCP_OVERHEAD { + return -1 + } + buffer := make([]byte, (mtu+IKCP_OVERHEAD)*3) + if buffer == nil { + return -2 + } + kcp.mtu = uint32(mtu) + kcp.mss = kcp.mtu - IKCP_OVERHEAD + kcp.buffer = buffer + return 0 +} + +// NoDelay options +// fastest: ikcp_nodelay(kcp, 1, 20, 2, 1) +// nodelay: 0:disable(default), 1:enable +// interval: internal update timer interval in millisec, default is 100ms +// resend: 0:disable fast resend(default), 1:enable fast resend +// nc: 0:normal congestion control(default), 1:disable congestion control +func (kcp *KCP) NoDelay(nodelay, interval, resend, nc int) int { + if nodelay >= 0 { + kcp.nodelay = uint32(nodelay) + if nodelay != 0 { + kcp.rx_minrto = IKCP_RTO_NDL + } else { + kcp.rx_minrto = IKCP_RTO_MIN + } + } + if interval >= 0 { + if interval > 5000 { + interval = 5000 + } else if interval < 10 { + interval = 10 + } + kcp.interval = uint32(interval) + } + if resend >= 0 { + kcp.fastresend = int32(resend) + } + if nc >= 0 { + kcp.nocwnd = int32(nc) + } + return 0 +} + +// WndSize sets maximum window size: sndwnd=32, rcvwnd=32 by default +func (kcp *KCP) WndSize(sndwnd, rcvwnd int) int { + if sndwnd > 0 { + kcp.snd_wnd = uint32(sndwnd) + } + if rcvwnd > 0 { + kcp.rcv_wnd = uint32(rcvwnd) + } + return 0 +} + +// WaitSnd gets how many packet is waiting to be sent +func (kcp *KCP) WaitSnd() int { + return len(kcp.snd_buf) + len(kcp.snd_queue) +} + +// remove front n elements from queue +func (kcp *KCP) remove_front(q []segment, n int) []segment { + newn := copy(q, q[n:]) + return q[:newn] +} diff --git a/lib/kcp/kcp_test.go b/lib/kcp/kcp_test.go new file mode 100644 index 0000000..d06b061 --- /dev/null +++ b/lib/kcp/kcp_test.go @@ -0,0 +1,302 @@ +package kcp + +import ( + "bytes" + "container/list" + "encoding/binary" + "fmt" + "math/rand" + "sync" + "testing" + "time" +) + +func iclock() int32 { + return int32(currentMs()) +} + +type DelayPacket struct { + _ptr []byte + _size int + _ts int32 +} + +func (p *DelayPacket) Init(size int, src []byte) { + p._ptr = make([]byte, size) + p._size = size + copy(p._ptr, src[:size]) +} + +func (p *DelayPacket) ptr() []byte { return p._ptr } +func (p *DelayPacket) size() int { return p._size } +func (p *DelayPacket) ts() int32 { return p._ts } +func (p *DelayPacket) setts(ts int32) { p._ts = ts } + +type DelayTunnel struct{ *list.List } +type LatencySimulator struct { + current int32 + lostrate, rttmin, rttmax, nmax int + p12 DelayTunnel + p21 DelayTunnel + r12 *rand.Rand + r21 *rand.Rand +} + +// lostrate: 往返一周丢包率的百分比,默认 10% +// rttmin:rtt最小值,默认 60 +// rttmax:rtt最大值,默认 125 +//func (p *LatencySimulator)Init(int lostrate = 10, int rttmin = 60, int rttmax = 125, int nmax = 1000): +func (p *LatencySimulator) Init(lostrate, rttmin, rttmax, nmax int) { + p.r12 = rand.New(rand.NewSource(9)) + p.r21 = rand.New(rand.NewSource(99)) + p.p12 = DelayTunnel{list.New()} + p.p21 = DelayTunnel{list.New()} + p.current = iclock() + p.lostrate = lostrate / 2 // 上面数据是往返丢包率,单程除以2 + p.rttmin = rttmin / 2 + p.rttmax = rttmax / 2 + p.nmax = nmax +} + +// 发送数据 +// peer - 端点0/1,从0发送,从1接收;从1发送从0接收 +func (p *LatencySimulator) send(peer int, data []byte, size int) int { + rnd := 0 + if peer == 0 { + rnd = p.r12.Intn(100) + } else { + rnd = p.r21.Intn(100) + } + //println("!!!!!!!!!!!!!!!!!!!!", rnd, p.lostrate, peer) + if rnd < p.lostrate { + return 0 + } + pkt := &DelayPacket{} + pkt.Init(size, data) + p.current = iclock() + delay := p.rttmin + if p.rttmax > p.rttmin { + delay += rand.Int() % (p.rttmax - p.rttmin) + } + pkt.setts(p.current + int32(delay)) + if peer == 0 { + p.p12.PushBack(pkt) + } else { + p.p21.PushBack(pkt) + } + return 1 +} + +// 接收数据 +func (p *LatencySimulator) recv(peer int, data []byte, maxsize int) int32 { + var it *list.Element + if peer == 0 { + it = p.p21.Front() + if p.p21.Len() == 0 { + return -1 + } + } else { + it = p.p12.Front() + if p.p12.Len() == 0 { + return -1 + } + } + pkt := it.Value.(*DelayPacket) + p.current = iclock() + if p.current < pkt.ts() { + return -2 + } + if maxsize < pkt.size() { + return -3 + } + if peer == 0 { + p.p21.Remove(it) + } else { + p.p12.Remove(it) + } + maxsize = pkt.size() + copy(data, pkt.ptr()[:maxsize]) + return int32(maxsize) +} + +//===================================================================== +//===================================================================== + +// 模拟网络 +var vnet *LatencySimulator + +// 测试用例 +func test(mode int) { + // 创建模拟网络:丢包率10%,Rtt 60ms~125ms + vnet = &LatencySimulator{} + vnet.Init(10, 60, 125, 1000) + + // 创建两个端点的 kcp对象,第一个参数 conv是会话编号,同一个会话需要相同 + // 最后一个是 user参数,用来传递标识 + output1 := func(buf []byte, size int) { + if vnet.send(0, buf, size) != 1 { + } + } + output2 := func(buf []byte, size int) { + if vnet.send(1, buf, size) != 1 { + } + } + kcp1 := NewKCP(0x11223344, output1) + kcp2 := NewKCP(0x11223344, output2) + + current := uint32(iclock()) + slap := current + 20 + index := 0 + next := 0 + var sumrtt uint32 + count := 0 + maxrtt := 0 + + // 配置窗口大小:平均延迟200ms,每20ms发送一个包, + // 而考虑到丢包重发,设置最大收发窗口为128 + kcp1.WndSize(128, 128) + kcp2.WndSize(128, 128) + + // 判断测试用例的模式 + if mode == 0 { + // 默认模式 + kcp1.NoDelay(0, 10, 0, 0) + kcp2.NoDelay(0, 10, 0, 0) + } else if mode == 1 { + // 普通模式,关闭流控等 + kcp1.NoDelay(0, 10, 0, 1) + kcp2.NoDelay(0, 10, 0, 1) + } else { + // 启动快速模式 + // 第二个参数 nodelay-启用以后若干常规加速将启动 + // 第三个参数 interval为内部处理时钟,默认设置为 10ms + // 第四个参数 resend为快速重传指标,设置为2 + // 第五个参数 为是否禁用常规流控,这里禁止 + kcp1.NoDelay(1, 10, 2, 1) + kcp2.NoDelay(1, 10, 2, 1) + } + + buffer := make([]byte, 2000) + var hr int32 + + ts1 := iclock() + + for { + time.Sleep(1 * time.Millisecond) + current = uint32(iclock()) + kcp1.Update() + kcp2.Update() + + // 每隔 20ms,kcp1发送数据 + for ; current >= slap; slap += 20 { + buf := new(bytes.Buffer) + binary.Write(buf, binary.LittleEndian, uint32(index)) + index++ + binary.Write(buf, binary.LittleEndian, uint32(current)) + // 发送上层协议包 + kcp1.Send(buf.Bytes()) + //println("now", iclock()) + } + + // 处理虚拟网络:检测是否有udp包从p1->p2 + for { + hr = vnet.recv(1, buffer, 2000) + if hr < 0 { + break + } + // 如果 p2收到udp,则作为下层协议输入到kcp2 + kcp2.Input(buffer[:hr], true, false) + } + + // 处理虚拟网络:检测是否有udp包从p2->p1 + for { + hr = vnet.recv(0, buffer, 2000) + if hr < 0 { + break + } + // 如果 p1收到udp,则作为下层协议输入到kcp1 + kcp1.Input(buffer[:hr], true, false) + //println("@@@@", hr, r) + } + + // kcp2接收到任何包都返回回去 + for { + hr = int32(kcp2.Recv(buffer[:10])) + // 没有收到包就退出 + if hr < 0 { + break + } + // 如果收到包就回射 + buf := bytes.NewReader(buffer) + var sn uint32 + binary.Read(buf, binary.LittleEndian, &sn) + kcp2.Send(buffer[:hr]) + } + + // kcp1收到kcp2的回射数据 + for { + hr = int32(kcp1.Recv(buffer[:10])) + buf := bytes.NewReader(buffer) + // 没有收到包就退出 + if hr < 0 { + break + } + var sn uint32 + var ts, rtt uint32 + binary.Read(buf, binary.LittleEndian, &sn) + binary.Read(buf, binary.LittleEndian, &ts) + rtt = uint32(current) - ts + + if sn != uint32(next) { + // 如果收到的包不连续 + //for i:=0;i<8 ;i++ { + //println("---", i, buffer[i]) + //} + println("ERROR sn ", count, "<->", next, sn) + return + } + + next++ + sumrtt += rtt + count++ + if rtt > uint32(maxrtt) { + maxrtt = int(rtt) + } + + //println("[RECV] mode=", mode, " sn=", sn, " rtt=", rtt) + } + + if next > 100 { + break + } + } + + ts1 = iclock() - ts1 + + names := []string{"default", "normal", "fast"} + fmt.Printf("%s mode result (%dms):\n", names[mode], ts1) + fmt.Printf("avgrtt=%d maxrtt=%d\n", int(sumrtt/uint32(count)), maxrtt) +} + +func TestNetwork(t *testing.T) { + test(0) // 默认模式,类似 TCP:正常模式,无快速重传,常规流控 + test(1) // 普通模式,关闭流控等 + test(2) // 快速模式,所有开关都打开,且关闭流控 +} + +func BenchmarkFlush(b *testing.B) { + kcp := NewKCP(1, func(buf []byte, size int) {}) + kcp.snd_buf = make([]segment, 1024) + for k := range kcp.snd_buf { + kcp.snd_buf[k].xmit = 1 + kcp.snd_buf[k].resendts = currentMs() + 10000 + } + b.ResetTimer() + b.ReportAllocs() + var mu sync.Mutex + for i := 0; i < b.N; i++ { + mu.Lock() + kcp.flush(false) + mu.Unlock() + } +} diff --git a/lib/kcp/sess.go b/lib/kcp/sess.go new file mode 100644 index 0000000..a60b7b8 --- /dev/null +++ b/lib/kcp/sess.go @@ -0,0 +1,963 @@ +package kcp + +import ( + "crypto/rand" + "encoding/binary" + "hash/crc32" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type errTimeout struct { + error +} + +func (errTimeout) Timeout() bool { return true } +func (errTimeout) Temporary() bool { return true } +func (errTimeout) Error() string { return "i/o timeout" } + +const ( + // 16-bytes nonce for each packet + nonceSize = 16 + + // 4-bytes packet checksum + crcSize = 4 + + // overall crypto header size + cryptHeaderSize = nonceSize + crcSize + + // maximum packet size + mtuLimit = 1500 + + // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory + rxFECMulti = 3 + + // accept backlog + acceptBacklog = 128 +) + +const ( + errBrokenPipe = "broken pipe" + errInvalidOperation = "invalid operation" +) + +var ( + // a system-wide packet buffer shared among sending, receiving and FEC + // to mitigate high-frequency memory allocation for packets + xmitBuf sync.Pool +) + +func init() { + xmitBuf.New = func() interface{} { + return make([]byte, mtuLimit) + } +} + +type ( + // UDPSession defines a KCP session implemented by UDP + UDPSession struct { + updaterIdx int // record slice index in updater + conn net.PacketConn // the underlying packet connection + kcp *KCP // KCP ARQ protocol + l *Listener // pointing to the Listener object if it's been accepted by a Listener + block BlockCrypt // block encryption object + + // kcp receiving is based on packets + // recvbuf turns packets into stream + recvbuf []byte + bufptr []byte + // header extended output buffer, if has header + ext []byte + + // FEC codec + fecDecoder *fecDecoder + fecEncoder *fecEncoder + + // settings + remote net.Addr // remote peer address + rd time.Time // read deadline + wd time.Time // write deadline + headerSize int // the header size additional to a KCP frame + ackNoDelay bool // send ack immediately for each incoming packet(testing purpose) + writeDelay bool // delay kcp.flush() for Write() for bulk transfer + dup int // duplicate udp packets(testing purpose) + + // notifications + die chan struct{} // notify current session has Closed + chReadEvent chan struct{} // notify Read() can be called without blocking + chWriteEvent chan struct{} // notify Write() can be called without blocking + chReadError chan error // notify PacketConn.Read() have an error + chWriteError chan error // notify PacketConn.Write() have an error + + // nonce generator + nonce Entropy + + isClosed bool // flag the session has Closed + mu sync.Mutex + } + + setReadBuffer interface { + SetReadBuffer(bytes int) error + } + + setWriteBuffer interface { + SetWriteBuffer(bytes int) error + } +) + +// newUDPSession create a new udp session for client or server +func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn net.PacketConn, remote net.Addr, block BlockCrypt) *UDPSession { + sess := new(UDPSession) + sess.die = make(chan struct{}) + sess.nonce = new(nonceAES128) + sess.nonce.Init() + sess.chReadEvent = make(chan struct{}, 1) + sess.chWriteEvent = make(chan struct{}, 1) + sess.chReadError = make(chan error, 1) + sess.chWriteError = make(chan error, 1) + sess.remote = remote + sess.conn = conn + sess.l = l + sess.block = block + sess.recvbuf = make([]byte, mtuLimit) + + // FEC codec initialization + sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) + if sess.block != nil { + sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize) + } else { + sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0) + } + + // calculate additional header size introduced by FEC and encryption + if sess.block != nil { + sess.headerSize += cryptHeaderSize + } + if sess.fecEncoder != nil { + sess.headerSize += fecHeaderSizePlus2 + } + + // we only need to allocate extended packet buffer if we have the additional header + if sess.headerSize > 0 { + sess.ext = make([]byte, mtuLimit) + } + + sess.kcp = NewKCP(conv, func(buf []byte, size int) { + if size >= IKCP_OVERHEAD { + sess.output(buf[:size]) + } + }) + sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize) + + // register current session to the global updater, + // which call sess.update() periodically. + updater.addSession(sess) + + if sess.l == nil { // it's a client connection + go sess.readLoop() + atomic.AddUint64(&DefaultSnmp.ActiveOpens, 1) + } else { + atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1) + } + currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1) + maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn) + if currestab > maxconn { + atomic.CompareAndSwapUint64(&DefaultSnmp.MaxConn, maxconn, currestab) + } + + return sess +} + +// Read implements net.Conn +func (s *UDPSession) Read(b []byte) (n int, err error) { + for { + s.mu.Lock() + if len(s.bufptr) > 0 { // copy from buffer into b + n = copy(b, s.bufptr) + s.bufptr = s.bufptr[n:] + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) + return n, nil + } + + if s.isClosed { + s.mu.Unlock() + return 0, errors.New(errBrokenPipe) + } + + if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp + if len(b) >= size { // receive data into 'b' directly + s.kcp.Recv(b) + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(size)) + return size, nil + } + + // if necessary resize the stream buffer to guarantee a sufficent buffer space + if cap(s.recvbuf) < size { + s.recvbuf = make([]byte, size) + } + + // resize the length of recvbuf to correspond to data size + s.recvbuf = s.recvbuf[:size] + s.kcp.Recv(s.recvbuf) + n = copy(b, s.recvbuf) // copy to 'b' + s.bufptr = s.recvbuf[n:] // pointer update + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) + return n, nil + } + + // deadline for current reading operation + var timeout *time.Timer + var c <-chan time.Time + if !s.rd.IsZero() { + if time.Now().After(s.rd) { + s.mu.Unlock() + return 0, errTimeout{} + } + + delay := s.rd.Sub(time.Now()) + timeout = time.NewTimer(delay) + c = timeout.C + } + s.mu.Unlock() + + // wait for read event or timeout + select { + case <-s.chReadEvent: + case <-c: + case <-s.die: + case err = <-s.chReadError: + if timeout != nil { + timeout.Stop() + } + return n, err + } + + if timeout != nil { + timeout.Stop() + } + } +} + +// Write implements net.Conn +func (s *UDPSession) Write(b []byte) (n int, err error) { + for { + s.mu.Lock() + if s.isClosed { + s.mu.Unlock() + return 0, errors.New(errBrokenPipe) + } + + // controls how much data will be sent to kcp core + // to prevent the memory from exhuasting + if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { + n = len(b) + for { + if len(b) <= int(s.kcp.mss) { + s.kcp.Send(b) + break + } else { + s.kcp.Send(b[:s.kcp.mss]) + b = b[s.kcp.mss:] + } + } + + // flush immediately if the queue is full + if s.kcp.WaitSnd() >= int(s.kcp.snd_wnd) || !s.writeDelay { + s.kcp.flush(false) + } + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n)) + return n, nil + } + + // deadline for current writing operation + var timeout *time.Timer + var c <-chan time.Time + if !s.wd.IsZero() { + if time.Now().After(s.wd) { + s.mu.Unlock() + return 0, errTimeout{} + } + delay := s.wd.Sub(time.Now()) + timeout = time.NewTimer(delay) + c = timeout.C + } + s.mu.Unlock() + + // wait for write event or timeout + select { + case <-s.chWriteEvent: + case <-c: + case <-s.die: + case err = <-s.chWriteError: + if timeout != nil { + timeout.Stop() + } + return n, err + } + + if timeout != nil { + timeout.Stop() + } + } +} + +// Close closes the connection. +func (s *UDPSession) Close() error { + // remove current session from updater & listener(if necessary) + updater.removeSession(s) + if s.l != nil { // notify listener + s.l.closeSession(s.remote) + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.isClosed { + return errors.New(errBrokenPipe) + } + close(s.die) + s.isClosed = true + atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0)) + if s.l == nil { // client socket close + return s.conn.Close() + } + return nil +} + +// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (s *UDPSession) LocalAddr() net.Addr { return s.conn.LocalAddr() } + +// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (s *UDPSession) RemoteAddr() net.Addr { return s.remote } + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (s *UDPSession) SetDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.rd = t + s.wd = t + s.notifyReadEvent() + s.notifyWriteEvent() + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (s *UDPSession) SetReadDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.rd = t + s.notifyReadEvent() + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (s *UDPSession) SetWriteDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.wd = t + s.notifyWriteEvent() + return nil +} + +// SetWriteDelay delays write for bulk transfer until the next update interval +func (s *UDPSession) SetWriteDelay(delay bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.writeDelay = delay +} + +// SetWindowSize set maximum window size +func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) { + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.WndSize(sndwnd, rcvwnd) +} + +// SetMtu sets the maximum transmission unit(not including UDP header) +func (s *UDPSession) SetMtu(mtu int) bool { + if mtu > mtuLimit { + return false + } + + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.SetMtu(mtu - s.headerSize) + return true +} + +// SetStreamMode toggles the stream mode on/off +func (s *UDPSession) SetStreamMode(enable bool) { + s.mu.Lock() + defer s.mu.Unlock() + if enable { + s.kcp.stream = 1 + } else { + s.kcp.stream = 0 + } +} + +// SetACKNoDelay changes ack flush option, set true to flush ack immediately, +func (s *UDPSession) SetACKNoDelay(nodelay bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.ackNoDelay = nodelay +} + +// SetDUP duplicates udp packets for kcp output, for testing purpose only +func (s *UDPSession) SetDUP(dup int) { + s.mu.Lock() + defer s.mu.Unlock() + s.dup = dup +} + +// SetNoDelay calls nodelay() of kcp +// https://github.com/skywind3000/kcp/blob/master/README.en.md#protocol-configuration +func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) { + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.NoDelay(nodelay, interval, resend, nc) +} + +// SetDSCP sets the 6bit DSCP field of IP header, no effect if it's accepted from Listener +func (s *UDPSession) SetDSCP(dscp int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + if nc, ok := s.conn.(net.Conn); ok { + if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil { + return ipv6.NewConn(nc).SetTrafficClass(dscp) + } + return nil + } + } + return errors.New(errInvalidOperation) +} + +// SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener +func (s *UDPSession) SetReadBuffer(bytes int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + if nc, ok := s.conn.(setReadBuffer); ok { + return nc.SetReadBuffer(bytes) + } + } + return errors.New(errInvalidOperation) +} + +// SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener +func (s *UDPSession) SetWriteBuffer(bytes int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + if nc, ok := s.conn.(setWriteBuffer); ok { + return nc.SetWriteBuffer(bytes) + } + } + return errors.New(errInvalidOperation) +} + +// post-processing for sending a packet from kcp core +// steps: +// 0. Header extending +// 1. FEC packet generation +// 2. CRC32 integrity +// 3. Encryption +// 4. WriteTo kernel +func (s *UDPSession) output(buf []byte) { + var ecc [][]byte + + // 0. extend buf's header space(if necessary) + ext := buf + if s.headerSize > 0 { + ext = s.ext[:s.headerSize+len(buf)] + copy(ext[s.headerSize:], buf) + } + + // 1. FEC encoding + if s.fecEncoder != nil { + ecc = s.fecEncoder.encode(ext) + } + + // 2&3. crc32 & encryption + if s.block != nil { + s.nonce.Fill(ext[:nonceSize]) + checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:]) + binary.LittleEndian.PutUint32(ext[nonceSize:], checksum) + s.block.Encrypt(ext, ext) + + for k := range ecc { + s.nonce.Fill(ecc[k][:nonceSize]) + checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:]) + binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum) + s.block.Encrypt(ecc[k], ecc[k]) + } + } + + // 4. WriteTo kernel + nbytes := 0 + npkts := 0 + for i := 0; i < s.dup+1; i++ { + if n, err := s.conn.WriteTo(ext, s.remote); err == nil { + nbytes += n + npkts++ + } else { + s.notifyWriteError(err) + } + } + + for k := range ecc { + if n, err := s.conn.WriteTo(ecc[k], s.remote); err == nil { + nbytes += n + npkts++ + } else { + s.notifyWriteError(err) + } + } + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} + +// kcp update, returns interval for next calling +func (s *UDPSession) update() (interval time.Duration) { + s.mu.Lock() + waitsnd := s.kcp.WaitSnd() + interval = time.Duration(s.kcp.flush(false)) * time.Millisecond + if s.kcp.WaitSnd() < waitsnd { + s.notifyWriteEvent() + } + s.mu.Unlock() + return +} + +// GetConv gets conversation id of a session +func (s *UDPSession) GetConv() uint32 { return s.kcp.conv } + +func (s *UDPSession) notifyReadEvent() { + select { + case s.chReadEvent <- struct{}{}: + default: + } +} + +func (s *UDPSession) notifyWriteEvent() { + select { + case s.chWriteEvent <- struct{}{}: + default: + } +} + +func (s *UDPSession) notifyWriteError(err error) { + select { + case s.chWriteError <- err: + default: + } +} + +func (s *UDPSession) kcpInput(data []byte) { + var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64 + + if s.fecDecoder != nil { + if len(data) > fecHeaderSize { // must be larger than fec header size + f := s.fecDecoder.decodeBytes(data) + if f.flag == typeData || f.flag == typeFEC { // header check + if f.flag == typeFEC { + fecParityShards++ + } + recovers := s.fecDecoder.decode(f) + + s.mu.Lock() + waitsnd := s.kcp.WaitSnd() + if f.flag == typeData { + if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { + kcpInErrors++ + } + } + + for _, r := range recovers { + if len(r) >= 2 { // must be larger than 2bytes + sz := binary.LittleEndian.Uint16(r) + if int(sz) <= len(r) && sz >= 2 { + if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 { + fecRecovered++ + } else { + kcpInErrors++ + } + } else { + fecErrs++ + } + } else { + fecErrs++ + } + } + + // to notify the readers to receive the data + if n := s.kcp.PeekSize(); n > 0 { + s.notifyReadEvent() + } + // to notify the writers when queue is shorter(e.g. ACKed) + if s.kcp.WaitSnd() < waitsnd { + s.notifyWriteEvent() + } + s.mu.Unlock() + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + } + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + } + } else { + s.mu.Lock() + waitsnd := s.kcp.WaitSnd() + if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { + kcpInErrors++ + } + if n := s.kcp.PeekSize(); n > 0 { + s.notifyReadEvent() + } + if s.kcp.WaitSnd() < waitsnd { + s.notifyWriteEvent() + } + s.mu.Unlock() + } + + atomic.AddUint64(&DefaultSnmp.InPkts, 1) + atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data))) + if fecParityShards > 0 { + atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards) + } + if kcpInErrors > 0 { + atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors) + } + if fecErrs > 0 { + atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs) + } + if fecRecovered > 0 { + atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) + } +} + +// the read loop for a client session +func (s *UDPSession) readLoop() { + buf := make([]byte, mtuLimit) + var src string + for { + if n, addr, err := s.conn.ReadFrom(buf); err == nil { + // make sure the packet is from the same source + if src == "" { // set source address + src = addr.String() + } else if addr.String() != src { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + continue + } + + if n >= s.headerSize+IKCP_OVERHEAD { + data := buf[:n] + dataValid := false + if s.block != nil { + s.block.Decrypt(data, data) + data = data[nonceSize:] + checksum := crc32.ChecksumIEEE(data[crcSize:]) + if checksum == binary.LittleEndian.Uint32(data) { + data = data[crcSize:] + dataValid = true + } else { + atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + } + } else if s.block == nil { + dataValid = true + } + + if dataValid { + s.kcpInput(data) + } + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + } + } else { + s.chReadError <- err + return + } + } +} + +type ( + // Listener defines a server which will be waiting to accept incoming connections + Listener struct { + block BlockCrypt // block encryption + dataShards int // FEC data shard + parityShards int // FEC parity shard + fecDecoder *fecDecoder // FEC mock initialization + conn net.PacketConn // the underlying packet connection + + sessions map[string]*UDPSession // all sessions accepted by this Listener + sessionLock sync.Mutex + chAccepts chan *UDPSession // Listen() backlog + chSessionClosed chan net.Addr // session close queue + headerSize int // the additional header to a KCP frame + die chan struct{} // notify the listener has closed + rd atomic.Value // read deadline for Accept() + wd atomic.Value + } +) + +// monitor incoming data for all connections of server +func (l *Listener) monitor() { + // a cache for session object last used + var lastAddr string + var lastSession *UDPSession + buf := make([]byte, mtuLimit) + for { + if n, from, err := l.conn.ReadFrom(buf); err == nil { + if n >= l.headerSize+IKCP_OVERHEAD { + data := buf[:n] + dataValid := false + if l.block != nil { + l.block.Decrypt(data, data) + data = data[nonceSize:] + checksum := crc32.ChecksumIEEE(data[crcSize:]) + if checksum == binary.LittleEndian.Uint32(data) { + data = data[crcSize:] + dataValid = true + } else { + atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + } + } else if l.block == nil { + dataValid = true + } + + if dataValid { + addr := from.String() + var s *UDPSession + var ok bool + + // the packets received from an address always come in batch, + // cache the session for next packet, without querying map. + if addr == lastAddr { + s, ok = lastSession, true + } else { + l.sessionLock.Lock() + if s, ok = l.sessions[addr]; ok { + lastSession = s + lastAddr = addr + } + l.sessionLock.Unlock() + } + + if !ok { // new session + if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue + var conv uint32 + convValid := false + if l.fecDecoder != nil { + isfec := binary.LittleEndian.Uint16(data[4:]) + if isfec == typeData { + conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:]) + convValid = true + } + } else { + conv = binary.LittleEndian.Uint32(data) + convValid = true + } + + if convValid { // creates a new session only if the 'conv' field in kcp is accessible + s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, from, l.block) + s.kcpInput(data) + l.sessionLock.Lock() + l.sessions[addr] = s + l.sessionLock.Unlock() + l.chAccepts <- s + } + } + } else { + s.kcpInput(data) + } + } + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + } + } else { + return + } + } +} + +// SetReadBuffer sets the socket read buffer for the Listener +func (l *Listener) SetReadBuffer(bytes int) error { + if nc, ok := l.conn.(setReadBuffer); ok { + return nc.SetReadBuffer(bytes) + } + return errors.New(errInvalidOperation) +} + +// SetWriteBuffer sets the socket write buffer for the Listener +func (l *Listener) SetWriteBuffer(bytes int) error { + if nc, ok := l.conn.(setWriteBuffer); ok { + return nc.SetWriteBuffer(bytes) + } + return errors.New(errInvalidOperation) +} + +// SetDSCP sets the 6bit DSCP field of IP header +func (l *Listener) SetDSCP(dscp int) error { + if nc, ok := l.conn.(net.Conn); ok { + if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil { + return ipv6.NewConn(nc).SetTrafficClass(dscp) + } + return nil + } + return errors.New(errInvalidOperation) +} + +// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. +func (l *Listener) Accept() (net.Conn, error) { + return l.AcceptKCP() +} + +// AcceptKCP accepts a KCP connection +func (l *Listener) AcceptKCP() (*UDPSession, error) { + var timeout <-chan time.Time + if tdeadline, ok := l.rd.Load().(time.Time); ok && !tdeadline.IsZero() { + timeout = time.After(tdeadline.Sub(time.Now())) + } + + select { + case <-timeout: + return nil, &errTimeout{} + case c := <-l.chAccepts: + return c, nil + case <-l.die: + return nil, errors.New(errBrokenPipe) + } +} + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (l *Listener) SetDeadline(t time.Time) error { + l.SetReadDeadline(t) + l.SetWriteDeadline(t) + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (l *Listener) SetReadDeadline(t time.Time) error { + l.rd.Store(t) + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (l *Listener) SetWriteDeadline(t time.Time) error { + l.wd.Store(t) + return nil +} + +// Close stops listening on the UDP address. Already Accepted connections are not closed. +func (l *Listener) Close() error { + close(l.die) + return l.conn.Close() +} + +// closeSession notify the listener that a session has closed +func (l *Listener) closeSession(remote net.Addr) (ret bool) { + l.sessionLock.Lock() + defer l.sessionLock.Unlock() + if _, ok := l.sessions[remote.String()]; ok { + delete(l.sessions, remote.String()) + return true + } + return false +} + +// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. +func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() } + +// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp", +func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) } + +// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption, +// rdataShards, parityShards defines Reed-Solomon Erasure Coding parametes +func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) { + udpaddr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + return nil, errors.Wrap(err, "net.ResolveUDPAddr") + } + conn, err := net.ListenUDP("udp", udpaddr) + if err != nil { + return nil, errors.Wrap(err, "net.ListenUDP") + } + + return ServeConn(block, dataShards, parityShards, conn) +} + +// ServeConn serves KCP protocol for a single packet connection. +func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) { + l := new(Listener) + l.conn = conn + l.sessions = make(map[string]*UDPSession) + l.chAccepts = make(chan *UDPSession, acceptBacklog) + l.chSessionClosed = make(chan net.Addr) + l.die = make(chan struct{}) + l.dataShards = dataShards + l.parityShards = parityShards + l.block = block + l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) + + // calculate header size + if l.block != nil { + l.headerSize += cryptHeaderSize + } + if l.fecDecoder != nil { + l.headerSize += fecHeaderSizePlus2 + } + + go l.monitor() + return l, nil +} + +// Dial connects to the remote address "raddr" on the network "udp" +func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) } + +// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption +func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) { + // network type detection + udpaddr, err := net.ResolveUDPAddr("udp", raddr) + if err != nil { + return nil, errors.Wrap(err, "net.ResolveUDPAddr") + } + network := "udp4" + if udpaddr.IP.To4() == nil { + network = "udp" + } + + conn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, errors.Wrap(err, "net.DialUDP") + } + + return NewConn(raddr, block, dataShards, parityShards, conn) +} + +// NewConn establishes a session and talks KCP protocol over a packet connection. +func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { + udpaddr, err := net.ResolveUDPAddr("udp", raddr) + if err != nil { + return nil, errors.Wrap(err, "net.ResolveUDPAddr") + } + + var convid uint32 + binary.Read(rand.Reader, binary.LittleEndian, &convid) + return newUDPSession(convid, dataShards, parityShards, nil, conn, udpaddr, block), nil +} + +// monotonic reference time point +var refTime time.Time = time.Now() + +// currentMs returns current elasped monotonic milliseconds since program startup +func currentMs() uint32 { return uint32(time.Now().Sub(refTime) / time.Millisecond) } diff --git a/lib/kcp/sess_test.go b/lib/kcp/sess_test.go new file mode 100644 index 0000000..4fce29a --- /dev/null +++ b/lib/kcp/sess_test.go @@ -0,0 +1,475 @@ +package kcp + +import ( + "crypto/sha1" + "fmt" + "io" + "log" + "net" + "net/http" + _ "net/http/pprof" + "sync" + "testing" + "time" + + "golang.org/x/crypto/pbkdf2" +) + +const portEcho = "127.0.0.1:9999" +const portSink = "127.0.0.1:19999" +const portTinyBufferEcho = "127.0.0.1:29999" +const portListerner = "127.0.0.1:9998" + +var key = []byte("testkey") +var pass = pbkdf2.Key(key, []byte(portSink), 4096, 32, sha1.New) + +func init() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + + go echoServer() + go sinkServer() + go tinyBufferEchoServer() + println("beginning tests, encryption:salsa20, fec:10/3") +} + +func dialEcho() (*UDPSession, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(portEcho, block, 10, 3) + if err != nil { + panic(err) + } + + sess.SetStreamMode(true) + sess.SetStreamMode(false) + sess.SetStreamMode(true) + sess.SetWindowSize(1024, 1024) + sess.SetReadBuffer(16 * 1024 * 1024) + sess.SetWriteBuffer(16 * 1024 * 1024) + sess.SetStreamMode(true) + sess.SetNoDelay(1, 10, 2, 1) + sess.SetMtu(1400) + sess.SetMtu(1600) + sess.SetMtu(1400) + sess.SetACKNoDelay(true) + sess.SetACKNoDelay(false) + sess.SetDeadline(time.Now().Add(time.Minute)) + return sess, err +} + +func dialSink() (*UDPSession, error) { + sess, err := DialWithOptions(portSink, nil, 0, 0) + if err != nil { + panic(err) + } + + sess.SetStreamMode(true) + sess.SetWindowSize(1024, 1024) + sess.SetReadBuffer(16 * 1024 * 1024) + sess.SetWriteBuffer(16 * 1024 * 1024) + sess.SetStreamMode(true) + sess.SetNoDelay(1, 10, 2, 1) + sess.SetMtu(1400) + sess.SetACKNoDelay(false) + sess.SetDeadline(time.Now().Add(time.Minute)) + return sess, err +} + +func dialTinyBufferEcho() (*UDPSession, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(portTinyBufferEcho, block, 10, 3) + if err != nil { + panic(err) + } + return sess, err +} + +////////////////////////// +func listenEcho() (net.Listener, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(portEcho, block, 10, 3) +} +func listenTinyBufferEcho() (net.Listener, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(portTinyBufferEcho, block, 10, 3) +} + +func listenSink() (net.Listener, error) { + return ListenWithOptions(portSink, nil, 0, 0) +} + +func echoServer() { + l, err := listenEcho() + if err != nil { + panic(err) + } + + go func() { + kcplistener := l.(*Listener) + kcplistener.SetReadBuffer(4 * 1024 * 1024) + kcplistener.SetWriteBuffer(4 * 1024 * 1024) + kcplistener.SetDSCP(46) + for { + s, err := l.Accept() + if err != nil { + return + } + + // coverage test + s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024) + s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024) + go handleEcho(s.(*UDPSession)) + } + }() +} + +func sinkServer() { + l, err := listenSink() + if err != nil { + panic(err) + } + + go func() { + kcplistener := l.(*Listener) + kcplistener.SetReadBuffer(4 * 1024 * 1024) + kcplistener.SetWriteBuffer(4 * 1024 * 1024) + kcplistener.SetDSCP(46) + for { + s, err := l.Accept() + if err != nil { + return + } + + go handleSink(s.(*UDPSession)) + } + }() +} + +func tinyBufferEchoServer() { + l, err := listenTinyBufferEcho() + if err != nil { + panic(err) + } + + go func() { + for { + s, err := l.Accept() + if err != nil { + return + } + go handleTinyBufferEcho(s.(*UDPSession)) + } + }() +} + +/////////////////////////// + +func handleEcho(conn *UDPSession) { + conn.SetStreamMode(true) + conn.SetWindowSize(4096, 4096) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetDSCP(46) + conn.SetMtu(1400) + conn.SetACKNoDelay(false) + conn.SetReadDeadline(time.Now().Add(time.Hour)) + conn.SetWriteDeadline(time.Now().Add(time.Hour)) + buf := make([]byte, 65536) + for { + n, err := conn.Read(buf) + if err != nil { + panic(err) + } + conn.Write(buf[:n]) + } +} + +func handleSink(conn *UDPSession) { + conn.SetStreamMode(true) + conn.SetWindowSize(4096, 4096) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetDSCP(46) + conn.SetMtu(1400) + conn.SetACKNoDelay(false) + conn.SetReadDeadline(time.Now().Add(time.Hour)) + conn.SetWriteDeadline(time.Now().Add(time.Hour)) + buf := make([]byte, 65536) + for { + _, err := conn.Read(buf) + if err != nil { + panic(err) + } + } +} + +func handleTinyBufferEcho(conn *UDPSession) { + conn.SetStreamMode(true) + buf := make([]byte, 2) + for { + n, err := conn.Read(buf) + if err != nil { + panic(err) + } + conn.Write(buf[:n]) + } +} + +/////////////////////////// + +func TestTimeout(t *testing.T) { + cli, err := dialEcho() + if err != nil { + panic(err) + } + buf := make([]byte, 10) + + //timeout + cli.SetDeadline(time.Now().Add(time.Second)) + <-time.After(2 * time.Second) + n, err := cli.Read(buf) + if n != 0 || err == nil { + t.Fail() + } + cli.Close() +} + +func TestSendRecv(t *testing.T) { + cli, err := dialEcho() + if err != nil { + panic(err) + } + cli.SetWriteDelay(true) + cli.SetDUP(1) + const N = 100 + buf := make([]byte, 10) + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + cli.Write([]byte(msg)) + if n, err := cli.Read(buf); err == nil { + if string(buf[:n]) != msg { + t.Fail() + } + } else { + panic(err) + } + } + cli.Close() +} + +func TestTinyBufferReceiver(t *testing.T) { + cli, err := dialTinyBufferEcho() + if err != nil { + panic(err) + } + const N = 100 + snd := byte(0) + fillBuffer := func(buf []byte) { + for i := 0; i < len(buf); i++ { + buf[i] = snd + snd++ + } + } + + rcv := byte(0) + check := func(buf []byte) bool { + for i := 0; i < len(buf); i++ { + if buf[i] != rcv { + return false + } + rcv++ + } + return true + } + sndbuf := make([]byte, 7) + rcvbuf := make([]byte, 7) + for i := 0; i < N; i++ { + fillBuffer(sndbuf) + cli.Write(sndbuf) + if n, err := io.ReadFull(cli, rcvbuf); err == nil { + if !check(rcvbuf[:n]) { + t.Fail() + } + } else { + panic(err) + } + } + cli.Close() +} + +func TestClose(t *testing.T) { + cli, err := dialEcho() + if err != nil { + panic(err) + } + buf := make([]byte, 10) + + cli.Close() + if cli.Close() == nil { + t.Fail() + } + n, err := cli.Write(buf) + if n != 0 || err == nil { + t.Fail() + } + n, err = cli.Read(buf) + if n != 0 || err == nil { + t.Fail() + } + cli.Close() +} + +func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1024) + for i := 0; i < 1024; i++ { + go parallel_client(&wg) + } + wg.Wait() +} + +func parallel_client(wg *sync.WaitGroup) (err error) { + cli, err := dialEcho() + if err != nil { + panic(err) + } + + err = echo_tester(cli, 64, 64) + wg.Done() + return +} + +func BenchmarkEchoSpeed4K(b *testing.B) { + speedclient(b, 4096) +} + +func BenchmarkEchoSpeed64K(b *testing.B) { + speedclient(b, 65536) +} + +func BenchmarkEchoSpeed512K(b *testing.B) { + speedclient(b, 524288) +} + +func BenchmarkEchoSpeed1M(b *testing.B) { + speedclient(b, 1048576) +} + +func speedclient(b *testing.B, nbytes int) { + b.ReportAllocs() + cli, err := dialEcho() + if err != nil { + panic(err) + } + + if err := echo_tester(cli, nbytes, b.N); err != nil { + b.Fail() + } + b.SetBytes(int64(nbytes)) +} + +func BenchmarkSinkSpeed4K(b *testing.B) { + sinkclient(b, 4096) +} + +func BenchmarkSinkSpeed64K(b *testing.B) { + sinkclient(b, 65536) +} + +func BenchmarkSinkSpeed256K(b *testing.B) { + sinkclient(b, 524288) +} + +func BenchmarkSinkSpeed1M(b *testing.B) { + sinkclient(b, 1048576) +} + +func sinkclient(b *testing.B, nbytes int) { + b.ReportAllocs() + cli, err := dialSink() + if err != nil { + panic(err) + } + + sink_tester(cli, nbytes, b.N) + b.SetBytes(int64(nbytes)) +} + +func echo_tester(cli net.Conn, msglen, msgcount int) error { + buf := make([]byte, msglen) + for i := 0; i < msgcount; i++ { + // send packet + if _, err := cli.Write(buf); err != nil { + return err + } + + // receive packet + nrecv := 0 + for { + n, err := cli.Read(buf) + if err != nil { + return err + } else { + nrecv += n + if nrecv == msglen { + break + } + } + } + } + return nil +} + +func sink_tester(cli *UDPSession, msglen, msgcount int) error { + // sender + buf := make([]byte, msglen) + for i := 0; i < msgcount; i++ { + if _, err := cli.Write(buf); err != nil { + return err + } + } + return nil +} + +func TestSNMP(t *testing.T) { + t.Log(DefaultSnmp.Copy()) + t.Log(DefaultSnmp.Header()) + t.Log(DefaultSnmp.ToSlice()) + DefaultSnmp.Reset() + t.Log(DefaultSnmp.ToSlice()) +} + +func TestListenerClose(t *testing.T) { + l, err := ListenWithOptions(portListerner, nil, 10, 3) + if err != nil { + t.Fail() + } + l.SetReadDeadline(time.Now().Add(time.Second)) + l.SetWriteDeadline(time.Now().Add(time.Second)) + l.SetDeadline(time.Now().Add(time.Second)) + time.Sleep(2 * time.Second) + if _, err := l.Accept(); err == nil { + t.Fail() + } + + l.Close() + fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111") + if l.closeSession(fakeaddr) { + t.Fail() + } +} diff --git a/lib/kcp/snmp.go b/lib/kcp/snmp.go new file mode 100644 index 0000000..607118e --- /dev/null +++ b/lib/kcp/snmp.go @@ -0,0 +1,164 @@ +package kcp + +import ( + "fmt" + "sync/atomic" +) + +// Snmp defines network statistics indicator +type Snmp struct { + BytesSent uint64 // bytes sent from upper level + BytesReceived uint64 // bytes received to upper level + MaxConn uint64 // max number of connections ever reached + ActiveOpens uint64 // accumulated active open connections + PassiveOpens uint64 // accumulated passive open connections + CurrEstab uint64 // current number of established connections + InErrs uint64 // UDP read errors reported from net.PacketConn + InCsumErrors uint64 // checksum errors from CRC32 + KCPInErrors uint64 // packet iput errors reported from KCP + InPkts uint64 // incoming packets count + OutPkts uint64 // outgoing packets count + InSegs uint64 // incoming KCP segments + OutSegs uint64 // outgoing KCP segments + InBytes uint64 // UDP bytes received + OutBytes uint64 // UDP bytes sent + RetransSegs uint64 // accmulated retransmited segments + FastRetransSegs uint64 // accmulated fast retransmitted segments + EarlyRetransSegs uint64 // accmulated early retransmitted segments + LostSegs uint64 // number of segs infered as lost + RepeatSegs uint64 // number of segs duplicated + FECRecovered uint64 // correct packets recovered from FEC + FECErrs uint64 // incorrect packets recovered from FEC + FECParityShards uint64 // FEC segments received + FECShortShards uint64 // number of data shards that's not enough for recovery +} + +func newSnmp() *Snmp { + return new(Snmp) +} + +// Header returns all field names +func (s *Snmp) Header() []string { + return []string{ + "BytesSent", + "BytesReceived", + "MaxConn", + "ActiveOpens", + "PassiveOpens", + "CurrEstab", + "InErrs", + "InCsumErrors", + "KCPInErrors", + "InPkts", + "OutPkts", + "InSegs", + "OutSegs", + "InBytes", + "OutBytes", + "RetransSegs", + "FastRetransSegs", + "EarlyRetransSegs", + "LostSegs", + "RepeatSegs", + "FECParityShards", + "FECErrs", + "FECRecovered", + "FECShortShards", + } +} + +// ToSlice returns current snmp info as slice +func (s *Snmp) ToSlice() []string { + snmp := s.Copy() + return []string{ + fmt.Sprint(snmp.BytesSent), + fmt.Sprint(snmp.BytesReceived), + fmt.Sprint(snmp.MaxConn), + fmt.Sprint(snmp.ActiveOpens), + fmt.Sprint(snmp.PassiveOpens), + fmt.Sprint(snmp.CurrEstab), + fmt.Sprint(snmp.InErrs), + fmt.Sprint(snmp.InCsumErrors), + fmt.Sprint(snmp.KCPInErrors), + fmt.Sprint(snmp.InPkts), + fmt.Sprint(snmp.OutPkts), + fmt.Sprint(snmp.InSegs), + fmt.Sprint(snmp.OutSegs), + fmt.Sprint(snmp.InBytes), + fmt.Sprint(snmp.OutBytes), + fmt.Sprint(snmp.RetransSegs), + fmt.Sprint(snmp.FastRetransSegs), + fmt.Sprint(snmp.EarlyRetransSegs), + fmt.Sprint(snmp.LostSegs), + fmt.Sprint(snmp.RepeatSegs), + fmt.Sprint(snmp.FECParityShards), + fmt.Sprint(snmp.FECErrs), + fmt.Sprint(snmp.FECRecovered), + fmt.Sprint(snmp.FECShortShards), + } +} + +// Copy make a copy of current snmp snapshot +func (s *Snmp) Copy() *Snmp { + d := newSnmp() + d.BytesSent = atomic.LoadUint64(&s.BytesSent) + d.BytesReceived = atomic.LoadUint64(&s.BytesReceived) + d.MaxConn = atomic.LoadUint64(&s.MaxConn) + d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens) + d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens) + d.CurrEstab = atomic.LoadUint64(&s.CurrEstab) + d.InErrs = atomic.LoadUint64(&s.InErrs) + d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors) + d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors) + d.InPkts = atomic.LoadUint64(&s.InPkts) + d.OutPkts = atomic.LoadUint64(&s.OutPkts) + d.InSegs = atomic.LoadUint64(&s.InSegs) + d.OutSegs = atomic.LoadUint64(&s.OutSegs) + d.InBytes = atomic.LoadUint64(&s.InBytes) + d.OutBytes = atomic.LoadUint64(&s.OutBytes) + d.RetransSegs = atomic.LoadUint64(&s.RetransSegs) + d.FastRetransSegs = atomic.LoadUint64(&s.FastRetransSegs) + d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs) + d.LostSegs = atomic.LoadUint64(&s.LostSegs) + d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs) + d.FECParityShards = atomic.LoadUint64(&s.FECParityShards) + d.FECErrs = atomic.LoadUint64(&s.FECErrs) + d.FECRecovered = atomic.LoadUint64(&s.FECRecovered) + d.FECShortShards = atomic.LoadUint64(&s.FECShortShards) + return d +} + +// Reset values to zero +func (s *Snmp) Reset() { + atomic.StoreUint64(&s.BytesSent, 0) + atomic.StoreUint64(&s.BytesReceived, 0) + atomic.StoreUint64(&s.MaxConn, 0) + atomic.StoreUint64(&s.ActiveOpens, 0) + atomic.StoreUint64(&s.PassiveOpens, 0) + atomic.StoreUint64(&s.CurrEstab, 0) + atomic.StoreUint64(&s.InErrs, 0) + atomic.StoreUint64(&s.InCsumErrors, 0) + atomic.StoreUint64(&s.KCPInErrors, 0) + atomic.StoreUint64(&s.InPkts, 0) + atomic.StoreUint64(&s.OutPkts, 0) + atomic.StoreUint64(&s.InSegs, 0) + atomic.StoreUint64(&s.OutSegs, 0) + atomic.StoreUint64(&s.InBytes, 0) + atomic.StoreUint64(&s.OutBytes, 0) + atomic.StoreUint64(&s.RetransSegs, 0) + atomic.StoreUint64(&s.FastRetransSegs, 0) + atomic.StoreUint64(&s.EarlyRetransSegs, 0) + atomic.StoreUint64(&s.LostSegs, 0) + atomic.StoreUint64(&s.RepeatSegs, 0) + atomic.StoreUint64(&s.FECParityShards, 0) + atomic.StoreUint64(&s.FECErrs, 0) + atomic.StoreUint64(&s.FECRecovered, 0) + atomic.StoreUint64(&s.FECShortShards, 0) +} + +// DefaultSnmp is the global KCP connection statistics collector +var DefaultSnmp *Snmp + +func init() { + DefaultSnmp = newSnmp() +} diff --git a/lib/kcp/updater.go b/lib/kcp/updater.go new file mode 100644 index 0000000..9a90c82 --- /dev/null +++ b/lib/kcp/updater.go @@ -0,0 +1,104 @@ +package kcp + +import ( + "container/heap" + "sync" + "time" +) + +var updater updateHeap + +func init() { + updater.init() + go updater.updateTask() +} + +// entry contains a session update info +type entry struct { + ts time.Time + s *UDPSession +} + +// a global heap managed kcp.flush() caller +type updateHeap struct { + entries []entry + mu sync.Mutex + chWakeUp chan struct{} +} + +func (h *updateHeap) Len() int { return len(h.entries) } +func (h *updateHeap) Less(i, j int) bool { return h.entries[i].ts.Before(h.entries[j].ts) } +func (h *updateHeap) Swap(i, j int) { + h.entries[i], h.entries[j] = h.entries[j], h.entries[i] + h.entries[i].s.updaterIdx = i + h.entries[j].s.updaterIdx = j +} + +func (h *updateHeap) Push(x interface{}) { + h.entries = append(h.entries, x.(entry)) + n := len(h.entries) + h.entries[n-1].s.updaterIdx = n - 1 +} + +func (h *updateHeap) Pop() interface{} { + n := len(h.entries) + x := h.entries[n-1] + h.entries[n-1].s.updaterIdx = -1 + h.entries[n-1] = entry{} // manual set nil for GC + h.entries = h.entries[0 : n-1] + return x +} + +func (h *updateHeap) init() { + h.chWakeUp = make(chan struct{}, 1) +} + +func (h *updateHeap) addSession(s *UDPSession) { + h.mu.Lock() + heap.Push(h, entry{time.Now(), s}) + h.mu.Unlock() + h.wakeup() +} + +func (h *updateHeap) removeSession(s *UDPSession) { + h.mu.Lock() + if s.updaterIdx != -1 { + heap.Remove(h, s.updaterIdx) + } + h.mu.Unlock() +} + +func (h *updateHeap) wakeup() { + select { + case h.chWakeUp <- struct{}{}: + default: + } +} + +func (h *updateHeap) updateTask() { + var timer <-chan time.Time + for { + select { + case <-timer: + case <-h.chWakeUp: + } + + h.mu.Lock() + hlen := h.Len() + for i := 0; i < hlen; i++ { + entry := &h.entries[0] + if time.Now().After(entry.ts) { + interval := entry.s.update() + entry.ts = time.Now().Add(interval) + heap.Fix(h, 0) + } else { + break + } + } + + if hlen > 0 { + timer = time.After(h.entries[0].ts.Sub(time.Now())) + } + h.mu.Unlock() + } +} diff --git a/lib/log.go b/lib/lg/log.go similarity index 79% rename from lib/log.go rename to lib/lg/log.go index dcd36d6..41ddf74 100644 --- a/lib/log.go +++ b/lib/lg/log.go @@ -1,4 +1,4 @@ -package lib +package lg import ( "log" @@ -9,10 +9,10 @@ import ( var Log *log.Logger -func InitLogFile(f string, isStdout bool) { +func InitLogFile(f string, isStdout bool, logPath string) { var prefix string if !isStdout { - logFile, err := os.OpenFile(filepath.Join(GetLogPath(), f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766) + logFile, err := os.OpenFile(filepath.Join(logPath, f+"_log.txt"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0766) if err != nil { log.Fatalln("open file error !", err) } diff --git a/lib/pool.go b/lib/pool.go deleted file mode 100644 index b96941c..0000000 --- a/lib/pool.go +++ /dev/null @@ -1,43 +0,0 @@ -package lib - -import ( - "sync" -) - -const poolSize = 64 * 1024 -const poolSizeSmall = 100 -const poolSizeUdp = 1472 -const poolSizeCopy = 32 * 1024 - -var BufPool = sync.Pool{ - New: func() interface{} { - return make([]byte, poolSize) - }, -} - -var BufPoolUdp = sync.Pool{ - New: func() interface{} { - return make([]byte, poolSizeUdp) - }, -} -var BufPoolMax = sync.Pool{ - New: func() interface{} { - return make([]byte, poolSize) - }, -} -var BufPoolSmall = sync.Pool{ - New: func() interface{} { - return make([]byte, poolSizeSmall) - }, -} -var BufPoolCopy = sync.Pool{ - New: func() interface{} { - return make([]byte, poolSizeCopy) - }, -} - -func PutBufPoolCopy(buf []byte) { - if cap(buf) == poolSizeCopy { - BufPoolCopy.Put(buf[:poolSizeCopy]) - } -} diff --git a/lib/pool/pool.go b/lib/pool/pool.go new file mode 100644 index 0000000..7c42c2c --- /dev/null +++ b/lib/pool/pool.go @@ -0,0 +1,49 @@ +package pool + +import ( + "sync" +) + +const PoolSize = 64 * 1024 +const PoolSizeSmall = 100 +const PoolSizeUdp = 1472 +const PoolSizeCopy = 32 * 1024 + +var BufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, PoolSize) + }, +} + +var BufPoolUdp = sync.Pool{ + New: func() interface{} { + return make([]byte, PoolSizeUdp) + }, +} +var BufPoolMax = sync.Pool{ + New: func() interface{} { + return make([]byte, PoolSize) + }, +} +var BufPoolSmall = sync.Pool{ + New: func() interface{} { + return make([]byte, PoolSizeSmall) + }, +} +var BufPoolCopy = sync.Pool{ + New: func() interface{} { + return make([]byte, PoolSizeCopy) + }, +} + +func PutBufPoolCopy(buf []byte) { + if cap(buf) == PoolSizeCopy { + BufPoolCopy.Put(buf[:PoolSizeCopy]) + } +} + +func PutBufPoolUdp(buf []byte) { + if cap(buf) == PoolSizeUdp { + BufPoolUdp.Put(buf[:PoolSizeUdp]) + } +} diff --git a/lib/rate.go b/lib/rate/rate.go similarity index 98% rename from lib/rate.go rename to lib/rate/rate.go index 9b959f8..a689a79 100644 --- a/lib/rate.go +++ b/lib/rate/rate.go @@ -1,4 +1,4 @@ -package lib +package rate import ( "sync/atomic" diff --git a/lib/snappy/decode.go b/lib/snappy/decode.go new file mode 100644 index 0000000..72efb03 --- /dev/null +++ b/lib/snappy/decode.go @@ -0,0 +1,237 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" + "errors" + "io" +) + +var ( + // ErrCorrupt reports that the input is invalid. + ErrCorrupt = errors.New("snappy: corrupt input") + // ErrTooLarge reports that the uncompressed length is too large. + ErrTooLarge = errors.New("snappy: decoded block is too large") + // ErrUnsupported reports that the input isn't supported. + ErrUnsupported = errors.New("snappy: unsupported input") + + errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length") +) + +// DecodedLen returns the length of the decoded block. +func DecodedLen(src []byte) (int, error) { + v, _, err := decodedLen(src) + return v, err +} + +// decodedLen returns the length of the decoded block and the number of bytes +// that the length header occupied. +func decodedLen(src []byte) (blockLen, headerLen int, err error) { + v, n := binary.Uvarint(src) + if n <= 0 || v > 0xffffffff { + return 0, 0, ErrCorrupt + } + + const wordSize = 32 << (^uint(0) >> 32 & 1) + if wordSize == 32 && v > 0x7fffffff { + return 0, 0, ErrTooLarge + } + return int(v), n, nil +} + +const ( + decodeErrCodeCorrupt = 1 + decodeErrCodeUnsupportedLiteralLength = 2 +) + +// Decode returns the decoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire decoded block. +// Otherwise, a newly allocated slice will be returned. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +func Decode(dst, src []byte) ([]byte, error) { + dLen, s, err := decodedLen(src) + if err != nil { + return nil, err + } + if dLen <= len(dst) { + dst = dst[:dLen] + } else { + dst = make([]byte, dLen) + } + switch decode(dst, src[s:]) { + case 0: + return dst, nil + case decodeErrCodeUnsupportedLiteralLength: + return nil, errUnsupportedLiteralLength + } + return nil, ErrCorrupt +} + +// NewReader returns a new Reader that decompresses from r, using the framing +// format described at +// https://github.com/google/snappy/blob/master/framing_format.txt +func NewReader(r io.Reader) *Reader { + return &Reader{ + r: r, + decoded: make([]byte, maxBlockSize), + buf: make([]byte, maxEncodedLenOfMaxBlockSize+checksumSize), + } +} + +// Reader is an io.Reader that can read Snappy-compressed bytes. +type Reader struct { + r io.Reader + err error + decoded []byte + buf []byte + // decoded[i:j] contains decoded bytes that have not yet been passed on. + i, j int + readHeader bool +} + +// Reset discards any buffered data, resets all state, and switches the Snappy +// reader to read from r. This permits reusing a Reader rather than allocating +// a new one. +func (r *Reader) Reset(reader io.Reader) { + r.r = reader + r.err = nil + r.i = 0 + r.j = 0 + r.readHeader = false +} + +func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) { + if _, r.err = io.ReadFull(r.r, p); r.err != nil { + if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) { + r.err = ErrCorrupt + } + return false + } + return true +} + +// Read satisfies the io.Reader interface. +func (r *Reader) Read(p []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + for { + if r.i < r.j { + n := copy(p, r.decoded[r.i:r.j]) + r.i += n + return n, nil + } + if !r.readFull(r.buf[:4], true) { + return 0, r.err + } + chunkType := r.buf[0] + if !r.readHeader { + if chunkType != chunkTypeStreamIdentifier { + r.err = ErrCorrupt + return 0, r.err + } + r.readHeader = true + } + chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 + if chunkLen > len(r.buf) { + r.err = ErrUnsupported + return 0, r.err + } + + // The chunk types are specified at + // https://github.com/google/snappy/blob/master/framing_format.txt + switch chunkType { + case chunkTypeCompressedData: + // Section 4.2. Compressed data (chunk type 0x00). + if chunkLen < checksumSize { + r.err = ErrCorrupt + return 0, r.err + } + buf := r.buf[:chunkLen] + if !r.readFull(buf, false) { + return 0, r.err + } + checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 + buf = buf[checksumSize:] + + n, err := DecodedLen(buf) + if err != nil { + r.err = err + return 0, r.err + } + if n > len(r.decoded) { + r.err = ErrCorrupt + return 0, r.err + } + if _, err := Decode(r.decoded, buf); err != nil { + r.err = err + return 0, r.err + } + if crc(r.decoded[:n]) != checksum { + r.err = ErrCorrupt + return 0, r.err + } + r.i, r.j = 0, n + continue + + case chunkTypeUncompressedData: + // Section 4.3. Uncompressed data (chunk type 0x01). + if chunkLen < checksumSize { + r.err = ErrCorrupt + return 0, r.err + } + buf := r.buf[:checksumSize] + if !r.readFull(buf, false) { + return 0, r.err + } + checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 + // Read directly into r.decoded instead of via r.buf. + n := chunkLen - checksumSize + if n > len(r.decoded) { + r.err = ErrCorrupt + return 0, r.err + } + if !r.readFull(r.decoded[:n], false) { + return 0, r.err + } + if crc(r.decoded[:n]) != checksum { + r.err = ErrCorrupt + return 0, r.err + } + r.i, r.j = 0, n + continue + + case chunkTypeStreamIdentifier: + // Section 4.1. Stream identifier (chunk type 0xff). + if chunkLen != len(magicBody) { + r.err = ErrCorrupt + return 0, r.err + } + if !r.readFull(r.buf[:len(magicBody)], false) { + return 0, r.err + } + for i := 0; i < len(magicBody); i++ { + if r.buf[i] != magicBody[i] { + r.err = ErrCorrupt + return 0, r.err + } + } + continue + } + + if chunkType <= 0x7f { + // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f). + r.err = ErrUnsupported + return 0, r.err + } + // Section 4.4 Padding (chunk type 0xfe). + // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). + if !r.readFull(r.buf[:chunkLen], false) { + return 0, r.err + } + } +} diff --git a/lib/snappy/decode_amd64.go b/lib/snappy/decode_amd64.go new file mode 100644 index 0000000..fcd192b --- /dev/null +++ b/lib/snappy/decode_amd64.go @@ -0,0 +1,14 @@ +// Copyright 2016 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +package snappy + +// decode has the same semantics as in decode_other.go. +// +//go:noescape +func decode(dst, src []byte) int diff --git a/lib/snappy/decode_amd64.s b/lib/snappy/decode_amd64.s new file mode 100644 index 0000000..e6179f6 --- /dev/null +++ b/lib/snappy/decode_amd64.s @@ -0,0 +1,490 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The asm code generally follows the pure Go code in decode_other.go, except +// where marked with a "!!!". + +// func decode(dst, src []byte) int +// +// All local variables fit into registers. The non-zero stack size is only to +// spill registers and push args when issuing a CALL. The register allocation: +// - AX scratch +// - BX scratch +// - CX length or x +// - DX offset +// - SI &src[s] +// - DI &dst[d] +// + R8 dst_base +// + R9 dst_len +// + R10 dst_base + dst_len +// + R11 src_base +// + R12 src_len +// + R13 src_base + src_len +// - R14 used by doCopy +// - R15 used by doCopy +// +// The registers R8-R13 (marked with a "+") are set at the start of the +// function, and after a CALL returns, and are not otherwise modified. +// +// The d variable is implicitly DI - R8, and len(dst)-d is R10 - DI. +// The s variable is implicitly SI - R11, and len(src)-s is R13 - SI. +TEXT ·decode(SB), NOSPLIT, $48-56 + // Initialize SI, DI and R8-R13. + MOVQ dst_base+0(FP), R8 + MOVQ dst_len+8(FP), R9 + MOVQ R8, DI + MOVQ R8, R10 + ADDQ R9, R10 + MOVQ src_base+24(FP), R11 + MOVQ src_len+32(FP), R12 + MOVQ R11, SI + MOVQ R11, R13 + ADDQ R12, R13 + +loop: + // for s < len(src) + CMPQ SI, R13 + JEQ end + + // CX = uint32(src[s]) + // + // switch src[s] & 0x03 + MOVBLZX (SI), CX + MOVL CX, BX + ANDL $3, BX + CMPL BX, $1 + JAE tagCopy + + // ---------------------------------------- + // The code below handles literal tags. + + // case tagLiteral: + // x := uint32(src[s] >> 2) + // switch + SHRL $2, CX + CMPL CX, $60 + JAE tagLit60Plus + + // case x < 60: + // s++ + INCQ SI + +doLit: + // This is the end of the inner "switch", when we have a literal tag. + // + // We assume that CX == x and x fits in a uint32, where x is the variable + // used in the pure Go decode_other.go code. + + // length = int(x) + 1 + // + // Unlike the pure Go code, we don't need to check if length <= 0 because + // CX can hold 64 bits, so the increment cannot overflow. + INCQ CX + + // Prepare to check if copying length bytes will run past the end of dst or + // src. + // + // AX = len(dst) - d + // BX = len(src) - s + MOVQ R10, AX + SUBQ DI, AX + MOVQ R13, BX + SUBQ SI, BX + + // !!! Try a faster technique for short (16 or fewer bytes) copies. + // + // if length > 16 || len(dst)-d < 16 || len(src)-s < 16 { + // goto callMemmove // Fall back on calling runtime·memmove. + // } + // + // The C++ snappy code calls this TryFastAppend. It also checks len(src)-s + // against 21 instead of 16, because it cannot assume that all of its input + // is contiguous in memory and so it needs to leave enough source bytes to + // read the next tag without refilling buffers, but Go's Decode assumes + // contiguousness (the src argument is a []byte). + CMPQ CX, $16 + JGT callMemmove + CMPQ AX, $16 + JLT callMemmove + CMPQ BX, $16 + JLT callMemmove + + // !!! Implement the copy from src to dst as a 16-byte load and store. + // (Decode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only length bytes, but that's + // OK. If the input is a valid Snappy encoding then subsequent iterations + // will fix up the overrun. Otherwise, Decode returns a nil []byte (and a + // non-nil error), so the overrun will be ignored. + // + // Note that on amd64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + MOVOU 0(SI), X0 + MOVOU X0, 0(DI) + + // d += length + // s += length + ADDQ CX, DI + ADDQ CX, SI + JMP loop + +callMemmove: + // if length > len(dst)-d || length > len(src)-s { etc } + CMPQ CX, AX + JGT errCorrupt + CMPQ CX, BX + JGT errCorrupt + + // copy(dst[d:], src[s:s+length]) + // + // This means calling runtime·memmove(&dst[d], &src[s], length), so we push + // DI, SI and CX as arguments. Coincidentally, we also need to spill those + // three registers to the stack, to save local variables across the CALL. + MOVQ DI, 0(SP) + MOVQ SI, 8(SP) + MOVQ CX, 16(SP) + MOVQ DI, 24(SP) + MOVQ SI, 32(SP) + MOVQ CX, 40(SP) + CALL runtime·memmove(SB) + + // Restore local variables: unspill registers from the stack and + // re-calculate R8-R13. + MOVQ 24(SP), DI + MOVQ 32(SP), SI + MOVQ 40(SP), CX + MOVQ dst_base+0(FP), R8 + MOVQ dst_len+8(FP), R9 + MOVQ R8, R10 + ADDQ R9, R10 + MOVQ src_base+24(FP), R11 + MOVQ src_len+32(FP), R12 + MOVQ R11, R13 + ADDQ R12, R13 + + // d += length + // s += length + ADDQ CX, DI + ADDQ CX, SI + JMP loop + +tagLit60Plus: + // !!! This fragment does the + // + // s += x - 58; if uint(s) > uint(len(src)) { etc } + // + // checks. In the asm version, we code it once instead of once per switch case. + ADDQ CX, SI + SUBQ $58, SI + MOVQ SI, BX + SUBQ R11, BX + CMPQ BX, R12 + JA errCorrupt + + // case x == 60: + CMPL CX, $61 + JEQ tagLit61 + JA tagLit62Plus + + // x = uint32(src[s-1]) + MOVBLZX -1(SI), CX + JMP doLit + +tagLit61: + // case x == 61: + // x = uint32(src[s-2]) | uint32(src[s-1])<<8 + MOVWLZX -2(SI), CX + JMP doLit + +tagLit62Plus: + CMPL CX, $62 + JA tagLit63 + + // case x == 62: + // x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 + MOVWLZX -3(SI), CX + MOVBLZX -1(SI), BX + SHLL $16, BX + ORL BX, CX + JMP doLit + +tagLit63: + // case x == 63: + // x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 + MOVL -4(SI), CX + JMP doLit + +// The code above handles literal tags. +// ---------------------------------------- +// The code below handles copy tags. + +tagCopy4: + // case tagCopy4: + // s += 5 + ADDQ $5, SI + + // if uint(s) > uint(len(src)) { etc } + MOVQ SI, BX + SUBQ R11, BX + CMPQ BX, R12 + JA errCorrupt + + // length = 1 + int(src[s-5])>>2 + SHRQ $2, CX + INCQ CX + + // offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) + MOVLQZX -4(SI), DX + JMP doCopy + +tagCopy2: + // case tagCopy2: + // s += 3 + ADDQ $3, SI + + // if uint(s) > uint(len(src)) { etc } + MOVQ SI, BX + SUBQ R11, BX + CMPQ BX, R12 + JA errCorrupt + + // length = 1 + int(src[s-3])>>2 + SHRQ $2, CX + INCQ CX + + // offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) + MOVWQZX -2(SI), DX + JMP doCopy + +tagCopy: + // We have a copy tag. We assume that: + // - BX == src[s] & 0x03 + // - CX == src[s] + CMPQ BX, $2 + JEQ tagCopy2 + JA tagCopy4 + + // case tagCopy1: + // s += 2 + ADDQ $2, SI + + // if uint(s) > uint(len(src)) { etc } + MOVQ SI, BX + SUBQ R11, BX + CMPQ BX, R12 + JA errCorrupt + + // offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + MOVQ CX, DX + ANDQ $0xe0, DX + SHLQ $3, DX + MOVBQZX -1(SI), BX + ORQ BX, DX + + // length = 4 + int(src[s-2])>>2&0x7 + SHRQ $2, CX + ANDQ $7, CX + ADDQ $4, CX + +doCopy: + // This is the end of the outer "switch", when we have a copy tag. + // + // We assume that: + // - CX == length && CX > 0 + // - DX == offset + + // if offset <= 0 { etc } + CMPQ DX, $0 + JLE errCorrupt + + // if d < offset { etc } + MOVQ DI, BX + SUBQ R8, BX + CMPQ BX, DX + JLT errCorrupt + + // if length > len(dst)-d { etc } + MOVQ R10, BX + SUBQ DI, BX + CMPQ CX, BX + JGT errCorrupt + + // forwardCopy(dst[d:d+length], dst[d-offset:]); d += length + // + // Set: + // - R14 = len(dst)-d + // - R15 = &dst[d-offset] + MOVQ R10, R14 + SUBQ DI, R14 + MOVQ DI, R15 + SUBQ DX, R15 + + // !!! Try a faster technique for short (16 or fewer bytes) forward copies. + // + // First, try using two 8-byte load/stores, similar to the doLit technique + // above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is + // still OK if offset >= 8. Note that this has to be two 8-byte load/stores + // and not one 16-byte load/store, and the first store has to be before the + // second load, due to the overlap if offset is in the range [8, 16). + // + // if length > 16 || offset < 8 || len(dst)-d < 16 { + // goto slowForwardCopy + // } + // copy 16 bytes + // d += length + CMPQ CX, $16 + JGT slowForwardCopy + CMPQ DX, $8 + JLT slowForwardCopy + CMPQ R14, $16 + JLT slowForwardCopy + MOVQ 0(R15), AX + MOVQ AX, 0(DI) + MOVQ 8(R15), BX + MOVQ BX, 8(DI) + ADDQ CX, DI + JMP loop + +slowForwardCopy: + // !!! If the forward copy is longer than 16 bytes, or if offset < 8, we + // can still try 8-byte load stores, provided we can overrun up to 10 extra + // bytes. As above, the overrun will be fixed up by subsequent iterations + // of the outermost loop. + // + // The C++ snappy code calls this technique IncrementalCopyFastPath. Its + // commentary says: + // + // ---- + // + // The main part of this loop is a simple copy of eight bytes at a time + // until we've copied (at least) the requested amount of bytes. However, + // if d and d-offset are less than eight bytes apart (indicating a + // repeating pattern of length < 8), we first need to expand the pattern in + // order to get the correct results. For instance, if the buffer looks like + // this, with the eight-byte and patterns marked as + // intervals: + // + // abxxxxxxxxxxxx + // [------] d-offset + // [------] d + // + // a single eight-byte copy from to will repeat the pattern + // once, after which we can move two bytes without moving : + // + // ababxxxxxxxxxx + // [------] d-offset + // [------] d + // + // and repeat the exercise until the two no longer overlap. + // + // This allows us to do very well in the special case of one single byte + // repeated many times, without taking a big hit for more general cases. + // + // The worst case of extra writing past the end of the match occurs when + // offset == 1 and length == 1; the last copy will read from byte positions + // [0..7] and write to [4..11], whereas it was only supposed to write to + // position 1. Thus, ten excess bytes. + // + // ---- + // + // That "10 byte overrun" worst case is confirmed by Go's + // TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy + // and finishSlowForwardCopy algorithm. + // + // if length > len(dst)-d-10 { + // goto verySlowForwardCopy + // } + SUBQ $10, R14 + CMPQ CX, R14 + JGT verySlowForwardCopy + +makeOffsetAtLeast8: + // !!! As above, expand the pattern so that offset >= 8 and we can use + // 8-byte load/stores. + // + // for offset < 8 { + // copy 8 bytes from dst[d-offset:] to dst[d:] + // length -= offset + // d += offset + // offset += offset + // // The two previous lines together means that d-offset, and therefore + // // R15, is unchanged. + // } + CMPQ DX, $8 + JGE fixUpSlowForwardCopy + MOVQ (R15), BX + MOVQ BX, (DI) + SUBQ DX, CX + ADDQ DX, DI + ADDQ DX, DX + JMP makeOffsetAtLeast8 + +fixUpSlowForwardCopy: + // !!! Add length (which might be negative now) to d (implied by DI being + // &dst[d]) so that d ends up at the right place when we jump back to the + // top of the loop. Before we do that, though, we save DI to AX so that, if + // length is positive, copying the remaining length bytes will write to the + // right place. + MOVQ DI, AX + ADDQ CX, DI + +finishSlowForwardCopy: + // !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative + // length means that we overrun, but as above, that will be fixed up by + // subsequent iterations of the outermost loop. + CMPQ CX, $0 + JLE loop + MOVQ (R15), BX + MOVQ BX, (AX) + ADDQ $8, R15 + ADDQ $8, AX + SUBQ $8, CX + JMP finishSlowForwardCopy + +verySlowForwardCopy: + // verySlowForwardCopy is a simple implementation of forward copy. In C + // parlance, this is a do/while loop instead of a while loop, since we know + // that length > 0. In Go syntax: + // + // for { + // dst[d] = dst[d - offset] + // d++ + // length-- + // if length == 0 { + // break + // } + // } + MOVB (R15), BX + MOVB BX, (DI) + INCQ R15 + INCQ DI + DECQ CX + JNZ verySlowForwardCopy + JMP loop + +// The code above handles copy tags. +// ---------------------------------------- + +end: + // This is the end of the "for s < len(src)". + // + // if d != len(dst) { etc } + CMPQ DI, R10 + JNE errCorrupt + + // return 0 + MOVQ $0, ret+48(FP) + RET + +errCorrupt: + // return decodeErrCodeCorrupt + MOVQ $1, ret+48(FP) + RET diff --git a/lib/snappy/decode_other.go b/lib/snappy/decode_other.go new file mode 100644 index 0000000..8c9f204 --- /dev/null +++ b/lib/snappy/decode_other.go @@ -0,0 +1,101 @@ +// Copyright 2016 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 appengine !gc noasm + +package snappy + +// decode writes the decoding of src to dst. It assumes that the varint-encoded +// length of the decompressed bytes has already been read, and that len(dst) +// equals that length. +// +// It returns 0 on success or a decodeErrCodeXxx error code on failure. +func decode(dst, src []byte) int { + var d, s, offset, length int + for s < len(src) { + switch src[s] & 0x03 { + case tagLiteral: + x := uint32(src[s] >> 2) + switch { + case x < 60: + s++ + case x == 60: + s += 2 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + x = uint32(src[s-1]) + case x == 61: + s += 3 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + x = uint32(src[s-2]) | uint32(src[s-1])<<8 + case x == 62: + s += 4 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 + case x == 63: + s += 5 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 + } + length = int(x) + 1 + if length <= 0 { + return decodeErrCodeUnsupportedLiteralLength + } + if length > len(dst)-d || length > len(src)-s { + return decodeErrCodeCorrupt + } + copy(dst[d:], src[s:s+length]) + d += length + s += length + continue + + case tagCopy1: + s += 2 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + length = 4 + int(src[s-2])>>2&0x7 + offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + + case tagCopy2: + s += 3 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + length = 1 + int(src[s-3])>>2 + offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) + + case tagCopy4: + s += 5 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + return decodeErrCodeCorrupt + } + length = 1 + int(src[s-5])>>2 + offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) + } + + if offset <= 0 || d < offset || length > len(dst)-d { + return decodeErrCodeCorrupt + } + // Copy from an earlier sub-slice of dst to a later sub-slice. Unlike + // the built-in copy function, this byte-by-byte copy always runs + // forwards, even if the slices overlap. Conceptually, this is: + // + // d += forwardCopy(dst[d:d+length], dst[d-offset:]) + for end := d + length; d != end; d++ { + dst[d] = dst[d-offset] + } + } + if d != len(dst) { + return decodeErrCodeCorrupt + } + return 0 +} diff --git a/lib/snappy/encode.go b/lib/snappy/encode.go new file mode 100644 index 0000000..8d393e9 --- /dev/null +++ b/lib/snappy/encode.go @@ -0,0 +1,285 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" + "errors" + "io" +) + +// Encode returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +func Encode(dst, src []byte) []byte { + if n := MaxEncodedLen(len(src)); n < 0 { + panic(ErrTooLarge) + } else if len(dst) < n { + dst = make([]byte, n) + } + + // The block starts with the varint-encoded length of the decompressed bytes. + d := binary.PutUvarint(dst, uint64(len(src))) + + for len(src) > 0 { + p := src + src = nil + if len(p) > maxBlockSize { + p, src = p[:maxBlockSize], p[maxBlockSize:] + } + if len(p) < minNonLiteralBlockSize { + d += emitLiteral(dst[d:], p) + } else { + d += encodeBlock(dst[d:], p) + } + } + return dst[:d] +} + +// inputMargin is the minimum number of extra input bytes to keep, inside +// encodeBlock's inner loop. On some architectures, this margin lets us +// implement a fast path for emitLiteral, where the copy of short (<= 16 byte) +// literals can be implemented as a single load to and store from a 16-byte +// register. That literal's actual length can be as short as 1 byte, so this +// can copy up to 15 bytes too much, but that's OK as subsequent iterations of +// the encoding loop will fix up the copy overrun, and this inputMargin ensures +// that we don't overrun the dst and src buffers. +const inputMargin = 16 - 1 + +// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that +// could be encoded with a copy tag. This is the minimum with respect to the +// algorithm used by encodeBlock, not a minimum enforced by the file format. +// +// The encoded output must start with at least a 1 byte literal, as there are +// no previous bytes to copy. A minimal (1 byte) copy after that, generated +// from an emitCopy call in encodeBlock's main loop, would require at least +// another inputMargin bytes, for the reason above: we want any emitLiteral +// calls inside encodeBlock's main loop to use the fast path if possible, which +// requires being able to overrun by inputMargin bytes. Thus, +// minNonLiteralBlockSize equals 1 + 1 + inputMargin. +// +// The C++ code doesn't use this exact threshold, but it could, as discussed at +// https://groups.google.com/d/topic/snappy-compression/oGbhsdIJSJ8/discussion +// The difference between Go (2+inputMargin) and C++ (inputMargin) is purely an +// optimization. It should not affect the encoded form. This is tested by +// TestSameEncodingAsCppShortCopies. +const minNonLiteralBlockSize = 1 + 1 + inputMargin + +// MaxEncodedLen returns the maximum length of a snappy block, given its +// uncompressed length. +// +// It will return a negative value if srcLen is too large to encode. +func MaxEncodedLen(srcLen int) int { + n := uint64(srcLen) + if n > 0xffffffff { + return -1 + } + // Compressed data can be defined as: + // compressed := item* literal* + // item := literal* copy + // + // The trailing literal sequence has a space blowup of at most 62/60 + // since a literal of length 60 needs one tag byte + one extra byte + // for length information. + // + // Item blowup is trickier to measure. Suppose the "copy" op copies + // 4 bytes of data. Because of a special check in the encoding code, + // we produce a 4-byte copy only if the offset is < 65536. Therefore + // the copy op takes 3 bytes to encode, and this type of item leads + // to at most the 62/60 blowup for representing literals. + // + // Suppose the "copy" op copies 5 bytes of data. If the offset is big + // enough, it will take 5 bytes to encode the copy op. Therefore the + // worst case here is a one-byte literal followed by a five-byte copy. + // That is, 6 bytes of input turn into 7 bytes of "compressed" data. + // + // This last factor dominates the blowup, so the final estimate is: + n = 32 + n + n/6 + if n > 0xffffffff { + return -1 + } + return int(n) +} + +var errClosed = errors.New("snappy: Writer is closed") + +// NewWriter returns a new Writer that compresses to w. +// +// The Writer returned does not buffer writes. There is no need to Flush or +// Close such a Writer. +// +// Deprecated: the Writer returned is not suitable for many small writes, only +// for few large writes. Use NewBufferedWriter instead, which is efficient +// regardless of the frequency and shape of the writes, and remember to Close +// that Writer when done. +func NewWriter(w io.Writer) *Writer { + return &Writer{ + w: w, + obuf: make([]byte, obufLen), + } +} + +// NewBufferedWriter returns a new Writer that compresses to w, using the +// framing format described at +// https://github.com/google/snappy/blob/master/framing_format.txt +// +// The Writer returned buffers writes. Users must call Close to guarantee all +// data has been forwarded to the underlying io.Writer. They may also call +// Flush zero or more times before calling Close. +func NewBufferedWriter(w io.Writer) *Writer { + return &Writer{ + w: w, + ibuf: make([]byte, 0, maxBlockSize), + obuf: make([]byte, obufLen), + } +} + +// Writer is an io.Writer that can write Snappy-compressed bytes. +type Writer struct { + w io.Writer + err error + + // ibuf is a buffer for the incoming (uncompressed) bytes. + // + // Its use is optional. For backwards compatibility, Writers created by the + // NewWriter function have ibuf == nil, do not buffer incoming bytes, and + // therefore do not need to be Flush'ed or Close'd. + ibuf []byte + + // obuf is a buffer for the outgoing (compressed) bytes. + obuf []byte + + // wroteStreamHeader is whether we have written the stream header. + wroteStreamHeader bool +} + +// Reset discards the writer's state and switches the Snappy writer to write to +// w. This permits reusing a Writer rather than allocating a new one. +func (w *Writer) Reset(writer io.Writer) { + w.w = writer + w.err = nil + if w.ibuf != nil { + w.ibuf = w.ibuf[:0] + } + w.wroteStreamHeader = false +} + +// Write satisfies the io.Writer interface. +func (w *Writer) Write(p []byte) (nRet int, errRet error) { + if w.ibuf == nil { + // Do not buffer incoming bytes. This does not perform or compress well + // if the caller of Writer.Write writes many small slices. This + // behavior is therefore deprecated, but still supported for backwards + // compatibility with code that doesn't explicitly Flush or Close. + return w.write(p) + } + + // The remainder of this method is based on bufio.Writer.Write from the + // standard library. + + for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil { + var n int + if len(w.ibuf) == 0 { + // Large write, empty buffer. + // Write directly from p to avoid copy. + n, _ = w.write(p) + } else { + n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p) + w.ibuf = w.ibuf[:len(w.ibuf)+n] + w.Flush() + } + nRet += n + p = p[n:] + } + if w.err != nil { + return nRet, w.err + } + n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p) + w.ibuf = w.ibuf[:len(w.ibuf)+n] + nRet += n + return nRet, nil +} + +func (w *Writer) write(p []byte) (nRet int, errRet error) { + if w.err != nil { + return 0, w.err + } + for len(p) > 0 { + obufStart := len(magicChunk) + if !w.wroteStreamHeader { + w.wroteStreamHeader = true + copy(w.obuf, magicChunk) + obufStart = 0 + } + + var uncompressed []byte + if len(p) > maxBlockSize { + uncompressed, p = p[:maxBlockSize], p[maxBlockSize:] + } else { + uncompressed, p = p, nil + } + checksum := crc(uncompressed) + + // Compress the buffer, discarding the result if the improvement + // isn't at least 12.5%. + compressed := Encode(w.obuf[obufHeaderLen:], uncompressed) + chunkType := uint8(chunkTypeCompressedData) + chunkLen := 4 + len(compressed) + obufEnd := obufHeaderLen + len(compressed) + if len(compressed) >= len(uncompressed)-len(uncompressed)/8 { + chunkType = chunkTypeUncompressedData + chunkLen = 4 + len(uncompressed) + obufEnd = obufHeaderLen + } + + // Fill in the per-chunk header that comes before the body. + w.obuf[len(magicChunk)+0] = chunkType + w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0) + w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8) + w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16) + w.obuf[len(magicChunk)+4] = uint8(checksum >> 0) + w.obuf[len(magicChunk)+5] = uint8(checksum >> 8) + w.obuf[len(magicChunk)+6] = uint8(checksum >> 16) + w.obuf[len(magicChunk)+7] = uint8(checksum >> 24) + + if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil { + w.err = err + return nRet, err + } + if chunkType == chunkTypeUncompressedData { + if _, err := w.w.Write(uncompressed); err != nil { + w.err = err + return nRet, err + } + } + nRet += len(uncompressed) + } + return nRet, nil +} + +// Flush flushes the Writer to its underlying io.Writer. +func (w *Writer) Flush() error { + if w.err != nil { + return w.err + } + if len(w.ibuf) == 0 { + return nil + } + w.write(w.ibuf) + w.ibuf = w.ibuf[:0] + return w.err +} + +// Close calls Flush and then closes the Writer. +func (w *Writer) Close() error { + w.Flush() + ret := w.err + if w.err == nil { + w.err = errClosed + } + return ret +} diff --git a/lib/snappy/encode_amd64.go b/lib/snappy/encode_amd64.go new file mode 100644 index 0000000..150d91b --- /dev/null +++ b/lib/snappy/encode_amd64.go @@ -0,0 +1,29 @@ +// Copyright 2016 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +package snappy + +// emitLiteral has the same semantics as in encode_other.go. +// +//go:noescape +func emitLiteral(dst, lit []byte) int + +// emitCopy has the same semantics as in encode_other.go. +// +//go:noescape +func emitCopy(dst []byte, offset, length int) int + +// extendMatch has the same semantics as in encode_other.go. +// +//go:noescape +func extendMatch(src []byte, i, j int) int + +// encodeBlock has the same semantics as in encode_other.go. +// +//go:noescape +func encodeBlock(dst, src []byte) (d int) diff --git a/lib/snappy/encode_amd64.s b/lib/snappy/encode_amd64.s new file mode 100644 index 0000000..adfd979 --- /dev/null +++ b/lib/snappy/encode_amd64.s @@ -0,0 +1,730 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The XXX lines assemble on Go 1.4, 1.5 and 1.7, but not 1.6, due to a +// Go toolchain regression. See https://github.com/golang/go/issues/15426 and +// https://github.com/golang/snappy/issues/29 +// +// As a workaround, the package was built with a known good assembler, and +// those instructions were disassembled by "objdump -d" to yield the +// 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15 +// style comments, in AT&T asm syntax. Note that rsp here is a physical +// register, not Go/asm's SP pseudo-register (see https://golang.org/doc/asm). +// The instructions were then encoded as "BYTE $0x.." sequences, which assemble +// fine on Go 1.6. + +// The asm code generally follows the pure Go code in encode_other.go, except +// where marked with a "!!!". + +// ---------------------------------------------------------------------------- + +// func emitLiteral(dst, lit []byte) int +// +// All local variables fit into registers. The register allocation: +// - AX len(lit) +// - BX n +// - DX return value +// - DI &dst[i] +// - R10 &lit[0] +// +// The 24 bytes of stack space is to call runtime·memmove. +// +// The unusual register allocation of local variables, such as R10 for the +// source pointer, matches the allocation used at the call site in encodeBlock, +// which makes it easier to manually inline this function. +TEXT ·emitLiteral(SB), NOSPLIT, $24-56 + MOVQ dst_base+0(FP), DI + MOVQ lit_base+24(FP), R10 + MOVQ lit_len+32(FP), AX + MOVQ AX, DX + MOVL AX, BX + SUBL $1, BX + + CMPL BX, $60 + JLT oneByte + CMPL BX, $256 + JLT twoBytes + +threeBytes: + MOVB $0xf4, 0(DI) + MOVW BX, 1(DI) + ADDQ $3, DI + ADDQ $3, DX + JMP memmove + +twoBytes: + MOVB $0xf0, 0(DI) + MOVB BX, 1(DI) + ADDQ $2, DI + ADDQ $2, DX + JMP memmove + +oneByte: + SHLB $2, BX + MOVB BX, 0(DI) + ADDQ $1, DI + ADDQ $1, DX + +memmove: + MOVQ DX, ret+48(FP) + + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // DI, R10 and AX as arguments. + MOVQ DI, 0(SP) + MOVQ R10, 8(SP) + MOVQ AX, 16(SP) + CALL runtime·memmove(SB) + RET + +// ---------------------------------------------------------------------------- + +// func emitCopy(dst []byte, offset, length int) int +// +// All local variables fit into registers. The register allocation: +// - AX length +// - SI &dst[0] +// - DI &dst[i] +// - R11 offset +// +// The unusual register allocation of local variables, such as R11 for the +// offset, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·emitCopy(SB), NOSPLIT, $0-48 + MOVQ dst_base+0(FP), DI + MOVQ DI, SI + MOVQ offset+24(FP), R11 + MOVQ length+32(FP), AX + +loop0: + // for length >= 68 { etc } + CMPL AX, $68 + JLT step1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVB $0xfe, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + SUBL $64, AX + JMP loop0 + +step1: + // if length > 64 { etc } + CMPL AX, $64 + JLE step2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVB $0xee, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + SUBL $60, AX + +step2: + // if length >= 12 || offset >= 2048 { goto step3 } + CMPL AX, $12 + JGE step3 + CMPL R11, $2048 + JGE step3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(DI) + SHRL $8, R11 + SHLB $5, R11 + SUBB $4, AX + SHLB $2, AX + ORB AX, R11 + ORB $1, R11 + MOVB R11, 0(DI) + ADDQ $2, DI + + // Return the number of bytes written. + SUBQ SI, DI + MOVQ DI, ret+40(FP) + RET + +step3: + // Emit the remaining copy, encoded as 3 bytes. + SUBL $1, AX + SHLB $2, AX + ORB $2, AX + MOVB AX, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + + // Return the number of bytes written. + SUBQ SI, DI + MOVQ DI, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func extendMatch(src []byte, i, j int) int +// +// All local variables fit into registers. The register allocation: +// - DX &src[0] +// - SI &src[j] +// - R13 &src[len(src) - 8] +// - R14 &src[len(src)] +// - R15 &src[i] +// +// The unusual register allocation of local variables, such as R15 for a source +// pointer, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·extendMatch(SB), NOSPLIT, $0-48 + MOVQ src_base+0(FP), DX + MOVQ src_len+8(FP), R14 + MOVQ i+24(FP), R15 + MOVQ j+32(FP), SI + ADDQ DX, R14 + ADDQ DX, R15 + ADDQ DX, SI + MOVQ R14, R13 + SUBQ $8, R13 + +cmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMPQ SI, R13 + JA cmp1 + MOVQ (R15), AX + MOVQ (SI), BX + CMPQ AX, BX + JNE bsf + ADDQ $8, R15 + ADDQ $8, SI + JMP cmp8 + +bsf: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. The BSF instruction finds the + // least significant 1 bit, the amd64 architecture is little-endian, and + // the shift by 3 converts a bit index to a byte index. + XORQ AX, BX + BSFQ BX, BX + SHRQ $3, BX + ADDQ BX, SI + + // Convert from &src[ret] to ret. + SUBQ DX, SI + MOVQ SI, ret+40(FP) + RET + +cmp1: + // In src's tail, compare 1 byte at a time. + CMPQ SI, R14 + JAE extendMatchEnd + MOVB (R15), AX + MOVB (SI), BX + CMPB AX, BX + JNE extendMatchEnd + ADDQ $1, R15 + ADDQ $1, SI + JMP cmp1 + +extendMatchEnd: + // Convert from &src[ret] to ret. + SUBQ DX, SI + MOVQ SI, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func encodeBlock(dst, src []byte) (d int) +// +// All local variables fit into registers, other than "var table". The register +// allocation: +// - AX . . +// - BX . . +// - CX 56 shift (note that amd64 shifts by non-immediates must use CX). +// - DX 64 &src[0], tableSize +// - SI 72 &src[s] +// - DI 80 &dst[d] +// - R9 88 sLimit +// - R10 . &src[nextEmit] +// - R11 96 prevHash, currHash, nextHash, offset +// - R12 104 &src[base], skip +// - R13 . &src[nextS], &src[len(src) - 8] +// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x +// - R15 112 candidate +// +// The second column (56, 64, etc) is the stack offset to spill the registers +// when calling other functions. We could pack this slightly tighter, but it's +// simpler to have a dedicated spill map independent of the function called. +// +// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An +// extra 56 bytes, to call other functions, and an extra 64 bytes, to spill +// local variables (registers) during calls gives 32768 + 56 + 64 = 32888. +TEXT ·encodeBlock(SB), 0, $32888-56 + MOVQ dst_base+0(FP), DI + MOVQ src_base+24(FP), SI + MOVQ src_len+32(FP), R14 + + // shift, tableSize := uint32(32-8), 1<<8 + MOVQ $24, CX + MOVQ $256, DX + +calcShift: + // for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 { + // shift-- + // } + CMPQ DX, $16384 + JGE varTable + CMPQ DX, R14 + JGE varTable + SUBQ $1, CX + SHLQ $1, DX + JMP calcShift + +varTable: + // var table [maxTableSize]uint16 + // + // In the asm code, unlike the Go code, we can zero-initialize only the + // first tableSize elements. Each uint16 element is 2 bytes and each MOVOU + // writes 16 bytes, so we can do only tableSize/8 writes instead of the + // 2048 writes that would zero-initialize all of table's 32768 bytes. + SHRQ $3, DX + LEAQ table-32768(SP), BX + PXOR X0, X0 + +memclr: + MOVOU X0, 0(BX) + ADDQ $16, BX + SUBQ $1, DX + JNZ memclr + + // !!! DX = &src[0] + MOVQ SI, DX + + // sLimit := len(src) - inputMargin + MOVQ R14, R9 + SUBQ $15, R9 + + // !!! Pre-emptively spill CX, DX and R9 to the stack. Their values don't + // change for the rest of the function. + MOVQ CX, 56(SP) + MOVQ DX, 64(SP) + MOVQ R9, 88(SP) + + // nextEmit := 0 + MOVQ DX, R10 + + // s := 1 + ADDQ $1, SI + + // nextHash := hash(load32(src, s), shift) + MOVL 0(SI), R11 + IMULL $0x1e35a7bd, R11 + SHRL CX, R11 + +outer: + // for { etc } + + // skip := 32 + MOVQ $32, R12 + + // nextS := s + MOVQ SI, R13 + + // candidate := 0 + MOVQ $0, R15 + +inner0: + // for { etc } + + // s := nextS + MOVQ R13, SI + + // bytesBetweenHashLookups := skip >> 5 + MOVQ R12, R14 + SHRQ $5, R14 + + // nextS = s + bytesBetweenHashLookups + ADDQ R14, R13 + + // skip += bytesBetweenHashLookups + ADDQ R14, R12 + + // if nextS > sLimit { goto emitRemainder } + MOVQ R13, AX + SUBQ DX, AX + CMPQ AX, R9 + JA emitRemainder + + // candidate = int(table[nextHash]) + // XXX: MOVWQZX table-32768(SP)(R11*2), R15 + // XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15 + BYTE $0x4e + BYTE $0x0f + BYTE $0xb7 + BYTE $0x7c + BYTE $0x5c + BYTE $0x78 + + // table[nextHash] = uint16(s) + MOVQ SI, AX + SUBQ DX, AX + + // XXX: MOVW AX, table-32768(SP)(R11*2) + // XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2) + BYTE $0x66 + BYTE $0x42 + BYTE $0x89 + BYTE $0x44 + BYTE $0x5c + BYTE $0x78 + + // nextHash = hash(load32(src, nextS), shift) + MOVL 0(R13), R11 + IMULL $0x1e35a7bd, R11 + SHRL CX, R11 + + // if load32(src, s) != load32(src, candidate) { continue } break + MOVL 0(SI), AX + MOVL (DX)(R15*1), BX + CMPL AX, BX + JNE inner0 + +fourByteMatch: + // As per the encode_other.go code: + // + // A 4-byte match has been found. We'll later see etc. + + // !!! Jump to a fast path for short (<= 16 byte) literals. See the comment + // on inputMargin in encode.go. + MOVQ SI, AX + SUBQ R10, AX + CMPQ AX, $16 + JLE emitLiteralFastPath + + // ---------------------------------------- + // Begin inline of the emitLiteral call. + // + // d += emitLiteral(dst[d:], src[nextEmit:s]) + + MOVL AX, BX + SUBL $1, BX + + CMPL BX, $60 + JLT inlineEmitLiteralOneByte + CMPL BX, $256 + JLT inlineEmitLiteralTwoBytes + +inlineEmitLiteralThreeBytes: + MOVB $0xf4, 0(DI) + MOVW BX, 1(DI) + ADDQ $3, DI + JMP inlineEmitLiteralMemmove + +inlineEmitLiteralTwoBytes: + MOVB $0xf0, 0(DI) + MOVB BX, 1(DI) + ADDQ $2, DI + JMP inlineEmitLiteralMemmove + +inlineEmitLiteralOneByte: + SHLB $2, BX + MOVB BX, 0(DI) + ADDQ $1, DI + +inlineEmitLiteralMemmove: + // Spill local variables (registers) onto the stack; call; unspill. + // + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // DI, R10 and AX as arguments. + MOVQ DI, 0(SP) + MOVQ R10, 8(SP) + MOVQ AX, 16(SP) + ADDQ AX, DI // Finish the "d +=" part of "d += emitLiteral(etc)". + MOVQ SI, 72(SP) + MOVQ DI, 80(SP) + MOVQ R15, 112(SP) + CALL runtime·memmove(SB) + MOVQ 56(SP), CX + MOVQ 64(SP), DX + MOVQ 72(SP), SI + MOVQ 80(SP), DI + MOVQ 88(SP), R9 + MOVQ 112(SP), R15 + JMP inner1 + +inlineEmitLiteralEnd: + // End inline of the emitLiteral call. + // ---------------------------------------- + +emitLiteralFastPath: + // !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2". + MOVB AX, BX + SUBB $1, BX + SHLB $2, BX + MOVB BX, (DI) + ADDQ $1, DI + + // !!! Implement the copy from lit to dst as a 16-byte load and store. + // (Encode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only len(lit) bytes, but that's + // OK. Subsequent iterations will fix up the overrun. + // + // Note that on amd64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + MOVOU 0(R10), X0 + MOVOU X0, 0(DI) + ADDQ AX, DI + +inner1: + // for { etc } + + // base := s + MOVQ SI, R12 + + // !!! offset := base - candidate + MOVQ R12, R11 + SUBQ R15, R11 + SUBQ DX, R11 + + // ---------------------------------------- + // Begin inline of the extendMatch call. + // + // s = extendMatch(src, candidate+4, s+4) + + // !!! R14 = &src[len(src)] + MOVQ src_len+32(FP), R14 + ADDQ DX, R14 + + // !!! R13 = &src[len(src) - 8] + MOVQ R14, R13 + SUBQ $8, R13 + + // !!! R15 = &src[candidate + 4] + ADDQ $4, R15 + ADDQ DX, R15 + + // !!! s += 4 + ADDQ $4, SI + +inlineExtendMatchCmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMPQ SI, R13 + JA inlineExtendMatchCmp1 + MOVQ (R15), AX + MOVQ (SI), BX + CMPQ AX, BX + JNE inlineExtendMatchBSF + ADDQ $8, R15 + ADDQ $8, SI + JMP inlineExtendMatchCmp8 + +inlineExtendMatchBSF: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. The BSF instruction finds the + // least significant 1 bit, the amd64 architecture is little-endian, and + // the shift by 3 converts a bit index to a byte index. + XORQ AX, BX + BSFQ BX, BX + SHRQ $3, BX + ADDQ BX, SI + JMP inlineExtendMatchEnd + +inlineExtendMatchCmp1: + // In src's tail, compare 1 byte at a time. + CMPQ SI, R14 + JAE inlineExtendMatchEnd + MOVB (R15), AX + MOVB (SI), BX + CMPB AX, BX + JNE inlineExtendMatchEnd + ADDQ $1, R15 + ADDQ $1, SI + JMP inlineExtendMatchCmp1 + +inlineExtendMatchEnd: + // End inline of the extendMatch call. + // ---------------------------------------- + + // ---------------------------------------- + // Begin inline of the emitCopy call. + // + // d += emitCopy(dst[d:], base-candidate, s-base) + + // !!! length := s - base + MOVQ SI, AX + SUBQ R12, AX + +inlineEmitCopyLoop0: + // for length >= 68 { etc } + CMPL AX, $68 + JLT inlineEmitCopyStep1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVB $0xfe, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + SUBL $64, AX + JMP inlineEmitCopyLoop0 + +inlineEmitCopyStep1: + // if length > 64 { etc } + CMPL AX, $64 + JLE inlineEmitCopyStep2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVB $0xee, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + SUBL $60, AX + +inlineEmitCopyStep2: + // if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 } + CMPL AX, $12 + JGE inlineEmitCopyStep3 + CMPL R11, $2048 + JGE inlineEmitCopyStep3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(DI) + SHRL $8, R11 + SHLB $5, R11 + SUBB $4, AX + SHLB $2, AX + ORB AX, R11 + ORB $1, R11 + MOVB R11, 0(DI) + ADDQ $2, DI + JMP inlineEmitCopyEnd + +inlineEmitCopyStep3: + // Emit the remaining copy, encoded as 3 bytes. + SUBL $1, AX + SHLB $2, AX + ORB $2, AX + MOVB AX, 0(DI) + MOVW R11, 1(DI) + ADDQ $3, DI + +inlineEmitCopyEnd: + // End inline of the emitCopy call. + // ---------------------------------------- + + // nextEmit = s + MOVQ SI, R10 + + // if s >= sLimit { goto emitRemainder } + MOVQ SI, AX + SUBQ DX, AX + CMPQ AX, R9 + JAE emitRemainder + + // As per the encode_other.go code: + // + // We could immediately etc. + + // x := load64(src, s-1) + MOVQ -1(SI), R14 + + // prevHash := hash(uint32(x>>0), shift) + MOVL R14, R11 + IMULL $0x1e35a7bd, R11 + SHRL CX, R11 + + // table[prevHash] = uint16(s-1) + MOVQ SI, AX + SUBQ DX, AX + SUBQ $1, AX + + // XXX: MOVW AX, table-32768(SP)(R11*2) + // XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2) + BYTE $0x66 + BYTE $0x42 + BYTE $0x89 + BYTE $0x44 + BYTE $0x5c + BYTE $0x78 + + // currHash := hash(uint32(x>>8), shift) + SHRQ $8, R14 + MOVL R14, R11 + IMULL $0x1e35a7bd, R11 + SHRL CX, R11 + + // candidate = int(table[currHash]) + // XXX: MOVWQZX table-32768(SP)(R11*2), R15 + // XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15 + BYTE $0x4e + BYTE $0x0f + BYTE $0xb7 + BYTE $0x7c + BYTE $0x5c + BYTE $0x78 + + // table[currHash] = uint16(s) + ADDQ $1, AX + + // XXX: MOVW AX, table-32768(SP)(R11*2) + // XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2) + BYTE $0x66 + BYTE $0x42 + BYTE $0x89 + BYTE $0x44 + BYTE $0x5c + BYTE $0x78 + + // if uint32(x>>8) == load32(src, candidate) { continue } + MOVL (DX)(R15*1), BX + CMPL R14, BX + JEQ inner1 + + // nextHash = hash(uint32(x>>16), shift) + SHRQ $8, R14 + MOVL R14, R11 + IMULL $0x1e35a7bd, R11 + SHRL CX, R11 + + // s++ + ADDQ $1, SI + + // break out of the inner1 for loop, i.e. continue the outer loop. + JMP outer + +emitRemainder: + // if nextEmit < len(src) { etc } + MOVQ src_len+32(FP), AX + ADDQ DX, AX + CMPQ R10, AX + JEQ encodeBlockEnd + + // d += emitLiteral(dst[d:], src[nextEmit:]) + // + // Push args. + MOVQ DI, 0(SP) + MOVQ $0, 8(SP) // Unnecessary, as the callee ignores it, but conservative. + MOVQ $0, 16(SP) // Unnecessary, as the callee ignores it, but conservative. + MOVQ R10, 24(SP) + SUBQ R10, AX + MOVQ AX, 32(SP) + MOVQ AX, 40(SP) // Unnecessary, as the callee ignores it, but conservative. + + // Spill local variables (registers) onto the stack; call; unspill. + MOVQ DI, 80(SP) + CALL ·emitLiteral(SB) + MOVQ 80(SP), DI + + // Finish the "d +=" part of "d += emitLiteral(etc)". + ADDQ 48(SP), DI + +encodeBlockEnd: + MOVQ dst_base+0(FP), AX + SUBQ AX, DI + MOVQ DI, d+48(FP) + RET diff --git a/lib/snappy/encode_other.go b/lib/snappy/encode_other.go new file mode 100644 index 0000000..dbcae90 --- /dev/null +++ b/lib/snappy/encode_other.go @@ -0,0 +1,238 @@ +// Copyright 2016 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 appengine !gc noasm + +package snappy + +func load32(b []byte, i int) uint32 { + b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load64(b []byte, i int) uint64 { + b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line. + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +// emitLiteral writes a literal chunk and returns the number of bytes written. +// +// It assumes that: +// dst is long enough to hold the encoded bytes +// 1 <= len(lit) && len(lit) <= 65536 +func emitLiteral(dst, lit []byte) int { + i, n := 0, uint(len(lit)-1) + switch { + case n < 60: + dst[0] = uint8(n)<<2 | tagLiteral + i = 1 + case n < 1<<8: + dst[0] = 60<<2 | tagLiteral + dst[1] = uint8(n) + i = 2 + default: + dst[0] = 61<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + i = 3 + } + return i + copy(dst[i:], lit) +} + +// emitCopy writes a copy chunk and returns the number of bytes written. +// +// It assumes that: +// dst is long enough to hold the encoded bytes +// 1 <= offset && offset <= 65535 +// 4 <= length && length <= 65535 +func emitCopy(dst []byte, offset, length int) int { + i := 0 + // The maximum length for a single tagCopy1 or tagCopy2 op is 64 bytes. The + // threshold for this loop is a little higher (at 68 = 64 + 4), and the + // length emitted down below is is a little lower (at 60 = 64 - 4), because + // it's shorter to encode a length 67 copy as a length 60 tagCopy2 followed + // by a length 7 tagCopy1 (which encodes as 3+2 bytes) than to encode it as + // a length 64 tagCopy2 followed by a length 3 tagCopy2 (which encodes as + // 3+3 bytes). The magic 4 in the 64±4 is because the minimum length for a + // tagCopy1 op is 4 bytes, which is why a length 3 copy has to be an + // encodes-as-3-bytes tagCopy2 instead of an encodes-as-2-bytes tagCopy1. + for length >= 68 { + // Emit a length 64 copy, encoded as 3 bytes. + dst[i+0] = 63<<2 | tagCopy2 + dst[i+1] = uint8(offset) + dst[i+2] = uint8(offset >> 8) + i += 3 + length -= 64 + } + if length > 64 { + // Emit a length 60 copy, encoded as 3 bytes. + dst[i+0] = 59<<2 | tagCopy2 + dst[i+1] = uint8(offset) + dst[i+2] = uint8(offset >> 8) + i += 3 + length -= 60 + } + if length >= 12 || offset >= 2048 { + // Emit the remaining copy, encoded as 3 bytes. + dst[i+0] = uint8(length-1)<<2 | tagCopy2 + dst[i+1] = uint8(offset) + dst[i+2] = uint8(offset >> 8) + return i + 3 + } + // Emit the remaining copy, encoded as 2 bytes. + dst[i+0] = uint8(offset>>8)<<5 | uint8(length-4)<<2 | tagCopy1 + dst[i+1] = uint8(offset) + return i + 2 +} + +// extendMatch returns the largest k such that k <= len(src) and that +// src[i:i+k-j] and src[j:k] have the same contents. +// +// It assumes that: +// 0 <= i && i < j && j <= len(src) +func extendMatch(src []byte, i, j int) int { + for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { + } + return j +} + +func hash(u, shift uint32) uint32 { + return (u * 0x1e35a7bd) >> shift +} + +// encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It +// assumes that the varint-encoded length of the decompressed bytes has already +// been written. +// +// It also assumes that: +// len(dst) >= MaxEncodedLen(len(src)) && +// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize +func encodeBlock(dst, src []byte) (d int) { + // Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive. + // The table element type is uint16, as s < sLimit and sLimit < len(src) + // and len(src) <= maxBlockSize and maxBlockSize == 65536. + const ( + maxTableSize = 1 << 14 + // tableMask is redundant, but helps the compiler eliminate bounds + // checks. + tableMask = maxTableSize - 1 + ) + shift := uint32(32 - 8) + for tableSize := 1 << 8; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 { + shift-- + } + // In Go, all array elements are zero-initialized, so there is no advantage + // to a smaller tableSize per se. However, it matches the C++ algorithm, + // and in the asm versions of this code, we can get away with zeroing only + // the first tableSize elements. + var table [maxTableSize]uint16 + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := len(src) - inputMargin + + // nextEmit is where in src the next emitLiteral should start from. + nextEmit := 0 + + // The encoded form must start with a literal, as there are no previous + // bytes to copy, so we start looking for hash matches at s == 1. + s := 1 + nextHash := hash(load32(src, s), shift) + + for { + // Copied from the C++ snappy implementation: + // + // Heuristic match skipping: If 32 bytes are scanned with no matches + // found, start looking only at every other byte. If 32 more bytes are + // scanned (or skipped), look at every third byte, etc.. When a match + // is found, immediately go back to looking at every byte. This is a + // small loss (~5% performance, ~0.1% density) for compressible data + // due to more bookkeeping, but for non-compressible data (such as + // JPEG) it's a huge win since the compressor quickly "realizes" the + // data is incompressible and doesn't bother looking for matches + // everywhere. + // + // The "skip" variable keeps track of how many bytes there are since + // the last match; dividing it by 32 (ie. right-shifting by five) gives + // the number of bytes to move ahead for each iteration. + skip := 32 + + nextS := s + candidate := 0 + for { + s = nextS + bytesBetweenHashLookups := skip >> 5 + nextS = s + bytesBetweenHashLookups + skip += bytesBetweenHashLookups + if nextS > sLimit { + goto emitRemainder + } + candidate = int(table[nextHash&tableMask]) + table[nextHash&tableMask] = uint16(s) + nextHash = hash(load32(src, nextS), shift) + if load32(src, s) == load32(src, candidate) { + break + } + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + d += emitLiteral(dst[d:], src[nextEmit:s]) + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + base := s + + // Extend the 4-byte match as long as possible. + // + // This is an inlined version of: + // s = extendMatch(src, candidate+4, s+4) + s += 4 + for i := candidate + 4; s < len(src) && src[i] == src[s]; i, s = i+1, s+1 { + } + + d += emitCopy(dst[d:], base-candidate, s-base) + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load64(src, s-1) + prevHash := hash(uint32(x>>0), shift) + table[prevHash&tableMask] = uint16(s - 1) + currHash := hash(uint32(x>>8), shift) + candidate = int(table[currHash&tableMask]) + table[currHash&tableMask] = uint16(s) + if uint32(x>>8) != load32(src, candidate) { + nextHash = hash(uint32(x>>16), shift) + s++ + break + } + } + } + +emitRemainder: + if nextEmit < len(src) { + d += emitLiteral(dst[d:], src[nextEmit:]) + } + return d +} diff --git a/lib/snappy/golden_test.go b/lib/snappy/golden_test.go new file mode 100644 index 0000000..e4496f9 --- /dev/null +++ b/lib/snappy/golden_test.go @@ -0,0 +1,1965 @@ +// Copyright 2016 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +// extendMatchGoldenTestCases is the i and j arguments, and the returned value, +// for every extendMatch call issued when encoding the +// testdata/Mark.Twain-Tom.Sawyer.txt file. It is used to benchmark the +// extendMatch implementation. +// +// It was generated manually by adding some print statements to the (pure Go) +// extendMatch implementation: +// +// func extendMatch(src []byte, i, j int) int { +// i0, j0 := i, j +// for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { +// } +// println("{", i0, ",", j0, ",", j, "},") +// return j +// } +// +// and running "go test -test.run=EncodeGoldenInput -tags=noasm". +var extendMatchGoldenTestCases = []struct { + i, j, want int +}{ + {11, 61, 62}, + {80, 81, 82}, + {86, 87, 101}, + {85, 133, 149}, + {152, 153, 162}, + {133, 168, 193}, + {168, 207, 225}, + {81, 255, 275}, + {278, 279, 283}, + {306, 417, 417}, + {373, 428, 430}, + {389, 444, 447}, + {474, 510, 512}, + {465, 533, 533}, + {47, 547, 547}, + {307, 551, 554}, + {420, 582, 587}, + {309, 604, 604}, + {604, 625, 625}, + {538, 629, 629}, + {328, 640, 640}, + {573, 645, 645}, + {319, 657, 657}, + {30, 664, 664}, + {45, 679, 680}, + {621, 684, 684}, + {376, 700, 700}, + {33, 707, 708}, + {601, 733, 733}, + {334, 744, 745}, + {625, 758, 759}, + {382, 763, 763}, + {550, 769, 771}, + {533, 789, 789}, + {804, 813, 813}, + {342, 841, 842}, + {742, 847, 847}, + {74, 852, 852}, + {810, 864, 864}, + {758, 868, 869}, + {714, 883, 883}, + {582, 889, 891}, + {61, 934, 935}, + {894, 942, 942}, + {939, 949, 949}, + {785, 956, 957}, + {886, 978, 978}, + {792, 998, 998}, + {998, 1005, 1005}, + {572, 1032, 1032}, + {698, 1051, 1053}, + {599, 1067, 1069}, + {1056, 1079, 1079}, + {942, 1089, 1090}, + {831, 1094, 1096}, + {1088, 1100, 1103}, + {732, 1113, 1114}, + {1037, 1118, 1118}, + {872, 1128, 1130}, + {1079, 1140, 1142}, + {332, 1162, 1162}, + {207, 1168, 1186}, + {1189, 1190, 1225}, + {105, 1229, 1230}, + {79, 1256, 1257}, + {1190, 1261, 1283}, + {255, 1306, 1306}, + {1319, 1339, 1358}, + {364, 1370, 1370}, + {955, 1378, 1380}, + {122, 1403, 1403}, + {1325, 1407, 1419}, + {664, 1423, 1424}, + {941, 1461, 1463}, + {867, 1477, 1478}, + {757, 1488, 1489}, + {1140, 1499, 1499}, + {31, 1506, 1506}, + {1487, 1510, 1512}, + {1089, 1520, 1521}, + {1467, 1525, 1529}, + {1394, 1537, 1537}, + {1499, 1541, 1541}, + {367, 1558, 1558}, + {1475, 1564, 1564}, + {1525, 1568, 1571}, + {1541, 1582, 1583}, + {864, 1587, 1588}, + {704, 1597, 1597}, + {336, 1602, 1602}, + {1383, 1613, 1613}, + {1498, 1617, 1618}, + {1051, 1623, 1625}, + {401, 1643, 1645}, + {1072, 1654, 1655}, + {1067, 1667, 1669}, + {699, 1673, 1674}, + {1587, 1683, 1684}, + {920, 1696, 1696}, + {1505, 1710, 1710}, + {1550, 1723, 1723}, + {996, 1727, 1727}, + {833, 1733, 1734}, + {1638, 1739, 1740}, + {1654, 1744, 1744}, + {753, 1761, 1761}, + {1548, 1773, 1773}, + {1568, 1777, 1780}, + {1683, 1793, 1794}, + {948, 1801, 1801}, + {1666, 1805, 1808}, + {1502, 1814, 1814}, + {1696, 1822, 1822}, + {502, 1836, 1837}, + {917, 1843, 1843}, + {1733, 1854, 1855}, + {970, 1859, 1859}, + {310, 1863, 1863}, + {657, 1872, 1872}, + {1005, 1876, 1876}, + {1662, 1880, 1880}, + {904, 1892, 1892}, + {1427, 1910, 1910}, + {1772, 1929, 1930}, + {1822, 1937, 1940}, + {1858, 1949, 1950}, + {1602, 1956, 1956}, + {1150, 1962, 1962}, + {1504, 1966, 1967}, + {51, 1971, 1971}, + {1605, 1979, 1979}, + {1458, 1983, 1988}, + {1536, 2001, 2006}, + {1373, 2014, 2018}, + {1494, 2025, 2025}, + {1667, 2029, 2031}, + {1592, 2035, 2035}, + {330, 2045, 2045}, + {1376, 2053, 2053}, + {1991, 2058, 2059}, + {1635, 2065, 2065}, + {1992, 2073, 2074}, + {2014, 2080, 2081}, + {1546, 2085, 2087}, + {59, 2099, 2099}, + {1996, 2106, 2106}, + {1836, 2110, 2110}, + {2068, 2114, 2114}, + {1338, 2122, 2122}, + {1562, 2128, 2130}, + {1934, 2134, 2134}, + {2114, 2141, 2142}, + {977, 2149, 2150}, + {956, 2154, 2155}, + {1407, 2162, 2162}, + {1773, 2166, 2166}, + {883, 2171, 2171}, + {623, 2175, 2178}, + {1520, 2191, 2192}, + {1162, 2200, 2200}, + {912, 2204, 2204}, + {733, 2208, 2208}, + {1777, 2212, 2215}, + {1532, 2219, 2219}, + {718, 2223, 2225}, + {2069, 2229, 2229}, + {2207, 2245, 2246}, + {1139, 2264, 2264}, + {677, 2274, 2274}, + {2099, 2279, 2279}, + {1863, 2283, 2283}, + {1966, 2305, 2306}, + {2279, 2313, 2313}, + {1628, 2319, 2319}, + {755, 2329, 2329}, + {1461, 2334, 2334}, + {2117, 2340, 2340}, + {2313, 2349, 2349}, + {1859, 2353, 2353}, + {1048, 2362, 2362}, + {895, 2366, 2366}, + {2278, 2373, 2373}, + {1884, 2377, 2377}, + {1402, 2387, 2392}, + {700, 2398, 2398}, + {1971, 2402, 2402}, + {2009, 2419, 2419}, + {1441, 2426, 2428}, + {2208, 2432, 2432}, + {2038, 2436, 2436}, + {932, 2443, 2443}, + {1759, 2447, 2448}, + {744, 2452, 2452}, + {1875, 2458, 2458}, + {2405, 2468, 2468}, + {1596, 2472, 2473}, + {1953, 2480, 2482}, + {736, 2487, 2487}, + {1913, 2493, 2493}, + {774, 2497, 2497}, + {1484, 2506, 2508}, + {2432, 2512, 2512}, + {752, 2519, 2519}, + {2497, 2523, 2523}, + {2409, 2528, 2529}, + {2122, 2533, 2533}, + {2396, 2537, 2538}, + {2410, 2547, 2548}, + {1093, 2555, 2560}, + {551, 2564, 2565}, + {2268, 2569, 2569}, + {1362, 2580, 2580}, + {1916, 2584, 2585}, + {994, 2589, 2590}, + {1979, 2596, 2596}, + {1041, 2602, 2602}, + {2104, 2614, 2616}, + {2609, 2621, 2628}, + {2329, 2638, 2638}, + {2211, 2657, 2658}, + {2638, 2662, 2667}, + {2578, 2676, 2679}, + {2153, 2685, 2686}, + {2608, 2696, 2697}, + {598, 2712, 2712}, + {2620, 2719, 2720}, + {1888, 2724, 2728}, + {2709, 2732, 2732}, + {1365, 2739, 2739}, + {784, 2747, 2748}, + {424, 2753, 2753}, + {2204, 2759, 2759}, + {812, 2768, 2769}, + {2455, 2773, 2773}, + {1722, 2781, 2781}, + {1917, 2792, 2792}, + {2705, 2799, 2799}, + {2685, 2806, 2807}, + {2742, 2811, 2811}, + {1370, 2818, 2818}, + {2641, 2830, 2830}, + {2512, 2837, 2837}, + {2457, 2841, 2841}, + {2756, 2845, 2845}, + {2719, 2855, 2855}, + {1423, 2859, 2859}, + {2849, 2863, 2865}, + {1474, 2871, 2871}, + {1161, 2875, 2876}, + {2282, 2880, 2881}, + {2746, 2888, 2888}, + {1783, 2893, 2893}, + {2401, 2899, 2900}, + {2632, 2920, 2923}, + {2422, 2928, 2930}, + {2715, 2939, 2939}, + {2162, 2943, 2943}, + {2859, 2947, 2947}, + {1910, 2951, 2951}, + {1431, 2955, 2956}, + {1439, 2964, 2964}, + {2501, 2968, 2969}, + {2029, 2973, 2976}, + {689, 2983, 2984}, + {1658, 2988, 2988}, + {1031, 2996, 2996}, + {2149, 3001, 3002}, + {25, 3009, 3013}, + {2964, 3023, 3023}, + {953, 3027, 3028}, + {2359, 3036, 3036}, + {3023, 3049, 3049}, + {2880, 3055, 3056}, + {2973, 3076, 3077}, + {2874, 3090, 3090}, + {2871, 3094, 3094}, + {2532, 3100, 3100}, + {2938, 3107, 3108}, + {350, 3115, 3115}, + {2196, 3119, 3121}, + {1133, 3127, 3129}, + {1797, 3134, 3150}, + {3032, 3158, 3158}, + {3016, 3172, 3172}, + {2533, 3179, 3179}, + {3055, 3187, 3188}, + {1384, 3192, 3193}, + {2799, 3199, 3199}, + {2126, 3203, 3207}, + {2334, 3215, 3215}, + {2105, 3220, 3221}, + {3199, 3229, 3229}, + {2891, 3233, 3233}, + {855, 3240, 3240}, + {1852, 3253, 3256}, + {2140, 3263, 3263}, + {1682, 3268, 3270}, + {3243, 3274, 3274}, + {924, 3279, 3279}, + {2212, 3283, 3283}, + {2596, 3287, 3287}, + {2999, 3291, 3291}, + {2353, 3295, 3295}, + {2480, 3302, 3304}, + {1959, 3308, 3311}, + {3000, 3318, 3318}, + {845, 3330, 3330}, + {2283, 3334, 3334}, + {2519, 3342, 3342}, + {3325, 3346, 3348}, + {2397, 3353, 3354}, + {2763, 3358, 3358}, + {3198, 3363, 3364}, + {3211, 3368, 3372}, + {2950, 3376, 3377}, + {3245, 3388, 3391}, + {2264, 3398, 3398}, + {795, 3403, 3403}, + {3287, 3407, 3407}, + {3358, 3411, 3411}, + {3317, 3415, 3415}, + {3232, 3431, 3431}, + {2128, 3435, 3437}, + {3236, 3441, 3441}, + {3398, 3445, 3446}, + {2814, 3450, 3450}, + {3394, 3466, 3466}, + {2425, 3470, 3470}, + {3330, 3476, 3476}, + {1612, 3480, 3480}, + {1004, 3485, 3486}, + {2732, 3490, 3490}, + {1117, 3494, 3495}, + {629, 3501, 3501}, + {3087, 3514, 3514}, + {684, 3518, 3518}, + {3489, 3522, 3524}, + {1760, 3529, 3529}, + {617, 3537, 3537}, + {3431, 3541, 3541}, + {997, 3547, 3547}, + {882, 3552, 3553}, + {2419, 3558, 3558}, + {610, 3562, 3563}, + {1903, 3567, 3569}, + {3005, 3575, 3575}, + {3076, 3585, 3586}, + {3541, 3590, 3590}, + {3490, 3594, 3594}, + {1899, 3599, 3599}, + {3545, 3606, 3606}, + {3290, 3614, 3615}, + {2056, 3619, 3620}, + {3556, 3625, 3625}, + {3294, 3632, 3633}, + {637, 3643, 3644}, + {3609, 3648, 3650}, + {3175, 3658, 3658}, + {3498, 3665, 3665}, + {1597, 3669, 3669}, + {1983, 3673, 3673}, + {3215, 3682, 3682}, + {3544, 3689, 3689}, + {3694, 3698, 3698}, + {3228, 3715, 3716}, + {2594, 3720, 3722}, + {3573, 3726, 3726}, + {2479, 3732, 3735}, + {3191, 3741, 3742}, + {1113, 3746, 3747}, + {2844, 3751, 3751}, + {3445, 3756, 3757}, + {3755, 3766, 3766}, + {3421, 3775, 3780}, + {3593, 3784, 3786}, + {3263, 3796, 3796}, + {3469, 3806, 3806}, + {2602, 3815, 3815}, + {723, 3819, 3821}, + {1608, 3826, 3826}, + {3334, 3830, 3830}, + {2198, 3835, 3835}, + {2635, 3840, 3840}, + {3702, 3852, 3853}, + {3406, 3858, 3859}, + {3681, 3867, 3870}, + {3407, 3880, 3880}, + {340, 3889, 3889}, + {3772, 3893, 3893}, + {593, 3897, 3897}, + {2563, 3914, 3916}, + {2981, 3929, 3929}, + {1835, 3933, 3934}, + {3906, 3951, 3951}, + {1459, 3958, 3958}, + {3889, 3974, 3974}, + {2188, 3982, 3982}, + {3220, 3986, 3987}, + {3585, 3991, 3993}, + {3712, 3997, 4001}, + {2805, 4007, 4007}, + {1879, 4012, 4013}, + {3618, 4018, 4018}, + {1145, 4031, 4032}, + {3901, 4037, 4037}, + {2772, 4046, 4047}, + {2802, 4053, 4054}, + {3299, 4058, 4058}, + {3725, 4066, 4066}, + {2271, 4070, 4070}, + {385, 4075, 4076}, + {3624, 4089, 4090}, + {3745, 4096, 4098}, + {1563, 4102, 4102}, + {4045, 4106, 4111}, + {3696, 4115, 4119}, + {3376, 4125, 4126}, + {1880, 4130, 4130}, + {2048, 4140, 4141}, + {2724, 4149, 4149}, + {1767, 4156, 4156}, + {2601, 4164, 4164}, + {2757, 4168, 4168}, + {3974, 4172, 4172}, + {3914, 4178, 4178}, + {516, 4185, 4185}, + {1032, 4189, 4190}, + {3462, 4197, 4198}, + {3805, 4202, 4203}, + {3910, 4207, 4212}, + {3075, 4221, 4221}, + {3756, 4225, 4226}, + {1872, 4236, 4237}, + {3844, 4241, 4241}, + {3991, 4245, 4249}, + {2203, 4258, 4258}, + {3903, 4267, 4268}, + {705, 4272, 4272}, + {1896, 4276, 4276}, + {1955, 4285, 4288}, + {3746, 4302, 4303}, + {2672, 4311, 4311}, + {3969, 4317, 4317}, + {3883, 4322, 4322}, + {1920, 4339, 4340}, + {3527, 4344, 4346}, + {1160, 4358, 4358}, + {3648, 4364, 4366}, + {2711, 4387, 4387}, + {3619, 4391, 4392}, + {1944, 4396, 4396}, + {4369, 4400, 4400}, + {2736, 4404, 4407}, + {2546, 4411, 4412}, + {4390, 4422, 4422}, + {3610, 4426, 4427}, + {4058, 4431, 4431}, + {4374, 4435, 4435}, + {3463, 4445, 4446}, + {1813, 4452, 4452}, + {3669, 4456, 4456}, + {3830, 4460, 4460}, + {421, 4464, 4465}, + {1719, 4471, 4471}, + {3880, 4475, 4475}, + {1834, 4485, 4487}, + {3590, 4491, 4491}, + {442, 4496, 4497}, + {4435, 4501, 4501}, + {3814, 4509, 4509}, + {987, 4513, 4513}, + {4494, 4518, 4521}, + {3218, 4526, 4529}, + {4221, 4537, 4537}, + {2778, 4543, 4545}, + {4422, 4552, 4552}, + {4031, 4558, 4559}, + {4178, 4563, 4563}, + {3726, 4567, 4574}, + {4027, 4578, 4578}, + {4339, 4585, 4587}, + {3796, 4592, 4595}, + {543, 4600, 4613}, + {2855, 4620, 4621}, + {2795, 4627, 4627}, + {3440, 4631, 4632}, + {4279, 4636, 4639}, + {4245, 4643, 4645}, + {4516, 4649, 4650}, + {3133, 4654, 4654}, + {4042, 4658, 4659}, + {3422, 4663, 4663}, + {4046, 4667, 4668}, + {4267, 4672, 4672}, + {4004, 4676, 4677}, + {2490, 4682, 4682}, + {2451, 4697, 4697}, + {3027, 4705, 4705}, + {4028, 4717, 4717}, + {4460, 4721, 4721}, + {2471, 4725, 4727}, + {3090, 4735, 4735}, + {3192, 4739, 4740}, + {3835, 4760, 4760}, + {4540, 4764, 4764}, + {4007, 4772, 4774}, + {619, 4784, 4784}, + {3561, 4789, 4791}, + {3367, 4805, 4805}, + {4490, 4810, 4811}, + {2402, 4815, 4815}, + {3352, 4819, 4822}, + {2773, 4828, 4828}, + {4552, 4832, 4832}, + {2522, 4840, 4841}, + {316, 4847, 4852}, + {4715, 4858, 4858}, + {2959, 4862, 4862}, + {4858, 4868, 4869}, + {2134, 4873, 4873}, + {578, 4878, 4878}, + {4189, 4889, 4890}, + {2229, 4894, 4894}, + {4501, 4898, 4898}, + {2297, 4903, 4903}, + {2933, 4909, 4909}, + {3008, 4913, 4913}, + {3153, 4917, 4917}, + {4819, 4921, 4921}, + {4921, 4932, 4933}, + {4920, 4944, 4945}, + {4814, 4954, 4955}, + {576, 4966, 4966}, + {1854, 4970, 4971}, + {1374, 4975, 4976}, + {3307, 4980, 4980}, + {974, 4984, 4988}, + {4721, 4992, 4992}, + {4898, 4996, 4996}, + {4475, 5006, 5006}, + {3819, 5012, 5012}, + {1948, 5019, 5021}, + {4954, 5027, 5029}, + {3740, 5038, 5040}, + {4763, 5044, 5045}, + {1936, 5051, 5051}, + {4844, 5055, 5060}, + {4215, 5069, 5072}, + {1146, 5076, 5076}, + {3845, 5082, 5082}, + {4865, 5090, 5090}, + {4624, 5094, 5094}, + {4815, 5098, 5098}, + {5006, 5105, 5105}, + {4980, 5109, 5109}, + {4795, 5113, 5115}, + {5043, 5119, 5121}, + {4782, 5129, 5129}, + {3826, 5139, 5139}, + {3876, 5156, 5156}, + {3111, 5167, 5171}, + {1470, 5177, 5177}, + {4431, 5181, 5181}, + {546, 5189, 5189}, + {4225, 5193, 5193}, + {1672, 5199, 5201}, + {4207, 5205, 5209}, + {4220, 5216, 5217}, + {4658, 5224, 5225}, + {3295, 5235, 5235}, + {2436, 5239, 5239}, + {2349, 5246, 5246}, + {2175, 5250, 5250}, + {5180, 5257, 5258}, + {3161, 5263, 5263}, + {5105, 5272, 5272}, + {3552, 5282, 5282}, + {4944, 5299, 5300}, + {4130, 5312, 5313}, + {902, 5323, 5323}, + {913, 5327, 5327}, + {2987, 5333, 5334}, + {5150, 5344, 5344}, + {5249, 5348, 5348}, + {1965, 5358, 5359}, + {5330, 5364, 5364}, + {2012, 5373, 5377}, + {712, 5384, 5386}, + {5235, 5390, 5390}, + {5044, 5398, 5399}, + {564, 5406, 5406}, + {39, 5410, 5410}, + {4642, 5422, 5425}, + {4421, 5437, 5438}, + {2347, 5449, 5449}, + {5333, 5453, 5454}, + {4136, 5458, 5459}, + {3793, 5468, 5468}, + {2243, 5480, 5480}, + {4889, 5492, 5493}, + {4295, 5504, 5504}, + {2785, 5511, 5511}, + {2377, 5518, 5518}, + {3662, 5525, 5525}, + {5097, 5529, 5530}, + {4781, 5537, 5538}, + {4697, 5547, 5548}, + {436, 5552, 5553}, + {5542, 5558, 5558}, + {3692, 5562, 5562}, + {2696, 5568, 5569}, + {4620, 5578, 5578}, + {2898, 5590, 5590}, + {5557, 5596, 5618}, + {2797, 5623, 5625}, + {2792, 5629, 5629}, + {5243, 5633, 5633}, + {5348, 5637, 5637}, + {5547, 5643, 5643}, + {4296, 5654, 5655}, + {5568, 5662, 5662}, + {3001, 5670, 5671}, + {3794, 5679, 5679}, + {4006, 5685, 5686}, + {4969, 5690, 5692}, + {687, 5704, 5704}, + {4563, 5708, 5708}, + {1723, 5738, 5738}, + {649, 5742, 5742}, + {5163, 5748, 5755}, + {3907, 5759, 5759}, + {3074, 5764, 5764}, + {5326, 5771, 5771}, + {2951, 5776, 5776}, + {5181, 5780, 5780}, + {2614, 5785, 5788}, + {4709, 5794, 5794}, + {2784, 5799, 5799}, + {5518, 5803, 5803}, + {4155, 5812, 5815}, + {921, 5819, 5819}, + {5224, 5823, 5824}, + {2853, 5830, 5836}, + {5776, 5840, 5840}, + {2955, 5844, 5845}, + {5745, 5853, 5853}, + {3291, 5857, 5857}, + {2988, 5861, 5861}, + {2647, 5865, 5865}, + {5398, 5869, 5870}, + {1085, 5874, 5875}, + {4906, 5881, 5881}, + {802, 5886, 5886}, + {5119, 5890, 5893}, + {5802, 5899, 5900}, + {3415, 5904, 5904}, + {5629, 5908, 5908}, + {3714, 5912, 5914}, + {5558, 5921, 5921}, + {2710, 5927, 5928}, + {1094, 5932, 5934}, + {2653, 5940, 5941}, + {4735, 5954, 5954}, + {5861, 5958, 5958}, + {1040, 5971, 5971}, + {5514, 5977, 5977}, + {5048, 5981, 5982}, + {5953, 5992, 5993}, + {3751, 5997, 5997}, + {4991, 6001, 6002}, + {5885, 6006, 6007}, + {5529, 6011, 6012}, + {4974, 6019, 6020}, + {5857, 6024, 6024}, + {3483, 6032, 6032}, + {3594, 6036, 6036}, + {1997, 6040, 6040}, + {5997, 6044, 6047}, + {5197, 6051, 6051}, + {1764, 6055, 6055}, + {6050, 6059, 6059}, + {5239, 6063, 6063}, + {5049, 6067, 6067}, + {5957, 6073, 6074}, + {1022, 6078, 6078}, + {3414, 6083, 6084}, + {3809, 6090, 6090}, + {4562, 6095, 6096}, + {5878, 6104, 6104}, + {594, 6108, 6109}, + {3353, 6115, 6116}, + {4992, 6120, 6121}, + {2424, 6125, 6125}, + {4484, 6130, 6130}, + {3900, 6134, 6135}, + {5793, 6139, 6141}, + {3562, 6145, 6145}, + {1438, 6152, 6153}, + {6058, 6157, 6158}, + {4411, 6162, 6163}, + {4590, 6167, 6171}, + {4748, 6175, 6175}, + {5517, 6183, 6184}, + {6095, 6191, 6192}, + {1471, 6203, 6203}, + {2643, 6209, 6210}, + {450, 6220, 6220}, + {5266, 6226, 6226}, + {2576, 6233, 6233}, + {2607, 6239, 6240}, + {5164, 6244, 6251}, + {6054, 6255, 6255}, + {1789, 6260, 6261}, + {5250, 6265, 6265}, + {6062, 6273, 6278}, + {5990, 6282, 6282}, + {3283, 6286, 6286}, + {5436, 6290, 6290}, + {6059, 6294, 6294}, + {5668, 6298, 6300}, + {3072, 6324, 6329}, + {3132, 6338, 6339}, + {3246, 6343, 6344}, + {28, 6348, 6349}, + {1503, 6353, 6355}, + {6067, 6359, 6359}, + {3384, 6364, 6364}, + {545, 6375, 6376}, + {5803, 6380, 6380}, + {5522, 6384, 6385}, + {5908, 6389, 6389}, + {2796, 6393, 6396}, + {4831, 6403, 6404}, + {6388, 6412, 6412}, + {6005, 6417, 6420}, + {4450, 6430, 6430}, + {4050, 6435, 6435}, + {5372, 6441, 6441}, + {4378, 6447, 6447}, + {6199, 6452, 6452}, + {3026, 6456, 6456}, + {2642, 6460, 6462}, + {6392, 6470, 6470}, + {6459, 6474, 6474}, + {2829, 6487, 6488}, + {2942, 6499, 6504}, + {5069, 6508, 6511}, + {5341, 6515, 6516}, + {5853, 6521, 6525}, + {6104, 6531, 6531}, + {5759, 6535, 6538}, + {4672, 6542, 6543}, + {2443, 6550, 6550}, + {5109, 6554, 6554}, + {6494, 6558, 6560}, + {6006, 6570, 6572}, + {6424, 6576, 6580}, + {4693, 6591, 6592}, + {6439, 6596, 6597}, + {3179, 6601, 6601}, + {5299, 6606, 6607}, + {4148, 6612, 6613}, + {3774, 6617, 6617}, + {3537, 6623, 6624}, + {4975, 6628, 6629}, + {3848, 6636, 6636}, + {856, 6640, 6640}, + {5724, 6645, 6645}, + {6632, 6651, 6651}, + {4630, 6656, 6658}, + {1440, 6662, 6662}, + {4281, 6666, 6667}, + {4302, 6671, 6672}, + {2589, 6676, 6677}, + {5647, 6681, 6687}, + {6082, 6691, 6693}, + {6144, 6698, 6698}, + {6103, 6709, 6710}, + {3710, 6714, 6714}, + {4253, 6718, 6721}, + {2467, 6730, 6730}, + {4778, 6734, 6734}, + {6528, 6738, 6738}, + {4358, 6747, 6747}, + {5889, 6753, 6753}, + {5193, 6757, 6757}, + {5797, 6761, 6761}, + {3858, 6765, 6766}, + {5951, 6776, 6776}, + {6487, 6781, 6782}, + {3282, 6786, 6787}, + {4667, 6797, 6799}, + {1927, 6803, 6806}, + {6583, 6810, 6810}, + {4937, 6814, 6814}, + {6099, 6824, 6824}, + {4415, 6835, 6836}, + {6332, 6840, 6841}, + {5160, 6850, 6850}, + {4764, 6854, 6854}, + {6814, 6858, 6859}, + {3018, 6864, 6864}, + {6293, 6868, 6869}, + {6359, 6877, 6877}, + {3047, 6884, 6886}, + {5262, 6890, 6891}, + {5471, 6900, 6900}, + {3268, 6910, 6912}, + {1047, 6916, 6916}, + {5904, 6923, 6923}, + {5798, 6933, 6938}, + {4149, 6942, 6942}, + {1821, 6946, 6946}, + {3599, 6952, 6952}, + {6470, 6957, 6957}, + {5562, 6961, 6961}, + {6268, 6965, 6967}, + {6389, 6971, 6971}, + {6596, 6975, 6976}, + {6553, 6980, 6981}, + {6576, 6985, 6989}, + {1375, 6993, 6993}, + {652, 6998, 6998}, + {4876, 7002, 7003}, + {5768, 7011, 7013}, + {3973, 7017, 7017}, + {6802, 7025, 7025}, + {6955, 7034, 7036}, + {6974, 7040, 7040}, + {5944, 7044, 7044}, + {6992, 7048, 7054}, + {6872, 7059, 7059}, + {2943, 7063, 7063}, + {6923, 7067, 7067}, + {5094, 7071, 7071}, + {4873, 7075, 7075}, + {5819, 7079, 7079}, + {5945, 7085, 7085}, + {1540, 7090, 7091}, + {2090, 7095, 7095}, + {5024, 7104, 7105}, + {6900, 7109, 7109}, + {6024, 7113, 7114}, + {6000, 7118, 7120}, + {2187, 7124, 7125}, + {6760, 7129, 7130}, + {5898, 7134, 7136}, + {7032, 7144, 7144}, + {4271, 7148, 7148}, + {3706, 7152, 7152}, + {6970, 7156, 7157}, + {7088, 7161, 7163}, + {2718, 7168, 7169}, + {5674, 7175, 7175}, + {4631, 7182, 7182}, + {7070, 7188, 7189}, + {6220, 7196, 7196}, + {3458, 7201, 7202}, + {2041, 7211, 7212}, + {1454, 7216, 7216}, + {5199, 7225, 7227}, + {3529, 7234, 7234}, + {6890, 7238, 7238}, + {3815, 7242, 7243}, + {5490, 7250, 7253}, + {6554, 7257, 7263}, + {5890, 7267, 7269}, + {6877, 7273, 7273}, + {4877, 7277, 7277}, + {2502, 7285, 7285}, + {1483, 7289, 7295}, + {7210, 7304, 7308}, + {6845, 7313, 7316}, + {7219, 7320, 7320}, + {7001, 7325, 7329}, + {6853, 7333, 7334}, + {6120, 7338, 7338}, + {6606, 7342, 7343}, + {7020, 7348, 7350}, + {3509, 7354, 7354}, + {7133, 7359, 7363}, + {3434, 7371, 7374}, + {2787, 7384, 7384}, + {7044, 7388, 7388}, + {6960, 7394, 7395}, + {6676, 7399, 7400}, + {7161, 7404, 7404}, + {7285, 7417, 7418}, + {4558, 7425, 7426}, + {4828, 7430, 7430}, + {6063, 7436, 7436}, + {3597, 7442, 7442}, + {914, 7446, 7446}, + {7320, 7452, 7454}, + {7267, 7458, 7460}, + {5076, 7464, 7464}, + {7430, 7468, 7469}, + {6273, 7473, 7474}, + {7440, 7478, 7487}, + {7348, 7491, 7494}, + {1021, 7510, 7510}, + {7473, 7515, 7515}, + {2823, 7519, 7519}, + {6264, 7527, 7527}, + {7302, 7531, 7531}, + {7089, 7535, 7535}, + {7342, 7540, 7541}, + {3688, 7547, 7551}, + {3054, 7558, 7560}, + {4177, 7566, 7567}, + {6691, 7574, 7575}, + {7156, 7585, 7586}, + {7147, 7590, 7592}, + {7407, 7598, 7598}, + {7403, 7602, 7603}, + {6868, 7607, 7607}, + {6636, 7611, 7611}, + {4805, 7617, 7617}, + {5779, 7623, 7623}, + {7063, 7627, 7627}, + {5079, 7632, 7632}, + {7377, 7637, 7637}, + {7337, 7641, 7642}, + {6738, 7655, 7655}, + {7338, 7659, 7659}, + {6541, 7669, 7671}, + {595, 7675, 7675}, + {7658, 7679, 7680}, + {7647, 7685, 7686}, + {2477, 7690, 7690}, + {5823, 7694, 7694}, + {4156, 7699, 7699}, + {5931, 7703, 7706}, + {6854, 7712, 7712}, + {4931, 7718, 7718}, + {6979, 7722, 7722}, + {5085, 7727, 7727}, + {6965, 7732, 7732}, + {7201, 7736, 7737}, + {3639, 7741, 7743}, + {7534, 7749, 7749}, + {4292, 7753, 7753}, + {3427, 7759, 7763}, + {7273, 7767, 7767}, + {940, 7778, 7778}, + {4838, 7782, 7785}, + {4216, 7790, 7792}, + {922, 7800, 7801}, + {7256, 7810, 7811}, + {7789, 7815, 7819}, + {7225, 7823, 7825}, + {7531, 7829, 7829}, + {6997, 7833, 7833}, + {7757, 7837, 7838}, + {4129, 7842, 7842}, + {7333, 7848, 7849}, + {6776, 7855, 7855}, + {7527, 7859, 7859}, + {4370, 7863, 7863}, + {4512, 7868, 7868}, + {5679, 7880, 7880}, + {3162, 7884, 7885}, + {3933, 7892, 7894}, + {7804, 7899, 7902}, + {6363, 7906, 7907}, + {7848, 7911, 7912}, + {5584, 7917, 7921}, + {874, 7926, 7926}, + {3342, 7930, 7930}, + {4507, 7935, 7937}, + {3672, 7943, 7944}, + {7911, 7948, 7949}, + {6402, 7956, 7956}, + {7940, 7960, 7960}, + {7113, 7964, 7964}, + {1073, 7968, 7968}, + {7740, 7974, 7974}, + {7601, 7978, 7982}, + {6797, 7987, 7988}, + {3528, 7994, 7995}, + {5483, 7999, 7999}, + {5717, 8011, 8011}, + {5480, 8017, 8017}, + {7770, 8023, 8030}, + {2452, 8034, 8034}, + {5282, 8047, 8047}, + {7967, 8051, 8051}, + {1128, 8058, 8066}, + {6348, 8070, 8070}, + {8055, 8077, 8077}, + {7925, 8081, 8086}, + {6810, 8090, 8090}, + {5051, 8101, 8101}, + {4696, 8109, 8110}, + {5129, 8119, 8119}, + {4449, 8123, 8123}, + {7222, 8127, 8127}, + {4649, 8131, 8134}, + {7994, 8138, 8138}, + {5954, 8148, 8148}, + {475, 8152, 8153}, + {7906, 8157, 8157}, + {7458, 8164, 8166}, + {7632, 8171, 8173}, + {3874, 8177, 8183}, + {4391, 8187, 8187}, + {561, 8191, 8191}, + {2417, 8195, 8195}, + {2357, 8204, 8204}, + {2269, 8216, 8218}, + {3968, 8222, 8222}, + {2200, 8226, 8227}, + {3453, 8247, 8247}, + {2439, 8251, 8252}, + {7175, 8257, 8257}, + {976, 8262, 8264}, + {4953, 8273, 8273}, + {4219, 8278, 8278}, + {6, 8285, 8291}, + {5703, 8295, 8296}, + {5272, 8300, 8300}, + {8037, 8304, 8304}, + {8186, 8314, 8314}, + {8304, 8318, 8318}, + {8051, 8326, 8326}, + {8318, 8330, 8330}, + {2671, 8334, 8335}, + {2662, 8339, 8339}, + {8081, 8349, 8350}, + {3328, 8356, 8356}, + {2879, 8360, 8362}, + {8050, 8370, 8371}, + {8330, 8375, 8376}, + {8375, 8386, 8386}, + {4961, 8390, 8390}, + {1017, 8403, 8405}, + {3533, 8416, 8416}, + {4555, 8422, 8422}, + {6445, 8426, 8426}, + {8169, 8432, 8432}, + {990, 8436, 8436}, + {4102, 8440, 8440}, + {7398, 8444, 8446}, + {3480, 8450, 8450}, + {6324, 8462, 8462}, + {7948, 8466, 8467}, + {5950, 8471, 8471}, + {5189, 8476, 8476}, + {4026, 8490, 8490}, + {8374, 8494, 8495}, + {4682, 8501, 8501}, + {7387, 8506, 8506}, + {8164, 8510, 8515}, + {4079, 8524, 8524}, + {8360, 8529, 8531}, + {7446, 8540, 8543}, + {7971, 8547, 8548}, + {4311, 8552, 8552}, + {5204, 8556, 8557}, + {7968, 8562, 8562}, + {7847, 8571, 8573}, + {8547, 8577, 8577}, + {5320, 8581, 8581}, + {8556, 8585, 8586}, + {8504, 8590, 8590}, + {7669, 8602, 8604}, + {5874, 8608, 8609}, + {5828, 8613, 8613}, + {7998, 8617, 8617}, + {8519, 8625, 8625}, + {7250, 8637, 8637}, + {426, 8641, 8641}, + {8436, 8645, 8645}, + {5986, 8649, 8656}, + {8157, 8660, 8660}, + {7182, 8665, 8665}, + {8421, 8675, 8675}, + {8509, 8681, 8681}, + {5137, 8688, 8689}, + {8625, 8694, 8695}, + {5228, 8701, 8702}, + {6661, 8714, 8714}, + {1010, 8719, 8719}, + {6648, 8723, 8723}, + {3500, 8728, 8728}, + {2442, 8735, 8735}, + {8494, 8740, 8741}, + {8171, 8753, 8755}, + {7242, 8763, 8764}, + {4739, 8768, 8769}, + {7079, 8773, 8773}, + {8386, 8777, 8777}, + {8624, 8781, 8787}, + {661, 8791, 8794}, + {8631, 8801, 8801}, + {7753, 8805, 8805}, + {4783, 8809, 8810}, + {1673, 8814, 8815}, + {6623, 8819, 8819}, + {4404, 8823, 8823}, + {8089, 8827, 8828}, + {8773, 8832, 8832}, + {5394, 8836, 8836}, + {6231, 8841, 8843}, + {1015, 8852, 8853}, + {6873, 8857, 8857}, + {6289, 8865, 8865}, + {8577, 8869, 8869}, + {8114, 8873, 8875}, + {8534, 8883, 8883}, + {3007, 8887, 8888}, + {8827, 8892, 8893}, + {4788, 8897, 8900}, + {5698, 8906, 8907}, + {7690, 8911, 8911}, + {6643, 8919, 8919}, + {7206, 8923, 8924}, + {7866, 8929, 8931}, + {8880, 8942, 8942}, + {8630, 8951, 8952}, + {6027, 8958, 8958}, + {7749, 8966, 8967}, + {4932, 8972, 8973}, + {8892, 8980, 8981}, + {634, 9003, 9003}, + {8109, 9007, 9008}, + {8777, 9012, 9012}, + {3981, 9016, 9017}, + {5723, 9025, 9025}, + {7662, 9034, 9038}, + {8955, 9042, 9042}, + {8070, 9060, 9062}, + {8910, 9066, 9066}, + {5363, 9070, 9071}, + {7699, 9075, 9076}, + {8991, 9081, 9081}, + {6850, 9085, 9085}, + {5811, 9092, 9094}, + {9079, 9098, 9102}, + {6456, 9106, 9106}, + {2259, 9111, 9111}, + {4752, 9116, 9116}, + {9060, 9120, 9123}, + {8090, 9127, 9127}, + {5305, 9131, 9132}, + {8623, 9137, 9137}, + {7417, 9141, 9141}, + {6564, 9148, 9149}, + {9126, 9157, 9158}, + {4285, 9169, 9170}, + {8698, 9174, 9174}, + {8869, 9178, 9178}, + {2572, 9182, 9183}, + {6482, 9188, 9190}, + {9181, 9201, 9201}, + {2968, 9208, 9209}, + {2506, 9213, 9215}, + {9127, 9219, 9219}, + {7910, 9225, 9227}, + {5422, 9235, 9239}, + {8813, 9244, 9246}, + {9178, 9250, 9250}, + {8748, 9255, 9255}, + {7354, 9265, 9265}, + {7767, 9269, 9269}, + {7710, 9281, 9283}, + {8826, 9288, 9290}, + {861, 9295, 9295}, + {4482, 9301, 9301}, + {9264, 9305, 9306}, + {8805, 9310, 9310}, + {4995, 9314, 9314}, + {6730, 9318, 9318}, + {7457, 9328, 9328}, + {2547, 9335, 9336}, + {6298, 9340, 9343}, + {9305, 9353, 9354}, + {9269, 9358, 9358}, + {6338, 9370, 9370}, + {7289, 9376, 9379}, + {5780, 9383, 9383}, + {7607, 9387, 9387}, + {2065, 9392, 9392}, + {7238, 9396, 9396}, + {8856, 9400, 9400}, + {8069, 9412, 9413}, + {611, 9420, 9420}, + {7071, 9424, 9424}, + {3089, 9430, 9431}, + {7117, 9435, 9438}, + {1976, 9445, 9445}, + {6640, 9449, 9449}, + {5488, 9453, 9453}, + {8739, 9457, 9459}, + {5958, 9466, 9466}, + {7985, 9470, 9470}, + {8735, 9475, 9475}, + {5009, 9479, 9479}, + {8073, 9483, 9484}, + {2328, 9490, 9491}, + {9250, 9495, 9495}, + {4043, 9502, 9502}, + {7712, 9506, 9506}, + {9012, 9510, 9510}, + {9028, 9514, 9515}, + {2190, 9521, 9524}, + {9029, 9528, 9528}, + {9519, 9532, 9532}, + {9495, 9536, 9536}, + {8527, 9540, 9540}, + {2137, 9550, 9550}, + {8419, 9557, 9557}, + {9383, 9561, 9562}, + {8970, 9575, 9578}, + {8911, 9582, 9582}, + {7828, 9595, 9596}, + {6180, 9600, 9600}, + {8738, 9604, 9607}, + {7540, 9611, 9612}, + {9599, 9616, 9618}, + {9187, 9623, 9623}, + {9294, 9628, 9629}, + {4536, 9639, 9639}, + {3867, 9643, 9643}, + {6305, 9648, 9648}, + {1617, 9654, 9657}, + {5762, 9666, 9666}, + {8314, 9670, 9670}, + {9666, 9674, 9675}, + {9506, 9679, 9679}, + {9669, 9685, 9686}, + {9683, 9690, 9690}, + {8763, 9697, 9698}, + {7468, 9702, 9702}, + {460, 9707, 9707}, + {3115, 9712, 9712}, + {9424, 9716, 9717}, + {7359, 9721, 9724}, + {7547, 9728, 9729}, + {7151, 9733, 9738}, + {7627, 9742, 9742}, + {2822, 9747, 9747}, + {8247, 9751, 9753}, + {9550, 9758, 9758}, + {7585, 9762, 9763}, + {1002, 9767, 9767}, + {7168, 9772, 9773}, + {6941, 9777, 9780}, + {9728, 9784, 9786}, + {9770, 9792, 9796}, + {6411, 9801, 9802}, + {3689, 9806, 9808}, + {9575, 9814, 9816}, + {7025, 9820, 9821}, + {2776, 9826, 9826}, + {9806, 9830, 9830}, + {9820, 9834, 9835}, + {9800, 9839, 9847}, + {9834, 9851, 9852}, + {9829, 9856, 9862}, + {1400, 9866, 9866}, + {3197, 9870, 9871}, + {9851, 9875, 9876}, + {9742, 9883, 9884}, + {3362, 9888, 9889}, + {9883, 9893, 9893}, + {5711, 9899, 9910}, + {7806, 9915, 9915}, + {9120, 9919, 9919}, + {9715, 9925, 9934}, + {2580, 9938, 9938}, + {4907, 9942, 9944}, + {6239, 9953, 9954}, + {6961, 9963, 9963}, + {5295, 9967, 9968}, + {1915, 9972, 9973}, + {3426, 9983, 9985}, + {9875, 9994, 9995}, + {6942, 9999, 9999}, + {6621, 10005, 10005}, + {7589, 10010, 10012}, + {9286, 10020, 10020}, + {838, 10024, 10024}, + {9980, 10028, 10031}, + {9994, 10035, 10041}, + {2702, 10048, 10051}, + {2621, 10059, 10059}, + {10054, 10065, 10065}, + {8612, 10073, 10074}, + {7033, 10078, 10078}, + {916, 10082, 10082}, + {10035, 10086, 10087}, + {8613, 10097, 10097}, + {9919, 10107, 10108}, + {6133, 10114, 10115}, + {10059, 10119, 10119}, + {10065, 10126, 10127}, + {7732, 10131, 10131}, + {7155, 10135, 10136}, + {6728, 10140, 10140}, + {6162, 10144, 10145}, + {4724, 10150, 10150}, + {1665, 10154, 10154}, + {10126, 10163, 10163}, + {9783, 10168, 10168}, + {1715, 10172, 10173}, + {7152, 10177, 10182}, + {8760, 10187, 10187}, + {7829, 10191, 10191}, + {9679, 10196, 10196}, + {9369, 10201, 10201}, + {2928, 10206, 10208}, + {6951, 10214, 10217}, + {5633, 10221, 10221}, + {7199, 10225, 10225}, + {10118, 10230, 10231}, + {9999, 10235, 10236}, + {10045, 10240, 10249}, + {5565, 10256, 10256}, + {9866, 10261, 10261}, + {10163, 10268, 10268}, + {9869, 10272, 10272}, + {9789, 10276, 10283}, + {10235, 10287, 10288}, + {10214, 10298, 10299}, + {6971, 10303, 10303}, + {3346, 10307, 10307}, + {10185, 10311, 10312}, + {9993, 10318, 10320}, + {2779, 10332, 10334}, + {1726, 10338, 10338}, + {741, 10354, 10360}, + {10230, 10372, 10373}, + {10260, 10384, 10385}, + {10131, 10389, 10398}, + {6946, 10406, 10409}, + {10158, 10413, 10420}, + {10123, 10424, 10424}, + {6157, 10428, 10429}, + {4518, 10434, 10434}, + {9893, 10438, 10438}, + {9865, 10442, 10446}, + {7558, 10454, 10454}, + {10434, 10460, 10460}, + {10064, 10466, 10468}, + {2703, 10472, 10474}, + {9751, 10478, 10479}, + {6714, 10485, 10485}, + {8020, 10490, 10490}, + {10303, 10494, 10494}, + {3521, 10499, 10500}, + {9281, 10513, 10515}, + {6028, 10519, 10523}, + {9387, 10527, 10527}, + {7614, 10531, 10531}, + {3611, 10536, 10536}, + {9162, 10540, 10540}, + {10081, 10546, 10547}, + {10034, 10560, 10562}, + {6726, 10567, 10571}, + {8237, 10575, 10575}, + {10438, 10579, 10583}, + {10140, 10587, 10587}, + {5784, 10592, 10592}, + {9819, 10597, 10600}, + {10567, 10604, 10608}, + {9335, 10613, 10613}, + {8300, 10617, 10617}, + {10575, 10621, 10621}, + {9678, 10625, 10626}, + {9962, 10632, 10633}, + {10535, 10637, 10638}, + {8199, 10642, 10642}, + {10372, 10647, 10648}, + {10637, 10656, 10657}, + {10579, 10667, 10668}, + {10465, 10677, 10680}, + {6702, 10684, 10685}, + {10073, 10691, 10692}, + {4505, 10696, 10697}, + {9042, 10701, 10701}, + {6460, 10705, 10706}, + {10010, 10714, 10716}, + {10656, 10720, 10722}, + {7282, 10727, 10729}, + {2327, 10733, 10733}, + {2491, 10740, 10741}, + {10704, 10748, 10750}, + {6465, 10754, 10754}, + {10647, 10758, 10759}, + {10424, 10763, 10763}, + {10748, 10776, 10776}, + {10546, 10780, 10781}, + {10758, 10785, 10786}, + {10287, 10790, 10797}, + {10785, 10801, 10807}, + {10240, 10811, 10826}, + {9509, 10830, 10830}, + {2579, 10836, 10838}, + {9801, 10843, 10845}, + {7555, 10849, 10850}, + {10776, 10860, 10865}, + {8023, 10869, 10869}, + {10046, 10876, 10884}, + {10253, 10888, 10892}, + {9941, 10897, 10897}, + {7898, 10901, 10905}, + {6725, 10909, 10913}, + {10757, 10921, 10923}, + {10160, 10931, 10931}, + {10916, 10935, 10942}, + {10261, 10946, 10946}, + {10318, 10952, 10954}, + {5911, 10959, 10961}, + {10801, 10965, 10966}, + {10946, 10970, 10977}, + {10592, 10982, 10984}, + {9913, 10988, 10990}, + {8510, 10994, 10996}, + {9419, 11000, 11001}, + {6765, 11006, 11007}, + {10725, 11011, 11011}, + {5537, 11017, 11019}, + {9208, 11024, 11025}, + {5850, 11030, 11030}, + {9610, 11034, 11036}, + {8846, 11041, 11047}, + {9697, 11051, 11051}, + {1622, 11055, 11058}, + {2370, 11062, 11062}, + {8393, 11067, 11067}, + {9756, 11071, 11071}, + {10172, 11076, 11076}, + {27, 11081, 11081}, + {7357, 11087, 11092}, + {8151, 11104, 11106}, + {6115, 11110, 11110}, + {10667, 11114, 11115}, + {11099, 11121, 11123}, + {10705, 11127, 11127}, + {8938, 11131, 11131}, + {11114, 11135, 11136}, + {1390, 11140, 11141}, + {10964, 11146, 11148}, + {11140, 11152, 11155}, + {9813, 11159, 11166}, + {624, 11171, 11172}, + {3118, 11177, 11179}, + {11029, 11184, 11186}, + {10186, 11190, 11190}, + {10306, 11196, 11196}, + {8665, 11201, 11201}, + {7382, 11205, 11205}, + {1100, 11210, 11210}, + {2337, 11216, 11217}, + {1609, 11221, 11223}, + {5763, 11228, 11229}, + {5220, 11233, 11233}, + {11061, 11241, 11241}, + {10617, 11246, 11246}, + {11190, 11250, 11251}, + {10144, 11255, 11256}, + {11232, 11260, 11260}, + {857, 11264, 11265}, + {10994, 11269, 11271}, + {3879, 11280, 11281}, + {11184, 11287, 11289}, + {9611, 11293, 11295}, + {11250, 11299, 11299}, + {4495, 11304, 11304}, + {7574, 11308, 11309}, + {9814, 11315, 11317}, + {1713, 11321, 11324}, + {1905, 11328, 11328}, + {8745, 11335, 11340}, + {8883, 11351, 11351}, + {8119, 11358, 11358}, + {1842, 11363, 11364}, + {11237, 11368, 11368}, + {8814, 11373, 11374}, + {5684, 11378, 11378}, + {11011, 11382, 11382}, + {6520, 11389, 11389}, + {11183, 11393, 11396}, + {1790, 11404, 11404}, + {9536, 11408, 11408}, + {11298, 11418, 11419}, + {3929, 11425, 11425}, + {5588, 11429, 11429}, + {8476, 11436, 11436}, + {4096, 11440, 11442}, + {11084, 11446, 11454}, + {10603, 11458, 11463}, + {7332, 11472, 11474}, + {7611, 11483, 11486}, + {4836, 11490, 11491}, + {10024, 11495, 11495}, + {4917, 11501, 11506}, + {6486, 11510, 11512}, + {11269, 11516, 11518}, + {3603, 11522, 11525}, + {11126, 11535, 11535}, + {11418, 11539, 11541}, + {11408, 11545, 11545}, + {9021, 11549, 11552}, + {6745, 11557, 11557}, + {5118, 11561, 11564}, + {7590, 11568, 11569}, + {4426, 11573, 11578}, + {9790, 11582, 11583}, + {6447, 11587, 11587}, + {10229, 11591, 11594}, + {10457, 11598, 11598}, + {10168, 11604, 11604}, + {10543, 11608, 11608}, + {7404, 11612, 11612}, + {11127, 11616, 11616}, + {3337, 11620, 11620}, + {11501, 11624, 11628}, + {4543, 11633, 11635}, + {8449, 11642, 11642}, + {4943, 11646, 11648}, + {10526, 11652, 11654}, + {11620, 11659, 11659}, + {8927, 11664, 11669}, + {532, 11673, 11673}, + {10513, 11677, 11679}, + {10428, 11683, 11683}, + {10999, 11689, 11690}, + {9469, 11695, 11695}, + {3606, 11699, 11699}, + {9560, 11708, 11709}, + {1564, 11714, 11714}, + {10527, 11718, 11718}, + {3071, 11723, 11726}, + {11590, 11731, 11732}, + {6605, 11737, 11737}, + {11624, 11741, 11745}, + {7822, 11749, 11752}, + {5269, 11757, 11758}, + {1339, 11767, 11767}, + {1363, 11771, 11773}, + {3704, 11777, 11777}, + {10952, 11781, 11783}, + {6764, 11793, 11795}, + {8675, 11800, 11800}, + {9963, 11804, 11804}, + {11573, 11808, 11809}, + {9548, 11813, 11813}, + {11591, 11817, 11818}, + {11446, 11822, 11822}, + {9224, 11828, 11828}, + {3158, 11836, 11836}, + {10830, 11840, 11840}, + {7234, 11846, 11846}, + {11299, 11850, 11850}, + {11544, 11854, 11855}, + {11498, 11859, 11859}, + {10993, 11865, 11868}, + {9720, 11872, 11878}, + {10489, 11882, 11890}, + {11712, 11898, 11904}, + {11516, 11908, 11910}, + {11568, 11914, 11915}, + {10177, 11919, 11924}, + {11363, 11928, 11929}, + {10494, 11933, 11933}, + {9870, 11937, 11938}, + {9427, 11942, 11942}, + {11481, 11949, 11949}, + {6030, 11955, 11957}, + {11718, 11961, 11961}, + {10531, 11965, 11983}, + {5126, 11987, 11987}, + {7515, 11991, 11991}, + {10646, 11996, 11997}, + {2947, 12001, 12001}, + {9582, 12009, 12010}, + {6202, 12017, 12018}, + {11714, 12022, 12022}, + {9235, 12033, 12037}, + {9721, 12041, 12044}, + {11932, 12051, 12052}, + {12040, 12056, 12056}, + {12051, 12060, 12060}, + {11601, 12066, 12066}, + {8426, 12070, 12070}, + {4053, 12077, 12077}, + {4262, 12081, 12081}, + {9761, 12086, 12088}, + {11582, 12092, 12093}, + {10965, 12097, 12098}, + {11803, 12103, 12104}, + {11933, 12108, 12109}, + {10688, 12117, 12117}, + {12107, 12125, 12126}, + {6774, 12130, 12132}, + {6286, 12137, 12137}, + {9543, 12141, 12141}, + {12097, 12145, 12146}, + {10790, 12150, 12150}, + {10125, 12154, 12156}, + {12125, 12164, 12164}, + {12064, 12168, 12172}, + {10811, 12178, 12188}, + {12092, 12192, 12193}, + {10058, 12197, 12198}, + {11611, 12211, 12212}, + {3459, 12216, 12216}, + {10291, 12225, 12228}, + {12191, 12232, 12234}, + {12145, 12238, 12238}, + {12001, 12242, 12250}, + {3840, 12255, 12255}, + {12216, 12259, 12259}, + {674, 12272, 12272}, + {12141, 12276, 12276}, + {10766, 12280, 12280}, + {11545, 12284, 12284}, + {6496, 12290, 12290}, + {11381, 12294, 12295}, + {603, 12302, 12303}, + {12276, 12308, 12308}, + {11850, 12313, 12314}, + {565, 12319, 12319}, + {9351, 12324, 12324}, + {11822, 12328, 12328}, + {2691, 12333, 12334}, + {11840, 12338, 12338}, + {11070, 12343, 12343}, + {9510, 12347, 12347}, + {11024, 12352, 12353}, + {7173, 12359, 12359}, + {517, 12363, 12363}, + {6311, 12367, 12368}, + {11367, 12372, 12373}, + {12008, 12377, 12377}, + {11372, 12382, 12384}, + {11358, 12391, 12392}, + {11382, 12396, 12396}, + {6882, 12400, 12401}, + {11246, 12405, 12405}, + {8359, 12409, 12412}, + {10154, 12418, 12418}, + {12016, 12425, 12426}, + {8972, 12434, 12435}, + {10478, 12439, 12440}, + {12395, 12449, 12449}, + {11612, 12454, 12454}, + {12347, 12458, 12458}, + {10700, 12466, 12467}, + {3637, 12471, 12476}, + {1042, 12480, 12481}, + {6747, 12488, 12488}, + {12396, 12492, 12493}, + {9420, 12497, 12497}, + {11285, 12501, 12510}, + {4470, 12515, 12515}, + {9374, 12519, 12519}, + {11293, 12528, 12528}, + {2058, 12534, 12535}, + {6521, 12539, 12539}, + {12492, 12543, 12543}, + {3043, 12547, 12547}, + {2982, 12551, 12553}, + {11030, 12557, 12563}, + {7636, 12568, 12568}, + {9639, 12572, 12572}, + {12543, 12576, 12576}, + {5989, 12580, 12583}, + {11051, 12587, 12587}, + {1061, 12592, 12594}, + {12313, 12599, 12601}, + {11846, 12605, 12605}, + {12576, 12609, 12609}, + {11040, 12618, 12625}, + {12479, 12629, 12629}, + {6903, 12633, 12633}, + {12322, 12639, 12639}, + {12253, 12643, 12645}, + {5594, 12651, 12651}, + {12522, 12655, 12655}, + {11703, 12659, 12659}, + {1377, 12665, 12665}, + {8022, 12669, 12669}, + {12280, 12674, 12674}, + {9023, 12680, 12681}, + {12328, 12685, 12685}, + {3085, 12689, 12693}, + {4700, 12698, 12698}, + {10224, 12702, 12702}, + {8781, 12706, 12706}, + {1651, 12710, 12710}, + {12458, 12714, 12714}, + {12005, 12718, 12721}, + {11908, 12725, 12726}, + {8202, 12733, 12733}, + {11708, 12739, 12740}, + {12599, 12744, 12745}, + {12284, 12749, 12749}, + {5285, 12756, 12756}, + {12055, 12775, 12777}, + {6919, 12782, 12782}, + {12242, 12786, 12786}, + {12009, 12790, 12790}, + {9628, 12794, 12796}, + {11354, 12801, 12802}, + {10225, 12806, 12807}, + {579, 12813, 12813}, + {8935, 12817, 12822}, + {8753, 12827, 12829}, + {11006, 12835, 12835}, + {858, 12841, 12845}, + {476, 12849, 12849}, + {7667, 12854, 12854}, + {12760, 12860, 12871}, + {11677, 12875, 12877}, + {12714, 12881, 12881}, + {12731, 12885, 12890}, + {7108, 12894, 12896}, + {1165, 12900, 12900}, + {4021, 12906, 12906}, + {10829, 12910, 12911}, + {12331, 12915, 12915}, + {8887, 12919, 12921}, + {11639, 12925, 12925}, + {7964, 12929, 12929}, + {12528, 12937, 12937}, + {8148, 12941, 12941}, + {12770, 12948, 12950}, + {12609, 12954, 12954}, + {12685, 12958, 12958}, + {2803, 12962, 12962}, + {9561, 12966, 12966}, + {6671, 12972, 12973}, + {12056, 12977, 12977}, + {6380, 12981, 12981}, + {12048, 12985, 12985}, + {11961, 12989, 12993}, + {3368, 12997, 12999}, + {6634, 13004, 13004}, + {6775, 13009, 13010}, + {12136, 13014, 13019}, + {10341, 13023, 13023}, + {13002, 13027, 13027}, + {10587, 13031, 13031}, + {10307, 13035, 13035}, + {12736, 13039, 13039}, + {12744, 13043, 13044}, + {6175, 13048, 13048}, + {9702, 13053, 13054}, + {662, 13059, 13061}, + {12718, 13065, 13068}, + {12893, 13072, 13075}, + {8299, 13086, 13091}, + {12604, 13095, 13096}, + {12848, 13100, 13101}, + {12749, 13105, 13105}, + {12526, 13109, 13114}, + {9173, 13122, 13122}, + {12769, 13128, 13128}, + {13038, 13132, 13132}, + {12725, 13136, 13137}, + {12639, 13146, 13146}, + {9711, 13150, 13151}, + {12137, 13155, 13155}, + {13039, 13159, 13159}, + {4681, 13163, 13164}, + {12954, 13168, 13168}, + {13158, 13175, 13176}, + {13105, 13180, 13180}, + {10754, 13184, 13184}, + {13167, 13188, 13188}, + {12658, 13192, 13192}, + {4294, 13199, 13200}, + {11682, 13204, 13205}, + {11695, 13209, 13209}, + {11076, 13214, 13214}, + {12232, 13218, 13218}, + {9399, 13223, 13224}, + {12880, 13228, 13229}, + {13048, 13234, 13234}, + {9701, 13238, 13239}, + {13209, 13243, 13243}, + {3658, 13248, 13248}, + {3698, 13252, 13254}, + {12237, 13260, 13260}, + {8872, 13266, 13266}, + {12957, 13272, 13273}, + {1393, 13281, 13281}, + {2013, 13285, 13288}, + {4244, 13296, 13299}, + {9428, 13303, 13303}, + {12702, 13307, 13307}, + {13078, 13311, 13311}, + {6071, 13315, 13315}, + {3061, 13319, 13319}, + {2051, 13324, 13324}, + {11560, 13328, 13331}, + {6584, 13336, 13336}, + {8482, 13340, 13340}, + {5331, 13344, 13344}, + {4171, 13348, 13348}, + {8501, 13352, 13352}, + {9219, 13356, 13356}, + {9473, 13360, 13363}, + {12881, 13367, 13367}, + {13065, 13371, 13375}, + {2979, 13379, 13384}, + {1518, 13388, 13388}, + {11177, 13392, 13392}, + {9457, 13398, 13398}, + {12293, 13407, 13410}, + {3697, 13414, 13417}, + {10338, 13425, 13425}, + {13367, 13429, 13429}, + {11074, 13433, 13437}, + {4201, 13441, 13443}, + {1812, 13447, 13448}, + {13360, 13452, 13456}, + {13188, 13463, 13463}, + {9732, 13470, 13470}, + {11332, 13477, 13477}, + {9918, 13487, 13487}, + {6337, 13497, 13497}, + {13429, 13501, 13501}, + {11413, 13505, 13505}, + {4685, 13512, 13513}, + {13136, 13517, 13519}, + {7416, 13528, 13530}, + {12929, 13534, 13534}, + {11110, 13539, 13539}, + {11521, 13543, 13543}, + {12825, 13553, 13553}, + {13447, 13557, 13558}, + {12299, 13562, 13563}, + {9003, 13570, 13570}, + {12500, 13577, 13577}, + {13501, 13581, 13581}, + {9392, 13586, 13586}, + {12454, 13590, 13590}, + {6189, 13595, 13595}, + {13053, 13599, 13599}, + {11881, 13604, 13604}, + {13159, 13608, 13608}, + {4894, 13612, 13612}, + {13221, 13621, 13621}, + {8950, 13625, 13625}, + {13533, 13629, 13629}, + {9633, 13633, 13633}, + {7892, 13637, 13639}, + {13581, 13643, 13643}, + {13616, 13647, 13649}, + {12794, 13653, 13654}, + {8919, 13659, 13659}, + {9674, 13663, 13663}, + {13577, 13668, 13668}, + {12966, 13672, 13672}, + {12659, 13676, 13683}, + {6124, 13688, 13688}, + {9225, 13693, 13695}, + {11833, 13702, 13702}, + {12904, 13709, 13717}, + {13647, 13721, 13722}, + {11687, 13726, 13727}, + {12434, 13731, 13732}, + {12689, 13736, 13742}, + {13168, 13746, 13746}, + {6151, 13751, 13752}, + {11821, 13756, 13757}, + {6467, 13764, 13764}, + {5730, 13769, 13769}, + {5136, 13780, 13780}, + {724, 13784, 13785}, + {13517, 13789, 13791}, + {640, 13795, 13796}, + {7721, 13800, 13802}, + {11121, 13806, 13807}, + {5791, 13811, 13815}, + {12894, 13819, 13819}, + {11100, 13824, 13824}, + {7011, 13830, 13830}, + {7129, 13834, 13837}, + {13833, 13841, 13841}, + {11276, 13847, 13847}, + {13621, 13853, 13853}, + {13589, 13862, 13863}, + {12989, 13867, 13867}, + {12789, 13871, 13871}, + {1239, 13875, 13875}, + {4675, 13879, 13881}, + {4686, 13885, 13885}, + {707, 13889, 13889}, + {5449, 13897, 13898}, + {13867, 13902, 13903}, + {10613, 13908, 13908}, + {13789, 13912, 13914}, + {4451, 13918, 13919}, + {9200, 13924, 13924}, + {2011, 13930, 13930}, + {11433, 13934, 13936}, + {4695, 13942, 13943}, + {9435, 13948, 13951}, + {13688, 13955, 13957}, + {11694, 13961, 13962}, + {5712, 13966, 13966}, + {5991, 13970, 13972}, + {13477, 13976, 13976}, + {10213, 13987, 13987}, + {11839, 13991, 13993}, + {12272, 13997, 13997}, + {6206, 14001, 14001}, + {13179, 14006, 14007}, + {2939, 14011, 14011}, + {12972, 14016, 14017}, + {13918, 14021, 14022}, + {7436, 14026, 14027}, + {7678, 14032, 14034}, + {13586, 14040, 14040}, + {13347, 14044, 14044}, + {13109, 14048, 14051}, + {9244, 14055, 14057}, + {13315, 14061, 14061}, + {13276, 14067, 14067}, + {11435, 14073, 14074}, + {13853, 14078, 14078}, + {13452, 14082, 14082}, + {14044, 14087, 14087}, + {4440, 14091, 14095}, + {4479, 14100, 14103}, + {9395, 14107, 14109}, + {6834, 14119, 14119}, + {10458, 14123, 14124}, + {1429, 14129, 14129}, + {8443, 14135, 14135}, + {10365, 14140, 14140}, + {5267, 14145, 14145}, + {11834, 14151, 14153}, +} diff --git a/lib/snappy/snappy.go b/lib/snappy/snappy.go new file mode 100644 index 0000000..74a3668 --- /dev/null +++ b/lib/snappy/snappy.go @@ -0,0 +1,98 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package snappy implements the Snappy compression format. It aims for very +// high speeds and reasonable compression. +// +// There are actually two Snappy formats: block and stream. They are related, +// but different: trying to decompress block-compressed data as a Snappy stream +// will fail, and vice versa. The block format is the Decode and Encode +// functions and the stream format is the Reader and Writer types. +// +// The block format, the more common case, is used when the complete size (the +// number of bytes) of the original data is known upfront, at the time +// compression starts. The stream format, also known as the framing format, is +// for when that isn't always true. +// +// The canonical, C++ implementation is at https://github.com/google/snappy and +// it only implements the block format. +package snappy + +import ( + "hash/crc32" +) + +/* +Each encoded block begins with the varint-encoded length of the decoded data, +followed by a sequence of chunks. Chunks begin and end on byte boundaries. The +first byte of each chunk is broken into its 2 least and 6 most significant bits +called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag. +Zero means a literal tag. All other values mean a copy tag. + +For literal tags: + - If m < 60, the next 1 + m bytes are literal bytes. + - Otherwise, let n be the little-endian unsigned integer denoted by the next + m - 59 bytes. The next 1 + n bytes after that are literal bytes. + +For copy tags, length bytes are copied from offset bytes ago, in the style of +Lempel-Ziv compression algorithms. In particular: + - For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12). + The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10 + of the offset. The next byte is bits 0-7 of the offset. + - For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65). + The length is 1 + m. The offset is the little-endian unsigned integer + denoted by the next 2 bytes. + - For l == 3, this tag is a legacy format that is no longer issued by most + encoders. Nonetheless, the offset ranges in [0, 1<<32) and the length in + [1, 65). The length is 1 + m. The offset is the little-endian unsigned + integer denoted by the next 4 bytes. +*/ +const ( + tagLiteral = 0x00 + tagCopy1 = 0x01 + tagCopy2 = 0x02 + tagCopy4 = 0x03 +) + +const ( + checksumSize = 4 + chunkHeaderSize = 4 + magicChunk = "\xff\x06\x00\x00" + magicBody + magicBody = "sNaPpY" + + // maxBlockSize is the maximum size of the input to encodeBlock. It is not + // part of the wire format per se, but some parts of the encoder assume + // that an offset fits into a uint16. + // + // Also, for the framing format (Writer type instead of Encode function), + // https://github.com/google/snappy/blob/master/framing_format.txt says + // that "the uncompressed data in a chunk must be no longer than 65536 + // bytes". + maxBlockSize = 65536 + + // maxEncodedLenOfMaxBlockSize equals MaxEncodedLen(maxBlockSize), but is + // hard coded to be a const instead of a variable, so that obufLen can also + // be a const. Their equivalence is confirmed by + // TestMaxEncodedLenOfMaxBlockSize. + maxEncodedLenOfMaxBlockSize = 76490 + + obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize + obufLen = obufHeaderLen + maxEncodedLenOfMaxBlockSize +) + +const ( + chunkTypeCompressedData = 0x00 + chunkTypeUncompressedData = 0x01 + chunkTypePadding = 0xfe + chunkTypeStreamIdentifier = 0xff +) + +var crcTable = crc32.MakeTable(crc32.Castagnoli) + +// crc implements the checksum specified in section 3 of +// https://github.com/google/snappy/blob/master/framing_format.txt +func crc(b []byte) uint32 { + c := crc32.Update(0, crcTable, b) + return uint32(c>>15|c<<17) + 0xa282ead8 +} diff --git a/lib/snappy/snappy_test.go b/lib/snappy/snappy_test.go new file mode 100644 index 0000000..2712710 --- /dev/null +++ b/lib/snappy/snappy_test.go @@ -0,0 +1,1353 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "bytes" + "encoding/binary" + "flag" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" +) + +var ( + download = flag.Bool("download", false, "If true, download any missing files before running benchmarks") + testdataDir = flag.String("testdataDir", "testdata", "Directory containing the test data") + benchdataDir = flag.String("benchdataDir", "testdata/bench", "Directory containing the benchmark data") +) + +// goEncoderShouldMatchCppEncoder is whether to test that the algorithm used by +// Go's encoder matches byte-for-byte what the C++ snappy encoder produces, on +// this GOARCH. There is more than one valid encoding of any given input, and +// there is more than one good algorithm along the frontier of trading off +// throughput for output size. Nonetheless, we presume that the C++ encoder's +// algorithm is a good one and has been tested on a wide range of inputs, so +// matching that exactly should mean that the Go encoder's algorithm is also +// good, without needing to gather our own corpus of test data. +// +// The exact algorithm used by the C++ code is potentially endian dependent, as +// it puns a byte pointer to a uint32 pointer to load, hash and compare 4 bytes +// at a time. The Go implementation is endian agnostic, in that its output is +// the same (as little-endian C++ code), regardless of the CPU's endianness. +// +// Thus, when comparing Go's output to C++ output generated beforehand, such as +// the "testdata/pi.txt.rawsnappy" file generated by C++ code on a little- +// endian system, we can run that test regardless of the runtime.GOARCH value. +// +// When comparing Go's output to dynamically generated C++ output, i.e. the +// result of fork/exec'ing a C++ program, we can run that test only on +// little-endian systems, because the C++ output might be different on +// big-endian systems. The runtime package doesn't export endianness per se, +// but we can restrict this match-C++ test to common little-endian systems. +const goEncoderShouldMatchCppEncoder = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "arm" + +func TestMaxEncodedLenOfMaxBlockSize(t *testing.T) { + got := maxEncodedLenOfMaxBlockSize + want := MaxEncodedLen(maxBlockSize) + if got != want { + t.Fatalf("got %d, want %d", got, want) + } +} + +func cmp(a, b []byte) error { + if bytes.Equal(a, b) { + return nil + } + if len(a) != len(b) { + return fmt.Errorf("got %d bytes, want %d", len(a), len(b)) + } + for i := range a { + if a[i] != b[i] { + return fmt.Errorf("byte #%d: got 0x%02x, want 0x%02x", i, a[i], b[i]) + } + } + return nil +} + +func roundtrip(b, ebuf, dbuf []byte) error { + d, err := Decode(dbuf, Encode(ebuf, b)) + if err != nil { + return fmt.Errorf("decoding error: %v", err) + } + if err := cmp(d, b); err != nil { + return fmt.Errorf("roundtrip mismatch: %v", err) + } + return nil +} + +func TestEmpty(t *testing.T) { + if err := roundtrip(nil, nil, nil); err != nil { + t.Fatal(err) + } +} + +func TestSmallCopy(t *testing.T) { + for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for i := 0; i < 32; i++ { + s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb" + if err := roundtrip([]byte(s), ebuf, dbuf); err != nil { + t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err) + } + } + } + } +} + +func TestSmallRand(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i := range b { + b[i] = uint8(rng.Intn(256)) + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func TestSmallRegular(t *testing.T) { + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i := range b { + b[i] = uint8(i%10 + 'a') + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func TestInvalidVarint(t *testing.T) { + testCases := []struct { + desc string + input string + }{{ + "invalid varint, final byte has continuation bit set", + "\xff", + }, { + "invalid varint, value overflows uint64", + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00", + }, { + // https://github.com/google/snappy/blob/master/format_description.txt + // says that "the stream starts with the uncompressed length [as a + // varint] (up to a maximum of 2^32 - 1)". + "valid varint (as uint64), but value overflows uint32", + "\x80\x80\x80\x80\x10", + }} + + for _, tc := range testCases { + input := []byte(tc.input) + if _, err := DecodedLen(input); err != ErrCorrupt { + t.Errorf("%s: DecodedLen: got %v, want ErrCorrupt", tc.desc, err) + } + if _, err := Decode(nil, input); err != ErrCorrupt { + t.Errorf("%s: Decode: got %v, want ErrCorrupt", tc.desc, err) + } + } +} + +func TestDecode(t *testing.T) { + lit40Bytes := make([]byte, 40) + for i := range lit40Bytes { + lit40Bytes[i] = byte(i) + } + lit40 := string(lit40Bytes) + + testCases := []struct { + desc string + input string + want string + wantErr error + }{{ + `decodedLen=0; valid input`, + "\x00", + "", + nil, + }, { + `decodedLen=3; tagLiteral, 0-byte length; length=3; valid input`, + "\x03" + "\x08\xff\xff\xff", + "\xff\xff\xff", + nil, + }, { + `decodedLen=2; tagLiteral, 0-byte length; length=3; not enough dst bytes`, + "\x02" + "\x08\xff\xff\xff", + "", + ErrCorrupt, + }, { + `decodedLen=3; tagLiteral, 0-byte length; length=3; not enough src bytes`, + "\x03" + "\x08\xff\xff", + "", + ErrCorrupt, + }, { + `decodedLen=40; tagLiteral, 0-byte length; length=40; valid input`, + "\x28" + "\x9c" + lit40, + lit40, + nil, + }, { + `decodedLen=1; tagLiteral, 1-byte length; not enough length bytes`, + "\x01" + "\xf0", + "", + ErrCorrupt, + }, { + `decodedLen=3; tagLiteral, 1-byte length; length=3; valid input`, + "\x03" + "\xf0\x02\xff\xff\xff", + "\xff\xff\xff", + nil, + }, { + `decodedLen=1; tagLiteral, 2-byte length; not enough length bytes`, + "\x01" + "\xf4\x00", + "", + ErrCorrupt, + }, { + `decodedLen=3; tagLiteral, 2-byte length; length=3; valid input`, + "\x03" + "\xf4\x02\x00\xff\xff\xff", + "\xff\xff\xff", + nil, + }, { + `decodedLen=1; tagLiteral, 3-byte length; not enough length bytes`, + "\x01" + "\xf8\x00\x00", + "", + ErrCorrupt, + }, { + `decodedLen=3; tagLiteral, 3-byte length; length=3; valid input`, + "\x03" + "\xf8\x02\x00\x00\xff\xff\xff", + "\xff\xff\xff", + nil, + }, { + `decodedLen=1; tagLiteral, 4-byte length; not enough length bytes`, + "\x01" + "\xfc\x00\x00\x00", + "", + ErrCorrupt, + }, { + `decodedLen=1; tagLiteral, 4-byte length; length=3; not enough dst bytes`, + "\x01" + "\xfc\x02\x00\x00\x00\xff\xff\xff", + "", + ErrCorrupt, + }, { + `decodedLen=4; tagLiteral, 4-byte length; length=3; not enough src bytes`, + "\x04" + "\xfc\x02\x00\x00\x00\xff", + "", + ErrCorrupt, + }, { + `decodedLen=3; tagLiteral, 4-byte length; length=3; valid input`, + "\x03" + "\xfc\x02\x00\x00\x00\xff\xff\xff", + "\xff\xff\xff", + nil, + }, { + `decodedLen=4; tagCopy1, 1 extra length|offset byte; not enough extra bytes`, + "\x04" + "\x01", + "", + ErrCorrupt, + }, { + `decodedLen=4; tagCopy2, 2 extra length|offset bytes; not enough extra bytes`, + "\x04" + "\x02\x00", + "", + ErrCorrupt, + }, { + `decodedLen=4; tagCopy4, 4 extra length|offset bytes; not enough extra bytes`, + "\x04" + "\x03\x00\x00\x00", + "", + ErrCorrupt, + }, { + `decodedLen=4; tagLiteral (4 bytes "abcd"); valid input`, + "\x04" + "\x0cabcd", + "abcd", + nil, + }, { + `decodedLen=13; tagLiteral (4 bytes "abcd"); tagCopy1; length=9 offset=4; valid input`, + "\x0d" + "\x0cabcd" + "\x15\x04", + "abcdabcdabcda", + nil, + }, { + `decodedLen=8; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=4; valid input`, + "\x08" + "\x0cabcd" + "\x01\x04", + "abcdabcd", + nil, + }, { + `decodedLen=8; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=2; valid input`, + "\x08" + "\x0cabcd" + "\x01\x02", + "abcdcdcd", + nil, + }, { + `decodedLen=8; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=1; valid input`, + "\x08" + "\x0cabcd" + "\x01\x01", + "abcddddd", + nil, + }, { + `decodedLen=8; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=0; zero offset`, + "\x08" + "\x0cabcd" + "\x01\x00", + "", + ErrCorrupt, + }, { + `decodedLen=9; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=4; inconsistent dLen`, + "\x09" + "\x0cabcd" + "\x01\x04", + "", + ErrCorrupt, + }, { + `decodedLen=8; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=5; offset too large`, + "\x08" + "\x0cabcd" + "\x01\x05", + "", + ErrCorrupt, + }, { + `decodedLen=7; tagLiteral (4 bytes "abcd"); tagCopy1; length=4 offset=4; length too large`, + "\x07" + "\x0cabcd" + "\x01\x04", + "", + ErrCorrupt, + }, { + `decodedLen=6; tagLiteral (4 bytes "abcd"); tagCopy2; length=2 offset=3; valid input`, + "\x06" + "\x0cabcd" + "\x06\x03\x00", + "abcdbc", + nil, + }, { + `decodedLen=6; tagLiteral (4 bytes "abcd"); tagCopy4; length=2 offset=3; valid input`, + "\x06" + "\x0cabcd" + "\x07\x03\x00\x00\x00", + "abcdbc", + nil, + }} + + const ( + // notPresentXxx defines a range of byte values [0xa0, 0xc5) that are + // not present in either the input or the output. It is written to dBuf + // to check that Decode does not write bytes past the end of + // dBuf[:dLen]. + // + // The magic number 37 was chosen because it is prime. A more 'natural' + // number like 32 might lead to a false negative if, for example, a + // byte was incorrectly copied 4*8 bytes later. + notPresentBase = 0xa0 + notPresentLen = 37 + ) + + var dBuf [100]byte +loop: + for i, tc := range testCases { + input := []byte(tc.input) + for _, x := range input { + if notPresentBase <= x && x < notPresentBase+notPresentLen { + t.Errorf("#%d (%s): input shouldn't contain %#02x\ninput: % x", i, tc.desc, x, input) + continue loop + } + } + + dLen, n := binary.Uvarint(input) + if n <= 0 { + t.Errorf("#%d (%s): invalid varint-encoded dLen", i, tc.desc) + continue + } + if dLen > uint64(len(dBuf)) { + t.Errorf("#%d (%s): dLen %d is too large", i, tc.desc, dLen) + continue + } + + for j := range dBuf { + dBuf[j] = byte(notPresentBase + j%notPresentLen) + } + g, gotErr := Decode(dBuf[:], input) + if got := string(g); got != tc.want || gotErr != tc.wantErr { + t.Errorf("#%d (%s):\ngot %q, %v\nwant %q, %v", + i, tc.desc, got, gotErr, tc.want, tc.wantErr) + continue + } + for j, x := range dBuf { + if uint64(j) < dLen { + continue + } + if w := byte(notPresentBase + j%notPresentLen); x != w { + t.Errorf("#%d (%s): Decode overrun: dBuf[%d] was modified: got %#02x, want %#02x\ndBuf: % x", + i, tc.desc, j, x, w, dBuf) + continue loop + } + } + } +} + +func TestDecodeCopy4(t *testing.T) { + dots := strings.Repeat(".", 65536) + + input := strings.Join([]string{ + "\x89\x80\x04", // decodedLen = 65545. + "\x0cpqrs", // 4-byte literal "pqrs". + "\xf4\xff\xff" + dots, // 65536-byte literal dots. + "\x13\x04\x00\x01\x00", // tagCopy4; length=5 offset=65540. + }, "") + + gotBytes, err := Decode(nil, []byte(input)) + if err != nil { + t.Fatal(err) + } + got := string(gotBytes) + want := "pqrs" + dots + "pqrs." + if len(got) != len(want) { + t.Fatalf("got %d bytes, want %d", len(got), len(want)) + } + if got != want { + for i := 0; i < len(got); i++ { + if g, w := got[i], want[i]; g != w { + t.Fatalf("byte #%d: got %#02x, want %#02x", i, g, w) + } + } + } +} + +// TestDecodeLengthOffset tests decoding an encoding of the form literal + +// copy-length-offset + literal. For example: "abcdefghijkl" + "efghij" + "AB". +func TestDecodeLengthOffset(t *testing.T) { + const ( + prefix = "abcdefghijklmnopqr" + suffix = "ABCDEFGHIJKLMNOPQR" + + // notPresentXxx defines a range of byte values [0xa0, 0xc5) that are + // not present in either the input or the output. It is written to + // gotBuf to check that Decode does not write bytes past the end of + // gotBuf[:totalLen]. + // + // The magic number 37 was chosen because it is prime. A more 'natural' + // number like 32 might lead to a false negative if, for example, a + // byte was incorrectly copied 4*8 bytes later. + notPresentBase = 0xa0 + notPresentLen = 37 + ) + var gotBuf, wantBuf, inputBuf [128]byte + for length := 1; length <= 18; length++ { + for offset := 1; offset <= 18; offset++ { + loop: + for suffixLen := 0; suffixLen <= 18; suffixLen++ { + totalLen := len(prefix) + length + suffixLen + + inputLen := binary.PutUvarint(inputBuf[:], uint64(totalLen)) + inputBuf[inputLen] = tagLiteral + 4*byte(len(prefix)-1) + inputLen++ + inputLen += copy(inputBuf[inputLen:], prefix) + inputBuf[inputLen+0] = tagCopy2 + 4*byte(length-1) + inputBuf[inputLen+1] = byte(offset) + inputBuf[inputLen+2] = 0x00 + inputLen += 3 + if suffixLen > 0 { + inputBuf[inputLen] = tagLiteral + 4*byte(suffixLen-1) + inputLen++ + inputLen += copy(inputBuf[inputLen:], suffix[:suffixLen]) + } + input := inputBuf[:inputLen] + + for i := range gotBuf { + gotBuf[i] = byte(notPresentBase + i%notPresentLen) + } + got, err := Decode(gotBuf[:], input) + if err != nil { + t.Errorf("length=%d, offset=%d; suffixLen=%d: %v", length, offset, suffixLen, err) + continue + } + + wantLen := 0 + wantLen += copy(wantBuf[wantLen:], prefix) + for i := 0; i < length; i++ { + wantBuf[wantLen] = wantBuf[wantLen-offset] + wantLen++ + } + wantLen += copy(wantBuf[wantLen:], suffix[:suffixLen]) + want := wantBuf[:wantLen] + + for _, x := range input { + if notPresentBase <= x && x < notPresentBase+notPresentLen { + t.Errorf("length=%d, offset=%d; suffixLen=%d: input shouldn't contain %#02x\ninput: % x", + length, offset, suffixLen, x, input) + continue loop + } + } + for i, x := range gotBuf { + if i < totalLen { + continue + } + if w := byte(notPresentBase + i%notPresentLen); x != w { + t.Errorf("length=%d, offset=%d; suffixLen=%d; totalLen=%d: "+ + "Decode overrun: gotBuf[%d] was modified: got %#02x, want %#02x\ngotBuf: % x", + length, offset, suffixLen, totalLen, i, x, w, gotBuf) + continue loop + } + } + for _, x := range want { + if notPresentBase <= x && x < notPresentBase+notPresentLen { + t.Errorf("length=%d, offset=%d; suffixLen=%d: want shouldn't contain %#02x\nwant: % x", + length, offset, suffixLen, x, want) + continue loop + } + } + + if !bytes.Equal(got, want) { + t.Errorf("length=%d, offset=%d; suffixLen=%d:\ninput % x\ngot % x\nwant % x", + length, offset, suffixLen, input, got, want) + continue + } + } + } + } +} + +const ( + goldenText = "Mark.Twain-Tom.Sawyer.txt" + goldenCompressed = goldenText + ".rawsnappy" +) + +func TestDecodeGoldenInput(t *testing.T) { + tDir := filepath.FromSlash(*testdataDir) + src, err := ioutil.ReadFile(filepath.Join(tDir, goldenCompressed)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + got, err := Decode(nil, src) + if err != nil { + t.Fatalf("Decode: %v", err) + } + want, err := ioutil.ReadFile(filepath.Join(tDir, goldenText)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if err := cmp(got, want); err != nil { + t.Fatal(err) + } +} + +func TestEncodeGoldenInput(t *testing.T) { + tDir := filepath.FromSlash(*testdataDir) + src, err := ioutil.ReadFile(filepath.Join(tDir, goldenText)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + got := Encode(nil, src) + want, err := ioutil.ReadFile(filepath.Join(tDir, goldenCompressed)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if err := cmp(got, want); err != nil { + t.Fatal(err) + } +} + +func TestExtendMatchGoldenInput(t *testing.T) { + tDir := filepath.FromSlash(*testdataDir) + src, err := ioutil.ReadFile(filepath.Join(tDir, goldenText)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + for i, tc := range extendMatchGoldenTestCases { + got := extendMatch(src, tc.i, tc.j) + if got != tc.want { + t.Errorf("test #%d: i, j = %5d, %5d: got %5d (= j + %6d), want %5d (= j + %6d)", + i, tc.i, tc.j, got, got-tc.j, tc.want, tc.want-tc.j) + } + } +} + +func TestExtendMatch(t *testing.T) { + // ref is a simple, reference implementation of extendMatch. + ref := func(src []byte, i, j int) int { + for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { + } + return j + } + + nums := []int{0, 1, 2, 7, 8, 9, 29, 30, 31, 32, 33, 34, 38, 39, 40} + for yIndex := 40; yIndex > 30; yIndex-- { + xxx := bytes.Repeat([]byte("x"), 40) + if yIndex < len(xxx) { + xxx[yIndex] = 'y' + } + for _, i := range nums { + for _, j := range nums { + if i >= j { + continue + } + got := extendMatch(xxx, i, j) + want := ref(xxx, i, j) + if got != want { + t.Errorf("yIndex=%d, i=%d, j=%d: got %d, want %d", yIndex, i, j, got, want) + } + } + } + } +} + +const snappytoolCmdName = "cmd/snappytool/snappytool" + +func skipTestSameEncodingAsCpp() (msg string) { + if !goEncoderShouldMatchCppEncoder { + return fmt.Sprintf("skipping testing that the encoding is byte-for-byte identical to C++: GOARCH=%s", runtime.GOARCH) + } + if _, err := os.Stat(snappytoolCmdName); err != nil { + return fmt.Sprintf("could not find snappytool: %v", err) + } + return "" +} + +func runTestSameEncodingAsCpp(src []byte) error { + got := Encode(nil, src) + + cmd := exec.Command(snappytoolCmdName, "-e") + cmd.Stdin = bytes.NewReader(src) + want, err := cmd.Output() + if err != nil { + return fmt.Errorf("could not run snappytool: %v", err) + } + return cmp(got, want) +} + +func TestSameEncodingAsCppShortCopies(t *testing.T) { + if msg := skipTestSameEncodingAsCpp(); msg != "" { + t.Skip(msg) + } + src := bytes.Repeat([]byte{'a'}, 20) + for i := 0; i <= len(src); i++ { + if err := runTestSameEncodingAsCpp(src[:i]); err != nil { + t.Errorf("i=%d: %v", i, err) + } + } +} + +func TestSameEncodingAsCppLongFiles(t *testing.T) { + if msg := skipTestSameEncodingAsCpp(); msg != "" { + t.Skip(msg) + } + bDir := filepath.FromSlash(*benchdataDir) + failed := false + for i, tf := range testFiles { + if err := downloadBenchmarkFiles(t, tf.filename); err != nil { + t.Fatalf("failed to download testdata: %s", err) + } + data := readFile(t, filepath.Join(bDir, tf.filename)) + if n := tf.sizeLimit; 0 < n && n < len(data) { + data = data[:n] + } + if err := runTestSameEncodingAsCpp(data); err != nil { + t.Errorf("i=%d: %v", i, err) + failed = true + } + } + if failed { + t.Errorf("was the snappytool program built against the C++ snappy library version " + + "d53de187 or later, commited on 2016-04-05? See " + + "https://github.com/google/snappy/commit/d53de18799418e113e44444252a39b12a0e4e0cc") + } +} + +// TestSlowForwardCopyOverrun tests the "expand the pattern" algorithm +// described in decode_amd64.s and its claim of a 10 byte overrun worst case. +func TestSlowForwardCopyOverrun(t *testing.T) { + const base = 100 + + for length := 1; length < 18; length++ { + for offset := 1; offset < 18; offset++ { + highWaterMark := base + d := base + l := length + o := offset + + // makeOffsetAtLeast8 + for o < 8 { + if end := d + 8; highWaterMark < end { + highWaterMark = end + } + l -= o + d += o + o += o + } + + // fixUpSlowForwardCopy + a := d + d += l + + // finishSlowForwardCopy + for l > 0 { + if end := a + 8; highWaterMark < end { + highWaterMark = end + } + a += 8 + l -= 8 + } + + dWant := base + length + overrun := highWaterMark - dWant + if d != dWant || overrun < 0 || 10 < overrun { + t.Errorf("length=%d, offset=%d: d and overrun: got (%d, %d), want (%d, something in [0, 10])", + length, offset, d, overrun, dWant) + } + } + } +} + +// TestEncodeNoiseThenRepeats encodes input for which the first half is very +// incompressible and the second half is very compressible. The encoded form's +// length should be closer to 50% of the original length than 100%. +func TestEncodeNoiseThenRepeats(t *testing.T) { + for _, origLen := range []int{256 * 1024, 2048 * 1024} { + src := make([]byte, origLen) + rng := rand.New(rand.NewSource(1)) + firstHalf, secondHalf := src[:origLen/2], src[origLen/2:] + for i := range firstHalf { + firstHalf[i] = uint8(rng.Intn(256)) + } + for i := range secondHalf { + secondHalf[i] = uint8(i >> 8) + } + dst := Encode(nil, src) + if got, want := len(dst), origLen*3/4; got >= want { + t.Errorf("origLen=%d: got %d encoded bytes, want less than %d", origLen, got, want) + } + } +} + +func TestFramingFormat(t *testing.T) { + // src is comprised of alternating 1e5-sized sequences of random + // (incompressible) bytes and repeated (compressible) bytes. 1e5 was chosen + // because it is larger than maxBlockSize (64k). + src := make([]byte, 1e6) + rng := rand.New(rand.NewSource(1)) + for i := 0; i < 10; i++ { + if i%2 == 0 { + for j := 0; j < 1e5; j++ { + src[1e5*i+j] = uint8(rng.Intn(256)) + } + } else { + for j := 0; j < 1e5; j++ { + src[1e5*i+j] = uint8(i) + } + } + } + + buf := new(bytes.Buffer) + if _, err := NewWriter(buf).Write(src); err != nil { + t.Fatalf("Write: encoding: %v", err) + } + dst, err := ioutil.ReadAll(NewReader(buf)) + if err != nil { + t.Fatalf("ReadAll: decoding: %v", err) + } + if err := cmp(dst, src); err != nil { + t.Fatal(err) + } +} + +func TestWriterGoldenOutput(t *testing.T) { + buf := new(bytes.Buffer) + w := NewBufferedWriter(buf) + defer w.Close() + w.Write([]byte("abcd")) // Not compressible. + w.Flush() + w.Write(bytes.Repeat([]byte{'A'}, 150)) // Compressible. + w.Flush() + // The next chunk is also compressible, but a naive, greedy encoding of the + // overall length 67 copy as a length 64 copy (the longest expressible as a + // tagCopy1 or tagCopy2) plus a length 3 remainder would be two 3-byte + // tagCopy2 tags (6 bytes), since the minimum length for a tagCopy1 is 4 + // bytes. Instead, we could do it shorter, in 5 bytes: a 3-byte tagCopy2 + // (of length 60) and a 2-byte tagCopy1 (of length 7). + w.Write(bytes.Repeat([]byte{'B'}, 68)) + w.Write([]byte("efC")) // Not compressible. + w.Write(bytes.Repeat([]byte{'C'}, 20)) // Compressible. + w.Write(bytes.Repeat([]byte{'B'}, 20)) // Compressible. + w.Write([]byte("g")) // Not compressible. + w.Flush() + + got := buf.String() + want := strings.Join([]string{ + magicChunk, + "\x01\x08\x00\x00", // Uncompressed chunk, 8 bytes long (including 4 byte checksum). + "\x68\x10\xe6\xb6", // Checksum. + "\x61\x62\x63\x64", // Uncompressed payload: "abcd". + "\x00\x11\x00\x00", // Compressed chunk, 17 bytes long (including 4 byte checksum). + "\x5f\xeb\xf2\x10", // Checksum. + "\x96\x01", // Compressed payload: Uncompressed length (varint encoded): 150. + "\x00\x41", // Compressed payload: tagLiteral, length=1, "A". + "\xfe\x01\x00", // Compressed payload: tagCopy2, length=64, offset=1. + "\xfe\x01\x00", // Compressed payload: tagCopy2, length=64, offset=1. + "\x52\x01\x00", // Compressed payload: tagCopy2, length=21, offset=1. + "\x00\x18\x00\x00", // Compressed chunk, 24 bytes long (including 4 byte checksum). + "\x30\x85\x69\xeb", // Checksum. + "\x70", // Compressed payload: Uncompressed length (varint encoded): 112. + "\x00\x42", // Compressed payload: tagLiteral, length=1, "B". + "\xee\x01\x00", // Compressed payload: tagCopy2, length=60, offset=1. + "\x0d\x01", // Compressed payload: tagCopy1, length=7, offset=1. + "\x08\x65\x66\x43", // Compressed payload: tagLiteral, length=3, "efC". + "\x4e\x01\x00", // Compressed payload: tagCopy2, length=20, offset=1. + "\x4e\x5a\x00", // Compressed payload: tagCopy2, length=20, offset=90. + "\x00\x67", // Compressed payload: tagLiteral, length=1, "g". + }, "") + if got != want { + t.Fatalf("\ngot: % x\nwant: % x", got, want) + } +} + +func TestEmitLiteral(t *testing.T) { + testCases := []struct { + length int + want string + }{ + {1, "\x00"}, + {2, "\x04"}, + {59, "\xe8"}, + {60, "\xec"}, + {61, "\xf0\x3c"}, + {62, "\xf0\x3d"}, + {254, "\xf0\xfd"}, + {255, "\xf0\xfe"}, + {256, "\xf0\xff"}, + {257, "\xf4\x00\x01"}, + {65534, "\xf4\xfd\xff"}, + {65535, "\xf4\xfe\xff"}, + {65536, "\xf4\xff\xff"}, + } + + dst := make([]byte, 70000) + nines := bytes.Repeat([]byte{0x99}, 65536) + for _, tc := range testCases { + lit := nines[:tc.length] + n := emitLiteral(dst, lit) + if !bytes.HasSuffix(dst[:n], lit) { + t.Errorf("length=%d: did not end with that many literal bytes", tc.length) + continue + } + got := string(dst[:n-tc.length]) + if got != tc.want { + t.Errorf("length=%d:\ngot % x\nwant % x", tc.length, got, tc.want) + continue + } + } +} + +func TestEmitCopy(t *testing.T) { + testCases := []struct { + offset int + length int + want string + }{ + {8, 04, "\x01\x08"}, + {8, 11, "\x1d\x08"}, + {8, 12, "\x2e\x08\x00"}, + {8, 13, "\x32\x08\x00"}, + {8, 59, "\xea\x08\x00"}, + {8, 60, "\xee\x08\x00"}, + {8, 61, "\xf2\x08\x00"}, + {8, 62, "\xf6\x08\x00"}, + {8, 63, "\xfa\x08\x00"}, + {8, 64, "\xfe\x08\x00"}, + {8, 65, "\xee\x08\x00\x05\x08"}, + {8, 66, "\xee\x08\x00\x09\x08"}, + {8, 67, "\xee\x08\x00\x0d\x08"}, + {8, 68, "\xfe\x08\x00\x01\x08"}, + {8, 69, "\xfe\x08\x00\x05\x08"}, + {8, 80, "\xfe\x08\x00\x3e\x08\x00"}, + + {256, 04, "\x21\x00"}, + {256, 11, "\x3d\x00"}, + {256, 12, "\x2e\x00\x01"}, + {256, 13, "\x32\x00\x01"}, + {256, 59, "\xea\x00\x01"}, + {256, 60, "\xee\x00\x01"}, + {256, 61, "\xf2\x00\x01"}, + {256, 62, "\xf6\x00\x01"}, + {256, 63, "\xfa\x00\x01"}, + {256, 64, "\xfe\x00\x01"}, + {256, 65, "\xee\x00\x01\x25\x00"}, + {256, 66, "\xee\x00\x01\x29\x00"}, + {256, 67, "\xee\x00\x01\x2d\x00"}, + {256, 68, "\xfe\x00\x01\x21\x00"}, + {256, 69, "\xfe\x00\x01\x25\x00"}, + {256, 80, "\xfe\x00\x01\x3e\x00\x01"}, + + {2048, 04, "\x0e\x00\x08"}, + {2048, 11, "\x2a\x00\x08"}, + {2048, 12, "\x2e\x00\x08"}, + {2048, 13, "\x32\x00\x08"}, + {2048, 59, "\xea\x00\x08"}, + {2048, 60, "\xee\x00\x08"}, + {2048, 61, "\xf2\x00\x08"}, + {2048, 62, "\xf6\x00\x08"}, + {2048, 63, "\xfa\x00\x08"}, + {2048, 64, "\xfe\x00\x08"}, + {2048, 65, "\xee\x00\x08\x12\x00\x08"}, + {2048, 66, "\xee\x00\x08\x16\x00\x08"}, + {2048, 67, "\xee\x00\x08\x1a\x00\x08"}, + {2048, 68, "\xfe\x00\x08\x0e\x00\x08"}, + {2048, 69, "\xfe\x00\x08\x12\x00\x08"}, + {2048, 80, "\xfe\x00\x08\x3e\x00\x08"}, + } + + dst := make([]byte, 1024) + for _, tc := range testCases { + n := emitCopy(dst, tc.offset, tc.length) + got := string(dst[:n]) + if got != tc.want { + t.Errorf("offset=%d, length=%d:\ngot % x\nwant % x", tc.offset, tc.length, got, tc.want) + } + } +} + +func TestNewBufferedWriter(t *testing.T) { + // Test all 32 possible sub-sequences of these 5 input slices. + // + // Their lengths sum to 400,000, which is over 6 times the Writer ibuf + // capacity: 6 * maxBlockSize is 393,216. + inputs := [][]byte{ + bytes.Repeat([]byte{'a'}, 40000), + bytes.Repeat([]byte{'b'}, 150000), + bytes.Repeat([]byte{'c'}, 60000), + bytes.Repeat([]byte{'d'}, 120000), + bytes.Repeat([]byte{'e'}, 30000), + } +loop: + for i := 0; i < 1< 0; { + i := copy(x, src) + x = x[i:] + } + return dst +} + +func benchWords(b *testing.B, n int, decode bool) { + // Note: the file is OS-language dependent so the resulting values are not + // directly comparable for non-US-English OS installations. + data := expand(readFile(b, "/usr/share/dict/words"), n) + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +func BenchmarkWordsDecode1e1(b *testing.B) { benchWords(b, 1e1, true) } +func BenchmarkWordsDecode1e2(b *testing.B) { benchWords(b, 1e2, true) } +func BenchmarkWordsDecode1e3(b *testing.B) { benchWords(b, 1e3, true) } +func BenchmarkWordsDecode1e4(b *testing.B) { benchWords(b, 1e4, true) } +func BenchmarkWordsDecode1e5(b *testing.B) { benchWords(b, 1e5, true) } +func BenchmarkWordsDecode1e6(b *testing.B) { benchWords(b, 1e6, true) } +func BenchmarkWordsEncode1e1(b *testing.B) { benchWords(b, 1e1, false) } +func BenchmarkWordsEncode1e2(b *testing.B) { benchWords(b, 1e2, false) } +func BenchmarkWordsEncode1e3(b *testing.B) { benchWords(b, 1e3, false) } +func BenchmarkWordsEncode1e4(b *testing.B) { benchWords(b, 1e4, false) } +func BenchmarkWordsEncode1e5(b *testing.B) { benchWords(b, 1e5, false) } +func BenchmarkWordsEncode1e6(b *testing.B) { benchWords(b, 1e6, false) } + +func BenchmarkRandomEncode(b *testing.B) { + rng := rand.New(rand.NewSource(1)) + data := make([]byte, 1<<20) + for i := range data { + data[i] = uint8(rng.Intn(256)) + } + benchEncode(b, data) +} + +// testFiles' values are copied directly from +// https://raw.githubusercontent.com/google/snappy/master/snappy_unittest.cc +// The label field is unused in snappy-go. +var testFiles = []struct { + label string + filename string + sizeLimit int +}{ + {"html", "html", 0}, + {"urls", "urls.10K", 0}, + {"jpg", "fireworks.jpeg", 0}, + {"jpg_200", "fireworks.jpeg", 200}, + {"pdf", "paper-100k.pdf", 0}, + {"html4", "html_x_4", 0}, + {"txt1", "alice29.txt", 0}, + {"txt2", "asyoulik.txt", 0}, + {"txt3", "lcet10.txt", 0}, + {"txt4", "plrabn12.txt", 0}, + {"pb", "geo.protodata", 0}, + {"gaviota", "kppkn.gtb", 0}, +} + +const ( + // The benchmark data files are at this canonical URL. + benchURL = "https://raw.githubusercontent.com/google/snappy/master/testdata/" +) + +func downloadBenchmarkFiles(b testing.TB, basename string) (errRet error) { + bDir := filepath.FromSlash(*benchdataDir) + filename := filepath.Join(bDir, basename) + if stat, err := os.Stat(filename); err == nil && stat.Size() != 0 { + return nil + } + + if !*download { + b.Skipf("test data not found; skipping %s without the -download flag", testOrBenchmark(b)) + } + // Download the official snappy C++ implementation reference test data + // files for benchmarking. + if err := os.MkdirAll(bDir, 0777); err != nil && !os.IsExist(err) { + return fmt.Errorf("failed to create %s: %s", bDir, err) + } + + f, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create %s: %s", filename, err) + } + defer f.Close() + defer func() { + if errRet != nil { + os.Remove(filename) + } + }() + url := benchURL + basename + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download %s: %s", url, err) + } + defer resp.Body.Close() + if s := resp.StatusCode; s != http.StatusOK { + return fmt.Errorf("downloading %s: HTTP status code %d (%s)", url, s, http.StatusText(s)) + } + _, err = io.Copy(f, resp.Body) + if err != nil { + return fmt.Errorf("failed to download %s to %s: %s", url, filename, err) + } + return nil +} + +func benchFile(b *testing.B, i int, decode bool) { + if err := downloadBenchmarkFiles(b, testFiles[i].filename); err != nil { + b.Fatalf("failed to download testdata: %s", err) + } + bDir := filepath.FromSlash(*benchdataDir) + data := readFile(b, filepath.Join(bDir, testFiles[i].filename)) + if n := testFiles[i].sizeLimit; 0 < n && n < len(data) { + data = data[:n] + } + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +// Naming convention is kept similar to what snappy's C++ implementation uses. +func Benchmark_UFlat0(b *testing.B) { benchFile(b, 0, true) } +func Benchmark_UFlat1(b *testing.B) { benchFile(b, 1, true) } +func Benchmark_UFlat2(b *testing.B) { benchFile(b, 2, true) } +func Benchmark_UFlat3(b *testing.B) { benchFile(b, 3, true) } +func Benchmark_UFlat4(b *testing.B) { benchFile(b, 4, true) } +func Benchmark_UFlat5(b *testing.B) { benchFile(b, 5, true) } +func Benchmark_UFlat6(b *testing.B) { benchFile(b, 6, true) } +func Benchmark_UFlat7(b *testing.B) { benchFile(b, 7, true) } +func Benchmark_UFlat8(b *testing.B) { benchFile(b, 8, true) } +func Benchmark_UFlat9(b *testing.B) { benchFile(b, 9, true) } +func Benchmark_UFlat10(b *testing.B) { benchFile(b, 10, true) } +func Benchmark_UFlat11(b *testing.B) { benchFile(b, 11, true) } +func Benchmark_ZFlat0(b *testing.B) { benchFile(b, 0, false) } +func Benchmark_ZFlat1(b *testing.B) { benchFile(b, 1, false) } +func Benchmark_ZFlat2(b *testing.B) { benchFile(b, 2, false) } +func Benchmark_ZFlat3(b *testing.B) { benchFile(b, 3, false) } +func Benchmark_ZFlat4(b *testing.B) { benchFile(b, 4, false) } +func Benchmark_ZFlat5(b *testing.B) { benchFile(b, 5, false) } +func Benchmark_ZFlat6(b *testing.B) { benchFile(b, 6, false) } +func Benchmark_ZFlat7(b *testing.B) { benchFile(b, 7, false) } +func Benchmark_ZFlat8(b *testing.B) { benchFile(b, 8, false) } +func Benchmark_ZFlat9(b *testing.B) { benchFile(b, 9, false) } +func Benchmark_ZFlat10(b *testing.B) { benchFile(b, 10, false) } +func Benchmark_ZFlat11(b *testing.B) { benchFile(b, 11, false) } + +func BenchmarkExtendMatch(b *testing.B) { + tDir := filepath.FromSlash(*testdataDir) + src, err := ioutil.ReadFile(filepath.Join(tDir, goldenText)) + if err != nil { + b.Fatalf("ReadFile: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range extendMatchGoldenTestCases { + extendMatch(src, tc.i, tc.j) + } + } +} diff --git a/server/base.go b/server/base.go index 9c093d2..a9c2a5e 100644 --- a/server/base.go +++ b/server/base.go @@ -3,7 +3,10 @@ package server import ( "errors" "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/pool" "net" "net/http" "sync" @@ -13,8 +16,8 @@ import ( type server struct { id int bridge *bridge.Bridge - task *lib.Tunnel - config *lib.Config + task *file.Tunnel + config *file.Config errorContent []byte sync.Mutex } @@ -26,7 +29,7 @@ func (s *server) FlowAdd(in, out int64) { s.task.Flow.InletFlow += in } -func (s *server) FlowAddHost(host *lib.Host, in, out int64) { +func (s *server) FlowAddHost(host *file.Host, in, out int64) { s.Lock() defer s.Unlock() host.Flow.ExportFlow += out @@ -36,7 +39,7 @@ func (s *server) FlowAddHost(host *lib.Host, in, out int64) { //热更新配置 func (s *server) ResetConfig() bool { //获取最新数据 - task, err := lib.GetCsvDb().GetTask(s.task.Id) + task, err := file.GetCsvDb().GetTask(s.task.Id) if err != nil { return false } @@ -45,7 +48,7 @@ func (s *server) ResetConfig() bool { } s.task.UseClientCnf = task.UseClientCnf //使用客户端配置 - client, err := lib.GetCsvDb().GetClient(s.task.Client.Id) + client, err := file.GetCsvDb().GetClient(s.task.Client.Id) if s.task.UseClientCnf { if err == nil { s.config.U = client.Cnf.U @@ -62,11 +65,11 @@ func (s *server) ResetConfig() bool { } } s.task.Client.Rate = client.Rate - s.config.CompressDecode, s.config.CompressEncode = lib.GetCompressType(s.config.Compress) + s.config.CompressDecode, s.config.CompressEncode = common.GetCompressType(s.config.Compress) return true } -func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Conn, flow *lib.Flow) { +func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) { if rb != nil { if _, err := tunnel.SendMsg(rb, link); err != nil { c.Close() @@ -74,32 +77,32 @@ func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Co } flow.Add(len(rb), 0) } + + buf := pool.BufPoolCopy.Get().([]byte) for { - buf := lib.BufPoolCopy.Get().([]byte) if n, err := c.Read(buf); err != nil { - tunnel.SendMsg([]byte(lib.IO_EOF), link) + tunnel.SendMsg([]byte(common.IO_EOF), link) break } else { if _, err := tunnel.SendMsg(buf[:n], link); err != nil { - lib.PutBufPoolCopy(buf) c.Close() break } - lib.PutBufPoolCopy(buf) flow.Add(n, 0) } } + pool.PutBufPoolCopy(buf) } func (s *server) writeConnFail(c net.Conn) { - c.Write([]byte(lib.ConnectionFailBytes)) + c.Write([]byte(common.ConnectionFailBytes)) c.Write(s.errorContent) } //权限认证 -func (s *server) auth(r *http.Request, c *lib.Conn, u, p string) error { - if u != "" && p != "" && !lib.CheckAuth(r, u, p) { - c.Write([]byte(lib.UnauthorizedBytes)) +func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error { + if u != "" && p != "" && !common.CheckAuth(r, u, p) { + c.Write([]byte(common.UnauthorizedBytes)) c.Close() return errors.New("401 Unauthorized") } diff --git a/server/http.go b/server/http.go index e31688f..ffa8fb1 100644 --- a/server/http.go +++ b/server/http.go @@ -3,9 +3,13 @@ package server import ( "bufio" "crypto/tls" - "github.com/astaxie/beego" + "github.com/cnlh/nps/lib/beego" "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/lg" + "github.com/cnlh/nps/lib/common" + "log" "net/http" "net/http/httputil" "path/filepath" @@ -22,7 +26,7 @@ type httpServer struct { stop chan bool } -func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer { +func NewHttp(bridge *bridge.Bridge, c *file.Tunnel) *httpServer { httpPort, _ := beego.AppConfig.Int("httpProxyPort") httpsPort, _ := beego.AppConfig.Int("httpsProxyPort") pemPath := beego.AppConfig.String("pemPath") @@ -44,33 +48,33 @@ func NewHttp(bridge *bridge.Bridge, c *lib.Tunnel) *httpServer { func (s *httpServer) Start() error { var err error var http, https *http.Server - if s.errorContent, err = lib.ReadAllFromFile(filepath.Join(lib.GetRunPath(), "web", "static", "page", "error.html")); err != nil { + if s.errorContent, err = common.ReadAllFromFile(filepath.Join(common.GetRunPath(), "web", "static", "page", "error.html")); err != nil { s.errorContent = []byte("easyProxy 404") } if s.httpPort > 0 { http = s.NewServer(s.httpPort) go func() { - lib.Println("启动http监听,端口为", s.httpPort) + lg.Println("启动http监听,端口为", s.httpPort) err := http.ListenAndServe() if err != nil { - lib.Fatalln(err) + lg.Fatalln(err) } }() } if s.httpsPort > 0 { - if !lib.FileExists(s.pemPath) { - lib.Fatalf("ssl certFile文件%s不存在", s.pemPath) + if !common.FileExists(s.pemPath) { + lg.Fatalf("ssl certFile文件%s不存在", s.pemPath) } - if !lib.FileExists(s.keyPath) { - lib.Fatalf("ssl keyFile文件%s不存在", s.keyPath) + if !common.FileExists(s.keyPath) { + lg.Fatalf("ssl keyFile文件%s不存在", s.keyPath) } https = s.NewServer(s.httpsPort) go func() { - lib.Println("启动https监听,端口为", s.httpsPort) + lg.Println("启动https监听,端口为", s.httpsPort) err := https.ListenAndServeTLS(s.pemPath, s.keyPath) if err != nil { - lib.Fatalln(err) + lg.Fatalln(err) } }() } @@ -96,40 +100,41 @@ func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) { http.Error(w, "Hijacking not supported", http.StatusInternalServerError) return } - conn, _, err := hijacker.Hijack() + c, _, err := hijacker.Hijack() if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) } - s.process(lib.NewConn(conn), r) + s.process(conn.NewConn(c), r) } -func (s *httpServer) process(c *lib.Conn, r *http.Request) { +func (s *httpServer) process(c *conn.Conn, r *http.Request) { //多客户端域名代理 var ( isConn = true - link *lib.Link - host *lib.Host - tunnel *lib.Conn + lk *conn.Link + host *file.Host + tunnel *conn.Conn err error ) for { //首次获取conn if isConn { if host, err = GetInfoByHost(r.Host); err != nil { - lib.Printf("the host %s is not found !", r.Host) + lg.Printf("the host %s is not found !", r.Host) break } //流量限制 if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) { break } - host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = lib.GetCompressType(host.Client.Cnf.Compress) + host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = common.GetCompressType(host.Client.Cnf.Compress) //权限控制 if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil { break } - link = lib.NewLink(host.Client.GetId(), lib.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil) - if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, link); err != nil { + lk = conn.NewLink(host.Client.GetId(), common.CONN_TCP, host.GetRandomTarget(), host.Client.Cnf.CompressEncode, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, c, host.Flow, nil, host.Client.Rate, nil) + if tunnel, err = s.bridge.SendLinkInfo(host.Client.Id, lk); err != nil { + log.Println(err) break } isConn = false @@ -140,13 +145,13 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) { } } //根据设定,修改header和host - lib.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String()) + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String()) b, err := httputil.DumpRequest(r, true) if err != nil { break } host.Flow.Add(len(b), 0) - if _, err := tunnel.SendMsg(b, link); err != nil { + if _, err := tunnel.SendMsg(b, lk); err != nil { c.Close() break } @@ -155,7 +160,7 @@ func (s *httpServer) process(c *lib.Conn, r *http.Request) { if isConn { s.writeConnFail(c.Conn) } else { - tunnel.SendMsg([]byte(lib.IO_EOF), link) + tunnel.SendMsg([]byte(common.IO_EOF), lk) } c.Close() diff --git a/server/server.go b/server/server.go index e765938..7a3ed97 100644 --- a/server/server.go +++ b/server/server.go @@ -3,7 +3,8 @@ package server import ( "errors" "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/lg" "reflect" "strings" ) @@ -11,44 +12,41 @@ import ( var ( Bridge *bridge.Bridge RunList map[int]interface{} //运行中的任务 - startFinish chan bool ) func init() { RunList = make(map[int]interface{}) - startFinish = make(chan bool) } //从csv文件中恢复任务 func InitFromCsv() { - for _, v := range lib.GetCsvDb().Tasks { + for _, v := range file.GetCsvDb().Tasks { if v.Status { - lib.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort) + lg.Println("启动模式:", v.Mode, "监听端口:", v.TcpPort) AddTask(v) } } } //start a new server -func StartNewServer(bridgePort int, cnf *lib.Tunnel) { - Bridge = bridge.NewTunnel(bridgePort, RunList) +func StartNewServer(bridgePort int, cnf *file.Tunnel, bridgeType string) { + Bridge = bridge.NewTunnel(bridgePort, RunList, bridgeType) if err := Bridge.StartTunnel(); err != nil { - lib.Fatalln("服务端开启失败", err) + lg.Fatalln("服务端开启失败", err) } if svr := NewMode(Bridge, cnf); svr != nil { RunList[cnf.Id] = svr err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0] if err.Interface() != nil { - lib.Fatalln(err) + lg.Fatalln(err) } } else { - lib.Fatalln("启动模式不正确") + lg.Fatalln("启动模式不正确") } - } //new a server by mode name -func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} { +func NewMode(Bridge *bridge.Bridge, c *file.Tunnel) interface{} { switch c.Mode { case "tunnelServer": return NewTunnelModeServer(ProcessTunnel, Bridge, c) @@ -60,17 +58,15 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} { return NewUdpModeServer(Bridge, c) case "webServer": InitFromCsv() - t := &lib.Tunnel{ + t := &file.Tunnel{ TcpPort: 0, Mode: "httpHostServer", Target: "", - Config: &lib.Config{}, + Config: &file.Config{}, Status: true, } AddTask(t) return NewWebServer(Bridge) - case "hostServer": - return NewHostServer(c) case "httpHostServer": return NewHttp(Bridge, c) } @@ -81,11 +77,11 @@ func NewMode(Bridge *bridge.Bridge, c *lib.Tunnel) interface{} { func StopServer(id int) error { if v, ok := RunList[id]; ok { reflect.ValueOf(v).MethodByName("Close").Call(nil) - if t, err := lib.GetCsvDb().GetTask(id); err != nil { + if t, err := file.GetCsvDb().GetTask(id); err != nil { return err } else { t.Status = false - lib.GetCsvDb().UpdateTask(t) + file.GetCsvDb().UpdateTask(t) } return nil } @@ -93,13 +89,13 @@ func StopServer(id int) error { } //add task -func AddTask(t *lib.Tunnel) error { +func AddTask(t *file.Tunnel) error { if svr := NewMode(Bridge, t); svr != nil { RunList[t.Id] = svr go func() { err := reflect.ValueOf(svr).MethodByName("Start").Call(nil)[0] if err.Interface() != nil { - lib.Fatalln("服务端", t.Id, "启动失败,错误:", err) + lg.Fatalln("客户端", t.Id, "启动失败,错误:", err) delete(RunList, t.Id) } }() @@ -111,12 +107,12 @@ func AddTask(t *lib.Tunnel) error { //start task func StartTask(id int) error { - if t, err := lib.GetCsvDb().GetTask(id); err != nil { + if t, err := file.GetCsvDb().GetTask(id); err != nil { return err } else { AddTask(t) t.Status = true - lib.GetCsvDb().UpdateTask(t) + file.GetCsvDb().UpdateTask(t) } return nil } @@ -126,12 +122,12 @@ func DelTask(id int) error { if err := StopServer(id); err != nil { return err } - return lib.GetCsvDb().DelTask(id) + return file.GetCsvDb().DelTask(id) } //get key by host from x -func GetInfoByHost(host string) (h *lib.Host, err error) { - for _, v := range lib.GetCsvDb().Hosts { +func GetInfoByHost(host string) (h *file.Host, err error) { + for _, v := range file.GetCsvDb().Hosts { s := strings.Split(host, ":") if s[0] == v.Host { h = v @@ -143,10 +139,10 @@ func GetInfoByHost(host string) (h *lib.Host, err error) { } //get task list by page num -func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel, int) { - list := make([]*lib.Tunnel, 0) +func GetTunnel(start, length int, typeVal string, clientId int) ([]*file.Tunnel, int) { + list := make([]*file.Tunnel, 0) var cnt int - for _, v := range lib.GetCsvDb().Tasks { + for _, v := range file.GetCsvDb().Tasks { if (typeVal != "" && v.Mode != typeVal) || (typeVal == "" && clientId != v.Client.Id) { continue } @@ -171,13 +167,13 @@ func GetTunnel(start, length int, typeVal string, clientId int) ([]*lib.Tunnel, } //获取客户端列表 -func GetClientList(start, length int) (list []*lib.Client, cnt int) { - list, cnt = lib.GetCsvDb().GetClientList(start, length) +func GetClientList(start, length int) (list []*file.Client, cnt int) { + list, cnt = file.GetCsvDb().GetClientList(start, length) dealClientData(list) return } -func dealClientData(list []*lib.Client) { +func dealClientData(list []*file.Client) { for _, v := range list { if _, ok := Bridge.Client[v.Id]; ok { v.IsConnect = true @@ -186,13 +182,13 @@ func dealClientData(list []*lib.Client) { } v.Flow.InletFlow = 0 v.Flow.ExportFlow = 0 - for _, h := range lib.GetCsvDb().Hosts { + for _, h := range file.GetCsvDb().Hosts { if h.Client.Id == v.Id { v.Flow.InletFlow += h.Flow.InletFlow v.Flow.ExportFlow += h.Flow.ExportFlow } } - for _, t := range lib.GetCsvDb().Tasks { + for _, t := range file.GetCsvDb().Tasks { if t.Client.Id == v.Id { v.Flow.InletFlow += t.Flow.InletFlow v.Flow.ExportFlow += t.Flow.ExportFlow @@ -204,14 +200,14 @@ func dealClientData(list []*lib.Client) { //根据客户端id删除其所属的所有隧道和域名 func DelTunnelAndHostByClientId(clientId int) { - for _, v := range lib.GetCsvDb().Tasks { + for _, v := range file.GetCsvDb().Tasks { if v.Client.Id == clientId { DelTask(v.Id) } } - for _, v := range lib.GetCsvDb().Hosts { + for _, v := range file.GetCsvDb().Hosts { if v.Client.Id == clientId { - lib.GetCsvDb().DelHost(v.Host) + file.GetCsvDb().DelHost(v.Host) } } } @@ -223,9 +219,9 @@ func DelClientConnect(clientId int) { func GetDashboardData() map[string]int { data := make(map[string]int) - data["hostCount"] = len(lib.GetCsvDb().Hosts) - data["clientCount"] = len(lib.GetCsvDb().Clients) - list := lib.GetCsvDb().Clients + data["hostCount"] = len(file.GetCsvDb().Hosts) + data["clientCount"] = len(file.GetCsvDb().Clients) + list := file.GetCsvDb().Clients dealClientData(list) c := 0 var in, out int64 @@ -239,7 +235,7 @@ func GetDashboardData() map[string]int { data["clientOnlineCount"] = c data["inletFlowCount"] = int(in) data["exportFlowCount"] = int(out) - for _, v := range lib.GetCsvDb().Tasks { + for _, v := range file.GetCsvDb().Tasks { switch v.Mode { case "tunnelServer": data["tunnelServerCount"] += 1 diff --git a/server/socks5.go b/server/socks5.go index cc491d6..f440f02 100755 --- a/server/socks5.go +++ b/server/socks5.go @@ -4,7 +4,10 @@ import ( "encoding/binary" "errors" "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/lg" "io" "net" "strconv" @@ -65,7 +68,7 @@ func (s *Sock5ModeServer) handleRequest(c net.Conn) { _, err := io.ReadFull(c, header) if err != nil { - lib.Println("illegal request", err) + lg.Println("illegal request", err) c.Close() return } @@ -135,18 +138,18 @@ func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) { addr := net.JoinHostPort(host, strconv.Itoa(int(port))) var ltype string if command == associateMethod { - ltype = lib.CONN_UDP + ltype = common.CONN_UDP } else { - ltype = lib.CONN_TCP + ltype = common.CONN_TCP } - link := lib.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, lib.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil) + link := conn.NewLink(s.task.Client.GetId(), ltype, addr, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, conn.NewConn(c), s.task.Flow, nil, s.task.Client.Rate, nil) if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil { c.Close() return } else { s.sendReply(c, succeeded) - s.linkCopy(link, lib.NewConn(c), nil, tunnel, s.task.Flow) + s.linkCopy(link, conn.NewConn(c), nil, tunnel, s.task.Flow) } return } @@ -162,7 +165,7 @@ func (s *Sock5ModeServer) handleBind(c net.Conn) { //udp func (s *Sock5ModeServer) handleUDP(c net.Conn) { - lib.Println("UDP Associate") + lg.Println("UDP Associate") /* +----+------+------+----------+----------+----------+ |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | @@ -175,7 +178,7 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) { // relay udp datagram silently, without any notification to the requesting client if buf[2] != 0 { // does not support fragmentation, drop it - lib.Println("does not support fragmentation, drop") + lg.Println("does not support fragmentation, drop") dummy := make([]byte, maxUDPPacketSize) c.Read(dummy) } @@ -187,13 +190,13 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) { func (s *Sock5ModeServer) handleConn(c net.Conn) { buf := make([]byte, 2) if _, err := io.ReadFull(c, buf); err != nil { - lib.Println("negotiation err", err) + lg.Println("negotiation err", err) c.Close() return } if version := buf[0]; version != 5 { - lib.Println("only support socks5, request from: ", c.RemoteAddr()) + lg.Println("only support socks5, request from: ", c.RemoteAddr()) c.Close() return } @@ -201,7 +204,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) { methods := make([]byte, nMethods) if len, err := c.Read(methods); len != int(nMethods) || err != nil { - lib.Println("wrong method") + lg.Println("wrong method") c.Close() return } @@ -210,7 +213,7 @@ func (s *Sock5ModeServer) handleConn(c net.Conn) { c.Write(buf) if err := s.Auth(c); err != nil { c.Close() - lib.Println("验证失败:", err) + lg.Println("验证失败:", err) return } } else { @@ -269,7 +272,7 @@ func (s *Sock5ModeServer) Start() error { if strings.Contains(err.Error(), "use of closed network connection") { break } - lib.Fatalln("accept error: ", err) + lg.Fatalln("accept error: ", err) } if !s.ResetConfig() { conn.Close() @@ -286,11 +289,11 @@ func (s *Sock5ModeServer) Close() error { } //new -func NewSock5ModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *Sock5ModeServer { +func NewSock5ModeServer(bridge *bridge.Bridge, task *file.Tunnel) *Sock5ModeServer { s := new(Sock5ModeServer) s.bridge = bridge s.task = task - s.config = lib.DeepCopyConfig(task.Config) + s.config = file.DeepCopyConfig(task.Config) if s.config.U != "" && s.config.P != "" { s.isVerify = true } else { diff --git a/server/tcp.go b/server/tcp.go index d9d1de3..8d9d0ac 100755 --- a/server/tcp.go +++ b/server/tcp.go @@ -2,9 +2,12 @@ package server import ( "errors" - "github.com/astaxie/beego" + "github.com/cnlh/nps/lib/beego" "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/lg" "net" "path/filepath" "strings" @@ -17,12 +20,12 @@ type TunnelModeServer struct { } //tcp|http|host -func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *lib.Tunnel) *TunnelModeServer { +func NewTunnelModeServer(process process, bridge *bridge.Bridge, task *file.Tunnel) *TunnelModeServer { s := new(TunnelModeServer) s.bridge = bridge s.process = process s.task = task - s.config = lib.DeepCopyConfig(task.Config) + s.config = file.DeepCopyConfig(task.Config) return s } @@ -34,22 +37,22 @@ func (s *TunnelModeServer) Start() error { return err } for { - conn, err := s.listener.AcceptTCP() + c, err := s.listener.AcceptTCP() if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { break } - lib.Println(err) + lg.Println(err) continue } - go s.process(lib.NewConn(conn), s) + go s.process(conn.NewConn(c), s) } return nil } //与客户端建立通道 -func (s *TunnelModeServer) dealClient(c *lib.Conn, cnf *lib.Config, addr string, method string, rb []byte) error { - link := lib.NewLink(s.task.Client.GetId(), lib.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil) +func (s *TunnelModeServer) dealClient(c *conn.Conn, cnf *file.Config, addr string, method string, rb []byte) error { + link := conn.NewLink(s.task.Client.GetId(), common.CONN_TCP, addr, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil) if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil { c.Close() @@ -73,13 +76,13 @@ type WebServer struct { //开始 func (s *WebServer) Start() error { p, _ := beego.AppConfig.Int("httpport") - if !lib.TestTcpPort(p) { - lib.Fatalln("web管理端口", p, "被占用!") + if !common.TestTcpPort(p) { + lg.Fatalln("web管理端口", p, "被占用!") } beego.BConfig.WebConfig.Session.SessionOn = true - lib.Println("web管理启动,访问端口为", beego.AppConfig.String("httpport")) - beego.SetStaticPath("/static", filepath.Join(lib.GetRunPath(), "web", "static")) - beego.SetViewsPath(filepath.Join(lib.GetRunPath(), "web", "views")) + lg.Println("web管理启动,访问端口为", p) + beego.SetStaticPath("/static", filepath.Join(common.GetRunPath(), "web", "static")) + beego.SetViewsPath(filepath.Join(common.GetRunPath(), "web", "views")) beego.Run() return errors.New("web管理启动失败") } @@ -91,32 +94,10 @@ func NewWebServer(bridge *bridge.Bridge) *WebServer { return s } -//host -type HostServer struct { - server -} - -//开始 -func (s *HostServer) Start() error { - return nil -} - -func NewHostServer(task *lib.Tunnel) *HostServer { - s := new(HostServer) - s.task = task - s.config = lib.DeepCopyConfig(task.Config) - return s -} - -//close -func (s *HostServer) Close() error { - return nil -} - -type process func(c *lib.Conn, s *TunnelModeServer) error +type process func(c *conn.Conn, s *TunnelModeServer) error //tcp隧道模式 -func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error { +func ProcessTunnel(c *conn.Conn, s *TunnelModeServer) error { if !s.ResetConfig() { c.Close() return errors.New("流量超出") @@ -125,7 +106,7 @@ func ProcessTunnel(c *lib.Conn, s *TunnelModeServer) error { } //http代理模式 -func ProcessHttp(c *lib.Conn, s *TunnelModeServer) error { +func ProcessHttp(c *conn.Conn, s *TunnelModeServer) error { if !s.ResetConfig() { c.Close() return errors.New("流量超出") diff --git a/server/test.go b/server/test.go index 358ef0c..341852d 100644 --- a/server/test.go +++ b/server/test.go @@ -1,54 +1,78 @@ package server import ( - "github.com/astaxie/beego" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/beego" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/file" "log" "strconv" ) func TestServerConfig() { - var postArr []int - for _, v := range lib.GetCsvDb().Tasks { - isInArr(&postArr, v.TcpPort, v.Remark) + var postTcpArr []int + var postUdpArr []int + for _, v := range file.GetCsvDb().Tasks { + if v.Mode == "udpServer" { + isInArr(&postUdpArr, v.TcpPort, v.Remark, "udp") + } else { + isInArr(&postTcpArr, v.TcpPort, v.Remark, "tcp") + } } p, err := beego.AppConfig.Int("httpport") if err != nil { log.Fatalln("Getting web management port error :", err) } else { - isInArr(&postArr, p, "WebmManagement port") + isInArr(&postTcpArr, p, "Web Management port", "tcp") } + + if p := beego.AppConfig.String("bridgePort"); p != "" { + if port, err := strconv.Atoi(p); err != nil { + log.Fatalln("get Server and client communication portserror:", err) + } else if beego.AppConfig.String("bridgeType") == "kcp" { + isInArr(&postUdpArr, port, "Server and client communication ports", "udp") + } else { + isInArr(&postTcpArr, port, "Server and client communication ports", "tcp") + } + } + if p := beego.AppConfig.String("httpProxyPort"); p != "" { if port, err := strconv.Atoi(p); err != nil { log.Fatalln("get http port error:", err) } else { - isInArr(&postArr, port, "https port") + isInArr(&postTcpArr, port, "https port", "tcp") } } if p := beego.AppConfig.String("httpsProxyPort"); p != "" { if port, err := strconv.Atoi(p); err != nil { log.Fatalln("get https port error", err) } else { - if !lib.FileExists(beego.AppConfig.String("pemPath")) { + if !common.FileExists(beego.AppConfig.String("pemPath")) { log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath")) } - if !lib.FileExists(beego.AppConfig.String("ketPath")) { + if !common.FileExists(beego.AppConfig.String("ketPath")) { log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath")) } - isInArr(&postArr, port, "http port") + isInArr(&postTcpArr, port, "http port", "tcp") } } } -func isInArr(arr *[]int, val int, remark string) { +func isInArr(arr *[]int, val int, remark string, tp string) { for _, v := range *arr { if v == val { log.Fatalf("the port %d is reused,remark: %s", val, remark) } } - if !lib.TestTcpPort(val) { - log.Fatalf("open the %d port error ,remark: %s", val, remark) + if tp == "tcp" { + if !common.TestTcpPort(val) { + log.Fatalf("open the %d port error ,remark: %s", val, remark) + } + } else { + if !common.TestUdpPort(val) { + log.Fatalf("open the %d port error ,remark: %s", val, remark) + } } + *arr = append(*arr, val) return } diff --git a/server/udp.go b/server/udp.go index 9fc39c7..039c3fc 100755 --- a/server/udp.go +++ b/server/udp.go @@ -2,7 +2,10 @@ package server import ( "github.com/cnlh/nps/bridge" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/common" + "github.com/cnlh/nps/lib/conn" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/pool" "net" "strings" ) @@ -10,15 +13,15 @@ import ( type UdpModeServer struct { server listener *net.UDPConn - udpMap map[string]*lib.Conn + udpMap map[string]*conn.Conn } -func NewUdpModeServer(bridge *bridge.Bridge, task *lib.Tunnel) *UdpModeServer { +func NewUdpModeServer(bridge *bridge.Bridge, task *file.Tunnel) *UdpModeServer { s := new(UdpModeServer) s.bridge = bridge - s.udpMap = make(map[string]*lib.Conn) + s.udpMap = make(map[string]*conn.Conn) s.task = task - s.config = lib.DeepCopyConfig(task.Config) + s.config = file.DeepCopyConfig(task.Config) return s } @@ -29,7 +32,7 @@ func (s *UdpModeServer) Start() error { if err != nil { return err } - buf := lib.BufPoolUdp.Get().([]byte) + buf := pool.BufPoolUdp.Get().([]byte) for { n, addr, err := s.listener.ReadFromUDP(buf) if err != nil { @@ -47,13 +50,14 @@ func (s *UdpModeServer) Start() error { } func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) { - link := lib.NewLink(s.task.Client.GetId(), lib.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr) + link := conn.NewLink(s.task.Client.GetId(), common.CONN_UDP, s.task.Target, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, nil, s.task.Flow, s.listener, s.task.Client.Rate, addr) if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link); err != nil { return } else { s.task.Flow.Add(len(data), 0) tunnel.SendMsg(data, link) + pool.PutBufPoolUdp(data) } } diff --git a/web/controllers/base.go b/web/controllers/base.go index 6039af5..0431ac5 100755 --- a/web/controllers/base.go +++ b/web/controllers/base.go @@ -1,8 +1,8 @@ package controllers import ( - "github.com/astaxie/beego" - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/beego" + "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/server" "strconv" "strings" @@ -40,7 +40,7 @@ func (s *BaseController) display(tpl ...string) { } ip := s.Ctx.Request.Host if strings.LastIndex(ip, ":") > 0 { - arr := strings.Split(lib.GetHostByName(ip), ":") + arr := strings.Split(common.GetHostByName(ip), ":") s.Data["ip"] = arr[0] } s.Data["p"] = server.Bridge.TunnelPort diff --git a/web/controllers/client.go b/web/controllers/client.go index 85cc43e..907957f 100644 --- a/web/controllers/client.go +++ b/web/controllers/client.go @@ -1,7 +1,9 @@ package controllers import ( - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/crypt" + "github.com/cnlh/nps/lib/file" + "github.com/cnlh/nps/lib/rate" "github.com/cnlh/nps/server" ) @@ -28,29 +30,29 @@ func (s *ClientController) Add() { s.SetInfo("新增") s.display() } else { - t := &lib.Client{ - VerifyKey: lib.GetRandomString(16), - Id: lib.GetCsvDb().GetClientId(), + t := &file.Client{ + VerifyKey: crypt.GetRandomString(16), + Id: file.GetCsvDb().GetClientId(), Status: true, Remark: s.GetString("remark"), - Cnf: &lib.Config{ + Cnf: &file.Config{ U: s.GetString("u"), P: s.GetString("p"), Compress: s.GetString("compress"), Crypt: s.GetBoolNoErr("crypt"), }, RateLimit: s.GetIntNoErr("rate_limit"), - Flow: &lib.Flow{ + Flow: &file.Flow{ ExportFlow: 0, InletFlow: 0, FlowLimit: int64(s.GetIntNoErr("flow_limit")), }, } if t.RateLimit > 0 { - t.Rate = lib.NewRate(int64(t.RateLimit * 1024)) + t.Rate = rate.NewRate(int64(t.RateLimit * 1024)) t.Rate.Start() } - lib.GetCsvDb().NewClient(t) + file.GetCsvDb().NewClient(t) s.AjaxOk("添加成功") } } @@ -58,7 +60,7 @@ func (s *ClientController) GetClient() { if s.Ctx.Request.Method == "POST" { id := s.GetIntNoErr("id") data := make(map[string]interface{}) - if c, err := lib.GetCsvDb().GetClient(id); err != nil { + if c, err := file.GetCsvDb().GetClient(id); err != nil { data["code"] = 0 } else { data["code"] = 1 @@ -74,7 +76,7 @@ func (s *ClientController) Edit() { id := s.GetIntNoErr("id") if s.Ctx.Request.Method == "GET" { s.Data["menu"] = "client" - if c, err := lib.GetCsvDb().GetClient(id); err != nil { + if c, err := file.GetCsvDb().GetClient(id); err != nil { s.error() } else { s.Data["c"] = c @@ -82,7 +84,7 @@ func (s *ClientController) Edit() { s.SetInfo("修改") s.display() } else { - if c, err := lib.GetCsvDb().GetClient(id); err != nil { + if c, err := file.GetCsvDb().GetClient(id); err != nil { s.error() } else { c.Remark = s.GetString("remark") @@ -96,12 +98,12 @@ func (s *ClientController) Edit() { c.Rate.Stop() } if c.RateLimit > 0 { - c.Rate = lib.NewRate(int64(c.RateLimit * 1024)) + c.Rate = rate.NewRate(int64(c.RateLimit * 1024)) c.Rate.Start() } else { c.Rate = nil } - lib.GetCsvDb().UpdateClient(c) + file.GetCsvDb().UpdateClient(c) } s.AjaxOk("修改成功") } @@ -110,7 +112,7 @@ func (s *ClientController) Edit() { //更改状态 func (s *ClientController) ChangeStatus() { id := s.GetIntNoErr("id") - if client, err := lib.GetCsvDb().GetClient(id); err == nil { + if client, err := file.GetCsvDb().GetClient(id); err == nil { client.Status = s.GetBoolNoErr("status") if client.Status == false { server.DelClientConnect(client.Id) @@ -123,7 +125,7 @@ func (s *ClientController) ChangeStatus() { //删除客户端 func (s *ClientController) Del() { id := s.GetIntNoErr("id") - if err := lib.GetCsvDb().DelClient(id); err != nil { + if err := file.GetCsvDb().DelClient(id); err != nil { s.AjaxErr("删除失败") } server.DelTunnelAndHostByClientId(id) diff --git a/web/controllers/index.go b/web/controllers/index.go index ea25f18..462785b 100755 --- a/web/controllers/index.go +++ b/web/controllers/index.go @@ -1,7 +1,7 @@ package controllers import ( - "github.com/cnlh/nps/lib" + "github.com/cnlh/nps/lib/file" "github.com/cnlh/nps/server" ) @@ -72,27 +72,27 @@ func (s *IndexController) Add() { s.SetInfo("新增") s.display() } else { - t := &lib.Tunnel{ + t := &file.Tunnel{ TcpPort: s.GetIntNoErr("port"), Mode: s.GetString("type"), Target: s.GetString("target"), - Config: &lib.Config{ + Config: &file.Config{ U: s.GetString("u"), P: s.GetString("p"), Compress: s.GetString("compress"), Crypt: s.GetBoolNoErr("crypt"), }, - Id: lib.GetCsvDb().GetTaskId(), + Id: file.GetCsvDb().GetTaskId(), UseClientCnf: s.GetBoolNoErr("use_client"), Status: true, Remark: s.GetString("remark"), - Flow: &lib.Flow{}, + Flow: &file.Flow{}, } var err error - if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { + if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { s.AjaxErr(err.Error()) } - lib.GetCsvDb().NewTask(t) + file.GetCsvDb().NewTask(t) if err := server.AddTask(t); err != nil { s.AjaxErr(err.Error()) } else { @@ -103,7 +103,7 @@ func (s *IndexController) Add() { func (s *IndexController) GetOneTunnel() { id := s.GetIntNoErr("id") data := make(map[string]interface{}) - if t, err := lib.GetCsvDb().GetTask(id); err != nil { + if t, err := file.GetCsvDb().GetTask(id); err != nil { data["code"] = 0 } else { data["code"] = 1 @@ -115,7 +115,7 @@ func (s *IndexController) GetOneTunnel() { func (s *IndexController) Edit() { id := s.GetIntNoErr("id") if s.Ctx.Request.Method == "GET" { - if t, err := lib.GetCsvDb().GetTask(id); err != nil { + if t, err := file.GetCsvDb().GetTask(id); err != nil { s.error() } else { s.Data["t"] = t @@ -123,7 +123,7 @@ func (s *IndexController) Edit() { s.SetInfo("修改") s.display() } else { - if t, err := lib.GetCsvDb().GetTask(id); err != nil { + if t, err := file.GetCsvDb().GetTask(id); err != nil { s.error() } else { t.TcpPort = s.GetIntNoErr("port") @@ -137,10 +137,10 @@ func (s *IndexController) Edit() { t.Config.Crypt = s.GetBoolNoErr("crypt") t.UseClientCnf = s.GetBoolNoErr("use_client") t.Remark = s.GetString("remark") - if t.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { + if t.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { s.AjaxErr("修改失败") } - lib.GetCsvDb().UpdateTask(t) + file.GetCsvDb().UpdateTask(t) } s.AjaxOk("修改成功") } @@ -179,7 +179,7 @@ func (s *IndexController) HostList() { } else { start, length := s.GetAjaxParams() clientId := s.GetIntNoErr("client_id") - list, cnt := lib.GetCsvDb().GetHost(start, length, clientId) + list, cnt := file.GetCsvDb().GetHost(start, length, clientId) s.AjaxTable(list, cnt, cnt) } } @@ -200,7 +200,7 @@ func (s *IndexController) GetHost() { func (s *IndexController) DelHost() { host := s.GetString("host") - if err := lib.GetCsvDb().DelHost(host); err != nil { + if err := file.GetCsvDb().DelHost(host); err != nil { s.AjaxErr("删除失败") } s.AjaxOk("删除成功") @@ -213,19 +213,19 @@ func (s *IndexController) AddHost() { s.SetInfo("新增") s.display("index/hadd") } else { - h := &lib.Host{ + h := &file.Host{ Host: s.GetString("host"), Target: s.GetString("target"), HeaderChange: s.GetString("header"), HostChange: s.GetString("hostchange"), Remark: s.GetString("remark"), - Flow: &lib.Flow{}, + Flow: &file.Flow{}, } var err error - if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { + if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { s.AjaxErr("添加失败") } - lib.GetCsvDb().NewHost(h) + file.GetCsvDb().NewHost(h) s.AjaxOk("添加成功") } } @@ -251,9 +251,9 @@ func (s *IndexController) EditHost() { h.HostChange = s.GetString("hostchange") h.Remark = s.GetString("remark") h.TargetArr = nil - lib.GetCsvDb().UpdateHost(h) + file.GetCsvDb().UpdateHost(h) var err error - if h.Client, err = lib.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { + if h.Client, err = file.GetCsvDb().GetClient(s.GetIntNoErr("client_id")); err != nil { s.AjaxErr("修改失败") } } diff --git a/web/controllers/login.go b/web/controllers/login.go index fb0cc5d..53ea676 100755 --- a/web/controllers/login.go +++ b/web/controllers/login.go @@ -1,7 +1,7 @@ package controllers import ( - "github.com/astaxie/beego" + "github.com/cnlh/nps/lib/beego" ) type LoginController struct { diff --git a/web/routers/router.go b/web/routers/router.go index 5c820e2..869be07 100755 --- a/web/routers/router.go +++ b/web/routers/router.go @@ -1,7 +1,7 @@ package routers import ( - "github.com/astaxie/beego" + "github.com/cnlh/nps/lib/beego" "github.com/cnlh/nps/web/controllers" )