From 9b92b3e87a4c432e0a35b562683eac925e10740f Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 10:40:37 -0300 Subject: [PATCH 01/12] Basic SVE implementation --- Cargo.toml | 3 + build.rs | 8 + c_code/.clang-format | 6 + c_code/Makefile | 43 ++++ c_code/src/main.c | 46 ++++ c_code/src/test_sve.c | 538 ++++++++++++++++++++++++++++++++++++++++++ c_code/src/test_sve.h | 15 ++ src/hash/rpo/mod.rs | 15 +- src/hash/rpo/tests.rs | 37 ++- 9 files changed, 709 insertions(+), 2 deletions(-) create mode 100644 build.rs create mode 100644 c_code/.clang-format create mode 100644 c_code/Makefile create mode 100644 c_code/src/main.c create mode 100644 c_code/src/test_sve.c create mode 100644 c_code/src/test_sve.h diff --git a/Cargo.toml b/Cargo.toml index 8c9a3b5ca..e72324bc7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,6 @@ winter_utils = { version = "0.6", package = "winter-utils", default-features = f criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } + +[build-dependencies] +cc = "1.0.79" diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..e27219088 --- /dev/null +++ b/build.rs @@ -0,0 +1,8 @@ +fn main() { + println!("cargo:rerun-if-changed=c_code/src/test_sve.c"); + println!("cargo:rerun-if-changed=c_code/src/test_sve.h"); + cc::Build::new() + .file("c_code/src/test_sve.c") + .flag("-march=armv8-a+sve") + .compile("sve"); +} diff --git a/c_code/.clang-format b/c_code/.clang-format new file mode 100644 index 000000000..80be9b5c1 --- /dev/null +++ b/c_code/.clang-format @@ -0,0 +1,6 @@ +UseTab: ForIndentation +IndentWidth: 8 +BreakBeforeBraces: Allman +AllowShortIfStatementsOnASingleLine: false +IndentCaseLabels: false +ColumnLimit: 120 diff --git a/c_code/Makefile b/c_code/Makefile new file mode 100644 index 000000000..918079ad3 --- /dev/null +++ b/c_code/Makefile @@ -0,0 +1,43 @@ +.PHONY: clean fmt + +TARGET=test_sve + +CC=cc +CFLAGS=-std=c99 -march=armv8-a+sve -Wall -Wextra -pedantic -g -O0 +LN_FLAGS= + +BUILD_DIR=./build +SRC_DIR=./src + +SOURCE = $(wildcard $(SRC_DIR)/*.c) +HEADERS = $(wildcard $(SRC_DIR)/*.h) +OBJECTS = $(patsubst $(SRC_DIR)/%.c, $(BUILD_DIR)/%.o, $(SOURCE)) + +# Gcc/Clang will create these .d files containing dependencies. +DEP = $(OBJECTS:%.o=%.d) + +default: $(TARGET) + +$(TARGET): $(BUILD_DIR)/$(TARGET) + +$(BUILD_DIR)/$(TARGET): $(OBJECTS) + $(CC) $(CFLAGS) $(LN_FLAGS) $^ -o $@ + +-include $(DEP) + +# The potential dependency on header files is covered +# by calling `-include $(DEP)`. +# The -MMD flags additionaly creates a .d file with +# the same name as the .o file. +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.c + @mkdir -p $(@D) + $(CC) $(CFLAGS) -MMD -c $< -o $@ + +test: $(TARGET) + $(BUILD_DIR)/$(TARGET) + +clean: + rm -rf $(BUILD_DIR) + +fmt: + clang-format --style=file -i $(SOURCE) $(HEADERS) diff --git a/c_code/src/main.c b/c_code/src/main.c new file mode 100644 index 000000000..b28595742 --- /dev/null +++ b/c_code/src/main.c @@ -0,0 +1,46 @@ +#include "test_sve.h" + +int main() +{ + // TEST SHIFT LEFT + uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_left(x, y, result); + print_array(STATE_WIDTH, result); + + // TEST SHIFT RIGHT + uint64_t x_1[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y_1[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result_1[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_right(x_1, y_1, result_1); + print_array(STATE_WIDTH, result_1); + + // TEST ADD + uint64_t x_2[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y_2[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result_2[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t overflowed_2[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_add(x_2, y_2, result_2, overflowed_2); + print_array(STATE_WIDTH, result_2); + print_array(STATE_WIDTH, overflowed_2); + + // TEST SUBSTRACT + uint64_t x_3[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y_3[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + uint64_t result_3[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t overflowed_3[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract(x_3, y_3, result_3, overflowed_3); + print_array(STATE_WIDTH, result_3); + print_array(STATE_WIDTH, overflowed_3); + + // TEST SUBSTRACT AS u32 + uint64_t x_4[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y_4[STATE_WIDTH] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + uint64_t result_4[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract_as_u32(x_4, y_4, result_4); + print_array(STATE_WIDTH, result_4); + + return 0; +} diff --git a/c_code/src/test_sve.c b/c_code/src/test_sve.c new file mode 100644 index 000000000..3ad63ead2 --- /dev/null +++ b/c_code/src/test_sve.c @@ -0,0 +1,538 @@ +#include "test_sve.h" +#include +#include +#include +#include + +#define ZERO_ARRAY \ + { \ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 \ + } + +const uint64_t ONES[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +const uint64_t ZEROES[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +const uint64_t THIRTY_TWOS[STATE_WIDTH] = {32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}; + +bool will_sum_overflow(uint64_t a, uint64_t b) +{ + if ((UINT_MAX - a) < b) + { + return true; + } + + return false; +} + +bool will_sub_overflow(uint64_t a, uint64_t b) { return a < b; } + +void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svlsl_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svlsr_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + svbool_t addition_overflowed; + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svuint64_t addition_result = svadd_z(pg, x_vec, y_vec); + svst1(pg, &result[i], addition_result); + + svuint64_t one_vec = svld1(pg, &ONES[i]); + + addition_overflowed = svcmplt(pg, addition_result, svmax_z(pg, x_vec, y_vec)); + svst1(addition_overflowed, &overflowed[i], one_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + svbool_t substraction_underflowed; + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svsub_z(pg, x_vec, y_vec)); + + svuint64_t one_vec = svld1(pg, &ONES[i]); + + substraction_underflowed = svcmplt_u64(pg, x_vec, y_vec); + svst1(substraction_underflowed, &underflowed[i], one_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint32_t x_vec = svld1(pg, (uint32_t *)&x[i]); + svuint32_t y_vec = svld1(pg, (uint32_t *)&y[i]); + svst1(pg, (uint32_t *)&result[i], svsub_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svmul_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t x_vec = svld1(pg, &x[i]); + svuint64_t y_vec = svld1(pg, &y[i]); + svst1(pg, &result[i], svmulh_z(pg, x_vec, y_vec)); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) +{ + uint64_t e[STATE_WIDTH] = ZERO_ARRAY; + uint64_t a[STATE_WIDTH]; + uint64_t x_shifted[STATE_WIDTH]; + + sve_shift_left(x, THIRTY_TWOS, x_shifted); + sve_add(x, x_shifted, a, e); + + uint64_t a_shifted[STATE_WIDTH]; + sve_shift_right(a, THIRTY_TWOS, a_shifted); + + uint64_t b[STATE_WIDTH]; + uint64_t _unused[STATE_WIDTH]; + sve_substract(a, a_shifted, b, _unused); + sve_substract(b, e, b, _unused); + + uint64_t r[STATE_WIDTH]; + uint64_t c[STATE_WIDTH] = ZERO_ARRAY; + + sve_substract(y, b, r, c); + + uint64_t minus_c[STATE_WIDTH] = ZERO_ARRAY; + sve_substract_as_u32(ZEROES, c, minus_c); + + sve_substract(r, minus_c, result, _unused); +} + +void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], uint64_t *result) +{ + uint64_t low[STATE_WIDTH]; + uint64_t high[STATE_WIDTH]; + + sve_multiply_low(a, b, low); + sve_multiply_high(a, b, high); + + sve_mont_red_cst(low, high, result); +} + +void sve_square(uint64_t *a) { sve_multiply_montgomery_form_felts(a, a, a); } + +void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) +{ + int64_t i = 0; + svbool_t pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); + do + { + svuint64_t a_vec = svld1(pg, &a[i]); + svst1(pg, ©[i], a_vec); + + i += svcntd(); + pg = svwhilelt_b64(i, (int64_t)STATE_WIDTH); // [1] + } while (svptest_any(svptrue_b64(), pg)); +} + +void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 3; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 6; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 12; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + sve_copy(base, result); + + // Square each element of `result` M number of times + for (int i = 0; i < 31; i++) + { + sve_square(result); + } + + sve_multiply_montgomery_form_felts(result, tail, result); +} + +void sve_apply_inv_sbox(uint64_t state[STATE_WIDTH]) +{ + uint64_t t1[STATE_WIDTH]; + sve_copy(state, t1); + + sve_square(t1); + + uint64_t t2[STATE_WIDTH]; + sve_copy(t1, t2); + + sve_square(t2); + + uint64_t t3[STATE_WIDTH]; + sve_exp_acc_3(t2, t2, t3); + + uint64_t t4[STATE_WIDTH]; + sve_exp_acc_6(t3, t3, t4); + + uint64_t t5[STATE_WIDTH]; + sve_exp_acc_12(t4, t4, t5); + + uint64_t t6[STATE_WIDTH]; + sve_exp_acc_6(t5, t3, t6); + + uint64_t t7[STATE_WIDTH]; + sve_exp_acc_31(t6, t6, t7); + + sve_square(t7); + uint64_t a[STATE_WIDTH]; + sve_multiply_montgomery_form_felts(t7, t6, a); + sve_square(a); + sve_square(a); + + uint64_t b[STATE_WIDTH]; + sve_multiply_montgomery_form_felts(t1, t2, b); + sve_multiply_montgomery_form_felts(b, state, b); + + sve_multiply_montgomery_form_felts(a, b, state); +} + +// /// Montgomery reduction (constant time) +// #[inline(always)] +// const fn mont_red_cst(x: u128) ->u64 +// { +// // See reference above for a description of the following implementation. +// let xl = x as u64; +// let xh = (x >> 64) as u64; +// let(a, e) = xl.overflowing_add(xl << 32); + +// let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); + +// let(r, c) = xh.overflowing_sub(b); +// r.wrapping_sub(0u32.wrapping_sub(c as u32)as u64) +// } +uint64_t mont_red_cst(__uint128_t x) +{ + uint64_t xl = (uint64_t)x; + uint64_t xh = x >> 64; + + bool e = will_sum_overflow(xl, xl << 32); + uint64_t a = xl + (xl << 32); + + uint64_t b = (a - (a >> 32)) - e; + + bool c = will_sub_overflow(xh, b); + uint64_t r = xh - b; + + return r - (uint64_t)((uint32_t)0 - (uint32_t)c); +} + +// #[inline] +// fn mul(self, rhs: Self) -> Self { +// Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128))) +// } +uint64_t multiply_montgomery_form_felts(uint64_t a, uint64_t b) +{ + __uint128_t a_casted = (__uint128_t)a; + __uint128_t b_casted = (__uint128_t)b; + + return mont_red_cst(a_casted * b_casted); +} + +uint64_t square(uint64_t a) { return multiply_montgomery_form_felts(a, a); } + +// #[inline(always)] +// fn exp_acc( +// base: [B; N], +// tail: [B; N], +// ) -> [B; N] { +// let mut result = base; +// for _ in 0..M { +// result.iter_mut().for_each(|r| *r = r.square()); +// } +// result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); +// result +// } +void exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 12; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 31; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +// #[inline(always)] +// fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { +// // compute base^10540996611094048183 using 72 multiplications per array element +// // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 +// // compute base^10 + +// let mut t1 = *state; + +// t1.iter_mut().for_each(|t| *t = t.square()); + +// // compute base^100 +// let mut t2 = t1; + +// t2.iter_mut().for_each(|t| *t = t.square()); +// // compute base^100100 + +// let t3 = Self::exp_acc::(t2, t2); +// // compute base^100100100100 + +// let t4 = Self::exp_acc::(t3, t3); +// // compute base^100100100100100100100100 + +// let t5 = Self::exp_acc::(t4, t4); +// // compute base^100100100100100100100100100100 + +// let t6 = Self::exp_acc::(t5, t3); +// // compute base^1001001001001001001001001001000100100100100100100100100100100 + +// let t7 = Self::exp_acc::(t6, t6); +// // compute base^1001001001001001001001001001000110110110110110110110110110110111 + +// for (i, s) in state.iter_mut().enumerate() { +// let a = (t7[i].square() * t6[i]).square().square(); +// let b = t1[i] * t2[i] * *s; +// *s = a * b; +// } +// } +void apply_inv_sbox_c(uint64_t state[STATE_WIDTH]) +{ + uint64_t t1[STATE_WIDTH]; + + // Square each element of state, call it t1 + for (int j = 0; j < STATE_WIDTH; j++) + { + t1[j] = square(state[j]); + } + + uint64_t t2[STATE_WIDTH]; + + // Square each element of t1, call it t2 + for (int j = 0; j < STATE_WIDTH; j++) + { + t2[j] = square(t1[j]); + } + + // Call exp_acc_3(t2, t2), call it t3 + uint64_t t3[STATE_WIDTH]; + exp_acc_3(t2, t2, t3); + + // Call exp_acc_6(t3, t3), call it t4 + uint64_t t4[STATE_WIDTH]; + exp_acc_6(t3, t3, t4); + + // Call exp_acc_12(t4, t4), call it t5 + uint64_t t5[STATE_WIDTH]; + exp_acc_12(t4, t4, t5); + + // Call exp_acc_6(t5, t3), call it t6 + uint64_t t6[STATE_WIDTH]; + exp_acc_6(t5, t3, t6); + + // Call exp_acc_31(t6, t6), call it t7 + uint64_t t7[STATE_WIDTH]; + exp_acc_31(t6, t6, t7); + + for (int i = 0; i < STATE_WIDTH; i++) + { + uint64_t a = square(square((multiply_montgomery_form_felts((square(t7[i])), t6[i])))); + uint64_t b = multiply_montgomery_form_felts(multiply_montgomery_form_felts(t1[i], t2[i]), state[i]); + + state[i] = multiply_montgomery_form_felts(a, b); + } +} + +void print_array(size_t len, uint64_t arr[len]) +{ + printf("["); + for (size_t i = 0; i < len; i++) + { + printf("%lu ", arr[i]); + } + + printf("]\n"); +} diff --git a/c_code/src/test_sve.h b/c_code/src/test_sve.h new file mode 100644 index 000000000..305fe5e3c --- /dev/null +++ b/c_code/src/test_sve.h @@ -0,0 +1,15 @@ +#include +#include +#ifdef __ARM_FEATURE_SVE +#include +#endif /* __ARM_FEATURE_SVE */ + +#define STATE_WIDTH 12 + +void print_array(size_t len, uint64_t arr[len]); +void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed); +void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed); +void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result); +void sve_apply_inv_sbox(uint64_t state[STATE_WIDTH]); diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 95f2c979d..7161985ed 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -10,6 +10,12 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; +#[link(name = "sve", kind = "static")] +extern "C" { + fn apply_inv_sbox_c(state: *mut std::ffi::c_ulong); + fn sve_apply_inv_sbox(state: *mut std::ffi::c_ulong); +} + // CONSTANTS // ================================================================================================ @@ -351,7 +357,14 @@ impl Rpo256 { // apply second half of RPO round Self::apply_mds(state); Self::add_constants(state, &ARK2[round]); - Self::apply_inv_sbox(state); + // Self::apply_inv_sbox(state); + let mut state_inner: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + state_inner[i] = state[i].inner(); + } + unsafe { + sve_apply_inv_sbox(state_inner.as_mut_ptr()); + } } // HELPER FUNCTIONS diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index d0f68890b..832e913e3 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -30,17 +30,52 @@ fn test_sbox() { assert_eq!(expected, actual); } +#[link(name = "sve", kind = "static")] +extern "C" { + fn apply_inv_sbox_c(state: *mut std::ffi::c_ulong); + fn sve_apply_inv_sbox(state: *mut std::ffi::c_ulong); +} + #[test] fn test_inv_sbox() { - let state = [Felt::new(rand_value()); STATE_WIDTH]; + let state: [Felt; STATE_WIDTH] = [ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]; let mut expected = state; expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA)); let mut actual = state; + let mut actual_c: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c[i] = actual[i].inner(); + } + + let mut actual_c_sve: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c_sve[i] = actual[i].inner(); + } Rpo256::apply_inv_sbox(&mut actual); + unsafe { + apply_inv_sbox_c(actual_c.as_mut_ptr()); + sve_apply_inv_sbox(actual_c_sve.as_mut_ptr()); + } + let actual_as_u64_vec: Vec = actual.iter().map(|s| s.inner()).collect(); assert_eq!(expected, actual); + assert_eq!(actual_as_u64_vec, actual_c); + assert_eq!(actual_as_u64_vec, actual_c_sve); } #[test] From fdfd0d157dcd45104506cd6601c1f3eafdc9932f Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:18:36 -0300 Subject: [PATCH 02/12] Reorder project --- build.rs | 10 +- c_code/src/main.c | 46 ---- {c_code => sve}/.clang-format | 0 {c_code => sve}/Makefile | 0 sve/src/inv_sbox.c | 253 ++++++++++++++++++ sve/src/inv_sbox.h | 1 + .../src/test_sve.c => sve/src/sve_inv_sbox.c | 250 +---------------- .../src/test_sve.h => sve/src/sve_inv_sbox.h | 0 sve/src/test.c | 95 +++++++ 9 files changed, 357 insertions(+), 298 deletions(-) delete mode 100644 c_code/src/main.c rename {c_code => sve}/.clang-format (100%) rename {c_code => sve}/Makefile (100%) create mode 100644 sve/src/inv_sbox.c create mode 100644 sve/src/inv_sbox.h rename c_code/src/test_sve.c => sve/src/sve_inv_sbox.c (54%) rename c_code/src/test_sve.h => sve/src/sve_inv_sbox.h (100%) create mode 100644 sve/src/test.c diff --git a/build.rs b/build.rs index e27219088..a30f75a9b 100644 --- a/build.rs +++ b/build.rs @@ -1,8 +1,12 @@ fn main() { - println!("cargo:rerun-if-changed=c_code/src/test_sve.c"); - println!("cargo:rerun-if-changed=c_code/src/test_sve.h"); + println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.c"); + println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.h"); + println!("cargo:rerun-if-changed=sve/src/inv_sbox.h"); + println!("cargo:rerun-if-changed=sve/src/inv_sbox.h"); + cc::Build::new() - .file("c_code/src/test_sve.c") + .file("sve/src/sve_inv_sbox.c") + .file("sve/src/inv_sbox.c") .flag("-march=armv8-a+sve") .compile("sve"); } diff --git a/c_code/src/main.c b/c_code/src/main.c deleted file mode 100644 index b28595742..000000000 --- a/c_code/src/main.c +++ /dev/null @@ -1,46 +0,0 @@ -#include "test_sve.h" - -int main() -{ - // TEST SHIFT LEFT - uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; - uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_shift_left(x, y, result); - print_array(STATE_WIDTH, result); - - // TEST SHIFT RIGHT - uint64_t x_1[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; - uint64_t y_1[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - uint64_t result_1[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_shift_right(x_1, y_1, result_1); - print_array(STATE_WIDTH, result_1); - - // TEST ADD - uint64_t x_2[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - uint64_t y_2[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - uint64_t result_2[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - uint64_t overflowed_2[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_add(x_2, y_2, result_2, overflowed_2); - print_array(STATE_WIDTH, result_2); - print_array(STATE_WIDTH, overflowed_2); - - // TEST SUBSTRACT - uint64_t x_3[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - uint64_t y_3[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, - UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; - uint64_t result_3[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - uint64_t overflowed_3[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_substract(x_3, y_3, result_3, overflowed_3); - print_array(STATE_WIDTH, result_3); - print_array(STATE_WIDTH, overflowed_3); - - // TEST SUBSTRACT AS u32 - uint64_t x_4[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - uint64_t y_4[STATE_WIDTH] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - uint64_t result_4[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_substract_as_u32(x_4, y_4, result_4); - print_array(STATE_WIDTH, result_4); - - return 0; -} diff --git a/c_code/.clang-format b/sve/.clang-format similarity index 100% rename from c_code/.clang-format rename to sve/.clang-format diff --git a/c_code/Makefile b/sve/Makefile similarity index 100% rename from c_code/Makefile rename to sve/Makefile diff --git a/sve/src/inv_sbox.c b/sve/src/inv_sbox.c new file mode 100644 index 000000000..eb58f0311 --- /dev/null +++ b/sve/src/inv_sbox.c @@ -0,0 +1,253 @@ +#include "inv_sbox.h" +#include +#include +#include +#include + +bool will_sum_overflow(uint64_t a, uint64_t b) +{ + if ((UINT_MAX - a) < b) + { + return true; + } + + return false; +} + +bool will_sub_overflow(uint64_t a, uint64_t b) { return a < b; } + +// /// Montgomery reduction (constant time) +// #[inline(always)] +// const fn mont_red_cst(x: u128) ->u64 +// { +// // See reference above for a description of the following implementation. +// let xl = x as u64; +// let xh = (x >> 64) as u64; +// let(a, e) = xl.overflowing_add(xl << 32); + +// let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); + +// let(r, c) = xh.overflowing_sub(b); +// r.wrapping_sub(0u32.wrapping_sub(c as u32)as u64) +// } +uint64_t mont_red_cst(__uint128_t x) +{ + uint64_t xl = (uint64_t)x; + uint64_t xh = x >> 64; + + bool e = will_sum_overflow(xl, xl << 32); + uint64_t a = xl + (xl << 32); + + uint64_t b = (a - (a >> 32)) - e; + + bool c = will_sub_overflow(xh, b); + uint64_t r = xh - b; + + return r - (uint64_t)((uint32_t)0 - (uint32_t)c); +} + +// #[inline] +// fn mul(self, rhs: Self) -> Self { +// Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128))) +// } +uint64_t multiply_montgomery_form_felts(uint64_t a, uint64_t b) +{ + __uint128_t a_casted = (__uint128_t)a; + __uint128_t b_casted = (__uint128_t)b; + + return mont_red_cst(a_casted * b_casted); +} + +uint64_t square(uint64_t a) { return multiply_montgomery_form_felts(a, a); } + +// #[inline(always)] +// fn exp_acc( +// base: [B; N], +// tail: [B; N], +// ) -> [B; N] { +// let mut result = base; +// for _ in 0..M { +// result.iter_mut().for_each(|r| *r = r.square()); +// } +// result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); +// result +// } +void exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 6; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 12; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +void exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +{ + // Copy `base` into `result` + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = base[i]; + } + + // Square each element of `result` M number of times + for (int i = 0; i < 31; i++) + { + for (int j = 0; j < STATE_WIDTH; j++) + { + result[j] = square(result[j]); + } + } + + // Multiply each element of result by its corresponding tail element. + for (int i = 0; i < STATE_WIDTH; i++) + { + result[i] = multiply_montgomery_form_felts(result[i], tail[i]); + } +} + +// #[inline(always)] +// fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { +// // compute base^10540996611094048183 using 72 multiplications per array element +// // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 +// // compute base^10 + +// let mut t1 = *state; + +// t1.iter_mut().for_each(|t| *t = t.square()); + +// // compute base^100 +// let mut t2 = t1; + +// t2.iter_mut().for_each(|t| *t = t.square()); +// // compute base^100100 + +// let t3 = Self::exp_acc::(t2, t2); +// // compute base^100100100100 + +// let t4 = Self::exp_acc::(t3, t3); +// // compute base^100100100100100100100100 + +// let t5 = Self::exp_acc::(t4, t4); +// // compute base^100100100100100100100100100100 + +// let t6 = Self::exp_acc::(t5, t3); +// // compute base^1001001001001001001001001001000100100100100100100100100100100 + +// let t7 = Self::exp_acc::(t6, t6); +// // compute base^1001001001001001001001001001000110110110110110110110110110110111 + +// for (i, s) in state.iter_mut().enumerate() { +// let a = (t7[i].square() * t6[i]).square().square(); +// let b = t1[i] * t2[i] * *s; +// *s = a * b; +// } +// } +void apply_inv_sbox_c(uint64_t state[STATE_WIDTH]) +{ + uint64_t t1[STATE_WIDTH]; + + // Square each element of state, call it t1 + for (int j = 0; j < STATE_WIDTH; j++) + { + t1[j] = square(state[j]); + } + + uint64_t t2[STATE_WIDTH]; + + // Square each element of t1, call it t2 + for (int j = 0; j < STATE_WIDTH; j++) + { + t2[j] = square(t1[j]); + } + + // Call exp_acc_3(t2, t2), call it t3 + uint64_t t3[STATE_WIDTH]; + exp_acc_3(t2, t2, t3); + + // Call exp_acc_6(t3, t3), call it t4 + uint64_t t4[STATE_WIDTH]; + exp_acc_6(t3, t3, t4); + + // Call exp_acc_12(t4, t4), call it t5 + uint64_t t5[STATE_WIDTH]; + exp_acc_12(t4, t4, t5); + + // Call exp_acc_6(t5, t3), call it t6 + uint64_t t6[STATE_WIDTH]; + exp_acc_6(t5, t3, t6); + + // Call exp_acc_31(t6, t6), call it t7 + uint64_t t7[STATE_WIDTH]; + exp_acc_31(t6, t6, t7); + + for (int i = 0; i < STATE_WIDTH; i++) + { + uint64_t a = square(square((multiply_montgomery_form_felts((square(t7[i])), t6[i])))); + uint64_t b = multiply_montgomery_form_felts(multiply_montgomery_form_felts(t1[i], t2[i]), state[i]); + + state[i] = multiply_montgomery_form_felts(a, b); + } +} diff --git a/sve/src/inv_sbox.h b/sve/src/inv_sbox.h new file mode 100644 index 000000000..43feb36ea --- /dev/null +++ b/sve/src/inv_sbox.h @@ -0,0 +1 @@ +#define STATE_WIDTH 12 diff --git a/c_code/src/test_sve.c b/sve/src/sve_inv_sbox.c similarity index 54% rename from c_code/src/test_sve.c rename to sve/src/sve_inv_sbox.c index 3ad63ead2..04e771efc 100644 --- a/c_code/src/test_sve.c +++ b/sve/src/sve_inv_sbox.c @@ -1,4 +1,4 @@ -#include "test_sve.h" +#include "sve_inv_sbox.h" #include #include #include @@ -13,18 +13,6 @@ const uint64_t ONES[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; const uint64_t ZEROES[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const uint64_t THIRTY_TWOS[STATE_WIDTH] = {32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}; -bool will_sum_overflow(uint64_t a, uint64_t b) -{ - if ((UINT_MAX - a) < b) - { - return true; - } - - return false; -} - -bool will_sub_overflow(uint64_t a, uint64_t b) { return a < b; } - void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) { int64_t i = 0; @@ -290,242 +278,6 @@ void sve_apply_inv_sbox(uint64_t state[STATE_WIDTH]) sve_multiply_montgomery_form_felts(a, b, state); } -// /// Montgomery reduction (constant time) -// #[inline(always)] -// const fn mont_red_cst(x: u128) ->u64 -// { -// // See reference above for a description of the following implementation. -// let xl = x as u64; -// let xh = (x >> 64) as u64; -// let(a, e) = xl.overflowing_add(xl << 32); - -// let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64); - -// let(r, c) = xh.overflowing_sub(b); -// r.wrapping_sub(0u32.wrapping_sub(c as u32)as u64) -// } -uint64_t mont_red_cst(__uint128_t x) -{ - uint64_t xl = (uint64_t)x; - uint64_t xh = x >> 64; - - bool e = will_sum_overflow(xl, xl << 32); - uint64_t a = xl + (xl << 32); - - uint64_t b = (a - (a >> 32)) - e; - - bool c = will_sub_overflow(xh, b); - uint64_t r = xh - b; - - return r - (uint64_t)((uint32_t)0 - (uint32_t)c); -} - -// #[inline] -// fn mul(self, rhs: Self) -> Self { -// Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128))) -// } -uint64_t multiply_montgomery_form_felts(uint64_t a, uint64_t b) -{ - __uint128_t a_casted = (__uint128_t)a; - __uint128_t b_casted = (__uint128_t)b; - - return mont_red_cst(a_casted * b_casted); -} - -uint64_t square(uint64_t a) { return multiply_montgomery_form_felts(a, a); } - -// #[inline(always)] -// fn exp_acc( -// base: [B; N], -// tail: [B; N], -// ) -> [B; N] { -// let mut result = base; -// for _ in 0..M { -// result.iter_mut().for_each(|r| *r = r.square()); -// } -// result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); -// result -// } -void exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) -{ - // Copy `base` into `result` - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = base[i]; - } - - // Square each element of `result` M number of times - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < STATE_WIDTH; j++) - { - result[j] = square(result[j]); - } - } - - // Multiply each element of result by its corresponding tail element. - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = multiply_montgomery_form_felts(result[i], tail[i]); - } -} - -void exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) -{ - // Copy `base` into `result` - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = base[i]; - } - - // Square each element of `result` M number of times - for (int i = 0; i < 6; i++) - { - for (int j = 0; j < STATE_WIDTH; j++) - { - result[j] = square(result[j]); - } - } - - // Multiply each element of result by its corresponding tail element. - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = multiply_montgomery_form_felts(result[i], tail[i]); - } -} - -void exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) -{ - // Copy `base` into `result` - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = base[i]; - } - - // Square each element of `result` M number of times - for (int i = 0; i < 12; i++) - { - for (int j = 0; j < STATE_WIDTH; j++) - { - result[j] = square(result[j]); - } - } - - // Multiply each element of result by its corresponding tail element. - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = multiply_montgomery_form_felts(result[i], tail[i]); - } -} - -void exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) -{ - // Copy `base` into `result` - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = base[i]; - } - - // Square each element of `result` M number of times - for (int i = 0; i < 31; i++) - { - for (int j = 0; j < STATE_WIDTH; j++) - { - result[j] = square(result[j]); - } - } - - // Multiply each element of result by its corresponding tail element. - for (int i = 0; i < STATE_WIDTH; i++) - { - result[i] = multiply_montgomery_form_felts(result[i], tail[i]); - } -} - -// #[inline(always)] -// fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { -// // compute base^10540996611094048183 using 72 multiplications per array element -// // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 -// // compute base^10 - -// let mut t1 = *state; - -// t1.iter_mut().for_each(|t| *t = t.square()); - -// // compute base^100 -// let mut t2 = t1; - -// t2.iter_mut().for_each(|t| *t = t.square()); -// // compute base^100100 - -// let t3 = Self::exp_acc::(t2, t2); -// // compute base^100100100100 - -// let t4 = Self::exp_acc::(t3, t3); -// // compute base^100100100100100100100100 - -// let t5 = Self::exp_acc::(t4, t4); -// // compute base^100100100100100100100100100100 - -// let t6 = Self::exp_acc::(t5, t3); -// // compute base^1001001001001001001001001001000100100100100100100100100100100 - -// let t7 = Self::exp_acc::(t6, t6); -// // compute base^1001001001001001001001001001000110110110110110110110110110110111 - -// for (i, s) in state.iter_mut().enumerate() { -// let a = (t7[i].square() * t6[i]).square().square(); -// let b = t1[i] * t2[i] * *s; -// *s = a * b; -// } -// } -void apply_inv_sbox_c(uint64_t state[STATE_WIDTH]) -{ - uint64_t t1[STATE_WIDTH]; - - // Square each element of state, call it t1 - for (int j = 0; j < STATE_WIDTH; j++) - { - t1[j] = square(state[j]); - } - - uint64_t t2[STATE_WIDTH]; - - // Square each element of t1, call it t2 - for (int j = 0; j < STATE_WIDTH; j++) - { - t2[j] = square(t1[j]); - } - - // Call exp_acc_3(t2, t2), call it t3 - uint64_t t3[STATE_WIDTH]; - exp_acc_3(t2, t2, t3); - - // Call exp_acc_6(t3, t3), call it t4 - uint64_t t4[STATE_WIDTH]; - exp_acc_6(t3, t3, t4); - - // Call exp_acc_12(t4, t4), call it t5 - uint64_t t5[STATE_WIDTH]; - exp_acc_12(t4, t4, t5); - - // Call exp_acc_6(t5, t3), call it t6 - uint64_t t6[STATE_WIDTH]; - exp_acc_6(t5, t3, t6); - - // Call exp_acc_31(t6, t6), call it t7 - uint64_t t7[STATE_WIDTH]; - exp_acc_31(t6, t6, t7); - - for (int i = 0; i < STATE_WIDTH; i++) - { - uint64_t a = square(square((multiply_montgomery_form_felts((square(t7[i])), t6[i])))); - uint64_t b = multiply_montgomery_form_felts(multiply_montgomery_form_felts(t1[i], t2[i]), state[i]); - - state[i] = multiply_montgomery_form_felts(a, b); - } -} - void print_array(size_t len, uint64_t arr[len]) { printf("["); diff --git a/c_code/src/test_sve.h b/sve/src/sve_inv_sbox.h similarity index 100% rename from c_code/src/test_sve.h rename to sve/src/sve_inv_sbox.h diff --git a/sve/src/test.c b/sve/src/test.c new file mode 100644 index 000000000..800a2b84a --- /dev/null +++ b/sve/src/test.c @@ -0,0 +1,95 @@ +#include "sve_inv_sbox.h" +#include + +int main() +{ + test_sve_shift_left(); + test_sve_shift_right(); + test_sve_add(); + test_sve_substract(); + + return 0; +} + +void test_sve_shift_left() +{ + uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_left(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected = { + 0, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, + }; + assert(result == expected); +} + +void test_sve_shift_right() +{ + uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_shift_right(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected = { + 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, + }; + assert(result == expected); +} + +void test_sve_add() +{ + uint64_t x[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t overflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_add(x, y, result, overflowed); + print_array(STATE_WIDTH, result); + print_array(STATE_WIDTH, overflowed); + + uint64_t expected_result = { + 0, 0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + }; + uint64_t expected_overflowed = { + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + assert(result == expected_result); + assert(overflowed == expected_overflowed); +} + +void test_sve_substract() +{ + uint64_t x[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, + UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint64_t unverflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract(x, y, result, unverflowed); + print_array(STATE_WIDTH, result); + print_array(STATE_WIDTH, unverflowed); + + uint64_t expected_result = { + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + }; + uint64_t expected_unverflowed = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }; + assert(result == expected_result); + assert(unverflowed == expected_unverflowed); +} + +void test_sve_substract_as_u32() +{ + uint64_t x[STATE_WIDTH] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint64_t y[STATE_WIDTH] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract_as_u32(x, y, result); + print_array(STATE_WIDTH, result); + + uint64_t expected_result = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }; + assert(result == expected_result); +} From 02ff25f632f369a09597aaf53aa1f096a554905f Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:20:03 -0300 Subject: [PATCH 03/12] Fix tests --- sve/src/test.c | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sve/src/test.c b/sve/src/test.c index 800a2b84a..2ec9126d6 100644 --- a/sve/src/test.c +++ b/sve/src/test.c @@ -19,7 +19,7 @@ void test_sve_shift_left() sve_shift_left(x, y, result); print_array(STATE_WIDTH, result); - uint64_t expected = { + uint64_t expected[STATE_WIDTH] = { 0, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, }; assert(result == expected); @@ -33,7 +33,7 @@ void test_sve_shift_right() sve_shift_right(x, y, result); print_array(STATE_WIDTH, result); - uint64_t expected = { + uint64_t expected[STATE_WIDTH] = { 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, }; assert(result == expected); @@ -49,10 +49,10 @@ void test_sve_add() print_array(STATE_WIDTH, result); print_array(STATE_WIDTH, overflowed); - uint64_t expected_result = { + uint64_t expected_result[STATE_WIDTH] = { 0, 0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, }; - uint64_t expected_overflowed = { + uint64_t expected_overflowed[STATE_WIDTH] = { 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; assert(result == expected_result); @@ -70,10 +70,10 @@ void test_sve_substract() print_array(STATE_WIDTH, result); print_array(STATE_WIDTH, unverflowed); - uint64_t expected_result = { + uint64_t expected_result[STATE_WIDTH] = { 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, }; - uint64_t expected_unverflowed = { + uint64_t expected_unverflowed[STATE_WIDTH] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, }; assert(result == expected_result); @@ -88,7 +88,7 @@ void test_sve_substract_as_u32() sve_substract_as_u32(x, y, result); print_array(STATE_WIDTH, result); - uint64_t expected_result = { + uint64_t expected_result[STATE_WIDTH] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, }; assert(result == expected_result); From d44f00d409fb8535f6f97180210515347191f2e1 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:20:49 -0300 Subject: [PATCH 04/12] More test fixes --- sve/src/test.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sve/src/test.c b/sve/src/test.c index 2ec9126d6..6621754a4 100644 --- a/sve/src/test.c +++ b/sve/src/test.c @@ -1,6 +1,11 @@ #include "sve_inv_sbox.h" #include +void test_sve_shift_left(); +void test_sve_shift_right(); +void test_sve_add(); +void test_sve_substract(); + int main() { test_sve_shift_left(); From ec6a649c72d1a14d94548bf79eda977bed620d90 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:23:42 -0300 Subject: [PATCH 05/12] More test fixes --- sve/src/test.c | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/sve/src/test.c b/sve/src/test.c index 6621754a4..3eb86c6ad 100644 --- a/sve/src/test.c +++ b/sve/src/test.c @@ -16,6 +16,14 @@ int main() return 0; } +void assert_array_equality(const uint64_t result[STATE_WIDTH], const uint64_t expected[STATE_WIDTH]) +{ + for (int i = 0; i < STATE_WIDTH; i++) + { + assert(result[i] == expected[i]); + } +} + void test_sve_shift_left() { uint64_t x[STATE_WIDTH] = {0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; @@ -27,7 +35,7 @@ void test_sve_shift_left() uint64_t expected[STATE_WIDTH] = { 0, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, }; - assert(result == expected); + assert_array_equality(result, expected); } void test_sve_shift_right() @@ -41,7 +49,7 @@ void test_sve_shift_right() uint64_t expected[STATE_WIDTH] = { 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, }; - assert(result == expected); + assert_array_equality(result, expected); } void test_sve_add() @@ -60,8 +68,8 @@ void test_sve_add() uint64_t expected_overflowed[STATE_WIDTH] = { 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; - assert(result == expected_result); - assert(overflowed == expected_overflowed); + assert_array_equality(result, expected_result); + assert_array_equality(overflowed, expected_overflowed); } void test_sve_substract() @@ -70,19 +78,19 @@ void test_sve_substract() uint64_t y[STATE_WIDTH] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; uint64_t result[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - uint64_t unverflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - sve_substract(x, y, result, unverflowed); + uint64_t underflowed[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + sve_substract(x, y, result, underflowed); print_array(STATE_WIDTH, result); - print_array(STATE_WIDTH, unverflowed); + print_array(STATE_WIDTH, underflowed); uint64_t expected_result[STATE_WIDTH] = { 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, }; - uint64_t expected_unverflowed[STATE_WIDTH] = { + uint64_t expected_underflowed[STATE_WIDTH] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, }; - assert(result == expected_result); - assert(unverflowed == expected_unverflowed); + assert_array_equality(result, expected_result); + assert_array_equality(underflowed, expected_underflowed); } void test_sve_substract_as_u32() @@ -96,5 +104,5 @@ void test_sve_substract_as_u32() uint64_t expected_result[STATE_WIDTH] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, }; - assert(result == expected_result); + assert_array_equality(result, expected_result); } From 34464b82d32db557ee460a94ffb4ced24ed18976 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:46:27 -0300 Subject: [PATCH 06/12] Add feature flag --- Cargo.toml | 2 ++ build.rs | 6 ++++++ src/hash/rpo/mod.rs | 21 +++++++++++++-------- src/hash/rpo/tests.rs | 40 ++++++++++++++++++++++++---------------- sve/src/sve_inv_sbox.c | 2 +- 5 files changed, 46 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e72324bc7..9151d931c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,11 @@ harness = false [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +sve_backend = [] [dependencies] blake3 = { version = "1.3", default-features = false } +cfg-if = "1.0.0" winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } diff --git a/build.rs b/build.rs index a30f75a9b..662eca343 100644 --- a/build.rs +++ b/build.rs @@ -1,4 +1,10 @@ fn main() { + #[cfg(feature = "sve_backend")] + compile_sve_backend(); +} + +#[cfg(feature = "sve_backend")] +fn compile_sve_backend() { println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.c"); println!("cargo:rerun-if-changed=sve/src/sve_inv_sbox.h"); println!("cargo:rerun-if-changed=sve/src/inv_sbox.h"); diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 7161985ed..db459a653 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -10,9 +10,9 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; +#[cfg(feature = "sve_backend")] #[link(name = "sve", kind = "static")] extern "C" { - fn apply_inv_sbox_c(state: *mut std::ffi::c_ulong); fn sve_apply_inv_sbox(state: *mut std::ffi::c_ulong); } @@ -357,13 +357,18 @@ impl Rpo256 { // apply second half of RPO round Self::apply_mds(state); Self::add_constants(state, &ARK2[round]); - // Self::apply_inv_sbox(state); - let mut state_inner: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; - for i in 0..STATE_WIDTH { - state_inner[i] = state[i].inner(); - } - unsafe { - sve_apply_inv_sbox(state_inner.as_mut_ptr()); + cfg_if::cfg_if! { + if #[cfg(feature = "sve_backend")] { + let mut state_inner: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + state_inner[i] = state[i].inner(); + } + unsafe { + sve_apply_inv_sbox(state_inner.as_mut_ptr()); + } + } else { + Self::apply_inv_sbox(state); + } } } diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index 832e913e3..fccd0e40f 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -30,6 +30,7 @@ fn test_sbox() { assert_eq!(expected, actual); } +#[cfg(feature = "sve_backend")] #[link(name = "sve", kind = "static")] extern "C" { fn apply_inv_sbox_c(state: *mut std::ffi::c_ulong); @@ -57,25 +58,32 @@ fn test_inv_sbox() { expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA)); let mut actual = state; - let mut actual_c: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; - for i in 0..STATE_WIDTH { - actual_c[i] = actual[i].inner(); - } - - let mut actual_c_sve: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; - for i in 0..STATE_WIDTH { - actual_c_sve[i] = actual[i].inner(); - } Rpo256::apply_inv_sbox(&mut actual); - unsafe { - apply_inv_sbox_c(actual_c.as_mut_ptr()); - sve_apply_inv_sbox(actual_c_sve.as_mut_ptr()); - } - let actual_as_u64_vec: Vec = actual.iter().map(|s| s.inner()).collect(); assert_eq!(expected, actual); - assert_eq!(actual_as_u64_vec, actual_c); - assert_eq!(actual_as_u64_vec, actual_c_sve); + + cfg_if::cfg_if! { + if #[cfg(feature = "sve_backend")] { + let mut actual_c: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c[i] = actual[i].inner(); + } + + let mut actual_c_sve: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; + for i in 0..STATE_WIDTH { + actual_c_sve[i] = actual[i].inner(); + } + unsafe { + apply_inv_sbox_c(actual_c.as_mut_ptr()); + sve_apply_inv_sbox(actual_c_sve.as_mut_ptr()); + } + + let actual_as_u64_vec: Vec = actual.iter().map(|s| s.inner()).collect(); + assert_eq!(actual_as_u64_vec, actual_c); + assert_eq!(actual_as_u64_vec, actual_c_sve); + + } + } } #[test] diff --git a/sve/src/sve_inv_sbox.c b/sve/src/sve_inv_sbox.c index 04e771efc..11015cdff 100644 --- a/sve/src/sve_inv_sbox.c +++ b/sve/src/sve_inv_sbox.c @@ -283,7 +283,7 @@ void print_array(size_t len, uint64_t arr[len]) printf("["); for (size_t i = 0; i < len; i++) { - printf("%lu ", arr[i]); + printf("%llu ", arr[i]); } printf("]\n"); From 0cea5d9eff3335adb8b2198a163774c202d821ca Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 11:59:03 -0300 Subject: [PATCH 07/12] Fix tests --- src/hash/rpo/tests.rs | 8 ++++---- sve/src/sve_inv_sbox.c | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index fccd0e40f..be905f826 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -56,11 +56,7 @@ fn test_inv_sbox() { let mut expected = state; expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA)); - let mut actual = state; - Rpo256::apply_inv_sbox(&mut actual); - - assert_eq!(expected, actual); cfg_if::cfg_if! { if #[cfg(feature = "sve_backend")] { @@ -84,6 +80,10 @@ fn test_inv_sbox() { } } + + Rpo256::apply_inv_sbox(&mut actual); + + assert_eq!(expected, actual); } #[test] diff --git a/sve/src/sve_inv_sbox.c b/sve/src/sve_inv_sbox.c index 11015cdff..04e771efc 100644 --- a/sve/src/sve_inv_sbox.c +++ b/sve/src/sve_inv_sbox.c @@ -283,7 +283,7 @@ void print_array(size_t len, uint64_t arr[len]) printf("["); for (size_t i = 0; i < len; i++) { - printf("%llu ", arr[i]); + printf("%lu ", arr[i]); } printf("]\n"); From aa5cebd568d781f1db3143ded32a9a6996166d7e Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 12:20:48 -0300 Subject: [PATCH 08/12] Hopefully fix tests --- src/hash/rpo/tests.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index be905f826..b82760cfa 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -74,10 +74,9 @@ fn test_inv_sbox() { sve_apply_inv_sbox(actual_c_sve.as_mut_ptr()); } - let actual_as_u64_vec: Vec = actual.iter().map(|s| s.inner()).collect(); - assert_eq!(actual_as_u64_vec, actual_c); - assert_eq!(actual_as_u64_vec, actual_c_sve); - + let expected_as_u64_vec: Vec = expected.iter().map(|s| s.inner()).collect(); + assert_eq!(expected_as_u64_vec, actual_c); + assert_eq!(expected_as_u64_vec, actual_c_sve); } } From 64945f41d5a33aa9164ed62b0694204b7f95e6a0 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 17:27:40 -0300 Subject: [PATCH 09/12] Avoid copying state when calling the sve version of apply_inv_sbox --- src/hash/rpo/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index db459a653..ee67dbc46 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -1,5 +1,7 @@ use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO}; use core::{convert::TryInto, ops::Range}; +#[cfg(feature = "sve_backend")] +use winter_math::fields::f64::BaseElement; mod digest; pub use digest::RpoDigest; @@ -364,7 +366,7 @@ impl Rpo256 { state_inner[i] = state[i].inner(); } unsafe { - sve_apply_inv_sbox(state_inner.as_mut_ptr()); + sve_apply_inv_sbox(std::mem::transmute::<*mut BaseElement, *mut u64>(state.as_mut_ptr())); } } else { Self::apply_inv_sbox(state); From 1fb40446005b2e1454b11280d1c323ea11a26bd2 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 17:28:55 -0300 Subject: [PATCH 10/12] Forgot to delete copying code --- src/hash/rpo/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index ee67dbc46..77ecb8a8b 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -361,10 +361,6 @@ impl Rpo256 { Self::add_constants(state, &ARK2[round]); cfg_if::cfg_if! { if #[cfg(feature = "sve_backend")] { - let mut state_inner: [u64; STATE_WIDTH] = [0; STATE_WIDTH]; - for i in 0..STATE_WIDTH { - state_inner[i] = state[i].inner(); - } unsafe { sve_apply_inv_sbox(std::mem::transmute::<*mut BaseElement, *mut u64>(state.as_mut_ptr())); } From ec6a0547ad5041c41a9d12aa1448db57b64bca1f Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 17:42:42 -0300 Subject: [PATCH 11/12] Try inlining C functions --- sve/src/sve_inv_sbox.c | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sve/src/sve_inv_sbox.c b/sve/src/sve_inv_sbox.c index 04e771efc..a9216b1a0 100644 --- a/sve/src/sve_inv_sbox.c +++ b/sve/src/sve_inv_sbox.c @@ -13,6 +13,34 @@ const uint64_t ONES[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; const uint64_t ZEROES[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const uint64_t THIRTY_TWOS[STATE_WIDTH] = {32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}; +void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) + __attribute__((always_inline)); +void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) + __attribute__((always_inline)); +void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) __attribute__((always_inline)); +void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); +void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) + __attribute__((always_inline)); + void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) { int64_t i = 0; From ff7e339f38e17def1a67ec2cc5366e63319a1ac2 Mon Sep 17 00:00:00 2001 From: Javier Chatruc Date: Tue, 27 Jun 2023 17:49:59 -0300 Subject: [PATCH 12/12] Actually inline stuff --- sve/src/sve_inv_sbox.c | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sve/src/sve_inv_sbox.c b/sve/src/sve_inv_sbox.c index a9216b1a0..c4fc502c8 100644 --- a/sve/src/sve_inv_sbox.c +++ b/sve/src/sve_inv_sbox.c @@ -13,32 +13,32 @@ const uint64_t ONES[STATE_WIDTH] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; const uint64_t ZEROES[STATE_WIDTH] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const uint64_t THIRTY_TWOS[STATE_WIDTH] = {32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}; -void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +inline void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +inline void sve_shift_right(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) +inline void sve_add(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *overflowed) __attribute__((always_inline)); -void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) +inline void sve_substract(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result, uint64_t *underflowed) __attribute__((always_inline)); -void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) - __attribute__((always_inline)); -void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +inline void sve_substract_as_u32(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) +inline void sve_multiply_low(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) +inline void sve_multiply_high(const uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], uint64_t *result) +inline void sve_mont_red_cst(uint64_t x[STATE_WIDTH], uint64_t y[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) __attribute__((always_inline)); -void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +inline void sve_multiply_montgomery_form_felts(const uint64_t a[STATE_WIDTH], const uint64_t b[STATE_WIDTH], + uint64_t *result) __attribute__((always_inline)); +inline void sve_copy(const uint64_t a[STATE_WIDTH], uint64_t *copy) __attribute__((always_inline)); +inline void sve_exp_acc_3(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +inline void sve_exp_acc_6(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +inline void sve_exp_acc_12(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); -void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) +inline void sve_exp_acc_31(uint64_t base[STATE_WIDTH], uint64_t tail[STATE_WIDTH], uint64_t *result) __attribute__((always_inline)); void sve_shift_left(uint64_t x[STATE_WIDTH], const uint64_t y[STATE_WIDTH], uint64_t *result)