Skip to content

Commit

Permalink
allow also scalar/scalar ops
Browse files Browse the repository at this point in the history
  • Loading branch information
bertiqwerty committed Feb 15, 2024
1 parent d22e06f commit 0966743
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 12 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion rormula-rs/src/expression/ops_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ pub fn op_scalar(a: Value, b: Value, op: &impl Fn(f64, f64) -> f64) -> Value {
}
Value::Array(mem::take(arr))
};
let sc_vs_sc = |sc1, sc2| Value::Scalar(op(sc1, sc2));

match (a, b) {
(Value::Array(mut arr), Value::Scalar(sc)) => arr_vs_sc(&mut arr, sc),
(Value::Scalar(sc), Value::Array(mut arr)) => sc_vs_arr(sc, &mut arr),
_ => Value::Error("scalar op can only be applied to matrix and scalar".to_string()),
(Value::Scalar(sc1), Value::Scalar(sc2)) => sc_vs_sc(sc1, sc2),
_ => Value::Error(
"scalar op can only be applied to matrix and scalar or scalar and scalar".to_string(),
),
}
}

Expand Down
7 changes: 7 additions & 0 deletions rormula-rs/tests/test_rormula.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ fn test_arithmetic() {
let expr_ref = ExprArithmetic::parse(s_ref).unwrap();
let rev_val = expr_ref.eval(&[Value::Array(Array2d::ones(5, 1))]).unwrap();
assert_eq!(res, rev_val);
let s = "4*3";
let expr = ExprArithmetic::parse(s).unwrap();
let res = expr.eval_vec(vec![]).unwrap();
let sc_ref = Value::Scalar(12.0);
assert_eq!(res, sc_ref);
let s = "5/3 * alpha / beta * (0.2 / 200.0 / (29.22+gamma+epsilon+phi) / 7500)";
let _ = ExprArithmetic::parse(s).unwrap();
}
#[test]
fn test_restrict() {
Expand Down
2 changes: 1 addition & 1 deletion rormula/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "rormula"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20.0", features = ["generate-import-lib"] }
pyo3 = { version = "0.20.2", features = ["generate-import-lib"] }
numpy = "0.20.0"
rormula-rs = { path = "../rormula-rs" }

Expand Down
8 changes: 8 additions & 0 deletions rormula/test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def eval_asdf():
res = rormula.eval_asdf(df)
assert np.allclose(res.to_numpy().item(), (5.0 - 2.5) / 4.0)

data = np.ones((100, 6))
df = pd.DataFrame(
data=data, columns=["alpha", "beta", "gamma", "delta", "epsilon", "phi"]
)
s = "5/3 * alpha / beta * (0.2 / 200.0 / (29.22+gamma+epsilon+phi) / 1000)"
rormula = Arithmetic(s, "testslash")
res = rormula.eval_asdf(df)


if __name__ == "__main__":
test_arithmetic()

0 comments on commit 0966743

Please sign in to comment.