diff --git a/src/crypto/x509/x509.go b/src/crypto/x509/x509.go index 338b48861c92e0..1c8adb36da6d88 100644 --- a/src/crypto/x509/x509.go +++ b/src/crypto/x509/x509.go @@ -760,7 +760,16 @@ type Certificate struct { // CRL Distribution Points CRLDistributionPoints []string + // Policy identifiers with sub-oid values less than or equal to math.MaxInt32. + // When parsing a certificate, the certificate policy identifiers are unmarshaled into + // PolicyIdentifiers if all sub-oids are less than or equal to math.MaxInt32. + // If at least one sub-oid is greater than math.MaxInt32, policy identifiers are + // unmarshaled into PolicyIdentifiersExt. + // When generating a certificate, set either PolicyIdentifiers or PolicyIdentifiersExt + // but not both. PolicyIdentifiers []asn1.ObjectIdentifier + // Policy identifiers with sub-oid values greater than math.MaxInt32. + PolicyIdentifiersExt []asn1.ObjectIdentifierExt } // ErrUnsupportedAlgorithm results from attempting to perform an operation that @@ -943,7 +952,7 @@ type basicConstraints struct { // RFC 5280 4.2.1.4 type policyInformation struct { - Policy asn1.ObjectIdentifier + Policy asn1.ObjectIdentifierExt // policyQualifiers omitted } @@ -1517,9 +1526,25 @@ func parseCertificate(in *certificate) (*Certificate, error) { } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 certificate policies") } - out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, len(policies)) - for i, policy := range policies { - out.PolicyIdentifiers[i] = policy.Policy + s := make([]asn1.ObjectIdentifier, 0, len(policies)) + largeoids := false + for _, policy := range policies { + if oid, err1 := policy.Policy.GetObjectIdentifier(); err1 == nil { + // Add to PolicyIdentifiers if sub-oids are less than or equal to math.MaxInt32. + s = append(s, oid) + } else { + largeoids = true + } + } + if largeoids { + // The certificate contains at least one policy identifier that has a sub-oid + // value greater than math.MaxInt32. + out.PolicyIdentifiersExt = make([]asn1.ObjectIdentifierExt, len(policies)) + for i, policy := range policies { + out.PolicyIdentifiersExt[i] = policy.Policy + } + } else { + out.PolicyIdentifiers = s } default: @@ -1811,11 +1836,30 @@ func buildExtensions(template *Certificate, subjectIsEmpty bool, authorityKeyId n++ } + if len(template.PolicyIdentifiers) > 0 && len(template.PolicyIdentifiersExt) > 0 { + err = errors.New("x509: invalid template, cannot specify both PolicyIdentifiers and PolicyIdentifiersExt") + return + } + if len(template.PolicyIdentifiers) > 0 && !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { ret[n].Id = oidExtensionCertificatePolicies policies := make([]policyInformation, len(template.PolicyIdentifiers)) for i, policy := range template.PolicyIdentifiers { + policies[i].Policy = asn1.NewObjectIdentifierExt(policy) + } + ret[n].Value, err = asn1.Marshal(policies) + if err != nil { + return + } + n++ + } + + if len(template.PolicyIdentifiersExt) > 0 && + !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { + ret[n].Id = oidExtensionCertificatePolicies + policies := make([]policyInformation, len(template.PolicyIdentifiersExt)) + for i, policy := range template.PolicyIdentifiersExt { policies[i].Policy = policy } ret[n].Value, err = asn1.Marshal(policies) @@ -2067,6 +2111,7 @@ var emptyASN1Subject = []byte{0x30, 0} // - PermittedIPRanges // - PermittedURIDomains // - PolicyIdentifiers +// - PolicyIdentifiersExt // - SerialNumber // - SignatureAlgorithm // - Subject diff --git a/src/crypto/x509/x509_test.go b/src/crypto/x509/x509_test.go index 0141021504e137..052cc4d0759545 100644 --- a/src/crypto/x509/x509_test.go +++ b/src/crypto/x509/x509_test.go @@ -22,6 +22,7 @@ import ( "encoding/pem" "fmt" "internal/testenv" + "math" "math/big" "net" "net/url" @@ -686,6 +687,10 @@ func TestCreateSelfSignedCertificate(t *testing.T) { t.Errorf("%s: failed to parse policy identifiers: got:%#v want:%#v", test.name, cert.PolicyIdentifiers, template.PolicyIdentifiers) } + if len(cert.PolicyIdentifiersExt) > 0 { + t.Errorf("%s: unexpected PolicyIdentifiersExt value:%#v", test.name, template.PolicyIdentifiersExt) + } + if len(cert.PermittedDNSDomains) != 2 || cert.PermittedDNSDomains[0] != ".example.com" || cert.PermittedDNSDomains[1] != "example.com" { t.Errorf("%s: failed to parse name constraints: %#v", test.name, cert.PermittedDNSDomains) } @@ -806,6 +811,82 @@ func TestCreateSelfSignedCertificate(t *testing.T) { } } +func TestCreateSelfSignedCertificatePolicyIdentifier(t *testing.T) { + random := rand.Reader + + ecdsaPriv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ECDSA key: %s", err) + } + { + // Cannot have a certificate template with both PolicyIdentifiers and PolicyIdentifiersExt. + template := Certificate{ + SerialNumber: big.NewInt(-1), + SignatureAlgorithm: ECDSAWithSHA256, + PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + PolicyIdentifiersExt: []asn1.ObjectIdentifierExt{asn1.ObjectIdentifierExt{1, 2, 3, 4}}, + } + _, err := CreateCertificate(random, &template, &template, &ecdsaPriv.PublicKey, ecdsaPriv) + if err == nil { + t.Error("specifying PolicyIdentifiers and PolicyIdentifiersExt should fail") + } + } + { + // Certificate template with policy identifier that has sub-oids less than math.MaxInt32. + template := Certificate{ + SerialNumber: big.NewInt(-1), + SignatureAlgorithm: ECDSAWithSHA256, + PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + } + derBytes, err := CreateCertificate(random, &template, &template, &ecdsaPriv.PublicKey, ecdsaPriv) + if err != nil { + t.Errorf("failed to create certificate: %s", err) + } + cert, err := ParseCertificate(derBytes) + if err != nil { + t.Errorf("failed to parse certificate: %s", err) + } + if len(cert.PolicyIdentifiers) != 1 || !cert.PolicyIdentifiers[0].Equal(template.PolicyIdentifiers[0]) { + t.Errorf("failed to parse policy identifiers: got:%#v want:%#v", cert.PolicyIdentifiers, template.PolicyIdentifiers) + } + if len(cert.PolicyIdentifiersExt) > 0 { + t.Errorf("unexpected PolicyIdentifiersExt value:%#v", template.PolicyIdentifiersExt) + } + } + { + // Certificate template with policy identifier that has sub-oids greater than math.MaxInt32. + template := Certificate{ + SerialNumber: big.NewInt(-1), + SignatureAlgorithm: ECDSAWithSHA256, + PolicyIdentifiersExt: []asn1.ObjectIdentifierExt{ + asn1.ObjectIdentifierExt{1, 2, math.MaxInt32}, + asn1.ObjectIdentifierExt{1, 2, 1 << 60}, + asn1.ObjectIdentifierExt{1, 2, new(big.Int).Lsh(big.NewInt(1), 80)}, + asn1.ObjectIdentifierExt{1, 2, new(big.Int).Lsh(big.NewInt(1), 80), 1 << 60}, + }, + } + derBytes, err := CreateCertificate(random, &template, &template, &ecdsaPriv.PublicKey, ecdsaPriv) + if err != nil { + t.Errorf("failed to create certificate: %s", err) + } + cert, err := ParseCertificate(derBytes) + if err != nil { + t.Errorf("failed to parse certificate: %s", err) + } + if len(cert.PolicyIdentifiers) > 0 { + t.Errorf("unexpected PolicyIdentifiers value:%#v", template.PolicyIdentifiers) + } + if len(cert.PolicyIdentifiersExt) == 0 || len(cert.PolicyIdentifiersExt) != len(template.PolicyIdentifiersExt) { + t.Errorf("failed to parse policy identifiers: got:%#v want:%#v", cert.PolicyIdentifiersExt, template.PolicyIdentifiersExt) + } + for i, pi := range cert.PolicyIdentifiersExt { + if !pi.Equal(template.PolicyIdentifiersExt[i]) { + t.Errorf("failed to parse policy identifiers: got:%#v want:%#v", cert.PolicyIdentifiersExt, template.PolicyIdentifiersExt) + } + } + } +} + // Self-signed certificate using ECDSA with SHA1 & secp256r1 var ecdsaSHA1CertPem = ` -----BEGIN CERTIFICATE----- @@ -2189,6 +2270,66 @@ func TestCriticalNameConstraintWithUnknownType(t *testing.T) { } } +const certWithLargeSubOidPEM = ` +-----BEGIN CERTIFICATE----- +MIIFZjCCA06gAwIBAgITFgAAAAImoUeGgGDlrAAAAAAAAjANBgkqhkiG9w0BAQsF +ADAYMRYwFAYDVQQDEw1DaG9ydXNSb290LUNBMB4XDTE0MDMyNTIxMjYxNFoXDTM0 +MDMxOTIzMjAwNlowYTESMBAGCgmSJomT8ixkARkWAm56MRIwEAYKCZImiZPyLGQB +GRYCY28xFjAUBgoJkiaJk/IsZAEZFgZjaG9ydXMxHzAdBgNVBAMTFmNob3J1cy1D +U01TUFJQS0kxNTgtQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC3 +YrgjhkQJqYjL27DAbQim8c3wVldUURBWMLExycwGIbMCxyRFrQmMhqV3bBaKxixx +jIAZTVB7hha6HCuR/fkfYlNo1suiu1g7WIPxdecV23CuvOiVfTbI7j8LlijsVbKW +2jOy7LBSywaU58aPS95UqUfqtY53pWbFzQQu//hovqFFwk12mApu42SqmcupxS7/ +tmhkaC+wgliaiS8p+CJZGSBUekuVQNclLGqyYUeBlO3jjIwVZzh9qlLaEbO7NLG8 +k3A5w/9T3r195AmA4+sXKlj9nV9zS6Q8t6ygB6g6/Hr2sv8Xogi3AAR65HghSz2z +kRbANOUVuVtCMHfiJUYHAgMBAAGjggFeMIIBWjAQBgkrBgEEAYI3FQEEAwIBADAd +BgNVHQ4EFgQUPn4jsmFshpUupvUcLDmp2LpJCaowgYwGA1UdIASBhDCBgTB/Bg0q +JIH5z5nyYIUaQgEBMG4wOgYIKwYBBQUHAgIwLh4sAEwAZQBnAGEAbAAgAFAAbwBs +AGkAYwB5ACAAUwB0AGEAdABlAG0AZQBuAHQwMAYIKwYBBQUHAgEWJGh0dHA6Ly9j +cmwuY2hvcnVzLmNvLm56L3BraS9jcHMudHh0ADAZBgkrBgEEAYI3FAIEDB4KAFMA +dQBiAEMAQTALBgNVHQ8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAfBgNVHSMEGDAW +gBTPctdjhzmVPatJ73QufzslwP7q3DA+BgNVHR8ENzA1MDOgMaAvhi1odHRwOi8v +Y3JsLmNob3J1cy5jby5uei9wa2kvQ2hvcnVzUm9vdC1DQS5jcmwwDQYJKoZIhvcN +AQELBQADggIBAKFEeteUZqZXv95+hpjYFGj6NubVRmbmIH1DU2nydY+RdOZAhn6b +0ozXoTtprEoo5POUjNZOz7btr08SCbtYQsm4nL5NHj3JSuMj0jDlnn8Qs4yadk5D +3rTMOf6ZdwVqqZuctwfjlfXgOvPHnVsbUsK02x4b6yJqbbQu7KxxkVoSuneOWpHd +ZPNqF8aigupfTn5wylKFz2zW39yRQbu1Xbbe31xjqN6g0T/57+myf1j6PtyntmcX +8n31ZFLCtuC1uXEvN4Mlr0FGXoMpwzlysHzWejFWRQ7Oj7O9/pyzHFgxrlbZilp4 +7qDATM212smJDReEFaFTVR6CgA4xZC4xADL0SU/6MNa2vA4bg8bVlQ5XxBOyRBfq +PEItYXN3dp7c9medpAh1QauNpFZL7n4DA63X2zB97o6N8fyNSzJXp6x1qGUUgSSz +QSF65ypj/QDdRwczNmvBlSAFQoFEQpYJarJXPBj9859y1ZkDCctYz6lXUA1qjkCT +/mM6UTdnDD7LR3vQmQo6t8ydr4sQVV8O1ZJmAxHt+4qYg/UpHdPt63mneHPZaoE9 +N9tF+yS8w1F2Mw5YP6OyjZJy6o8kXR3I/jao4hvDVjfEzo4eSrucAsqwr3Q4kUgl +GYhu9NN9++fuqOoXuWBguh+3//dgMOXhgFVsWvx9fzLiQQqUc5ToGbJK +-----END CERTIFICATE-----` + +// X.509 certificates may contain policy identifiers that have +// sub-oid values greater than math.MaxInt32. ParseCertificate +// must not return an error when parsing such certificates. +func TestCertificateWithLargeSubOidInPolicyIdentifiers(t *testing.T) { + block, _ := pem.Decode([]byte(certWithLargeSubOidPEM)) + c, err := ParseCertificate(block.Bytes) + if err != nil { + t.Errorf("Failed to parse certificate: %v", err.Error()) + } + if len(c.PolicyIdentifiersExt) != 1 { + // The Certificate policy identifier is 1.2.36.67006527840.666.66.1.1 + // Because it does not fit in the PolicyIdentifiers field, it is + // unmarshaled in the PolicyIdentifiersExt field. + t.Errorf("PolicyIdentifiersExt expected 1 but got %d", len(c.PolicyIdentifiersExt)) + } + expected := "1.2.36.67006527840.666.66.1.1" + if c.PolicyIdentifiersExt[0].String() != expected { + t.Errorf("PolicyIdentifiersExt expected %s but got %s", expected, c.PolicyIdentifiersExt[0].String()) + } + if len(c.PolicyIdentifiers) > 0 { + // The Certificate policy identifier for this particular x509 certificate + // does not fit in the PolicyIdentifiers field, because suboid 67006527840 + // is more than 2^31 + t.Errorf("PolicyIdentifiers expected zero length but got %d", len(c.PolicyIdentifiers)) + } +} + const badIPMaskPEM = ` -----BEGIN CERTIFICATE----- MIICzzCCAbegAwIBAgICEjQwDQYJKoZIhvcNAQELBQAwHTEbMBkGA1UEAxMSQmFk diff --git a/src/encoding/asn1/asn1.go b/src/encoding/asn1/asn1.go index d809dde2781cd8..1ef2fb5e06ab39 100644 --- a/src/encoding/asn1/asn1.go +++ b/src/encoding/asn1/asn1.go @@ -127,7 +127,12 @@ func parseInt32(bytes []byte) (int32, error) { return int32(ret64), nil } -var bigOne = big.NewInt(1) +var ( + bigOne = big.NewInt(1) + bigTwo = big.NewInt(2) + bigForty = big.NewInt(40) + bigEighty = big.NewInt(80) +) // parseBigInt treats the given bytes as a big-endian, signed integer and returns // the result. @@ -221,6 +226,14 @@ var NullBytes = []byte{TagNull, 0} // An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER. type ObjectIdentifier []int +// A SubOid represents a single INTEGER value in a ASN.1 OBJECT IDENTIFIER. +// The type may be int, int64 or *big.Int. +type SubOid interface{} + +// An ObjectIdentifierExt represents an ASN.1 OBJECT IDENTIFIER +// with sub-oids that can be int, int64 or *big.Int. +type ObjectIdentifierExt []SubOid + // Equal reports whether oi and other represent the same identifier. func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool { if len(oi) != len(other) { @@ -251,7 +264,7 @@ func (oi ObjectIdentifier) String() string { // parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and // returns it. An object identifier is a sequence of variable length integers // that are assigned in a hierarchy. -func parseObjectIdentifier(bytes []byte) (s ObjectIdentifier, err error) { +func parseObjectIdentifier(bytes []byte) (s ObjectIdentifierExt, err error) { if len(bytes) == 0 { err = SyntaxError{"zero length OBJECT IDENTIFIER"} return @@ -259,36 +272,186 @@ func parseObjectIdentifier(bytes []byte) (s ObjectIdentifier, err error) { // In the worst case, we get two elements from the first byte (which is // encoded differently) and then every varint is a single byte long. - s = make([]int, len(bytes)+1) + s = ObjectIdentifierExt(make([]SubOid, len(bytes)+1)) // The first varint is 40*value1 + value2: // According to this packing, value1 can take the values 0, 1 and 2 only. // When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2, // then there are no restrictions on value2. - v, offset, err := parseBase128Int(bytes, 0) + val, offset, err := parseBase128Int(bytes, 0) if err != nil { return } - if v < 80 { - s[0] = v / 40 - s[1] = v % 40 - } else { - s[0] = 2 - s[1] = v - 80 + switch v := val.(type) { + case int: + if v < 80 { + s[0] = v / 40 + s[1] = v % 40 + } else { + s[0] = 2 + s[1] = v - 80 + } + case int64: + if v < 80 { + s[0] = v / 40 + s[1] = v % 40 + } else { + s[0] = 2 + s[1] = v - 80 + } + case *big.Int: + if v.Cmp(bigEighty) == -1 { + s[0] = new(big.Int).Div(v, bigForty) + s[1] = new(big.Int).Mod(v, bigForty) + } else { + s[0] = bigTwo + s[1] = new(big.Int).Sub(v, bigEighty) + } } i := 2 for ; offset < len(bytes); i++ { - v, offset, err = parseBase128Int(bytes, offset) + val, offset, err = parseBase128Int(bytes, offset) if err != nil { return } - s[i] = v + s[i] = val } s = s[0:i] return } +// NewObjectIdentifierExt creates and returns a ObjectIdentifierExt from a ObjectIdentifier. +func NewObjectIdentifierExt(oid ObjectIdentifier) (s ObjectIdentifierExt) { + s = make(ObjectIdentifierExt, len(oid), len(oid)) + for i, o := range oid { + s[i] = o + } + return +} + +// GetObjectIdentifier returns the object identifier as a slice of int. +// An error is returned if at least one of the sub oid is greater than math.MaxInt32. +func (oi ObjectIdentifierExt) GetObjectIdentifier() (s ObjectIdentifier, err error) { + s = make(ObjectIdentifier, len(oi), len(oi)) + for i, sub := range oi { + switch v := sub.(type) { + case int: + s[i] = v + case int64: + if v > math.MaxInt32 { + err = StructuralError{"base 128 integer too large"} + return + } + s[i] = int(v) + case *big.Int: + if !v.IsInt64() || v.Int64() > math.MaxInt32 { + err = StructuralError{"base 128 integer too large"} + return + } + s[i] = int(v.Int64()) + } + } + return +} + +// Equal reports whether oi and other represent the same identifier. +func (oi ObjectIdentifierExt) Equal(other interface{}) bool { + comp := func(i int, suboid interface{}) bool { + switch v1 := oi[i].(type) { + case int: + switch v2 := suboid.(type) { + case int: + if v1 != v2 { + return false + } + case int64: + if int64(v1) != v2 { + return false + } + case *big.Int: + if big.NewInt(int64(v1)).Cmp(v2) != 0 { + return false + } + } + case int64: + switch v2 := suboid.(type) { + case int: + if v1 != int64(v2) { + return false + } + case int64: + if v1 != v2 { + return false + } + case *big.Int: + if big.NewInt(v1).Cmp(v2) != 0 { + return false + } + } + case *big.Int: + switch v2 := suboid.(type) { + case int: + if v1.Cmp(big.NewInt(int64(v2))) != 0 { + return false + } + case int64: + if v1.Cmp(big.NewInt(v2)) != 0 { + return false + } + case *big.Int: + if v1.Cmp(v2) != 0 { + return false + } + } + } + return true + } + switch v := other.(type) { + case ObjectIdentifier: + if len(oi) != len(v) { + return false + } + for i := 0; i < len(oi); i++ { + if !comp(i, v[i]) { + return false + } + } + case ObjectIdentifierExt: + if len(oi) != len(v) { + return false + } + for i := 0; i < len(oi); i++ { + if !comp(i, v[i]) { + return false + } + } + default: + return false + } + return true +} + +func (oi ObjectIdentifierExt) String() string { + var s string + + for i, v := range oi { + if i > 0 { + s += "." + } + switch val := v.(type) { + case int: + s += strconv.FormatInt(int64(val), 10) + case int64: + s += strconv.FormatInt(val, 10) + case *big.Int: + s += val.String() + } + } + + return s +} + // ENUMERATED // An Enumerated is represented as a plain int. @@ -301,17 +464,22 @@ type Flag bool // parseBase128Int parses a base-128 encoded int from the given offset in the // given byte slice. It returns the value and the new offset. -func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) { +// The return value may be a int, int64 or big.Int depending on the size of the input data. +func parseBase128Int(bytes []byte, initOffset int) (ret SubOid, offset int, err error) { offset = initOffset var ret64 int64 + var retBigInt *big.Int for shifted := 0; offset < len(bytes); shifted++ { - // 5 * 7 bits per byte == 35 bits of data - // Thus the representation is either non-minimal or too large for an int32 - if shifted == 5 { - err = StructuralError{"base 128 integer too large"} - return + // 9 * 7 bits per byte == 63 bits of data + // Thus the representation is either non-minimal or too large for an int64. Use big.Int + if shifted >= 9 { + if shifted == 9 { + retBigInt = big.NewInt(ret64) + } + retBigInt.Lsh(retBigInt, 7) + } else { + ret64 <<= 7 } - ret64 <<= 7 b := bytes[offset] // integers should be minimally encoded, so the leading octet should // never be 0x80 @@ -319,13 +487,19 @@ func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) err = SyntaxError{"integer is not minimally encoded"} return } - ret64 |= int64(b & 0x7f) + if shifted >= 9 { + retBigInt.Or(retBigInt, big.NewInt(int64(b&0x7f))) + } else { + ret64 |= int64(b & 0x7f) + } offset++ if b&0x80 == 0 { - ret = int(ret64) - // Ensure that the returned value fits in an int on all platforms - if ret64 > math.MaxInt32 { - err = StructuralError{"base 128 integer too large"} + if ret64 <= math.MaxInt32 { + ret = int(ret64) + } else if shifted >= 9 { + ret = retBigInt + } else { + ret = ret64 } return } @@ -541,10 +715,16 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i // If the bottom five bits are set, then the tag number is actually base 128 // encoded afterwards if ret.tag == 0x1f { - ret.tag, offset, err = parseBase128Int(bytes, offset) + var v interface{} + v, offset, err = parseBase128Int(bytes, offset) if err != nil { return } + var ok bool + if ret.tag, ok = v.(int); !ok { + err = StructuralError{"base 128 integer too large"} + return + } // Tags should be encoded in minimal form. if ret.tag < 0x1f { err = SyntaxError{"non-minimal tag"} @@ -653,14 +833,15 @@ func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type } var ( - bitStringType = reflect.TypeOf(BitString{}) - objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) - enumeratedType = reflect.TypeOf(Enumerated(0)) - flagType = reflect.TypeOf(Flag(false)) - timeType = reflect.TypeOf(time.Time{}) - rawValueType = reflect.TypeOf(RawValue{}) - rawContentsType = reflect.TypeOf(RawContent(nil)) - bigIntType = reflect.TypeOf(new(big.Int)) + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + objectIdentifierExtType = reflect.TypeOf(ObjectIdentifierExt{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) + bigIntType = reflect.TypeOf(new(big.Int)) ) // invalidLength reports whether offset + length > sliceLength, or if the @@ -715,6 +896,13 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam result, err = parseBitString(innerBytes) case TagOID: result, err = parseObjectIdentifier(innerBytes) + if err == nil { + var r ObjectIdentifier + r, err = result.(ObjectIdentifierExt).GetObjectIdentifier() + if err == nil { + result = r + } + } case TagUTCTime: result, err = parseUTCTime(innerBytes) case TagGeneralizedTime: @@ -857,6 +1045,18 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam v.Set(reflect.ValueOf(result)) return case objectIdentifierType: + newSlice, err1 := parseObjectIdentifier(innerBytes) + v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) + if err1 == nil { + var intSlice []int + intSlice, err1 = newSlice.GetObjectIdentifier() + if err1 == nil { + reflect.Copy(v, reflect.ValueOf(intSlice)) + } + } + err = err1 + return + case objectIdentifierExtType: newSlice, err1 := parseObjectIdentifier(innerBytes) v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) if err1 == nil { @@ -1006,11 +1206,14 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam // canHaveDefaultValue reports whether k is a Kind that we will set a default // value for. (A signed integer, essentially.) -func canHaveDefaultValue(k reflect.Kind) bool { - switch k { +func canHaveDefaultValue(v reflect.Value) bool { + switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return true } + if v.IsValid() && v.Type() == reflect.TypeOf((*big.Int)(nil)) { + return true + } return false } @@ -1026,7 +1229,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { if params.defaultValue == nil { return } - if canHaveDefaultValue(v.Kind()) { + if canHaveDefaultValue(v) { v.SetInt(*params.defaultValue) } return diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go index 8daae97faad400..c0b4b2203f4076 100644 --- a/src/encoding/asn1/asn1_test.go +++ b/src/encoding/asn1/asn1_test.go @@ -225,19 +225,50 @@ func TestBitStringRightAlign(t *testing.T) { } } +type defaultValueTest struct { + v interface{} + ok bool +} + +var defaultValueTestData = []defaultValueTest{ + {int8(4), true}, + {int16(4), true}, + {int(4), true}, + {int32(4), true}, + {int64(4), true}, + {big.NewInt(4), true}, + {byte(4), false}, + {"abc", false}, + {nil, false}, +} + +func TestCanHaveDefaultValue(t *testing.T) { + for _, test := range defaultValueTestData { + if canHaveDefaultValue(reflect.ValueOf(test.v)) != test.ok { + t.Errorf("Bad result. Type '%v' can have default value: %v", reflect.ValueOf(test.v).Type(), test.ok) + } + } +} + type objectIdentifierTest struct { - in []byte - ok bool - out ObjectIdentifier // has base type[]int + in []byte + ok bool + downcastOk bool // True if the ObjectIdentifierExt can be downcast to ObjectIdentifier + out ObjectIdentifierExt // has base type []int, []int64 or []*big.Int } var objectIdentifierTestData = []objectIdentifierTest{ - {[]byte{}, false, []int{}}, - {[]byte{85}, true, []int{2, 5}}, - {[]byte{85, 0x02}, true, []int{2, 5, 2}}, - {[]byte{85, 0x02, 0xc0, 0x00}, true, []int{2, 5, 2, 0x2000}}, - {[]byte{0x81, 0x34, 0x03}, true, []int{2, 100, 3}}, - {[]byte{85, 0x02, 0xc0, 0x80, 0x80, 0x80, 0x80}, false, []int{}}, + {[]byte{}, false, true, ObjectIdentifierExt{}}, + {[]byte{85}, true, true, ObjectIdentifierExt{2, 5}}, + {[]byte{85, 0x02}, true, true, ObjectIdentifierExt{2, 5, 2}}, + {[]byte{85, 0x02, 0xc0, 0x00}, true, true, ObjectIdentifierExt{2, 5, 2, 0x2000}}, + {[]byte{0x81, 0x34, 0x03}, true, true, ObjectIdentifierExt{2, 100, 3}}, + {[]byte{85, 0x02, 0xc0, 0x80, 0x80, 0x80, 0x80}, false, true, ObjectIdentifierExt{}}, + // At least one sub-oid has a value higher than max int32 value, but less than max int64 value + {[]byte{0x2a, 0x24, 0x81, 0xf9, 0xcf, 0x99, 0xf2, 0x60, 0x85, 0x1a, 0x42, 0x01, 0x01}, true, false, ObjectIdentifierExt{1, 2, 36, int64(67006527840), 666, 66, 1, 1}}, + {[]byte{0x2a, 0xc0, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00, 0x01}, true, false, ObjectIdentifierExt{1, 2, int64(1 << 62), 1}}, + // At least one sub-oid has a value higher than max int64 value + {[]byte{0x2a, 0x84, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00, 0x01}, true, false, ObjectIdentifierExt{1, 2, new(big.Int).Lsh(big.NewInt(1), 65), 1}}, } func TestObjectIdentifier(t *testing.T) { @@ -251,11 +282,32 @@ func TestObjectIdentifier(t *testing.T) { t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) } } + oid, err1 := ret.GetObjectIdentifier() + if test.downcastOk { + if err1 != nil { + t.Errorf("#%d: Bad result: %v should have been converted to []int", i, ret) + } + if !ret.Equal(oid) { + t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out) + } + } + if !test.downcastOk && err1 == nil { + t.Errorf("#%d: Bad result: %v should not be converted to []int", i, ret) + } } if s := ObjectIdentifier([]int{1, 2, 3, 4}).String(); s != "1.2.3.4" { t.Errorf("bad ObjectIdentifier.String(). Got %s, want 1.2.3.4", s) } + if s := (ObjectIdentifierExt{1, 2, 3, 4}).String(); s != "1.2.3.4" { + t.Errorf("bad ObjectIdentifierExt.String(). Got %s, want 1.2.3.4", s) + } + if (ObjectIdentifierExt{1, 2, 3, 4}).Equal(nil) { + t.Errorf("bad ObjectIdentifierExt.Equal().") + } + if (ObjectIdentifierExt{1, 2, 3, 4}).Equal("abc") { + t.Errorf("bad ObjectIdentifierExt.Equal().") + } } type timeTest struct { @@ -635,6 +687,73 @@ func TestObjectIdentifierEqual(t *testing.T) { } } +type oiExtEqualTest struct { + first ObjectIdentifierExt + second interface{} + same bool +} + +var oiExtEqualTests = []oiExtEqualTest{ + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifierExt{1, 2, 3}, + true, + }, + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifier{1, 2, 3}, + true, + }, + { + ObjectIdentifierExt{1, 2, 1 << 61}, + ObjectIdentifierExt{1, 2, 1 << 61}, + true, + }, + { + ObjectIdentifierExt{1, 2, new(big.Int).Lsh(big.NewInt(1), 80)}, + ObjectIdentifierExt{1, 2, new(big.Int).Lsh(big.NewInt(1), 80)}, + true, + }, + { + ObjectIdentifierExt{1}, + ObjectIdentifierExt{1, 2, 3}, + false, + }, + { + ObjectIdentifierExt{1}, + ObjectIdentifier{1, 2, 3}, + false, + }, + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifierExt{10, 11, 12}, + false, + }, + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifier{10, 11, 12}, + false, + }, + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifierExt{10, 11, 12}, + false, + }, + { + ObjectIdentifierExt{1, 2, 3}, + ObjectIdentifier{10, 11, 12}, + false, + }, +} + +func TestObjectIdentifierExtEqual(t *testing.T) { + for _, o := range oiEqualTests { + if s := o.first.Equal(o.second); s != o.same { + t.Errorf("ObjectIdentifierExt.Equal: got: %t want: %t", s, o.same) + } + } +} + var derEncodedSelfSignedCert = Certificate{ TBSCertificate: TBSCertificate{ Version: 0, diff --git a/src/encoding/asn1/common.go b/src/encoding/asn1/common.go index e2aa8bd9c578e3..e078e60b9ab9f8 100644 --- a/src/encoding/asn1/common.go +++ b/src/encoding/asn1/common.go @@ -150,7 +150,7 @@ func getUniversalType(t reflect.Type) (matchAny bool, tagNumber int, isCompound, switch t { case rawValueType: return true, -1, false, true - case objectIdentifierType: + case objectIdentifierType, objectIdentifierExtType: return false, TagOID, false, true case bitStringType: return false, TagBitString, false, true diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go index 0d34d5aa1e8152..394436ccb68d60 100644 --- a/src/encoding/asn1/marshal.go +++ b/src/encoding/asn1/marshal.go @@ -8,6 +8,7 @@ import ( "bytes" "errors" "fmt" + "math" "math/big" "reflect" "sort" @@ -178,6 +179,19 @@ func base128IntLength(n int64) int { return l } +func base128BigIntLength(n *big.Int) int { + if n.Sign() == 0 { + return 1 + } + + l := 0 + for i := new(big.Int).Set(n); i.Sign() > 0; i.Rsh(i, 7) { + l++ + } + + return l +} + func appendBase128Int(dst []byte, n int64) []byte { l := base128IntLength(n) @@ -194,6 +208,28 @@ func appendBase128Int(dst []byte, n int64) []byte { return dst } +func appendBase128BigInt(dst []byte, n *big.Int) []byte { + l := base128BigIntLength(n) + + for i := l - 1; i >= 0; i-- { + b := new(big.Int).Rsh(n, uint(i*7)) + var o byte + if b.Sign() == 0 { + o = 0 + } else { + o = b.Bytes()[len(b.Bytes())-1] + } + o &= 0x7f + if i != 0 { + o |= 0x80 + } + + dst = append(dst, o) + } + + return dst +} + func makeBigInt(n *big.Int) (encoder, error) { if n == nil { return nil, StructuralError{"empty integer"} @@ -302,6 +338,84 @@ func (oid oidEncoder) Encode(dst []byte) { } } +type oidEncoderExt ObjectIdentifierExt + +// suboid12Encoding returns 40 * value1 + value2, either as int64 or *big.Int. +// First sub-oid is limited to values 0, 1, and 2. +// Second sub-oid is limited to the range 0 to 39 when value1 is 0 or 1. +func (oid oidEncoderExt) suboid12Encoding() interface{} { + var s interface{} + switch v := oid[0].(type) { + case int: + s = int64(v) * 40 + case int64: + // This cannot overflow because sub-oid is limited to values 0, 1, and 2. + s = v * 40 + case *big.Int: + // This can be converted to int64 because sub-oid is limited to values 0, 1, and 2. + s = v.Int64() * 40 + } + switch v := oid[1].(type) { + case int: + if math.MaxInt64-int64(v) >= s.(int64) { + s = s.(int64) + int64(v) + } else { + s = new(big.Int).Add(big.NewInt(s.(int64)), big.NewInt(int64(v))) + } + case int64: + if math.MaxInt64-int64(v) >= s.(int64) { + s = s.(int64) + int64(v) + } else { + s = new(big.Int).Add(big.NewInt(s.(int64)), big.NewInt(v)) + } + case *big.Int: + s = new(big.Int).Add(big.NewInt(s.(int64)), v) + } + return s +} + +func (oid oidEncoderExt) Len() int { + s := oid.suboid12Encoding() + var l int + switch v := s.(type) { + case int64: + l = base128IntLength(v) + case *big.Int: + l = base128BigIntLength(v) + } + for i := 2; i < len(oid); i++ { + switch v := oid[i].(type) { + case int: + l += base128IntLength(int64(v)) + case int64: + l += base128IntLength(v) + case *big.Int: + l += base128BigIntLength(v) + } + } + return l +} + +func (oid oidEncoderExt) Encode(dst []byte) { + s := oid.suboid12Encoding() + switch v := s.(type) { + case int64: + dst = appendBase128Int(dst[:0], v) + case *big.Int: + dst = appendBase128BigInt(dst[:0], v) + } + for i := 2; i < len(oid); i++ { + switch v := oid[i].(type) { + case int: + dst = appendBase128Int(dst, int64(v)) + case int64: + dst = appendBase128Int(dst, v) + case *big.Int: + dst = appendBase128BigInt(dst, v) + } + } +} + func makeObjectIdentifier(oid []int) (e encoder, err error) { if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { return nil, StructuralError{"invalid object identifier"} @@ -310,6 +424,59 @@ func makeObjectIdentifier(oid []int) (e encoder, err error) { return oidEncoder(oid), nil } +func makeObjectIdentifierExt(oid ObjectIdentifierExt) (e encoder, err error) { + isValid := true + if len(oid) < 2 { + isValid = false + } else { + validateOid2Range := false + switch v := oid[0].(type) { + case int: + if v > 2 { + isValid = false + } + if v < 2 { + validateOid2Range = true + } + case int64: + if v > 2 { + isValid = false + } + if v < 2 { + validateOid2Range = true + } + case *big.Int: + if v.Cmp(bigTwo) > 0 { + isValid = false + } + if v.Cmp(bigTwo) < 0 { + validateOid2Range = false + } + } + if isValid && validateOid2Range { + switch v := oid[1].(type) { + case int: + if v >= 40 { + isValid = false + } + case int64: + if v >= 40 { + isValid = false + } + case *big.Int: + if v.Cmp(bigForty) >= 0 { + isValid = false + } + } + } + } + if !isValid { + return nil, StructuralError{"invalid object identifier"} + } + + return oidEncoderExt(oid), nil +} + func makePrintableString(s string) (e encoder, err error) { for i := 0; i < len(s); i++ { // The asterisk is often used in PrintableString, even though @@ -472,6 +639,8 @@ func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error return bitStringEncoder(value.Interface().(BitString)), nil case objectIdentifierType: return makeObjectIdentifier(value.Interface().(ObjectIdentifier)) + case objectIdentifierExtType: + return makeObjectIdentifierExt(value.Interface().(ObjectIdentifierExt)) case bigIntType: return makeBigInt(value.Interface().(*big.Int)) } @@ -584,12 +753,11 @@ func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { return makeField(v.Elem(), params) } - if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { return bytesEncoder(nil), nil } - if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { + if params.optional && params.defaultValue != nil && canHaveDefaultValue(v) { defaultValue := reflect.New(v.Type()).Elem() defaultValue.SetInt(*params.defaultValue) diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go index 529052285f5950..4787e8ddb16176 100644 --- a/src/encoding/asn1/marshal_test.go +++ b/src/encoding/asn1/marshal_test.go @@ -7,6 +7,7 @@ package asn1 import ( "bytes" "encoding/hex" + "math" "math/big" "reflect" "strings" @@ -134,6 +135,21 @@ var marshalTests = []marshalTest{ {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, {ObjectIdentifier([]int{1, 2, 840, 133549, 1, 1, 5}), "06092a864888932d010105"}, {ObjectIdentifier([]int{2, 100, 3}), "0603813403"}, + // Sub-oid value 67006527840 exceeds 2^31-1 and less than 2^63-1 + {ObjectIdentifierExt{1, 2, 36, 67006527840, 666, 66, 1, 1}, "060d2a2481f9cf99f260851a420101"}, + // Use cases that could cause overflow of (40 * value1 + value2) + {ObjectIdentifierExt{2, math.MaxInt64, 1, 1}, "060C8180808080808080804F0101"}, + {ObjectIdentifierExt{2, math.MaxInt64 - 79, 1, 1}, "060C818080808080808080000101"}, + {ObjectIdentifierExt{2, math.MaxInt64 - 80, 1, 1}, "060BFFFFFFFFFFFFFFFF7F0101"}, + {ObjectIdentifierExt{2, math.MaxInt64 - 81, 1, 1}, "060BFFFFFFFFFFFFFFFF7E0101"}, + // Sub-oid value is near 2^62, which is close to max int64 value + {ObjectIdentifierExt{1, 2, 1 << 62, 1}, "060b2ac0808080808080800001"}, + // Same OID as above, implemented as big.Int + {ObjectIdentifierExt{big.NewInt(1), big.NewInt(2), new(big.Int).Lsh(big.NewInt(1), 62), big.NewInt(1)}, "060b2ac0808080808080800001"}, + // Same OID as above, with a mix of int, int64 and big.Int + {ObjectIdentifierExt{big.NewInt(1), int64(2), new(big.Int).Lsh(big.NewInt(1), 62), int(1)}, "060b2ac0808080808080800001"}, + // Sub-oid value is more than max int64 value + {ObjectIdentifierExt{big.NewInt(1), big.NewInt(2), new(big.Int).Lsh(big.NewInt(1), 65), big.NewInt(1)}, "060c2a8480808080808080800001"}, {"test", "130474657374"}, { "" + @@ -254,6 +270,30 @@ func TestInvalidUTF8(t *testing.T) { } } +func TestInvalidOID(t *testing.T) { + var marshalTestsOID = []interface{}{ + // First sub-oid must be 0, 1 or 2. + ObjectIdentifier{-1, 999, 3}, + ObjectIdentifier{3, 999, 3}, + ObjectIdentifierExt{3, 999, 3}, + ObjectIdentifierExt{math.MaxInt64, 999, 3}, + ObjectIdentifierExt{big.NewInt(3), 999, 3}, + // Second sub-oid is limited to the range 0 to 39 when first sub-oid is 0 or 1. + ObjectIdentifier{0, 40, 3}, + ObjectIdentifier{1, 40, 3}, + ObjectIdentifierExt{0, 40, 3}, + ObjectIdentifierExt{1, 40, 3}, + ObjectIdentifierExt{0, math.MaxInt64, 3}, + ObjectIdentifierExt{1, math.MaxInt64, 3}, + } + for i, oid := range marshalTestsOID { + _, err := Marshal(oid) + if err == nil || "asn1: structure error: invalid object identifier" != err.Error() { + t.Errorf("#%d failed: %s", i, err) + } + } +} + func TestMarshalOID(t *testing.T) { var marshalTestsOID = []marshalTest{ {[]byte("\x06\x01\x30"), "0403060130"}, // bytes format returns a byte sequence \x04