Skip to content

Commit 8e142df

Browse files
r3code13k
andauthored
Improved migration creation + Windows fix (#352)
* CLI: Improve migration creation. * Validates migration version on creation, to avoid creation of duplicated versions * Uses `os.OpenFile` with `O_CREATE|O_EXCL` to create files to avoid file collisions * Uses `filepath.Join` to concatenate paths, making `cleanPath()` not necessary * Prints generated filenames * Fixes #238 * Supersedes #250 * CLI: change `createCmd` to return error and accept `print` parameter * feat: print out an abs path Better for Windows OS when specify -dir as /subdir, user will see C:/subdir/0001_name.up.sql rather than /subdir/0001_name.up.sql * test: fixed abs path test fail for OS Windows abs path tests fails because filepath.IsAbs() treats `/subdir` path as invalid abs path at windows when drive letter is not present * feat: print absolute path for created files Better for Windows OS systems where `/path` can be interpreted in different ways depending on working dir * test: corrected tests for OS Windows OS Windows has different interpretation of `/path`, it depends on working dir. If working dir D:\test it interprets `/path` as `D:\path` * test: fixed `dir invalid` test Linux OS has less restriction on a filepath than Windows, so path invalid in windows is perfectly valid for Linux. The only invalid dir name in Linux is one ending with null string terminator (\000) * refac(cli): *Cmd() now returns an error and not uses log.fatalErr(err) * docs: added godoc, migarate usage updated * refac(cli): code refactored * refac: removed unnecessary path covert * docs: comment added * test: fixed code review issue, noErrorExpected var removed #352 (comment) * docs: fixed godoc Co-authored-by: Kiyoshi '13k' Murata <[email protected]>
1 parent 9b3db6c commit 8e142df

File tree

4 files changed

+377
-145
lines changed

4 files changed

+377
-145
lines changed

internal/cli/commands.go

Lines changed: 121 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,184 +3,223 @@ package cli
33
import (
44
"errors"
55
"fmt"
6-
"github.com/golang-migrate/migrate/v4"
7-
_ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again
8-
_ "github.com/golang-migrate/migrate/v4/source/file"
96
"os"
107
"path/filepath"
118
"strconv"
129
"strings"
1310
"time"
11+
12+
"github.com/golang-migrate/migrate/v4"
13+
_ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again
14+
_ "github.com/golang-migrate/migrate/v4/source/file"
1415
)
1516

16-
func nextSeq(matches []string, dir string, seqDigits int) (string, error) {
17+
var (
18+
errInvalidSequenceWidth = errors.New("Digits must be positive")
19+
errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
20+
errInvalidTimeFormat = errors.New("Time format may not be empty")
21+
)
22+
23+
func nextSeqVersion(matches []string, seqDigits int) (string, error) {
1724
if seqDigits <= 0 {
18-
return "", errors.New("Digits must be positive")
25+
return "", errInvalidSequenceWidth
1926
}
2027

21-
nextSeq := 1
28+
nextSeq := uint64(1)
29+
2230
if len(matches) > 0 {
2331
filename := matches[len(matches)-1]
24-
matchSeqStr := strings.TrimPrefix(filename, dir)
32+
matchSeqStr := filepath.Base(filename)
2533
idx := strings.Index(matchSeqStr, "_")
34+
2635
if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit
27-
return "", errors.New("Malformed migration filename: " + filename)
36+
return "", fmt.Errorf("Malformed migration filename: %s", filename)
2837
}
29-
matchSeqStr = matchSeqStr[0:idx]
38+
3039
var err error
31-
nextSeq, err = strconv.Atoi(matchSeqStr)
40+
matchSeqStr = matchSeqStr[0:idx]
41+
nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
42+
3243
if err != nil {
3344
return "", err
3445
}
46+
3547
nextSeq++
3648
}
37-
if nextSeq <= 0 {
38-
return "", errors.New("Next sequence number must be positive")
39-
}
4049

41-
nextSeqStr := strconv.Itoa(nextSeq)
42-
if len(nextSeqStr) > seqDigits {
43-
return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", nextSeqStr, seqDigits)
44-
}
45-
padding := seqDigits - len(nextSeqStr)
46-
if padding > 0 {
47-
nextSeqStr = strings.Repeat("0", padding) + nextSeqStr
50+
version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
51+
52+
if len(version) > seqDigits {
53+
return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
4854
}
49-
return nextSeqStr, nil
55+
56+
return version, nil
5057
}
5158

52-
// cleanDir normalizes the provided directory
53-
func cleanDir(dir string) string {
54-
dir = filepath.Clean(dir)
55-
switch dir {
56-
case ".":
57-
return ""
58-
case "/":
59-
return dir
59+
func timeVersion(startTime time.Time, format string) (version string, err error) {
60+
switch format {
61+
case "":
62+
err = errInvalidTimeFormat
63+
case "unix":
64+
version = strconv.FormatInt(startTime.Unix(), 10)
65+
case "unixNano":
66+
version = strconv.FormatInt(startTime.UnixNano(), 10)
6067
default:
61-
return dir + "/"
68+
version = startTime.Format(format)
6269
}
70+
71+
return
6372
}
6473

6574
// createCmd (meant to be called via a CLI command) creates a new migration
66-
func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int) {
67-
dir = cleanDir(dir)
68-
var base string
75+
func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
6976
if seq && format != defaultTimeFormat {
70-
log.fatalErr(errors.New("The seq and format options are mutually exclusive"))
77+
return errIncompatibleSeqAndFormat
7178
}
79+
80+
var version string
81+
var err error
82+
83+
dir = filepath.Clean(dir)
84+
ext = "." + strings.TrimPrefix(ext, ".")
85+
7286
if seq {
73-
if seqDigits <= 0 {
74-
log.fatalErr(errors.New("Digits must be positive"))
75-
}
76-
matches, err := filepath.Glob(dir + "*" + ext)
87+
matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
88+
7789
if err != nil {
78-
log.fatalErr(err)
90+
return err
7991
}
80-
nextSeqStr, err := nextSeq(matches, dir, seqDigits)
92+
93+
version, err = nextSeqVersion(matches, seqDigits)
94+
8195
if err != nil {
82-
log.fatalErr(err)
96+
return err
8397
}
84-
base = fmt.Sprintf("%v%v_%v.", dir, nextSeqStr, name)
8598
} else {
86-
switch format {
87-
case "":
88-
log.fatal("Time format may not be empty")
89-
case "unix":
90-
base = fmt.Sprintf("%v%v_%v.", dir, startTime.Unix(), name)
91-
case "unixNano":
92-
base = fmt.Sprintf("%v%v_%v.", dir, startTime.UnixNano(), name)
93-
default:
94-
base = fmt.Sprintf("%v%v_%v.", dir, startTime.Format(format), name)
99+
version, err = timeVersion(startTime, format)
100+
101+
if err != nil {
102+
return err
95103
}
96104
}
97105

98-
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
99-
log.fatalErr(err)
106+
versionGlob := filepath.Join(dir, version+"_*"+ext)
107+
matches, err := filepath.Glob(versionGlob)
108+
109+
if err != nil {
110+
return err
111+
}
112+
113+
if len(matches) > 0 {
114+
return fmt.Errorf("duplicate migration version: %s", version)
115+
}
116+
117+
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
118+
return err
119+
}
120+
121+
for _, direction := range []string{"up", "down"} {
122+
basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
123+
filename := filepath.Join(dir, basename)
124+
125+
if err = createFile(filename); err != nil {
126+
return err
127+
}
128+
129+
if print {
130+
absPath, _ := filepath.Abs(filename)
131+
log.Println(absPath)
132+
}
100133
}
101134

102-
createFile(base + "up" + ext)
103-
createFile(base + "down" + ext)
135+
return nil
104136
}
105137

106-
func createFile(fname string) {
107-
if _, err := os.Create(fname); err != nil {
108-
log.fatalErr(err)
138+
func createFile(filename string) error {
139+
// create exclusive (fails if file already exists)
140+
// os.Create() specifies 0666 as the FileMode, so we're doing the same
141+
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
142+
143+
if err != nil {
144+
return err
109145
}
146+
147+
return f.Close()
110148
}
111149

112-
func gotoCmd(m *migrate.Migrate, v uint) {
150+
func gotoCmd(m *migrate.Migrate, v uint) error {
113151
if err := m.Migrate(v); err != nil {
114152
if err != migrate.ErrNoChange {
115-
log.fatalErr(err)
116-
} else {
117-
log.Println(err)
153+
return err
118154
}
155+
log.Println(err)
119156
}
157+
return nil
120158
}
121159

122-
func upCmd(m *migrate.Migrate, limit int) {
160+
func upCmd(m *migrate.Migrate, limit int) error {
123161
if limit >= 0 {
124162
if err := m.Steps(limit); err != nil {
125163
if err != migrate.ErrNoChange {
126-
log.fatalErr(err)
127-
} else {
128-
log.Println(err)
164+
return err
129165
}
166+
log.Println(err)
130167
}
131168
} else {
132169
if err := m.Up(); err != nil {
133170
if err != migrate.ErrNoChange {
134-
log.fatalErr(err)
135-
} else {
136-
log.Println(err)
171+
return err
137172
}
173+
log.Println(err)
138174
}
139175
}
176+
return nil
140177
}
141178

142-
func downCmd(m *migrate.Migrate, limit int) {
179+
func downCmd(m *migrate.Migrate, limit int) error {
143180
if limit >= 0 {
144181
if err := m.Steps(-limit); err != nil {
145182
if err != migrate.ErrNoChange {
146-
log.fatalErr(err)
147-
} else {
148-
log.Println(err)
183+
return err
149184
}
185+
log.Println(err)
150186
}
151187
} else {
152188
if err := m.Down(); err != nil {
153189
if err != migrate.ErrNoChange {
154-
log.fatalErr(err)
155-
} else {
156-
log.Println(err)
190+
return err
157191
}
192+
log.Println(err)
158193
}
159194
}
195+
return nil
160196
}
161197

162-
func dropCmd(m *migrate.Migrate) {
198+
func dropCmd(m *migrate.Migrate) error {
163199
if err := m.Drop(); err != nil {
164-
log.fatalErr(err)
200+
return err
165201
}
202+
return nil
166203
}
167204

168-
func forceCmd(m *migrate.Migrate, v int) {
205+
func forceCmd(m *migrate.Migrate, v int) error {
169206
if err := m.Force(v); err != nil {
170-
log.fatalErr(err)
207+
return err
171208
}
209+
return nil
172210
}
173211

174-
func versionCmd(m *migrate.Migrate) {
212+
func versionCmd(m *migrate.Migrate) error {
175213
v, dirty, err := m.Version()
176214
if err != nil {
177-
log.fatalErr(err)
215+
return err
178216
}
179217
if dirty {
180218
log.Printf("%v (dirty)\n", v)
181219
} else {
182220
log.Println(v)
183221
}
222+
return nil
184223
}
185224

186225
// numDownMigrationsFromArgs returns an int for number of migrations to apply

0 commit comments

Comments
 (0)