Skip to content

Commit cff3e75

Browse files
aglbradfitz
authored andcommitted
crypto/tls: add Config.GetConfigForClient
GetConfigForClient allows the tls.Config to be updated on a per-client basis. Fixes #16066. Fixes #15707. Fixes #15699. Change-Id: I2c675a443d557f969441226729f98502b38901ea Reviewed-on: https://go-review.googlesource.com/30790 Run-TryBot: Adam Langley <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 7e2bf95 commit cff3e75

File tree

4 files changed

+225
-32
lines changed

4 files changed

+225
-32
lines changed

src/crypto/tls/common.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,27 @@ type Config struct {
303303
// If GetCertificate is nil or returns nil, then the certificate is
304304
// retrieved from NameToCertificate. If NameToCertificate is nil, the
305305
// first element of Certificates will be used.
306-
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
306+
GetCertificate func(*ClientHelloInfo) (*Certificate, error)
307+
308+
// GetConfigForClient, if not nil, is called after a ClientHello is
309+
// received from a client. It may return a non-nil Config in order to
310+
// change the Config that will be used to handle this connection. If
311+
// the returned Config is nil, the original Config will be used. The
312+
// Config returned by this callback may not be subsequently modified.
313+
//
314+
// If GetConfigForClient is nil, the Config passed to Server() will be
315+
// used for all connections.
316+
//
317+
// Uniquely for the fields in the returned Config, session ticket keys
318+
// will be duplicated from the original Config if not set.
319+
// Specifically, if SetSessionTicketKeys was called on the original
320+
// config but not on the returned config then the ticket keys from the
321+
// original config will be copied into the new config before use.
322+
// Otherwise, if SessionTicketKey was set in the original config but
323+
// not in the returned config then it will be copied into the returned
324+
// config before use. If neither of those cases applies then the key
325+
// material from the returned config will be used for session tickets.
326+
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
307327

308328
// RootCAs defines the set of root certificate authorities
309329
// that clients use when verifying server certificates.
@@ -398,13 +418,17 @@ type Config struct {
398418

399419
serverInitOnce sync.Once // guards calling (*Config).serverInit
400420

401-
// mutex protects sessionTicketKeys
421+
// mutex protects sessionTicketKeys and originalConfig.
402422
mutex sync.RWMutex
403423
// sessionTicketKeys contains zero or more ticket keys. If the length
404424
// is zero, SessionTicketsDisabled must be true. The first key is used
405425
// for new tickets and any subsequent keys can be used to decrypt old
406426
// tickets.
407427
sessionTicketKeys []ticketKey
428+
// originalConfig is set to the Config that was passed to Server if
429+
// this Config is returned by a GetConfigForClient callback. It's used
430+
// by serverInit in order to copy session ticket keys if needed.
431+
originalConfig *Config
408432
}
409433

410434
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
@@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
434458
// Clone returns a shallow clone of c.
435459
// Only the exported fields are copied.
436460
func (c *Config) Clone() *Config {
461+
var sessionTicketKeys []ticketKey
462+
c.mutex.RLock()
463+
sessionTicketKeys = c.sessionTicketKeys
464+
c.mutex.RUnlock()
465+
437466
return &Config{
438467
Rand: c.Rand,
439468
Time: c.Time,
440469
Certificates: c.Certificates,
441470
NameToCertificate: c.NameToCertificate,
442471
GetCertificate: c.GetCertificate,
472+
GetConfigForClient: c.GetConfigForClient,
443473
RootCAs: c.RootCAs,
444474
NextProtos: c.NextProtos,
445475
ServerName: c.ServerName,
@@ -457,6 +487,8 @@ func (c *Config) Clone() *Config {
457487
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
458488
Renegotiation: c.Renegotiation,
459489
KeyLogWriter: c.KeyLogWriter,
490+
sessionTicketKeys: sessionTicketKeys,
491+
// originalConfig is deliberately not duplicated.
460492
}
461493
}
462494

@@ -465,6 +497,11 @@ func (c *Config) serverInit() {
465497
return
466498
}
467499

500+
var originalConfig *Config
501+
c.mutex.Lock()
502+
originalConfig, c.originalConfig = c.originalConfig, nil
503+
c.mutex.Unlock()
504+
468505
alreadySet := false
469506
for _, b := range c.SessionTicketKey {
470507
if b != 0 {
@@ -474,13 +511,21 @@ func (c *Config) serverInit() {
474511
}
475512

476513
if !alreadySet {
477-
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
514+
if originalConfig != nil {
515+
copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:])
516+
} else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
478517
c.SessionTicketsDisabled = true
479518
return
480519
}
481520
}
482521

483-
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
522+
if originalConfig != nil {
523+
originalConfig.mutex.RLock()
524+
c.sessionTicketKeys = originalConfig.sessionTicketKeys
525+
originalConfig.mutex.RUnlock()
526+
} else {
527+
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
528+
}
484529
}
485530

486531
func (c *Config) ticketKeys() []ticketKey {

src/crypto/tls/handshake_server.go

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ type serverHandshakeState struct {
3737
// serverHandshake performs a TLS handshake as a server.
3838
// c.out.Mutex <= L; c.handshakeMutex <= L.
3939
func (c *Conn) serverHandshake() error {
40-
config := c.config
41-
4240
// If this is the first server handshake, we generate a random key to
4341
// encrypt the tickets with.
44-
config.serverInitOnce.Do(config.serverInit)
42+
c.config.serverInitOnce.Do(c.config.serverInit)
4543

4644
hs := serverHandshakeState{
4745
c: c,
@@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error {
112110
// readClientHello reads a ClientHello message from the client and decides
113111
// whether we will perform session resumption.
114112
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
115-
config := hs.c.config
116113
c := hs.c
117114

118115
msg, err := c.readHandshake()
@@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
125122
c.sendAlert(alertUnexpectedMessage)
126123
return false, unexpectedMessageError(hs.clientHello, msg)
127124
}
128-
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
125+
126+
clientHelloInfo := &ClientHelloInfo{
127+
CipherSuites: hs.clientHello.cipherSuites,
128+
ServerName: hs.clientHello.serverName,
129+
SupportedCurves: hs.clientHello.supportedCurves,
130+
SupportedPoints: hs.clientHello.supportedPoints,
131+
}
132+
133+
if c.config.GetConfigForClient != nil {
134+
if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil {
135+
c.sendAlert(alertInternalError)
136+
return false, err
137+
} else if newConfig != nil {
138+
newConfig.mutex.Lock()
139+
newConfig.originalConfig = c.config
140+
newConfig.mutex.Unlock()
141+
142+
newConfig.serverInitOnce.Do(newConfig.serverInit)
143+
c.config = newConfig
144+
}
145+
}
146+
147+
c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
129148
if !ok {
130149
c.sendAlert(alertProtocolVersion)
131150
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
@@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
135154
hs.hello = new(serverHelloMsg)
136155

137156
supportedCurve := false
138-
preferredCurves := config.curvePreferences()
157+
preferredCurves := c.config.curvePreferences()
139158
Curves:
140159
for _, curve := range hs.clientHello.supportedCurves {
141160
for _, supported := range preferredCurves {
@@ -171,7 +190,7 @@ Curves:
171190

172191
hs.hello.vers = c.vers
173192
hs.hello.random = make([]byte, 32)
174-
_, err = io.ReadFull(config.rand(), hs.hello.random)
193+
_, err = io.ReadFull(c.config.rand(), hs.hello.random)
175194
if err != nil {
176195
c.sendAlert(alertInternalError)
177196
return false, err
@@ -196,20 +215,15 @@ Curves:
196215
} else {
197216
// Although sending an empty NPN extension is reasonable, Firefox has
198217
// had a bug around this. Best to send nothing at all if
199-
// config.NextProtos is empty. See
218+
// c.config.NextProtos is empty. See
200219
// https://golang.org/issue/5445.
201-
if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
220+
if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
202221
hs.hello.nextProtoNeg = true
203-
hs.hello.nextProtos = config.NextProtos
222+
hs.hello.nextProtos = c.config.NextProtos
204223
}
205224
}
206225

207-
hs.cert, err = config.getCertificate(&ClientHelloInfo{
208-
CipherSuites: hs.clientHello.cipherSuites,
209-
ServerName: hs.clientHello.serverName,
210-
SupportedCurves: hs.clientHello.supportedCurves,
211-
SupportedPoints: hs.clientHello.supportedPoints,
212-
})
226+
hs.cert, err = c.config.getCertificate(clientHelloInfo)
213227
if err != nil {
214228
c.sendAlert(alertInternalError)
215229
return false, err
@@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
354368
}
355369

356370
func (hs *serverHandshakeState) doFullHandshake() error {
357-
config := hs.c.config
358371
c := hs.c
359372

360373
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
361374
hs.hello.ocspStapling = true
362375
}
363376

364-
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled
377+
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
365378
hs.hello.cipherSuite = hs.suite.id
366379

367380
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
368-
if config.ClientAuth == NoClientCert {
381+
if c.config.ClientAuth == NoClientCert {
369382
// No need to keep a full record of the handshake if client
370383
// certificates won't be used.
371384
hs.finishedHash.discardHandshakeBuffer()
@@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
394407
}
395408

396409
keyAgreement := hs.suite.ka(c.vers)
397-
skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello)
410+
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
398411
if err != nil {
399412
c.sendAlert(alertHandshakeFailure)
400413
return err
@@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
406419
}
407420
}
408421

409-
if config.ClientAuth >= RequestClientCert {
422+
if c.config.ClientAuth >= RequestClientCert {
410423
// Request a client certificate
411424
certReq := new(certificateRequestMsg)
412425
certReq.certificateTypes = []byte{
@@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
423436
// to our request. When we know the CAs we trust, then
424437
// we can send them down, so that the client can choose
425438
// an appropriate certificate to give to us.
426-
if config.ClientCAs != nil {
427-
certReq.certificateAuthorities = config.ClientCAs.Subjects()
439+
if c.config.ClientCAs != nil {
440+
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
428441
}
429442
hs.finishedHash.Write(certReq.marshal())
430443
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
@@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
452465
var ok bool
453466
// If we requested a client certificate, then the client must send a
454467
// certificate message, even if it's empty.
455-
if config.ClientAuth >= RequestClientCert {
468+
if c.config.ClientAuth >= RequestClientCert {
456469
if certMsg, ok = msg.(*certificateMsg); !ok {
457470
c.sendAlert(alertUnexpectedMessage)
458471
return unexpectedMessageError(certMsg, msg)
@@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
461474

462475
if len(certMsg.certificates) == 0 {
463476
// The client didn't actually send a certificate
464-
switch config.ClientAuth {
477+
switch c.config.ClientAuth {
465478
case RequireAnyClientCert, RequireAndVerifyClientCert:
466479
c.sendAlert(alertBadCertificate)
467480
return errors.New("tls: client didn't provide a certificate")
@@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error {
487500
}
488501
hs.finishedHash.Write(ckx.marshal())
489502

490-
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
503+
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
491504
if err != nil {
492505
c.sendAlert(alertHandshakeFailure)
493506
return err
494507
}
495508
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
496-
if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
509+
if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
497510
c.sendAlert(alertInternalError)
498511
return err
499512
}

0 commit comments

Comments
 (0)