Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score_batch to Generator #89

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// build.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -106,6 +106,7 @@ fn main() {
cxx_build::bridges([
"src/sys/types.rs",
"src/sys/config.rs",
"src/sys/scoring.rs",
"src/sys/translator.rs",
"src/sys/generator.rs",
"src/sys/storage_view.rs",
Expand Down
10 changes: 7 additions & 3 deletions examples/gpt-2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// gpt-2.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -42,7 +42,7 @@ use std::time;
use anyhow::Result;
use clap::Parser;

use ct2rs::{Config, Device, GenerationOptions, Generator};
use ct2rs::{Config, Device, GenerationOptions, Generator, ScoringOptions};

/// Generate text using GPT-2 models.
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -83,7 +83,7 @@ fn main() -> Result<()> {

let now = time::Instant::now();
let res = g.generate_batch(
&[prompts],
&[prompts.clone()],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand All @@ -98,5 +98,9 @@ fn main() -> Result<()> {
}
println!("Time taken: {elapsed:?}");

// Scoring the prompts.
let scores = g.score_batch(&[prompts], &ScoringOptions::default())?;
println!("{:?}", scores[0]);

Ok(())
}
9 changes: 8 additions & 1 deletion include/generator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.h
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -21,6 +21,8 @@ struct GenerationOptions;
struct GenerationResult;
struct GenerationStepResult;
struct GenerationCallbackBox;
struct ScoringOptions;
struct ScoringResult;

class Generator {
private:
Expand All @@ -37,6 +39,11 @@ class Generator {
GenerationCallbackBox& callback
) const;

rust::Vec<ScoringResult> score_batch(
const rust::Vec<VecStr>& tokens,
const ScoringOptions& options
) const;

inline size_t num_queued_batches() const {
return this->impl->num_queued_batches();
}
Expand Down
81 changes: 59 additions & 22 deletions src/generator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -17,7 +17,7 @@ pub use sys::GenerationOptions;

use crate::tokenizer::encode_all;

use super::{sys, Config, GenerationStepResult, Tokenizer};
use super::{sys, Config, GenerationStepResult, ScoringOptions, ScoringResult, Tokenizer};

/// A text generator with a tokenizer.
///
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<T: Tokenizer> Generator<T> {
/// occurs during initialization, the function will return an error wrapped in the `Result`.
///
/// # Example
/// The following example creates a translator instance with the tokenizer provided by
/// The following example creates a generator instance with the tokenizer provided by
/// [tokenizers](https://huggingface.co/docs/tokenizers).
///
/// ```no_run
Expand Down Expand Up @@ -213,6 +213,29 @@ impl<T: Tokenizer> Generator<T> {
Ok(res)
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of strings to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
///
pub fn score_batch<U>(
&self,
prompts: &[U],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>
where
U: AsRef<str>,
{
self.generator
.score_batch(&encode_all(&self.tokenizer, prompts)?, options)
}

/// Number of batches in the work queue.
#[inline]
pub fn num_queued_batches(&self) -> anyhow::Result<usize> {
Expand Down Expand Up @@ -242,16 +265,16 @@ impl<T: Tokenizer> Debug for Generator<T> {
#[cfg(feature = "hub")]
mod tests {
use super::Generator;
use crate::tokenizers::auto::Tokenizer;
use crate::{download_model, Config, Device, GenerationOptions};
use anyhow::Result;
use std::path::PathBuf;

const MODEL_ID: &str = "jkawamoto/gpt2-ct2";

#[test]
#[ignore]
fn test_generate() {
let model_path = download_model(MODEL_ID).unwrap();
let g = Generator::new(
&model_path,
fn new_generator(model_path: &PathBuf) -> Result<Generator<Tokenizer>> {
Generator::new(
model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
Expand All @@ -261,7 +284,13 @@ mod tests {
..Default::default()
},
)
.unwrap();
}

#[test]
#[ignore]
fn test_generate() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();

let prompt = "CTranslate2 is a library";
let res = g
Expand All @@ -278,22 +307,30 @@ mod tests {
assert!(res[0].0[0].starts_with(prompt));
}

#[test]
#[ignore]
fn test_scoring() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();

let prompt = "CTranslate2 is a library";
let res = g.score_batch(&[prompt], &Default::default()).unwrap();

assert_eq!(
res[0].tokens,
vec!["Trans", "late", "2", "Ġis", "Ġa", "Ġlibrary"]
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
);
assert_ne!(res[0].normalized_score(), 0.0);
}

#[test]
#[ignore]
fn test_generator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let g = Generator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
let g = new_generator(&model_path).unwrap();

assert!(format!("{:?}", g).contains(model_path.file_name().unwrap().to_str().unwrap()));
}
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// lib.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -80,7 +80,10 @@ pub use generator::{GenerationOptions, Generator};
#[cfg_attr(docsrs, doc(cfg(feature = "hub")))]
pub use hub::download_model;
pub use result::GenerationStepResult;
pub use sys::{set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel};
pub use sys::{
set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel,
ScoringOptions, ScoringResult,
};
pub use tokenizer::Tokenizer;
pub use translator::{TranslationOptions, Translator};
#[cfg(feature = "whisper")]
Expand Down
4 changes: 3 additions & 1 deletion src/sys.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sys.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -47,13 +47,15 @@

pub use config::*;
pub use generator::*;
pub use scoring::*;
pub use storage_view::*;
pub use translator::*;
pub use types::*;
pub use whisper::*;

mod config;
mod generator;
mod scoring;
mod storage_view;
mod translator;
mod types;
Expand Down
29 changes: 28 additions & 1 deletion src/sys/generator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.cpp
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -97,3 +97,30 @@ Generator::generate_batch(

return res;
}

Vec<ScoringResult>
Generator::score_batch(
const Vec<VecStr>& tokens,
const ScoringOptions& options
) const {
auto futures = this->impl->score_batch_async(
from_rust(tokens),
ctranslate2::ScoringOptions {
options.max_input_length,
options.offset,
},
options.max_batch_size,
options.batch_type
);

Vec<ScoringResult> res;
for (auto& future : futures) {
const auto& r = future.get();
res.push_back(ScoringResult {
to_rust(r.tokens),
to_rust(r.tokens_score),
});
}

return res;
}
35 changes: 33 additions & 2 deletions src/sys/generator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// generator.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand All @@ -17,7 +17,8 @@ use anyhow::{anyhow, Error, Result};
use cxx::UniquePtr;

use super::{
config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString, VecUSize,
config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, ScoringOptions, ScoringResult,
VecStr, VecString, VecUSize,
};

trait GenerationCallback {
Expand Down Expand Up @@ -93,6 +94,7 @@ mod ffi {
unsafe extern "C++" {
include!("ct2rs/include/generator.h");
include!("ct2rs/src/sys/types.rs.h");
include!("ct2rs/src/sys/scoring.rs.h");

type VecString = super::VecString;
type VecStr<'a> = super::VecStr<'a>;
Expand All @@ -102,6 +104,9 @@ mod ffi {
type BatchType = super::BatchType;
type GenerationStepResult = super::GenerationStepResult;

type ScoringOptions = super::ScoringOptions;
type ScoringResult = super::ScoringResult;

type Generator;

fn generator(model_path: &str, config: UniquePtr<Config>) -> Result<UniquePtr<Generator>>;
Expand All @@ -114,6 +119,12 @@ mod ffi {
callback: &mut GenerationCallbackBox,
) -> Result<Vec<GenerationResult>>;

fn score_batch(
self: &Generator,
tokens: &Vec<VecStr>,
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>;

fn num_queued_batches(self: &Generator) -> Result<usize>;

fn num_active_batches(self: &Generator) -> Result<usize>;
Expand Down Expand Up @@ -260,6 +271,26 @@ impl Generator {
.collect())
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of tokens to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
pub fn score_batch<T: AsRef<str>>(
&self,
tokens: &[Vec<T>],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>> {
self.ptr
.score_batch(&vec_ffi_vecstr(tokens), options)
.map_err(Error::from)
}

/// Number of batches in the work queue.
#[inline]
pub fn num_queued_batches(&self) -> Result<usize> {
Expand Down
Loading
Loading