Skip to content

Commit

Permalink
Merge pull request #2 from soulteary/feat/smart-scan
Browse files Browse the repository at this point in the history
fear: smart config scanner
  • Loading branch information
soulteary authored Dec 23, 2024
2 parents 326287a + 541b993 commit c047a98
Show file tree
Hide file tree
Showing 7 changed files with 880 additions and 38 deletions.
10 changes: 10 additions & 0 deletions internal/define/define.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ type YAMLOutput struct {
Default map[string]string `yaml:"default,omitempty"`
Groups map[string]GroupConfig `yaml:",inline"`
}

var ExcludePatterns = []string{
"known_hosts",
"authorized_keys",
"*.pub",
"id_*",
"*.key",
"*.pem",
"*.ppk",
}
32 changes: 9 additions & 23 deletions internal/fn/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,33 +99,19 @@ func DetectStringType(input string) string {
}

func GetPathContent(src string) ([]byte, error) {
srcInfo, err := os.Stat(src)
configFiles, err := ReadSSHConfigs(src)
if err != nil {
return nil, fmt.Errorf("can not get source info: %v", err)
return nil, err
}
if len(configFiles.Configs) == 0 {
return nil, fmt.Errorf("no valid SSH config found in %s", src)
}

var content []byte

if srcInfo.IsDir() {
files, err := os.ReadDir(src)
if err != nil {
return nil, fmt.Errorf("can not read source directory: %v", err)
}

for _, file := range files {
if !file.IsDir() {
filePath := filepath.Join(src, file.Name())
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("can not read file %s: %v", filePath, err)
}
content = append(content, fileContent...)
}
}
} else {
content, err = os.ReadFile(src)
if err != nil {
return nil, fmt.Errorf("can not read source file: %v", err)
for filePath := range configFiles.Configs {
fileContent, err := os.ReadFile(filePath)
if err == nil {
content = append(content, fileContent...)
}
}
return content, nil
Expand Down
119 changes: 114 additions & 5 deletions internal/fn/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

Define "github.com/soulteary/ssh-config/internal/define"
"github.com/soulteary/ssh-config/internal/fn"
Fn "github.com/soulteary/ssh-config/internal/fn"
)

Expand Down Expand Up @@ -478,8 +479,8 @@ func TestGetPathContent(t *testing.T) {

file1 := filepath.Join(multiDir, "file1.txt")
file2 := filepath.Join(multiDir, "file2.txt")
content1 := []byte("Content of file 1")
content2 := []byte("Content of file 2")
content1 := []byte("Host test1")
content2 := []byte("Host test2")

err = os.WriteFile(file1, content1, 0644)
if err != nil {
Expand Down Expand Up @@ -532,7 +533,7 @@ func TestGetPathContent(t *testing.T) {

_, err = Fn.GetPathContent(dirWithUnreadableFile)
if err == nil {
t.Error("Expected error for directory with unreadable file, got nil")
t.Error("Expected error for no valid SSH config found in, got nil", err)
}

unreadableFile2 := filepath.Join(tempDir, "unreadable_single.txt")
Expand All @@ -550,8 +551,31 @@ func TestGetPathContent(t *testing.T) {
_, err = Fn.GetPathContent(unreadableFile2)
if err == nil {
t.Error("Expected error for unreadable single file, got nil")
} else if !strings.Contains(err.Error(), "can not read source file") {
t.Errorf("Expected error message to contain 'can not read source file', got: %v", err)
} else if !strings.Contains(err.Error(), "no valid SSH config found in") {
t.Errorf("Expected error message to contain 'no valid SSH config found in', got: %v", err)
}

dirWithCorruptFile := filepath.Join(tempDir, "dir_with_corrupt")
err = os.Mkdir(dirWithCorruptFile, 0755)
if err != nil {
t.Fatalf("Failed to create dir_with_corrupt: %v", err)
}

normalFile := filepath.Join(dirWithCorruptFile, "normal.txt")
err = os.WriteFile(normalFile, []byte("Normal content"), 0644)
if err != nil {
t.Fatalf("Failed to create normal file: %v", err)
}

corruptFile := filepath.Join(dirWithCorruptFile, "corrupt.txt")
err = os.Symlink("/nonexistent/file", corruptFile)
if err != nil {
t.Fatalf("Failed to create corrupt file: %v", err)
}

_, err = Fn.GetPathContent(dirWithCorruptFile)
if err == nil {
t.Fatalf("Expected error for directory with corrupt file, got nil")
}
}

Expand Down Expand Up @@ -752,3 +776,88 @@ func TestTidyLastEmptyLines(t *testing.T) {
})
}
}

//

// SSHConfig 模拟原始结构
type SSHConfig struct {
Configs map[string]interface{}
}

// 创建测试用的配置文件目录
func createTestConfigDir(t *testing.T) (string, error) {
tmpDir := t.TempDir()

// 创建测试文件1
test1Path := filepath.Join(tmpDir, "test1.txt")
err := os.WriteFile(test1Path, []byte("Host abc"), 0644)
if err != nil {
return "", err
}

// 创建测试文件2
test2Path := filepath.Join(tmpDir, "test2.txt")
err = os.WriteFile(test2Path, []byte("Host def"), 0644)
if err != nil {
return "", err
}

// 创建一个无权限的文件
noPermFile := filepath.Join(tmpDir, "no_perm.txt")
err = os.WriteFile(noPermFile, []byte("no permission"), 0644)
if err != nil {
return "", err
}
err = os.Chmod(noPermFile, 0000) // 移除所有权限
if err != nil {
return "", err
}

return tmpDir, nil
}

func TestGetPathContent2(t *testing.T) {
// 测试场景1: 成功读取文件
t.Run("Success case", func(t *testing.T) {
tmpDir, err := createTestConfigDir(t)
if err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}

content, err := fn.GetPathContent(tmpDir)

// 验证无权限文件的内容没有被包含
if strings.Contains(string(content), "no permission") {
t.Error("Content should not contain 'no permission' as the file is not readable")
}
})

// 测试场景2: 配置目录不存在
t.Run("Non-existent directory", func(t *testing.T) {
content, err := fn.GetPathContent("non_existent_dir")
if err == nil {
t.Error("Expected an error, got nil")
}
if content != nil {
t.Error("Expected nil content")
}
})

// 测试场景3: 文件读取失败(权限问题)
t.Run("File read error due to permissions", func(t *testing.T) {
tmpDir, err := createTestConfigDir(t)
if err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}

content, err := fn.GetPathContent(tmpDir)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

// 验证无权限文件的内容没有被包含
if strings.Contains(string(content), "no permission") {
t.Error("Content should not contain 'no permission' as the file is not readable")
}
})
}
186 changes: 186 additions & 0 deletions internal/fn/scanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package fn

import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/soulteary/ssh-config/internal/define"
)

type ConfigFile struct {
Path string
Content []string
Hosts map[string]map[string]string
}

type SSHConfig struct {
Configs map[string]*ConfigFile // key: 配置文件路径
}

func IsExcluded(filename string) bool {
filename = strings.ToLower(filename)

for _, pattern := range define.ExcludePatterns {
if matched, _ := filepath.Match(pattern, filename); matched {
return true
}
}

return false
}

func IsConfigFile(path string) bool {
// read file first few lines to determine if it's SSH config file format
file, err := os.Open(path)
if err != nil {
return false
}
defer file.Close()

scanner := bufio.NewScanner(file)
lineCount := 0
validLines := 0

// check first 5 lines
for scanner.Scan() && lineCount < 5 {
line := strings.TrimSpace(scanner.Text())
lineCount++

if line == "" || strings.HasPrefix(line, "#") {
continue
}

parts := strings.Fields(line)
if len(parts) >= 2 {
key := strings.ToLower(parts[0])
switch key {
case "host", "hostname", "user", "port", "identityfile", "proxycommand":
validLines++
}
}
}

return validLines > 0
}

func ReadSSHConfigs(sshPath string) (*SSHConfig, error) {
config := &SSHConfig{
Configs: make(map[string]*ConfigFile),
}

info, err := os.Stat(sshPath)
if err != nil {
return nil, fmt.Errorf("failed to get path object info: %v", err)
}

if !info.IsDir() {
configFile := ReadSingleConfig(sshPath)
if configFile != nil {
config.Configs[sshPath] = configFile
}
return config, nil
}

err = filepath.Walk(sshPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if info.IsDir() {
return nil
}

if IsExcluded(info.Name()) {
return nil
}

if !IsConfigFile(path) {
return nil
}

configFile := ReadSingleConfig(path)
if configFile != nil {
config.Configs[path] = configFile
}
return nil
})

if err != nil {
return nil, fmt.Errorf("failed to walk directory: %v", err)
}

return config, nil
}

func ReadSingleConfig(path string) *ConfigFile {
file, err := os.Open(path)
if err != nil {
return nil
}
defer file.Close()

config := &ConfigFile{
Path: path,
Hosts: make(map[string]map[string]string),
}

scanner := bufio.NewScanner(file)
var currentHost string
var content []string

for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
content = append(content, line)

if line == "" || strings.HasPrefix(line, "#") {
continue
}

parts := strings.Fields(line)
if len(parts) == 2 {
key := strings.ToLower(parts[0])
value := strings.Join(parts[1:], " ")

if key == "host" {
currentHost = value
config.Hosts[currentHost] = make(map[string]string)
} else if currentHost != "" {
config.Hosts[currentHost][key] = value
}
}
}

if err := scanner.Err(); err != nil {
return nil
}

config.Content = content
return config
}

func (c *SSHConfig) GetHostConfig(host string) map[string]map[string]string {
results := make(map[string]map[string]string)

for path, config := range c.Configs {
if hostConfig, exists := config.Hosts[host]; exists {
results[path] = hostConfig
}
}

return results
}

func (c *SSHConfig) PrintConfigs() {
for path, config := range c.Configs {
fmt.Printf("\n=== 配置文件: %s ===\n", path)
for host, hostConfig := range config.Hosts {
fmt.Printf("\nHost %s:\n", host)
for key, value := range hostConfig {
fmt.Printf(" %s = %s\n", key, value)
}
}
}
}
Loading

0 comments on commit c047a98

Please sign in to comment.