Skip to content

Commit 479e802

Browse files
authored
Add upsert tests (#17)
* add upsert tests
1 parent e853ed6 commit 479e802

File tree

2 files changed

+257
-23
lines changed

2 files changed

+257
-23
lines changed

tests/passed-tests.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,15 +370,25 @@ TestUpdatesWithStructPointer
370370
TestUpdateCustomDataType
371371
TestBatchUpdateSlice
372372
#TestMixedSaveBatch
373-
#TestUpsert
374-
TestUpsertSlice
375373
#TestDistinctComputedColumn
376374
TestDistinctWithVaryingCase
377375
#TestDistinctWithAggregation
376+
TestUpsert
377+
TestUpsertSlice
378378
TestUpsertWithSave
379379
TestFindOrInitialize
380-
TestFindOrCreate
381380
TestUpdateWithMissWhere
381+
TestUpsertCompositePK
382+
TestUpsertIgnoreColumn
383+
TestUpsertReturning
384+
TestUpsertNullValues
385+
TestUpsertSliceMixed
386+
TestUpsertWithExpressions
387+
TestUpsertPrimaryKeyNotUpdated
388+
TestUpsertWithNullUnique
389+
TestUpsertLargeBatch
390+
TestUpsertFromSubquery
391+
TestFindOrCreate
382392
BenchmarkCreate
383393
BenchmarkFind
384394
BenchmarkScan

tests/upsert_test.go

Lines changed: 244 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
package tests
4040

4141
import (
42+
"fmt"
4243
"regexp"
4344
"testing"
4445

@@ -52,7 +53,6 @@ import (
5253
)
5354

5455
func TestUpsert(t *testing.T) {
55-
t.Skip()
5656
lang := Language{Code: "upsert", Name: "Upsert"}
5757
if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil {
5858
t.Fatalf("failed to upsert, got %v", err)
@@ -64,10 +64,10 @@ func TestUpsert(t *testing.T) {
6464
}
6565

6666
var langs []Language
67-
if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil {
68-
t.Errorf("no error should happen when find languages with code, but got %v", err)
67+
if err := DB.Find(&langs, "\"code\" = ?", lang.Code).Error; err != nil {
68+
t.Fatalf("no error should happen when find languages with code, but got %v", err)
6969
} else if len(langs) != 1 {
70-
t.Errorf("should only find only 1 languages, but got %+v", langs)
70+
t.Fatalf("should only find only 1 languages, but got %+v", langs)
7171
}
7272

7373
lang3 := Language{Code: "upsert", Name: "Upsert"}
@@ -78,12 +78,12 @@ func TestUpsert(t *testing.T) {
7878
t.Fatalf("failed to upsert, got %v", err)
7979
}
8080

81-
if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil {
82-
t.Errorf("no error should happen when find languages with code, but got %v", err)
81+
if err := DB.Find(&langs, "\"code\" = ?", lang.Code).Error; err != nil {
82+
t.Fatalf("no error should happen when find languages with code, but got %v", err)
8383
} else if len(langs) != 1 {
84-
t.Errorf("should only find only 1 languages, but got %+v", langs)
84+
t.Fatalf("should only find only 1 languages, but got %+v", langs)
8585
} else if langs[0].Name != "upsert-new" {
86-
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
86+
t.Fatalf("should update name on conflict, but got name %+v", langs[0].Name)
8787
}
8888

8989
lang = Language{Code: "upsert", Name: "Upsert-Newname"}
@@ -92,27 +92,25 @@ func TestUpsert(t *testing.T) {
9292
}
9393

9494
var result Language
95-
if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
95+
if err := DB.Find(&result, "\"code\" = ?", lang.Code).Error; err != nil || result.Name != lang.Name {
9696
t.Fatalf("failed to upsert, got name %v", result.Name)
9797
}
9898

99-
if name := DB.Dialector.Name(); name != "sqlserver" {
100-
type RestrictedLanguage struct {
101-
Code string `gorm:"primarykey"`
102-
Name string
103-
Lang string `gorm:"<-:create"`
104-
}
99+
type RestrictedLanguage struct {
100+
Code string `gorm:"primarykey"`
101+
Name string
102+
Lang string `gorm:"<-:create"`
103+
}
105104

106-
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
107-
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) {
108-
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
109-
}
105+
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
106+
if !regexp.MustCompile(`MERGE INTO "restricted_languages".*WHEN MATCHED THEN UPDATE SET "name"="excluded"."name".*INSERT \("code","name","lang"\)`).MatchString(r.Statement.SQL.String()) {
107+
t.Fatalf("Table with escape character, got %v", r.Statement.SQL.String())
110108
}
111109

112110
user := *GetUser("upsert_on_conflict", Config{})
113111
user.Age = 20
114112
if err := DB.Create(&user).Error; err != nil {
115-
t.Errorf("failed to create user, got error %v", err)
113+
t.Fatalf("failed to create user, got error %v", err)
116114
}
117115

118116
var user2 User
@@ -124,6 +122,8 @@ func TestUpsert(t *testing.T) {
124122
} else {
125123
var user3 User
126124
DB.First(&user3, user.ID)
125+
fmt.Printf("%d\n", user3.UpdatedAt.UnixNano())
126+
fmt.Printf("%d\n", user2.UpdatedAt.UnixNano())
127127
if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() {
128128
t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt)
129129
}
@@ -369,3 +369,227 @@ func TestUpdateWithMissWhere(t *testing.T) {
369369
t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String())
370370
}
371371
}
372+
373+
type CompositeLang struct {
374+
Code string `gorm:"primaryKey;size:100"`
375+
Lang string `gorm:"primaryKey;size:10"`
376+
Name string
377+
}
378+
379+
func TestUpsertCompositePK(t *testing.T) {
380+
langs := []CompositeLang{
381+
{Code: "c1", Lang: "en", Name: "English"},
382+
{Code: "c1", Lang: "fr", Name: "French"},
383+
}
384+
385+
DB.Migrator().DropTable(&CompositeLang{})
386+
DB.Migrator().AutoMigrate(&CompositeLang{})
387+
388+
if err := DB.Create(&langs).Error; err != nil {
389+
t.Fatalf("failed to insert composite PK: %v", err)
390+
}
391+
392+
for i := range langs {
393+
langs[i].Name = langs[i].Name + "_updated"
394+
}
395+
396+
if err := DB.Clauses(clause.OnConflict{
397+
UpdateAll: true,
398+
}).Create(&langs).Error; err != nil {
399+
t.Fatalf("failed to upsert composite PK: %v", err)
400+
}
401+
402+
for _, expected := range langs {
403+
var result CompositeLang
404+
if err := DB.First(&result, "\"code\" = ? AND \"lang\" = ?", expected.Code, expected.Lang).Error; err != nil {
405+
t.Fatalf("failed to fetch row for %+v: %v", expected, err)
406+
}
407+
if result.Name != expected.Name {
408+
t.Fatalf("expected %v, got %v", expected.Name, result.Name)
409+
}
410+
}
411+
412+
DB.Migrator().DropTable(&CompositeLang{})
413+
}
414+
415+
func TestUpsertPrimaryKeyNotUpdated(t *testing.T) {
416+
lang := Language{Code: "pk1", Name: "Name1"}
417+
DB.Create(&lang)
418+
419+
lang.Code = "pk2" // try changing PK
420+
DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang)
421+
422+
var result Language
423+
DB.First(&result, "\"code\" = ?", "pk1")
424+
if result.Name != "Name1" {
425+
t.Fatalf("expected original row untouched, got %v", result)
426+
}
427+
}
428+
429+
type LangWithIgnore struct {
430+
Code string `gorm:"primaryKey"`
431+
Name string
432+
Lang string `gorm:"<-:create"` // should not be updated
433+
}
434+
435+
func TestUpsertIgnoreColumn(t *testing.T) {
436+
DB.Migrator().DropTable(&LangWithIgnore{})
437+
DB.Migrator().AutoMigrate(&LangWithIgnore{})
438+
lang := LangWithIgnore{Code: "upsert_ignore", Name: "OldName", Lang: "en"}
439+
DB.Create(&lang)
440+
441+
lang.Name = "NewName"
442+
lang.Lang = "fr"
443+
DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang)
444+
445+
var result LangWithIgnore
446+
DB.First(&result, "\"code\" = ?", lang.Code)
447+
if result.Name != "NewName" {
448+
t.Fatalf("expected Name updated, got %v", result.Name)
449+
}
450+
if result.Lang != "en" {
451+
t.Fatalf("Lang should not be updated, got %v", result.Lang)
452+
}
453+
DB.Migrator().DropTable(&LangWithIgnore{})
454+
}
455+
456+
func TestUpsertNullValues(t *testing.T) {
457+
lang := Language{Code: "upsert_null", Name: ""}
458+
DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang)
459+
460+
var result Language
461+
DB.First(&result, "\"code\" = ?", lang.Code)
462+
if result.Name != "" {
463+
t.Fatalf("expected empty Name, got %v", result.Name)
464+
}
465+
}
466+
467+
func TestUpsertWithNullUnique(t *testing.T) {
468+
type NullLang struct {
469+
Code *string `gorm:"uniqueIndex"`
470+
Name string
471+
}
472+
DB.Migrator().DropTable(&NullLang{})
473+
DB.Migrator().AutoMigrate(&NullLang{})
474+
475+
DB.Create(&NullLang{Code: nil, Name: "First"})
476+
477+
if err := DB.Clauses(clause.OnConflict{
478+
Columns: []clause.Column{{Name: "code"}},
479+
DoUpdates: clause.AssignmentColumns([]string{"name"}),
480+
}).Create(&NullLang{Code: nil, Name: "Second"}).Error; err != nil {
481+
t.Fatalf("unexpected error on upsert with NULL: %v", err)
482+
}
483+
484+
var count int64
485+
DB.Model(&NullLang{}).Count(&count)
486+
if count != 2 {
487+
t.Fatalf("expected 2 rows due to NULL uniqueness, got %d", count)
488+
}
489+
}
490+
491+
func TestUpsertSliceMixed(t *testing.T) {
492+
DB.Create(&Language{Code: "m1", Name: "Old1"})
493+
langs := []Language{
494+
{Code: "m1", Name: "New1"}, // exists
495+
{Code: "m2", Name: "New2"}, // new
496+
}
497+
498+
DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&langs)
499+
500+
var l1, l2 Language
501+
DB.First(&l1, "\"code\" = ?", "m1")
502+
DB.First(&l2, "\"code\" = ?", "m2")
503+
if l1.Name != "New1" || l2.Name != "New2" {
504+
t.Fatalf("batch mixed upsert failed: %+v, %+v", l1, l2)
505+
}
506+
}
507+
508+
func TestUpsertWithExpressions(t *testing.T) {
509+
lang := Language{Code: "expr1", Name: "Name1"}
510+
DB.Create(&lang)
511+
512+
DB.Clauses(clause.OnConflict{
513+
Columns: []clause.Column{{Name: "code"}},
514+
DoUpdates: clause.Assignments(map[string]interface{}{
515+
"name": gorm.Expr("UPPER(?)", "newname"),
516+
}),
517+
}).Create(&lang)
518+
519+
var result Language
520+
DB.First(&result, "\"code\" = ?", "expr1")
521+
if result.Name != "NEWNAME" {
522+
t.Fatalf("expected NEWNAME, got %v", result.Name)
523+
}
524+
}
525+
526+
func TestUpsertLargeBatch(t *testing.T) {
527+
var langs []Language
528+
for i := 0; i < 1000; i++ {
529+
langs = append(langs, Language{Code: fmt.Sprintf("lb_%d", i), Name: "Name"})
530+
}
531+
if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs).Error; err != nil {
532+
t.Fatalf("failed large batch insert: %v", err)
533+
}
534+
}
535+
536+
func TestUpsertFromSubquery(t *testing.T) {
537+
DB.Migrator().DropTable(&Language{})
538+
if err := DB.AutoMigrate(&Language{}); err != nil {
539+
t.Fatalf("failed to migrate: %v", err)
540+
}
541+
542+
initial := []Language{
543+
{Code: "en", Name: "English"},
544+
{Code: "fr", Name: "French - Old"}, // Will be updated
545+
{Code: "es", Name: "Spanish - Old"}, // Will be updated
546+
}
547+
if err := DB.Create(&initial).Error; err != nil {
548+
t.Fatalf("failed to seed: %v", err)
549+
}
550+
551+
updates := []Language{
552+
{Code: "fr", Name: "French - Updated"},
553+
{Code: "es", Name: "Spanish - Updated"},
554+
{Code: "de", Name: "German"}, // New record
555+
}
556+
557+
for _, update := range updates {
558+
err := DB.Clauses(clause.OnConflict{
559+
Columns: []clause.Column{{Name: "code"}},
560+
DoUpdates: clause.AssignmentColumns([]string{"name"}),
561+
}).Create(&update).Error
562+
563+
if err != nil {
564+
t.Fatalf("failed upsert: %v", err)
565+
}
566+
}
567+
568+
var results []Language
569+
if err := DB.Order("\"code\"").Find(&results).Error; err != nil {
570+
t.Fatalf("failed to query results: %v", err)
571+
}
572+
573+
expected := []Language{
574+
{Code: "de", Name: "German"}, // inserted
575+
{Code: "en", Name: "English"}, // unchanged
576+
{Code: "es", Name: "Spanish - Updated"}, // updated
577+
{Code: "fr", Name: "French - Updated"}, // updated
578+
}
579+
580+
if len(results) != len(expected) {
581+
t.Errorf("expected %d rows, got %d", len(expected), len(results))
582+
}
583+
584+
for i := range expected {
585+
if i >= len(results) {
586+
t.Errorf("missing row %d: expected (%s, %s)", i, expected[i].Code, expected[i].Name)
587+
continue
588+
}
589+
if results[i].Code != expected[i].Code || results[i].Name != expected[i].Name {
590+
t.Errorf("row %d mismatch: expected (%s, %s), got (%s, %s)",
591+
i, expected[i].Code, expected[i].Name,
592+
results[i].Code, results[i].Name)
593+
}
594+
}
595+
}

0 commit comments

Comments
 (0)