Skip to content

Commit

Permalink
Merge pull request #89 from jkawamoto/scoring
Browse files Browse the repository at this point in the history
Add score_batch to Generator
  • Loading branch information
jkawamoto authored Jan 8, 2025
2 parents 045a341 + 174138b commit 12b158a
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 33 deletions.
3 changes: 2 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// build.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -106,6 +106,7 @@ fn main() {
cxx_build::bridges([
"src/sys/types.rs",
"src/sys/config.rs",
"src/sys/scoring.rs",
"src/sys/translator.rs",
"src/sys/generator.rs",
"src/sys/storage_view.rs",
Expand Down
10 changes: 7 additions & 3 deletions examples/gpt-2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// gpt-2.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -42,7 +42,7 @@ use std::time;
use anyhow::Result;
use clap::Parser;

use ct2rs::{Config, Device, GenerationOptions, Generator};
use ct2rs::{Config, Device, GenerationOptions, Generator, ScoringOptions};

/// Generate text using GPT-2 models.
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -83,7 +83,7 @@ fn main() -> Result<()> {

let now = time::Instant::now();
let res = g.generate_batch(
&[prompts],
&[prompts.clone()],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand All @@ -98,5 +98,9 @@ fn main() -> Result<()> {
}
println!("Time taken: {elapsed:?}");

// Scoring the prompts.
let scores = g.score_batch(&[prompts], &ScoringOptions::default())?;
println!("{:?}", scores[0]);

Ok(())
}
9 changes: 8 additions & 1 deletion include/generator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.h
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -21,6 +21,8 @@ struct GenerationOptions;
struct GenerationResult;
struct GenerationStepResult;
struct GenerationCallbackBox;
struct ScoringOptions;
struct ScoringResult;

class Generator {
private:
Expand All @@ -37,6 +39,11 @@ class Generator {
GenerationCallbackBox& callback
) const;

rust::Vec<ScoringResult> score_batch(
const rust::Vec<VecStr>& tokens,
const ScoringOptions& options
) const;

inline size_t num_queued_batches() const {
return this->impl->num_queued_batches();
}
Expand Down
81 changes: 59 additions & 22 deletions src/generator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -17,7 +17,7 @@ pub use sys::GenerationOptions;

use crate::tokenizer::encode_all;

use super::{sys, Config, GenerationStepResult, Tokenizer};
use super::{sys, Config, GenerationStepResult, ScoringOptions, ScoringResult, Tokenizer};

/// A text generator with a tokenizer.
///
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<T: Tokenizer> Generator<T> {
/// occurs during initialization, the function will return an error wrapped in the `Result`.
///
/// # Example
/// The following example creates a translator instance with the tokenizer provided by
/// The following example creates a generator instance with the tokenizer provided by
/// [tokenizers](https://huggingface.co/docs/tokenizers).
///
/// ```no_run
Expand Down Expand Up @@ -213,6 +213,29 @@ impl<T: Tokenizer> Generator<T> {
Ok(res)
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of strings to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
///
pub fn score_batch<U>(
&self,
prompts: &[U],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>
where
U: AsRef<str>,
{
self.generator
.score_batch(&encode_all(&self.tokenizer, prompts)?, options)
}

/// Number of batches in the work queue.
#[inline]
pub fn num_queued_batches(&self) -> anyhow::Result<usize> {
Expand Down Expand Up @@ -242,16 +265,16 @@ impl<T: Tokenizer> Debug for Generator<T> {
#[cfg(feature = "hub")]
mod tests {
use super::Generator;
use crate::tokenizers::auto::Tokenizer;
use crate::{download_model, Config, Device, GenerationOptions};
use anyhow::Result;
use std::path::PathBuf;

const MODEL_ID: &str = "jkawamoto/gpt2-ct2";

#[test]
#[ignore]
fn test_generate() {
let model_path = download_model(MODEL_ID).unwrap();
let g = Generator::new(
&model_path,
fn new_generator(model_path: &PathBuf) -> Result<Generator<Tokenizer>> {
Generator::new(
model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
Expand All @@ -261,7 +284,13 @@ mod tests {
..Default::default()
},
)
.unwrap();
}

#[test]
#[ignore]
fn test_generate() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();

let prompt = "CTranslate2 is a library";
let res = g
Expand All @@ -278,22 +307,30 @@ mod tests {
assert!(res[0].0[0].starts_with(prompt));
}

#[test]
#[ignore]
fn test_scoring() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();

let prompt = "CTranslate2 is a library";
let res = g.score_batch(&[prompt], &Default::default()).unwrap();

assert_eq!(
res[0].tokens,
vec!["Trans", "late", "2", "Ġis", "Ġa", "Ġlibrary"]
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
);
assert_ne!(res[0].normalized_score(), 0.0);
}

#[test]
#[ignore]
fn test_generator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let g = Generator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
let g = new_generator(&model_path).unwrap();

assert!(format!("{:?}", g).contains(model_path.file_name().unwrap().to_str().unwrap()));
}
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// lib.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -80,7 +80,10 @@ pub use generator::{GenerationOptions, Generator};
#[cfg_attr(docsrs, doc(cfg(feature = "hub")))]
pub use hub::download_model;
pub use result::GenerationStepResult;
pub use sys::{set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel};
pub use sys::{
set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel,
ScoringOptions, ScoringResult,
};
pub use tokenizer::Tokenizer;
pub use translator::{TranslationOptions, Translator};
#[cfg(feature = "whisper")]
Expand Down
4 changes: 3 additions & 1 deletion src/sys.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sys.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -47,13 +47,15 @@
pub use config::*;
pub use generator::*;
pub use scoring::*;
pub use storage_view::*;
pub use translator::*;
pub use types::*;
pub use whisper::*;

mod config;
mod generator;
mod scoring;
mod storage_view;
mod translator;
mod types;
Expand Down
29 changes: 28 additions & 1 deletion src/sys/generator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.cpp
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -97,3 +97,30 @@ Generator::generate_batch(

return res;
}

Vec<ScoringResult>
Generator::score_batch(
const Vec<VecStr>& tokens,
const ScoringOptions& options
) const {
auto futures = this->impl->score_batch_async(
from_rust(tokens),
ctranslate2::ScoringOptions {
options.max_input_length,
options.offset,
},
options.max_batch_size,
options.batch_type
);

Vec<ScoringResult> res;
for (auto& future : futures) {
const auto& r = future.get();
res.push_back(ScoringResult {
to_rust(r.tokens),
to_rust(r.tokens_score),
});
}

return res;
}
35 changes: 33 additions & 2 deletions src/sys/generator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -17,7 +17,8 @@ use anyhow::{anyhow, Error, Result};
use cxx::UniquePtr;

use super::{
config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString, VecUSize,
config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, ScoringOptions, ScoringResult,
VecStr, VecString, VecUSize,
};

trait GenerationCallback {
Expand Down Expand Up @@ -93,6 +94,7 @@ mod ffi {
unsafe extern "C++" {
include!("ct2rs/include/generator.h");
include!("ct2rs/src/sys/types.rs.h");
include!("ct2rs/src/sys/scoring.rs.h");

type VecString = super::VecString;
type VecStr<'a> = super::VecStr<'a>;
Expand All @@ -102,6 +104,9 @@ mod ffi {
type BatchType = super::BatchType;
type GenerationStepResult = super::GenerationStepResult;

type ScoringOptions = super::ScoringOptions;
type ScoringResult = super::ScoringResult;

type Generator;

fn generator(model_path: &str, config: UniquePtr<Config>) -> Result<UniquePtr<Generator>>;
Expand All @@ -114,6 +119,12 @@ mod ffi {
callback: &mut GenerationCallbackBox,
) -> Result<Vec<GenerationResult>>;

fn score_batch(
self: &Generator,
tokens: &Vec<VecStr>,
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>;

fn num_queued_batches(self: &Generator) -> Result<usize>;

fn num_active_batches(self: &Generator) -> Result<usize>;
Expand Down Expand Up @@ -260,6 +271,26 @@ impl Generator {
.collect())
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of tokens to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
pub fn score_batch<T: AsRef<str>>(
&self,
tokens: &[Vec<T>],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>> {
self.ptr
.score_batch(&vec_ffi_vecstr(tokens), options)
.map_err(Error::from)
}

/// Number of batches in the work queue.
#[inline]
pub fn num_queued_batches(&self) -> Result<usize> {
Expand Down
Loading

0 comments on commit 12b158a

Please sign in to comment.