Skip to content

Commit

Permalink
Improve read waiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 7, 2023
1 parent 074dd24 commit d0914d7
Show file tree
Hide file tree
Showing 17 changed files with 287 additions and 312 deletions.
233 changes: 0 additions & 233 deletions common/badtls/badtls.go

This file was deleted.

14 changes: 0 additions & 14 deletions common/badtls/badtls_stub.go

This file was deleted.

22 changes: 0 additions & 22 deletions common/badtls/link.go

This file was deleted.

119 changes: 119 additions & 0 deletions common/badtls/read_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//go:build go1.21 && !without_badtls

package badtls

import (
"bytes"
"os"
"reflect"
"sync"
"unsafe"

"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/tls"
)

var _ N.ReadWaiter = (*ReadWaitConn)(nil)

type ReadWaitConn struct {
*tls.STDConn
halfAccess *sync.Mutex
rawInput *bytes.Buffer
input *bytes.Reader
hand *bytes.Buffer
readWaitOptions N.ReadWaitOptions
}

func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
stdConn, isSTDConn := conn.(*tls.STDConn)
if !isSTDConn {
return nil, os.ErrInvalid
}
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
rawHalfConn := rawConn.FieldByName("in")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn")
}
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half mutex")
}
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
rawRawInput := rawHalfConn.FieldByName("rawInput")
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid raw input")
}
rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
rawInput0 := rawHalfConn.FieldByName("input")
if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid input")
}
input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
rawHand := rawConn.FieldByName("hand")
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid hand")
}
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
rawReadRecord := rawConn.MethodByName("readRecord")
if !rawReadRecord.IsValid() {
return nil, E.New("badtls: invalid readRecord")
}
return &ReadWaitConn{
STDConn: stdConn,
halfAccess: halfAccess,
rawInput: rawInput,
input: input,
hand: hand,
}, nil
}

func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
err = c.Handshake()
if err != nil {
return
}
c.halfAccess.Lock()
defer c.halfAccess.Unlock()
for c.input.Len() == 0 {
err = tlsReadRecord(c.STDConn)
if err != nil {
return
}
for c.hand.Len() > 0 {
err = tlsHandlePostHandshakeMessage(c.STDConn)
if err != nil {
return
}
}
}
buffer = c.readWaitOptions.NewBuffer()
n, err := c.input.Read(buffer.FreeBytes())
if err != nil {
buffer.Release()
return
}
buffer.Truncate(n)

if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
c.rawInput.Bytes()[0] == 21 {
_ = tlsReadRecord(c.STDConn)
// return n, err // will be io.EOF on closeNotify
}

c.readWaitOptions.PostReturn(buffer)
return
}

//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
func tlsReadRecord(c *tls.STDConn) error

//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
Loading

0 comments on commit d0914d7

Please sign in to comment.