diff --git a/kr2r/Cargo.toml b/kr2r/Cargo.toml index 375b617..fc9c5b3 100644 --- a/kr2r/Cargo.toml +++ b/kr2r/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kr2r" -version = "0.6.0" +version = "0.6.1" edition = "2021" authors = ["eric9n@gmail.com"] diff --git a/kr2r/src/bin/annotate.rs b/kr2r/src/bin/annotate.rs index 73bc326..b4613b8 100644 --- a/kr2r/src/bin/annotate.rs +++ b/kr2r/src/bin/annotate.rs @@ -1,15 +1,13 @@ use clap::Parser; use kr2r::compact_hash::{CHTable, Compact, HashConfig, Row, Slot}; use kr2r::utils::{find_and_sort_files, open_file}; -// use std::collections::HashMap; -use rayon::prelude::*; +use seqkmer::buffer_read_parallel; use std::collections::HashMap; use std::fs::{File, OpenOptions}; use std::io::{self, BufReader, BufWriter, Read, Result, Write}; use std::path::Path; use std::path::PathBuf; use std::time::Instant; - // 定义每批次处理的 Slot 数量 pub const BATCH_SIZE: usize = 8 * 1024 * 1024; @@ -100,9 +98,7 @@ fn process_batch( where R: Read + Send, { - let slot_size = std::mem::size_of::>(); let row_size = std::mem::size_of::(); - let mut batch_buffer = vec![0u8; slot_size * batch_size]; let mut last_file_index: Option = None; let mut writer: Option> = None; @@ -111,21 +107,13 @@ where let idx_mask = hash_config.get_idx_mask(); let idx_bits = hash_config.get_idx_bits(); - while let Ok(bytes_read) = reader.read(&mut batch_buffer) { - if bytes_read == 0 { - break; - } // 文件末尾 - - // 处理读取的数据批次 - let slots_in_batch = bytes_read / slot_size; - - let slots = unsafe { - std::slice::from_raw_parts(batch_buffer.as_ptr() as *const Slot, slots_in_batch) - }; - - let result: HashMap> = slots - .par_iter() - .filter_map(|slot| { + buffer_read_parallel( + reader, + num_cpus::get(), + batch_size, + |dataset: &[Slot]| { + let mut results: HashMap> = HashMap::new(); + for slot in dataset { let indx = slot.idx & idx_mask; let compacted = slot.value.left(value_bits) as u32; let taxid = chtm.get_from_page(indx, compacted, page_index); @@ -137,48 +125,37 @@ where let left = slot.value.left(value_bits) as u32; let high = u32::combined(left, taxid, value_bits); let row = Row::new(high, seq_id, kmer_id as u32); - // let value = slot.to_b(high); - // let value_bytes = value.to_le_bytes(); // 将u64转换为[u8; 8] let value_bytes = row.as_slice(row_size); - Some((file_index, value_bytes.to_vec())) - } else { - None - } - }) - .fold( - || HashMap::new(), - |mut acc: HashMap>, (file_index, value_bytes)| { - acc.entry(file_index) + + results + .entry(file_index) .or_insert_with(Vec::new) .extend(value_bytes); - acc - }, - ) - .reduce( - || HashMap::new(), - |mut acc, h| { - for (k, mut v) in h { - acc.entry(k).or_insert_with(Vec::new).append(&mut v); + } + } + Some(results) + }, + |result| { + while let Some(Some(res)) = result.next() { + let mut file_indices: Vec<_> = res.keys().cloned().collect(); + file_indices.sort_unstable(); // 对file_index进行排序 + + for file_index in file_indices { + if let Some(bytes) = res.get(&file_index) { + write_to_file( + file_index, + bytes, + &mut last_file_index, + &mut writer, + &chunk_dir, + ) + .expect("write to file error"); } - acc - }, - ); - - let mut file_indices: Vec<_> = result.keys().cloned().collect(); - file_indices.sort_unstable(); // 对file_index进行排序 - - for file_index in file_indices { - if let Some(bytes) = result.get(&file_index) { - write_to_file( - file_index, - bytes, - &mut last_file_index, - &mut writer, - &chunk_dir, - )?; + } } - } - } + }, + ) + .expect("failed"); if let Some(w) = writer.as_mut() { w.flush()?; @@ -200,16 +177,13 @@ fn process_chunk_file>( let start = Instant::now(); let config = HashConfig::from_hash_header(&args.database.join("hash_config.k2d"))?; - let parition = hash_files.len(); - let chtm = if args.kraken_db_type { - CHTable::from_pair( - config, - &hash_files[page_index], - &hash_files[(page_index + 1) % parition], - )? - } else { - CHTable::from(config, &hash_files[page_index])? - }; + let chtm = CHTable::from_range( + config, + hash_files, + page_index, + page_index + 1, + args.kraken_db_type, + )?; // 计算持续时间 let duration = start.elapsed(); @@ -229,7 +203,6 @@ fn process_chunk_file>( pub fn run(args: Args) -> Result<()> { let chunk_files = find_and_sort_files(&args.chunk_dir, "sample", ".k2")?; - let hash_files = find_and_sort_files(&args.database, "hash", ".k2d")?; // 开始计时 diff --git a/kr2r/src/bin/direct.rs b/kr2r/src/bin/direct.rs index 5cd8f59..3d0f01c 100644 --- a/kr2r/src/bin/direct.rs +++ b/kr2r/src/bin/direct.rs @@ -26,6 +26,10 @@ pub struct Args { #[arg(long = "db", required = true)] pub database: PathBuf, + /// File path for outputting normal Kraken output. + #[clap(long = "output-dir", value_parser)] + pub kraken_output_dir: Option, + /// Enable paired-end processing. #[clap(short = 'P', long = "paired-end-processing", action)] pub paired_end_processing: bool, @@ -73,9 +77,9 @@ pub struct Args { #[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())] pub num_threads: usize, - /// File path for outputting normal Kraken output. - #[clap(long = "output-dir", value_parser)] - pub kraken_output_dir: Option, + /// Enables use of a Kraken 2 compatible shared database. Default is false. + #[clap(long, default_value_t = false)] + pub kraken_db_type: bool, /// A list of input file paths (FASTA/FASTQ) to be processed by the classify program. /// Supports fasta or fastq format files (e.g., .fasta, .fastq) and gzip compressed files (e.g., .fasta.gz, .fastq.gz). @@ -98,7 +102,7 @@ fn process_seq( let partition_index = idx / chunk_size; let index = idx % chunk_size; - let taxid = chtable.get_from_page(index, compacted, partition_index + 1); + let taxid = chtable.get_from_page(index, compacted, partition_index); if taxid > 0 { let high = u32::combined(compacted, taxid, value_bits); let row = Row::new(high, 0, sort as u32 + 1 + offset as u32); @@ -348,7 +352,7 @@ pub fn run(args: Args) -> Result<()> { let start = Instant::now(); let meros = idx_opts.as_meros(); let hash_files = find_and_sort_files(&args.database, "hash", ".k2d")?; - let chtable = CHTable::from_hash_files(hash_config, hash_files)?; + let chtable = CHTable::from_hash_files(hash_config, &hash_files, args.kraken_db_type)?; process_files(args, meros, hash_config, &chtable, &taxo)?; let duration = start.elapsed(); diff --git a/kr2r/src/bin/kun.rs b/kr2r/src/bin/kun.rs index 5a31ab2..01713df 100644 --- a/kr2r/src/bin/kun.rs +++ b/kr2r/src/bin/kun.rs @@ -11,7 +11,7 @@ mod splitr; use kr2r::args::ClassifyArgs; use kr2r::args::{parse_size, Build}; -use kr2r::utils::find_and_sort_files; +use kr2r::utils::find_files; // use std::io::Result; use std::path::PathBuf; use std::time::Instant; @@ -180,9 +180,9 @@ fn main() -> Result<(), Box> { let start = Instant::now(); let splitr_args = splitr::Args::from(cmd_args.clone()); - let chunk_files = find_and_sort_files(&splitr_args.chunk_dir, "sample", ".k2")?; - let sample_files = find_and_sort_files(&splitr_args.chunk_dir, "sample", ".map")?; - let bin_files = find_and_sort_files(&splitr_args.chunk_dir, "sample", ".bin")?; + let chunk_files = find_files(&splitr_args.chunk_dir, "sample", ".k2"); + let sample_files = find_files(&splitr_args.chunk_dir, "sample", ".map"); + let bin_files = find_files(&splitr_args.chunk_dir, "sample", ".bin"); if !chunk_files.is_empty() || !sample_files.is_empty() || !bin_files.is_empty() { return Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/kr2r/src/bin/resolve.rs b/kr2r/src/bin/resolve.rs index 72d4baf..f0270b6 100644 --- a/kr2r/src/bin/resolve.rs +++ b/kr2r/src/bin/resolve.rs @@ -1,5 +1,4 @@ use clap::Parser; -use dashmap::{DashMap, DashSet}; use kr2r::classify::process_hitgroup; use kr2r::compact_hash::{HashConfig, Row}; use kr2r::readcounts::{TaxonCounters, TaxonCountersDash}; @@ -7,26 +6,25 @@ use kr2r::report::report_kraken_style; use kr2r::taxonomy::Taxonomy; use kr2r::utils::{find_and_sort_files, open_file}; use kr2r::HitGroup; -use rayon::prelude::*; -use seqkmer::{trim_pair_info, OptionPair}; -use std::collections::HashMap; +// use rayon::prelude::*; +use seqkmer::{buffer_map_parallel, trim_pair_info, OptionPair}; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{self, BufRead, BufReader, BufWriter, Read, Result, Write}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Mutex; use std::time::Instant; -const BATCH_SIZE: usize = 8 * 1024 * 1024; +const BATCH_SIZE: usize = 16 * 1024 * 1024; pub fn read_id_to_seq_map>( filename: P, -) -> Result)>> { +) -> Result)>> { let file = open_file(filename)?; let reader = BufReader::new(file); - let id_map = DashMap::new(); + let mut id_map = HashMap::new(); - reader.lines().par_bridge().for_each(|line| { + reader.lines().for_each(|line| { let line = line.expect("Could not read line"); let parts: Vec<&str> = line.trim().split_whitespace().collect(); if parts.len() >= 4 { @@ -102,82 +100,84 @@ pub struct Args { pub batch_size: usize, } +fn read_rows_from_file>(file_path: P) -> io::Result>> { + let file = File::open(file_path)?; + let mut reader = BufReader::new(file); + let mut buffer = [0u8; std::mem::size_of::()]; // 确保buffer的大小与Row结构体的大小一致 + let mut map: HashMap> = HashMap::new(); + + while reader.read_exact(&mut buffer).is_ok() { + let row: Row = unsafe { std::mem::transmute(buffer) }; // 将读取的字节直接转换为Row结构体 + map.entry(row.seq_id).or_default().push(row); // 插入到HashMap中 + } + + Ok(map) +} + fn process_batch>( sample_file: P, args: &Args, taxonomy: &Taxonomy, - id_map: &DashMap)>, - writer: &Mutex>, + id_map: &HashMap)>, + writer: &mut Box, value_mask: usize, -) -> Result<(TaxonCountersDash, usize, DashSet)> { - let file = open_file(sample_file)?; - let mut reader = BufReader::new(file); - let size = std::mem::size_of::(); - let mut batch_buffer = vec![0u8; size * BATCH_SIZE]; - - let hit_counts = DashMap::new(); - let hit_seq_id_set = DashSet::new(); +) -> Result<(TaxonCountersDash, usize, HashSet)> { + let hit_seq_id_set = HashSet::new(); let confidence_threshold = args.confidence_threshold; let minimum_hit_groups = args.minimum_hit_groups; - while let Ok(bytes_read) = reader.read(&mut batch_buffer) { - if bytes_read == 0 { - break; - } // 文件末尾 + let hit_counts: HashMap> = read_rows_from_file(sample_file)?; - // 处理读取的数据批次 - let slots_in_batch = bytes_read / size; - let slots = unsafe { - std::slice::from_raw_parts(batch_buffer.as_ptr() as *const Row, slots_in_batch) - }; - - slots.par_iter().for_each(|item| { - let seq_id = item.seq_id; - hit_seq_id_set.insert(seq_id); - hit_counts - .entry(seq_id) - .or_insert_with(Vec::new) - .push(*item) - }); - } - - // let writer = Mutex::new(writer); let classify_counter = AtomicUsize::new(0); let cur_taxon_counts = TaxonCountersDash::new(); - hit_counts.into_par_iter().for_each(|(k, mut rows)| { - if let Some(item) = id_map.get(&k) { - rows.sort_unstable(); - let dna_id = trim_pair_info(&item.0); - let range = OptionPair::from(((0, item.2), item.3.map(|size| (item.2, size + item.2)))); - let hits = HitGroup::new(rows, range); - - let hit_data = process_hitgroup( - &hits, - taxonomy, - &classify_counter, - hits.required_score(confidence_threshold), - minimum_hit_groups, - value_mask, - ); - - hit_data.3.iter().for_each(|(key, value)| { - cur_taxon_counts - .entry(*key) - .or_default() - .merge(value) - .unwrap(); - }); - - // 使用锁来同步写入 - let output_line = format!( - "{}\t{}\t{}\t{}\t{}\n", - hit_data.0, dna_id, hit_data.1, item.1, hit_data.2 - ); - let mut file = writer.lock().unwrap(); - file.write_all(output_line.as_bytes()).unwrap(); - } - }); + buffer_map_parallel( + &hit_counts, + num_cpus::get(), + |(k, rows)| { + if let Some(item) = id_map.get(&k) { + let mut rows = rows.to_owned(); + rows.sort_unstable(); + let dna_id = trim_pair_info(&item.0); + let range = + OptionPair::from(((0, item.2), item.3.map(|size| (item.2, size + item.2)))); + let hits = HitGroup::new(rows, range); + + let hit_data = process_hitgroup( + &hits, + taxonomy, + &classify_counter, + hits.required_score(confidence_threshold), + minimum_hit_groups, + value_mask, + ); + + hit_data.3.iter().for_each(|(key, value)| { + cur_taxon_counts + .entry(*key) + .or_default() + .merge(value) + .unwrap(); + }); + + // 使用锁来同步写入 + let output_line = format!( + "{}\t{}\t{}\t{}\t{}\n", + hit_data.0, dna_id, hit_data.1, item.1, hit_data.2 + ); + Some(output_line) + } else { + None + } + }, + |result| { + while let Some(Some(res)) = result.next() { + writer.write_all(res.as_bytes()).unwrap(); + } + }, + ) + .expect("failed"); + Ok(( cur_taxon_counts, classify_counter.load(Ordering::SeqCst), @@ -208,8 +208,9 @@ pub fn run(args: Args) -> Result<()> { for i in 0..partition { let sample_file = &sample_files[i]; let sample_id_map = read_id_to_seq_map(&sample_id_files[i])?; + let thread_sequences = sample_id_map.len(); - let writer: Box = match &args.kraken_output_dir { + let mut writer: Box = match &args.kraken_output_dir { Some(ref file_path) => { let filename = file_path.join(format!("output_{}.txt", i + 1)); let file = File::create(filename)?; @@ -217,31 +218,29 @@ pub fn run(args: Args) -> Result<()> { } None => Box::new(BufWriter::new(io::stdout())) as Box, }; - let writer = Mutex::new(writer); let (thread_taxon_counts, thread_classified, hit_seq_set) = process_batch::<&PathBuf>( sample_file, &args, &taxo, &sample_id_map, - &writer, + &mut writer, value_mask, )?; if args.full_output { sample_id_map .iter() - .filter(|item| !hit_seq_set.contains(item.key())) - .for_each(|item| { - let dna_id = trim_pair_info(&item.0); + .filter(|(key, _)| !hit_seq_set.contains(key)) + .for_each(|(_, value)| { + let dna_id = trim_pair_info(&value.0); // 假设 key 是 &str 类型 let output_line = format!( "U\t{}\t0\t{}\t{}\n", dna_id, - item.1, - if item.3.is_none() { "" } else { " |:| " } + value.1, + if value.3.is_none() { "" } else { " |:| " } ); - let mut file = writer.lock().unwrap(); - file.write_all(output_line.as_bytes()).unwrap(); + writer.write_all(output_line.as_bytes()).unwrap(); }); } diff --git a/kr2r/src/classify.rs b/kr2r/src/classify.rs index 71229bb..671c121 100644 --- a/kr2r/src/classify.rs +++ b/kr2r/src/classify.rs @@ -6,62 +6,6 @@ use seqkmer::SpaceDist; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; -// fn generate_hit_string( -// count: usize, -// rows: &Vec, -// taxonomy: &Taxonomy, -// value_mask: usize, -// offset: usize, -// ) -> String { -// let mut result = Vec::new(); -// let mut last_pos = 0; - -// for row in rows { -// let sort = row.kmer_id as usize; -// if sort < offset || sort >= offset + count { -// continue; -// } -// let adjusted_pos = row.kmer_id as usize - offset; - -// let value = row.value; -// let key = value.right(value_mask); -// let ext_code = taxonomy.nodes[key as usize].external_id; - -// if last_pos == 0 && adjusted_pos > 0 { -// result.push((0, adjusted_pos)); // 在开始处添加0 -// } else if adjusted_pos - last_pos > 1 { -// result.push((0, adjusted_pos - last_pos - 1)); // 在两个特定位置之间添加0 -// } -// if let Some(last) = result.last_mut() { -// if last.0 == ext_code { -// last.1 += 1; -// last_pos = adjusted_pos; -// continue; -// } -// } - -// // 添加当前key的计数 -// result.push((ext_code, 1)); -// last_pos = adjusted_pos; -// } - -// // 填充尾随0 -// if last_pos < count - 1 { -// if last_pos == 0 { -// result.push((0, count - last_pos)); -// } else { -// result.push((0, count - last_pos - 1)); -// } -// } - -// result -// .iter() -// .map(|i| format!("{}:{}", i.0, i.1)) -// .collect::>() -// .join(" ") -// } - -// &HashMap, pub fn resolve_tree( hit_counts: &HashMap, taxonomy: &Taxonomy, @@ -105,57 +49,6 @@ pub fn resolve_tree( max_taxon } -// pub fn add_hitlist_string( -// rows: &Vec, -// value_mask: usize, -// kmer_count1: usize, -// kmer_count2: Option, -// taxonomy: &Taxonomy, -// ) -> String { -// let result1 = generate_hit_string(kmer_count1, &rows, taxonomy, value_mask, 0); -// if let Some(count) = kmer_count2 { -// let result2 = generate_hit_string(count, &rows, taxonomy, value_mask, kmer_count1); -// format!("{} |:| {}", result1, result2) -// } else { -// format!("{}", result1) -// } -// } - -// pub fn count_values( -// rows: &Vec, -// value_mask: usize, -// kmer_count1: u32, -// ) -> (HashMap, TaxonCountersDash, usize) { -// let mut counts = HashMap::new(); - -// let mut hit_count: usize = 0; - -// let mut last_row: Row = Row::new(0, 0, 0); -// let cur_taxon_counts = TaxonCountersDash::new(); - -// for row in rows { -// let value = row.value; -// let key = value.right(value_mask); -// *counts.entry(key).or_insert(0) += 1; - -// // 如果切换到第2条seq,就重新计算 -// if last_row.kmer_id < kmer_count1 && row.kmer_id > kmer_count1 { -// last_row = Row::new(0, 0, 0); -// } -// if !(last_row.value == value && row.kmer_id - last_row.kmer_id == 1) { -// cur_taxon_counts -// .entry(key as u64) -// .or_default() -// .add_kmer(value as u64); -// hit_count += 1; -// } - -// last_row = *row; -// } - -// (counts, cur_taxon_counts, hit_count) -// } - fn stat_hits<'a>( hits: &HitGroup, counts: &mut HashMap, @@ -192,8 +85,6 @@ pub fn process_hitgroup( minimum_hit_groups: usize, value_mask: usize, ) -> (String, u64, String, TaxonCounters) { - // let value_mask = hash_config.value_mask; - let mut cur_taxon_counts = TaxonCounters::new(); let mut counts = HashMap::new(); let hit_groups = hits.capacity(); @@ -205,14 +96,6 @@ pub fn process_hitgroup( &mut cur_taxon_counts, ); - // cur_counts.iter().for_each(|(key, value)| { - // cur_taxon_counts - // .entry(*key) - // .or_default() - // .merge(value) - // .unwrap(); - // }); - let mut call = resolve_tree(&counts, taxonomy, required_score); if call > 0 && hit_groups < minimum_hit_groups { call = 0; diff --git a/kr2r/src/compact_hash.rs b/kr2r/src/compact_hash.rs index ab12aa6..288be41 100644 --- a/kr2r/src/compact_hash.rs +++ b/kr2r/src/compact_hash.rs @@ -409,60 +409,41 @@ pub struct CHTable { impl CHTable { pub fn from_hash_files + Debug>( config: HashConfig, - hash_files: Vec

, + hash_sorted_files: &Vec

, + kd_type: bool, ) -> Result { - let mut pages = vec![Page::default(); hash_files.len() + 1]; - for hash_file in hash_files { + let end = hash_sorted_files.len(); + Self::from_range(config, hash_sorted_files, 0, end, kd_type) + } + + pub fn from_range + Debug>( + config: HashConfig, + hash_sorted_files: &Vec

, + start: usize, + end: usize, + kd_type: bool, + ) -> Result { + let mut pages = vec![Page::default(); start]; + let parition = hash_sorted_files.len(); + for i in start..end { + let mut hash_file = &hash_sorted_files[i]; let mut page = read_page_from_file(&hash_file)?; let next_page = if page.data.last().map_or(false, |&x| x == 0) { + if kd_type { + hash_file = &hash_sorted_files[(i + 1) % parition] + } read_first_block_from_file(&hash_file)? } else { Page::default() }; page.merge(next_page); - if let Some(elem) = pages.get_mut(page.index) { - *elem = page; - } + pages.push(page); } let chtm = CHTable { config, pages }; Ok(chtm) } - pub fn from_pair + Debug>( - config: HashConfig, - chunk_file1: P, - chunk_file2: P, - ) -> Result { - let mut page = read_page_from_file(chunk_file1)?; - let next_page = if page.data.last().map_or(false, |&x| x == 0) { - read_first_block_from_file(chunk_file2)? - } else { - Page::default() - }; - page.merge(next_page); - let count = page.index; - let mut pages = vec![Page::default(); count - 1]; - pages.push(page); - let chtm: CHTable = CHTable { config, pages }; - Ok(chtm) - } - - pub fn from + Debug>(config: HashConfig, chunk_file1: P) -> Result { - let mut page = read_page_from_file(&chunk_file1)?; - let next_page = if page.data.last().map_or(false, |&x| x == 0) { - read_first_block_from_file(&chunk_file1)? - } else { - Page::default() - }; - page.merge(next_page); - let count = page.index; - let mut pages = vec![Page::default(); count - 1]; - pages.push(page); - let chtm = CHTable { config, pages }; - Ok(chtm) - } - pub fn get_from_page(&self, indx: usize, compacted: u32, page_index: usize) -> u32 { if let Some(page) = self.pages.get(page_index) { page.find_index( diff --git a/seqkmer/src/mmscanner.rs b/seqkmer/src/mmscanner.rs index 773cf98..b1a64be 100644 --- a/seqkmer/src/mmscanner.rs +++ b/seqkmer/src/mmscanner.rs @@ -264,24 +264,3 @@ pub fn scan_sequence<'a>( } } } - -pub fn tranfer_sequence<'a>( - sequence: &'a Base>, - meros: &'a Meros, -) -> Base> { - let func = |seq: &'a Vec| { - let cursor = Cursor::new(meros.l_mer, meros.mask); - let window = MinimizerWindow::new(meros.window_size()); - MinimizerIterator::new(seq, cursor, window, meros).collect() - }; - - match &sequence.body { - OptionPair::Pair(seq1, seq2) => Base::new( - sequence.header.clone(), - OptionPair::Pair(func(&seq1), func(&seq2)), - ), - OptionPair::Single(seq1) => { - Base::new(sequence.header.clone(), OptionPair::Single(func(&seq1))) - } - } -} diff --git a/seqkmer/src/parallel.rs b/seqkmer/src/parallel.rs index ab09045..ba4f859 100644 --- a/seqkmer/src/parallel.rs +++ b/seqkmer/src/parallel.rs @@ -4,6 +4,7 @@ use crate::seq::{Base, SeqFormat}; use crate::{detect_file_format, FastaReader, FastqReader, Meros}; use crossbeam_channel::{bounded, Receiver}; use scoped_threadpool::Pool; +use std::collections::HashMap; use std::io::Result; use std::sync::Arc; @@ -98,3 +99,128 @@ where Ok(()) } + +pub fn buffer_read_parallel( + reader: &mut R, + n_threads: usize, + buffer_size: usize, + work: W, + func: F, +) -> Result<()> +where + D: Send + Sized + Sync, + R: std::io::Read + Send, + O: Send, + Out: Send + Default, + W: Send + Sync + Fn(&[D]) -> Option, + F: FnOnce(&mut ParallelResult>) -> Out + Send, +{ + assert!(n_threads > 2); + let buffer_len = n_threads + 2; + let (sender, receiver) = bounded::<&[D]>(buffer_len); + let (done_send, done_recv) = bounded::>(buffer_len); + let receiver = Arc::new(receiver); // 使用 Arc 来共享 receiver + let done_send = Arc::new(done_send); + let mut pool = Pool::new(n_threads as u32); + + let slot_size = std::mem::size_of::(); + let mut parallel_result = ParallelResult { recv: done_recv }; + + pool.scoped(|pool_scope| { + // 生产者线程 + pool_scope.execute(move || { + let mut batch_buffer = vec![0u8; slot_size * buffer_size]; + + while let Ok(bytes_read) = reader.read(&mut batch_buffer) { + if bytes_read == 0 { + break; + } // 文件末尾 + + let slots_in_batch = bytes_read / slot_size; + let slots = unsafe { + std::slice::from_raw_parts(batch_buffer.as_ptr() as *const D, slots_in_batch) + }; + sender.send(slots).expect("Failed to send sequences"); + } + }); + + // 消费者线程 + for _ in 0..n_threads - 2 { + let receiver = Arc::clone(&receiver); + let work = &work; + let done_send = Arc::clone(&done_send); + pool_scope.execute(move || { + while let Ok(seqs) = receiver.recv() { + let output = work(seqs); + done_send.send(output).expect("Failed to send outputs"); + } + }); + } + + // 引用计数减掉一个,这样都子线程结束时, done_send还能完全释放 + drop(done_send); + pool_scope.execute(move || { + let _ = func(&mut parallel_result); + }); + + pool_scope.join_all(); + }); + + Ok(()) +} + +pub fn buffer_map_parallel( + map: &HashMap>, + n_threads: usize, + work: W, + func: F, +) -> Result<()> +where + D: Send + Sized + Sync, + O: Send, + Out: Send + Default, + W: Send + Sync + Fn((&u32, &Vec)) -> Option, + F: FnOnce(&mut ParallelResult>) -> Out + Send, +{ + assert!(n_threads > 2); + let buffer_len = n_threads + 2; + let (sender, receiver) = bounded::<(&u32, &Vec)>(buffer_len); + let (done_send, done_recv) = bounded::>(buffer_len); + let receiver = Arc::new(receiver); // 使用 Arc 来共享 receiver + let done_send = Arc::new(done_send); + let mut pool = Pool::new(n_threads as u32); + + let mut parallel_result = ParallelResult { recv: done_recv }; + + pool.scoped(|pool_scope| { + // 生产者线程 + pool_scope.execute(move || { + for entry in map { + sender.send(entry).expect("Failed to send sequences"); + } + }); + + // 消费者线程 + for _ in 0..n_threads - 2 { + let receiver = Arc::clone(&receiver); + let work = &work; + let done_send = Arc::clone(&done_send); + pool_scope.execute(move || { + while let Ok(seqs) = receiver.recv() { + let output = work(seqs); + done_send.send(output).expect("Failed to send outputs"); + } + }); + } + + // 引用计数减掉一个,这样都子线程结束时, done_send还能完全释放 + drop(done_send); + pool_scope.execute(move || { + let _ = func(&mut parallel_result); + }); + + pool_scope.join_all(); + }); + + Ok(()) +}