diff --git a/cached_downloader.go b/cached_downloader.go index cd25860..840d53e 100644 --- a/cached_downloader.go +++ b/cached_downloader.go @@ -23,7 +23,7 @@ type CachedDownloader interface { } func NoopTransform(source, destination string) (int64, error) { - err := os.Rename(source, destination) + err := replace(source, destination) if err != nil { return 0, err } diff --git a/cached_downloader_test.go b/cached_downloader_test.go index 351cd44..f9fa7a2 100644 --- a/cached_downloader_test.go +++ b/cached_downloader_test.go @@ -472,6 +472,9 @@ var _ = Describe("File cache", func() { )) url, _ = Url.Parse(server.URL() + "/A") + // windows timer increments every 15.6ms, without the sleep + // A, B & C will sometimes have the same timestamp + time.Sleep(16 * time.Millisecond) cache.Fetch(url, "A", NoopTransform, cancelChan) }) diff --git a/file_cache.go b/file_cache.go index f949ae0..792e8cb 100644 --- a/file_cache.go +++ b/file_cache.go @@ -80,7 +80,7 @@ func (e *fileCacheEntry) readCloser() (*CachedFile, error) { return readCloser, nil } -func (c *FileCache) Add(cacheKey string, sourcePath string, size int64, cachingInfo CachingInfoType) (*CachedFile, error) { +func (c *FileCache) Add(cacheKey, sourcePath string, size int64, cachingInfo CachingInfoType) (*CachedFile, error) { lock.Lock() defer lock.Unlock() @@ -105,7 +105,7 @@ func (c *FileCache) Add(cacheKey string, sourcePath string, size int64, cachingI uniqueName := fmt.Sprintf("%s-%d-%d", cacheKey, time.Now().UnixNano(), c.seq) cachePath := filepath.Join(c.cachedPath, uniqueName) - err := os.Rename(sourcePath, cachePath) + err := replace(sourcePath, cachePath) if err != nil { return nil, err } diff --git a/replace.go b/replace.go new file mode 100644 index 0000000..f9fc404 --- /dev/null +++ b/replace.go @@ -0,0 +1,11 @@ +// +build !windows + +package cacheddownloader + +import "os" + +// if you are wondering why we have this function, see `replace' +// implementation in replace_windows.go +func replace(src, dst string) error { + return os.Rename(src, dst) +} diff --git a/replace_windows.go b/replace_windows.go new file mode 100644 index 0000000..38ccfe4 --- /dev/null +++ b/replace_windows.go @@ -0,0 +1,45 @@ +package cacheddownloader + +import ( + "syscall" + "unsafe" +) + +// Replaces `dst' with `src' atomically. Under linux we only have to +// call os.Rename(), on windows os.Rename() will error if the +// destination exists already. The replace function serves as a +// unified interface on both platforms. +func replace(src, dst string) error { + kernel32, err := syscall.LoadLibrary("kernel32.dll") + if err != nil { + return err + } + defer syscall.FreeLibrary(kernel32) + moveFileExUnicode, err := syscall.GetProcAddress(kernel32, "MoveFileExW") + if err != nil { + return err + } + + srcString, err := syscall.UTF16PtrFromString(src) + if err != nil { + return err + } + + dstString, err := syscall.UTF16PtrFromString(dst) + if err != nil { + return err + } + + srcPtr := uintptr(unsafe.Pointer(srcString)) + dstPtr := uintptr(unsafe.Pointer(dstString)) + + MOVEFILE_REPLACE_EXISTING := 0x1 + flag := uintptr(MOVEFILE_REPLACE_EXISTING) + + _, _, callErr := syscall.Syscall(uintptr(moveFileExUnicode), 3, srcPtr, dstPtr, flag) + if callErr != 0 { + return callErr + } + + return nil +} diff --git a/tar_transformer.go b/tar_transformer.go index 006dbb8..30bbf47 100644 --- a/tar_transformer.go +++ b/tar_transformer.go @@ -110,6 +110,7 @@ func transformZipToTar(path, destPath string) (int64, error) { if err != nil { return 0, err } + defer dest.Close() zr, err := zip.OpenReader(path) if err != nil { @@ -136,6 +137,11 @@ func transformZipToTar(path, destPath string) (int64, error) { return 0, err } + err = zr.Close() + if err != nil { + return 0, err + } + err = os.Remove(path) if err != nil { return 0, err diff --git a/tar_transformer_windows_test.go b/tar_transformer_windows_test.go new file mode 100644 index 0000000..537f08d --- /dev/null +++ b/tar_transformer_windows_test.go @@ -0,0 +1,68 @@ +package cacheddownloader_test + +import ( + "io/ioutil" + "os" + "path/filepath" + + "github.com/pivotal-golang/archiver/extractor/test_helper" + . "github.com/pivotal-golang/cacheddownloader" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("TarTransformer", func() { + var ( + scratch string + + sourcePath string + destinationPath string + + transformedSize int64 + transformErr error + ) + + archiveFiles := []test_helper.ArchiveFile{ + {Name: "some-file", Body: "some-contents"}, + } + + BeforeEach(func() { + var err error + + scratch, err = ioutil.TempDir("", "tar-transformer-scratch") + Expect(err).ShouldNot(HaveOccurred()) + + destinationFile, err := ioutil.TempFile("", "destination") + Expect(err).ShouldNot(HaveOccurred()) + + err = destinationFile.Close() + Expect(err).ShouldNot(HaveOccurred()) + + destinationPath = destinationFile.Name() + }) + + AfterEach(func() { + err := os.RemoveAll(scratch) + Expect(err).ShouldNot(HaveOccurred()) + }) + + JustBeforeEach(func() { + transformedSize, transformErr = TarTransform(sourcePath, destinationPath) + }) + + Context("when the file is a .zip", func() { + BeforeEach(func() { + sourcePath = filepath.Join(scratch, "file.zip") + + test_helper.CreateZipArchive(sourcePath, archiveFiles) + }) + + It("closes the tarfile", func() { + // On Windows, you can't remove files that are still open. On Linux, you can. + err := os.Remove(destinationPath) + + Expect(err).ShouldNot(HaveOccurred()) + }) + }) +})