From c2a8d77a1240eab19d3b409a467674306e83d3c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rin=20Cat=20=28=E9=88=B4=E7=8C=AB=29?= Date: Tue, 5 May 2026 09:42:17 +0900 Subject: [PATCH] Use new algorithm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rin Cat (鈴猫) --- Cargo.toml | 21 +- README.md | 110 ++++++ src/bin/bench.rs | 155 +++++++++ src/ed25519.rs | 13 +- src/lib.rs | 886 ++++++++++++++++++++++++++++++++++++----------- src/prime.rs | 96 +++-- src/rsa.rs | 250 +++++++++++++ src/sloth.rs | 111 ------ 8 files changed, 1297 insertions(+), 345 deletions(-) create mode 100644 src/bin/bench.rs create mode 100644 src/rsa.rs delete mode 100644 src/sloth.rs diff --git a/Cargo.toml b/Cargo.toml index 2988d64..fecf320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,24 +1,33 @@ [package] name = "chidori-pow" version = "0.1.0" -edition = "2021" +edition = "2024" [lib] crate-type = ["rlib", "cdylib"] +[[bin]] +name = "bench" +path = "src/bin/bench.rs" +required-features = ["native-bin"] + +[features] +native-bin = [] + [dependencies] num-traits = "0.2" num-integer = "0.1.46" -serde_json = "1.0" -serde = { version = "1.0.217", features = ["derive"] } -cfg-if = "1.0.0" +serde = { version = "1.0.228", features = ["derive"] } +bincode = "1.3" +cfg-if = "1.0.4" base64 = "0.22.1" wasm-bindgen = "0.2" +sha2 = "0.10" [target.'cfg(target_arch = "wasm32")'.dependencies] num-bigint = { version = "0.4", features = ["serde"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] num-bigint = { version="0.4", features = ["rand", "serde"] } -rand = "0.8" -ed25519-dalek = { version = "2.1.1", features = ["std", "rand_core"] } +rand = "0.10" +ed25519-dalek = { version = "2.2.0", features = ["std", "rand_core"] } diff --git a/README.md b/README.md index 5aa4dfa..a60f784 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,112 @@ # chidori-pow +Anti-bruteforce proof-of-work. + +## Design + +The challenge flow uses an RSA trapdoor repeated-squaring puzzle: + +- the server generates or loads RSA factors at process start; +- the signed binary challenge payload contains the public modulus and exact + difficulty step count; +- the client performs scheduled sequential modular squaring; +- the server verifies cheaply with the private trapdoor. + +The default RSA modulus is 2048 bits, and applications may set a different +modulus at initialization. Generated moduli use at least 512 bits. The default +difficulty is `450000` scheduled RSA work steps. Applications should choose +modulus and step counts from their own benchmarks and risk policy. + +Challenges are signed with Ed25519. The default builder generates a fresh +signing key and RSA factors at process start. Persist or inject keys and factors +if restart continuity matters. + +`ChallengerBuilder::build()` and `Challenger::issue_challenge()` return +`Result`s. Generated factors are clamped to the minimum modulus size. +`with_factors(p, q)` injects persisted RSA prime factors and validates that they +are distinct primes with a large enough product. + +The challenge string is `base64url(payload || signature)`. `payload` is the +bincode-serialized `Puzzle` bytes, `signature` is the fixed 64-byte Ed25519 +signature suffix, and the signature covers the raw payload bytes. + +## Binding Data + +Applications may bind their own opaque bytes into the puzzle without sending +those bytes in the challenge: + +```rust +let solution = solve_challenge(&challenge, app_binding_data); +challenger.verify_challenge(&challenge, &solution, expected_binding_data); +``` + +The library does not interpret `binding_data`; callers are responsible for +canonicalizing it. + +Recommended `binding_data` contents are app-specific request context such as flow +name, route, normalized username hash, CSRF token, form nonce, and a versioned +site-specific prefix. Do not put passwords or raw secrets in `binding_data`. + +Mismatched binding data fails exactly like an invalid proof-of-work solution. +Verification returns `false` for malformed, expired, replayed, mismatched, or +internally unavailable challenges. + +## Native Example + +```rust +use chidori_pow::{ChallengerBuilder, solve_challenge}; + +let binding_data = b"site-login-v1\0/login\0user-hash"; + +let challenger = ChallengerBuilder::new() + .with_modulus_bits(2048) + .with_difficulty(450_000) + .build()?; + +let challenge = challenger.issue_challenge()?; +let solution = solve_challenge(&challenge, binding_data); + +assert!(challenger.verify_challenge( + &challenge, + &solution, + binding_data, +)); +``` + +## Browser/WASM + +The wasm package exports: + +```ts +solve_challenge(challenge: string, binding_data: Uint8Array): string +``` + +The browser/app code is responsible for constructing the same canonical +`binding_data` bytes that the server will later use for verification. Pass an +empty `Uint8Array` when no app-specific binding data is needed. + +The browser solver checks the decoded puzzle before solving. It accepts modulus +sizes from 512 to 8192 bits and difficulty up to `10000000` scheduled steps. + +## Benchmark + +Run the native benchmark with: + +```sh +cargo run --release --features native-bin --bin bench -- \ + --modulus-bits 2048 \ + --difficulty 450000 \ + --rounds 5 \ + --binding login:user=a +``` + +The benchmark builds one challenger, then measures repeated issue/solve/verify +rounds and prints min/average/p50/max timings. + +## Replay Cache + +Solved challenge tickets are remembered for the current and previous validity +windows. The default cache capacity is `250000` tickets per window. When the +current window is full, verification fails closed instead of evicting entries, +so replay protection is preserved at the cost of rejecting additional solves +until the next rotation. diff --git a/src/bin/bench.rs b/src/bin/bench.rs new file mode 100644 index 0000000..b29cf2d --- /dev/null +++ b/src/bin/bench.rs @@ -0,0 +1,155 @@ +use chidori_pow::{ChallengerBuilder, solve_challenge}; +use std::env; +use std::time::{Duration, Instant}; + +struct Config { + modulus_bits: u64, + difficulty: u32, + rounds: u32, + binding_data: Vec, +} + +impl Config { + fn from_args() -> Self { + let mut config = Self { + modulus_bits: 2048, + difficulty: 450_000, + rounds: 5, + binding_data: b"bench:binding".to_vec(), + }; + + let mut args = env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--modulus-bits" => { + config.modulus_bits = parse_next(&mut args, "--modulus-bits"); + } + "--difficulty" => { + config.difficulty = parse_next(&mut args, "--difficulty"); + } + "--rounds" => { + config.rounds = parse_next::(&mut args, "--rounds").max(1); + } + "--binding" => { + config.binding_data = args + .next() + .unwrap_or_else(|| usage("--binding requires a value")) + .into_bytes(); + } + "--help" | "-h" => { + print_usage_and_exit(); + } + _ => usage(&format!("unknown argument: {arg}")), + } + } + + config + } +} + +fn parse_next(args: &mut impl Iterator, name: &str) -> T +where + T: std::str::FromStr, +{ + args.next() + .unwrap_or_else(|| usage(&format!("{name} requires a value"))) + .parse() + .unwrap_or_else(|_| usage(&format!("{name} has an invalid value"))) +} + +fn print_usage_and_exit() -> ! { + println!( + "Usage: cargo run --release --features native-bin --bin bench -- [--modulus-bits 2048] [--difficulty 450000] [--rounds 5] [--binding bench:binding]" + ); + std::process::exit(0); +} + +fn usage(message: &str) -> ! { + eprintln!("{message}"); + eprintln!( + "Usage: cargo run --release --features native-bin --bin bench -- [--modulus-bits 2048] [--difficulty 450000] [--rounds 5] [--binding bench:binding]" + ); + std::process::exit(2); +} + +fn timed(f: impl FnOnce() -> T) -> (T, Duration) { + let start = Instant::now(); + let value = f(); + (value, start.elapsed()) +} + +fn main() { + let config = Config::from_args(); + + println!("modulus_bits: {}", config.modulus_bits); + println!("difficulty_steps: {}", config.difficulty); + println!("rounds: {}", config.rounds); + println!("binding_data_len: {}", config.binding_data.len()); + + let (challenger, build_time) = timed(|| { + match ChallengerBuilder::new() + .with_modulus_bits(config.modulus_bits) + .with_difficulty(config.difficulty) + .build() + { + Ok(challenger) => challenger, + Err(error) => { + eprintln!("build failed: {error:?}"); + std::process::exit(1); + } + } + }); + + let mut issue_times = Vec::with_capacity(config.rounds as usize); + let mut solve_times = Vec::with_capacity(config.rounds as usize); + let mut verify_times = Vec::with_capacity(config.rounds as usize); + let mut challenge_len = 0; + let mut solution_len = 0; + + for _ in 0..config.rounds { + let (challenge, issue_time) = timed(|| match challenger.issue_challenge() { + Ok(challenge) => challenge, + Err(error) => { + eprintln!("challenge issue failed: {error:?}"); + std::process::exit(1); + } + }); + let (solution, solve_time) = timed(|| solve_challenge(&challenge, &config.binding_data)); + let (verified, verify_time) = + timed(|| challenger.verify_challenge(&challenge, &solution, &config.binding_data)); + + if !verified { + println!("verified: false"); + std::process::exit(1); + } + + challenge_len = challenge.len(); + solution_len = solution.len(); + issue_times.push(issue_time); + solve_times.push(solve_time); + verify_times.push(verify_time); + } + + println!("build_ms: {:.3}", build_time.as_secs_f64() * 1000.0); + print_stats("issue", &mut issue_times); + print_stats("solve", &mut solve_times); + print_stats("verify", &mut verify_times); + println!("challenge_len: {challenge_len}"); + println!("solution_len: {solution_len}"); + println!("verified: true"); +} + +fn print_stats(label: &str, times: &mut [Duration]) { + times.sort(); + + let sum = times.iter().fold(Duration::ZERO, |acc, value| acc + *value); + let avg = sum.as_secs_f64() * 1000.0 / times.len() as f64; + let min = times[0].as_secs_f64() * 1000.0; + let p50 = times[times.len() / 2].as_secs_f64() * 1000.0; + let max = times[times.len() - 1].as_secs_f64() * 1000.0; + + println!("{label}_min_ms: {min:.3}"); + println!("{label}_avg_ms: {avg:.3}"); + println!("{label}_p50_ms: {p50:.3}"); + println!("{label}_max_ms: {max:.3}"); +} diff --git a/src/ed25519.rs b/src/ed25519.rs index 4544207..2f9c0bf 100644 --- a/src/ed25519.rs +++ b/src/ed25519.rs @@ -7,9 +7,12 @@ pub struct ChallengeSigner { } impl ChallengeSigner { + pub fn generate_sign_key() -> [u8; 32] { + rand::random() + } + pub fn new() -> Self { - let signing_key = SigningKey::generate(&mut rand::thread_rng()); - ChallengeSigner { signing_key } + ChallengeSigner::new_from_bytes(ChallengeSigner::generate_sign_key()) } pub fn new_from_bytes(bytes: [u8; 32]) -> Self { @@ -26,3 +29,9 @@ impl ChallengeSigner { self.signing_key.verify(message, &signature_result).is_ok() } } + +impl Default for ChallengeSigner { + fn default() -> Self { + Self::new() + } +} diff --git a/src/lib.rs b/src/lib.rs index ba62057..c5b9c4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,40 +1,66 @@ -pub mod sloth; +pub mod rsa; -use base64::{prelude::BASE64_URL_SAFE, Engine}; -use num_bigint::BigUint; +use base64::{Engine, prelude::BASE64_URL_SAFE}; +pub use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use wasm_bindgen::prelude::*; +const MIN_RSA_MODULUS_BITS: u64 = 512; +const MAX_CHALLENGE_LEN: usize = 10240; +#[cfg(not(target_arch = "wasm32"))] +const MAX_SOLUTION_LEN: usize = 1024; +const MAX_BINDING_DATA_LEN: usize = 4096; +const MAX_SOLVE_MODULUS_BITS: u64 = 8192; +const MAX_SOLVE_DIFFICULTY_STEPS: u32 = 10_000_000; + cfg_if::cfg_if! { if #[cfg(target_arch = "wasm32")] { } else { use std::collections::HashSet; use std::sync::{Arc, RwLock}; use std::time::{SystemTime, UNIX_EPOCH}; - use rand::Rng; pub mod prime; pub mod ed25519; } } #[cfg(not(target_arch = "wasm32"))] -static DIFFICULTY_SCALE: f64 = 0.04; +const DEFAULT_RSA_MODULUS_BITS: u64 = 2048; +#[cfg(not(target_arch = "wasm32"))] +const DEFAULT_DIFFICULTY_STEPS: u32 = 450_000; +#[cfg(not(target_arch = "wasm32"))] +const DEFAULT_TICKET_CACHE_CAPACITY: usize = 250_000; #[derive(Serialize, Deserialize)] pub struct Puzzle { pub ticket: u64, pub issued_at: u64, - pub prime: BigUint, - pub challenge: BigUint, + pub modulus: BigUint, pub difficulty: u32, } +#[cfg(not(target_arch = "wasm32"))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChallengerBuildError { + EqualFactors, + ModulusTooSmall, + NonPrimeFactor, +} + +#[cfg(not(target_arch = "wasm32"))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChallengeIssueError { + Serialize, +} + #[cfg(not(target_arch = "wasm32"))] pub struct ChallengerBuilder { difficulty: Option, - algorithm: Option, + factors: Option<(BigUint, BigUint)>, + modulus_bits: Option, signer: Option, valid_time_window: Option, + ticket_cache_capacity: Option, } #[cfg(not(target_arch = "wasm32"))] @@ -42,19 +68,22 @@ impl ChallengerBuilder { pub fn new() -> Self { Self { difficulty: None, - algorithm: None, + factors: None, + modulus_bits: None, signer: None, valid_time_window: None, + ticket_cache_capacity: None, } } + /// Set the exact number of scheduled RSA work steps. pub fn with_difficulty(mut self, difficulty: u32) -> Self { - self.difficulty = Some((difficulty as f64 * DIFFICULTY_SCALE) as u32); + self.difficulty = Some(difficulty); self } - pub fn with_prime(mut self, prime: BigUint) -> Self { - self.algorithm = Some(crate::sloth::Sloth::new(prime)); + pub fn with_modulus_bits(mut self, modulus_bits: u64) -> Self { + self.modulus_bits = Some(modulus_bits.max(MIN_RSA_MODULUS_BITS)); self } @@ -68,167 +97,191 @@ impl ChallengerBuilder { self } - pub fn build(self) -> Challenger { - let difficulty = self - .difficulty - .unwrap_or((1000.0 * DIFFICULTY_SCALE) as u32); + pub fn with_ticket_cache_capacity(mut self, ticket_cache_capacity: usize) -> Self { + self.ticket_cache_capacity = Some(ticket_cache_capacity.max(1)); + self + } - let algorithm = self.algorithm.unwrap_or_else(|| { - crate::sloth::Sloth::new(crate::prime::generate_prime_mod_3_4(4096, 64)) - }); + pub fn with_factors(mut self, p: BigUint, q: BigUint) -> Self { + self.factors = Some((p, q)); + self + } - let signer = self - .signer - .unwrap_or_else(|| crate::ed25519::ChallengeSigner::new()); + pub fn generate_factor(bits: u64) -> BigUint { + crate::prime::generate_prime_mod_3_4(bits, 64) + } + + pub fn generate_factors(modulus_bits: u64) -> (BigUint, BigUint) { + generate_factors_with(modulus_bits, Self::generate_factor) + } + + pub fn generate_sign_key() -> [u8; 32] { + crate::ed25519::ChallengeSigner::generate_sign_key() + } + + pub fn build(self) -> Result { + let algorithm = match self.factors { + Some((p, q)) => { + validate_factors(&p, &q)?; + crate::rsa::RsaTrapdoor::from_factors(p, q) + } + None => crate::rsa::RsaTrapdoor::generate( + self.modulus_bits.unwrap_or(DEFAULT_RSA_MODULUS_BITS), + 64, + ), + }; + let difficulty = self.difficulty.unwrap_or(DEFAULT_DIFFICULTY_STEPS); + + let signer = self.signer.unwrap_or_default(); let tickets_current = Arc::new(RwLock::new(HashSet::new())); let tickets_previous = Arc::new(RwLock::new(HashSet::new())); let valid_time_window = self.valid_time_window.unwrap_or(300); - let last_rotation = Arc::new(RwLock::new( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(), - )); + let ticket_cache_capacity = self + .ticket_cache_capacity + .unwrap_or(DEFAULT_TICKET_CACHE_CAPACITY); + let last_rotation = Arc::new(RwLock::new(unix_now_secs())); - Challenger { + Ok(Challenger { difficulty, algorithm, signer, tickets_current, tickets_previous, valid_time_window, + ticket_cache_capacity, last_rotation, + }) + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn generate_factors_with( + modulus_bits: u64, + mut generate_factor: impl FnMut(u64) -> BigUint, +) -> (BigUint, BigUint) { + let modulus_bits = modulus_bits.max(MIN_RSA_MODULUS_BITS); + let factor_bits = modulus_bits.div_ceil(2); + + loop { + let p = generate_factor(factor_bits); + let q = generate_factor(factor_bits); + if q != p && (&p * &q).bits() >= modulus_bits { + return (p, q); } } } +#[cfg(not(target_arch = "wasm32"))] +impl Default for ChallengerBuilder { + fn default() -> Self { + Self::new() + } +} + #[cfg(not(target_arch = "wasm32"))] pub struct Challenger { pub difficulty: u32, - algorithm: crate::sloth::Sloth, + algorithm: crate::rsa::RsaTrapdoor, signer: crate::ed25519::ChallengeSigner, tickets_current: Arc>>, tickets_previous: Arc>>, valid_time_window: u64, + ticket_cache_capacity: usize, last_rotation: Arc>, } #[cfg(not(target_arch = "wasm32"))] impl Challenger { + /// Set the exact number of scheduled RSA work steps. pub fn set_difficulty(&mut self, difficulty: u32) { - self.difficulty = (difficulty as f64 * DIFFICULTY_SCALE) as u32; + self.difficulty = difficulty; } - pub fn issue_challenge(&mut self) -> String { - let (_x, y, p) = self.algorithm.create(self.difficulty); - + pub fn issue_challenge(&self) -> Result { let puzzle = Puzzle { - ticket: rand::thread_rng().gen(), - issued_at: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(), - prime: p, - challenge: y, + ticket: rand::random(), + issued_at: unix_now_secs(), + modulus: self.algorithm.modulus.clone(), difficulty: self.difficulty, }; - let json = serde_json::to_string(&puzzle).unwrap(); - let json_payload = BASE64_URL_SAFE.encode(json); - let signature = self.signer.sign(json_payload.as_bytes()); - let signature_payload = BASE64_URL_SAFE.encode(signature.to_bytes()); - - let challenge = json_payload + "." + signature_payload.as_str(); - challenge + self.issue_challenge_with(&puzzle, |puzzle| { + bincode::serialize(puzzle).map_err(|_| ChallengeIssueError::Serialize) + }) } - pub fn verify_challenge(&mut self, challenge: &str, solution: &str) -> bool { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); + fn issue_challenge_with( + &self, + puzzle: &Puzzle, + serialize: impl FnOnce(&Puzzle) -> Result, ChallengeIssueError>, + ) -> Result { + let payload_bytes = serialize(puzzle)?; + let signature = self.signer.sign(&payload_bytes); + let signature_bytes = signature.to_bytes(); + + let mut challenge_bytes = Vec::with_capacity(payload_bytes.len() + signature_bytes.len()); + challenge_bytes.extend_from_slice(&payload_bytes); + challenge_bytes.extend_from_slice(&signature_bytes); + + Ok(BASE64_URL_SAFE.encode(challenge_bytes)) + } + + pub fn verify_challenge(&self, challenge: &str, solution: &str, binding_data: &[u8]) -> bool { + let now = unix_now_secs(); let need_rotation = { - let last_rotation = self.last_rotation.read().unwrap(); - now - *last_rotation > self.valid_time_window + let last_rotation = match self.last_rotation.read() { + Ok(last_rotation) => last_rotation, + Err(_) => return false, + }; + now.saturating_sub(*last_rotation) > self.valid_time_window }; - if need_rotation { - let mut last_rotation = self.last_rotation.write().unwrap(); - let mut tickets_previous = self.tickets_previous.write().unwrap(); - let mut tickets_current = self.tickets_current.write().unwrap(); - *tickets_previous = std::mem::take(&mut *tickets_current); - *last_rotation = now; + if need_rotation && !self.rotate_tickets_if_needed(now) { + return false; } // Reject challenges that are too long - if challenge.len() > 10240 || solution.len() > 1024 { + if challenge.len() > MAX_CHALLENGE_LEN + || solution.len() > MAX_SOLUTION_LEN + || binding_data.len() > MAX_BINDING_DATA_LEN + { return false; } - // Split the challenge into the JSON payload and the signature - let parts: Vec<&str> = challenge.split('.').collect(); - if parts.len() != 2 { - return false; - } - - let json_payload = parts[0]; - let signature_payload = parts[1]; - - // Decode the signature - let signature_bytes: [u8; 64] = match BASE64_URL_SAFE.decode(signature_payload.as_bytes()) { - Ok(bytes) => match bytes.try_into() { - Ok(b) => b, - Err(_) => return false, - }, + let challenge_bytes = match BASE64_URL_SAFE.decode(challenge.as_bytes()) { + Ok(bytes) => bytes, Err(_) => return false, }; - // Verify the signature on the JSON payload - if !self - .signer - .verify(json_payload.as_bytes(), &signature_bytes) - { + if challenge_bytes.len() <= 64 { return false; } - // Decode the JSON payload - let puzzle: Puzzle = match BASE64_URL_SAFE.decode(json_payload.as_bytes()) { - Ok(bytes) => match serde_json::from_slice(&bytes) { - Ok(s) => s, - Err(_) => return false, - }, + let signature_offset = challenge_bytes.len() - 64; + let payload_bytes = &challenge_bytes[..signature_offset]; + let mut signature_bytes = [0_u8; 64]; + signature_bytes.copy_from_slice(&challenge_bytes[signature_offset..]); + + // Verify the signature on the raw binary payload. + if !self.signer.verify(payload_bytes, &signature_bytes) { + return false; + } + + let puzzle: Puzzle = match bincode::deserialize(payload_bytes) { + Ok(s) => s, Err(_) => return false, }; - // Reject challenges that are too old - if puzzle.issued_at < now - self.valid_time_window { + if puzzle.modulus != self.algorithm.modulus { return false; } - // Reject challenges that have already been solved - { - if self - .tickets_current - .read() - .unwrap() - .contains(&puzzle.ticket) - { - return false; - } - } - - { - if self - .tickets_previous - .read() - .unwrap() - .contains(&puzzle.ticket) - { - return false; - } + // Reject expired challenges. + if puzzle.issued_at < now.saturating_sub(self.valid_time_window) { + return false; } // Decode the solution @@ -237,165 +290,606 @@ impl Challenger { Err(_) => return false, }; - // Verify the solution - let solution_result = - self.algorithm - .verify(&solution, &puzzle.challenge, puzzle.difficulty); + let base = crate::rsa::derive_base(payload_bytes, binding_data, &puzzle.modulus); - if solution_result { - self.tickets_current.write().unwrap().insert(puzzle.ticket); + // Verify the solution + let solution_result = self.algorithm.verify( + &base, + &solution, + payload_bytes, + binding_data, + puzzle.difficulty, + ); + + if solution_result && !self.mark_ticket_used(puzzle.ticket) { + return false; } solution_result } + + fn mark_ticket_used(&self, ticket: u64) -> bool { + let tickets_previous = match self.tickets_previous.read() { + Ok(tickets_previous) => tickets_previous, + Err(_) => return false, + }; + if tickets_previous.contains(&ticket) { + return false; + } + + let mut tickets_current = match self.tickets_current.write() { + Ok(tickets_current) => tickets_current, + Err(_) => return false, + }; + if tickets_current.contains(&ticket) { + return false; + } + + if tickets_current.len() >= self.ticket_cache_capacity { + return false; + } + + tickets_current.insert(ticket); + true + } + + fn rotate_tickets_if_needed(&self, now: u64) -> bool { + let mut last_rotation = match self.last_rotation.write() { + Ok(last_rotation) => last_rotation, + Err(_) => return false, + }; + if now.saturating_sub(*last_rotation) <= self.valid_time_window { + return true; + } + + let mut tickets_previous = match self.tickets_previous.write() { + Ok(tickets_previous) => tickets_previous, + Err(_) => return false, + }; + let mut tickets_current = match self.tickets_current.write() { + Ok(tickets_current) => tickets_current, + Err(_) => return false, + }; + *tickets_previous = std::mem::take(&mut *tickets_current); + *last_rotation = now; + true + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn unix_now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +#[cfg(not(target_arch = "wasm32"))] +fn validate_factors(p: &BigUint, q: &BigUint) -> Result<(), ChallengerBuildError> { + if p == q { + return Err(ChallengerBuildError::EqualFactors); + } + + if (p * q).bits() < MIN_RSA_MODULUS_BITS { + return Err(ChallengerBuildError::ModulusTooSmall); + } + + if !crate::prime::is_probably_prime(p, 64) || !crate::prime::is_probably_prime(q, 64) { + return Err(ChallengerBuildError::NonPrimeFactor); + } + + Ok(()) } #[wasm_bindgen] -pub fn solve_challenge(challenge: &str) -> String { - let parts: Vec<&str> = challenge.split('.').collect(); - if parts.len() != 2 { +pub fn solve_challenge(challenge: &str, binding_data: &[u8]) -> String { + if challenge.len() > MAX_CHALLENGE_LEN || binding_data.len() > MAX_BINDING_DATA_LEN { return "ERROR".to_string(); } - let json_payload = parts[0]; - let puzzle: Puzzle = match BASE64_URL_SAFE.decode(json_payload.as_bytes()) { - Ok(bytes) => match serde_json::from_slice(&bytes) { - Ok(s) => s, - Err(_) => return "ERROR".to_string(), - }, + let challenge_bytes = match BASE64_URL_SAFE.decode(challenge.as_bytes()) { + Ok(bytes) => bytes, Err(_) => return "ERROR".to_string(), }; - let y = puzzle.challenge.clone(); - let p = puzzle.prime.clone(); - let t = puzzle.difficulty; + if challenge_bytes.len() <= 64 { + return "ERROR".to_string(); + } - let solution = crate::sloth::Sloth::decode(&y, &p, t); + let signature_offset = challenge_bytes.len() - 64; + let payload_bytes = &challenge_bytes[..signature_offset]; + + let puzzle: Puzzle = match bincode::deserialize(payload_bytes) { + Ok(s) => s, + Err(_) => return "ERROR".to_string(), + }; + + if !puzzle_looks_ok_for_solve(&puzzle) { + return "ERROR".to_string(); + } + + let base = crate::rsa::derive_base(payload_bytes, binding_data, &puzzle.modulus); + + let solution = crate::rsa::solve( + &base, + &puzzle.modulus, + payload_bytes, + binding_data, + puzzle.difficulty, + ); BASE64_URL_SAFE.encode(solution.to_bytes_be()) } +fn puzzle_looks_ok_for_solve(puzzle: &Puzzle) -> bool { + let modulus_bits = puzzle.modulus.bits(); + (MIN_RSA_MODULUS_BITS..=MAX_SOLVE_MODULUS_BITS).contains(&modulus_bits) + && puzzle.difficulty <= MAX_SOLVE_DIFFICULTY_STEPS +} + #[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod tests { use super::*; + use std::sync::RwLock; use std::thread; #[test] - fn test_challenge() { - let p = crate::prime::generate_prime_mod_3_4(4096, 64); - println!("prime: {}", p); - let mut challenger = ChallengerBuilder::new() - .with_prime(p) - .with_difficulty(1000) + fn test_default_difficulty() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + assert_eq!(DEFAULT_DIFFICULTY_STEPS, challenger.difficulty); + } + + #[test] + fn test_builder_helpers() { + let sign_key = ChallengerBuilder::generate_sign_key(); + let factor = ChallengerBuilder::generate_factor(16); + let (p, q) = ChallengerBuilder::generate_factors(512); + let mut challenger = ChallengerBuilder::default() + .with_sign_key(sign_key) + .with_factors(p, q) + .with_difficulty(10) + .build() + .unwrap(); + + assert!(factor.bits() >= 16); + assert_eq!(10, challenger.difficulty); + + challenger.set_difficulty(11); + assert_eq!(11, challenger.difficulty); + } + + #[test] + fn test_generate_factors_retries_until_valid_pair() { + let big_p: BigUint = (BigUint::from(1_u32) << 256) + BigUint::from(1_u32); + let big_q: BigUint = (BigUint::from(1_u32) << 256) + BigUint::from(3_u32); + let values = [ + big_p.clone(), + big_p.clone(), + BigUint::from(3_u32), + BigUint::from(5_u32), + big_p, + big_q, + ]; + let mut index = 0; + + let (p, q) = generate_factors_with(MIN_RSA_MODULUS_BITS, |_| { + let value = values[index].clone(); + index += 1; + value + }); + + assert_ne!(p, q); + assert!((&p * &q).bits() >= MIN_RSA_MODULUS_BITS); + assert_eq!(values.len(), index); + } + + #[test] + fn test_modulus_bits_are_configurable() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .with_difficulty(10) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); + + assert!(challenger.verify_challenge(&challenge, &solution, &[])); + } + + #[test] + fn test_modulus_bits_have_minimum() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1) + .with_difficulty(10) + .build() + .unwrap(); + + assert!(challenger.algorithm.modulus.bits() >= MIN_RSA_MODULUS_BITS); + } + + #[test] + fn test_invalid_factors_fail_build() { + let result = ChallengerBuilder::new() + .with_factors(BigUint::from(3_u32), BigUint::from(3_u32)) .build(); - let challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + + assert_eq!(Some(ChallengerBuildError::EqualFactors), result.err()); + + let result = ChallengerBuilder::new() + .with_factors(BigUint::from(3_u32), BigUint::from(5_u32)) + .build(); + + assert_eq!(Some(ChallengerBuildError::ModulusTooSmall), result.err()); + + let p = BigUint::from(1_u32) << 256; + let q = (BigUint::from(1_u32) << 256) + BigUint::from(2_u32); + let result = ChallengerBuilder::new().with_factors(p, q).build(); + + assert_eq!(Some(ChallengerBuildError::NonPrimeFactor), result.err()); + } + + #[test] + fn test_issue_challenge_reports_serialize_error() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + let puzzle = Puzzle { + ticket: 1, + issued_at: unix_now_secs(), + modulus: challenger.algorithm.modulus.clone(), + difficulty: 1, + }; + + assert_eq!( + Err(ChallengeIssueError::Serialize), + challenger.issue_challenge_with(&puzzle, |_| Err(ChallengeIssueError::Serialize)) + ); + } + + #[test] + fn test_rotation_rechecks_after_write_lock() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .with_valid_time_window(300) + .build() + .unwrap(); + let now = unix_now_secs(); + *challenger.last_rotation.write().unwrap() = now - 301; + + assert!(challenger.rotate_tickets_if_needed(now)); + assert_eq!(now, *challenger.last_rotation.read().unwrap()); + + assert!(challenger.rotate_tickets_if_needed(now)); + assert_eq!(now, *challenger.last_rotation.read().unwrap()); + } + + #[test] + fn test_poisoned_locks_fail_closed() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + poison_lock(&challenger.last_rotation); + assert!(!challenger.verify_challenge("", "", &[])); + assert!(!challenger.rotate_tickets_if_needed(unix_now_secs())); + + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + poison_lock(&challenger.tickets_previous); + assert!(!challenger.mark_ticket_used(1)); + + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + poison_lock(&challenger.tickets_current); + assert!(!challenger.mark_ticket_used(1)); + + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + let now = unix_now_secs(); + *challenger.last_rotation.write().unwrap() = now - 301; + poison_lock(&challenger.tickets_previous); + assert!(!challenger.rotate_tickets_if_needed(now)); + + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .build() + .unwrap(); + let now = unix_now_secs(); + *challenger.last_rotation.write().unwrap() = now - 301; + poison_lock(&challenger.tickets_current); + assert!(!challenger.verify_challenge("", "", &[])); + } + + #[test] + fn test_ticket_cache_capacity_fails_closed() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(512) + .with_difficulty(10) + .with_ticket_cache_capacity(1) + .build() + .unwrap(); + + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); + assert!(challenger.verify_challenge(&challenge, &solution, &[])); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); + + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); + } + + #[test] + fn test_solve_rejects_implausible_payload() { + let puzzle = Puzzle { + ticket: 1, + issued_at: 1, + modulus: BigUint::from(7_u32), + difficulty: 1, + }; + let challenge = unsigned_challenge_for(&puzzle); + + assert_eq!("ERROR", solve_challenge(&challenge, &[])); + + let puzzle = Puzzle { + ticket: 1, + issued_at: 1, + modulus: (BigUint::from(1_u32) << (MIN_RSA_MODULUS_BITS as usize - 1)) + | BigUint::from(1_u32), + difficulty: MAX_SOLVE_DIFFICULTY_STEPS + 1, + }; + let challenge = unsigned_challenge_for(&puzzle); + + assert_eq!("ERROR", solve_challenge(&challenge, &[])); + } + + #[test] + fn test_solve_rejects_malformed_inputs() { + assert_eq!( + "ERROR", + solve_challenge(&"A".repeat(MAX_CHALLENGE_LEN + 1), &[]) + ); + assert_eq!( + "ERROR", + solve_challenge("rin", &vec![0_u8; MAX_BINDING_DATA_LEN + 1]) + ); + assert_eq!("ERROR", solve_challenge("not base64!", &[])); + assert_eq!( + "ERROR", + solve_challenge(&BASE64_URL_SAFE.encode([0_u8; 64]), &[]) + ); + assert_eq!( + "ERROR", + solve_challenge(&BASE64_URL_SAFE.encode([0_u8; 65]), &[]) + ); + } + + #[test] + fn test_challenge() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(2048) + .with_difficulty(450_000) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); println!("challenge: {}", challenge); println!("solution: {}", solution); - assert!(challenger.verify_challenge(&challenge, &solution)); + assert!(challenger.verify_challenge(&challenge, &solution, &[])); } #[test] fn test_challenge_incorrect() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) - .build(); - let challenge = challenger.issue_challenge(); - let mut solution = solve_challenge(&challenge); + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let mut solution = solve_challenge(&challenge, &[]); solution.replace_range(1..4, "abcd"); - assert_eq!(false, challenger.verify_challenge(&challenge, &solution)); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); } #[test] - fn test_challenge_invaild_challenge() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) - .build(); - let mut challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + fn test_challenge_invalid_challenge() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let mut challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); challenge.replace_range(1..4, "abcd"); - assert_eq!(false, challenger.verify_challenge(&challenge, &solution)); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); } #[test] - fn test_challenge_invaild_signature() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) - .build(); - let mut challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + fn test_challenge_invalid_signature() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let mut challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); challenge.replace_range(challenge.len() - 5..challenge.len() - 1, "abcd"); - assert_eq!(false, challenger.verify_challenge(&challenge, &solution)); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); + + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); + let mut bytes = BASE64_URL_SAFE.decode(challenge.as_bytes()).unwrap(); + let last = bytes.len() - 1; + bytes[last] ^= 1; + let challenge = BASE64_URL_SAFE.encode(bytes); + + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); } #[test] fn test_challenge_reuse_attack() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) - .build(); - let challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); - assert!(challenger.verify_challenge(&challenge, &solution)); - assert_eq!(false, challenger.verify_challenge(&challenge, &solution)); + assert!(challenger.verify_challenge(&challenge, &solution, &[])); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); } #[test] - fn test_challenge_invaild_data() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) - .build(); + fn test_challenge_previous_ticket_reuse_attack() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let puzzle = puzzle_from_challenge(&challenge); + let solution = solve_challenge(&challenge, &[]); - assert_eq!( - false, - challenger.verify_challenge(&"".to_string(), &"".to_string()) - ); - assert_eq!( - false, - challenger.verify_challenge(&"rin".to_string(), &"cat".to_string()) - ); + challenger + .tickets_previous + .write() + .unwrap() + .insert(puzzle.ticket); + + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); + } + + #[test] + fn test_challenge_invalid_data() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + + assert!(!challenger.verify_challenge(&"A".repeat(MAX_CHALLENGE_LEN + 1), "", &[])); + assert!(!challenger.verify_challenge("", &"A".repeat(MAX_SOLUTION_LEN + 1), &[])); + assert!(!challenger.verify_challenge("", "", &vec![0_u8; MAX_BINDING_DATA_LEN + 1])); + assert!(!challenger.verify_challenge("", "", &[])); + assert!(!challenger.verify_challenge("rin", "cat", &[])); + + let challenge = signed_raw_challenge_for(&challenger, &[0_u8]); + assert!(!challenger.verify_challenge(&challenge, "", &[])); + + let sign_key = ChallengerBuilder::generate_sign_key(); + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .with_sign_key(sign_key) + .build() + .unwrap(); + let other = ChallengerBuilder::new() + .with_modulus_bits(1024) + .with_sign_key(sign_key) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); + + assert!(!other.verify_challenge(&challenge, &solution, &[])); } #[test] fn test_challenge_expired() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) .with_valid_time_window(1) - .build(); + .build() + .unwrap(); - let challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); thread::sleep(std::time::Duration::from_secs(2)); - assert_eq!(false, challenger.verify_challenge(&challenge, &solution)); + assert!(!challenger.verify_challenge(&challenge, &solution, &[])); } #[test] fn test_challenge_tickets() { - let mut challenger = ChallengerBuilder::new() - .with_prime(crate::prime::generate_prime_mod_3_4(2048, 64)) + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) .with_valid_time_window(5) - .build(); + .build() + .unwrap(); // Normal case - let challenge = challenger.issue_challenge(); - let solution = solve_challenge(&challenge); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, &[]); - assert_eq!(true, challenger.verify_challenge(&challenge, &solution)); + assert!(challenger.verify_challenge(&challenge, &solution, &[])); // expired case - let challenge1 = challenger.issue_challenge(); - let solution1 = solve_challenge(&challenge1); + let challenge1 = challenger.issue_challenge().unwrap(); + let solution1 = solve_challenge(&challenge1, &[]); thread::sleep(std::time::Duration::from_secs(6)); // Normal case - let challenge2 = challenger.issue_challenge(); - let solution2 = solve_challenge(&challenge2); + let challenge2 = challenger.issue_challenge().unwrap(); + let solution2 = solve_challenge(&challenge2, &[]); - assert_eq!(false, challenger.verify_challenge(&challenge1, &solution1)); - assert_eq!(true, challenger.verify_challenge(&challenge2, &solution2)); + assert!(!challenger.verify_challenge(&challenge1, &solution1, &[])); + assert!(challenger.verify_challenge(&challenge2, &solution2, &[])); + } + + #[test] + fn test_challenge_binding_data_is_bound() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, b"login:user=a"); + + assert!(challenger.verify_challenge(&challenge, &solution, b"login:user=a")); + } + + #[test] + fn test_challenge_binding_data_mismatch() { + let challenger = ChallengerBuilder::new() + .with_modulus_bits(1024) + .build() + .unwrap(); + let challenge = challenger.issue_challenge().unwrap(); + let solution = solve_challenge(&challenge, b"login:user=a"); + + assert!(!challenger.verify_challenge(&challenge, &solution, b"login:user=b")); + } + + fn unsigned_challenge_for(puzzle: &Puzzle) -> String { + let mut challenge_bytes = bincode::serialize(puzzle).unwrap(); + challenge_bytes.extend_from_slice(&[0_u8; 64]); + BASE64_URL_SAFE.encode(challenge_bytes) + } + + fn signed_raw_challenge_for(challenger: &Challenger, payload: &[u8]) -> String { + let signature = challenger.signer.sign(payload); + let mut challenge_bytes = payload.to_vec(); + challenge_bytes.extend_from_slice(&signature.to_bytes()); + BASE64_URL_SAFE.encode(challenge_bytes) + } + + fn puzzle_from_challenge(challenge: &str) -> Puzzle { + let challenge_bytes = BASE64_URL_SAFE.decode(challenge.as_bytes()).unwrap(); + let payload_bytes = &challenge_bytes[..challenge_bytes.len() - 64]; + bincode::deserialize(payload_bytes).unwrap() + } + + fn poison_lock(lock: &RwLock) { + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let _guard = lock.write().unwrap(); + panic!("poison lock"); + })); + assert!(result.is_err()); } } diff --git a/src/prime.rs b/src/prime.rs index f1dd9e6..a2bb9ff 100644 --- a/src/prime.rs +++ b/src/prime.rs @@ -1,7 +1,7 @@ -use num_bigint::{BigUint, RandBigInt, ToBigUint}; +use num_bigint::BigUint; use num_integer::Integer; use num_traits::One; -use rand::thread_rng; +use rand::RngExt; /// Miller–Rabin primality test /// https://en.wikipedia.org/wiki/Miller-Rabin_primality_test @@ -10,17 +10,17 @@ use rand::thread_rng; /// This is a basic implementation for odd candidate `n` greater than 2. pub fn is_probably_prime(n: &BigUint, k: u32) -> bool { // Handle small numbers. - if n == &2u32.to_biguint().unwrap() { + if n == &BigUint::from(2_u32) || n == &BigUint::from(3_u32) { return true; } - if n < &2u32.to_biguint().unwrap() || n.is_even() { + if n < &BigUint::from(2_u32) || n.is_even() { return false; } // Write n − 1 as d * 2^r. let one: BigUint = One::one(); - let two: BigUint = 2u32.to_biguint().unwrap(); + let two = BigUint::from(2_u32); let n_minus_one = n - &one; @@ -33,11 +33,11 @@ pub fn is_probably_prime(n: &BigUint, k: u32) -> bool { } // Try k random witnesses. - let mut rng = thread_rng(); + let mut rng = rand::rng(); 'witness_loop: for _ in 0..k { // Choose a random a in [2, n-2] - let a = rng.gen_biguint_range(&two, &(n - &two)); + let a = random_biguint_range(&mut rng, &two, &n_minus_one); // Compute x = a^d mod n. let mut x = a.modpow(&d, n); @@ -60,14 +60,53 @@ pub fn is_probably_prime(n: &BigUint, k: u32) -> bool { true } +fn random_biguint_range( + rng: &mut R, + low: &BigUint, + high: &BigUint, +) -> BigUint { + if low >= high { + return low.clone(); + } + + let bits = high.bits(); + let byte_len = bits.div_ceil(8) as usize; + let excess_bits = byte_len * 8 - bits as usize; + let mut bytes = vec![0_u8; byte_len]; + + loop { + rng.fill(&mut bytes); + + if excess_bits > 0 { + bytes[0] &= 0xff >> excess_bits; + } + + let candidate = BigUint::from_bytes_be(&bytes); + if &candidate >= low && &candidate < high { + return candidate; + } + } +} + pub fn generate_prime_mod_3_4(bits: u64, rounds: u32) -> BigUint { - let mut rng = thread_rng(); - let four: BigUint = 4u32.to_biguint().unwrap(); - let three: BigUint = 3u32.to_biguint().unwrap(); + let bits = bits.max(2); + + let mut rng = rand::rng(); + let byte_len = bits.div_ceil(8) as usize; + let excess_bits = byte_len * 8 - bits as usize; + let mut bytes = vec![0_u8; byte_len]; + let four = BigUint::from(4_u32); + let three = BigUint::from(3_u32); loop { // Generate a random number of the required bit-length. - let mut candidate = rng.gen_biguint(bits); + rng.fill(&mut bytes); + + if excess_bits > 0 { + bytes[0] &= 0xff >> excess_bits; + } + + let mut candidate = BigUint::from_bytes_be(&bytes); // Force the most significant bit to ensure it is exactly `bits` long. candidate.set_bit(bits - 1, true); @@ -78,21 +117,10 @@ pub fn generate_prime_mod_3_4(bits: u64, rounds: u32) -> BigUint { // Compute candidate mod 4. let rem = &candidate % &four; if rem != three { - // Calculate the adjustment needed. - // We want candidate + adj ≡ 3 (mod 4); that is, adj ≡ (3 - rem) mod 4. - let adj = if three >= rem { - &three - &rem - } else { - &three + &four - &rem - }; + let adj = &three - &rem; candidate += &adj; } - // Now candidate mod 4 should equal 3. - if &candidate % &four != three { - continue; // Should not happen, but be safe. - } - // Use Miller–Rabin to test candidate for primality. if is_probably_prime(&candidate, rounds) { return candidate; @@ -108,6 +136,9 @@ mod tests { /// This test has chance of false positives since it's probabilistic. #[test] fn test_is_probably_prime() { + assert!(is_probably_prime(&BigUint::from(2_u32), 64)); + assert!(is_probably_prime(&BigUint::from(3_u32), 64)); + let primes: Vec = vec![ BigUint::parse_bytes(b"723646214863847842402314246121044767400617866176733021174245260448070467161519753555151305391831172396032179266088879736498934532967238875067731186605319314486487094813782345277515046149035823394700558031365128080643117834402421935144013956523482034192169360458395261772557972018417296402072764848759", 10).unwrap(), BigUint::parse_bytes(b"268263962333296278340388301081833650583348564229009436402694247537863120457689419730666587700766950987474095807568632415133217325374788947918148879191904084190506645642271822238922848768059332086841470078498514866531550241226838886983034780850983546266212727522552823301120544245546076442247019621281", 10).unwrap(), @@ -122,7 +153,7 @@ mod tests { ]; for p in primes.iter() { - assert!(is_probably_prime(&p, 64)); + assert!(is_probably_prime(p, 64)); } } @@ -142,20 +173,25 @@ mod tests { ]; for c in composites.iter() { - assert!(!is_probably_prime(&c, 64)); + assert!(!is_probably_prime(c, 64)); } } #[test] fn test_generate_prime_mod_3_4() { - let bits = 2048; + let bits = 9; let rounds = 64; let prime = generate_prime_mod_3_4(bits, rounds); assert!(is_probably_prime(&prime, rounds)); assert_eq!(prime.bits(), bits); - assert_eq!( - &prime % 4u32.to_biguint().unwrap(), - 3u32.to_biguint().unwrap() - ); + assert_eq!(&prime % BigUint::from(4_u32), BigUint::from(3_u32)); + } + + #[test] + fn test_random_biguint_range_handles_empty_range() { + let mut rng = rand::rng(); + let value = BigUint::from(7_u32); + + assert_eq!(value, random_biguint_range(&mut rng, &value, &value)); } } diff --git a/src/rsa.rs b/src/rsa.rs new file mode 100644 index 0000000..551b7d6 --- /dev/null +++ b/src/rsa.rs @@ -0,0 +1,250 @@ +use num_bigint::BigUint; +use num_traits::One; +use sha2::{Digest, Sha256, Sha512}; + +const DOMAIN: &[u8] = b"chidori-pow-rsa-v1"; +const SCHEDULE_DOMAIN: &[u8] = b"chidori-pow-rsa-schedule-v1"; +const SCHEDULE_BLOCK_STEPS: u32 = 64; + +#[cfg(not(target_arch = "wasm32"))] +pub struct RsaTrapdoor { + pub modulus: BigUint, + lambda: BigUint, +} + +#[cfg(not(target_arch = "wasm32"))] +impl RsaTrapdoor { + pub fn generate(modulus_bits: u64, rounds: u32) -> Self { + generate_with(modulus_bits, |bits| { + crate::prime::generate_prime_mod_3_4(bits, rounds) + }) + } + + pub fn from_factors(p: BigUint, q: BigUint) -> Self { + use num_integer::Integer; + use num_traits::One; + + let modulus = &p * &q; + let lambda = (&p - BigUint::one()).lcm(&(&q - BigUint::one())); + + Self { modulus, lambda } + } + + pub fn verify( + &self, + base: &BigUint, + solution: &BigUint, + payload: &[u8], + binding_data: &[u8], + difficulty: u32, + ) -> bool { + let exponent = derive_exponent(payload, binding_data, difficulty, &self.lambda); + let expected = base.modpow(&exponent, &self.modulus); + &expected == solution + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn generate_with(modulus_bits: u64, mut generate_prime: impl FnMut(u64) -> BigUint) -> RsaTrapdoor { + let modulus_bits = modulus_bits.max(crate::MIN_RSA_MODULUS_BITS); + let factor_bits = modulus_bits.div_ceil(2); + + loop { + let p = generate_prime(factor_bits); + let q = generate_prime(factor_bits); + if q == p { + continue; + } + + let trapdoor = RsaTrapdoor::from_factors(p, q); + if trapdoor.modulus.bits() < modulus_bits { + continue; + } + + return trapdoor; + } +} + +pub fn solve( + base: &BigUint, + modulus: &BigUint, + payload: &[u8], + binding_data: &[u8], + difficulty: u32, +) -> BigUint { + let mut solution = base.clone(); + let schedule = Schedule::new(payload, binding_data); + + for step in 0..difficulty { + solution = (&solution * &solution) % modulus; + if schedule.cube_after_square(step) { + solution = (&solution * &solution * &solution) % modulus; + } + } + solution +} + +#[cfg(not(target_arch = "wasm32"))] +fn derive_exponent( + payload: &[u8], + binding_data: &[u8], + difficulty: u32, + modulus: &BigUint, +) -> BigUint { + let mut exponent = BigUint::one(); + let schedule = Schedule::new(payload, binding_data); + let two = BigUint::from(2_u32); + let mut step = 0_u32; + + while step < difficulty { + let square_steps = SCHEDULE_BLOCK_STEPS.min(difficulty - step); + exponent *= two.modpow(&BigUint::from(square_steps), modulus); + exponent %= modulus; + + if schedule.cube_after_square(step) { + exponent *= 3_u32; + exponent %= modulus; + } + + step += square_steps; + } + + exponent +} + +struct Schedule { + seed: [u8; 32], +} + +impl Schedule { + fn new(payload: &[u8], binding_data: &[u8]) -> Self { + let mut hasher = Sha256::new(); + hasher.update(SCHEDULE_DOMAIN); + hasher.update((payload.len() as u64).to_be_bytes()); + hasher.update(payload); + hasher.update((binding_data.len() as u64).to_be_bytes()); + hasher.update(binding_data); + + Self { + seed: hasher.finalize().into(), + } + } + + fn cube_after_square(&self, step: u32) -> bool { + if !step.is_multiple_of(SCHEDULE_BLOCK_STEPS) { + return false; + } + + let block = step / SCHEDULE_BLOCK_STEPS; + let byte = self.seed[(block as usize) % self.seed.len()]; + let bit = (byte >> (block % 8)) & 1; + bit == 1 + } +} + +pub fn derive_base(payload: &[u8], binding_data: &[u8], modulus: &BigUint) -> BigUint { + if modulus <= &BigUint::one() { + return BigUint::one(); + } + + let byte_len = modulus.bits().div_ceil(8) as usize + 16; + let mut bytes = Vec::with_capacity(byte_len); + let mut counter = 0_u64; + + while bytes.len() < byte_len { + let mut hasher = Sha512::new(); + hasher.update(DOMAIN); + hasher.update((payload.len() as u64).to_be_bytes()); + hasher.update(payload); + hasher.update((binding_data.len() as u64).to_be_bytes()); + hasher.update(binding_data); + hasher.update(counter.to_be_bytes()); + bytes.extend_from_slice(&hasher.finalize()); + counter += 1; + } + + bytes.truncate(byte_len); + + let one = BigUint::one(); + let range = modulus - &one; + (BigUint::from_bytes_be(&bytes) % range) + one +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::*; + + #[test] + fn test_rsa_trapdoor() { + let rsa = RsaTrapdoor::generate(512, 16); + let payload = b"payload"; + let binding_data = b"app-owned binding data"; + let difficulty = 10; + + let base = derive_base(payload, binding_data, &rsa.modulus); + let solution = solve(&base, &rsa.modulus, payload, binding_data, difficulty); + + assert!(rsa.verify(&base, &solution, payload, binding_data, difficulty)); + } + + #[test] + fn test_generate_retries_equal_factors() { + let p: BigUint = (BigUint::one() << 256) + BigUint::one(); + let q: BigUint = (BigUint::one() << 256) + BigUint::from(3_u32); + + let values = [ + p.clone(), + p.clone(), + BigUint::from(3_u32), + BigUint::from(5_u32), + p, + q, + ]; + let mut index = 0; + let rsa = generate_with(crate::MIN_RSA_MODULUS_BITS, |_| { + let value = values[index].clone(); + index += 1; + value + }); + + assert!(rsa.modulus.bits() >= crate::MIN_RSA_MODULUS_BITS); + assert_eq!(values.len(), index); + } + + #[test] + fn test_derive_base_handles_invalid_modulus() { + assert_eq!( + BigUint::one(), + derive_base(b"payload", b"binding", &BigUint::one()) + ); + } + + #[test] + fn test_binding_data_is_bound() { + let rsa = RsaTrapdoor::generate(512, 16); + let payload = b"payload"; + let difficulty = 10; + + let base = derive_base(payload, b"expected", &rsa.modulus); + let solution = solve(&base, &rsa.modulus, payload, b"expected", difficulty); + + let wrong_base = derive_base(payload, b"wrong", &rsa.modulus); + assert!(!rsa.verify(&wrong_base, &solution, payload, b"wrong", difficulty)); + } + + #[test] + fn test_schedule_changes_solution() { + let rsa = RsaTrapdoor::generate(512, 16); + let payload = b"payload"; + let difficulty = 128; + + let base_a = derive_base(payload, b"a", &rsa.modulus); + let base_b = derive_base(payload, b"b", &rsa.modulus); + let solution_a = solve(&base_a, &rsa.modulus, payload, b"a", difficulty); + let solution_b = solve(&base_b, &rsa.modulus, payload, b"b", difficulty); + + assert_ne!(solution_a, solution_b); + assert!(rsa.verify(&base_a, &solution_a, payload, b"a", difficulty)); + assert!(rsa.verify(&base_b, &solution_b, payload, b"b", difficulty)); + } +} diff --git a/src/sloth.rs b/src/sloth.rs deleted file mode 100644 index 13419a0..0000000 --- a/src/sloth.rs +++ /dev/null @@ -1,111 +0,0 @@ -use num_bigint::{BigUint, ToBigUint}; -use num_traits::One; - -#[cfg(not(target_arch = "wasm32"))] -use rand::thread_rng; - -#[cfg(not(target_arch = "wasm32"))] -use num_bigint::RandBigInt; - -/// Sloth is a verifiable delay function (VDF) that is designed to enforce -/// a predetermined amount of sequential work. - -pub struct Sloth { - /// The modulus `p` is a large prime number. - pub p: BigUint, -} - -impl Sloth { - /// Create a new Sloth VDF with the given modulus `p` and iteration count `t`. - #[cfg(not(target_arch = "wasm32"))] - pub fn new(p: BigUint) -> Self { - Sloth { p } - } - - /// Encode a message `x` using the Sloth VDF. - #[cfg(not(target_arch = "wasm32"))] - pub fn encode(&self, x: &BigUint, t: u32) -> BigUint { - let mut y = x.clone(); - - // Repeatedly square `x` T times mod `p`. - for _ in 0..t { - // y = y^2 mod p - y = (&y * &y) % &self.p; - } - y - } - - /// Decode a message `y` using the Sloth VDF. - pub fn decode(y: &BigUint, p: &BigUint, t: u32) -> BigUint { - let mut x = y.clone(); - for _ in 0..t { - x = Sloth::mod_sqrt(&x, p); - } - x - } - - /// Create a new tuple `(x, y)` where `x` is a random secret number - /// and `y` is the encoded challenge value. - #[cfg(not(target_arch = "wasm32"))] - pub fn create(&self, t: u32) -> (BigUint, BigUint, BigUint) { - let x = thread_rng().gen_biguint_range(&BigUint::one(), &self.p); - let y = self.encode(&x, t); - (x, y, self.p.clone()) - } - - /// Verify that the encoded value `y` was computed from the secret `x`. - #[cfg(not(target_arch = "wasm32"))] - pub fn verify(&self, x: &BigUint, y: &BigUint, t: u32) -> bool { - let check = self.encode(x, t); - &check == y - } - - /// Compute a modular square root when p ≡ 3 mod 4: - /// sqrt(a) mod p = a^((p+1)/4) mod p - /// We'll pick the "lower root" consistently to ensure uniqueness. - fn mod_sqrt(a: &BigUint, p: &BigUint) -> BigUint { - // exponent = (p + 1) / 4 - let exp = (p + BigUint::one()) / 4_u32.to_biguint().unwrap(); - let root = a.modpow(&exp, p); - - // A prime p ≡ 3 mod 4 has exactly two roots: `r` and `p-r`. - // We'll choose the smaller one to ensure uniqueness. - let p_minus_root = (p - &root) % p; - if root <= p_minus_root { - root - } else { - p_minus_root - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sloth() { - let p = crate::prime::generate_prime_mod_3_4(2048, 64); - let t = 100; - let sloth = Sloth::new(p); - - let (x, y, p) = sloth.create(t); - assert!(sloth.verify(&x, &y, t)); - - let decoded = Sloth::decode(&y, &p, t); - assert!(sloth.verify(&decoded, &y, t)); - } - - #[test] - fn test_sloth_incorrect() { - let p = crate::prime::generate_prime_mod_3_4(2048, 64); - let t = 100; - let sloth = Sloth::new(p); - - let (x, y, _p) = sloth.create(t); - assert!(sloth.verify(&x, &y, t)); - - let decoded = 1024_u32.to_biguint().unwrap(); - assert!(!sloth.verify(&decoded, &y, t)); - } -}