From d1903fbd460e9a8105bae72fcdf492a4999b4cee Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 00:20:23 +0000 Subject: [PATCH] rework client to prevent after-close usage, and support perm at open --- attrs.go | 19 ++++- client.go | 156 +++++++++++++++++++++++++++++++------ client_test.go | 2 +- packet.go | 93 ++++++++++++++++++---- packet_test.go | 29 +++++-- request-attrs.go | 6 -- server.go | 87 +++++++++------------ server_integration_test.go | 17 ++-- server_test.go | 82 ++++++++++++++++++- 9 files changed, 381 insertions(+), 110 deletions(-) diff --git a/attrs.go b/attrs.go index 758cd4ff..74ac03b7 100644 --- a/attrs.go +++ b/attrs.go @@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name } func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) } // Mode returns file mode bits. -func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) } +func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() } // ModTime returns the last modification time of the file. -func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) } +func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() } // IsDir returns true if the file is a directory. func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } @@ -56,6 +56,21 @@ type FileStat struct { Extended []StatExtended } +// ModTime returns the Mtime SFTP file attribute converted to a time.Time +func (fs *FileStat) ModTime() time.Time { + return time.Unix(int64(fs.Mtime), 0) +} + +// AccessTime returns the Atime SFTP file attribute converted to a time.Time +func (fs *FileStat) AccessTime() time.Time { + return time.Unix(int64(fs.Atime), 0) +} + +// FileMode returns the Mode SFTP file attribute converted to an os.FileMode +func (fs *FileStat) FileMode() os.FileMode { + return toFileMode(fs.Mode) +} + // StatExtended contains additional, extended information for a FileStat. type StatExtended struct { ExtType string diff --git a/client.go b/client.go index a3b8e22b..1d55aaea 100644 --- a/client.go +++ b/client.go @@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie // read/write at the same time. For those services you will need to use // `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. func (c *Client) Create(path string) (*File, error) { - return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) + return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) } const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt @@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error { } } -func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { +func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error { id := c.nextID() typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ ID: id, @@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error { // returned file can be used for reading; the associated file descriptor // has mode O_RDONLY. func (c *Client) Open(path string) (*File, error) { - return c.open(path, flags(os.O_RDONLY)) + return c.open(path, toPflags(os.O_RDONLY)) } // OpenFile is the generalized open call; most users will use Open or // Create instead. It opens the named file with specified flag (O_RDONLY // etc.). If successful, methods on the returned File can be used for I/O. func (c *Client) OpenFile(path string, f int) (*File, error) { - return c.open(path, flags(f)) + return c.open(path, toPflags(f)) } func (c *Client) open(path string, pflags uint32) (*File, error) { @@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error { type File struct { c *Client path string - handle string - mu sync.Mutex + mu sync.RWMutex + handle string offset int64 // current offset within remote file } // Close closes the File, rendering it unusable for I/O. It returns an // error, if any. func (f *File) Close() error { - return f.c.close(f.handle) + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed + } + + handle := f.handle + f.handle = "" + + return f.c.close(handle) } // Name returns the name of the file as presented to Open or Create. @@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - n, err := f.ReadAt(b, f.offset) + if f.handle == "" { + return 0, os.ErrClosed + } + + n, err := f.readAt(b, f.offset) f.offset += int64(n) return n, err } @@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { // the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, // so the file offset is not altered during the read. func (f *File) ReadAt(b []byte, off int64) (int, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.readAt(b, off) +} + +func (f *File) readAt(b []byte, off int64) (int, error) { if len(b) <= f.c.maxPacket { // This should be able to be serviced with 1/2 requests. // So, just do it directly. @@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + if f.c.disableConcurrentReads { return f.writeToSequential(w) } @@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } } +func (f *File) Stat() (os.FileInfo, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return nil, os.ErrClosed + } + + return f.stat() +} + // Stat returns the FileInfo structure describing file. If there is an // error. -func (f *File) Stat() (os.FileInfo, error) { +func (f *File) stat() (os.FileInfo, error) { fs, err := f.c.fstat(f.handle) if err != nil { return nil, err @@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - n, err := f.WriteAt(b, f.offset) + if f.handle == "" { + return 0, os.ErrClosed + } + + n, err := f.writeAt(b, f.offset) f.offset += int64(n) return n, err } @@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { // the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, // so the file offset is not altered during the write. func (f *File) WriteAt(b []byte, off int64) (written int, err error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.writeAt(b, off) +} + +func (f *File) writeAt(b []byte, off int64) (written int, err error) { if len(b) <= f.c.maxPacket { // We can do this in one write. return f.writeChunkAt(nil, b, off) @@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { // // Otherwise, the given concurrency will be capped by the Client's max concurrency. func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.readFromWithConcurrency(r, concurrency) +} + +func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { // Split the write into multiple maxPacket sized concurrent writes. // This allows writes with a suitably large reader // to transfer data at a much faster rate due to overlapping round trip times. @@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + if f.c.useConcurrentWrites { var remain int64 switch r := r.(type) { @@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { if remain < 0 { // We can strongly assert that we want default max concurrency here. - return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) + return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests) } if remain > int64(f.c.maxPacket) { @@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { concurrency64 = int64(f.c.maxConcurrentRequests) } - return f.ReadFromWithConcurrency(r, int(concurrency64)) + return f.readFromWithConcurrency(r, int(concurrency64)) } } @@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + switch whence { case io.SeekStart: case io.SeekCurrent: offset += f.offset case io.SeekEnd: - fi, err := f.Stat() + fi, err := f.stat() if err != nil { return f.offset, err } @@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { // Chown changes the uid/gid of the current file. func (f *File) Chown(uid, gid int) error { - return f.c.Chown(f.path, uid, gid) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{ + UID: uint32(uid), + GID: uint32(gid), + }) } // Chmod changes the permissions of the current file. // // See Client.Chmod for details. func (f *File) Chmod(mode os.FileMode) error { - return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) +} + +// Truncate sets the size of the current file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +// We send a SSH_FXP_FSETSTAT here since we have a file handle +func (f *File) Truncate(size int64) error { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size)) } // Sync requests a flush of the contents of a File to stable storage. // // Sync requires the server to support the fsync@openssh.com extension. func (f *File) Sync() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed + } + + id := f.c.nextID() typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ ID: id, @@ -1957,15 +2072,6 @@ func (f *File) Sync() error { } } -// Truncate sets the size of the current file. Although it may be safely assumed -// that if the size is less than its current size it will be truncated to fit, -// the SFTP protocol does not specify what behavior the server should do when setting -// size greater than the current size. -// We send a SSH_FXP_FSETSTAT here since we have a file handle -func (f *File) Truncate(size int64) error { - return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size)) -} - // normaliseError normalises an error into a more standard form that can be // checked against stdlib errors like io.EOF or os.ErrNotExist. func normaliseError(err error) error { @@ -1990,7 +2096,7 @@ func normaliseError(err error) error { // flags converts the flags passed to OpenFile into ssh flags. // Unsupported flags are ignored. -func flags(f int) uint32 { +func toPflags(f int) uint32 { var out uint32 switch f & os.O_WRONLY { case os.O_WRONLY: diff --git a/client_test.go b/client_test.go index 4577ca22..dda8af2b 100644 --- a/client_test.go +++ b/client_test.go @@ -81,7 +81,7 @@ var flagsTests = []struct { func TestFlags(t *testing.T) { for i, tt := range flagsTests { - got := flags(tt.flags) + got := toPflags(tt.flags) if got != tt.want { t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) } diff --git a/packet.go b/packet.go index 1232ff1e..2fea2bef 100644 --- a/packet.go +++ b/packet.go @@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte { flags, fileStat := fileStatFromInfo(fi) b = marshalUint32(b, flags) + + return marshalFileStat(b, flags, fileStat) +} + +func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte { if flags&sshFileXferAttrSize != 0 { b = marshalUint64(b, fileStat.Size) } @@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte { } func marshal(b []byte, v interface{}) []byte { - if v == nil { - return b - } switch v := v.(type) { + case nil: + return b case uint8: return append(b, v) case uint32: @@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte { return marshalUint64(b, v) case string: return marshalString(b, v) + case []byte: + return append(b, v...) case os.FileInfo: return marshalFileInfo(b, v) default: @@ -180,8 +186,6 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { } if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { fs.UID, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { fs.GID, b, _ = unmarshalUint32Safe(b) } if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { @@ -681,12 +685,13 @@ type sshFxpOpenPacket struct { ID uint32 Path string Pflags uint32 - Flags uint32 // ignored + Flags uint32 + Attrs interface{} } func (p *sshFxpOpenPacket) id() uint32 { return p.ID } -func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) { l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 + 4 @@ -698,7 +703,20 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Flags) - return b, nil + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } + + return b, marshal(nil, p.Attrs), nil +} + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { @@ -709,12 +727,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { return err } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { return err - } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { return err } + p.Attrs = b return nil } +func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + type sshFxpReadPacket struct { ID uint32 Len uint32 @@ -943,9 +974,15 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) - payload := marshal(nil, p.Attrs) + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } - return b, payload, nil + return b, marshal(nil, p.Attrs), nil } func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { @@ -964,9 +1001,15 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalString(b, p.Handle) b = marshalUint32(b, p.Flags) - payload := marshal(nil, p.Attrs) + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } - return b, payload, nil + return b, marshal(nil, p.Attrs), nil } func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { @@ -987,6 +1030,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { return nil } +func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { var err error if p.ID, b, err = unmarshalUint32Safe(b); err != nil { @@ -1000,6 +1055,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { return nil } +func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + type sshFxpHandlePacket struct { ID uint32 Handle string diff --git a/packet_test.go b/packet_test.go index cbee5e4c..6278ca4f 100644 --- a/packet_test.go +++ b/packet_test.go @@ -376,7 +376,7 @@ func TestSendPacket(t *testing.T) { packet: &sshFxpOpenPacket{ ID: 1, Path: "/foo", - Pflags: flags(os.O_RDONLY), + Pflags: toPflags(os.O_RDONLY), }, want: []byte{ 0x0, 0x0, 0x0, 0x15, @@ -387,6 +387,26 @@ func TestSendPacket(t *testing.T) { 0x0, 0x0, 0x0, 0x0, }, }, + { + packet: &sshFxpOpenPacket{ + ID: 3, + Path: "/foo", + Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC), + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x19, + 0x3, + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x1a, + 0x0, 0x0, 0x0, 0x4, + 0x0, 0x0, 0x1, 0xed, + }, + }, { packet: &sshFxpWritePacket{ ID: 124, @@ -409,10 +429,7 @@ func TestSendPacket(t *testing.T) { ID: 31, Path: "/bar", Flags: sshFileXferAttrUIDGID, - Attrs: struct { - UID uint32 - GID uint32 - }{ + Attrs: &FileStat{ UID: 1000, GID: 100, }, @@ -611,7 +628,7 @@ func BenchmarkMarshalOpen(b *testing.B) { benchMarshal(b, &sshFxpOpenPacket{ ID: 1, Path: "/home/test/some/random/path", - Pflags: flags(os.O_RDONLY), + Pflags: toPflags(os.O_RDONLY), }) } diff --git a/request-attrs.go b/request-attrs.go index b5c95b4a..c86539cc 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -3,7 +3,6 @@ package sftp // Methods on the Request object to make working with the Flags bitmasks and // Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write // request and AttrFlags() and Attributes() when working with SetStat requests. -import "os" // FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags // (https://golang.org/pkg/os/#pkg-constants). @@ -50,11 +49,6 @@ func (r *Request) AttrFlags() FileAttrFlags { return newFileAttrFlags(r.Flags) } -// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode -func (a FileStat) FileMode() os.FileMode { - return os.FileMode(a.Mode) -} - // Attributes parses file attributes byte blob and return them in a // FileStat object. func (r *Request) Attributes() *FileStat { diff --git a/server.go b/server.go index 2e419f59..6e53e264 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,6 @@ import ( "strconv" "sync" "syscall" - "time" ) const ( @@ -462,7 +461,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) + mode := os.FileMode(0o644) + // Like OpenSSH, we only handle permissions here, if the file is being created. + // Otherwise, the permissions are ignored. + if p.Flags & sshFileXferAttrPermissions != 0 { + fs := p.unmarshalFileStat(p.Flags) + mode = fs.FileMode() & os.ModePerm + } + + f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode) if err != nil { return statusFromError(p.ID, err) } @@ -496,43 +503,32 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { } func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { - // additional unmarshalling is required for each possibility here - b := p.Attrs.([]byte) - var err error + path := svr.toLocalPath(p.Path) - p.Path = svr.toLocalPath(p.Path) + debug("setstat name %q", path) + + fs := p.unmarshalFileStat(p.Flags) + + var err error - debug("setstat name \"%s\"", p.Path) if (p.Flags & sshFileXferAttrSize) != 0 { - var size uint64 - if size, b, err = unmarshalUint64Safe(b); err == nil { - err = os.Truncate(p.Path, int64(size)) + if err == nil { + err = os.Truncate(path, int64(fs.Size)) } } if (p.Flags & sshFileXferAttrPermissions) != 0 { - var mode uint32 - if mode, b, err = unmarshalUint32Safe(b); err == nil { - err = os.Chmod(p.Path, os.FileMode(mode)) + if err == nil { + err = os.Chmod(path, fs.FileMode()) } } if (p.Flags & sshFileXferAttrACmodTime) != 0 { - var atime uint32 - var mtime uint32 - if atime, b, err = unmarshalUint32Safe(b); err != nil { - } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { - } else { - atimeT := time.Unix(int64(atime), 0) - mtimeT := time.Unix(int64(mtime), 0) - err = os.Chtimes(p.Path, atimeT, mtimeT) + if err == nil { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { - var uid uint32 - var gid uint32 - if uid, b, err = unmarshalUint32Safe(b); err != nil { - } else if gid, _, err = unmarshalUint32Safe(b); err != nil { - } else { - err = os.Chown(p.Path, int(uid), int(gid)) + if err == nil { + err = os.Chown(path, int(fs.UID), int(fs.GID)) } } @@ -545,41 +541,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { return statusFromError(p.ID, EBADF) } - // additional unmarshalling is required for each possibility here - b := p.Attrs.([]byte) + path := f.Name() + + debug("fsetstat name %q", path) + + fs := p.unmarshalFileStat(p.Flags) + var err error - debug("fsetstat name \"%s\"", f.Name()) if (p.Flags & sshFileXferAttrSize) != 0 { - var size uint64 - if size, b, err = unmarshalUint64Safe(b); err == nil { - err = f.Truncate(int64(size)) + if err == nil { + err = f.Truncate(int64(fs.Size)) } } if (p.Flags & sshFileXferAttrPermissions) != 0 { - var mode uint32 - if mode, b, err = unmarshalUint32Safe(b); err == nil { - err = f.Chmod(os.FileMode(mode)) + if err == nil { + err = f.Chmod(fs.FileMode()) } } if (p.Flags & sshFileXferAttrACmodTime) != 0 { - var atime uint32 - var mtime uint32 - if atime, b, err = unmarshalUint32Safe(b); err != nil { - } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { - } else { - atimeT := time.Unix(int64(atime), 0) - mtimeT := time.Unix(int64(mtime), 0) - err = os.Chtimes(f.Name(), atimeT, mtimeT) + if err == nil { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { - var uid uint32 - var gid uint32 - if uid, b, err = unmarshalUint32Safe(b); err != nil { - } else if gid, _, err = unmarshalUint32Safe(b); err != nil { - } else { - err = f.Chown(int(uid), int(gid)) + if err == nil { + err = f.Chown(int(fs.UID), int(fs.GID)) } } diff --git a/server_integration_test.go b/server_integration_test.go index 407d38a2..74a6f8a1 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -591,18 +591,25 @@ ls -l /usr/bin/ goWords := spaceRegex.Split(goLine, -1) opWords := spaceRegex.Split(opLine, -1) // some fields are allowed to be different.. - // words[2] and [3] as these are users & groups - // words[1] as the link count for directories like proc is unstable // during testing as processes are created/destroyed. - // words[7] as timestamp on dirs can very for things like /tmp for j, goWord := range goWords { if j >= len(opWords) { bad = true break } opWord := opWords[j] - if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 { - bad = true + if goWord != opWord { + switch j { + case 1, 2, 3, 7: + // words[1] as the link count for directories like proc is unstable + // words[2] and [3] as these are users & groups + // words[7] as timestamps on dirs can vary for things like /tmp + case 8: + // words[8] can either have full path or just the filename + bad = !strings.HasSuffix(opWord, "/" + goWord) + default: + bad = true + } } } } diff --git a/server_test.go b/server_test.go index 87beece5..110e0dee 100644 --- a/server_test.go +++ b/server_test.go @@ -178,21 +178,22 @@ func TestOpenStatRace(t *testing.T) { // openpacket finishes to fast to trigger race in tests // need to add a small sleep on server to openpackets somehow tmppath := path.Join(os.TempDir(), "stat_race") - pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) ch := make(chan result, 3) id1 := client.nextID() + id2 := client.nextID() client.dispatchRequest(ch, &sshFxpOpenPacket{ ID: id1, Path: tmppath, Pflags: pflags, }) - id2 := client.nextID() client.dispatchRequest(ch, &sshFxpLstatPacket{ ID: id2, Path: tmppath, }) testreply := func(id uint32) { r := <-ch + require.NoError(t, r.err) switch r.typ { case sshFxpAttrs, sshFxpHandle: // ignore case sshFxpStatus: @@ -208,6 +209,83 @@ func TestOpenStatRace(t *testing.T) { checkServerAllocator(t, server) } +func TestOpenWithPermissions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + skipIfWindows(t) + + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + tmppath := path.Join(os.TempDir(), "open_permissions") + defer os.Remove(tmppath) + + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + + id1 := client.nextID() + id2 := client.nextID() + + typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id1, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o745, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id1, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + stat, err := os.Stat(tmppath) + if err != nil { + t.Fatal("unexpected error:", err) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v was expecting 0o745", stat.Mode()) + } + + // Existing files should not have their permissions changed. + typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id2, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id2, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v, was expecting unchanged 0o745", stat.Mode()) + } + + checkServerAllocator(t, server) +} + // Ensure that proper error codes are returned for non existent files, such // that they are mapped back to a 'not exists' error on the client side. func TestStatNonExistent(t *testing.T) {