Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

miri: implement square root without relying on host floats #4026

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
=> {
let [f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f32()?;
// Using host floats (but it's fine, these operations do not have guaranteed precision).
let f_host = f.to_host();
// Using host floats except for sqrt (but it's fine, these operations do not have
// guaranteed precision).
let res = match intrinsic_name {
"sinf32" => f_host.sin(),
"cosf32" => f_host.cos(),
"sqrtf32" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
"expf32" => f_host.exp(),
"exp2f32" => f_host.exp2(),
"logf32" => f_host.ln(),
"log10f32" => f_host.log10(),
"log2f32" => f_host.log2(),
"sinf32" => f.to_host().sin().to_soft(),
"cosf32" => f.to_host().cos().to_soft(),
"sqrtf32" => math::sqrt(f),
"expf32" => f.to_host().exp().to_soft(),
"exp2f32" => f.to_host().exp2().to_soft(),
"logf32" => f.to_host().ln().to_soft(),
"log10f32" => f.to_host().log10().to_soft(),
"log2f32" => f.to_host().log2().to_soft(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
Expand All @@ -247,20 +246,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
=> {
let [f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f64()?;
// Using host floats (but it's fine, these operations do not have guaranteed precision).
let f_host = f.to_host();
// Using host floats except for sqrt (but it's fine, these operations do not have
// guaranteed precision).
let res = match intrinsic_name {
"sinf64" => f_host.sin(),
"cosf64" => f_host.cos(),
"sqrtf64" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
"expf64" => f_host.exp(),
"exp2f64" => f_host.exp2(),
"logf64" => f_host.ln(),
"log10f64" => f_host.log10(),
"log2f64" => f_host.log2(),
"sinf64" => f.to_host().sin().to_soft(),
"cosf64" => f.to_host().cos().to_soft(),
"sqrtf64" => math::sqrt(f),
"expf64" => f.to_host().exp().to_soft(),
"exp2f64" => f.to_host().exp2().to_soft(),
"logf64" => f.to_host().ln().to_soft(),
"log10f64" => f.to_host().log10().to_soft(),
"log2f64" => f.to_host().log2().to_soft(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
Expand Down
39 changes: 18 additions & 21 deletions src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,42 +104,39 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let ty::Float(float_ty) = op.layout.ty.kind() else {
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
// Using host floats (but it's fine, these operations do not have guaranteed precision).
// Using host floats except for sqrt (but it's fine, these operations do not
// have guaranteed precision).
match float_ty {
FloatTy::F16 => unimplemented!("f16_f128"),
FloatTy::F32 => {
let f = op.to_scalar().to_f32()?;
let f_host = f.to_host();
let res = match host_op {
"fsqrt" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
"fsin" => f_host.sin(),
"fcos" => f_host.cos(),
"fexp" => f_host.exp(),
"fexp2" => f_host.exp2(),
"flog" => f_host.ln(),
"flog2" => f_host.log2(),
"flog10" => f_host.log10(),
"fsqrt" => math::sqrt(f),
"fsin" => f.to_host().sin().to_soft(),
"fcos" => f.to_host().cos().to_soft(),
"fexp" => f.to_host().exp().to_soft(),
"fexp2" => f.to_host().exp2().to_soft(),
"flog" => f.to_host().ln().to_soft(),
"flog2" => f.to_host().log2().to_soft(),
"flog10" => f.to_host().log10().to_soft(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
FloatTy::F64 => {
let f = op.to_scalar().to_f64()?;
let f_host = f.to_host();
let res = match host_op {
"fsqrt" => f_host.sqrt(),
"fsin" => f_host.sin(),
"fcos" => f_host.cos(),
"fexp" => f_host.exp(),
"fexp2" => f_host.exp2(),
"flog" => f_host.ln(),
"flog2" => f_host.log2(),
"flog10" => f_host.log10(),
"fsqrt" => math::sqrt(f),
"fsin" => f.to_host().sin().to_soft(),
"fcos" => f.to_host().cos().to_soft(),
"fexp" => f.to_host().exp().to_soft(),
"fexp2" => f.to_host().exp2().to_soft(),
"flog" => f.to_host().ln().to_soft(),
"flog2" => f.to_host().log2().to_soft(),
"flog10" => f.to_host().log10().to_soft(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ mod eval;
mod helpers;
mod intrinsics;
mod machine;
mod math;
mod mono_hash_map;
mod operator;
mod provenance_gc;
Expand Down
164 changes: 164 additions & 0 deletions src/math.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
use rand::Rng as _;
use rand::distributions::Distribution as _;
use rustc_apfloat::Float as _;
use rustc_apfloat::ieee::IeeeFloat;

/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
ecx: &mut crate::MiriInterpCx<'_>,
val: F,
err_scale: i32,
) -> F {
let rng = ecx.machine.rng.get_mut();
// Generate a random integer in the range [0, 2^PREC).
let dist = rand::distributions::Uniform::new(0, 1 << F::PRECISION);
let err = F::from_u128(dist.sample(rng))
.value
.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
// give it a random sign
let err = if rng.gen::<bool>() { -err } else { err };
// multiple the value with (1+err)
(val * (F::from_u128(1).value + err).value).value
}

pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is specified in 754 right? It should probably just go in https://github.com/rust-lang/rustc_apfloat

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we have an open issue for that: rust-lang/rustc_apfloat#14.

However, last time we wanted to change rustc_apfloat the answer was that things should land in LLVM first and then rustc_apfloat will pick them up from there. That's a quite tedious process. Not sure if that is still the current policy.

match x.category() {
// preserve zero sign
rustc_apfloat::Category::Zero => x,
// propagate NaN
rustc_apfloat::Category::NaN => x,
// sqrt of negative number is NaN
_ if x.is_negative() => IeeeFloat::NAN,
// sqrt(∞) = ∞
rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
RalfJung marked this conversation as resolved.
Show resolved Hide resolved
rustc_apfloat::Category::Normal => {
// Floating point precision, excluding the integer bit
let prec = i32::try_from(S::PRECISION).unwrap() - 1;

// x = 2^(exp - prec) * mant
// where mant is an integer with prec+1 bits
// mant is a u128, which should be large enough for the largest prec (112 for f128)
let mut exp = x.ilogb();
let mut mant = x.scalbn(prec - exp).to_u128(128).value;

if exp % 2 != 0 {
// Make exponent even, so it can be divided by 2
exp -= 1;
mant <<= 1;
RalfJung marked this conversation as resolved.
Show resolved Hide resolved
}

// Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
// mant is treated here as a fixed point number with prec fractional bits.
// mant will be shifted left by one bit to have an extra fractional bit, which
// will be used to determine the rounding direction.

// res is the truncated sqrt of mant, where one bit is added at each iteration.
let mut res = 0u128;
// rem is the remainder with the current res
// rem_i = 2^i * ((mant<<1) - res_i^2)
// starting with res = 0, rem = mant<<1
let mut rem = mant << 1;
// s_i = 2*res_i
let mut s = 0u128;
// d is used to iterate over bits, from high to low (d_i = 2^(-i))
let mut d = 1u128 << (prec + 1);

// For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
// (res_i + b_j * 2^(-j))^2 <= mant<<1
// Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
// res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
// And rearranging the terms:
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i

while d != 0 {
// Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
// t = 2*res_i + 2^(-j)
let t = s + d;
if rem >= t {
// b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
res += d;
s += d + d;
rem -= t;
}
// Adjust rem for next iteration
rem <<= 1;
// Shift iterator
d >>= 1;
}

// Remove extra fractional bit from result, rounding to nearest.
// If the last bit is 0, then the nearest neighbor is definitely the lower one.
// If the last bit is 1, it sounds like this may either be a tie (if there's
// infinitely many 0s after this 1), or the nearest neighbor is the upper one.
// However, since square roots are either exact or irrational, and an exact root
// would lead to the last "extra" bit being 0, we can exclude a tie in this case.
// We therefore always round up if the last bit is 1. When the last bit is 0,
// adding 1 will not do anything since the shift will discard it.
res = (res + 1) >> 1;

// Build resulting value with res as mantissa and exp/2 as exponent
IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
}
}
}

#[cfg(test)]
mod tests {
use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};

use super::sqrt;

#[test]
fn test_sqrt() {
#[track_caller]
fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
let x: IeeeFloat<S> = x.parse().unwrap();
let expected: IeeeFloat<S> = expected.parse().unwrap();
let result = sqrt(x);
assert_eq!(result, expected);
}

fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
test::<S>("0", "0");
test::<S>("1", "1");
test::<S>("1.5625", "1.25");
test::<S>("2.25", "1.5");
test::<S>("4", "2");
test::<S>("5.0625", "2.25");
test::<S>("9", "3");
test::<S>("16", "4");
test::<S>("25", "5");
test::<S>("36", "6");
test::<S>("49", "7");
test::<S>("64", "8");
test::<S>("81", "9");
test::<S>("100", "10");

test::<S>("0.5625", "0.75");
test::<S>("0.25", "0.5");
test::<S>("0.0625", "0.25");
test::<S>("0.00390625", "0.0625");
}

exact_tests::<HalfS>();
exact_tests::<SingleS>();
exact_tests::<DoubleS>();
exact_tests::<QuadS>();

test::<SingleS>("2", "1.4142135");
test::<DoubleS>("2", "1.4142135623730951");

test::<SingleS>("1.1", "1.0488088");
test::<DoubleS>("1.1", "1.0488088481701516");

test::<SingleS>("2.2", "1.4832398");
test::<DoubleS>("2.2", "1.4832396974191326");

test::<SingleS>("1.22101e-40", "1.10499205e-20");
test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");

test::<SingleS>("3.4028235e38", "1.8446743e19");
test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
}
}
27 changes: 4 additions & 23 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rand::Rng as _;
use rustc_abi::{ExternAbi, Size};
use rustc_apfloat::Float;
use rustc_apfloat::ieee::Single;
Expand Down Expand Up @@ -408,38 +407,20 @@ fn unary_op_f32<'tcx>(
let div = (Single::from_u128(1).value / op).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RCP.
let res = apply_random_float_error(ecx, div, -12);
let res = math::apply_random_float_error(ecx, div, -12);
interp_ok(Scalar::from_f32(res))
}
FloatUnaryOp::Rsqrt => {
let op = op.to_scalar().to_u32()?;
// FIXME using host floats
let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
let rsqrt = (Single::from_u128(1).value / sqrt).value;
let op = op.to_scalar().to_f32()?;
let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RSQRT.
let res = apply_random_float_error(ecx, rsqrt, -12);
let res = math::apply_random_float_error(ecx, rsqrt, -12);
interp_ok(Scalar::from_f32(res))
}
}
}

/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
#[expect(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
fn apply_random_float_error<F: rustc_apfloat::Float>(
ecx: &mut crate::MiriInterpCx<'_>,
val: F,
err_scale: i32,
) -> F {
let rng = ecx.machine.rng.get_mut();
// generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
let err = F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.strict_sub(64));
// give it a random sign
let err = if rng.gen::<bool>() { -err } else { err };
// multiple the value with (1+err)
(val * (F::from_u128(1).value + err).value).value
}

/// Performs `which` operation on the first component of `op` and copies
/// the other components. The result is stored in `dest`.
fn unary_op_ss<'tcx>(
Expand Down
14 changes: 12 additions & 2 deletions tests/pass/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,10 +959,20 @@ pub fn libm() {
unsafe { ldexp(a, b) }
}

assert_approx_eq!(64f32.sqrt(), 8f32);
assert_approx_eq!(64f64.sqrt(), 8f64);
assert_eq!(64_f32.sqrt(), 8_f32);
assert_eq!(64_f64.sqrt(), 8_f64);
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert!((-5.0_f32).sqrt().is_nan());
assert!((-5.0_f64).sqrt().is_nan());
assert!(f32::NEG_INFINITY.sqrt().is_nan());
assert!(f64::NEG_INFINITY.sqrt().is_nan());
assert!(f32::NAN.sqrt().is_nan());
assert!(f64::NAN.sqrt().is_nan());

assert_approx_eq!(25f32.powi(-2), 0.0016f32);
assert_approx_eq!(23.2f64.powi(2), 538.24f64);
Expand Down