diff --git a/Cargo.toml b/Cargo.toml index 8c9a3b5ca..9151d931c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,11 @@ harness = false [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +sve_backend = [] [dependencies] blake3 = { version = "1.3", default-features = false } +cfg-if = "1.0.0" winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } @@ -38,3 +40,6 @@ winter_utils = { version = "0.6", package = "winter-utils", default-features = f criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } + +[build-dependencies] +cc = "1.0.79" diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..662eca343 --- /dev/null +++ b/build.rs @@ -0,0 +1,18 @@ +fn main() { + #[cfg(feature = "sve_backend")] + compile_sve_backend(); +} + +#[cfg(feature = "sve_backend")] +fn compile_sve_backend() { + println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.c"); + println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.h"); + println!("cargo:rerun-if-changed=sve/src/inv_sbox.h"); + println!("cargo:rerun-if-changed=sve/src/inv_sbox.h"); + + cc::Build::new() + .file("sve/src/sve_inv_sbox.c") + .file("sve/src/inv_sbox.c") + .flag("-march=armv8-a+sve") + .compile("sve"); +} diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 95f2c979d..77ecb8a8b 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -1,5 +1,7 @@ use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO}; use core::{convert::TryInto, ops::Range}; +#[cfg(feature = "sve_backend")] +use winter_math::fields::f64::BaseElement; mod digest; pub use digest::RpoDigest; @@ -10,6 +12,12 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; +#[cfg(feature = "sve_backend")] +#[link(name = "sve", kind = "static")] +extern "C" { + fn sve_apply_inv_sbox(state: *mut std::ffi::c_ulong); +} + // CONSTANTS // ================================================================================================ @@ -351,7 +359,15 @@ impl Rpo256 { // apply second half of RPO round Self::apply_mds(state); Self::add_constants(state, &ARK2[round]); - Self::apply_inv_sbox(state); + cfg_if::cfg_if! { + if #[cfg(feature = "sve_backend")] { + unsafe { + sve_apply_inv_sbox(std::mem::transmute::<*mut BaseElement, *mut u64>(state.as_mut_ptr())); + } + } else { + Self::apply_inv_sbox(state); + } + } } // HELPER FUNCTIONS diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index d0f68890b..b82760cfa 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -30,14 +30,56 @@ fn test_sbox() { assert_eq!(expected, actual); } +#[cfg(feature = "sve_backend")] +#[link(name = "sve", kind = "static")] +extern "C" { + fn apply_inv_sbox_c(state: *mut std::ffi::c_ulong); + fn sve_apply_inv_sbox(state: *mut std::ffi::c_ulong); +} + #[test] fn test_inv_sbox() { - let state = [Felt::new(rand_value()); STATE_WIDTH]; + let state: [Felt; STATE_WIDTH] = [ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]; let mut expected = state; expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA)); - let mut actual = state; + + cfg_if::cfg_if! { + if #[cfg(feature = "sve_backend")] { + let mut actual_c: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c[i] = actual[i].inner(); + } + + let mut actual_c_sve: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c_sve[i] = actual[i].inner(); + } + unsafe { + apply_inv_sbox_c(actual_c.as_mut_ptr()); + sve_apply_inv_sbox(actual_c_sve.as_mut_ptr()); + } + + let expected_as_u64_vec: Vec = expected.iter().map(|s| s.inner()).collect(); + assert_eq!(expected_as_u64_vec, actual_c); + assert_eq!(expected_as_u64_vec, actual_c_sve); + } + } + Rpo256::apply_inv_sbox(&mut actual); assert_eq!(expected, actual); diff --git a/sve/.clang-format b/sve/.clang-format new file mode 100644 index 000000000..80be9b5c1 --- /dev/null +++ b/sve/.clang-format @@ -0,0 +1,6 @@ +UseTab: ForIndentation +IndentWidth: 8 +BreakBeforeBraces: Allman +AllowShortIfStatementsOnASingleLine: false +IndentCaseLabels: false +ColumnLimit: 120 diff --git a/sve/Makefile b/sve/Makefile new file mode 100644 index 000000000..918079ad3 --- /dev/null +++ b/sve/Makefile @@ -0,0 +1,43 @@ +.PHONY: clean fmt + +TARGET=test_sve + +CC=cc +CFLAGS=-std=c99 -march=armv8-a+sve -Wall -Wextra -pedantic -g -O0 +LN_FLAGS= + +BUILD_DIR=./build +SRC_DIR=./src + +SOURCE = $(wildcard $(SRC_DIR)/*.c) +HEADERS = $(wildcard $(SRC_DIR)/*.h) +OBJECTS = $(patsubst $(SRC_DIR)/%.c, $(BUILD_DIR)/%.o, $(SOURCE)) + +# Gcc/Clang will create these .d files containing dependencies. +DEP = $(OBJECTS:%.o=%.d) + +default: $(TARGET) + +$(TARGET): $(BUILD_DIR)/$(TARGET) + +$(BUILD_DIR)/$(TARGET): $(OBJECTS) + $(CC) $(CFLAGS) $(LN_FLAGS) $^ -o $@ + +-include $(DEP) + +# The potential dependency on header files is covered +# by calling `-include $(DEP)`. +# The -MMD flags additionaly creates a .d file with +# the same name as the .o file. +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.c + @mkdir -p $(@D) + $(CC) $(CFLAGS) -MMD -c $< -o $@ + +test: $(TARGET) + $(BUILD_DIR)/$(TARGET) + +clean: + rm -rf $(BUILD_DIR) + +fmt: + clang-format --style=file -i $(SOURCE) $(HEADERS) diff --git a/sve/src/inv_sbox.c b/sve/src/inv_sbox.c new file mode 100644 index 000000000..eb58f0311 --- /dev/null +++ b/sve/src/inv_sbox.c @@ -0,0 +1,253 @@ +#include "inv_sbox.h" +#include +#include +#include +#include + +bool will_sum_overflow(uint64_t a, uint64_t b) +{ + if ((UINT_MAX - a) < b) + { + return true; + } + + return false; +} + +bool will_sub_overflow(uint64_t a, uint64_t b) { return a < b; } + +// /// Montgomery reduction (constant time) +// #[inline(always)] +// const fn mont_red_cst(x: u128) ->u64 +// { +// // See reference above for a description of the following implementation. +// let xl = x as u64; +// let xh = (x >> 64) as u64; +// let(a, e) = xl.overflowing_add(xl << 32); + +// let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); + +// let(r, c) = xh.overflowing_sub(b); +// r.wrapping_sub(0u32.wrapping_sub(c as u32)as u64) +// } +uint64_t mont_red_cst(__uint128_t x) +{ + uint64_t xl = (uint64_t)x; + uint64_t xh = x >> 64; + + bool e = will_sum_overflow(xl, xl << 32); + uint64_t a = xl + (xl << 32); + + uint64_t b = (a - (a >> 32)) - e; + + bool c = will_sub_overflow(xh, b); + uint64_t r = xh - b; + + return r - (uint64_t)((uint32_t)0 - (uint32_t)c); +} + +// #[inline] +// fn mul(self, rhs: Self) -> Self { +// Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128))) +// } +uint64_t multiply_montgomery_form_felts(uint64_t a, uint64_t b) +{ + __uint128_t a_casted = (__uint128_t)a; + __uint128_t b_casted = (__uint128_t)b; + + return mont_red_cst(a_casted * b_casted); +} + +uint64_t square(uint64_t a) { return multiply_montgomery_form_felts(a, a); } + +// #[inline(always)] +// fn exp_acc( +// base: [B; N], +// tail: [B; N], +// ) -> [B; N] { +// let mut result = base; +// for _ in 0..M { +// result.iter_mut().for_each(|r| *r = r.square()); +// } +// result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); +// result +// } +void exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 12; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 31; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +// #[inline(always)] +// fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { +// // compute base^10540996611094048183 using 72 multiplications per array element +// // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 +// // compute base^10 + +// let mut t1 = *state; + +// t1.iter_mut().for_each(|t| *t = t.square()); + +// // compute base^100 +// let mut t2 = t1; + +// t2.iter_mut().for_each(|t| *t = t.square()); +// // compute base^100100 + +// let t3 = Self::exp_acc::(t2, t2); +// // compute base^100100100100 + +// let t4 = Self::exp_acc::(t3, t3); +// // compute base^100100100100100100100100 + +// let t5 = Self::exp_acc::(t4, t4); +// // compute base^100100100100100100100100100100 + +// let t6 = Self::exp_acc::(t5, t3); +// // compute base^1001001001001001001001001001000100100100100100100100100100100 + +// let t7 = Self::exp_acc::(t6, t6); +// // compute base^1001001001001001001001001001000110110110110110110110110110110111 + +// for (i, s) in state.iter_mut().enumerate() { +// let a = (t7[i].square() * t6[i]).square().square(); +// let b = t1[i] * t2[i] * *s; +// *s = a * b; +// } +// } +void apply_inv_sbox_c(uint64_t state[STATE_WIDTH]) +{ + uint64_t t1[STATE_WIDTH]; + + // Square each element of state, call it t1 + for (int j = 0; j < STATE_WIDTH; j++) + { + t1[j] = square(state[j]); + } + + uint64_t t2[STATE_WIDTH]; + + // Square each element of t1, call it t2 + for (int j = 0; j < STATE_WIDTH; j++) + { + t2[j] = square(t1[j]); + } + + // Call exp_acc_3(t2, t2), call it t3 + uint64_t t3[STATE_WIDTH]; + exp_acc_3(t2, t2, t3); + + // Call exp_acc_6(t3, t3), call it t4 + uint64_t t4[STATE_WIDTH]; + exp_acc_6(t3, t3, t4); + + // Call exp_acc_12(t4, t4), call it t5 + uint64_t t5[STATE_WIDTH]; + exp_acc_12(t4, t4, t5); + + // Call exp_acc_6(t5, t3), call it t6 + uint64_t t6[STATE_WIDTH]; + exp_acc_6(t5, t3, t6); + + // Call exp_acc_31(t6, t6), call it t7 + uint64_t t7[STATE_WIDTH]; + exp_acc_31(t6, t6, t7); + + for (int i = 0; i < STATE_WIDTH; i++) + { + uint64_t a = square(square((multiply_montgomery_form_felts((square(t7[i])), t6[i])))); + uint64_t b = multiply_montgomery_form_felts(multiply_montgomery_form_felts(t1[i], t2[i]), state[i]); + + state[i] = multiply_montgomery_form_felts(a, b); + } +} diff --git a/sve/src/inv_sbox.h b/sve/src/inv_sbox.h new file mode 100644 index 000000000..43feb36ea --- /dev/null +++ b/sve/src/inv_sbox.h @@ -0,0 +1 @@ +#define STATE_WIDTH 12 diff --git a/sve/src/sve_inv_sbox.c b/sve/src/sve_inv_sbox.c new file mode 100644 index 000000000..c4fc502c8 --- /dev/null +++ b/sve/src/sve_inv_sbox.c @@ -0,0 +1,318 @@ +#include "sve_inv_sbox.h" +#include +#include +#include +#include + +#define ZERO_ARRAY \ + { \ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 \ + } + +const uint64_t ONES[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +const uint64_t ZEROES[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +const uint64_t THIRTY_TWOS[STATE_WIDTH] = {32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}; + +inline void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) + __attribute__((always_inline)); +inline void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) + __attribute__((always_inline)); +inline void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], + uint64_t *result) __attribute__((always_inline)); +inline void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) __attribute__((always_inline)); +inline void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +inline void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); + +void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svlsl_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svlsr_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + svbool_t addition_overflowed; + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svuint64_t addition_result = svadd_z(pg, x_vec, y_vec); + svst1(pg, &result[i], addition_result); + + svuint64_t one_vec = svld1(pg, &ONES[i]); + + addition_overflowed = svcmplt(pg, addition_result, svmax_z(pg, x_vec, y_vec)); + svst1(addition_overflowed, &overflowed[i], one_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + svbool_t substraction_underflowed; + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svsub_z(pg, x_vec, y_vec)); + + svuint64_t one_vec = svld1(pg, &ONES[i]); + + substraction_underflowed = svcmplt_u64(pg, x_vec, y_vec); + svst1(substraction_underflowed, &underflowed[i], one_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint32_t x_vec = svld1(pg, (uint32_t *)&x[i]); + svuint32_t y_vec = svld1(pg, (uint32_t *)&y[i]); + svst1(pg, (uint32_t *)&result[i], svsub_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svmul_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svmulh_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) +{ + uint64_t e[STATE_WIDTH] = ZERO_ARRAY; + uint64_t a[STATE_WIDTH]; + uint64_t x_shifted[STATE_WIDTH]; + + sve_shift_left(x, THIRTY_TWOS, x_shifted); + sve_add(x, x_shifted, a, e); + + uint64_t a_shifted[STATE_WIDTH]; + sve_shift_right(a, THIRTY_TWOS, a_shifted); + + uint64_t b[STATE_WIDTH]; + uint64_t _unused[STATE_WIDTH]; + sve_substract(a, a_shifted, b, _unused); + sve_substract(b, e, b, _unused); + + uint64_t r[STATE_WIDTH]; + uint64_t c[STATE_WIDTH] = ZERO_ARRAY; + + sve_substract(y, b, r, c); + + uint64_t minus_c[STATE_WIDTH] = ZERO_ARRAY; + sve_substract_as_u32(ZEROES, c, minus_c); + + sve_substract(r, minus_c, result, _unused); +} + +void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], uint64_t *result) +{ + uint64_t low[STATE_WIDTH]; + uint64_t high[STATE_WIDTH]; + + sve_multiply_low(a, b, low); + sve_multiply_high(a, b, high); + + sve_mont_red_cst(low, high, result); +} + +void sve_square(uint64_t *a) { sve_multiply_montgomery_form_felts(a, a, a); } + +void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t a_vec = svld1(pg, &a[i]); + svst1(pg, ©[i], a_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 3; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 6; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 12; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 31; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_apply_inv_sbox(uint64_t state[STATE_WIDTH]) +{ + uint64_t t1[STATE_WIDTH]; + sve_copy(state, t1); + + sve_square(t1); + + uint64_t t2[STATE_WIDTH]; + sve_copy(t1, t2); + + sve_square(t2); + + uint64_t t3[STATE_WIDTH]; + sve_exp_acc_3(t2, t2, t3); + + uint64_t t4[STATE_WIDTH]; + sve_exp_acc_6(t3, t3, t4); + + uint64_t t5[STATE_WIDTH]; + sve_exp_acc_12(t4, t4, t5); + + uint64_t t6[STATE_WIDTH]; + sve_exp_acc_6(t5, t3, t6); + + uint64_t t7[STATE_WIDTH]; + sve_exp_acc_31(t6, t6, t7); + + sve_square(t7); + uint64_t a[STATE_WIDTH]; + sve_multiply_montgomery_form_felts(t7, t6, a); + sve_square(a); + sve_square(a); + + uint64_t b[STATE_WIDTH]; + sve_multiply_montgomery_form_felts(t1, t2, b); + sve_multiply_montgomery_form_felts(b, state, b); + + sve_multiply_montgomery_form_felts(a, b, state); +} + +void print_array(size_t len, uint64_t arr[len]) +{ + printf("["); + for (size_t i = 0; i < len; i++) + { + printf("%lu ", arr[i]); + } + + printf("]\n"); +} diff --git a/sve/src/sve_inv_sbox.h b/sve/src/sve_inv_sbox.h new file mode 100644 index 000000000..305fe5e3c --- /dev/null +++ b/sve/src/sve_inv_sbox.h @@ -0,0 +1,15 @@ +#include +#include +#ifdef __ARM_FEATURE_SVE +#include +#endif /* __ARM_FEATURE_SVE */ + +#define STATE_WIDTH 12 + +void print_array(size_t len, uint64_t arr[len]); +void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed); +void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed); +void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_apply_inv_sbox(uint64_t state[STATE_WIDTH]); diff --git a/sve/src/test.c b/sve/src/test.c new file mode 100644 index 000000000..3eb86c6ad --- /dev/null +++ b/sve/src/test.c @@ -0,0 +1,108 @@ +#include "sve_inv_sbox.h" +#include + +void test_sve_shift_left(); +void test_sve_shift_right(); +void test_sve_add(); +void test_sve_substract(); + +int main() +{ + test_sve_shift_left(); + test_sve_shift_right(); + test_sve_add(); + test_sve_substract(); + + return 0; +} + +void assert_array_equality(const uint64_t result[STATE_WIDTH], const uint64_t expected[STATE_WIDTH]) +{ + for (int i = 0; i < STATE_WIDTH; i++) + { + assert(result[i] == expected[i]); + } +} + +void test_sve_shift_left() +{ + uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_left(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected[STATE_WIDTH] = { + 0, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, + }; + assert_array_equality(result, expected); +} + +void test_sve_shift_right() +{ + uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_right(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected[STATE_WIDTH] = { + 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, + }; + assert_array_equality(result, expected); +} + +void test_sve_add() +{ + uint64_t x[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t overflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_add(x, y, result, overflowed); + print_array(STATE_WIDTH, result); + print_array(STATE_WIDTH, overflowed); + + uint64_t expected_result[STATE_WIDTH] = { + 0, 0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + }; + uint64_t expected_overflowed[STATE_WIDTH] = { + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + assert_array_equality(result, expected_result); + assert_array_equality(overflowed, expected_overflowed); +} + +void test_sve_substract() +{ + uint64_t x[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t underflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract(x, y, result, underflowed); + print_array(STATE_WIDTH, result); + print_array(STATE_WIDTH, underflowed); + + uint64_t expected_result[STATE_WIDTH] = { + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + }; + uint64_t expected_underflowed[STATE_WIDTH] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }; + assert_array_equality(result, expected_result); + assert_array_equality(underflowed, expected_underflowed); +} + +void test_sve_substract_as_u32() +{ + uint64_t x[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract_as_u32(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected_result[STATE_WIDTH] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }; + assert_array_equality(result, expected_result); +}