Skip to content

Commit

Permalink
Updated DocumentTermMatrix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Sep 10, 2024
1 parent 1ef3ba3 commit a125fcd
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 19 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Fixed

## [0.54.0]

### Updated
- Improved the performance of BM25/Keywords-based indices for >10M documents. Introduced new kwargs of `min_term_freq` and `max_terms` in `RT.get_keywords` to reduce the size of the vocabulary. See `?RT.get_keywords` for more information.

## [0.53.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.53.0"
version = "0.54.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
48 changes: 38 additions & 10 deletions ext/RAGToolsExperimentalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,69 @@ function Base.hcat(d1::RT.DocumentTermMatrix{<:AbstractSparseMatrix},
end

"""
document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
RT.document_term_matrix(
documents::AbstractVector{<:AbstractVector{T}};
min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString}
Builds a sparse matrix of term frequencies and document lengths from the given vector of documents wrapped in type `DocumentTermMatrix`.
Expects a vector of preprocessed (tokenized) documents, where each document is a vector of strings (clean tokens).
Returns: `DocumentTermMatrix`
# Arguments
- `documents`: A vector of documents, where each document is a vector of terms (clean tokens).
- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included.
- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included.
# Example
```
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
dtm = document_term_matrix(documents)
```
"""
function RT.document_term_matrix(documents::AbstractVector{<:AbstractVector{<:AbstractString}})
T = eltype(documents) |> eltype
vocab = convert(Vector{T}, unique(vcat(documents...)))
vocab_lookup = Dict{T, Int}(t => i for (i, t) in enumerate(vocab))
function RT.document_term_matrix(
documents::AbstractVector{<:AbstractVector{T}};
min_term_freq::Int = 1, max_terms::Int = typemax(Int)) where {T <: AbstractString}
## Calculate term frequencies, sort descending
counts = Dict{T, Int}()
@inbounds for doc in documents
for term in doc
counts[term] = get(counts, term, 0) + 1
end
end
counts = sort(collect(counts), by = x -> -x[2]) |> Base.Fix2(first, max_terms) |>
Base.Fix1(filter!, x -> x[2] >= min_term_freq)
## Create vocabulary
vocab = convert(Vector{T}, getindex.(counts, 1))
vocab_lookup = Dict{T, Int}(term => i for (i, term) in enumerate(vocab))
N = length(documents)
doc_freq = zeros(Int, length(vocab))
term_freq = spzeros(Float32, N, length(vocab))
doc_lengths = zeros(Float32, N)
## Term frequency matrix to be recorded via its sparse entries: I, J, V
# term_freq = spzeros(Float32, N, length(vocab))
I, J, V = Int[], Int[], Float32[]

unique_terms = Set{eltype(vocab)}()
sizehint!(unique_terms, 1000)
for di in eachindex(documents)
unique_terms = Set{eltype(vocab)}()
empty!(unique_terms)
doc = documents[di]
for t in doc
@inbounds for t in doc
doc_lengths[di] += 1
tid = vocab_lookup[t]
term_freq[di, tid] += 1
tid = get(vocab_lookup, t, nothing)
tid === nothing && continue
push!(I, di)
push!(J, tid)
push!(V, 1.0f0)
if !(t in unique_terms)
doc_freq[tid] += 1
push!(unique_terms, t)
end
end
end
## combine repeated terms with `+`
term_freq = sparse(I, J, V, N, length(vocab), +)
idf = @. log(1.0f0 + (N - doc_freq + 0.5f0) / (doc_freq + 0.5f0))
sumdl = sum(doc_lengths)
doc_rel_length = sumdl == 0 ? zeros(Float32, N) : doc_lengths ./ (sumdl / N)
Expand Down
15 changes: 8 additions & 7 deletions ext/SnowballPromptingToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ using Snowball
RT._stem(stemmer::Snowball.Stemmer, text::AbstractString) = Snowball.stem(stemmer, text)

"""
get_keywords(processor::KeywordsProcessor, docs::AbstractVector{<:AbstractString};
RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
verbose::Bool = true,
stemmer = nothing,
stopwords::Set{String} = Set(STOPWORDS),
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
min_term_freq::Int = 1, max_terms::Int = typemax(Int),
kwargs...)
Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stemmer` and `stopwords`.
Expand All @@ -29,6 +31,8 @@ Generate a `DocumentTermMatrix` from a vector of `docs` using the provided `stem
- `stopwords`: A set of stopwords to remove. Default is `Set(STOPWORDS)`.
- `return_keywords`: A boolean flag for returning the keywords. Default is `false`. Useful for query processing in search time.
- `min_length`: The minimum length of the keywords. Default is `3`.
- `min_term_freq`: The minimum frequency a term must have to be included in the vocabulary, eg, `min_term_freq = 2` means only terms that appear at least twice will be included.
- `max_terms`: The maximum number of terms to include in the vocabulary, eg, `max_terms = 100` means only the 100 most frequent terms will be included.
"""
function RT.get_keywords(
processor::RT.KeywordsProcessor, docs::AbstractVector{<:AbstractString};
Expand All @@ -37,16 +41,13 @@ function RT.get_keywords(
stopwords::Set{String} = Set(RT.STOPWORDS),
return_keywords::Bool = false,
min_length::Integer = 3,
min_term_freq::Int = 1, max_terms::Int = typemax(Int),
kwargs...)
## check if extension is available
ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt)
if isnothing(ext)
error("You need to also import LinearAlgebra and SparseArrays to use this function")
end
## ext = Base.get_extension(PromptingTools, :SnowballPromptingToolsExt)
## if isnothing(ext)
## error("You need to also import Snowball.jl to use this function")
## end
## Preprocess text into tokens
stemmer = !isnothing(stemmer) ? stemmer : Snowball.Stemmer("english")
# Single-threaded as stemmer is not thread-safe
Expand All @@ -56,7 +57,7 @@ function RT.get_keywords(
return_keywords && return keywords

## Create DTM
dtm = RT.document_term_matrix(keywords)
dtm = RT.document_term_matrix(keywords; min_term_freq, max_terms)

verbose && @info "Done processing DocumentTermMatrix."
return dtm
Expand Down
2 changes: 1 addition & 1 deletion src/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function get_chunks(chunker::AbstractChunker,
# split into chunks by recursively trying the separators provided
# if you want to start simple - just do `split(text,"\n\n")`
doc_chunks = PT.recursive_splitter(doc_raw, separators; max_length) .|> strip |>
x -> filter(!isempty, x)
Base.Fix1(filter!, !isempty)
# skip if no chunks found
isempty(doc_chunks) && continue
append!(output_chunks, doc_chunks)
Expand Down
26 changes: 26 additions & 0 deletions test/Experimental/RAGTools/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ end
@test Set(dtm.vocab) == Set(["this", "test", "document", "anoth", "more", "text"])
@test size(dtm.tf) == (2, 6)

# Test for KeywordsProcessor with min_term_freq and max_terms
docs_freq = [
"apple banana cherry apple",
"banana date fig grape",
"apple banana cherry date",
"elephant fig grape"
]
processor_freq = KeywordsProcessor()

# Test with min_term_freq = 2
dtm_freq = get_keywords(processor_freq, docs_freq; min_term_freq = 2)
@test Set(dtm_freq.vocab) ==
Set(["appl", "banana", "cherri", "date", "fig", "grape"])
@test size(dtm_freq.tf) == (4, 6)

# Test with max_terms = 3
dtm_max = get_keywords(processor_freq, docs_freq; max_terms = 3)
@test length(dtm_max.vocab) == 3
@test size(dtm_max.tf) == (4, 3)

# Test with both min_term_freq = 2 and max_terms = 2
dtm_both = get_keywords(processor_freq, docs_freq; min_term_freq = 2, max_terms = 2)
@test length(dtm_both.vocab) == 2
@test size(dtm_both.tf) == (4, 2)
@test all(sum(dtm_both.tf, dims = 1) .>= 2)

# Test for KeywordsProcessor with custom stemmer and stopwords
custom_stemmer = Snowball.Stemmer("french")
dtm_custom = get_keywords(
Expand Down

0 comments on commit a125fcd

Please sign in to comment.