diff --git a/constants.go b/constants.go index 23288d3..d5cd0ea 100644 --- a/constants.go +++ b/constants.go @@ -10,6 +10,7 @@ const ( annotationOmitEmpty = "omitempty" annotationISO8601 = "iso8601" annotationSeperator = "," + annotationIgnore = "-" iso8601TimeFormat = "2006-01-02T15:04:05Z" diff --git a/node.go b/node.go index 3a0c02e..a472347 100644 --- a/node.go +++ b/node.go @@ -28,6 +28,38 @@ type Node struct { Links *Links `json:"links,omitempty"` } +func (n *Node) merge(node *Node) { + if node.Type != "" { + n.Type = node.Type + } + + if node.ID != "" { + n.ID = node.ID + } + + if node.ClientID != "" { + n.ClientID = node.ClientID + } + + if n.Attributes == nil && node.Attributes != nil { + n.Attributes = make(map[string]interface{}) + } + for k, v := range node.Attributes { + n.Attributes[k] = v + } + + if n.Relationships == nil && n.Relationships != nil { + n.Relationships = make(map[string]interface{}) + } + for k, v := range node.Relationships { + n.Relationships[k] = v + } + + if node.Links != nil { + n.Links = node.Links + } +} + // RelationshipOneNode is used to represent a generic has one JSON API relation type RelationshipOneNode struct { Data *Node `json:"data"` diff --git a/request.go b/request.go index 335ecb4..33b9b09 100644 --- a/request.go +++ b/request.go @@ -131,14 +131,28 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) for i := 0; i < modelValue.NumField(); i++ { fieldType := modelType.Field(i) - tag := fieldType.Tag.Get("jsonapi") + tag := fieldType.Tag.Get(annotationJSONAPI) + + // handles embedded structs + if isEmbeddedStruct(fieldType) { + if shouldIgnoreField(tag) { + continue + } + model := reflect.ValueOf(modelValue.Field(i).Addr().Interface()) + err := unmarshalNode(data, model, included) + if err != nil { + er = err + break + } + } + if tag == "" { continue } fieldValue := modelValue.Field(i) - args := strings.Split(tag, ",") + args := strings.Split(tag, annotationSeperator) if len(args) < 1 { er = ErrBadJSONAPIStructTag @@ -446,7 +460,8 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } // As a final catch-all, ensure types line up to avoid a runtime panic. - if fieldValue.Kind() != v.Kind() { + // Ignore interfaces since interfaces are poly + if fieldValue.Kind() != reflect.Interface && fieldValue.Kind() != v.Kind() { return ErrInvalidType } fieldValue.Set(reflect.ValueOf(val)) diff --git a/response.go b/response.go index c44cd3b..51ef157 100644 --- a/response.go +++ b/response.go @@ -211,6 +211,21 @@ func visitModelNode(model interface{}, included *map[string]*Node, for i := 0; i < modelValue.NumField(); i++ { structField := modelValue.Type().Field(i) tag := structField.Tag.Get(annotationJSONAPI) + + // handles embedded structs + if isEmbeddedStruct(structField) { + if shouldIgnoreField(tag) { + continue + } + model := modelValue.Field(i).Addr().Interface() + embNode, err := visitModelNode(model, included, sideload) + if err != nil { + er = err + break + } + node.merge(embNode) + } + if tag == "" { continue } @@ -517,3 +532,17 @@ func convertToSliceInterface(i *interface{}) ([]interface{}, error) { } return response, nil } + +func isEmbeddedStruct(sField reflect.StructField) bool { + if sField.Anonymous && sField.Type.Kind() == reflect.Struct { + return true + } + return false +} + +func shouldIgnoreField(japiTag string) bool { + if strings.HasPrefix(japiTag, annotationIgnore) { + return true + } + return false +} diff --git a/response_test.go b/response_test.go index 756fe87..e7edd00 100644 --- a/response_test.go +++ b/response_test.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "sort" + "strings" "testing" "time" ) @@ -59,7 +60,7 @@ func (b *Blog) JSONAPIRelationshipLinks(relation string) *Links { } type Post struct { - Blog + Blog `jsonapi:"-"` ID uint64 `jsonapi:"primary,posts"` BlogID int `jsonapi:"attr,blog_id"` ClientID string `jsonapi:"client-id"` @@ -829,6 +830,205 @@ func TestMarshalMany_InvalidIntefaceArgument(t *testing.T) { } } +func TestMergeNode(t *testing.T) { + parent := &Node{ + Type: "Good", + ID: "99", + Attributes: map[string]interface{}{"fizz": "buzz"}, + } + + child := &Node{ + Type: "Better", + ClientID: "1111", + Attributes: map[string]interface{}{"timbuk": 2}, + } + + expected := &Node{ + Type: "Better", + ID: "99", + ClientID: "1111", + Attributes: map[string]interface{}{"fizz": "buzz", "timbuk": 2}, + } + + parent.merge(child) + + if !reflect.DeepEqual(expected, parent) { + t.Errorf("Got %+v Expected %+v", parent, expected) + } +} + +func TestIsEmbeddedStruct(t *testing.T) { + type foo struct{} + + structType := reflect.TypeOf(foo{}) + stringType := reflect.TypeOf("") + if structType.Kind() != reflect.Struct { + t.Fatal("structType.Kind() is not a struct.") + } + if stringType.Kind() != reflect.String { + t.Fatal("stringType.Kind() is not a string.") + } + + type test struct { + scenario string + input reflect.StructField + expectedRes bool + } + + tests := []test{ + test{ + scenario: "success", + input: reflect.StructField{Anonymous: true, Type: structType}, + expectedRes: true, + }, + test{ + scenario: "wrong type", + input: reflect.StructField{Anonymous: true, Type: stringType}, + expectedRes: false, + }, + test{ + scenario: "not embedded", + input: reflect.StructField{Type: structType}, + expectedRes: false, + }, + } + + for _, test := range tests { + res := isEmbeddedStruct(test.input) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +func TestShouldIgnoreField(t *testing.T) { + type test struct { + scenario string + input string + expectedRes bool + } + + tests := []test{ + test{ + scenario: "opt-out", + input: annotationIgnore, + expectedRes: true, + }, + test{ + scenario: "no tag", + input: "", + expectedRes: false, + }, + test{ + scenario: "wrong tag", + input: "wrong,tag", + expectedRes: false, + }, + } + + for _, test := range tests { + res := shouldIgnoreField(test.input) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +func TestIsValidEmbeddedStruct(t *testing.T) { + type foo struct{} + + structType := reflect.TypeOf(foo{}) + stringType := reflect.TypeOf("") + if structType.Kind() != reflect.Struct { + t.Fatal("structType.Kind() is not a struct.") + } + if stringType.Kind() != reflect.String { + t.Fatal("stringType.Kind() is not a string.") + } + + type test struct { + scenario string + input reflect.StructField + expectedRes bool + } + + tests := []test{ + test{ + scenario: "success", + input: reflect.StructField{Anonymous: true, Type: structType}, + expectedRes: true, + }, + test{ + scenario: "opt-out", + input: reflect.StructField{Anonymous: true, Tag: "jsonapi:\"-\"", Type: structType}, + expectedRes: false, + }, + test{ + scenario: "wrong type", + input: reflect.StructField{Anonymous: true, Type: stringType}, + expectedRes: false, + }, + test{ + scenario: "not embedded", + input: reflect.StructField{Type: structType}, + expectedRes: false, + }, + } + + for _, test := range tests { + res := (isEmbeddedStruct(test.input) && !shouldIgnoreField(test.input.Tag.Get(annotationJSONAPI))) + if res != test.expectedRes { + t.Errorf("Scenario -> %s\nGot -> %v\nExpected -> %v\n", test.scenario, res, test.expectedRes) + } + } +} + +func TestMarshalUnmarshalCompositeStruct(t *testing.T) { + type Thing struct { + ID int `jsonapi:"primary,things"` + Fizz string `jsonapi:"attr,fizz"` + Buzz int `jsonapi:"attr,buzz"` + } + + type Model struct { + Thing + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + } + + model := &Model{} + model.ID = 1 + model.Fizz = "fizzy" + model.Buzz = 99 + model.Foo = "fooey" + model.Bar = "barry" + model.Bat = "batty" + + buf := bytes.NewBuffer(nil) + if err := MarshalOnePayload(buf, model); err != nil { + t.Fatal(err) + } + + // assert encoding from model to jsonapi output + expected := `{"data":{"type":"things","id":"1","attributes":{"bar":"barry","bat":"batty","buzz":99,"fizz":"fizzy","foo":"fooey"}}}` + actual := strings.TrimSpace(string(buf.Bytes())) + + if expected != actual { + t.Errorf("Got %+v Expected %+v", actual, expected) + } + + dst := &Model{} + if err := UnmarshalPayload(buf, dst); err != nil { + t.Fatal(err) + } + + // assert decoding from jsonapi output to model + if !reflect.DeepEqual(model, dst) { + t.Errorf("Got %#v Expected %#v", dst, model) + } +} + func testBlog() *Blog { return &Blog{ ID: 5,