-
Notifications
You must be signed in to change notification settings - Fork 380
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rework client to prevent after-close usage, and support perm at open
- Loading branch information
1 parent
22452ea
commit d1903fb
Showing
9 changed files
with
381 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie | |
// read/write at the same time. For those services you will need to use | ||
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. | ||
func (c *Client) Create(path string) (*File, error) { | ||
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) | ||
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) | ||
} | ||
|
||
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt | ||
|
@@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error { | |
} | ||
} | ||
|
||
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { | ||
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error { | ||
id := c.nextID() | ||
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ | ||
ID: id, | ||
|
@@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error { | |
// returned file can be used for reading; the associated file descriptor | ||
// has mode O_RDONLY. | ||
func (c *Client) Open(path string) (*File, error) { | ||
return c.open(path, flags(os.O_RDONLY)) | ||
return c.open(path, toPflags(os.O_RDONLY)) | ||
} | ||
|
||
// OpenFile is the generalized open call; most users will use Open or | ||
// Create instead. It opens the named file with specified flag (O_RDONLY | ||
// etc.). If successful, methods on the returned File can be used for I/O. | ||
func (c *Client) OpenFile(path string, f int) (*File, error) { | ||
return c.open(path, flags(f)) | ||
return c.open(path, toPflags(f)) | ||
} | ||
|
||
func (c *Client) open(path string, pflags uint32) (*File, error) { | ||
|
@@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error { | |
type File struct { | ||
c *Client | ||
path string | ||
handle string | ||
|
||
mu sync.Mutex | ||
mu sync.RWMutex | ||
handle string | ||
offset int64 // current offset within remote file | ||
} | ||
|
||
// Close closes the File, rendering it unusable for I/O. It returns an | ||
// error, if any. | ||
func (f *File) Close() error { | ||
return f.c.close(f.handle) | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
handle := f.handle | ||
f.handle = "" | ||
|
||
return f.c.close(handle) | ||
} | ||
|
||
// Name returns the name of the file as presented to Open or Create. | ||
|
@@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
n, err := f.ReadAt(b, f.offset) | ||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
n, err := f.readAt(b, f.offset) | ||
f.offset += int64(n) | ||
return n, err | ||
} | ||
|
@@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { | |
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, | ||
// so the file offset is not altered during the read. | ||
func (f *File) ReadAt(b []byte, off int64) (int, error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.readAt(b, off) | ||
} | ||
|
||
func (f *File) readAt(b []byte, off int64) (int, error) { | ||
if len(b) <= f.c.maxPacket { | ||
// This should be able to be serviced with 1/2 requests. | ||
// So, just do it directly. | ||
|
@@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
if f.c.disableConcurrentReads { | ||
return f.writeToSequential(w) | ||
} | ||
|
@@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { | |
} | ||
} | ||
|
||
func (f *File) Stat() (os.FileInfo, error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return nil, os.ErrClosed | ||
} | ||
|
||
return f.stat() | ||
} | ||
|
||
// Stat returns the FileInfo structure describing file. If there is an | ||
// error. | ||
func (f *File) Stat() (os.FileInfo, error) { | ||
func (f *File) stat() (os.FileInfo, error) { | ||
fs, err := f.c.fstat(f.handle) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
n, err := f.WriteAt(b, f.offset) | ||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
n, err := f.writeAt(b, f.offset) | ||
f.offset += int64(n) | ||
return n, err | ||
} | ||
|
@@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { | |
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, | ||
// so the file offset is not altered during the write. | ||
func (f *File) WriteAt(b []byte, off int64) (written int, err error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.writeAt(b, off) | ||
} | ||
|
||
func (f *File) writeAt(b []byte, off int64) (written int, err error) { | ||
if len(b) <= f.c.maxPacket { | ||
// We can do this in one write. | ||
return f.writeChunkAt(nil, b, off) | ||
|
@@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { | |
// | ||
// Otherwise, the given concurrency will be capped by the Client's max concurrency. | ||
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.readFromWithConcurrency(r, concurrency) | ||
} | ||
|
||
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { | ||
// Split the write into multiple maxPacket sized concurrent writes. | ||
// This allows writes with a suitably large reader | ||
// to transfer data at a much faster rate due to overlapping round trip times. | ||
|
@@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
if f.c.useConcurrentWrites { | ||
var remain int64 | ||
switch r := r.(type) { | ||
|
@@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
|
||
if remain < 0 { | ||
// We can strongly assert that we want default max concurrency here. | ||
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) | ||
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests) | ||
} | ||
|
||
if remain > int64(f.c.maxPacket) { | ||
|
@@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
concurrency64 = int64(f.c.maxConcurrentRequests) | ||
} | ||
|
||
return f.ReadFromWithConcurrency(r, int(concurrency64)) | ||
return f.readFromWithConcurrency(r, int(concurrency64)) | ||
} | ||
} | ||
|
||
|
@@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
switch whence { | ||
case io.SeekStart: | ||
case io.SeekCurrent: | ||
offset += f.offset | ||
case io.SeekEnd: | ||
fi, err := f.Stat() | ||
fi, err := f.stat() | ||
if err != nil { | ||
return f.offset, err | ||
} | ||
|
@@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { | |
|
||
// Chown changes the uid/gid of the current file. | ||
func (f *File) Chown(uid, gid int) error { | ||
return f.c.Chown(f.path, uid, gid) | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{ | ||
UID: uint32(uid), | ||
GID: uint32(gid), | ||
}) | ||
} | ||
|
||
// Chmod changes the permissions of the current file. | ||
// | ||
// See Client.Chmod for details. | ||
func (f *File) Chmod(mode os.FileMode) error { | ||
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) | ||
} | ||
|
||
// Truncate sets the size of the current file. Although it may be safely assumed | ||
// that if the size is less than its current size it will be truncated to fit, | ||
// the SFTP protocol does not specify what behavior the server should do when setting | ||
// size greater than the current size. | ||
// We send a SSH_FXP_FSETSTAT here since we have a file handle | ||
func (f *File) Truncate(size int64) error { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size)) | ||
} | ||
|
||
// Sync requests a flush of the contents of a File to stable storage. | ||
// | ||
// Sync requires the server to support the [email protected] extension. | ||
func (f *File) Sync() error { | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
|
||
id := f.c.nextID() | ||
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ | ||
ID: id, | ||
|
@@ -1957,15 +2072,6 @@ func (f *File) Sync() error { | |
} | ||
} | ||
|
||
// Truncate sets the size of the current file. Although it may be safely assumed | ||
// that if the size is less than its current size it will be truncated to fit, | ||
// the SFTP protocol does not specify what behavior the server should do when setting | ||
// size greater than the current size. | ||
// We send a SSH_FXP_FSETSTAT here since we have a file handle | ||
func (f *File) Truncate(size int64) error { | ||
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size)) | ||
} | ||
|
||
// normaliseError normalises an error into a more standard form that can be | ||
// checked against stdlib errors like io.EOF or os.ErrNotExist. | ||
func normaliseError(err error) error { | ||
|
@@ -1990,7 +2096,7 @@ func normaliseError(err error) error { | |
|
||
// flags converts the flags passed to OpenFile into ssh flags. | ||
// Unsupported flags are ignored. | ||
func flags(f int) uint32 { | ||
func toPflags(f int) uint32 { | ||
var out uint32 | ||
switch f & os.O_WRONLY { | ||
case os.O_WRONLY: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.