Skip to content

Commit

Permalink
docs: add scoring feature documentation and related examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto committed Jan 8, 2025
1 parent 21c39ea commit 174138b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
10 changes: 7 additions & 3 deletions examples/gpt-2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// gpt-2.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
// Copyright (c) 2023-2025 Junpei Kawamoto
//
// This software is released under the MIT License.
//
Expand Down Expand Up @@ -42,7 +42,7 @@ use std::time;
use anyhow::Result;
use clap::Parser;

use ct2rs::{Config, Device, GenerationOptions, Generator};
use ct2rs::{Config, Device, GenerationOptions, Generator, ScoringOptions};

/// Generate text using GPT-2 models.
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -83,7 +83,7 @@ fn main() -> Result<()> {

let now = time::Instant::now();
let res = g.generate_batch(
&[prompts],
&[prompts.clone()],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand All @@ -98,5 +98,9 @@ fn main() -> Result<()> {
}
println!("Time taken: {elapsed:?}");

// Scoring the prompts.
let scores = g.score_batch(&[prompts], &ScoringOptions::default())?;
println!("{:?}", scores[0]);

Ok(())
}
13 changes: 12 additions & 1 deletion src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<T: Tokenizer> Generator<T> {
/// occurs during initialization, the function will return an error wrapped in the `Result`.
///
/// # Example
/// The following example creates a translator instance with the tokenizer provided by
/// The following example creates a generator instance with the tokenizer provided by
/// [tokenizers](https://huggingface.co/docs/tokenizers).
///
/// ```no_run
Expand Down Expand Up @@ -213,6 +213,17 @@ impl<T: Tokenizer> Generator<T> {
Ok(res)
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of strings to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
///
pub fn score_batch<U>(
&self,
prompts: &[U],
Expand Down
10 changes: 10 additions & 0 deletions src/sys/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ impl Generator {
.collect())
}

/// Scores a batch of tokens.
///
/// # Arguments
/// * `tokens` - Batch of tokens to score.
/// If the model expects special start or end tokens, they should also be added to this input.
/// * `options` - Settings applied to the scoring process.
///
/// # Returns
/// Returns a `Result` containing a vector of `ScoringResult` if successful,
/// or an error if the generation fails.
pub fn score_batch<T: AsRef<str>>(
&self,
tokens: &[Vec<T>],
Expand Down
32 changes: 31 additions & 1 deletion src/sys/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,53 @@
//
// http://opensource.org/licenses/mit-license.php

//! Structures for scoring.
use super::BatchType;
pub use ffi::{ScoringOptions, ScoringResult};

#[cxx::bridge]
pub(crate) mod ffi {

/// `ScoringOptions` specifies configuration options for the scoring process.
///
/// # Examples
///
/// Example of creating a default `ScoringOptions`:
///
/// ```
/// use ct2rs::sys::ScoringOptions;
///
/// let opts = ScoringOptions::default();
/// # assert_eq!(opts.max_input_length, 1024);
/// # assert_eq!(opts.offset, 0);
/// # assert_eq!(opts.max_batch_size, 0);
/// # assert_eq!(opts.batch_type, Default::default());
/// ```
#[derive(Clone, Debug)]
pub struct ScoringOptions {
/// Truncate the inputs after this many tokens (set 0 to disable truncation).
/// (default: 1024)
pub max_input_length: usize,
/// Offset. (default: 0)
pub offset: i64,

/// The maximum batch size.
/// If the number of inputs is greater than `max_batch_size`,
/// the inputs are sorted by length and split by chunks of `max_batch_size` examples
/// so that the number of padding positions is minimized.
/// (default: 0)
max_batch_size: usize,
/// Whether `max_batch_size` is the number of `examples` or `tokens`.
batch_type: BatchType,
}

/// `ScoringResult` represents the result of a scoring process,
/// containing tokens and their respective scores.
#[derive(Clone, Debug)]
pub struct ScoringResult {
/// The scored tokens.
pub tokens: Vec<String>,
/// Log probability of each token.
pub tokens_score: Vec<f32>,
}

Expand All @@ -51,10 +79,12 @@ impl Default for ScoringOptions {
}

impl ScoringResult {
/// Calculates and returns the total sum of all token scores.
pub fn cumulated_score(&self) -> f32 {
self.tokens_score.iter().sum()
}

/// Computes the average score per token, returning 0.0 if there are no tokens.
pub fn normalized_score(&self) -> f32 {
let num_tokens = self.tokens_score.len();
if num_tokens == 0 {
Expand Down

0 comments on commit 174138b

Please sign in to comment.