nps/component/component_test.go
2022-01-23 17:30:38 +08:00

194 lines
6.0 KiB
Go

package component
import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"ehang.io/nps/component/bridge"
"ehang.io/nps/component/client"
"ehang.io/nps/lib/cert"
"ehang.io/nps/lib/pb"
"github.com/lucas-clemente/quic-go"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"io/ioutil"
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
var createCertOnce sync.Once
func createCertFile(t *testing.T) (string, string) {
createCertOnce.Do(func() {
g := cert.NewX509Generator(pkix.Name{
Country: []string{"CN"},
Organization: []string{"Ehang.io"},
OrganizationalUnit: []string{"nps"},
Province: []string{"Beijing"},
CommonName: "nps",
Locality: []string{"Beijing"},
})
cert, key, err := g.CreateRootCa()
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_cert.pem"), cert, 0600))
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "root_key.pem"), key, 0600))
assert.NoError(t, g.InitRootCa(cert, key))
cert, key, err = g.CreateCert("bridge.nps.ehang.io")
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_cert.pem"), cert, 0600))
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "bridge_key.pem"), key, 0600))
cert, key, err = g.CreateCert("client.nps.ehang.io")
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_cert.pem"), cert, 0600))
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "client_key.pem"), key, 0600))
})
return filepath.Join(os.TempDir(), "cert.pem"), filepath.Join(os.TempDir(), "key.pem")
}
func TestTcpConnect(t *testing.T) {
createCertFile(t)
bridgeLn, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
assert.NoError(t, err)
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(buf)
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
assert.NoError(t, err)
bridgeTlsConfig := &tls.Config{
Certificates: []tls.Certificate{crt},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
}
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
assert.NoError(t, err)
clientConfig := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{crt},
}
go func() {
assert.NoError(t, bridge.StartTcpBridge(bridgeLn, bridgeTlsConfig, func(s string) bool {
return true
}, func(s string) bool {
sn, err := cert.GetCertSnFromConfig(clientConfig)
assert.NoError(t, err)
assert.Equal(t, sn, s)
return true
}))
}()
var c *client.Client
go func() {
id, err := cert.GetCertSnFromConfig(clientConfig)
assert.NoError(t, err)
creator := client.TcpTunnelCreator{}
connId := uuid.NewV1().String()
controlLn, err := creator.NewMux(bridgeLn.Addr().String(),
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
assert.NoError(t, err)
dataLn, err := creator.NewMux(bridgeLn.Addr().String(),
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
assert.NoError(t, err)
c = client.NewClient(controlLn, dataLn)
c.Run()
}()
timeout := time.NewTimer(time.Second * 30)
ticker := time.NewTicker(time.Millisecond * 100)
for {
select {
case <-ticker.C:
if c != nil && c.HasPong() {
c.Close()
return
}
case <-timeout.C:
t.Fail()
return
}
}
}
func TestQUICConnect(t *testing.T) {
createCertFile(t)
bridgePacketConn, err := net.ListenPacket("udp", "127.0.0.1:0")
assert.NoError(t, err)
buf, err := ioutil.ReadFile(filepath.Join(os.TempDir(), "root_cert.pem"))
assert.NoError(t, err)
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(buf)
crt, err := tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "bridge_cert.pem"), filepath.Join(os.TempDir(), "bridge_key.pem"))
assert.NoError(t, err)
bridgeTlsConfig := &tls.Config{
Certificates: []tls.Certificate{crt},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
NextProtos: []string{"quic-nps"},
}
crt, err = tls.LoadX509KeyPair(filepath.Join(os.TempDir(), "client_cert.pem"), filepath.Join(os.TempDir(), "client_key.pem"))
assert.NoError(t, err)
clientConfig := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{crt},
NextProtos: []string{"quic-nps"},
}
go func() {
assert.NoError(t, bridge.StartQUICBridge(bridgePacketConn, bridgeTlsConfig, &quic.Config{
MaxIncomingStreams: 1000000,
MaxIncomingUniStreams: 1000000,
MaxIdleTimeout: time.Minute,
KeepAlive: true,
}, func(s string) bool {
sn, err := cert.GetCertSnFromConfig(clientConfig)
assert.NoError(t, err)
assert.Equal(t, sn, s)
return true
}))
}()
var c *client.Client
go func() {
id, err := cert.GetCertSnFromConfig(clientConfig)
assert.NoError(t, err)
creator := client.QUICTunnelCreator{}
connId := uuid.NewV1().String()
controlLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: true}}}, clientConfig)
assert.NoError(t, err)
dataLn, err := creator.NewMux(bridgePacketConn.LocalAddr().String(),
&pb.ConnRequest{Id: id, ConnType: &pb.ConnRequest_NpcInfo{NpcInfo: &pb.NpcInfo{TunnelId: connId, IsControlTunnel: false}}}, clientConfig)
assert.NoError(t, err)
assert.NotEmpty(t, dataLn)
assert.NotEmpty(t, controlLn)
c = client.NewClient(controlLn, dataLn)
c.Run()
}()
timeout := time.NewTimer(time.Second * 30)
ticker := time.NewTicker(time.Millisecond * 100)
for {
select {
case <-ticker.C:
if c != nil && c.HasPong() {
c.Close()
return
}
case <-timeout.C:
t.Fail()
return
}
}
}