nps/core/process/https_proxy_test.go
2022-01-23 17:30:38 +08:00

121 lines
3.0 KiB
Go

package process
import (
"crypto/tls"
"crypto/x509/pkix"
"ehang.io/nps/core/action"
"ehang.io/nps/lib/cert"
"ehang.io/nps/lib/enet"
"fmt"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"testing"
)
var createCertOnce sync.Once
func createCertFile(t *testing.T) (string, string) {
createCertOnce.Do(func() {
g := cert.NewX509Generator(pkix.Name{
Country: []string{"CN"},
Organization: []string{"Ehang.io"},
OrganizationalUnit: []string{"nps"},
Province: []string{"Beijing"},
CommonName: "nps",
Locality: []string{"Beijing"},
})
cert, key, err := g.CreateRootCa()
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "cert.pem"), cert, 0600))
assert.NoError(t, os.WriteFile(filepath.Join(os.TempDir(), "key.pem"), key, 0600))
})
return filepath.Join(os.TempDir(), "cert.pem"), filepath.Join(os.TempDir(), "key.pem")
}
func TestHttpsProxyProcess(t *testing.T) {
sAddr, err := startHttps(t)
certFilePath, keyFilePath := createCertFile(t)
h := HttpsProxyProcess{
HttpProxyProcess: HttpProxyProcess{},
CertFile: certFilePath,
KeyFile: keyFilePath,
}
ac := &action.LocalAction{}
ac.Init()
assert.NoError(t, h.Init(ac))
ln, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
for {
c, err := ln.Accept()
assert.NoError(t, err)
go h.ProcessConn(enet.NewReaderConn(c))
}
}()
transport := &http.Transport{
Proxy: func(_ *http.Request) (*url.URL, error) {
return url.Parse(fmt.Sprintf("https://%s", ln.Addr().String()))
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := &http.Client{Transport: transport}
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}
func TestHttpsProxyProcessBasic(t *testing.T) {
certFilePath, keyFilePath := createCertFile(t)
sAddr, err := startHttps(t)
h := HttpsProxyProcess{
HttpProxyProcess: HttpProxyProcess{
BasicAuth: map[string]string{"aaa": "bbb"},
},
CertFile: certFilePath,
KeyFile: keyFilePath,
}
ac := &action.LocalAction{}
ac.Init()
assert.NoError(t, h.Init(ac))
ln, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
for {
c, err := ln.Accept()
assert.NoError(t, err)
go h.ProcessConn(enet.NewReaderConn(c))
}
}()
transport := &http.Transport{
Proxy: func(_ *http.Request) (*url.URL, error) {
return url.Parse(fmt.Sprintf("https://%s", ln.Addr().String()))
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := &http.Client{Transport: transport}
resp, err := client.Get(fmt.Sprintf("https://%s/now", sAddr))
assert.Error(t, err)
transport.Proxy = func(_ *http.Request) (*url.URL, error) {
return url.Parse(fmt.Sprintf("https://%s:%s@%s", "aaa", "bbb", ln.Addr().String()))
}
resp, err = client.Get(fmt.Sprintf("https://%s/now", sAddr))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}