diff --git a/archives.go b/archives.go index 790ccab..db59eee 100644 --- a/archives.go +++ b/archives.go @@ -100,10 +100,20 @@ func FilesFromDisk(ctx context.Context, options *FromDiskOptions, filenames map[ var linkTarget string if isSymlink(info) { if options != nil && options.FollowSymlinks { + originalFilename := filename filename, info, err = followSymlink(filename) if err != nil { return err } + if info.IsDir() { + symlinkDirFiles, err := FilesFromDisk(ctx, options, map[string]string{filename: nameInArchive}) + if err != nil { + return fmt.Errorf("getting files from symlink directory %s dereferenced to %s: %w", originalFilename, linkTarget, err) + } + + files = append(files, symlinkDirFiles...) + return nil + } } else { // preserve symlinks linkTarget, err = os.Readlink(filename) @@ -128,6 +138,7 @@ func FilesFromDisk(ctx context.Context, options *FromDiskOptions, filenames map[ } files = append(files, file) + return nil }) if walkErr != nil { diff --git a/archives_test.go b/archives_test.go index 623bf0c..b2058e1 100644 --- a/archives_test.go +++ b/archives_test.go @@ -1,11 +1,13 @@ package archives import ( + "context" "fmt" "os" "path/filepath" "reflect" "runtime" + "sort" "strings" "testing" ) @@ -266,6 +268,13 @@ func TestNameOnDiskToNameInArchive(t *testing.T) { } } +func fixSeparators(path string) string { + if runtime.GOOS == "windows" { + return strings.ReplaceAll(path, "/", "\\") + } + return path +} + func TestFollowSymlink(t *testing.T) { // Create temp directory for tests tmpDir := t.TempDir() @@ -510,3 +519,46 @@ func TestFollowSymlink(t *testing.T) { } }) } + +func TestFilesFromDisk_SymlinkOutsideFileNamesMap(t *testing.T) { + tmpDir := t.TempDir() + otherTmpDir := t.TempDir() + + testDirName := "test_dir" + testDir := filepath.Join(otherTmpDir, testDirName) + if err := os.Mkdir(testDir, 0755); err != nil { + t.Fatal(err) + } + + testFileName := "test.txt" + testFile := filepath.Join(testDir, testFileName) + if err := os.WriteFile(testFile, []byte("test content"), 0644); err != nil { + t.Fatal(err) + } + + symlinkDirName := "symlink_dir" + symlinkDir := filepath.Join(tmpDir, symlinkDirName) + if err := os.Symlink(testDir, symlinkDir); err != nil { + t.Fatal(err) + } + + files, err := FilesFromDisk(context.Background(), &FromDiskOptions{ + FollowSymlinks: true, + }, map[string]string{symlinkDir: ""}) + if err != nil { + t.Fatal(err) + } + + sort.Slice(files, func(i, j int) bool { + return files[i].NameInArchive < files[j].NameInArchive + }) + + if files[0].NameInArchive != symlinkDirName { + t.Fatalf("expected file name '%s', got '%s'", symlinkDirName, files[0].NameInArchive) + } + + testFilePath := fmt.Sprintf("%s/%s", symlinkDirName, testFileName) + if files[1].NameInArchive != testFilePath { + t.Fatalf("expected file name '%s', got '%s'", testFilePath, files[1].NameInArchive) + } +}