客户端服务端分离

This commit is contained in:
刘河
2019-01-09 20:33:00 +08:00
parent dcd21f211d
commit 1f61b99387
46 changed files with 1062 additions and 1431 deletions

354
utils/conn.go Executable file
View File

@@ -0,0 +1,354 @@
package utils
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"github.com/golang/snappy"
"io"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const cryptKey = "1234567812345678"
type CryptConn struct {
conn net.Conn
crypt bool
}
func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
c := new(CryptConn)
c.conn = conn
c.crypt = crypt
return c
}
//加密写
func (s *CryptConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
return
}
}
if b, err = GetLenBytes(b); err != nil {
return
}
_, err = s.conn.Write(b)
return
}
//解密读
func (s *CryptConn) Read(b []byte) (n int, err error) {
defer func() {
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
err = io.EOF
n = 0
}
}()
var lens int
var buf, bs []byte
c := NewConn(s.conn)
if lens, err = c.GetLen(); err != nil {
return
}
if buf, err = c.ReadLen(lens); err != nil {
return
}
if s.crypt {
if bs, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
return
}
} else {
bs = buf
}
n = len(bs)
copy(b, bs)
return
}
type SnappyConn struct {
w *snappy.Writer
r *snappy.Reader
crypt bool
}
func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
c := new(SnappyConn)
c.w = snappy.NewBufferedWriter(conn)
c.r = snappy.NewReader(conn)
c.crypt = crypt
return c
}
//snappy压缩写 包含加密
func (s *SnappyConn) Write(b []byte) (n int, err error) {
n = len(b)
if s.crypt {
if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
log.Println("encode crypt error:", err)
return
}
}
if _, err = s.w.Write(b); err != nil {
return
}
err = s.w.Flush()
return
}
//snappy压缩读 包含解密
func (s *SnappyConn) Read(b []byte) (n int, err error) {
defer func() {
if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
err = io.EOF
n = 0
}
}()
if n, err = s.r.Read(b); err != nil {
return
}
if s.crypt {
var bs []byte
if bs, err = AesDecrypt(b[:n], []byte(cryptKey)); err != nil {
log.Println("decode crypt error:", err)
return
}
n = len(bs)
copy(b, bs)
}
return
}
type Conn struct {
Conn net.Conn
}
//new conn
func NewConn(conn net.Conn) *Conn {
c := new(Conn)
c.Conn = conn
return c
}
//读取指定长度内容
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
if cLen > 65535 {
return nil, errors.New("长度错误")
}
buf := bufPool.Get().([]byte)[:cLen]
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
return buf, errors.New("读取指定长度错误" + err.Error())
}
return buf, nil
}
//获取长度
func (s *Conn) GetLen() (int, error) {
val, err := s.ReadLen(4)
if err != nil {
return 0, err
}
return GetLenByBytes(val)
}
//写入长度+内容 粘包
func (s *Conn) WriteLen(buf []byte) (int, error) {
var b []byte
var err error
if b, err = GetLenBytes(buf); err != nil {
return 0, err
}
return s.Write(b)
}
//读取flag
func (s *Conn) ReadFlag() (string, error) {
val, err := s.ReadLen(4)
if err != nil {
return "", err
}
return string(val), err
}
//读取host 连接地址 压缩类型
func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
retry:
lType, err := s.ReadLen(3)
if err != nil {
return
}
if typeStr = string(lType); typeStr == TEST_FLAG {
en, de, crypt, mux = s.GetConnInfoFromConn()
goto retry
}
cLen, err := s.GetLen()
if err != nil {
return
}
hostByte, err := s.ReadLen(cLen)
if err != nil {
return
}
host = string(hostByte)
return
}
//写连接类型 和 host地址
func (s *Conn) WriteHost(ltype string, host string) (int, error) {
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(ltype))
binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
binary.Write(raw, binary.LittleEndian, []byte(host))
return s.Write(raw.Bytes())
}
//设置连接为长连接
func (s *Conn) SetAlive() {
conn := s.Conn.(*net.TCPConn)
conn.SetReadDeadline(time.Time{})
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
}
//从tcp报文中解析出host连接类型等
func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
var b [32 * 1024]byte
var n int
if n, err = s.Read(b[:]); err != nil {
return
}
rb = b[:n]
r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
if err != nil {
log.Println("解析host出错", err)
return
}
hostPortURL, err := url.Parse(r.Host)
if err != nil {
return
}
if hostPortURL.Opaque == "443" { //https访问
address = r.Host + ":443"
} else { //http访问
if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口 默认80
address = r.Host + ":80"
} else {
address = r.Host
}
}
return
}
//单独读(加密|压缩)
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
if COMPRESS_SNAPY_DECODE == compress {
return NewSnappyConn(s.Conn, crypt).Read(b)
}
return NewCryptConn(s.Conn, crypt).Read(b)
}
//单独写(加密|压缩)
func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
if COMPRESS_SNAPY_ENCODE == compress {
return NewSnappyConn(s.Conn, crypt).Write(b)
}
return NewCryptConn(s.Conn, crypt).Write(b)
}
//写压缩方式,加密
func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
}
//获取压缩方式,是否加密
func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
buf, err := s.ReadLen(4)
if err != nil {
return
}
en, _ = strconv.Atoi(string(buf[0]))
de, _ = strconv.Atoi(string(buf[1]))
crypt = GetBoolByStr(string(buf[2]))
mux = GetBoolByStr(string(buf[3]))
return
}
//close
func (s *Conn) Close() error {
return s.Conn.Close()
}
//write
func (s *Conn) Write(b []byte) (int, error) {
return s.Conn.Write(b)
}
//read
func (s *Conn) Read(b []byte) (int, error) {
return s.Conn.Read(b)
}
//write error
func (s *Conn) WriteError() (int, error) {
return s.Write([]byte(RES_MSG))
}
//write sign flag
func (s *Conn) WriteSign() (int, error) {
return s.Write([]byte(RES_SIGN))
}
//write main
func (s *Conn) WriteMain() (int, error) {
return s.Write([]byte(WORK_MAIN))
}
//write chan
func (s *Conn) WriteChan() (int, error) {
return s.Write([]byte(WORK_CHAN))
}
//write test
func (s *Conn) WriteTest() (int, error) {
return s.Write([]byte(TEST_FLAG))
}
//write test
func (s *Conn) WriteSuccess() (int, error) {
return s.Write([]byte(CONN_SUCCESS))
}
//write test
func (s *Conn) WriteFail() (int, error) {
return s.Write([]byte(CONN_ERROR))
}
//获取长度+内容
func GetLenBytes(buf []byte) (b []byte, err error) {
raw := bytes.NewBuffer([]byte{})
if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
return
}
if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
return
}
b = raw.Bytes()
return
}
//解析出长度
func GetLenByBytes(buf []byte) (int, error) {
nlen := binary.LittleEndian.Uint32(buf)
if nlen <= 0 {
return 0, errors.New("数据长度错误")
}
return int(nlen), nil
}

82
utils/crypt.go Normal file
View File

@@ -0,0 +1,82 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"encoding/hex"
"errors"
"math/rand"
"time"
)
//en
func AesEncrypt(origData, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
origData = PKCS5Padding(origData, blockSize)
// origData = ZeroPadding(origData, block.BlockSize())
blockMode := cipher.NewCBCEncrypter(block, key[:blockSize])
crypted := make([]byte, len(origData))
// 根据CryptBlocks方法的说明如下方式初始化crypted也可以
// crypted := origData
blockMode.CryptBlocks(crypted, origData)
return crypted, nil
}
//de
func AesDecrypt(crypted, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
origData := make([]byte, len(crypted))
// origData := crypted
blockMode.CryptBlocks(origData, crypted)
err, origData = PKCS5UnPadding(origData)
// origData = ZeroUnPadding(origData)
return origData, err
}
//补全
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
//去补
func PKCS5UnPadding(origData []byte) (error, []byte) {
length := len(origData)
// 去掉最后一个字节 unpadding 次
unpadding := int(origData[length-1])
if (length - unpadding) < 0 {
return errors.New("len error"), nil
}
return nil, origData[:(length - unpadding)]
}
//生成32位md5字串
func Md5(s string) string {
h := md5.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
//生成随机验证密钥
func GetRandomString(l int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyz"
bytes := []byte(str)
result := []byte{}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < l; i++ {
result = append(result, bytes[r.Intn(len(bytes))])
}
return string(result)
}

204
utils/util.go Executable file
View File

@@ -0,0 +1,204 @@
package utils
import (
"encoding/base64"
"io"
"log"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
)
const (
COMPRESS_NONE_ENCODE = iota
COMPRESS_NONE_DECODE
COMPRESS_SNAPY_ENCODE
COMPRESS_SNAPY_DECODE
VERIFY_EER = "vkey"
WORK_MAIN = "main"
WORK_CHAN = "chan"
RES_SIGN = "sign"
RES_MSG = "msg0"
CONN_SUCCESS = "sucs"
CONN_ERROR = "fail"
TEST_FLAG = "tst"
CONN_TCP = "tcp"
CONN_UDP = "udp"
Unauthorized_BYTES = `HTTP/1.1 401 Unauthorized
Content-Type: text/plain; charset=utf-8
WWW-Authenticate: Basic realm="easyProxy"
401 Unauthorized`
IO_EOF = "PROXYEOF"
)
//copy
func Relay(in, out net.Conn, compressType int, crypt, mux bool) {
switch compressType {
case COMPRESS_SNAPY_ENCODE:
copyBuffer(NewSnappyConn(in, crypt), out)
if mux {
out.Close()
NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
}
case COMPRESS_SNAPY_DECODE:
copyBuffer(in, NewSnappyConn(out, crypt))
if mux {
in.Close()
}
case COMPRESS_NONE_ENCODE:
copyBuffer(NewCryptConn(in, crypt), out)
if mux {
out.Close()
NewCryptConn(in, crypt).Write([]byte(IO_EOF))
}
case COMPRESS_NONE_DECODE:
copyBuffer(in, NewCryptConn(out, crypt))
if mux {
in.Close()
}
}
if !mux {
in.Close()
out.Close()
}
}
//判断压缩方式
func GetCompressType(compress string) (int, int) {
switch compress {
case "":
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
case "snappy":
return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
default:
log.Fatalln("数据压缩格式错误")
}
return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
}
//通过host获取对应的ip地址
func Gethostbyname(hostname string) string {
if !DomainCheck(hostname) {
return hostname
}
ips, _ := net.LookupIP(hostname)
if ips != nil {
for _, v := range ips {
if v.To4() != nil {
return v.String()
}
}
}
return ""
}
//检查是否是域名
func DomainCheck(domain string) bool {
var match bool
IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
NotLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}"
match, _ = regexp.MatchString(IsLine, domain)
if !match {
match, _ = regexp.MatchString(NotLine, domain)
}
return match
}
//检查basic认证
func CheckAuth(r *http.Request, user, passwd string) bool {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 {
return false
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return false
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return false
}
return pair[0] == user && pair[1] == passwd
}
//get bool by str
func GetBoolByStr(s string) bool {
switch s {
case "1", "true":
return true
}
return false
}
//get str by bool
func GetStrByBool(b bool) string {
if b {
return "1"
}
return "0"
}
//int
func GetIntNoerrByStr(str string) int {
i, _ := strconv.Atoi(str)
return i
}
var bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 65535)
},
}
// io.copy的优化版读取buffer长度原为32*1024与snappy不同导致读取出的内容存在差异不利于解密特此修改
func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
//TODO 回收问题
buf := bufPool.Get().([]byte)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
//连接重置 清空缓存区
func FlushConn(c net.Conn) {
c.SetReadDeadline(time.Now().Add(time.Second * 3))
buf := bufPool.Get().([]byte)
for {
if _, err := c.Read(buf); err != nil {
break
}
}
c.SetReadDeadline(time.Time{})
}
//简单的一个校验值
func Getverifyval(vkey string) string {
return Md5(vkey)
}