Skip to content

Commit

Permalink
update HNSW to use scale_modification
Browse files Browse the repository at this point in the history
Add scale_modification in HNSW to have HubNSW
  • Loading branch information
jianshu93 authored Jan 11, 2025
1 parent c4b9b1e commit f02113f
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions src/bin/annembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,22 @@
//!
//! hnsw is an optional subcommand to change default parameters of the Hnsw structure. See [hnsw_rs](https://crates.io/crates/hnsw_rs).
//!
//! - Parameters for embedding.
//! The options are optional and give access to some fields of the [EmbedderParams] structure.
//!
//! --batch : optional, a integer giving the number of batch to run. Default to 15.
//! --stepg : optional, a float value , initial gradient step, default is 2.
//! --scale : optional, a float value, scale modification factor, default is 1.
//! --nbsample : optional, a number of edge sampling , default is 10
//! --layer : optional, in case of hierarchical embedding num of the lower layer we consider to run preliminary step.
//! default is set to 0 meaning one pass embedding.
//! --dim : optional, dimension of the embedding , default to 2.
//!
//! --quality : optional, asks for quality estimation.
//! --sampling : optional, for large data defines the fraction of sampled data as 1./sampling
//! - Parameters for embedding part are all optional The options give access to some fields of the [EmbedderParams] structure.
//! --batch : optional, a integer giving the number of batch to run. Default to 15.
//! --stepg : optional, a float value , initial gradient step, default is 2.
//! --scale : optional, a float value, scale modification factor, default is 1.
//! --nbsample : optional, a number of edge sampling , default is 10
//! --layer : optional, in case of hierarchical embedding num of the lower layer we consider to run preliminary step.
//! default is set to 0 meaning one pass embedding
//! --dim : optional, dimension of the embedding , default to 2.
//! --quality : optional, asks for quality estimation
//! --sampling : optional, for large data defines the fraction of sampled data as 1./sampling
//!
//! - Parameters for the hnsw subcommand. For more details see [hnsw_rs](https://crates.io/crates/hnsw_rs).
//! --nbconn : defines the number of connections by node in a layer. Can range from 4 to 64 or more if necessary and enough memory.
//! --dist : name of distance to use: "DistL1", "DistL2", "DistCosine", "DistJeyffreys".
//! --ef : controls the with of the search, a good guess is between 24 and 64 or more if necessary.
//! --knbn : the number of nodes to use in retrieval requests.
//! --nbconn : defines the number of connections by node in a layer. Can range from 4 to 64 or more if necessary and enough memory
//! --dist : name of distance to use: "DistL1", "DistL2", "DistCosine", "DistJeyffreys"
//! --ef : controls the with of the search, a good guess is between 24 and 64 or more if necessay
//! --knbn : the number of nodes to use in retrieval requests.
//!
//! The csv file must have one record by vector to embed. The default delimiter is ','.
//! The output is a csv file with embedded vectors.
Expand Down Expand Up @@ -55,25 +52,29 @@ pub struct HnswParams {
knbn: usize,
/// distance to use in Hnsw. Default is "DistL2". Other choices are "DistL1", "DistCosine", DistJeffreys
distance: String,
//scale_modification factor, must be [0.2, 1]
scale_modification : f64,
} // end of struct HnswParams

impl HnswParams {
pub fn my_default() -> Self {
pub fn default() -> Self {
HnswParams {
max_conn: 48,
ef_c: 400,
knbn: 10,
distance: String::from("DistL2"),
scale_modification: 1.0,
}
}

#[allow(unused)]
pub fn new(max_conn: usize, ef_c: usize, knbn: usize, distance: String) -> Self {
pub fn new(max_conn: usize, ef_c: usize, knbn: usize, distance: String, scale_modification: f64) -> Self {
HnswParams {
max_conn,
ef_c,
knbn,
distance,
scale_modification,
}
}
} // end impl block
Expand All @@ -98,10 +99,11 @@ impl Default for QualityParams {
fn parse_hnsw_cmd(matches: &ArgMatches) -> Result<HnswParams, anyhow::Error> {
log::debug!("in parse_hnsw_cmd");

let mut hnswparams = HnswParams::my_default();
let mut hnswparams = HnswParams::default();
hnswparams.max_conn = *matches.get_one::<usize>("nbconn").unwrap();
hnswparams.ef_c = *matches.get_one::<usize>("ef").unwrap();
hnswparams.knbn = *matches.get_one::<usize>("knbn").unwrap();
hnswparams.scale_modification = *matches.get_one::<f64>("scale_modification").unwrap();

match matches.get_one::<String>("dist") {
Some(str) => match str.as_str() {
Expand Down Expand Up @@ -169,6 +171,7 @@ pub fn main() {
let embedparams: EmbedderParams;
//
let hnswcmd = Command::new("hnsw")
.about("Build HNSW graph")
.arg(Arg::new("dist")
.long("dist")
.short('d')
Expand All @@ -193,15 +196,25 @@ pub fn main() {
.required(true)
.action(ArgAction::Set)
.value_parser(clap::value_parser!(usize))
.help("search factor"));
.help("search factor"))
.arg(Arg::new("scale_modification")
.long("scale_modify_f")
.help("scale modification factor in HNSW or HubNSW, must be in [0.2,1]")
.value_name("scale_modify")
.default_value("1.0")
.action(ArgAction::Set)
.value_parser(clap::value_parser!(f64))
);

//
// Now the command line
// ===================
//
let matches = Command::new("annembed")
// .subcommand_required(true)
.version("0.1.7")
.arg_required_else_help(true)
.about("Non-linear Dimension Reduction/Embedding via Approximate Nearest Neighbor Graph")
.arg(
Arg::new("csvfile")
.long("csv")
Expand Down Expand Up @@ -311,7 +324,7 @@ pub fn main() {
}
}
} else {
hnswparams = HnswParams::my_default();
hnswparams = HnswParams::default();
}
log::debug!("hnswparams : {:?}", hnswparams);

Expand Down Expand Up @@ -433,13 +446,14 @@ where
{
//
let nb_data = data_with_id.len();
let hnsw = Hnsw::<f64, Dist>::new(
let mut hnsw = Hnsw::<f64, Dist>::new(
hnswparams.max_conn,
nb_data,
nb_layer,
hnswparams.ef_c,
Dist::default(),
);
hnsw.modify_level_scale(hnswparams.scale_modification);
hnsw.parallel_insert(data_with_id);
hnsw.dump_layer_info();
let kgraph = kgraph_from_hnsw_all(&hnsw, hnswparams.knbn).unwrap();
Expand Down Expand Up @@ -496,13 +510,14 @@ where
{
//
let nb_data = data_with_id.len();
let hnsw = Hnsw::<f64, Dist>::new(
let mut hnsw = Hnsw::<f64, Dist>::new(
hnswparams.max_conn,
nb_data,
nb_layer,
hnswparams.ef_c,
Dist::default(),
);
hnsw.modify_level_scale(hnswparams.scale_modification);
hnsw.parallel_insert(data_with_id);
hnsw.dump_layer_info();
KGraphProjection::<f64>::new(&hnsw, hnswparams.knbn, layer_proj)
Expand Down

0 comments on commit f02113f

Please sign in to comment.