nps/server/proxy/reverseproxy.go
snowie2000 a732febf3b fixed typo in test.go
replaced self-made http reverseproxy with a more robust and versatile one.
dynamically generate cert for client-server tls encryption
2020-04-15 10:59:48 +08:00

137 lines
2.9 KiB
Go

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP reverse proxy handler
package proxy
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
)
type HTTPError struct {
error
HTTPCode int
}
func NewHTTPError(code int, errmsg string) error {
return &HTTPError{
error: errors.New(errmsg),
HTTPCode: code,
}
}
type ReverseProxy struct {
*httputil.ReverseProxy
WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
}
return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
}
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
rp := &ReverseProxy{
ReverseProxy: httputil.NewSingleHostReverseProxy(target),
WebSocketDialContext: nil,
}
rp.ErrorHandler = rp.errHandler
return rp
}
func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy {
rp := &ReverseProxy{
ReverseProxy: orp,
WebSocketDialContext: nil,
}
rp.ErrorHandler = rp.errHandler
return rp
}
func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) {
if e == io.EOF {
rw.WriteHeader(521)
//rw.Write(getWaitingPageContent())
} else {
if httperr, ok := e.(*HTTPError); ok {
rw.WriteHeader(httperr.HTTPCode)
} else {
rw.WriteHeader(http.StatusNotFound)
}
rw.Write([]byte("error: " + e.Error()))
}
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if IsWebsocketRequest(req) {
p.serveWebSocket(rw, req)
} else {
p.ReverseProxy.ServeHTTP(rw, req)
}
}
func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
if p.WebSocketDialContext == nil {
rw.WriteHeader(500)
return
}
targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
if err != nil {
rw.WriteHeader(501)
return
}
defer targetConn.Close()
p.Director(req)
hijacker, ok := rw.(http.Hijacker)
if !ok {
rw.WriteHeader(500)
return
}
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
rw.WriteHeader(500)
return
}
defer conn.Close()
req.Write(targetConn)
Join(conn, targetConn)
}
func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) {
var wait sync.WaitGroup
pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) {
defer to.Close()
defer from.Close()
defer wait.Done()
*count, _ = io.Copy(to, from)
}
wait.Add(2)
go pipe(c1, c2, &inCount)
go pipe(c2, c1, &outCount)
wait.Wait()
return
}