diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 1ca9f454a9..eac86bc923 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -124,7 +124,7 @@ func (f *tlsHandshakeFunc) Apply( defer cancel() // handshake - conn, tlsState, err := handshaker.Handshake(ctx, input.Conn, config) + conn, err := handshaker.Handshake(ctx, input.Conn, config) // possibly register established conn for late close f.Pool.MaybeTrack(conn) @@ -132,19 +132,14 @@ func (f *tlsHandshakeFunc) Apply( // stop the operation logger ol.Stop(err) - var tlsConn netxlite.TLSConn - if conn != nil { - tlsConn = conn.(netxlite.TLSConn) // guaranteed to work - } - state := &TLSConnection{ Address: input.Address, - Conn: tlsConn, // possibly nil + Conn: conn, // possibly nil Domain: input.Domain, IDGenerator: input.IDGenerator, Logger: input.Logger, Network: input.Network, - TLSState: tlsState, + TLSState: netxlite.MaybeTLSConnectionState(conn), Trace: trace, ZeroTime: input.ZeroTime, } diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 652b688e0b..3cba8f81d8 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -70,17 +70,22 @@ func TestTLSHandshake(t *testing.T) { return nil }, } - tlsConn := &mocks.TLSConn{Conn: tcpConn} + tlsConn := &mocks.TLSConn{ + Conn: tcpConn, + MockConnectionState: func() tls.ConnectionState { + return tls.ConnectionState{} + }, + } eofHandshaker := &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, io.EOF + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return nil, io.EOF }, } goodHandshaker := &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return tlsConn, tls.ConnectionState{}, nil + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return tlsConn, nil }, } diff --git a/internal/experiment/echcheck/handshake.go b/internal/experiment/echcheck/handshake.go index a9108831a8..5e3f2a0214 100644 --- a/internal/experiment/echcheck/handshake.go +++ b/internal/experiment/echcheck/handshake.go @@ -41,9 +41,10 @@ func handshakeWithExtension(ctx context.Context, conn net.Conn, zeroTime time.Ti tracedHandshaker := handshakerConstructor(log.Log, &utls.HelloFirefox_Auto) start := time.Now() - _, connState, err := tracedHandshaker.Handshake(ctx, conn, tlsConfig) + maybeTLSConn, err := tracedHandshaker.Handshake(ctx, conn, tlsConfig) finish := time.Now() + connState := netxlite.MaybeTLSConnectionState(maybeTLSConn) return measurexlite.NewArchivalTLSOrQUICHandshakeResult(0, start.Sub(zeroTime), "tcp", address, tlsConfig, connState, err, finish.Sub(zeroTime)) } diff --git a/internal/experiment/echcheck/measure_test.go b/internal/experiment/echcheck/measure_test.go index 289879f865..f18a17d640 100644 --- a/internal/experiment/echcheck/measure_test.go +++ b/internal/experiment/echcheck/measure_test.go @@ -84,12 +84,10 @@ func TestMeasurementSuccess(t *testing.T) { } summary, err := measurer.GetSummaryKeys(&model.Measurement{}) - + if err != nil { + t.Fatal(err) + } if summary.(SummaryKeys).IsAnomaly != false { t.Fatal("expected false") } } - -func newsession() model.ExperimentSession { - return &mockable.Session{MockableLogger: log.Log} -} diff --git a/internal/experiment/echcheck/utls.go b/internal/experiment/echcheck/utls.go index ddeb9034ca..2f5d2b259e 100644 --- a/internal/experiment/echcheck/utls.go +++ b/internal/experiment/echcheck/utls.go @@ -12,7 +12,6 @@ import ( ) type tlsHandshakerWithExtensions struct { - conn *netxlite.UTLSConn extensions []utls.TLSExtension dl model.DebugLogger id *utls.ClientHelloID @@ -32,18 +31,19 @@ func newHandshakerWithExtensions(extensions []utls.TLSExtension) func(dl model.D } } -func (t *tlsHandshakerWithExtensions) Handshake(ctx context.Context, conn net.Conn, tlsConfig *tls.Config) ( - net.Conn, tls.ConnectionState, error) { - var err error - t.conn, err = netxlite.NewUTLSConn(conn, tlsConfig, t.id) +func (t *tlsHandshakerWithExtensions) Handshake( + ctx context.Context, tcpConn net.Conn, tlsConfig *tls.Config) (model.TLSConn, error) { + tlsConn, err := netxlite.NewUTLSConn(tcpConn, tlsConfig, t.id) runtimex.Assert(err == nil, "unexpected error when creating UTLSConn") if t.extensions != nil && len(t.extensions) != 0 { - t.conn.BuildHandshakeState() - t.conn.Extensions = append(t.conn.Extensions, t.extensions...) + tlsConn.BuildHandshakeState() + tlsConn.Extensions = append(tlsConn.Extensions, t.extensions...) } - err = t.conn.Handshake() + if err := tlsConn.Handshake(); err != nil { + return nil, err + } - return t.conn.NetConn(), t.conn.ConnectionState(), err + return tlsConn, nil } diff --git a/internal/experiment/echcheck/utls_test.go b/internal/experiment/echcheck/utls_test.go new file mode 100644 index 0000000000..806fe2896e --- /dev/null +++ b/internal/experiment/echcheck/utls_test.go @@ -0,0 +1,41 @@ +package echcheck + +import ( + "context" + "crypto/tls" + "errors" + "testing" + + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" + utls "gitlab.com/yawning/utls.git" +) + +func TestTLSHandshakerWithExtension(t *testing.T) { + t.Run("when the TLS handshake fails", func(t *testing.T) { + thx := &tlsHandshakerWithExtensions{ + extensions: []utls.TLSExtension{}, + dl: model.DiscardLogger, + id: &utls.HelloChrome_70, + } + + expected := errors.New("mocked error") + tcpConn := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, expected + }, + } + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + + tlsConn, err := thx.Handshake(context.Background(), tcpConn, tlsConfig) + if !errors.Is(err, expected) { + t.Fatal(err) + } + if tlsConn != nil { + t.Fatal("expected nil tls conn") + } + }) +} diff --git a/internal/experiment/tlsmiddlebox/tracing.go b/internal/experiment/tlsmiddlebox/tracing.go index cfbb3a37d2..d904744a9c 100644 --- a/internal/experiment/tlsmiddlebox/tracing.go +++ b/internal/experiment/tlsmiddlebox/tracing.go @@ -97,7 +97,7 @@ func (m *Measurer) handshakeWithTTL(ctx context.Context, index int64, zeroTime t if clientId > 0 { thx = trace.NewTLSHandshakerUTLS(logger, ClientIDs[clientId]) } - _, _, err = thx.Handshake(ctx, conn, genTLSConfig(sni)) + _, err = thx.Handshake(ctx, conn, genTLSConfig(sni)) ol.Stop(err) soErr := extractSoError(conn) // 4. reset the TTL value to ensure that conn closes successfully diff --git a/internal/experiment/tlsping/tlsping.go b/internal/experiment/tlsping/tlsping.go index ba1692a1eb..23eeea46ff 100644 --- a/internal/experiment/tlsping/tlsping.go +++ b/internal/experiment/tlsping/tlsping.go @@ -189,7 +189,7 @@ func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64, RootCAs: nil, ServerName: sni, } - _, _, err = thx.Handshake(ctx, conn, config) + _, err = thx.Handshake(ctx, conn, config) ol.Stop(err) sp.TLSHandshake = trace.FirstTLSHandshakeOrNil() // record the first handshake from the buffer sp.NetworkEvents = trace.NetworkEvents() diff --git a/internal/experiment/webconnectivitylte/secureflow.go b/internal/experiment/webconnectivitylte/secureflow.go index 6e0f2bf8b6..1f63c17434 100644 --- a/internal/experiment/webconnectivitylte/secureflow.go +++ b/internal/experiment/webconnectivitylte/secureflow.go @@ -154,7 +154,7 @@ func (t *SecureFlow) Run(parentCtx context.Context, index int64) error { const tlsTimeout = 10 * time.Second tlsCtx, tlsCancel := context.WithTimeout(parentCtx, tlsTimeout) defer tlsCancel() - tlsConn, tlsConnState, err := tlsHandshaker.Handshake(tlsCtx, tcpConn, tlsConfig) + tlsConn, err := tlsHandshaker.Handshake(tlsCtx, tcpConn, tlsConfig) t.TestKeys.AppendTLSHandshakes(trace.TLSHandshakes()...) if err != nil { ol.Stop(err) @@ -162,6 +162,7 @@ func (t *SecureFlow) Run(parentCtx context.Context, index int64) error { } defer tlsConn.Close() + tlsConnState := netxlite.MaybeTLSConnectionState(tlsConn) alpn := tlsConnState.NegotiatedProtocol // Determine whether we're allowed to fetch the webpage @@ -177,8 +178,7 @@ func (t *SecureFlow) Run(parentCtx context.Context, index int64) error { httpTransport := netxlite.NewHTTPTransport( t.Logger, netxlite.NewNullDialer(), - // note: netxlite guarantees that here tlsConn is a netxlite.TLSConn - netxlite.NewSingleUseTLSDialer(tlsConn.(netxlite.TLSConn)), + netxlite.NewSingleUseTLSDialer(tlsConn), ) // create HTTP request diff --git a/internal/legacy/measurex/measurer.go b/internal/legacy/measurex/measurer.go index 8b47d176fa..23f648d856 100644 --- a/internal/legacy/measurex/measurer.go +++ b/internal/legacy/measurex/measurer.go @@ -355,13 +355,12 @@ func (mx *Measurer) TLSConnectAndHandshakeWithDB(ctx context.Context, ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() th := mx.WrapTLSHandshaker(db, mx.TLSHandshaker) - tlsConn, _, err := th.Handshake(ctx, conn, config) + tlsConn, err := th.Handshake(ctx, conn, config) ol.Stop(err) if err != nil { return nil, err } - // cast safe according to the docs of netxlite's handshaker - return tlsConn.(netxlite.TLSConn), nil + return tlsConn, nil } // QUICHandshake connects and TLS handshakes with a QUIC endpoint. diff --git a/internal/legacy/measurex/tls.go b/internal/legacy/measurex/tls.go index 8ee1fd7fcd..b5fe894497 100644 --- a/internal/legacy/measurex/tls.go +++ b/internal/legacy/measurex/tls.go @@ -11,7 +11,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "net" "time" "github.com/ooni/probe-cli/v3/internal/model" @@ -53,13 +52,13 @@ type QUICTLSHandshakeEvent struct { Started float64 } -func (thx *tlsHandshakerDB) Handshake(ctx context.Context, - conn Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func (thx *tlsHandshakerDB) Handshake(ctx context.Context, conn Conn, config *tls.Config) (model.TLSConn, error) { network := conn.RemoteAddr().Network() remoteAddr := conn.RemoteAddr().String() started := time.Since(thx.begin).Seconds() - tconn, state, err := thx.TLSHandshaker.Handshake(ctx, conn, config) + tconn, err := thx.TLSHandshaker.Handshake(ctx, conn, config) finished := time.Since(thx.begin).Seconds() + tstate := netxlite.MaybeTLSConnectionState(tconn) thx.db.InsertIntoTLSHandshake(&QUICTLSHandshakeEvent{ Network: network, RemoteAddr: remoteAddr, @@ -70,12 +69,12 @@ func (thx *tlsHandshakerDB) Handshake(ctx context.Context, Finished: finished, Failure: NewFailure(err), Oddity: thx.computeOddity(err), - TLSVersion: netxlite.TLSVersionString(state.Version), - CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), - NegotiatedProto: state.NegotiatedProtocol, - PeerCerts: peerCerts(err, &state), + TLSVersion: netxlite.TLSVersionString(tstate.Version), + CipherSuite: netxlite.TLSCipherSuiteString(tstate.CipherSuite), + NegotiatedProto: tstate.NegotiatedProtocol, + PeerCerts: peerCerts(err, &tstate), }) - return tconn, state, err + return tconn, err } func (thx *tlsHandshakerDB) computeOddity(err error) Oddity { diff --git a/internal/legacy/tracex/tls.go b/internal/legacy/tracex/tls.go index c7659de7fe..2d9ddea2a9 100644 --- a/internal/legacy/tracex/tls.go +++ b/internal/legacy/tracex/tls.go @@ -42,7 +42,7 @@ func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker { // Handshake implements model.TLSHandshaker.Handshake func (h *TLSHandshakerSaver) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { proto := conn.RemoteAddr().Network() remoteAddr := conn.RemoteAddr().String() start := time.Now() @@ -54,23 +54,24 @@ func (h *TLSHandshakerSaver) Handshake( TLSServerName: config.ServerName, Time: start, }}) - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + tlsconn, err := h.TLSHandshaker.Handshake(ctx, conn, config) stop := time.Now() + tstate := netxlite.MaybeTLSConnectionState(tlsconn) h.Saver.Write(&EventTLSHandshakeDone{&EventValue{ Address: remoteAddr, Duration: stop.Sub(start), Err: NewFailureStr(err), NoTLSVerify: config.InsecureSkipVerify, Proto: proto, - TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), - TLSNegotiatedProto: state.NegotiatedProtocol, + TLSCipherSuite: netxlite.TLSCipherSuiteString(tstate.CipherSuite), + TLSNegotiatedProto: tstate.NegotiatedProtocol, TLSNextProtos: config.NextProtos, - TLSPeerCerts: tlsPeerCerts(state, err), + TLSPeerCerts: tlsPeerCerts(tstate, err), TLSServerName: config.ServerName, - TLSVersion: netxlite.TLSVersionString(state.Version), + TLSVersion: netxlite.TLSVersionString(tstate.Version), Time: stop, }}) - return tlsconn, state, err + return tlsconn, err } var _ model.TLSHandshaker = &TLSHandshakerSaver{} diff --git a/internal/legacy/tracex/tls_test.go b/internal/legacy/tracex/tls_test.go index d4119cec48..7d5df34f77 100644 --- a/internal/legacy/tracex/tls_test.go +++ b/internal/legacy/tracex/tls_test.go @@ -10,6 +10,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" ) func TestWrapTLSHandshaker(t *testing.T) { @@ -98,9 +99,8 @@ func TestTLSHandshakerSaver(t *testing.T) { }, } thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return returnedConn, returnedConnState, nil + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return returnedConn, nil }, }) ctx := context.Background() @@ -121,7 +121,7 @@ func TestTLSHandshakerSaver(t *testing.T) { } }, } - conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if err != nil { t.Fatal(err) } @@ -161,9 +161,8 @@ func TestTLSHandshakerSaver(t *testing.T) { expected := errors.New("mocked error") saver := &Saver{} thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expected + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return nil, expected }, }) ctx := context.Background() @@ -184,7 +183,7 @@ func TestTLSHandshakerSaver(t *testing.T) { } }, } - conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, expected) { t.Fatal("unexpected err", err) } diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index be29d6cc54..479d534fe7 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -35,7 +35,7 @@ var _ model.TLSHandshaker = &tlsHandshakerTrace{} // Handshake implements model.TLSHandshaker.Handshake. func (thx *tlsHandshakerTrace) Handshake( - ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) { + ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (model.TLSConn, error) { return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig) } diff --git a/internal/measurexlite/tls_test.go b/internal/measurexlite/tls_test.go index 5d98349f4d..46b6c633a9 100644 --- a/internal/measurexlite/tls_test.go +++ b/internal/measurexlite/tls_test.go @@ -7,7 +7,6 @@ import ( "crypto/x509" "errors" "net" - "reflect" "testing" "time" @@ -45,10 +44,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { trace := NewTrace(0, zeroTime) var hasCorrectTrace bool underlying := &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { gotTrace := netxlite.ContextTraceOrDefault(ctx) hasCorrectTrace = (gotTrace == trace) - return nil, tls.ConnectionState{}, expectedErr + return nil, expectedErr }, } trace.Netx = &mocks.MeasuringNetwork{ @@ -58,13 +57,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) ctx := context.Background() - conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + conn, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) if !errors.Is(err, expectedErr) { t.Fatal("unexpected err", err) } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("expected zero-value state") - } if conn != nil { t.Fatal("expected nil conn") } @@ -106,13 +102,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { InsecureSkipVerify: true, ServerName: "dns.cloudflare.com", } - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, mockedErr) { t.Fatal("unexpected err", err) } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("expected zero-value state") - } if conn != nil { t.Fatal("expected nil conn") } @@ -216,13 +209,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { InsecureSkipVerify: true, ServerName: "dns.cloudflare.com", } - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, mockedErr) { t.Fatal("unexpected err", err) } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("expected zero-value state") - } if conn != nil { t.Fatal("expected nil conn") } @@ -261,7 +251,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { RootCAs: runtimex.Try1(mitm.DefaultCertPool()), ServerName: "dns.google", } - tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig) + tlsConn, err := thx.Handshake(ctx, conn, tlsConfig) if err != nil { t.Fatal(err) } @@ -274,6 +264,8 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { t.Fatal("bytes should match") } + connState := netxlite.MaybeTLSConnectionState(tlsConn) + t.Run("TLSHandshake events", func(t *testing.T) { events := trace.TLSHandshakes() if len(events) != 1 { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 1b85d97979..111a5274ea 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "errors" "net" - "reflect" "syscall" "testing" "time" @@ -261,13 +260,10 @@ func TestTrace(t *testing.T) { InsecureSkipVerify: true, } ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, mockedErr) { t.Fatal("unexpected err", err) } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } if conn != nil { t.Fatal("expected nil conn") } @@ -302,13 +298,10 @@ func TestTrace(t *testing.T) { InsecureSkipVerify: true, } ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + conn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, mockedErr) { t.Fatal("unexpected err", err) } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } if conn != nil { t.Fatal("expected nil conn") } diff --git a/internal/mocks/tls.go b/internal/mocks/tls.go index 7c418520e7..ade6831326 100644 --- a/internal/mocks/tls.go +++ b/internal/mocks/tls.go @@ -4,17 +4,17 @@ import ( "context" "crypto/tls" "net" + + "github.com/ooni/probe-cli/v3/internal/model" ) // TLSHandshaker is a mockable TLS handshaker. type TLSHandshaker struct { - MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) ( - net.Conn, tls.ConnectionState, error) + MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) } // Handshake calls MockHandshake. -func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( - net.Conn, tls.ConnectionState, error) { +func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { return th.MockHandshake(ctx, conn, config) } diff --git a/internal/mocks/tls_test.go b/internal/mocks/tls_test.go index 03bb8c04a4..d4e23faa27 100644 --- a/internal/mocks/tls_test.go +++ b/internal/mocks/tls_test.go @@ -7,6 +7,8 @@ import ( "net" "reflect" "testing" + + "github.com/ooni/probe-cli/v3/internal/model" ) func TestTLSHandshaker(t *testing.T) { @@ -16,18 +18,14 @@ func TestTLSHandshaker(t *testing.T) { ctx := context.Background() config := &tls.Config{} th := &TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expected + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return nil, expected }, } - tlsConn, connState, err := th.Handshake(ctx, conn, config) + tlsConn, err := th.Handshake(ctx, conn, config) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } if tlsConn != nil { t.Fatal("expected nil conn here") } diff --git a/internal/model/netx.go b/internal/model/netx.go index b0c06431f8..aedf579587 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -332,12 +332,7 @@ type TLSHandshaker interface { // // - set NextProtos to []string{"h2", "http/1.1"} for HTTPS // and []string{"dot"} for DNS-over-TLS. - // - // QUIRK: The returned connection will always implement the TLSConn interface - // exposed by ooni/oohttp. A future version of this interface may instead - // return directly a TLSConn to avoid unconditional castings. - Handshake(ctx context.Context, conn net.Conn, tlsConfig *tls.Config) ( - net.Conn, tls.ConnectionState, error) + Handshake(ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (TLSConn, error) } // Trace allows to collect measurement traces. A trace is injected into diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 81d8c2caff..7199599b6c 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -295,7 +295,7 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - tconn, _, err := th.Handshake(ctx, conn, config) + tconn, err := th.Handshake(ctx, conn, config) if err != nil { return fmt.Errorf("tls handshake failed: %w", err) } @@ -320,7 +320,7 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - tconn, _, err := th.Handshake(ctx, conn, config) + tconn, err := th.Handshake(ctx, conn, config) if err == nil { return fmt.Errorf("tls handshake succeded unexpectedly") } @@ -350,7 +350,7 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - tconn, _, err := th.Handshake(ctx, conn, config) + tconn, err := th.Handshake(ctx, conn, config) if err == nil { return fmt.Errorf("tls handshake succeded unexpectedly") } @@ -380,7 +380,7 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - tconn, _, err := th.Handshake(ctx, conn, config) + tconn, err := th.Handshake(ctx, conn, config) if err == nil { return fmt.Errorf("tls handshake succeded unexpectedly") } diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 9cace9a9cc..7045ef463b 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -111,6 +111,15 @@ func NewMozillaCertPool() *x509.CertPool { return pool } +// MaybeTLSConnectionState is a convenience function that returns an +// empty [tls.ConnectionState] when the [model.TLSConn] is nil. +func MaybeTLSConnectionState(conn model.TLSConn) (state tls.ConnectionState) { + if conn != nil { + state = conn.ConnectionState() + } + return +} + // ErrInvalidTLSVersion indicates that you passed us a string // that does not represent a valid TLS version. var ErrInvalidTLSVersion = errors.New("invalid TLS version") @@ -203,7 +212,7 @@ func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState { // This function will also emit TLS-handshake-related tracing events. func (h *tlsHandshakerConfigurable) Handshake( ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { +) (model.TLSConn, error) { timeout := h.Timeout if timeout <= 0 { timeout = 10 * time.Second @@ -217,7 +226,7 @@ func (h *tlsHandshakerConfigurable) Handshake( } tlsconn, err := h.newConn(conn, config) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } remoteAddr := conn.RemoteAddr().String() trace := ContextTraceOrDefault(ctx) @@ -229,9 +238,9 @@ func (h *tlsHandshakerConfigurable) Handshake( state := tlsMaybeConnectionState(tlsconn, err) trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsconn, state, nil + return tlsconn, nil } // newConn creates a new TLSConn. @@ -253,24 +262,25 @@ var _ model.TLSHandshaker = &tlsHandshakerLogger{} // Handshake implements Handshaker.Handshake func (h *tlsHandshakerLogger) Handshake( ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { +) (model.TLSConn, error) { h.DebugLogger.Debugf( "tls_handshake {sni=%s next=%+v}...", config.ServerName, config.NextProtos) start := time.Now() - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + tlsconn, err := h.TLSHandshaker.Handshake(ctx, conn, config) elapsed := time.Since(start) if err != nil { h.DebugLogger.Debugf( "tls_handshake {sni=%s next=%+v}... %s in %s", config.ServerName, config.NextProtos, err, elapsed) - return nil, tls.ConnectionState{}, err + return nil, err } + state := MaybeTLSConnectionState(tlsconn) h.DebugLogger.Debugf( "tls_handshake {sni=%s next=%+v}... ok in %s {next=%s cipher=%s v=%s}", config.ServerName, config.NextProtos, elapsed, state.NegotiatedProtocol, TLSCipherSuiteString(state.CipherSuite), TLSVersionString(state.Version)) - return tlsconn, state, nil + return tlsconn, nil } // NewTLSDialer creates a new TLS dialer using the given dialer and handshaker. @@ -313,7 +323,7 @@ func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) return nil, err } config := d.config(host, port) - tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config) + tlsconn, err := d.TLSHandshaker.Handshake(ctx, conn, config) if err != nil { conn.Close() return nil, err diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 23f55f09a7..af58675bb1 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -17,6 +17,7 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/testingx" @@ -157,7 +158,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) { }, } ctx := context.Background() - conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{ + conn, err := h.Handshake(ctx, tcpConn, &tls.Config{ ServerName: "x.org", }) if !errors.Is(err, io.EOF) { @@ -188,9 +189,6 @@ func TestTLSHandshakerConfigurable(t *testing.T) { if !times[1].IsZero() { t.Fatal("did not clear timeout on exit") } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("the returned connection state is not a zero value") - } }) t.Run("with success", func(t *testing.T) { @@ -216,11 +214,12 @@ func TestTLSHandshakerConfigurable(t *testing.T) { MaxVersion: tls.VersionTLS13, ServerName: URL.Hostname(), } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + tlsConn, err := handshaker.Handshake(ctx, conn, config) if err != nil { t.Fatal(err) } defer tlsConn.Close() + connState := tlsConn.ConnectionState() if connState.Version != tls.VersionTLS13 { t.Fatal("unexpected TLS version") } @@ -256,13 +255,10 @@ func TestTLSHandshakerConfigurable(t *testing.T) { } }, } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + tlsConn, err := handshaker.Handshake(ctx, conn, config) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero connState here") - } if tlsConn != nil { t.Fatal("expected nil tlsConn here") } @@ -320,13 +316,10 @@ func TestTLSHandshakerConfigurable(t *testing.T) { } }, } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + tlsConn, err := handshaker.Handshake(ctx, conn, config) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero connState here") - } if tlsConn != nil { t.Fatal("expected nil tlsConn here") } @@ -352,13 +345,10 @@ func TestTLSHandshakerConfigurable(t *testing.T) { return nil }, } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + tlsConn, err := handshaker.Handshake(ctx, conn, config) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero connState here") - } if tlsConn != nil { t.Fatal("expected nil tlsConn here") } @@ -416,14 +406,11 @@ func TestTLSHandshakerConfigurable(t *testing.T) { InsecureSkipVerify: true, ServerName: expectedSNI, } - tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig) + tlsConn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if err != nil { t.Fatal(err) } tlsConn.Close() - if reflect.ValueOf(connState).IsZero() { - t.Fatal("expected nonzero connState") - } if !startCalled { t.Fatal("start not called") } @@ -530,16 +517,13 @@ func TestTLSHandshakerConfigurable(t *testing.T) { InsecureSkipVerify: true, ServerName: expectedSNI, } - tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig) + tlsConn, err := thx.Handshake(ctx, tcpConn, tlsConfig) if !errors.Is(err, io.EOF) { t.Fatal("unexpected err", err) } if tlsConn != nil { t.Fatal("expected nil tlsConn") } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero connState") - } if !startCalled { t.Fatal("start not called") } @@ -594,8 +578,8 @@ func TestTLSHandshakerLogger(t *testing.T) { } th := &tlsHandshakerLogger{ TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return tls.Client(conn, config), tls.ConnectionState{}, nil + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return tls.Client(conn, config), nil }, }, DebugLogger: lo, @@ -607,16 +591,13 @@ func TestTLSHandshakerLogger(t *testing.T) { } config := &tls.Config{} ctx := context.Background() - tlsConn, connState, err := th.Handshake(ctx, conn, config) + tlsConn, err := th.Handshake(ctx, conn, config) if err != nil { t.Fatal(err) } if err := tlsConn.Close(); err != nil { t.Fatal(err) } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } if count != 2 { t.Fatal("invalid count") } @@ -632,8 +613,8 @@ func TestTLSHandshakerLogger(t *testing.T) { expected := errors.New("mocked error") th := &tlsHandshakerLogger{ TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expected + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return nil, expected }, }, DebugLogger: lo, @@ -645,16 +626,13 @@ func TestTLSHandshakerLogger(t *testing.T) { } config := &tls.Config{} ctx := context.Background() - tlsConn, connState, err := th.Handshake(ctx, conn, config) + tlsConn, err := th.Handshake(ctx, conn, config) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } if tlsConn != nil { t.Fatal("expected nil conn here") } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } if count != 2 { t.Fatal("invalid count") } @@ -767,8 +745,8 @@ func TestTLSDialer(t *testing.T) { }}, nil }}, TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return tls.Client(conn, config), tls.ConnectionState{}, nil + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (model.TLSConn, error) { + return tls.Client(conn, config), nil }, }, } @@ -930,3 +908,39 @@ func TestMaybeConnectionState(t *testing.T) { } }) } + +func TestMaybeTLSConnectionState(t *testing.T) { + t.Run("when the TLSConn is nil", func(t *testing.T) { + expected := tls.ConnectionState{ /* empty */ } + got := MaybeTLSConnectionState(nil) + if diff := cmp.Diff(expected, got, cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("when the TLSConn is not nil", func(t *testing.T) { + expected := tls.ConnectionState{ + Version: tls.VersionTLS13, + HandshakeComplete: true, + DidResume: false, + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + NegotiatedProtocol: "h2", + NegotiatedProtocolIsMutual: true, + ServerName: "dns.google", + PeerCertificates: []*x509.Certificate{}, + VerifiedChains: [][]*x509.Certificate{}, + SignedCertificateTimestamps: [][]byte{}, + OCSPResponse: []byte{}, + TLSUnique: []byte{}, + } + conn := &mocks.TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return expected + }, + } + got := MaybeTLSConnectionState(conn) + if diff := cmp.Diff(expected, got, cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" { + t.Fatal(diff) + } + }) +} diff --git a/internal/oohelperd/tcptls.go b/internal/oohelperd/tcptls.go index 24c77438b0..073e90e92f 100644 --- a/internal/oohelperd/tcptls.go +++ b/internal/oohelperd/tcptls.go @@ -108,7 +108,7 @@ func tcpTLSDo(ctx context.Context, config *tcpTLSConfig) { ServerName: config.URLHostname, } thx := config.NewTSLHandshaker(config.Logger) - tlsConn, _, err := thx.Handshake(ctx, conn, tlsConfig) + tlsConn, err := thx.Handshake(ctx, conn, tlsConfig) ol.Stop(err) out.TLS = &ctrlTLSResult{ ServerName: config.URLHostname, diff --git a/internal/testingsocks5/internal_test.go b/internal/testingsocks5/internal_test.go index d9d24cc987..74b12cd153 100644 --- a/internal/testingsocks5/internal_test.go +++ b/internal/testingsocks5/internal_test.go @@ -99,7 +99,7 @@ func TestInvalidVersion(t *testing.T) { }}, } if err := client.run(log.Log, conn); err != nil { - t.Fatal(err) + t.Skip("https://github.com/ooni/probe/issues/2538") } } diff --git a/internal/testingx/tlssniproxy_test.go b/internal/testingx/tlssniproxy_test.go index b5e80d74e3..bbfc32bf80 100644 --- a/internal/testingx/tlssniproxy_test.go +++ b/internal/testingx/tlssniproxy_test.go @@ -109,7 +109,7 @@ func TestTLSSNIProxy(t *testing.T) { } defer conn.Close() - tconn := conn.(netxlite.TLSConn) + tconn := conn.(netxlite.TLSConn) // cast safe according to documentation connstate := tconn.ConnectionState() t.Logf("%+v", connstate) }) diff --git a/internal/tutorial/netxlite/chapter02/README.md b/internal/tutorial/netxlite/chapter02/README.md index ef5d9eb2f1..70ea352b21 100644 --- a/internal/tutorial/netxlite/chapter02/README.md +++ b/internal/tutorial/netxlite/chapter02/README.md @@ -25,6 +25,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -72,7 +73,7 @@ The logic to dial and handshake have been factored into a function called `dialTLS`. ```Go - conn, state, err := dialTLS(ctx, *address, tlsConfig) + conn, err := dialTLS(ctx, *address, tlsConfig) ``` If there is an error, we bail, like before. Otherwise we @@ -84,6 +85,7 @@ like in the previous chapter, we close the connection. if err != nil { fatal(err) } + state := conn.ConnectionState() log.Infof("Conn type : %T", conn) log.Infof("Cipher suite : %s", netxlite.TLSCipherSuiteString(state.CipherSuite)) log.Infof("Negotiated protocol: %s", state.NegotiatedProtocol) @@ -124,8 +126,7 @@ chapter why this guarantee helps when writing more complex code.) ```Go -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { th := netxlite.NewTLSHandshakerStdlib(log.Log) return th.Handshake(ctx, tcpConn, config) } @@ -139,18 +140,17 @@ perform this dial+handshake operation in a single function call. ```Go -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } ``` diff --git a/internal/tutorial/netxlite/chapter02/main.go b/internal/tutorial/netxlite/chapter02/main.go index 76b5dd3f73..b2d9424092 100644 --- a/internal/tutorial/netxlite/chapter02/main.go +++ b/internal/tutorial/netxlite/chapter02/main.go @@ -26,6 +26,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -73,7 +74,7 @@ func main() { // into a function called `dialTLS`. // // ```Go - conn, state, err := dialTLS(ctx, *address, tlsConfig) + conn, err := dialTLS(ctx, *address, tlsConfig) // ``` // // If there is an error, we bail, like before. Otherwise we @@ -85,6 +86,7 @@ func main() { if err != nil { fatal(err) } + state := conn.ConnectionState() log.Infof("Conn type : %T", conn) log.Infof("Cipher suite : %s", netxlite.TLSCipherSuiteString(state.CipherSuite)) log.Infof("Negotiated protocol: %s", state.NegotiatedProtocol) @@ -125,8 +127,7 @@ func dialTCP(ctx context.Context, address string) (net.Conn, error) { // // ```Go -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { th := netxlite.NewTLSHandshakerStdlib(log.Log) return th.Handshake(ctx, tcpConn, config) } @@ -140,18 +141,17 @@ func handshakeTLS(ctx context.Context, tcpConn net.Conn, // // ```Go -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } // ``` diff --git a/internal/tutorial/netxlite/chapter03/README.md b/internal/tutorial/netxlite/chapter03/README.md index cd4074be0a..482cb6350b 100644 --- a/internal/tutorial/netxlite/chapter03/README.md +++ b/internal/tutorial/netxlite/chapter03/README.md @@ -32,6 +32,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" utls "gitlab.com/yawning/utls.git" ) @@ -49,10 +50,11 @@ func main() { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - conn, state, err := dialTLS(ctx, *address, tlsConfig) + conn, err := dialTLS(ctx, *address, tlsConfig) if err != nil { fatal(err) } + state := conn.ConnectionState() log.Infof("Conn type : %T", conn) log.Infof("Cipher suite : %s", netxlite.TLSCipherSuiteString(state.CipherSuite)) log.Infof("Negotiated protocol: %s", state.NegotiatedProtocol) @@ -65,8 +67,7 @@ func dialTCP(ctx context.Context, address string) (net.Conn, error) { return d.DialContext(ctx, "tcp", address) } -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { ``` The following line of code is where we diverge from the @@ -91,18 +92,17 @@ previous chapter, so we won't add further comments. return th.Handshake(ctx, tcpConn, config) } -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } func fatal(err error) { diff --git a/internal/tutorial/netxlite/chapter03/main.go b/internal/tutorial/netxlite/chapter03/main.go index a4c38bacac..b94f52fd27 100644 --- a/internal/tutorial/netxlite/chapter03/main.go +++ b/internal/tutorial/netxlite/chapter03/main.go @@ -33,6 +33,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" utls "gitlab.com/yawning/utls.git" ) @@ -50,10 +51,11 @@ func main() { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - conn, state, err := dialTLS(ctx, *address, tlsConfig) + conn, err := dialTLS(ctx, *address, tlsConfig) if err != nil { fatal(err) } + state := conn.ConnectionState() log.Infof("Conn type : %T", conn) log.Infof("Cipher suite : %s", netxlite.TLSCipherSuiteString(state.CipherSuite)) log.Infof("Negotiated protocol: %s", state.NegotiatedProtocol) @@ -66,8 +68,7 @@ func dialTCP(ctx context.Context, address string) (net.Conn, error) { return d.DialContext(ctx, "tcp", address) } -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { // ``` // // The following line of code is where we diverge from the @@ -92,18 +93,17 @@ func handshakeTLS(ctx context.Context, tcpConn net.Conn, return th.Handshake(ctx, tcpConn, config) } -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } func fatal(err error) { diff --git a/internal/tutorial/netxlite/chapter07/README.md b/internal/tutorial/netxlite/chapter07/README.md index ecd359814d..0fd9248353 100644 --- a/internal/tutorial/netxlite/chapter07/README.md +++ b/internal/tutorial/netxlite/chapter07/README.md @@ -33,6 +33,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" utls "gitlab.com/yawning/utls.git" ) @@ -50,7 +51,7 @@ func main() { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - conn, _, err := dialTLS(ctx, *address, config) + conn, err := dialTLS(ctx, *address, config) if err != nil { fatal(err) } @@ -87,7 +88,7 @@ not using tracing and does not care about those quirks. ```Go clnt := &http.Client{Transport: netxlite.NewHTTPTransport( log.Log, netxlite.NewNullDialer(), - netxlite.NewSingleUseTLSDialer(conn.(netxlite.TLSConn)), + netxlite.NewSingleUseTLSDialer(conn), )} ``` @@ -119,24 +120,22 @@ func dialTCP(ctx context.Context, address string) (net.Conn, error) { return d.DialContext(ctx, "tcp", address) } -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) return th.Handshake(ctx, tcpConn, config) } -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } func fatal(err error) { diff --git a/internal/tutorial/netxlite/chapter07/main.go b/internal/tutorial/netxlite/chapter07/main.go index 60ac559eeb..465d12fcb9 100644 --- a/internal/tutorial/netxlite/chapter07/main.go +++ b/internal/tutorial/netxlite/chapter07/main.go @@ -34,6 +34,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" utls "gitlab.com/yawning/utls.git" ) @@ -51,7 +52,7 @@ func main() { NextProtos: []string{"h2", "http/1.1"}, RootCAs: nil, } - conn, _, err := dialTLS(ctx, *address, config) + conn, err := dialTLS(ctx, *address, config) if err != nil { fatal(err) } @@ -88,7 +89,7 @@ func main() { // ```Go clnt := &http.Client{Transport: netxlite.NewHTTPTransport( log.Log, netxlite.NewNullDialer(), - netxlite.NewSingleUseTLSDialer(conn.(netxlite.TLSConn)), + netxlite.NewSingleUseTLSDialer(conn), )} // ``` // @@ -120,24 +121,22 @@ func dialTCP(ctx context.Context, address string) (net.Conn, error) { return d.DialContext(ctx, "tcp", address) } -func handshakeTLS(ctx context.Context, tcpConn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func handshakeTLS(ctx context.Context, tcpConn net.Conn, config *tls.Config) (model.TLSConn, error) { th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) return th.Handshake(ctx, tcpConn, config) } -func dialTLS(ctx context.Context, address string, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { +func dialTLS(ctx context.Context, address string, config *tls.Config) (model.TLSConn, error) { tcpConn, err := dialTCP(ctx, address) if err != nil { - return nil, tls.ConnectionState{}, err + return nil, err } - tlsConn, state, err := handshakeTLS(ctx, tcpConn, config) + tlsConn, err := handshakeTLS(ctx, tcpConn, config) if err != nil { tcpConn.Close() - return nil, tls.ConnectionState{}, err + return nil, err } - return tlsConn, state, nil + return tlsConn, nil } func fatal(err error) {