Skip to content

Commit e51d002

Browse files
committed
net/proto2: remove <message>.ExtensionMap() from generated messages
Turn generated message struct field XXX_Extensions map[int32]Extension into an embedded proto.InternalExtensions struct InternalExtensions is a struct without any exported fields and methods. This effectively makes the representation of the extension map private. The proto package can access InternalExtensions by checking that the generated struct has the method 'extmap() proto.InternalExtensions'. Also lock accesses to the extension map. This change bumps the Go protobuf generated code version number. Any .pb.go files generated with this version of the proto package or later will require this version or later of the proto package to compile.
1 parent cd85f19 commit e51d002

File tree

15 files changed

+274
-82
lines changed

15 files changed

+274
-82
lines changed

jsonpb/jsonpb.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,14 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
233233
}
234234

235235
// Handle proto2 extensions.
236-
if ep, ok := v.(extendableProto); ok {
236+
if ep, ok := v.(proto.Message); ok {
237237
extensions := proto.RegisteredExtensions(v)
238-
extensionMap := ep.ExtensionMap()
239238
// Sort extensions for stable output.
240-
ids := make([]int32, 0, len(extensionMap))
241-
for id := range extensionMap {
239+
ids := make([]int32, 0, len(extensions))
240+
for id, desc := range extensions {
241+
if !proto.HasExtension(ep, desc) {
242+
continue
243+
}
242244
ids = append(ids, id)
243245
}
244246
sort.Sort(int32Slice(ids))
@@ -767,13 +769,6 @@ func acceptedJSONFieldNames(prop *proto.Properties) fieldNames {
767769
return opts
768770
}
769771

770-
// extendableProto is an interface implemented by any protocol buffer that may be extended.
771-
type extendableProto interface {
772-
proto.Message
773-
ExtensionRangeArray() []proto.ExtensionRange
774-
ExtensionMap() map[int32]proto.Extension
775-
}
776-
777772
// Writer wrapper inspired by https://blog.golang.org/errors-are-values
778773
type errWriter struct {
779774
writer io.Writer

proto/clone.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) {
8484
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
8585
}
8686

87-
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
88-
emOut := out.Addr().Interface().(extendableProto)
89-
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
87+
if emIn, ok := extendable(in.Addr().Interface()); ok {
88+
emOut, _ := extendable(out.Addr().Interface())
89+
mIn, muIn := emIn.extensionsRead()
90+
if mIn != nil {
91+
mOut := emOut.extensionsWrite()
92+
muIn.Lock()
93+
mergeExtension(mOut, mIn)
94+
muIn.Unlock()
95+
}
9096
}
9197

9298
uf := in.FieldByName("XXX_unrecognized")

proto/decode.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
390390
if !ok {
391391
// Maybe it's an extension?
392392
if prop.extendable {
393-
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
393+
if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
394394
if err = o.skip(st, tag, wire); err == nil {
395-
ext := e.ExtensionMap()[int32(tag)] // may be missing
395+
extmap := e.extensionsWrite()
396+
ext := extmap[int32(tag)] // may be missing
396397
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
397-
e.ExtensionMap()[int32(tag)] = ext
398+
extmap[int32(tag)] = ext
398399
}
399400
continue
400401
}

proto/encode.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
10731073

10741074
// Encode an extension map.
10751075
func (o *Buffer) enc_map(p *Properties, base structPointer) error {
1076-
v := *structPointer_ExtMap(base, p.field)
1077-
if err := encodeExtensionMap(v); err != nil {
1076+
exts := structPointer_ExtMap(base, p.field)
1077+
if err := encodeExtensionsMap(*exts); err != nil {
10781078
return err
10791079
}
1080+
1081+
return o.enc_map_body(*exts)
1082+
}
1083+
1084+
func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
1085+
exts := structPointer_Extensions(base, p.field)
1086+
if err := encodeExtensions(exts); err != nil {
1087+
return err
1088+
}
1089+
v, _ := exts.extensionsRead()
1090+
1091+
return o.enc_map_body(v)
1092+
}
1093+
1094+
func (o *Buffer) enc_map_body(v map[int32]Extension) error {
10801095
// Fast-path for common cases: zero or one extensions.
10811096
if len(v) <= 1 {
10821097
for _, e := range v {
@@ -1099,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
10991114
}
11001115

11011116
func size_map(p *Properties, base structPointer) int {
1102-
v := *structPointer_ExtMap(base, p.field)
1103-
return sizeExtensionMap(v)
1117+
v := structPointer_ExtMap(base, p.field)
1118+
return extensionsMapSize(*v)
1119+
}
1120+
1121+
func size_exts(p *Properties, base structPointer) int {
1122+
v := structPointer_Extensions(base, p.field)
1123+
return extensionsSize(v)
11041124
}
11051125

11061126
// Encode a map field.

proto/equal.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ func equalStruct(v1, v2 reflect.Value) bool {
121121
}
122122
}
123123

124-
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
125-
em2 := v2.FieldByName("XXX_extensions")
126-
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
124+
if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
125+
em2 := v2.FieldByName("XXX_InternalExtensions")
126+
if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
127127
return false
128128
}
129129
}
@@ -223,8 +223,10 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
223223
}
224224

225225
// base is the struct type that the extensions are based on.
226-
// em1 and em2 are extension maps.
227-
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
226+
// x1 and x2 are InternalExtensions.
227+
func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
228+
em1, _ := x1.extensionsRead()
229+
em2, _ := x2.extensionsRead()
228230
if len(em1) != len(em2) {
229231
return false
230232
}

0 commit comments

Comments
 (0)