Skip to content

Commit

Permalink
Decide to use os.Mkdir when creating chain layer temp directories.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712408022
  • Loading branch information
Mario Leyva authored and copybara-github committed Jan 6, 2025
1 parent 56e5c3b commit 3c4c362
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 51 deletions.
1 change: 1 addition & 0 deletions artifact/image/layerscanning/image/file_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type fileNode struct {
originLayerID string
isWhiteout bool
virtualPath string
targetPath string
mode fs.FileMode
file *os.File
}
Expand Down
116 changes: 75 additions & 41 deletions artifact/image/layerscanning/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ import (
"github.com/google/go-containerregistry/pkg/v1/tarball"
scalibrImage "github.com/google/osv-scalibr/artifact/image"
"github.com/google/osv-scalibr/artifact/image/pathtree"
"github.com/google/osv-scalibr/artifact/image/symlink"
"github.com/google/osv-scalibr/artifact/image/whiteout"
"github.com/google/osv-scalibr/log"
)

const (
Expand Down Expand Up @@ -172,10 +174,8 @@ func FromV1Image(v1Image v1.Image, config *Config) (*Image, error) {
continue
}

// Create the chain layer directory if it doesn't exist.
dirPath := path.Join(tempPath, chainLayer.latestLayer.DiffID())

// TODO b/378491191 - Determine if an error should be thrown if the directory already exists. If
// so, we can probably use os.MkdirAll instead.
if err := os.Mkdir(dirPath, dirPermission); err != nil && !errors.Is(err, fs.ErrExist) {
return &outputImage, fmt.Errorf("failed to create chain layer directory: %w", err)
}
Expand Down Expand Up @@ -322,92 +322,126 @@ func fillChainLayerWithFilesFromTar(img *Image, tarReader *tar.Reader, originLay
// TODO: b/377553499 - Escape invalid characters on windows that's valid on linux
realFilePath := filepath.Join(dirPath, filepath.Clean(cleanedFilePath))

var fileMode fs.FileMode
// Write out the file/dir to disk.
var newNode *fileNode
switch header.Typeflag {
case tar.TypeDir:
fileMode, err = img.handleDir(realFilePath, tarReader, header)
if err != nil {
return fmt.Errorf("failed to handle directory: %w", err)
}

newNode, err = img.handleDir(realFilePath, virtualPath, originLayerID, tarReader, header, tombstone)
case tar.TypeReg:
// default:
newNode, err = img.handleFile(realFilePath, virtualPath, originLayerID, tarReader, header, tombstone)
case tar.TypeSymlink, tar.TypeLink:
newNode, err = img.handleSymlink(realFilePath, virtualPath, originLayerID, tarReader, header, tombstone)
default:
// TODO: b/374769529 - Handle symlinks.
// Assume if it's not a directory, it's a normal file.
fileMode, err = img.handleFile(realFilePath, tarReader, header)
if err != nil {
return fmt.Errorf("failed to handle file: %w", err)
}
log.Warnf("unsupported file type: %v", header.Typeflag)
continue
}

if err != nil {
return fmt.Errorf("failed to handle tar entry with path %s: %w", virtualPath, err)
}

// In each outer loop, a layer is added to each relevant output chainLayer slice. Because the
// outer loop is looping backwards (latest layer first), we ignore any files that are already in
// each chainLayer, as they would have been overwritten.
fillChainLayersWithVirtualPath(img, chainLayersToFill, originLayerID, virtualPath, tombstone, fileMode)
fillChainLayersWithFileNode(chainLayersToFill, newNode)
}

return nil
}

// handleSymlink returns the symlink header mode. Symlinks are handled by creating a fileNode with
// the symlink mode with additional metadata.
func (img *Image) handleSymlink(realFilePath, virtualPath, originLayerID string, tarReader *tar.Reader, header *tar.Header, isWhiteout bool) (*fileNode, error) {
targetPath := header.Linkname
if targetPath == "" {
return nil, fmt.Errorf("symlink header has no target path")
}

if symlink.TargetOutsideRoot(virtualPath, targetPath) {
log.Warnf("Found symlink that points outside the root, skipping: %q -> %q", virtualPath, targetPath)
return nil, fmt.Errorf("symlink points outside the root: %q -> %q", virtualPath, targetPath)
}

// Resolve the relative symlink path to an absolute path.
if !filepath.IsAbs(targetPath) {
targetPath = filepath.Join(filepath.Dir(virtualPath), targetPath)
}

return &fileNode{
extractDir: img.ExtractDir,
originLayerID: originLayerID,
virtualPath: virtualPath,
targetPath: targetPath,
isWhiteout: isWhiteout,
mode: fs.ModeSymlink,
}, nil
}

// handleDir creates the directory specified by path, if it doesn't exist.
func (img *Image) handleDir(path string, tarReader *tar.Reader, header *tar.Header) (fs.FileMode, error) {
if _, err := os.Stat(path); err != nil {
if err := os.MkdirAll(path, dirPermission); err != nil {
return 0, fmt.Errorf("failed to create directory with path %s: %w", path, err)
func (img *Image) handleDir(realFilePath, virtualPath, originLayerID string, tarReader *tar.Reader, header *tar.Header, isWhiteout bool) (*fileNode, error) {
if _, err := os.Stat(realFilePath); err != nil {
if err := os.MkdirAll(realFilePath, dirPermission); err != nil {
return nil, fmt.Errorf("failed to create directory with realFilePath %s: %w", realFilePath, err)
}
}
return fs.FileMode(header.Mode) | fs.ModeDir, nil
return &fileNode{
extractDir: img.ExtractDir,
originLayerID: originLayerID,
virtualPath: virtualPath,
isWhiteout: isWhiteout,
mode: fs.FileMode(header.Mode) | fs.ModeDir,
}, nil
}

// handleFile creates the file specified by path, and then copies the contents of the tarReader into
// the file.
func (img *Image) handleFile(path string, tarReader *tar.Reader, header *tar.Header) (fs.FileMode, error) {
func (img *Image) handleFile(realFilePath, virtualPath, originLayerID string, tarReader *tar.Reader, header *tar.Header, isWhiteout bool) (*fileNode, error) {
// Write all files as read/writable by the current user, inaccessible by anyone else
// Actual permission bits are stored in FileNode
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, filePermission)
f, err := os.OpenFile(realFilePath, os.O_CREATE|os.O_RDWR, filePermission)

if err != nil {
return 0, err
return nil, err
}
defer f.Close()

numBytes, err := io.Copy(f, io.LimitReader(tarReader, img.maxFileBytes))
if numBytes >= img.maxFileBytes || errors.Is(err, io.EOF) {
return 0, ErrFileReadLimitExceeded
return nil, ErrFileReadLimitExceeded
}

if err != nil {
return 0, fmt.Errorf("unable to copy file: %w", err)
return nil, fmt.Errorf("unable to copy file: %w", err)
}

return fs.FileMode(header.Mode), nil
return &fileNode{
extractDir: img.ExtractDir,
originLayerID: originLayerID,
virtualPath: virtualPath,
isWhiteout: isWhiteout,
mode: fs.FileMode(header.Mode),
}, nil
}

// fillChainLayersWithVirtualPath fills the chain layers with the virtual path.
func fillChainLayersWithVirtualPath(img *Image, chainLayers []*chainLayer, originLayerID, virtualPath string, isWhiteout bool, fileMode fs.FileMode) {
for _, chainLayer := range chainLayers {
// fillChainLayersWithFileNode fills the chain layers with a new fileNode.
func fillChainLayersWithFileNode(chainLayersToFill []*chainLayer, newNode *fileNode) {
virtualPath := newNode.virtualPath
for _, chainLayer := range chainLayersToFill {
if node := chainLayer.fileNodeTree.Get(virtualPath); node != nil {
// A newer version of the file already exists on a later chainLayer.
// Since we do not want to overwrite a later layer with information
// written in an earlier layer, skip this file.
continue
}

// check for a whited out parent directory
// Check for a whited out parent directory.
if inWhiteoutDir(chainLayer, virtualPath) {
// The entire directory has been deleted, so no need to save this file
// The entire directory has been deleted, so no need to save this file.
continue
}

// Add the file to the chain layer. If there is an error, then we fail open.
// TODO: b/379154069 - Add logging for fail open errors.
chainLayer.fileNodeTree.Insert(virtualPath, &fileNode{
extractDir: img.ExtractDir,
originLayerID: originLayerID,
virtualPath: virtualPath,
isWhiteout: isWhiteout,
mode: fileMode,
})
chainLayer.fileNodeTree.Insert(virtualPath, newNode)
}
}

Expand Down
8 changes: 5 additions & 3 deletions artifact/image/layerscanning/image/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,11 @@ func TestFromTarball(t *testing.T) {
if errors.Is(gotErr, tc.wantErr) {
return
}
t.Fatalf("Load(%v) returned error: %v, want error: %v", tc.tarPath, gotErr, tc.wantErr)
t.Fatalf("FromTarball(%v) returned error: %v, want error: %v", tc.tarPath, gotErr, tc.wantErr)
}

if gotErr != nil {
t.Fatalf("Load(%v) returned unexpected error: %v", tc.tarPath, gotErr)
t.Fatalf("FromTarball(%v) returned unexpected error: %v", tc.tarPath, gotErr)
}

chainLayers, err := gotImage.ChainLayers()
Expand Down Expand Up @@ -387,18 +387,20 @@ func TestFromV1Image(t *testing.T) {
// chainLayerEntries.
func compareChainLayerEntries(t *testing.T, gotChainLayer image.ChainLayer, wantChainLayerEntries chainLayerEntries) {
t.Helper()

chainfs := gotChainLayer.FS()

for _, filepathContentPair := range wantChainLayerEntries.filepathContentPairs {
gotFile, err := chainfs.Open(filepathContentPair.filepath)
if err != nil {
t.Fatalf("Open(%v) returned error: %v", filepathContentPair.filepath, err)
}
defer gotFile.Close()

contentBytes, err := io.ReadAll(gotFile)
if err != nil {
t.Fatalf("ReadAll(%v) returned error: %v", filepathContentPair.filepath, err)
}

gotContent := string(contentBytes[:])
if diff := cmp.Diff(gotContent, filepathContentPair.content); diff != "" {
t.Errorf("Open(%v) returned incorrect content: got %s, want %s", filepathContentPair.filepath, gotContent, filepathContentPair.content)
Expand Down
41 changes: 36 additions & 5 deletions artifact/image/layerscanning/image/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ var (
ErrDiffIDMissingFromLayer = errors.New("failed to get diffID from v1 layer")
// ErrUncompressedReaderMissingFromLayer is returned when the uncompressed reader is missing from a v1 layer.
ErrUncompressedReaderMissingFromLayer = errors.New("failed to get uncompressed reader from v1 layer")

// DefaultMaxSymlinkDepth is the default maximum symlink depth.
DefaultMaxSymlinkDepth = 3
)

// ========================================================
Expand Down Expand Up @@ -109,7 +112,8 @@ type chainLayer struct {
func (chainLayer *chainLayer) FS() scalibrfs.FS {
// root should be "/" given we are dealing with file paths.
return &FS{
tree: chainLayer.fileNodeTree,
tree: chainLayer.fileNodeTree,
maxSymlinkDepth: DefaultMaxSymlinkDepth,
}
}

Expand All @@ -129,7 +133,24 @@ func (chainLayer *chainLayer) Layer() image.Layer {

// FS implements the scalibrfs.FS interface that will be used when scanning for inventory.
type FS struct {
tree *pathtree.Node[fileNode]
tree *pathtree.Node[fileNode]
maxSymlinkDepth int
}

// resolveSymlink resolves a fileNode that represents a symlink.
func (chainfs FS) resolveSymlink(node *fileNode, depth int) (*fileNode, error) {
if depth == 0 {
return nil, fmt.Errorf("symlink depth exceeded")
}
if node.mode != fs.ModeSymlink {
return node, nil
}

linkedNode, err := chainfs.getFileNode(node.targetPath)
if err != nil {
return nil, fmt.Errorf("failed to get file node with virtual path %s: %w", linkedNode.targetPath, err)
}
return chainfs.resolveSymlink(linkedNode, depth-1)
}

// Open opens a file from the virtual filesystem.
Expand All @@ -138,16 +159,26 @@ func (chainfs FS) Open(name string) (fs.File, error) {
if err != nil {
return nil, fmt.Errorf("failed to get file node to open %s: %w", name, err)
}
return fileNode, nil

resolvedNode, err := chainfs.resolveSymlink(fileNode, chainfs.maxSymlinkDepth)
if err != nil {
return nil, fmt.Errorf("failed to resolve symlink for file node %s: %w", fileNode.virtualPath, err)
}
return resolvedNode, nil
}

// Stat returns a FileInfo object describing the file found at name.
func (chainfs *FS) Stat(name string) (fs.FileInfo, error) {
fileNode, err := chainfs.getFileNode(name)
node, err := chainfs.getFileNode(name)
if err != nil {
return nil, fmt.Errorf("failed to get file node to stat %s: %w", name, err)
}
return fileNode.Stat()

resolvedNode, err := chainfs.resolveSymlink(node, chainfs.maxSymlinkDepth)
if err != nil {
return nil, fmt.Errorf("failed to resolve symlink for file node %s: %w", node.virtualPath, err)
}
return resolvedNode.Stat()
}

// ReadDir returns the directory entries found at path name.
Expand Down
6 changes: 4 additions & 2 deletions artifact/image/layerscanning/image/layer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ func setUpEmptyChainFS(t *testing.T) FS {
t.Helper()

return FS{
tree: pathtree.NewNode[fileNode](),
tree: pathtree.NewNode[fileNode](),
maxSymlinkDepth: DefaultMaxSymlinkDepth,
}
}

Expand All @@ -483,7 +484,8 @@ func setUpChainFS(t *testing.T) (FS, string) {
tempDir := t.TempDir()

chainfs := FS{
tree: pathtree.NewNode[fileNode](),
tree: pathtree.NewNode[fileNode](),
maxSymlinkDepth: DefaultMaxSymlinkDepth,
}

vfsMap := map[string]*fileNode{
Expand Down

0 comments on commit 3c4c362

Please sign in to comment.