Skip to content

Commit 5213254

Browse files
authored
proto: store extension values according to protobuf data model (#746)
The current API represents scalar extension fields as *T and repeated extension fields as []T. However, this is not an accurate reflection of the protobuf data model. For scalars, pointers are usually used to represent nullability. However, in the case of extension scalars, there is no need to do so since presence information is captured by checking whether the field is in the extension map. For this reason, presence on extension scalars is not determined by checking whether the returned pointer is nil, but whether HasExtension reports true. For repeated fields, using []T means that the returned value is only a partially mutable value. You can swap out elements, but you cannot change the length of the original field value. On the other hand, the reflective API provides methods on repeated field values that permit operations that do change the length. Thus, using *[]T is a closer match to the protobuf data model. This CL changes the implementation of extension fields to always store T for scalars, and *[]T for repeated fields. However, for backwards compatibility, the API continues to provide *T and []T. In theory, this could break anyone that relies on memory aliasing for *T. However, this is unlikely since: * use of extensions themselves are relatively rare * if extensions are used, it is recommended practice to use a message as the field extension and not a scalar * relying on memory aliasing is generally not a good idiom to follow. The expected pattern is to call SetExtension to make it explicit that a mutation is happening on the message. * analysis within Google demonstrates that no one is relying on this behavior.
1 parent 951a149 commit 5213254

File tree

6 files changed

+141
-42
lines changed

6 files changed

+141
-42
lines changed

proto/clone_test.go

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,50 @@ func init() {
6767
if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil {
6868
panic("SetExtension: " + err.Error())
6969
}
70+
if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_Text, proto.String("hello")); err != nil {
71+
panic("SetExtension: " + err.Error())
72+
}
73+
if err := proto.SetExtension(cloneTestMessage, pb.E_Greeting, []string{"one", "two"}); err != nil {
74+
panic("SetExtension: " + err.Error())
75+
}
7076
}
7177

7278
func TestClone(t *testing.T) {
79+
// Create a clone using a marshal/unmarshal roundtrip.
80+
vanilla := new(pb.MyMessage)
81+
b, err := proto.Marshal(cloneTestMessage)
82+
if err != nil {
83+
t.Errorf("unexpected Marshal error: %v", err)
84+
}
85+
if err := proto.Unmarshal(b, vanilla); err != nil {
86+
t.Errorf("unexpected Unarshal error: %v", err)
87+
}
88+
89+
// Create a clone using Clone and verify that it is equal to the original.
7390
m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
7491
if !proto.Equal(m, cloneTestMessage) {
7592
t.Fatalf("Clone(%v) = %v", cloneTestMessage, m)
7693
}
7794

78-
// Verify it was a deep copy.
79-
*m.Inner.Port++
80-
if proto.Equal(m, cloneTestMessage) {
81-
t.Error("Mutating clone changed the original")
82-
}
83-
// Byte fields and repeated fields should be copied.
84-
if &m.Pet[0] == &cloneTestMessage.Pet[0] {
85-
t.Error("Pet: repeated field not copied")
95+
// Mutate the clone, which should not affect the original.
96+
x1, err := proto.GetExtension(m, pb.E_Ext_More)
97+
if err != nil {
98+
t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Ext_More.Name, err)
8699
}
87-
if &m.Others[0] == &cloneTestMessage.Others[0] {
88-
t.Error("Others: repeated field not copied")
100+
x2, err := proto.GetExtension(m, pb.E_Ext_Text)
101+
if err != nil {
102+
t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Ext_Text.Name, err)
89103
}
90-
if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
91-
t.Error("Others[0].Value: bytes field not copied")
104+
x3, err := proto.GetExtension(m, pb.E_Greeting)
105+
if err != nil {
106+
t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Greeting.Name, err)
92107
}
93-
if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
94-
t.Error("RepBytes: repeated field not copied")
95-
}
96-
if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
97-
t.Error("RepBytes[0]: bytes field not copied")
108+
*m.Inner.Port++
109+
*(x1.(*pb.Ext)).Data = "blah blah"
110+
*(x2.(*string)) = "goodbye"
111+
x3.([]string)[0] = "zero"
112+
if !proto.Equal(cloneTestMessage, vanilla) {
113+
t.Fatalf("mutation on original detected:\ngot %v\nwant %v", cloneTestMessage, vanilla)
98114
}
99115
}
100116

proto/equal.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
246246
return false
247247
}
248248

249-
m1, m2 := e1.value, e2.value
249+
m1 := extensionAsLegacyType(e1.value)
250+
m2 := extensionAsLegacyType(e2.value)
250251

251252
if m1 == nil && m2 == nil {
252253
// Both have only encoded form.

proto/extensions.go

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,25 @@ type Extension struct {
185185
// extension will have only enc set. When such an extension is
186186
// accessed using GetExtension (or GetExtensions) desc and value
187187
// will be set.
188-
desc *ExtensionDesc
188+
desc *ExtensionDesc
189+
190+
// value is a concrete value for the extension field. Let the type of
191+
// desc.ExtensionType be the "API type" and the type of Extension.value
192+
// be the "storage type". The API type and storage type are the same except:
193+
// * For scalars (except []byte), the API type uses *T,
194+
// while the storage type uses T.
195+
// * For repeated fields, the API type uses []T, while the storage type
196+
// uses *[]T.
197+
//
198+
// The reason for the divergence is so that the storage type more naturally
199+
// matches what is expected of when retrieving the values through the
200+
// protobuf reflection APIs.
201+
//
202+
// The value may only be populated if desc is also populated.
189203
value interface{}
190-
enc []byte
204+
205+
// enc is the raw bytes for the extension field.
206+
enc []byte
191207
}
192208

193209
// SetRawExtension is for testing only.
@@ -334,7 +350,7 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
334350
// descriptors with the same field number.
335351
return nil, errors.New("proto: descriptor conflict")
336352
}
337-
return e.value, nil
353+
return extensionAsLegacyType(e.value), nil
338354
}
339355

340356
if extension.ExtensionType == nil {
@@ -349,11 +365,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
349365

350366
// Remember the decoded version and drop the encoded version.
351367
// That way it is safe to mutate what we return.
352-
e.value = v
368+
e.value = extensionAsStorageType(v)
353369
e.desc = extension
354370
e.enc = nil
355371
emap[extension.Field] = e
356-
return e.value, nil
372+
return extensionAsLegacyType(e.value), nil
357373
}
358374

359375
// defaultExtensionValue returns the default value for extension.
@@ -500,7 +516,7 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
500516
}
501517

502518
extmap := epb.extensionsWrite()
503-
extmap[extension.Field] = Extension{desc: extension, value: value}
519+
extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)}
504520
return nil
505521
}
506522

@@ -541,3 +557,51 @@ func RegisterExtension(desc *ExtensionDesc) {
541557
func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
542558
return extensionMaps[reflect.TypeOf(pb).Elem()]
543559
}
560+
561+
// extensionAsLegacyType converts an value in the storage type as the API type.
562+
// See Extension.value.
563+
func extensionAsLegacyType(v interface{}) interface{} {
564+
switch rv := reflect.ValueOf(v); rv.Kind() {
565+
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
566+
// Represent primitive types as a pointer to the value.
567+
rv2 := reflect.New(rv.Type())
568+
rv2.Elem().Set(rv)
569+
v = rv2.Interface()
570+
case reflect.Ptr:
571+
// Represent slice types as the value itself.
572+
switch rv.Type().Elem().Kind() {
573+
case reflect.Slice:
574+
if rv.IsNil() {
575+
v = reflect.Zero(rv.Type().Elem()).Interface()
576+
} else {
577+
v = rv.Elem().Interface()
578+
}
579+
}
580+
}
581+
return v
582+
}
583+
584+
// extensionAsStorageType converts an value in the API type as the storage type.
585+
// See Extension.value.
586+
func extensionAsStorageType(v interface{}) interface{} {
587+
switch rv := reflect.ValueOf(v); rv.Kind() {
588+
case reflect.Ptr:
589+
// Represent slice types as the value itself.
590+
switch rv.Type().Elem().Kind() {
591+
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
592+
if rv.IsNil() {
593+
v = reflect.Zero(rv.Type().Elem()).Interface()
594+
} else {
595+
v = rv.Elem().Interface()
596+
}
597+
}
598+
case reflect.Slice:
599+
// Represent slice types as a pointer to the value.
600+
if rv.Type().Elem().Kind() != reflect.Uint8 {
601+
rv2 := reflect.New(rv.Type())
602+
rv2.Elem().Set(rv)
603+
v = rv2.Interface()
604+
}
605+
}
606+
return v
607+
}

proto/pointer_reflect.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,13 @@ func toPointer(i *Message) pointer {
7979

8080
// toAddrPointer converts an interface to a pointer that points to
8181
// the interface data.
82-
func toAddrPointer(i *interface{}, isptr bool) pointer {
82+
func toAddrPointer(i *interface{}, isptr, deref bool) pointer {
8383
v := reflect.ValueOf(*i)
8484
u := reflect.New(v.Type())
8585
u.Elem().Set(v)
86+
if deref {
87+
u = u.Elem()
88+
}
8689
return pointer{v: u}
8790
}
8891

proto/pointer_unsafe.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,21 @@ func toPointer(i *Message) pointer {
8585

8686
// toAddrPointer converts an interface to a pointer that points to
8787
// the interface data.
88-
func toAddrPointer(i *interface{}, isptr bool) pointer {
88+
func toAddrPointer(i *interface{}, isptr, deref bool) (p pointer) {
8989
// Super-tricky - read or get the address of data word of interface value.
9090
if isptr {
9191
// The interface is of pointer type, thus it is a direct interface.
9292
// The data word is the pointer data itself. We take its address.
93-
return pointer{p: unsafe.Pointer(uintptr(unsafe.Pointer(i)) + ptrSize)}
93+
p = pointer{p: unsafe.Pointer(uintptr(unsafe.Pointer(i)) + ptrSize)}
94+
} else {
95+
// The interface is not of pointer type. The data word is the pointer
96+
// to the data.
97+
p = pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
9498
}
95-
// The interface is not of pointer type. The data word is the pointer
96-
// to the data.
97-
return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
99+
if deref {
100+
p.p = *(*unsafe.Pointer)(p.p)
101+
}
102+
return p
98103
}
99104

100105
// valToPointer converts v to a pointer. v must be of pointer type.

proto/table_marshal.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ type marshalElemInfo struct {
8787
sizer sizer
8888
marshaler marshaler
8989
isptr bool // elem is pointer typed, thus interface of this type is a direct interface (extension only)
90+
deref bool // dereference the pointer before operating on it; implies isptr
9091
}
9192

9293
var (
@@ -407,13 +408,22 @@ func (u *marshalInfo) getExtElemInfo(desc *ExtensionDesc) *marshalElemInfo {
407408
panic("tag is not an integer")
408409
}
409410
wt := wiretype(tags[0])
411+
if t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct {
412+
t = t.Elem()
413+
}
410414
sizer, marshaler := typeMarshaler(t, tags, false, false)
415+
var deref bool
416+
if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
417+
t = reflect.PtrTo(t)
418+
deref = true
419+
}
411420
e = &marshalElemInfo{
412421
wiretag: uint64(tag)<<3 | wt,
413422
tagsize: SizeVarint(uint64(tag) << 3),
414423
sizer: sizer,
415424
marshaler: marshaler,
416425
isptr: t.Kind() == reflect.Ptr,
426+
deref: deref,
417427
}
418428

419429
// update cache
@@ -2310,8 +2320,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
23102320
for _, k := range m.MapKeys() {
23112321
ki := k.Interface()
23122322
vi := m.MapIndex(k).Interface()
2313-
kaddr := toAddrPointer(&ki, false) // pointer to key
2314-
vaddr := toAddrPointer(&vi, valIsPtr) // pointer to value
2323+
kaddr := toAddrPointer(&ki, false, false) // pointer to key
2324+
vaddr := toAddrPointer(&vi, valIsPtr, false) // pointer to value
23152325
siz := keySizer(kaddr, 1) + valSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
23162326
n += siz + SizeVarint(uint64(siz)) + tagsize
23172327
}
@@ -2329,8 +2339,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
23292339
for _, k := range keys {
23302340
ki := k.Interface()
23312341
vi := m.MapIndex(k).Interface()
2332-
kaddr := toAddrPointer(&ki, false) // pointer to key
2333-
vaddr := toAddrPointer(&vi, valIsPtr) // pointer to value
2342+
kaddr := toAddrPointer(&ki, false, false) // pointer to key
2343+
vaddr := toAddrPointer(&vi, valIsPtr, false) // pointer to value
23342344
b = appendVarint(b, tag)
23352345
siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
23362346
b = appendVarint(b, uint64(siz))
@@ -2399,7 +2409,7 @@ func (u *marshalInfo) sizeExtensions(ext *XXX_InternalExtensions) int {
23992409
// the last time this function was called.
24002410
ei := u.getExtElemInfo(e.desc)
24012411
v := e.value
2402-
p := toAddrPointer(&v, ei.isptr)
2412+
p := toAddrPointer(&v, ei.isptr, ei.deref)
24032413
n += ei.sizer(p, ei.tagsize)
24042414
}
24052415
mu.Unlock()
@@ -2434,7 +2444,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
24342444

24352445
ei := u.getExtElemInfo(e.desc)
24362446
v := e.value
2437-
p := toAddrPointer(&v, ei.isptr)
2447+
p := toAddrPointer(&v, ei.isptr, ei.deref)
24382448
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
24392449
if !nerr.Merge(err) {
24402450
return b, err
@@ -2465,7 +2475,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
24652475

24662476
ei := u.getExtElemInfo(e.desc)
24672477
v := e.value
2468-
p := toAddrPointer(&v, ei.isptr)
2478+
p := toAddrPointer(&v, ei.isptr, ei.deref)
24692479
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
24702480
if !nerr.Merge(err) {
24712481
return b, err
@@ -2510,7 +2520,7 @@ func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
25102520

25112521
ei := u.getExtElemInfo(e.desc)
25122522
v := e.value
2513-
p := toAddrPointer(&v, ei.isptr)
2523+
p := toAddrPointer(&v, ei.isptr, ei.deref)
25142524
n += ei.sizer(p, 1) // message, tag = 3 (size=1)
25152525
}
25162526
mu.Unlock()
@@ -2553,7 +2563,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
25532563

25542564
ei := u.getExtElemInfo(e.desc)
25552565
v := e.value
2556-
p := toAddrPointer(&v, ei.isptr)
2566+
p := toAddrPointer(&v, ei.isptr, ei.deref)
25572567
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
25582568
if !nerr.Merge(err) {
25592569
return b, err
@@ -2591,7 +2601,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
25912601

25922602
ei := u.getExtElemInfo(e.desc)
25932603
v := e.value
2594-
p := toAddrPointer(&v, ei.isptr)
2604+
p := toAddrPointer(&v, ei.isptr, ei.deref)
25952605
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
25962606
b = append(b, 1<<3|WireEndGroup)
25972607
if !nerr.Merge(err) {
@@ -2621,7 +2631,7 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
26212631

26222632
ei := u.getExtElemInfo(e.desc)
26232633
v := e.value
2624-
p := toAddrPointer(&v, ei.isptr)
2634+
p := toAddrPointer(&v, ei.isptr, ei.deref)
26252635
n += ei.sizer(p, ei.tagsize)
26262636
}
26272637
return n
@@ -2656,7 +2666,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
26562666

26572667
ei := u.getExtElemInfo(e.desc)
26582668
v := e.value
2659-
p := toAddrPointer(&v, ei.isptr)
2669+
p := toAddrPointer(&v, ei.isptr, ei.deref)
26602670
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
26612671
if !nerr.Merge(err) {
26622672
return b, err

0 commit comments

Comments
 (0)