mirror of
https://github.com/ehang-io/nps.git
synced 2025-09-02 11:56:53 +00:00
客户端服务端分离
This commit is contained in:
354
utils/conn.go
Executable file
354
utils/conn.go
Executable file
@@ -0,0 +1,354 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/golang/snappy"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const cryptKey = "1234567812345678"
|
||||
|
||||
type CryptConn struct {
|
||||
conn net.Conn
|
||||
crypt bool
|
||||
}
|
||||
|
||||
func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
|
||||
c := new(CryptConn)
|
||||
c.conn = conn
|
||||
c.crypt = crypt
|
||||
return c
|
||||
}
|
||||
|
||||
//加密写
|
||||
func (s *CryptConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if b, err = GetLenBytes(b); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = s.conn.Write(b)
|
||||
return
|
||||
}
|
||||
|
||||
//解密读
|
||||
func (s *CryptConn) Read(b []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
|
||||
err = io.EOF
|
||||
n = 0
|
||||
}
|
||||
}()
|
||||
var lens int
|
||||
var buf, bs []byte
|
||||
c := NewConn(s.conn)
|
||||
if lens, err = c.GetLen(); err != nil {
|
||||
return
|
||||
}
|
||||
if buf, err = c.ReadLen(lens); err != nil {
|
||||
return
|
||||
}
|
||||
if s.crypt {
|
||||
if bs, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
bs = buf
|
||||
}
|
||||
n = len(bs)
|
||||
copy(b, bs)
|
||||
return
|
||||
}
|
||||
|
||||
type SnappyConn struct {
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
crypt bool
|
||||
}
|
||||
|
||||
func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
|
||||
c := new(SnappyConn)
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
c.crypt = crypt
|
||||
return c
|
||||
}
|
||||
|
||||
//snappy压缩写 包含加密
|
||||
func (s *SnappyConn) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
if s.crypt {
|
||||
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
|
||||
log.Println("encode crypt error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if _, err = s.w.Write(b); err != nil {
|
||||
return
|
||||
}
|
||||
err = s.w.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
//snappy压缩读 包含解密
|
||||
func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
|
||||
err = io.EOF
|
||||
n = 0
|
||||
}
|
||||
}()
|
||||
if n, err = s.r.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
if s.crypt {
|
||||
var bs []byte
|
||||
if bs, err = AesDecrypt(b[:n], []byte(cryptKey)); err != nil {
|
||||
log.Println("decode crypt error:", err)
|
||||
return
|
||||
}
|
||||
n = len(bs)
|
||||
copy(b, bs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Conn net.Conn
|
||||
}
|
||||
|
||||
//new conn
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
c := new(Conn)
|
||||
c.Conn = conn
|
||||
return c
|
||||
}
|
||||
|
||||
//读取指定长度内容
|
||||
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
|
||||
if cLen > 65535 {
|
||||
return nil, errors.New("长度错误")
|
||||
}
|
||||
buf := bufPool.Get().([]byte)[:cLen]
|
||||
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
|
||||
return buf, errors.New("读取指定长度错误" + err.Error())
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
//获取长度
|
||||
func (s *Conn) GetLen() (int, error) {
|
||||
val, err := s.ReadLen(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return GetLenByBytes(val)
|
||||
}
|
||||
|
||||
//写入长度+内容 粘包
|
||||
func (s *Conn) WriteLen(buf []byte) (int, error) {
|
||||
var b []byte
|
||||
var err error
|
||||
if b, err = GetLenBytes(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Write(b)
|
||||
}
|
||||
|
||||
//读取flag
|
||||
func (s *Conn) ReadFlag() (string, error) {
|
||||
val, err := s.ReadLen(4)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(val), err
|
||||
}
|
||||
|
||||
//读取host 连接地址 压缩类型
|
||||
func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
|
||||
retry:
|
||||
lType, err := s.ReadLen(3)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if typeStr = string(lType); typeStr == TEST_FLAG {
|
||||
en, de, crypt, mux = s.GetConnInfoFromConn()
|
||||
goto retry
|
||||
}
|
||||
cLen, err := s.GetLen()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hostByte, err := s.ReadLen(cLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
host = string(hostByte)
|
||||
return
|
||||
}
|
||||
|
||||
//写连接类型 和 host地址
|
||||
func (s *Conn) WriteHost(ltype string, host string) (int, error) {
|
||||
raw := bytes.NewBuffer([]byte{})
|
||||
binary.Write(raw, binary.LittleEndian, []byte(ltype))
|
||||
binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
|
||||
binary.Write(raw, binary.LittleEndian, []byte(host))
|
||||
return s.Write(raw.Bytes())
|
||||
}
|
||||
|
||||
//设置连接为长连接
|
||||
func (s *Conn) SetAlive() {
|
||||
conn := s.Conn.(*net.TCPConn)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
conn.SetKeepAlive(true)
|
||||
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
||||
}
|
||||
|
||||
//从tcp报文中解析出host,连接类型等
|
||||
func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
|
||||
var b [32 * 1024]byte
|
||||
var n int
|
||||
if n, err = s.Read(b[:]); err != nil {
|
||||
return
|
||||
}
|
||||
rb = b[:n]
|
||||
r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
|
||||
if err != nil {
|
||||
log.Println("解析host出错:", err)
|
||||
return
|
||||
}
|
||||
hostPortURL, err := url.Parse(r.Host)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if hostPortURL.Opaque == "443" { //https访问
|
||||
address = r.Host + ":443"
|
||||
} else { //http访问
|
||||
if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
|
||||
address = r.Host + ":80"
|
||||
} else {
|
||||
address = r.Host
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//单独读(加密|压缩)
|
||||
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
|
||||
if COMPRESS_SNAPY_DECODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt).Read(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt).Read(b)
|
||||
}
|
||||
|
||||
//单独写(加密|压缩)
|
||||
func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
|
||||
if COMPRESS_SNAPY_ENCODE == compress {
|
||||
return NewSnappyConn(s.Conn, crypt).Write(b)
|
||||
}
|
||||
return NewCryptConn(s.Conn, crypt).Write(b)
|
||||
}
|
||||
|
||||
//写压缩方式,加密
|
||||
func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
|
||||
s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
|
||||
}
|
||||
|
||||
//获取压缩方式,是否加密
|
||||
func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
|
||||
buf, err := s.ReadLen(4)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
en, _ = strconv.Atoi(string(buf[0]))
|
||||
de, _ = strconv.Atoi(string(buf[1]))
|
||||
crypt = GetBoolByStr(string(buf[2]))
|
||||
mux = GetBoolByStr(string(buf[3]))
|
||||
return
|
||||
}
|
||||
|
||||
//close
|
||||
func (s *Conn) Close() error {
|
||||
return s.Conn.Close()
|
||||
}
|
||||
|
||||
//write
|
||||
func (s *Conn) Write(b []byte) (int, error) {
|
||||
return s.Conn.Write(b)
|
||||
}
|
||||
|
||||
//read
|
||||
func (s *Conn) Read(b []byte) (int, error) {
|
||||
return s.Conn.Read(b)
|
||||
}
|
||||
|
||||
//write error
|
||||
func (s *Conn) WriteError() (int, error) {
|
||||
return s.Write([]byte(RES_MSG))
|
||||
}
|
||||
|
||||
//write sign flag
|
||||
func (s *Conn) WriteSign() (int, error) {
|
||||
return s.Write([]byte(RES_SIGN))
|
||||
}
|
||||
|
||||
//write main
|
||||
func (s *Conn) WriteMain() (int, error) {
|
||||
return s.Write([]byte(WORK_MAIN))
|
||||
}
|
||||
|
||||
//write chan
|
||||
func (s *Conn) WriteChan() (int, error) {
|
||||
return s.Write([]byte(WORK_CHAN))
|
||||
}
|
||||
|
||||
//write test
|
||||
func (s *Conn) WriteTest() (int, error) {
|
||||
return s.Write([]byte(TEST_FLAG))
|
||||
}
|
||||
|
||||
//write test
|
||||
func (s *Conn) WriteSuccess() (int, error) {
|
||||
return s.Write([]byte(CONN_SUCCESS))
|
||||
}
|
||||
|
||||
//write test
|
||||
func (s *Conn) WriteFail() (int, error) {
|
||||
return s.Write([]byte(CONN_ERROR))
|
||||
}
|
||||
|
||||
//获取长度+内容
|
||||
func GetLenBytes(buf []byte) (b []byte, err error) {
|
||||
raw := bytes.NewBuffer([]byte{})
|
||||
if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
|
||||
return
|
||||
}
|
||||
if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
|
||||
return
|
||||
}
|
||||
b = raw.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
//解析出长度
|
||||
func GetLenByBytes(buf []byte) (int, error) {
|
||||
nlen := binary.LittleEndian.Uint32(buf)
|
||||
if nlen <= 0 {
|
||||
return 0, errors.New("数据长度错误")
|
||||
}
|
||||
return int(nlen), nil
|
||||
}
|
82
utils/crypt.go
Normal file
82
utils/crypt.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
//en
|
||||
func AesEncrypt(origData, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
origData = PKCS5Padding(origData, blockSize)
|
||||
// origData = ZeroPadding(origData, block.BlockSize())
|
||||
blockMode := cipher.NewCBCEncrypter(block, key[:blockSize])
|
||||
crypted := make([]byte, len(origData))
|
||||
// 根据CryptBlocks方法的说明,如下方式初始化crypted也可以
|
||||
// crypted := origData
|
||||
blockMode.CryptBlocks(crypted, origData)
|
||||
return crypted, nil
|
||||
}
|
||||
|
||||
//de
|
||||
func AesDecrypt(crypted, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
|
||||
origData := make([]byte, len(crypted))
|
||||
// origData := crypted
|
||||
blockMode.CryptBlocks(origData, crypted)
|
||||
err, origData = PKCS5UnPadding(origData)
|
||||
// origData = ZeroUnPadding(origData)
|
||||
return origData, err
|
||||
}
|
||||
|
||||
//补全
|
||||
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
//去补
|
||||
func PKCS5UnPadding(origData []byte) (error, []byte) {
|
||||
length := len(origData)
|
||||
// 去掉最后一个字节 unpadding 次
|
||||
unpadding := int(origData[length-1])
|
||||
if (length - unpadding) < 0 {
|
||||
return errors.New("len error"), nil
|
||||
}
|
||||
return nil, origData[:(length - unpadding)]
|
||||
}
|
||||
|
||||
//生成32位md5字串
|
||||
func Md5(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
//生成随机验证密钥
|
||||
func GetRandomString(l int) string {
|
||||
str := "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
bytes := []byte(str)
|
||||
result := []byte{}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for i := 0; i < l; i++ {
|
||||
result = append(result, bytes[r.Intn(len(bytes))])
|
||||
}
|
||||
return string(result)
|
||||
}
|
204
utils/util.go
Executable file
204
utils/util.go
Executable file
@@ -0,0 +1,204 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
COMPRESS_NONE_ENCODE = iota
|
||||
COMPRESS_NONE_DECODE
|
||||
COMPRESS_SNAPY_ENCODE
|
||||
COMPRESS_SNAPY_DECODE
|
||||
VERIFY_EER = "vkey"
|
||||
WORK_MAIN = "main"
|
||||
WORK_CHAN = "chan"
|
||||
RES_SIGN = "sign"
|
||||
RES_MSG = "msg0"
|
||||
CONN_SUCCESS = "sucs"
|
||||
CONN_ERROR = "fail"
|
||||
TEST_FLAG = "tst"
|
||||
CONN_TCP = "tcp"
|
||||
CONN_UDP = "udp"
|
||||
Unauthorized_BYTES = `HTTP/1.1 401 Unauthorized
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
WWW-Authenticate: Basic realm="easyProxy"
|
||||
|
||||
401 Unauthorized`
|
||||
IO_EOF = "PROXYEOF"
|
||||
)
|
||||
|
||||
//copy
|
||||
func Relay(in, out net.Conn, compressType int, crypt, mux bool) {
|
||||
switch compressType {
|
||||
case COMPRESS_SNAPY_ENCODE:
|
||||
copyBuffer(NewSnappyConn(in, crypt), out)
|
||||
if mux {
|
||||
out.Close()
|
||||
NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
|
||||
}
|
||||
case COMPRESS_SNAPY_DECODE:
|
||||
copyBuffer(in, NewSnappyConn(out, crypt))
|
||||
if mux {
|
||||
in.Close()
|
||||
}
|
||||
case COMPRESS_NONE_ENCODE:
|
||||
copyBuffer(NewCryptConn(in, crypt), out)
|
||||
if mux {
|
||||
out.Close()
|
||||
NewCryptConn(in, crypt).Write([]byte(IO_EOF))
|
||||
}
|
||||
case COMPRESS_NONE_DECODE:
|
||||
copyBuffer(in, NewCryptConn(out, crypt))
|
||||
if mux {
|
||||
in.Close()
|
||||
}
|
||||
}
|
||||
if !mux {
|
||||
in.Close()
|
||||
out.Close()
|
||||
}
|
||||
}
|
||||
|
||||
//判断压缩方式
|
||||
func GetCompressType(compress string) (int, int) {
|
||||
switch compress {
|
||||
case "":
|
||||
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
|
||||
case "snappy":
|
||||
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
|
||||
default:
|
||||
log.Fatalln("数据压缩格式错误")
|
||||
}
|
||||
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
|
||||
}
|
||||
|
||||
//通过host获取对应的ip地址
|
||||
func Gethostbyname(hostname string) string {
|
||||
if !DomainCheck(hostname) {
|
||||
return hostname
|
||||
}
|
||||
ips, _ := net.LookupIP(hostname)
|
||||
if ips != nil {
|
||||
for _, v := range ips {
|
||||
if v.To4() != nil {
|
||||
return v.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
//检查是否是域名
|
||||
func DomainCheck(domain string) bool {
|
||||
var match bool
|
||||
IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
|
||||
NotLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}"
|
||||
match, _ = regexp.MatchString(IsLine, domain)
|
||||
if !match {
|
||||
match, _ = regexp.MatchString(NotLine, domain)
|
||||
}
|
||||
return match
|
||||
}
|
||||
|
||||
//检查basic认证
|
||||
func CheckAuth(r *http.Request, user, passwd string) bool {
|
||||
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
||||
if len(s) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return false
|
||||
}
|
||||
return pair[0] == user && pair[1] == passwd
|
||||
}
|
||||
|
||||
//get bool by str
|
||||
func GetBoolByStr(s string) bool {
|
||||
switch s {
|
||||
case "1", "true":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//get str by bool
|
||||
func GetStrByBool(b bool) string {
|
||||
if b {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
//int
|
||||
func GetIntNoerrByStr(str string) int {
|
||||
i, _ := strconv.Atoi(str)
|
||||
return i
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 65535)
|
||||
},
|
||||
}
|
||||
// io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密,特此修改
|
||||
func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
//TODO 回收问题
|
||||
buf := bufPool.Get().([]byte)
|
||||
for {
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
||||
//连接重置 清空缓存区
|
||||
func FlushConn(c net.Conn) {
|
||||
c.SetReadDeadline(time.Now().Add(time.Second * 3))
|
||||
buf := bufPool.Get().([]byte)
|
||||
for {
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
//简单的一个校验值
|
||||
func Getverifyval(vkey string) string {
|
||||
return Md5(vkey)
|
||||
}
|
Reference in New Issue
Block a user