From 21c39ea548a521146785fc6211600fe19ed3179f Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Tue, 7 Jan 2025 23:54:51 -0600 Subject: [PATCH 1/2] feat: add score_batch to Generator --- build.rs | 3 +- include/generator.h | 9 ++++- src/generator.rs | 68 +++++++++++++++++++++---------- src/lib.rs | 7 +++- src/sys.rs | 4 +- src/sys/generator.cpp | 29 +++++++++++++- src/sys/generator.rs | 25 +++++++++++- src/sys/scoring.rs | 93 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 209 insertions(+), 29 deletions(-) create mode 100644 src/sys/scoring.rs diff --git a/build.rs b/build.rs index 5d79320..64fad3d 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,6 @@ // build.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -106,6 +106,7 @@ fn main() { cxx_build::bridges([ "src/sys/types.rs", "src/sys/config.rs", + "src/sys/scoring.rs", "src/sys/translator.rs", "src/sys/generator.rs", "src/sys/storage_view.rs", diff --git a/include/generator.h b/include/generator.h index e2a441c..50cf430 100644 --- a/include/generator.h +++ b/include/generator.h @@ -1,6 +1,6 @@ // generator.h // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -21,6 +21,8 @@ struct GenerationOptions; struct GenerationResult; struct GenerationStepResult; struct GenerationCallbackBox; +struct ScoringOptions; +struct ScoringResult; class Generator { private: @@ -37,6 +39,11 @@ class Generator { GenerationCallbackBox& callback ) const; + rust::Vec score_batch( + const rust::Vec& tokens, + const ScoringOptions& options + ) const; + inline size_t num_queued_batches() const { return this->impl->num_queued_batches(); } diff --git a/src/generator.rs b/src/generator.rs index 1632d7c..a2215d0 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -1,6 +1,6 @@ // generator.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -17,7 +17,7 @@ pub use sys::GenerationOptions; use crate::tokenizer::encode_all; -use super::{sys, Config, GenerationStepResult, Tokenizer}; +use super::{sys, Config, GenerationStepResult, ScoringOptions, ScoringResult, Tokenizer}; /// A text generator with a tokenizer. /// @@ -213,6 +213,18 @@ impl Generator { Ok(res) } + pub fn score_batch( + &self, + prompts: &[U], + options: &ScoringOptions, + ) -> Result> + where + U: AsRef, + { + self.generator + .score_batch(&encode_all(&self.tokenizer, prompts)?, options) + } + /// Number of batches in the work queue. #[inline] pub fn num_queued_batches(&self) -> anyhow::Result { @@ -242,16 +254,16 @@ impl Debug for Generator { #[cfg(feature = "hub")] mod tests { use super::Generator; + use crate::tokenizers::auto::Tokenizer; use crate::{download_model, Config, Device, GenerationOptions}; + use anyhow::Result; + use std::path::PathBuf; const MODEL_ID: &str = "jkawamoto/gpt2-ct2"; - #[test] - #[ignore] - fn test_generate() { - let model_path = download_model(MODEL_ID).unwrap(); - let g = Generator::new( - &model_path, + fn new_generator(model_path: &PathBuf) -> Result> { + Generator::new( + model_path, &Config { device: if cfg!(feature = "cuda") { Device::CUDA @@ -261,7 +273,13 @@ mod tests { ..Default::default() }, ) - .unwrap(); + } + + #[test] + #[ignore] + fn test_generate() { + let model_path = download_model(MODEL_ID).unwrap(); + let g = new_generator(&model_path).unwrap(); let prompt = "CTranslate2 is a library"; let res = g @@ -278,22 +296,30 @@ mod tests { assert!(res[0].0[0].starts_with(prompt)); } + #[test] + #[ignore] + fn test_scoring() { + let model_path = download_model(MODEL_ID).unwrap(); + let g = new_generator(&model_path).unwrap(); + + let prompt = "CTranslate2 is a library"; + let res = g.score_batch(&[prompt], &Default::default()).unwrap(); + + assert_eq!( + res[0].tokens, + vec!["Trans", "late", "2", "Ġis", "Ġa", "Ġlibrary"] + .iter() + .map(|s| s.to_string()) + .collect::>() + ); + assert_ne!(res[0].normalized_score(), 0.0); + } + #[test] #[ignore] fn test_generator_debug() { let model_path = download_model(MODEL_ID).unwrap(); - let g = Generator::new( - &model_path, - &Config { - device: if cfg!(feature = "cuda") { - Device::CUDA - } else { - Device::CPU - }, - ..Default::default() - }, - ) - .unwrap(); + let g = new_generator(&model_path).unwrap(); assert!(format!("{:?}", g).contains(model_path.file_name().unwrap().to_str().unwrap())); } diff --git a/src/lib.rs b/src/lib.rs index 17a02ed..30284e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ // lib.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -80,7 +80,10 @@ pub use generator::{GenerationOptions, Generator}; #[cfg_attr(docsrs, doc(cfg(feature = "hub")))] pub use hub::download_model; pub use result::GenerationStepResult; -pub use sys::{set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel}; +pub use sys::{ + set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel, + ScoringOptions, ScoringResult, +}; pub use tokenizer::Tokenizer; pub use translator::{TranslationOptions, Translator}; #[cfg(feature = "whisper")] diff --git a/src/sys.rs b/src/sys.rs index 8bfbb5e..27ce1b5 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1,6 +1,6 @@ // sys.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -47,6 +47,7 @@ pub use config::*; pub use generator::*; +pub use scoring::*; pub use storage_view::*; pub use translator::*; pub use types::*; @@ -54,6 +55,7 @@ pub use whisper::*; mod config; mod generator; +mod scoring; mod storage_view; mod translator; mod types; diff --git a/src/sys/generator.cpp b/src/sys/generator.cpp index f191014..479f9e9 100644 --- a/src/sys/generator.cpp +++ b/src/sys/generator.cpp @@ -1,6 +1,6 @@ // generator.cpp // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -97,3 +97,30 @@ Generator::generate_batch( return res; } + +Vec +Generator::score_batch( + const Vec& tokens, + const ScoringOptions& options +) const { + auto futures = this->impl->score_batch_async( + from_rust(tokens), + ctranslate2::ScoringOptions { + options.max_input_length, + options.offset, + }, + options.max_batch_size, + options.batch_type + ); + + Vec res; + for (auto& future : futures) { + const auto& r = future.get(); + res.push_back(ScoringResult { + to_rust(r.tokens), + to_rust(r.tokens_score), + }); + } + + return res; +} diff --git a/src/sys/generator.rs b/src/sys/generator.rs index 9a6c77a..43a8445 100644 --- a/src/sys/generator.rs +++ b/src/sys/generator.rs @@ -1,6 +1,6 @@ // generator.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -17,7 +17,8 @@ use anyhow::{anyhow, Error, Result}; use cxx::UniquePtr; use super::{ - config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString, VecUSize, + config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, ScoringOptions, ScoringResult, + VecStr, VecString, VecUSize, }; trait GenerationCallback { @@ -93,6 +94,7 @@ mod ffi { unsafe extern "C++" { include!("ct2rs/include/generator.h"); include!("ct2rs/src/sys/types.rs.h"); + include!("ct2rs/src/sys/scoring.rs.h"); type VecString = super::VecString; type VecStr<'a> = super::VecStr<'a>; @@ -102,6 +104,9 @@ mod ffi { type BatchType = super::BatchType; type GenerationStepResult = super::GenerationStepResult; + type ScoringOptions = super::ScoringOptions; + type ScoringResult = super::ScoringResult; + type Generator; fn generator(model_path: &str, config: UniquePtr) -> Result>; @@ -114,6 +119,12 @@ mod ffi { callback: &mut GenerationCallbackBox, ) -> Result>; + fn score_batch( + self: &Generator, + tokens: &Vec, + options: &ScoringOptions, + ) -> Result>; + fn num_queued_batches(self: &Generator) -> Result; fn num_active_batches(self: &Generator) -> Result; @@ -260,6 +271,16 @@ impl Generator { .collect()) } + pub fn score_batch>( + &self, + tokens: &[Vec], + options: &ScoringOptions, + ) -> Result> { + self.ptr + .score_batch(&vec_ffi_vecstr(tokens), options) + .map_err(Error::from) + } + /// Number of batches in the work queue. #[inline] pub fn num_queued_batches(&self) -> Result { diff --git a/src/sys/scoring.rs b/src/sys/scoring.rs new file mode 100644 index 0000000..3300c54 --- /dev/null +++ b/src/sys/scoring.rs @@ -0,0 +1,93 @@ +// scoring.rs +// +// Copyright (c) 2023-2025 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +use super::BatchType; +pub use ffi::{ScoringOptions, ScoringResult}; + +#[cxx::bridge] +pub(crate) mod ffi { + + #[derive(Clone, Debug)] + pub struct ScoringOptions { + /// Truncate the inputs after this many tokens (set 0 to disable truncation). + pub max_input_length: usize, + pub offset: i64, + + max_batch_size: usize, + batch_type: BatchType, + } + + #[derive(Clone, Debug)] + pub struct ScoringResult { + pub tokens: Vec, + pub tokens_score: Vec, + } + + struct _dummy { + _vec_scoring_result: Vec, + } + + unsafe extern "C++" { + include!("ct2rs/include/config.h"); + + type BatchType = super::BatchType; + } +} + +impl Default for ScoringOptions { + fn default() -> Self { + Self { + max_input_length: 1024, + offset: 0, + max_batch_size: 0, + batch_type: Default::default(), + } + } +} + +impl ScoringResult { + pub fn cumulated_score(&self) -> f32 { + self.tokens_score.iter().sum() + } + + pub fn normalized_score(&self) -> f32 { + let num_tokens = self.tokens_score.len(); + if num_tokens == 0 { + return 0.0; + } + self.cumulated_score() / num_tokens as f32 + } +} + +#[cfg(test)] +mod tests { + use crate::sys::scoring::ScoringResult; + + const EPSILON: f32 = 1e-6; + + #[test] + fn test_scoring_result() { + let res = ScoringResult { + tokens: vec!["a".to_string(), "b".to_string(), "c".to_string()], + tokens_score: vec![1.0, 2.0, 3.0], + }; + assert!((res.cumulated_score() - 6.0).abs() < EPSILON); + assert!((res.normalized_score() - 2.0).abs() < EPSILON); + } + + #[test] + fn test_empty_scoring_result() { + let res = ScoringResult { + tokens: vec![], + tokens_score: vec![], + }; + + assert_eq!(res.cumulated_score(), 0.0); + assert_eq!(res.normalized_score(), 0.0); + } +} From 174138ba8d75f7031efb53c2edefe149034bf69e Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Wed, 8 Jan 2025 01:38:56 -0600 Subject: [PATCH 2/2] docs: add scoring feature documentation and related examples --- examples/gpt-2.rs | 10 +++++++--- src/generator.rs | 13 ++++++++++++- src/sys/generator.rs | 10 ++++++++++ src/sys/scoring.rs | 32 +++++++++++++++++++++++++++++++- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/examples/gpt-2.rs b/examples/gpt-2.rs index 6130122..6b66135 100644 --- a/examples/gpt-2.rs +++ b/examples/gpt-2.rs @@ -1,6 +1,6 @@ // gpt-2.rs // -// Copyright (c) 2023-2024 Junpei Kawamoto +// Copyright (c) 2023-2025 Junpei Kawamoto // // This software is released under the MIT License. // @@ -42,7 +42,7 @@ use std::time; use anyhow::Result; use clap::Parser; -use ct2rs::{Config, Device, GenerationOptions, Generator}; +use ct2rs::{Config, Device, GenerationOptions, Generator, ScoringOptions}; /// Generate text using GPT-2 models. #[derive(Parser, Debug)] @@ -83,7 +83,7 @@ fn main() -> Result<()> { let now = time::Instant::now(); let res = g.generate_batch( - &[prompts], + &[prompts.clone()], &GenerationOptions { max_length: 30, sampling_topk: 10, @@ -98,5 +98,9 @@ fn main() -> Result<()> { } println!("Time taken: {elapsed:?}"); + // Scoring the prompts. + let scores = g.score_batch(&[prompts], &ScoringOptions::default())?; + println!("{:?}", scores[0]); + Ok(()) } diff --git a/src/generator.rs b/src/generator.rs index a2215d0..73ad29e 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -110,7 +110,7 @@ impl Generator { /// occurs during initialization, the function will return an error wrapped in the `Result`. /// /// # Example - /// The following example creates a translator instance with the tokenizer provided by + /// The following example creates a generator instance with the tokenizer provided by /// [tokenizers](https://huggingface.co/docs/tokenizers). /// /// ```no_run @@ -213,6 +213,17 @@ impl Generator { Ok(res) } + /// Scores a batch of tokens. + /// + /// # Arguments + /// * `tokens` - Batch of strings to score. + /// If the model expects special start or end tokens, they should also be added to this input. + /// * `options` - Settings applied to the scoring process. + /// + /// # Returns + /// Returns a `Result` containing a vector of `ScoringResult` if successful, + /// or an error if the generation fails. + /// pub fn score_batch( &self, prompts: &[U], diff --git a/src/sys/generator.rs b/src/sys/generator.rs index 43a8445..9e4374d 100644 --- a/src/sys/generator.rs +++ b/src/sys/generator.rs @@ -271,6 +271,16 @@ impl Generator { .collect()) } + /// Scores a batch of tokens. + /// + /// # Arguments + /// * `tokens` - Batch of tokens to score. + /// If the model expects special start or end tokens, they should also be added to this input. + /// * `options` - Settings applied to the scoring process. + /// + /// # Returns + /// Returns a `Result` containing a vector of `ScoringResult` if successful, + /// or an error if the generation fails. pub fn score_batch>( &self, tokens: &[Vec], diff --git a/src/sys/scoring.rs b/src/sys/scoring.rs index 3300c54..b45d0da 100644 --- a/src/sys/scoring.rs +++ b/src/sys/scoring.rs @@ -6,25 +6,53 @@ // // http://opensource.org/licenses/mit-license.php +//! Structures for scoring. + use super::BatchType; pub use ffi::{ScoringOptions, ScoringResult}; #[cxx::bridge] pub(crate) mod ffi { + /// `ScoringOptions` specifies configuration options for the scoring process. + /// + /// # Examples + /// + /// Example of creating a default `ScoringOptions`: + /// + /// ``` + /// use ct2rs::sys::ScoringOptions; + /// + /// let opts = ScoringOptions::default(); + /// # assert_eq!(opts.max_input_length, 1024); + /// # assert_eq!(opts.offset, 0); + /// # assert_eq!(opts.max_batch_size, 0); + /// # assert_eq!(opts.batch_type, Default::default()); + /// ``` #[derive(Clone, Debug)] pub struct ScoringOptions { /// Truncate the inputs after this many tokens (set 0 to disable truncation). + /// (default: 1024) pub max_input_length: usize, + /// Offset. (default: 0) pub offset: i64, - + /// The maximum batch size. + /// If the number of inputs is greater than `max_batch_size`, + /// the inputs are sorted by length and split by chunks of `max_batch_size` examples + /// so that the number of padding positions is minimized. + /// (default: 0) max_batch_size: usize, + /// Whether `max_batch_size` is the number of `examples` or `tokens`. batch_type: BatchType, } + /// `ScoringResult` represents the result of a scoring process, + /// containing tokens and their respective scores. #[derive(Clone, Debug)] pub struct ScoringResult { + /// The scored tokens. pub tokens: Vec, + /// Log probability of each token. pub tokens_score: Vec, } @@ -51,10 +79,12 @@ impl Default for ScoringOptions { } impl ScoringResult { + /// Calculates and returns the total sum of all token scores. pub fn cumulated_score(&self) -> f32 { self.tokens_score.iter().sum() } + /// Computes the average score per token, returning 0.0 if there are no tokens. pub fn normalized_score(&self) -> f32 { let num_tokens = self.tokens_score.len(); if num_tokens == 0 {