Skip to content

Commit

Permalink
Merge pull request #91 from k0sproject/cleanup
Browse files Browse the repository at this point in the history
Clean up, restructure
  • Loading branch information
kke authored Feb 27, 2023
2 parents 889cdc2 + 4380adc commit fdd82b2
Show file tree
Hide file tree
Showing 39 changed files with 579 additions and 622 deletions.
7 changes: 4 additions & 3 deletions cmd/rigtest/rigtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/k0sproject/rig/os"
"github.com/k0sproject/rig/os/registry"
_ "github.com/k0sproject/rig/os/support"
"github.com/k0sproject/rig/pkg/rigfs"
"github.com/kevinburke/ssh_config"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -235,7 +236,7 @@ func main() {
require.False(t, h.Configurer.FileExist(h, fn))

testFileSize := int64(1 << (10 * 2)) // 1MB
fsyses := []rig.FS{h.Fsys(), h.SudoFsys()}
fsyses := []rigfs.Fsys{h.Fsys(), h.SudoFsys()}

for idx, fsys := range fsyses {
t.Run("fsys functions (%d) on %s", idx+1, h)
Expand All @@ -244,7 +245,7 @@ func main() {
shasum := sha256.New()
reader := io.TeeReader(origin, shasum)

destf, err := fsys.OpenFile(fn, rig.ModeCreate, 0644)
destf, err := fsys.OpenFile(fn, rigfs.ModeCreate, 0644)
require.NoError(t, err, "open file")

n, err := io.Copy(destf, reader)
Expand All @@ -262,7 +263,7 @@ func main() {

require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote")

destf, err = fsys.OpenFile(fn, rig.ModeRead, 0644)
destf, err = fsys.OpenFile(fn, rigfs.ModeRead, 0644)
require.NoError(t, err, "open file for read")

readSha := sha256.New()
Expand Down
80 changes: 25 additions & 55 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"crypto/sha256"
"fmt"
"io"
"io/fs"
"os"
"strings"

Expand All @@ -16,12 +15,12 @@ import (
"github.com/k0sproject/rig/exec"
"github.com/k0sproject/rig/log"
rigos "github.com/k0sproject/rig/os"
"github.com/k0sproject/rig/pkg/rigfs"
)

var _ rigos.Host = &Connection{}
var _ rigos.Host = (*Connection)(nil)

// Waiter is an interface that has a Wait() function that blocks until a command is finished
type Waiter interface {
type waiter interface {
Wait() error
}

Expand All @@ -30,7 +29,7 @@ type client interface {
Disconnect()
IsWindows() bool
Exec(string, ...exec.Option) error
ExecStreams(string, io.ReadCloser, io.Writer, io.Writer, ...exec.Option) (Waiter, error)
ExecStreams(string, io.ReadCloser, io.Writer, io.Writer, ...exec.Option) (waiter, error)
ExecInteractive(string) error
String() string
Protocol() string
Expand Down Expand Up @@ -81,29 +80,8 @@ type Connection struct {

client client `yaml:"-"`
sudofunc sudofn
fsys FS
sudofsys FS
}

// File is a file on a remote host
type File interface {
Seek(int64, int) (int64, error)
CopyFromN(io.Reader, int64, io.Writer) (int64, error)
Copy(io.Writer) (int, error)
Write([]byte) (int, error)
Read([]byte) (int, error)
Stat() (fs.FileInfo, error)
Close() error
}

// FS is a fs.FS compatible filesystem interface for filesystems on remote hosts
type FS interface {
Open(name string) (fs.File, error)
OpenFile(name string, mode FileMode, perm int) (File, error)
Stat(name string) (fs.FileInfo, error)
Sha256(name string) (string, error)
ReadDir(name string) ([]fs.DirEntry, error)
Delete(name string) error
fsys rigfs.Fsys
sudofsys rigfs.Fsys
}

// SetDefaults sets a connection
Expand Down Expand Up @@ -175,26 +153,18 @@ func (c Connection) String() string {
}

// Fsys returns a fs.FS compatible filesystem interface for accessing files on remote hosts
func (c *Connection) Fsys() FS {
func (c *Connection) Fsys() rigfs.Fsys {
if c.fsys == nil {
if c.IsWindows() {
c.fsys = newWindowsFsys(c)
} else {
c.fsys = newUnixFsys(c)
}
c.fsys = rigfs.NewFsys(c)
}

return c.fsys
}

// SudoFsys returns a fs.FS compatible filesystem interface for accessing files on remote hosts with sudo permissions
func (c *Connection) SudoFsys() FS {
func (c *Connection) SudoFsys() rigfs.Fsys {
if c.sudofsys == nil {
if c.IsWindows() {
c.sudofsys = newWindowsFsys(c, exec.Sudo(c))
} else {
c.sudofsys = newUnixFsys(c, exec.Sudo(c))
}
c.sudofsys = rigfs.NewFsys(c, exec.Sudo(c))
}

return c.sudofsys
Expand All @@ -212,13 +182,13 @@ func (c *Connection) IsWindows() bool {

// ExecStreams executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that
// blocks until the command finishes and returns an error if the exit code is not zero.
func (c Connection) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (Waiter, error) {
func (c Connection) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (rigfs.Waiter, error) {
if err := c.checkConnected(); err != nil {
return nil, ErrNotConnected.Wrapf("exec streams")
return nil, fmt.Errorf("%w: exec with streams: %w", ErrCommandFailed, err)
}
waiter, err := c.client.ExecStreams(cmd, stdin, stdout, stderr, opts...)
if err != nil {
return nil, ErrCommandFailed.Wrapf("exec (with streams): %w", err)
return nil, fmt.Errorf("%w: exec with streams: %w", ErrCommandFailed, err)
}
return waiter, nil
}
Expand All @@ -230,7 +200,7 @@ func (c Connection) Exec(cmd string, opts ...exec.Option) error {
}

if err := c.client.Exec(cmd, opts...); err != nil {
return ErrCommandFailed.Wrapf("client exec: %w", err)
return fmt.Errorf("%w: client exec: %w", ErrCommandFailed, err)
}

return nil
Expand All @@ -252,14 +222,14 @@ func (c Connection) ExecOutput(cmd string, opts ...exec.Option) (string, error)
func (c *Connection) Connect() error {
if c.client == nil {
if err := defaults.Set(c); err != nil {
return ErrValidationFailed.Wrapf("set defaults: %w", err)
return fmt.Errorf("%w: set defaults: %w", ErrValidationFailed, err)
}
}

if err := c.client.Connect(); err != nil {
c.client = nil
log.Debugf("%s: failed to connect: %v", c, err)
return ErrNotConnected.Wrapf("client connect: %w", err)
return fmt.Errorf("%w: client connect: %w", ErrNotConnected, err)
}

if c.OSVersion == nil {
Expand Down Expand Up @@ -339,7 +309,7 @@ func (c *Connection) configureSudo() {
// Sudo formats a command string to be run with elevated privileges
func (c Connection) Sudo(cmd string) (string, error) {
if c.sudofunc == nil {
return "", ErrSudoRequired.Wrapf("user is not an administrator and passwordless access elevation has not been configured")
return "", fmt.Errorf("%w: user is not an administrator and passwordless access elevation has not been configured", ErrSudoRequired)
}

return c.sudofunc(cmd), nil
Expand All @@ -366,7 +336,7 @@ func (c Connection) ExecInteractive(cmd string) error {
}

if err := c.client.ExecInteractive(cmd); err != nil {
return ErrCommandFailed.Wrapf("client exec interactive: %w", err)
return fmt.Errorf("%w: client exec interactive: %w", ErrCommandFailed, err)
}

return nil
Expand All @@ -388,36 +358,36 @@ func (c *Connection) Upload(src, dst string, opts ...exec.Option) error {
}
local, err := os.Open(src)
if err != nil {
return ErrInvalidPath.Wrap(err)
return fmt.Errorf("%w: %w", ErrInvalidPath, err)
}
defer local.Close()

stat, err := local.Stat()
if err != nil {
return ErrInvalidPath.Wrapf("stat local file %s: %w", src, err)
return fmt.Errorf("%w: stat local file %s: %w", ErrInvalidPath, src, err)
}

shasum := sha256.New()

fsys := c.Fsys()
remote, err := fsys.OpenFile(dst, ModeCreate, int(stat.Mode()))
remote, err := fsys.OpenFile(dst, rigfs.ModeCreate, rigfs.FileMode(stat.Mode()))
if err != nil {
return ErrInvalidPath.Wrapf("open remote file for writing: %w", err)
return fmt.Errorf("%w: open remote file %s for writing: %w", ErrInvalidPath, dst, err)
}
defer remote.Close()

if _, err := remote.CopyFromN(local, stat.Size(), shasum); err != nil {
return ErrUploadFailed.Wrapf("copy file to remote host: %w", err)
return fmt.Errorf("%w: copy file %s to remote host: %w", ErrUploadFailed, dst, err)
}

log.Debugf("%s: post-upload validate checksum of %s", c, dst)
remoteSum, err := fsys.Sha256(dst)
if err != nil {
return ErrUploadFailed.Wrapf("validate checksum of %s: %w", dst, err)
return fmt.Errorf("%w: validate %s checksum: %w", ErrUploadFailed, dst, err)
}

if remoteSum != fmt.Sprintf("%x", shasum.Sum(nil)) {
return ErrUploadFailed.Wrapf("checksum mismatch")
return fmt.Errorf("%w: checksum mismatch", ErrUploadFailed)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (m *mockClient) Exec(cmd string, opts ...exec.Option) error {

return nil
}
func (m *mockClient) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (Waiter, error) {
func (m *mockClient) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) {
return nil, fmt.Errorf("not implemented")
}

Expand Down
28 changes: 13 additions & 15 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
package rig

import (
"github.com/k0sproject/rig/errstring"
)
import "errors"

var (
ErrOS = errstring.New("local os") // ErrOS is returned when an action fails on local OS
ErrInvalidPath = errstring.New("invalid path") // ErrInvalidPath is returned when a path is invalid
ErrValidationFailed = errstring.New("validation failed") // ErrValidationFailed is returned when a validation fails
ErrSudoRequired = errstring.New("sudo required") // ErrSudoRequired is returned when sudo is required
ErrNotFound = errstring.New("not found") // ErrNotFound is returned when a resource is not found
ErrNotImplemented = errstring.New("not implemented") // ErrNotImplemented is returned when a feature is not implemented
ErrNotSupported = errstring.New("not supported") // ErrNotSupported is returned when a feature is not supported
ErrAuthFailed = errstring.New("authentication failed") // ErrAuthFailed is returned when authentication fails
ErrUploadFailed = errstring.New("upload failed") // ErrUploadFailed is returned when an upload fails
ErrNotConnected = errstring.New("not connected") // ErrNotConnected is returned when a connection is not established
ErrCantConnect = errstring.New("can't connect") // ErrCantConnect is returned when a connection is not established and retrying will fail
ErrCommandFailed = errstring.New("command failed") // ErrCommandFailed is returned when a command fails
ErrOS = errors.New("local os") // ErrOS is returned when an action fails on local OS
ErrInvalidPath = errors.New("invalid path") // ErrInvalidPath is returned when a path is invalid
ErrValidationFailed = errors.New("validation failed") // ErrValidationFailed is returned when a validation fails
ErrSudoRequired = errors.New("sudo required") // ErrSudoRequired is returned when sudo is required
ErrNotFound = errors.New("not found") // ErrNotFound is returned when a resource is not found
ErrNotImplemented = errors.New("not implemented") // ErrNotImplemented is returned when a feature is not implemented
ErrNotSupported = errors.New("not supported") // ErrNotSupported is returned when a feature is not supported
ErrAuthFailed = errors.New("authentication failed") // ErrAuthFailed is returned when authentication fails
ErrUploadFailed = errors.New("upload failed") // ErrUploadFailed is returned when an upload fails
ErrNotConnected = errors.New("not connected") // ErrNotConnected is returned when a connection is not established
ErrCantConnect = errors.New("can't connect") // ErrCantConnect is returned when a connection is not established and retrying will fail
ErrCommandFailed = errors.New("command failed") // ErrCommandFailed is returned when a command fails
)
75 changes: 0 additions & 75 deletions errors_test.go

This file was deleted.

Loading

0 comments on commit fdd82b2

Please sign in to comment.