Skip to content

Commit 597b87a

Browse files
Fixed random data generation
1 parent 19adada commit 597b87a

File tree

3 files changed

+98
-55
lines changed

3 files changed

+98
-55
lines changed

internal/getters/getters.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ type Getter interface {
1919
}
2020

2121
type RandomInt struct {
22-
mask uint64
22+
mask int64
2323
allowNull bool
2424
}
2525

2626
func (r *RandomInt) Value() interface{} {
2727
rand.Seed(time.Now().UnixNano())
28-
return uint64(rand.Int63n(10e8)) & r.mask
28+
return rand.Int63n(r.mask)
2929
}
3030

31-
func NewRandomInt(mask uint64, allowNull bool) Getter {
31+
func NewRandomInt(mask int64, allowNull bool) Getter {
3232
return &RandomInt{mask, allowNull}
3333
}
3434

@@ -39,7 +39,6 @@ type RandomIntRange struct {
3939
}
4040

4141
func (r *RandomIntRange) Value() interface{} {
42-
rand.Seed(time.Now().UnixNano())
4342
limit := r.max - r.min + 1
4443
return r.min + rand.Int63n(limit)
4544
}
@@ -54,7 +53,6 @@ type RandomDecimal struct {
5453
}
5554

5655
func (r *RandomDecimal) Value() interface{} {
57-
rand.Seed(time.Now().UnixNano())
5856
f := rand.Float64() * float64(rand.Int63n(int64(math.Pow10(int(r.size)))))
5957
format := fmt.Sprintf("%%%0.1ff", r.size)
6058
return fmt.Sprintf(format, f)
@@ -73,7 +71,6 @@ type RandomString struct {
7371
}
7472

7573
func (r *RandomString) Value() interface{} {
76-
rand.Seed(time.Now().UnixNano())
7774
if r.allowNull && rand.Int63n(100) < nilFrequency {
7875
return nil
7976
}

main.go

Lines changed: 91 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import (
66
"log"
77
"net/url"
88
"os"
9+
"runtime"
910
"strings"
1011
"sync"
12+
"sync/atomic"
1113
"time"
1214

1315
"github.com/Percona-Lab/random_data_load/internal/getters"
@@ -34,13 +36,13 @@ var (
3436
rows = app.Arg("rows", "Number of rows to insert").Required().Int()
3537

3638
validFunctions = []string{"int", "string", "date", "date_in_range"}
37-
masks = map[string]uint64{
39+
masks = map[string]int64{
3840
"tinyint": 0XF,
3941
"smallint": 0xFF,
40-
"mediumint": 0xFFF,
41-
"int": 0xFFFF,
42-
"integer": 0xFFFF,
43-
"bigint": 0xFFFFFFFF,
42+
"mediumint": 0x7FFFF,
43+
"int": 0x7FFFFFFF,
44+
"integer": 0x7FFFFFFF,
45+
"bigint": 0x7FFFFFFFFFFFFFFF,
4446
}
4547
)
4648

@@ -57,12 +59,13 @@ func main() {
5759
}
5860

5961
dsn := Config{
60-
User: *user,
61-
Passwd: *pass,
62-
Addr: address,
63-
Net: net,
64-
DBName: *dbName,
65-
ParseTime: true,
62+
User: *user,
63+
Passwd: *pass,
64+
Addr: address,
65+
Net: net,
66+
DBName: *dbName,
67+
ParseTime: true,
68+
ClientFoundRows: true,
6669
}
6770

6871
db, err := sql.Open("mysql", dsn.FormatDSN())
@@ -91,98 +94,138 @@ func main() {
9194
fmt.Println(sql)
9295
}
9396

94-
stmt, err := db.Prepare(sql)
95-
if err != nil {
96-
log.Printf("cannot prepare %q: %s", sql, err)
97-
os.Exit(1)
98-
}
99-
defer stmt.Close()
100-
97+
var wg sync.WaitGroup
98+
var okRowsCount int64
10199
values := makeValueFuncs(table.Fields)
100+
resultsChan := make(chan int)
101+
102102
rowsChan := makeRowsChan(*rows, values)
103103

104104
if *maxThreads < 1 {
105105
*maxThreads = 1
106106
}
107-
var wg sync.WaitGroup
108107

109108
log.Println("Starting")
110109

110+
bar := uiprogress.AddBar(*rows).AppendCompleted()
111111
uiprogress.Start()
112-
bar := uiprogress.AddBar(*rows).PrependElapsed().AppendCompleted()
113112

114-
fields, placeholders := getFieldsAndPlaceholders(table.Fields)
113+
// This go-routine keeps track of how many rows were actually inserted
114+
// by the bulk inserts since one or more rows could generate duplicated
115+
// keys so, not allways the number of inserted rows = number of rows in
116+
// the bulk insert
117+
go func() {
118+
for okCount := range resultsChan {
119+
bar.Set(bar.Current() + okCount)
120+
atomic.AddInt64(&okRowsCount, int64(okCount))
121+
}
122+
}()
123+
115124
for i := 0; i < *maxThreads; i++ {
116125
wg.Add(1)
117-
go runInsert(*dbName, *tableName, *bulkSize, fields, placeholders, db, rowsChan, bar, &wg)
126+
go runInsert(db, table, *bulkSize, rowsChan, resultsChan, &wg)
118127
}
119128
wg.Wait()
129+
130+
// Let the counter go-rutine to run for the last time
131+
runtime.Gosched()
132+
close(resultsChan)
133+
134+
if okRowsCount != int64(*rows) {
135+
loadExtraRows(db, table, int64(*rows)-okRowsCount, values)
136+
bar.Set(*rows)
137+
}
138+
}
139+
140+
func loadExtraRows(db *sql.DB, table *tableparser.Table, rows int64, values []getters.Getter) {
141+
var okCount int64
142+
for okCount < rows {
143+
vals := make([]interface{}, len(values))
144+
for j, val := range values {
145+
vals[j] = val.Value()
146+
}
147+
148+
if err := runOneInsert(db, table, vals); err != nil {
149+
continue
150+
}
151+
okCount++
152+
}
120153
}
121154

122155
func makeRowsChan(rows int, values []getters.Getter) chan []interface{} {
123156
preloadCount := 10000
124157
if rows < preloadCount {
125158
preloadCount = rows
126159
}
160+
127161
rowsChan := make(chan []interface{}, preloadCount)
128162
go func() {
129-
vals := make([]interface{}, len(values))
130163
for i := 0; i < rows; i++ {
131-
for i, val := range values {
132-
vals[i] = val.Value()
164+
vals := make([]interface{}, len(values))
165+
for j, val := range values {
166+
vals[j] = val.Value()
133167
}
134168
rowsChan <- vals
135169
}
136170
close(rowsChan)
137171
}()
138-
139172
return rowsChan
140173
}
141174

142-
func runInsert(dbName string, tableName string, bulkSize int, fieldNames []string,
143-
placeholders []string, db *sql.DB, valsChan chan []interface{},
144-
bar *uiprogress.Bar, wg *sync.WaitGroup) {
145-
baseSQL := fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES ",
146-
backticks(dbName),
147-
backticks(tableName),
148-
strings.Join(fieldNames, ","),
175+
func runInsert(db *sql.DB, table *tableparser.Table, bulkSize int, valsChan chan []interface{},
176+
resultsChan chan int, wg *sync.WaitGroup) {
177+
//
178+
fields, placeholders := getFieldsAndPlaceholders(table.Fields)
179+
baseSQL := fmt.Sprintf("INSERT IGNORE INTO %s (%s) VALUES ",
180+
backticks(table.Name),
181+
strings.Join(fields, ","),
149182
)
150183
separator := ""
151184
sql := baseSQL
152-
bulkVals := make([]interface{}, 0, len(fieldNames))
153-
count := 0
185+
bulkVals := []interface{}{}
186+
var count int
154187

155188
for vals := range valsChan {
156189
sql += separator + "(" + strings.Join(placeholders, ",") + ")"
157190
separator = ", "
158-
bar.Incr()
159191
bulkVals = append(bulkVals, vals...)
160192
count++
161193
if count < bulkSize {
162194
continue
163195
}
164-
_, err := db.Exec(sql, bulkVals...)
165-
if err != nil {
166-
log.Printf("Error inserting values: %s\n", err)
167-
}
196+
result, _ := db.Exec(sql, bulkVals...)
197+
rowsAffected, _ := result.RowsAffected()
198+
resultsChan <- int(rowsAffected)
168199
separator = ""
169200
sql = baseSQL
170-
bulkVals = nil
201+
bulkVals = []interface{}{}
171202
count = 0
172203
}
173-
if count > 0 {
174-
_, err := db.Exec(sql, bulkVals...)
175-
if err != nil {
176-
log.Printf("Error inserting values: %s\n", err)
177-
}
204+
if count > 0 && len(bulkVals) > 0 {
205+
result, _ := db.Exec(sql, bulkVals...)
206+
rowsAffected, _ := result.RowsAffected()
207+
resultsChan <- int(rowsAffected)
178208
}
179209
wg.Done()
180210
}
181211

212+
func runOneInsert(db *sql.DB, table *tableparser.Table, vals []interface{}) error {
213+
fields, placeholders := getFieldsAndPlaceholders(table.Fields)
214+
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
215+
backticks(table.Name),
216+
strings.Join(fields, ","),
217+
strings.Join(placeholders, ","),
218+
)
219+
if _, err := db.Exec(query, vals...); err != nil {
220+
return err
221+
}
222+
return nil
223+
}
224+
182225
func makeValueFuncs(fields []tableparser.Field) []getters.Getter {
183226
var values []getters.Getter
184227
for _, field := range fields {
185-
if !field.AllowsNull && !field.Default.Valid && field.Key == "PRI" &&
228+
if !field.AllowsNull && field.Key == "PRI" &&
186229
strings.Contains(field.Extra, "auto_increment") {
187230
continue
188231
}
@@ -225,7 +268,7 @@ func getFieldsAndPlaceholders(fields []tableparser.Field) ([]string, []string) {
225268
continue
226269
}
227270
fieldNames = append(fieldNames, backticks(field.Name))
228-
if !field.AllowsNull && !field.Default.Valid && field.Key == "PRI" &&
271+
if !field.AllowsNull && field.Key == "PRI" &&
229272
strings.Contains(field.Extra, "auto_increment") {
230273
placeHolders = append(placeHolders, "NULL")
231274
} else {

tableparser/tableparser.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
)
1212

1313
type Table struct {
14+
Name string
1415
Fields []Field
1516
Indexes []Index
1617
}
@@ -32,7 +33,9 @@ type Field struct {
3233
}
3334

3435
func Parse(db *sql.DB, tableName string) (*Table, error) {
35-
table := &Table{}
36+
table := &Table{
37+
Name: tableName,
38+
}
3639
var err error
3740
table.Fields, err = parseTable(db, tableName)
3841
if err != nil {

0 commit comments

Comments
 (0)