Skip to content

Commit

Permalink
rework client to prevent after-close usage, and support perm at open
Browse files Browse the repository at this point in the history
  • Loading branch information
puellanivis committed Jan 19, 2024
1 parent 22452ea commit d1903fb
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 110 deletions.
19 changes: 17 additions & 2 deletions attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name }
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }

// Mode returns file mode bits.
func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }

// ModTime returns the last modification time of the file.
func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }

// IsDir returns true if the file is a directory.
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
Expand All @@ -56,6 +56,21 @@ type FileStat struct {
Extended []StatExtended
}

// ModTime returns the Mtime SFTP file attribute converted to a time.Time
func (fs *FileStat) ModTime() time.Time {
return time.Unix(int64(fs.Mtime), 0)
}

// AccessTime returns the Atime SFTP file attribute converted to a time.Time
func (fs *FileStat) AccessTime() time.Time {
return time.Unix(int64(fs.Atime), 0)
}

// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
func (fs *FileStat) FileMode() os.FileMode {
return toFileMode(fs.Mode)
}

// StatExtended contains additional, extended information for a FileStat.
type StatExtended struct {
ExtType string
Expand Down
156 changes: 131 additions & 25 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}

const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
Expand Down Expand Up @@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Expand Down Expand Up @@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error {
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY))
return c.open(path, toPflags(os.O_RDONLY))
}

// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f))
return c.open(path, toPflags(f))
}

func (c *Client) open(path string, pflags uint32) (*File, error) {
Expand Down Expand Up @@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error {
type File struct {
c *Client
path string
handle string

mu sync.Mutex
mu sync.RWMutex
handle string
offset int64 // current offset within remote file
}

// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
return f.c.close(f.handle)
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}

handle := f.handle
f.handle = ""

return f.c.close(handle)
}

// Name returns the name of the file as presented to Open or Create.
Expand All @@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.ReadAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.readAt(b, f.offset)
f.offset += int64(n)
return n, err
}
Expand Down Expand Up @@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.readAt(b, off)
}

func (f *File) readAt(b []byte, off int64) (int, error) {
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
Expand Down Expand Up @@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
Expand Down Expand Up @@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}
}

func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return nil, os.ErrClosed
}

return f.stat()
}

// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
Expand All @@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.WriteAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
Expand Down Expand Up @@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.writeAt(b, off)
}

func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
Expand Down Expand Up @@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.readFromWithConcurrency(r, concurrency)
}

func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
Expand Down Expand Up @@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
Expand All @@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {

if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}

if remain > int64(f.c.maxPacket) {
Expand All @@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}

return f.ReadFromWithConcurrency(r, int(concurrency64))
return f.readFromWithConcurrency(r, int(concurrency64))
}
}

Expand Down Expand Up @@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.Stat()
fi, err := f.stat()
if err != nil {
return f.offset, err
}
Expand All @@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {

// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid)
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
}

// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the [email protected] extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}


id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Expand All @@ -1957,15 +2072,6 @@ func (f *File) Sync() error {
}
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
Expand All @@ -1990,7 +2096,7 @@ func normaliseError(err error) error {

// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func flags(f int) uint32 {
func toPflags(f int) uint32 {
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ var flagsTests = []struct {

func TestFlags(t *testing.T) {
for i, tt := range flagsTests {
got := flags(tt.flags)
got := toPflags(tt.flags)
if got != tt.want {
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
}
Expand Down
Loading

0 comments on commit d1903fb

Please sign in to comment.