diff --git a/src/abstraction/transient.rs b/src/abstraction/transient.rs index 31ca21d1..584685b3 100644 --- a/src/abstraction/transient.rs +++ b/src/abstraction/transient.rs @@ -35,7 +35,7 @@ impl TransientObjectContext { if root_key_auth_size > 32 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } - if root_key_size < 1024 { + if root_key_size < 1024 || root_key_size > 4096 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } let mut context = Context::new(tcti)?; @@ -50,7 +50,7 @@ impl TransientObjectContext { let root_key_handle = context.create_primary_key( ESYS_TR_RH_OWNER, - &get_rsa_public(true, true, false, root_key_size.try_into().unwrap()), + &get_rsa_public(true, true, false, root_key_size.try_into().unwrap()), // should not fail on supported targets, given the checks above &root_key_auth, &[], &[], @@ -83,7 +83,7 @@ impl TransientObjectContext { if auth_size > 32 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } - if key_size < 1024 { + if key_size < 1024 || key_size > 4096 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } let key_auth = if auth_size > 0 { @@ -95,7 +95,7 @@ impl TransientObjectContext { self.set_session_attrs()?; let (key_priv, key_pub) = self.context.create_key( self.root_key_handle, - &get_rsa_public(false, false, true, key_size.try_into().unwrap()), + &get_rsa_public(false, false, true, key_size.try_into().unwrap()), // should not fail on valid targets, given the checks above &key_auth, &[], &[], @@ -122,7 +122,7 @@ impl TransientObjectContext { let pk = TPMU_PUBLIC_ID { rsa: TPM2B_PUBLIC_KEY_RSA { - size: public_key.len().try_into().unwrap(), + size: public_key.len().try_into().unwrap(), // should not fail on valid targets, given the checks above buffer: pk_buffer, }, }; @@ -131,7 +131,7 @@ impl TransientObjectContext { false, false, true, - u16::try_from(public_key.len()).unwrap() * 8u16, + u16::try_from(public_key.len()).unwrap() * 8u16, // should not fail on valid targets, given the checks above ); public.publicArea.unique = pk; @@ -160,7 +160,7 @@ impl TransientObjectContext { let key = match PublicIdUnion::from_public(&key_pub_id) { PublicIdUnion::Rsa(pub_key) => { let mut key = pub_key.buffer.to_vec(); - key.truncate(pub_key.size.try_into().unwrap()); + key.truncate(pub_key.size.try_into().unwrap()); // should not fail on supported targets key } _ => unimplemented!(), @@ -257,7 +257,7 @@ mod tests { let mut ctx = TransientObjectContext::new(Tcti::Mssim, 2048, 32, &[]).unwrap(); for _ in 0..4 { let (key, auth) = ctx.create_rsa_signing_key(2048, 16).unwrap(); - let mut signature = ctx.sign(key.clone(), &auth, &HASH).unwrap(); + let signature = ctx.sign(key.clone(), &auth, &HASH).unwrap(); let pub_key = ctx.read_public_key(key.clone()).unwrap(); let pub_key = ctx.load_external_rsa_public_key(&pub_key).unwrap(); ctx.verify_signature(pub_key, &HASH, signature).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 79db97e3..bd0f3219 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,10 +54,13 @@ use utils::{Signature, TpmaSession, TpmsContext}; #[macro_use] macro_rules! wrap_buffer { ($buf:expr, $buf_type:ty, $buf_size:expr) => {{ + if $buf.len() > $buf_size { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } let mut buffer = [0u8; $buf_size]; buffer[..$buf.len()].clone_from_slice(&$buf[..$buf.len()]); let mut buf_struct: $buf_type = Default::default(); - buf_struct.size = $buf.len().try_into().unwrap(); + buf_struct.size = $buf.len().try_into().unwrap(); // should not fail since the length is checked above buf_struct.buffer = buffer; buf_struct }}; @@ -115,7 +118,7 @@ impl Context { let ret = unsafe { tss2_esys::Esys_Initialize( &mut esys_context, - tcti_context.as_mut().unwrap().as_mut_ptr(), + tcti_context.as_mut().unwrap().as_mut_ptr(), // will not panic as per how tcti_context is initialised null_mut(), ) }; @@ -162,10 +165,6 @@ impl Context { symmetric: TPMT_SYM_DEF, auth_hash: TPMI_ALG_HASH, ) -> Result { - if nonce.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let nonce_caller = wrap_buffer!(nonce, TPM2B_NONCE, 64); let mut sess = ESYS_TR_NONE; @@ -218,30 +217,26 @@ impl Context { outside_info: &[u8], creation_pcrs: &[TPMS_PCR_SELECTION], ) -> Result { - if auth_value.len() > 64 - || initial_data.len() > 256 - || outside_info.len() > 64 - || creation_pcrs.len() > 16 - { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let sensitive_create = TPM2B_SENSITIVE_CREATE { size: std::mem::size_of::() .try_into() - .unwrap(), + .unwrap(), // will not fail on targets of at least 16 bits sensitive: TPMS_SENSITIVE_CREATE { userAuth: wrap_buffer!(auth_value, TPM2B_AUTH, 64), data: wrap_buffer!(initial_data, TPM2B_SENSITIVE_DATA, 256), }, }; - let outside_info = wrap_buffer!(outside_info, TPM2B_DATA, 64); + + if creation_pcrs.len() > 16 { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } + let mut creation_pcrs_buffer = [Default::default(); 16]; creation_pcrs_buffer[..creation_pcrs.len()] .clone_from_slice(&creation_pcrs[..creation_pcrs.len()]); let creation_pcrs = TPML_PCR_SELECTION { - count: creation_pcrs.len().try_into().unwrap(), + count: creation_pcrs.len().try_into().unwrap(), // will not fail given the len checks above pcrSelections: creation_pcrs_buffer, }; @@ -297,18 +292,10 @@ impl Context { outside_info: &[u8], creation_pcrs: &[TPMS_PCR_SELECTION], ) -> Result<(TPM2B_PRIVATE, TPM2B_PUBLIC)> { - if auth_value.len() > 64 - || initial_data.len() > 256 - || outside_info.len() > 64 - || creation_pcrs.len() > 16 - { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let sensitive_create = TPM2B_SENSITIVE_CREATE { size: std::mem::size_of::() .try_into() - .unwrap(), + .unwrap(), // will not fail on targets of at least 16 bits sensitive: TPMS_SENSITIVE_CREATE { userAuth: wrap_buffer!(auth_value, TPM2B_AUTH, 64), data: wrap_buffer!(initial_data, TPM2B_SENSITIVE_DATA, 256), @@ -317,11 +304,14 @@ impl Context { let outside_info = wrap_buffer!(outside_info, TPM2B_DATA, 64); + if creation_pcrs.len() > 16 { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } let mut creation_pcrs_buffer = [Default::default(); 16]; creation_pcrs_buffer[..creation_pcrs.len()] .clone_from_slice(&creation_pcrs[..creation_pcrs.len()]); let creation_pcrs = TPML_PCR_SELECTION { - count: creation_pcrs.len().try_into().unwrap(), + count: creation_pcrs.len().try_into().unwrap(), // will not fail given the len checks above pcrSelections: creation_pcrs_buffer, }; @@ -403,9 +393,6 @@ impl Context { scheme: TPMT_SIG_SCHEME, validation: &TPMT_TK_HASHCHECK, ) -> Result { - if digest.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } let mut signature = null_mut(); let digest = wrap_buffer!(digest, TPM2B_DIGEST, 64); let ret = unsafe { @@ -438,9 +425,6 @@ impl Context { digest: &[u8], signature: &TPMT_SIGNATURE, ) -> Result { - if digest.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } let mut validation = null_mut(); let digest = wrap_buffer!(digest, TPM2B_DIGEST, 64); let ret = unsafe { @@ -577,7 +561,7 @@ impl Context { let ret = Error::from_tss_rc(ret); if ret.is_success() { let context = unsafe { MBox::::from_raw(context) }; - Ok((*context).into()) + Ok((*context).try_into()?) } else { error!("Error in saving context: {}.", ret); Err(ret) @@ -612,7 +596,9 @@ impl Context { self.sessions.0, self.sessions.1, self.sessions.2, - num_bytes.try_into().unwrap(), + num_bytes + .try_into() + .or_else(|_| Err(Error::local_error(ErrorKind::WrongParamSize)))?, &mut buffer, ) }; @@ -621,7 +607,7 @@ impl Context { if ret.is_success() { let buffer = unsafe { MBox::from_raw(buffer) }; let mut random = buffer.buffer.to_vec(); - random.truncate(buffer.size.try_into().unwrap()); + random.truncate(buffer.size.try_into().unwrap()); // should not panic given the TryInto above Ok(random) } else { error!("Error in flushing context: {}.", ret); @@ -630,10 +616,6 @@ impl Context { } pub fn set_handle_auth(&mut self, handle: ESYS_TR, auth_value: &[u8]) -> Result<()> { - if auth_value.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let auth = wrap_buffer!(auth_value, TPM2B_AUTH, 64); let ret = unsafe { Esys_TR_SetAuth(self.mut_context(), handle, &auth) }; let ret = Error::from_tss_rc(ret); @@ -657,7 +639,7 @@ impl Context { } fn mut_context(&mut self) -> *mut ESYS_CONTEXT { - self.esys_context.as_mut().unwrap().as_mut_ptr() + self.esys_context.as_mut().unwrap().as_mut_ptr() // will only fail if called from Drop after .take() } } @@ -673,8 +655,8 @@ impl Drop for Context { } }); - let esys_context = self.esys_context.take().unwrap(); - let tcti_context = self.tcti_context.take().unwrap(); + let esys_context = self.esys_context.take().unwrap(); // should not fail based on how the context is initialised/used + let tcti_context = self.tcti_context.take().unwrap(); // should not fail based on how the context is initialised/used // Close the TCTI context. unsafe { diff --git a/src/response_code.rs b/src/response_code.rs index d8658c13..fe3df75b 100644 --- a/src/response_code.rs +++ b/src/response_code.rs @@ -328,7 +328,7 @@ impl std::fmt::Display for Tss2ResponseCode { if kind.is_none() { return write!(f, "response code not recognized"); } - match self.kind().unwrap() { + match self.kind().unwrap() { // should not panic, given the check above Tss2ResponseCodeKind::Success => write!(f, "success"), Tss2ResponseCodeKind::TpmVendorSpecific => write!(f, "vendor specific error: {}", self.error_number()), // Format Zero @@ -475,6 +475,8 @@ impl std::fmt::Display for Error { #[derive(Copy, Clone, PartialEq, Debug)] pub enum WrapperErrorKind { WrongParamSize, + ParamsMissing, + InconsistentParams, } impl std::fmt::Display for WrapperErrorKind { @@ -483,6 +485,13 @@ impl std::fmt::Display for WrapperErrorKind { WrapperErrorKind::WrongParamSize => { write!(f, "parameter provided is of the wrong size") } + WrapperErrorKind::ParamsMissing => { + write!(f, "some of the required parameters were not provided") + } + WrapperErrorKind::InconsistentParams => write!( + f, + "the provided parameters have inconsistent values or variants" + ), } } } diff --git a/src/utils.rs b/src/utils.rs index 964e2e28..445c1dfb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -71,7 +71,7 @@ impl Tpm2BPublicBuilder { self } - pub fn build(mut self) -> TPM2B_PUBLIC { + pub fn build(mut self) -> Result { match self.type_ { Some(TPM2_ALG_RSA) => { // RSA key @@ -80,9 +80,9 @@ impl Tpm2BPublicBuilder { if let Some(PublicParmsUnion::RsaDetail(parms)) = self.parameters { parameters = TPMU_PUBLIC_PARMS { rsaDetail: parms }; } else if self.parameters.is_none() { - panic!("No key parameters provided"); + return Err(Error::local_error(WrapperErrorKind::ParamsMissing)); } else { - panic!("Wrong parameter type provided"); + return Err(Error::local_error(WrapperErrorKind::InconsistentParams)); } if let Some(PublicIdUnion::Rsa(rsa_unique)) = self.unique { @@ -90,7 +90,7 @@ impl Tpm2BPublicBuilder { } else if self.unique.is_none() { unique = Default::default(); } else { - panic!("Wrong unique type provided"); + return Err(Error::local_error(WrapperErrorKind::InconsistentParams)); } if self.object_attributes.sign_encrypt() && self.object_attributes.decrypt() { @@ -111,19 +111,21 @@ impl Tpm2BPublicBuilder { parameters.rsaDetail.symmetric.algorithm = TPM2_ALG_NULL; } - TPM2B_PUBLIC { + Ok(TPM2B_PUBLIC { size: std::mem::size_of::() .try_into() - .expect("Failed to convert usize to u16"), + .expect("Failed to convert usize to u16"), // should not fail on valid targets publicArea: TPMT_PUBLIC { - type_: self.type_.expect("Object type not provided"), + type_: self + .type_ + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))?, nameAlg: self.name_alg, objectAttributes: self.object_attributes.0, authPolicy: self.auth_policy, parameters, unique, }, - } + }) } _ => unimplemented!(), } @@ -177,16 +179,16 @@ impl TpmsRsaParmsBuilder { self } - pub fn build(self) -> TPMS_RSA_PARMS { - TPMS_RSA_PARMS { + pub fn build(self) -> Result { + Ok(TPMS_RSA_PARMS { symmetric: self.symmetric, scheme: self .scheme - .expect("Scheme was not provided") + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))? .get_rsa_scheme(), keyBits: self.key_bits, exponent: self.exponent, - } + }) } } @@ -225,7 +227,7 @@ impl TpmtSymDefBuilder { self } - pub fn build_object(self) -> TPMT_SYM_DEF_OBJECT { + pub fn build_object(self) -> Result { let key_bits; let mode; match self.algorithm { @@ -263,11 +265,13 @@ impl TpmtSymDefBuilder { _ => unimplemented!(), } - TPMT_SYM_DEF_OBJECT { - algorithm: self.algorithm.expect("No algorithm provided"), + Ok(TPMT_SYM_DEF_OBJECT { + algorithm: self + .algorithm + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))?, keyBits: key_bits, mode, - } + }) } pub fn aes_256_cfb() -> TPMT_SYM_DEF { @@ -469,7 +473,7 @@ impl TryFrom for TPMT_SIGNATURE { rsassa: TPMS_SIGNATURE_RSA { hash: hash_alg, sig: TPM2B_PUBLIC_KEY_RSA { - size: len.try_into().expect("Failed to convert length to u16"), // Should never panic + size: len.try_into().expect("Failed to convert length to u16"), // Should never panic as per the check above buffer, }, }, @@ -517,18 +521,24 @@ pub struct TpmsContext { context_blob: Vec, } -impl From for TpmsContext { - fn from(tss2_context: TPMS_CONTEXT) -> Self { +impl TryFrom for TpmsContext { + type Error = Error; + + fn try_from(tss2_context: TPMS_CONTEXT) -> Result { let mut context = TpmsContext { sequence: tss2_context.sequence, saved_handle: tss2_context.savedHandle, hierarchy: tss2_context.hierarchy, context_blob: tss2_context.contextBlob.buffer.to_vec(), }; - context - .context_blob - .truncate(tss2_context.contextBlob.size.try_into().unwrap()); - context + context.context_blob.truncate( + tss2_context + .contextBlob + .size + .try_into() + .or_else(|_| Err(Error::local_error(WrapperErrorKind::WrongParamSize)))?, + ); + Ok(context) } } @@ -538,7 +548,7 @@ impl TryFrom for TPMS_CONTEXT { fn try_from(context: TpmsContext) -> Result { let buffer_size = context.context_blob.len(); if buffer_size > 5188 { - return Err(Error::from_tss_rc(TPM2_RC_SIZE)); + return Err(Error::local_error(WrapperErrorKind::WrongParamSize)); } let mut buffer = [0u8; 5188]; for (i, val) in context.context_blob.into_iter().enumerate() { @@ -549,7 +559,7 @@ impl TryFrom for TPMS_CONTEXT { savedHandle: context.saved_handle, hierarchy: context.hierarchy, contextBlob: TPM2B_CONTEXT_DATA { - size: buffer_size.try_into().unwrap(), + size: buffer_size.try_into().unwrap(), // should not panic given the check above buffer, }, }) @@ -563,7 +573,8 @@ pub fn get_rsa_public(restricted: bool, decrypt: bool, sign: bool, key_bits: u16 .with_symmetric(symmetric) .with_key_bits(key_bits) .with_scheme(scheme) - .build(); + .build() + .unwrap(); // should not fail as we control the params let mut object_attributes = ObjectAttributes(0); object_attributes.set_fixed_tpm(true); object_attributes.set_fixed_parent(true); @@ -579,4 +590,5 @@ pub fn get_rsa_public(restricted: bool, decrypt: bool, sign: bool, key_bits: u16 .with_object_attributes(object_attributes) .with_parms(PublicParmsUnion::RsaDetail(rsa_parms)) .build() + .unwrap() // should not fail as we control the params }