mirror of
https://github.com/chai2010/advanced-go-programming-book.git
synced 2025-05-24 12:32:21 +00:00
83 lines
2.4 KiB
Go
83 lines
2.4 KiB
Go
// Copyright 2015 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// +build go1.6
|
|
|
|
package http2
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net/http"
|
|
)
|
|
|
|
func configureTransport(t1 *http.Transport) (*Transport, error) {
|
|
connPool := new(clientConnPool)
|
|
t2 := &Transport{
|
|
ConnPool: noDialClientConnPool{connPool},
|
|
t1: t1,
|
|
}
|
|
connPool.t = t2
|
|
if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
|
|
return nil, err
|
|
}
|
|
if t1.TLSClientConfig == nil {
|
|
t1.TLSClientConfig = new(tls.Config)
|
|
}
|
|
if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
|
|
t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
|
|
}
|
|
if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
|
|
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
|
|
}
|
|
upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
|
|
addr := authorityAddr("https", authority)
|
|
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
|
|
go c.Close()
|
|
return erringRoundTripper{err}
|
|
} else if !used {
|
|
// Turns out we don't need this c.
|
|
// For example, two goroutines made requests to the same host
|
|
// at the same time, both kicking off TCP dials. (since protocol
|
|
// was unknown)
|
|
go c.Close()
|
|
}
|
|
return t2
|
|
}
|
|
if m := t1.TLSNextProto; len(m) == 0 {
|
|
t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
|
|
"h2": upgradeFn,
|
|
}
|
|
} else {
|
|
m["h2"] = upgradeFn
|
|
}
|
|
return t2, nil
|
|
}
|
|
|
|
// registerHTTPSProtocol calls Transport.RegisterProtocol but
|
|
// converting panics into errors.
|
|
func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err error) {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
err = fmt.Errorf("%v", e)
|
|
}
|
|
}()
|
|
t.RegisterProtocol("https", rt)
|
|
return nil
|
|
}
|
|
|
|
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
|
|
// if there's already has a cached connection to the host.
|
|
// (The field is exported so it can be accessed via reflect from net/http; tested
|
|
// by TestNoDialH2RoundTripperType)
|
|
type noDialH2RoundTripper struct{ *Transport }
|
|
|
|
func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
res, err := rt.Transport.RoundTrip(req)
|
|
if isNoCachedConnError(err) {
|
|
return nil, http.ErrSkipAltProtocol
|
|
}
|
|
return res, err
|
|
}
|