diff --git a/.gitignore b/.gitignore index 7f1ba869d..1d81fa66b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,6 @@ # exclude scratch files **/_* -docs/package-lock.json \ No newline at end of file +docs/package-lock.json + +.env diff --git a/Project.toml b/Project.toml index 372b15dbb..66ffb43bf 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OpenAI = "e9f21f70-7185-4079-aca2-91159181367c" +Pinecone = "ee90fdae-f7f0-4648-8b00-9c0307cf46d9" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/src/Experimental/RAGTools/RAGTools.jl b/src/Experimental/RAGTools/RAGTools.jl index fe4b80788..69647a1e6 100644 --- a/src/Experimental/RAGTools/RAGTools.jl +++ b/src/Experimental/RAGTools/RAGTools.jl @@ -32,12 +32,12 @@ include("api_services.jl") include("rag_interface.jl") -export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, CandidateChunks, RAGResult +export ChunkIndex, ChunkKeywordsIndex, ChunkEmbeddingsIndex, PineconeIndex, CandidateChunks, CandidateWithChunks, RAGResult export MultiIndex, SubChunkIndex, MultiCandidateChunks include("types.jl") export build_index, get_chunks, get_embeddings, get_keywords, get_tags, SimpleIndexer, - KeywordsIndexer + KeywordsIndexer, PineconeIndexer include("preparation.jl") include("rank_gpt.jl") diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl index 89c638190..709be3708 100644 --- a/src/Experimental/RAGTools/generation.jl +++ b/src/Experimental/RAGTools/generation.jl @@ -37,7 +37,8 @@ context = build_context(ContextEnumerator(), index, candidates; chunks_window_ma ``` """ function build_context(contexter::ContextEnumerator, - index::AbstractDocumentIndex, candidates::AbstractCandidateChunks; + index::AbstractDocumentIndex, + candidates::AbstractCandidateChunks; verbose::Bool = true, chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...) ## Checks @@ -63,6 +64,35 @@ function build_context(contexter::ContextEnumerator, return context end +""" + build_context(contexter::ContextEnumerator, + index::AbstractManagedIndex, candidates::AbstractCandidateWithChunks; + verbose::Bool = true, + chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...) + + build_context!(contexter::ContextEnumerator, + index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...) + +Dispatch for `AbstractManagedIndex` with `AbstractCandidateWithChunks`. +""" +function build_context(contexter::ContextEnumerator, + index::AbstractManagedIndex, + candidates::AbstractCandidateWithChunks; + verbose::Bool = true, kwargs...) + context = String[] + for (i, _) in enumerate(positions(candidates)) + ## select the right index + id = candidates isa MultiCandidateChunks ? candidates.index_ids[i] : + candidates.index_id + index_ = index isa AbstractChunkIndex ? index : index[id] + isnothing(index_) && continue + + chunks_ = chunks(candidates) + push!(context, "$(i). $(join(chunks_, "\n"))") + end + return context +end + function build_context!(contexter::AbstractContextBuilder, index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...) throw(ArgumentError("Contexter $(typeof(contexter)) not implemented")) @@ -74,6 +104,11 @@ function build_context!(contexter::ContextEnumerator, result.context = build_context(contexter, index, result.reranked_candidates; kwargs...) return result end +function build_context!(contexter::ContextEnumerator, + index::AbstractManagedIndex, result::AbstractRAGResult; kwargs...) + result.context = build_context(contexter, index, result.reranked_candidates; kwargs...) + return result +end ## First step: Answerer @@ -139,6 +174,42 @@ function answer!( return result end +""" + answer!( + answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult; + model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true, + template::Symbol = :RAGAnswerFromContext, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Dispatch for `AbstractManagedIndex`. +""" +function answer!( + answerer::SimpleAnswerer, index::AbstractManagedIndex, result::AbstractRAGResult; + model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true, + template::Symbol = :RAGAnswerFromContext, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + ## Checks + placeholders = only(aitemplates(template)).variables # only one template should be found + @assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(template) is not suitable. It must have placeholders: `question` and `context`." + ## + (; context, question) = result + conv = aigenerate(template; question, + context = join(context, "\n\n"), model, verbose = false, + return_all = true, + kwargs...) + msg = conv[end] + result.answer = strip(msg.content) + result.conversations[:answer] = conv + ## Increment the cost tracker + Threads.atomic_add!(cost_tracker, msg.cost) + verbose && + @info "Done generating the answer. Cost: \$$(round(msg.cost,digits=3))" + + return result +end + ## Refine """ NoRefiner <: AbstractRefiner @@ -162,11 +233,12 @@ Refines the answer by executing a web search using the Tavily API. This method a struct TavilySearchRefiner <: AbstractRefiner end function refine!( - refiner::AbstractRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::AbstractRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...) throw(ArgumentError("Refiner $(typeof(refiner)) not implemented")) end + """ refine!( refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; @@ -175,7 +247,7 @@ end Simple no-op function for `refine!`. It simply copies the `result.answer` and `result.conversations[:answer]` without any changes. """ function refine!( - refiner::NoRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::NoRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...) result.final_answer = result.answer if haskey(result.conversations, :answer) @@ -184,9 +256,10 @@ function refine!( return result end + """ refine!( - refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Bool = true, model::AbstractString = PT.MODEL_CHAT, template::Symbol = :RAGAnswerRefiner, @@ -210,7 +283,7 @@ This method uses the same context as the original answer, however, it can be mod - `cost_tracker`: An atomic counter to track the cost of the operation. """ function refine!( - refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::SimpleRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Bool = true, model::AbstractString = PT.MODEL_CHAT, template::Symbol = :RAGAnswerRefiner, @@ -238,9 +311,10 @@ function refine!( return result end + """ refine!( - refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Bool = true, model::AbstractString = PT.MODEL_CHAT, include_answer::Bool = true, @@ -288,7 +362,7 @@ pprint(result) ``` """ function refine!( - refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult; + refiner::TavilySearchRefiner, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Bool = true, model::AbstractString = PT.MODEL_CHAT, include_answer::Bool = true, @@ -353,13 +427,13 @@ Overload this method to add custom postprocessing steps, eg, logging, saving con """ struct NoPostprocessor <: AbstractPostprocessor end -function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractDocumentIndex, +function postprocess!(postprocessor::AbstractPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...) throw(ArgumentError("Postprocessor $(typeof(postprocessor)) not implemented")) end function postprocess!( - ::NoPostprocessor, index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...) + ::NoPostprocessor, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; kwargs...) return result end @@ -394,7 +468,7 @@ end """ generate!( - generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult; + generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Integer = 1, api_kwargs::NamedTuple = NamedTuple(), contexter::AbstractContextBuilder = generator.contexter, @@ -459,7 +533,7 @@ result = generate!(index, result) ``` """ function generate!( - generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult; + generator::AbstractGenerator, index::Union{AbstractDocumentIndex, AbstractManagedIndex}, result::AbstractRAGResult; verbose::Integer = 1, api_kwargs::NamedTuple = NamedTuple(), contexter::AbstractContextBuilder = generator.contexter, @@ -524,8 +598,9 @@ function Base.show(io::IO, cfg::AbstractRAGConfig) dump(io, cfg; maxdepth = 2) end +# TODO: add example for Pinecone """ - airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex; + airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex}; question::AbstractString, verbose::Integer = 1, return_all::Bool = false, api_kwargs::NamedTuple = NamedTuple(), @@ -644,11 +719,35 @@ result = airag(cfg, multi_index; question, return_all=true) # Pretty-print the result PT.pprint(result) + + +Example for Pinecone. + +```julia +import LinearAlgebra, Unicode, SparseArrays +using Pinecone + +# configure your Pinecone API key, index and namespace + +docs_files = ... # files containing docs that you want to upsert to Pinecone +metadata = [Dict{String, Any}("source" => ) for file in docs_files] # replace with your docs' sources +index_pinecone = RT.build_index( + RT.PineconeIndexer(), + docs_files; + pinecone_context, # API key wrapped with `Pinecone.jl` + pinecone_index, + pinecone_namespace, + metadata, + upsert = true +) + +question = "How do I multiply two vectors in Julia?" +result = RT.airag(index_pinecone; question) ``` For easier manipulation of nested kwargs, see utilities `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`. """ -function airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex; +function airag(cfg::AbstractRAGConfig, index::Union{AbstractDocumentIndex, AbstractManagedIndex}; question::AbstractString, verbose::Integer = 1, return_all::Bool = false, api_kwargs::NamedTuple = NamedTuple(), @@ -693,6 +792,10 @@ const DEFAULT_RAG_CONFIG = RAGConfig() function airag(index::AbstractDocumentIndex; question::AbstractString, kwargs...) return airag(DEFAULT_RAG_CONFIG, index; question, kwargs...) end +const DEFAULT_RAG_CONFIG_PINECONE = RAGConfig(PineconeIndexer(), PineconeRetriever(), AdvancedGenerator()) +function airag(index::AbstractManagedIndex; question::AbstractString, kwargs...) + return airag(DEFAULT_RAG_CONFIG_PINECONE, index; question, kwargs...) +end # Special method to pretty-print the airag results function PT.pprint(io::IO, airag_result::Tuple{PT.AIMessage, AbstractRAGResult}, diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index d687cc4fd..07ee96bf7 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -134,6 +134,19 @@ It uses `TextChunker`, `KeywordsProcessor`, and `NoTagger` as default chunker, p tagger::AbstractTagger = NoTagger() end +""" + PineconeIndexer <: AbstractIndexBuilder + +Pinecone index to be returned by `build_index`. + +It uses `FileChunker`, `BatchEmbedder` and `NoTagger` as default chunker, embedder and tagger. +""" +@kwdef mutable struct PineconeIndexer <: AbstractIndexBuilder + chunker::AbstractChunker = FileChunker() + embedder::AbstractEmbedder = BatchEmbedder() + tagger::AbstractTagger = NoTagger() +end + ### Functions ## "Build an index for RAG (Retriever-Augmented Generation) applications. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" @@ -701,6 +714,156 @@ function build_index( return index end +# TODO: where to put these? +using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3, init_v3, Index, PineconeVector, upsert +using UUIDs: UUIDs, uuid4 +""" + build_index( + indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString}; + metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(), + pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""), + pinecone_index::Pinecone.PineconeIndexv3 = nothing, + pinecone_namespace::AbstractString = "", + upsert::Bool = true, + verbose::Integer = 1, + index_id = gensym(pinecone_namespace), + chunker::AbstractChunker = indexer.chunker, + chunker_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = indexer.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = indexer.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + api_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) + +Builds a `PineconeIndex` containing a Pinecone context (API key, index and namespace). +The index stores the document chunks and their embeddings (and potentially other information). + +The function processes each file or document (depending on `chunker`), splits its content into chunks, embeds these chunks +and then combines this information into a retrievable index. The chunks and embeddings are upsert to Pinecone using +the provided Pinecone context (unless the `upsert` flag is set to `false`). + +# Arguments +- `indexer::PineconeIndexer`: The indexing logic for Pinecone operations. +- `files_or_docs`: A vector of valid file paths to be indexed (chunked and embedded). +- `metadata::Vector{Dict{String, Any}}`: A vector of metadata attributed to each docs file, given as dictionaries with `String` keys. Default is empty vector. +- `pinecone_context::Pinecone.PineconeContextv3`: The Pinecone API key generated using Pinecone.jl. Must be specified. +- `pinecone_index::Pinecone.PineconeIndexv3`: The Pinecone index generated using Pinecone.jl. Must be specified. +- `pinecone_namespace::AbstractString`: The Pinecone namespace associated to `pinecone_index`. +- `upsert::Bool = true`: A flag specifying whether to upsert the chunks and embeddings to Pinecone. Defaults to `true`. +- `verbose`: An Integer specifying the verbosity of the logs. Default is `1` (high-level logging). `0` is disabled. +- `index_id`: A unique identifier for the index. Default is a generated symbol. +- `chunker`: The chunker logic to use for splitting the documents. Default is `TextChunker()`. +- `chunker_kwargs`: Parameters to be provided to the `get_chunks` function. Useful to change the `separators` or `max_length`. + - `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs`. +- `embedder`: The embedder logic to use for embedding the chunks. Default is `BatchEmbedder()`. +- `embedder_kwargs`: Parameters to be provided to the `get_embeddings` function. Useful to change the `target_batch_size_length` or reduce asyncmap tasks `ntasks`. + - `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`. +- `tagger`: The tagger logic to use for extracting tags from the chunks. Default is `NoTagger()`, ie, skip tag extraction. There are also `PassthroughTagger` and `OpenTagger`. +- `tagger_kwargs`: Parameters to be provided to the `get_tags` function. + - `model`: The model to use for tags extraction. Default is `PT.MODEL_CHAT`. + - `template`: A template to be used for tags extraction. Default is `:RAGExtractMetadataShort`. + - `tags`: A vector of vectors of strings directly providing the tags for each chunk. Applicable for `tagger::PasstroughTagger`. +- `api_kwargs`: Parameters to be provided to the API endpoint. Shared across all API calls if provided. +- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. + +# Returns +- `PineconeIndex`: An object containing the compiled index of chunks, embeddings, tags, vocabulary, sources and metadata, together with the Pinecone connection data. + +See also: `PineconeIndex`, `get_chunks`, `get_embeddings`, `get_tags`, `CandidateWithChunks`, `find_closest`, `find_tags`, `rerank`, `retrieve`, `generate!`, `airag` + +# Examples +```julia +using Pinecone + +# Prepare the Pinecone connection data +pinecone_context = Pinecone.init_v3(ENV["PINECONE_API_KEY"]) +pindex = ENV["PINECONE_INDEX"] +pinecone_index = !isempty(pindex) ? Pinecone.Index(pinecone_context, pindex) : nothing +namespace = "my-namespace" + +# Add metadata about the sources in Pinecone +metadata = [Dict{String, Any}("source" => doc_file) for doc_file in docs_files] + +# Build the index. By default, the chunks and embeddings get upserted to Pinecone. +const RT = PromptingTools.Experimental.RAGTools +index_pinecone = RT.build_index( + RT.PineconeIndexer(), + docs_files; + pinecone_context = pinecone_context, + pinecone_index = pinecone_index, + pinecone_namespace = namespace, + metadata = metadata +) + +# Notes +- If you get errors about exceeding embedding input sizes, first check the `max_length` in your chunks. + If that does NOT resolve the issue, try changing the `embedding_kwargs`. + In particular, reducing the `target_batch_size_length` parameter (eg, 10_000) and number of tasks `ntasks=1`. + Some providers cannot handle large batch sizes (eg, Databricks). + +""" +function build_index( + indexer::PineconeIndexer, files_or_docs::Vector{<:AbstractString}; + metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}(), + pinecone_context::Pinecone.PineconeContextv3 = Pinecone.init_v3(""), + pinecone_index::Pinecone.PineconeIndexv3 = nothing, + pinecone_namespace::AbstractString = "", + upsert::Bool = true, + verbose::Integer = 1, + index_id = gensym(pinecone_namespace), + chunker::AbstractChunker = indexer.chunker, + chunker_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = indexer.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = indexer.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + api_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) + @assert !isempty(pinecone_context.apikey) && !isnothing(pinecone_index) "Pinecone context and index not set" + + ## Split into chunks + chunks, sources = get_chunks(chunker, files_or_docs; + chunker_kwargs...) + ## Get metadata for each chunk + if isempty(metadata) + metadata = [Dict{String, Any}() for _ in sources] + else + metadata = [metadata[findfirst(f -> f == source, files_or_docs)] for source in sources] + [metadata[idx]["content"] = chunk for (idx, chunk) in enumerate(chunks)] + end + + ## Embed chunks + embeddings = get_embeddings(embedder, chunks; + verbose = (verbose > 1), + cost_tracker, + api_kwargs, embedder_kwargs...) + + ## Extract tags + tags_extracted = get_tags(tagger, chunks; + verbose = (verbose > 1), + cost_tracker, + api_kwargs, tagger_kwargs...) + # Build the sparse matrix and the vocabulary + tags, tags_vocab = build_tags(tagger, tags_extracted) + + # Upsert to Pinecone + if upsert + embeddings_arr = [embeddings[:,i] for i in axes(embeddings,2)] + for (idx, emb) in enumerate(embeddings_arr) + pinevector = Pinecone.PineconeVector(string(UUIDs.uuid4()), emb, metadata[idx]) + Pinecone.upsert(pinecone_context, pinecone_index, [pinevector], pinecone_namespace) + @info "Upsert #$idx complete" + end + end + + index = PineconeIndex(; id = index_id, pinecone_context, pinecone_index, pinecone_namespace, chunks, embeddings, tags, tags_vocab, metadata, sources) + + (verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))" + + return index +end + # Convenience for easy index creation """ ChunkKeywordsIndex( diff --git a/src/Experimental/RAGTools/rag_interface.jl b/src/Experimental/RAGTools/rag_interface.jl index 15c08fecc..8556ff8df 100644 --- a/src/Experimental/RAGTools/rag_interface.jl +++ b/src/Experimental/RAGTools/rag_interface.jl @@ -138,6 +138,8 @@ abstract type AbstractTagger <: AbstractIndexingMethod end ### Index itself - return type of `build_index` abstract type AbstractDocumentIndex end +abstract type AbstractManagedIndex end + """ AbstractMultiIndex <: AbstractDocumentIndex @@ -177,6 +179,8 @@ Return type from `find_closest` and `find_tags` functions. """ abstract type AbstractCandidateChunks end +abstract type AbstractCandidateWithChunks end + # Main supertype for retrieval customizations abstract type AbstractRetrievalMethod end diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 343b67aa3..b9ab38977 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -262,6 +262,67 @@ function find_closest( end end +""" + find_closest( + finder::AbstractSimilarityFinder, index::PineconeIndex, + query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; + top_k::Int = 10, kwargs...) + +Finds the indices of chunks that are closest to query embedding (`query_emb`) by querying Pinecone. + +Returns only `top_k` closest indices. +""" +function find_closest( + finder::AbstractSimilarityFinder, index::PineconeIndex, + query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[]; + top_n::Int = 10, kwargs...) + # get Pinecone info + pinecone_context = index.pinecone_context + pinecone_index = index.pinecone_index + pinecone_namespace = index.pinecone_namespace + + # query candidates + pinecone_results = Pinecone.query(pinecone_context, pinecone_index, + Vector{Float32}(query_emb), top_n, pinecone_namespace, false, true) + pinecone_results_json = JSON3.read(pinecone_results) + matches = pinecone_results_json.matches + + # get the chunks / metadata / sources / scores + positions = [1 for _ in matches] # TODO: change this + scores = sort([m.score for m in matches], rev=true) + chunks = [m.metadata.content for m in matches] + metadata = [JSON3.read(JSON3.write(m.metadata), Dict{String, Any}) for m in matches] + # TODO: metadata might not have `source`, change this + sources = [m.metadata.source for m in matches] + + return CandidateWithChunks( + index_id = index.id, + positions = positions, + scores = Vector{Float32}(scores), + chunks = Vector{String}(chunks), + metadata = metadata, + sources = Vector{String}(sources)) +end + +# Dispatch to find scores for multiple embeddings +function find_closest( + finder::AbstractSimilarityFinder, index::PineconeIndex, + query_emb::AbstractMatrix{<:Real}, query_tokens::AbstractVector{<:AbstractVector{<:AbstractString}} = Vector{Vector{String}}(); + top_k::Int = 100, top_n::Int = 10, + kwargs...) + ## reduce top_k since we have more than one query + top_k_ = top_k รท size(query_emb, 2) + ## simply vcat together (gets sorted from the highest similarity to the lowest) + if isempty(query_tokens) + mapreduce( + c -> find_closest(finder, index, c; top_k = top_k_, top_n = top_n, kwargs...), vcat, eachcol(query_emb)) + else + @assert length(query_tokens)==size(query_emb, 2) "Length of `query_tokens` must be equal to the number of columns in `query_emb`." + mapreduce( + (emb, tok) -> find_closest(finder, index, emb, tok; top_k = top_k_, top_n = top_n, kwargs...), vcat, eachcol(query_emb), query_tokens) + end +end + ### For MultiIndex function find_closest( finder::MultiFinder, index::AbstractMultiIndex, @@ -563,7 +624,7 @@ function find_tags(method::AllTagFilter, index::AbstractChunkIndex, end """ - find_tags(method::NoTagFilter, index::AbstractChunkIndex, + find_tags(method::NoTagFilter, index::Union{AbstractChunkIndex, AbstractManagedIndex}, tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: Union{ AbstractString, Regex, Nothing}} @@ -571,12 +632,16 @@ end Returns all chunks in the index, ie, no filtering, so we simply return `nothing` (easier for dispatch). """ -function find_tags(method::NoTagFilter, index::AbstractChunkIndex, - tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: - Union{ +function find_tags( + method::NoTagFilter, index::Union{AbstractChunkIndex, + AbstractManagedIndex}, + tags::Union{T, AbstractVector{<:T}}; + kwargs...) where {T <: + Union{ AbstractString, Regex, Nothing}} return nothing end + ## Multi-index implementation -- logic differs within each index and then we simply vcat them together function find_tags(method::Union{AnyTagFilter, AllTagFilter}, index::AbstractMultiIndex, tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: @@ -683,6 +748,15 @@ function rerank(reranker::NoReranker, return first(candidates, top_n) end +function rerank(reranker::NoReranker, + index::AbstractManagedIndex, + question::AbstractString, + candidates::AbstractCandidateWithChunks; + top_n::Integer = length(candidates), + kwargs...) + return first(candidates, top_n) +end + """ rerank( reranker::CohereReranker, index::AbstractDocumentIndex, question::AbstractString, @@ -943,6 +1017,32 @@ Compared to SimpleRetriever, it adds rephrasing the query and reranking the resu reranker::AbstractReranker = CohereReranker() end +""" + PineconeRetriever <: AbstractRetriever + +Dispatch for `retrieve` for Pinecone. + +# Fields +- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` - uses `NoRephraser` +- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` (see Preparation Stage for more details) - uses `BatchEmbedder` +- `processor::AbstractProcessor`: the processor method, dispatching `get_keywords` (see Preparation Stage for more details) - uses `NoProcessor` +- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` - uses `CosineSimilarity` +- `tagger::AbstractTagger`: the tag generating method, dispatching `get_tags` (see Preparation Stage for more details) - uses `NoTagger` +- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` - uses `NoTagFilter` +- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` - uses `NoReranker` +""" +@kwdef mutable struct PineconeRetriever <: AbstractRetriever + rephraser::AbstractRephraser = NoRephraser() + # TODO: BatchEmbedder? + embedder::AbstractEmbedder = BatchEmbedder() + processor::AbstractProcessor = NoProcessor() + # TODO: actually do something with this; Pinecone allows choosing finder + finder::AbstractSimilarityFinder = CosineSimilarity() + tagger::AbstractTagger = NoTagger() + filter::AbstractTagFilter = NoTagFilter() + reranker::AbstractReranker = NoReranker() +end + """ retrieve(retriever::AbstractRetriever, index::AbstractChunkIndex, @@ -1157,6 +1257,134 @@ function retrieve(retriever::AbstractRetriever, return result end +""" + retrieve(retriever::PineconeRetriever, + index::PineconeIndex, + question::AbstractString; + verbose::Integer = 1, + top_k::Integer = 100, + top_n::Integer = 10, + api_kwargs::NamedTuple = NamedTuple(), + rephraser::AbstractRephraser = retriever.rephraser, + rephraser_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = retriever.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + processor::AbstractProcessor = retriever.processor, + processor_kwargs::NamedTuple = NamedTuple(), + finder::AbstractSimilarityFinder = retriever.finder, + finder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = retriever.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + filter::AbstractTagFilter = retriever.filter, + filter_kwargs::NamedTuple = NamedTuple(), + reranker::AbstractReranker = retriever.reranker, + reranker_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Dispatch method for `PineconeIndex`. +""" +function retrieve(retriever::PineconeRetriever, + index::PineconeIndex, + question::AbstractString; + verbose::Integer = 1, + top_k::Integer = 100, + top_n::Integer = 10, + api_kwargs::NamedTuple = NamedTuple(), + rephraser::AbstractRephraser = retriever.rephraser, + rephraser_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = retriever.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + processor::AbstractProcessor = retriever.processor, + processor_kwargs::NamedTuple = NamedTuple(), + finder::AbstractSimilarityFinder = retriever.finder, + finder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = retriever.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + filter::AbstractTagFilter = retriever.filter, + filter_kwargs::NamedTuple = NamedTuple(), + reranker::AbstractReranker = retriever.reranker, + reranker_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + ## Rephrase into one or more questions + rephraser_kwargs_ = isempty(api_kwargs) ? rephraser_kwargs : + merge(rephraser_kwargs, (; api_kwargs)) + rephrased_questions = rephrase( + rephraser, question; verbose = (verbose > 1), cost_tracker, rephraser_kwargs_...) + + ## Embed one or more rephrased questions + embeddings = if HasEmbeddings(index) + embedder_kwargs_ = isempty(api_kwargs) ? embedder_kwargs : + merge(embedder_kwargs, (; api_kwargs)) + embeddings = get_embeddings(embedder, rephrased_questions; + verbose = (verbose > 1), cost_tracker, embedder_kwargs_...) + else + embeddings = hcat([Float32[] for _ in rephrased_questions]...) + end + + ## Preprocess into keyword tokens if we're running BM25 + keywords = if HasKeywords(index) + ## Return only keywords, not DTM + keywords = get_keywords(processor, rephrased_questions; + verbose = (verbose > 1), processor_kwargs..., return_keywords = true) + ## Send warning for common error + verbose >= 1 && (keywords isa AbstractVector{<:AbstractVector{<:AbstractString}} || + @warn "Processed Keywords is not a vector of tokenized queries. Have you used the correct processor? (provided: $(typeof(processor))).") + keywords + else + [String[] for _ in rephrased_questions] + end + + finder_kwargs_ = isempty(api_kwargs) ? finder_kwargs : + merge(finder_kwargs, (; api_kwargs)) + emb_candidates = find_closest(finder, index, embeddings, keywords; + verbose = (verbose > 1), top_k, top_n, finder_kwargs_...) + + ## Tagging - if you provide them explicitly, use tagger `PassthroughTagger` and `tagger_kwargs = (;tags = ...)` + tagger_kwargs_ = isempty(api_kwargs) ? tagger_kwargs : + merge(tagger_kwargs, (; api_kwargs)) + tags = get_tags(tagger, rephrased_questions; verbose = (verbose > 1), + cost_tracker, tagger_kwargs_...) + + filter_kwargs_ = isempty(api_kwargs) ? filter_kwargs : + merge(filter_kwargs, (; api_kwargs)) + tag_candidates = find_tags( + filter, index, tags; verbose = (verbose > 1), filter_kwargs_...) + + ## Combine the two sets of candidates, looks for intersection (hard filter)! + # With tagger=NoTagger() get_tags returns `nothing` find_tags simply passes it through to skip the intersection + filtered_candidates = isnothing(tag_candidates) ? emb_candidates : + (emb_candidates & tag_candidates) + ## TODO: Future implementation should be to apply tag filtering BEFORE the find_closest, + ## but that requires implementing `view(::Index,...)` to provide only a subset of the embeddings to the subsequent functionality. + ## Also, find_closest is so fast & cheap that it doesn't matter at current scale/maturity of the use cases + + ## Reranking + reranker_kwargs_ = isempty(api_kwargs) ? reranker_kwargs : + merge(reranker_kwargs, (; api_kwargs)) + reranked_candidates = rerank(reranker, index, question, filtered_candidates; + top_n, verbose = (verbose > 1), cost_tracker, reranker_kwargs_...) + + verbose > 0 && + @info "Retrieval done. Total cost: \$$(round(cost_tracker[], digits=2))." + + result = RAGResult(; + question, + answer = nothing, + rephrased_questions, + final_answer = nothing, + ## Ensure chunks and sources are sorted + context = collect(index[reranked_candidates, :chunks, sorted = true]), + sources = collect(index[reranked_candidates, :sources, sorted = true]), + emb_candidates, + tag_candidates, + filtered_candidates, + reranked_candidates) + + return result +end + # Set default behavior DEFAULT_RETRIEVER = SimpleRetriever() function retrieve(index::AbstractChunkIndex, question::AbstractString; diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 3c3b970f2..1dcfcdc0b 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -3,6 +3,7 @@ using Base: parent ### Shared methods Base.parent(index::AbstractDocumentIndex) = index +Base.parent(index::AbstractManagedIndex) = index indexid(index::AbstractDocumentIndex) = index.id chunkdata(index::AbstractChunkIndex) = index.chunkdata "Access chunkdata for a subset of chunks, `chunk_idx` is a vector of chunk indices in the index" @@ -134,6 +135,61 @@ chunkdata(index::ChunkEmbeddingsIndex) = embeddings(index) # For backward compatibility const ChunkIndex = ChunkEmbeddingsIndex + +# TODO: where to put these? +indexid(index::AbstractManagedIndex) = index.id +chunks(index::AbstractManagedIndex) = index.chunks +sources(index::AbstractManagedIndex) = index.sources + +# TODO: what about this? +using Pinecone: Pinecone, PineconeContextv3, PineconeIndexv3 + +""" + PineconeIndex + +Main struct for storing document chunks and their embeddings along with the necessary Pinecone context for connecting to Pinecone. + +# Fields +- `id::Symbol`: unique identifier of each index (a symbol of the Pinecone index namespace) +- `pinecone_context::Pinecone.PineconeContextv3`: Pinecone API key +- `pinecone_index::Pinecone.PineconeIndexv3`: Pinecone index +- `pinecone_namespace::String`: name of the namespace inside the Pinecone index +- `chunks::Vector{<:AbstractString}`: underlying document chunks / snippets +- `embeddings::Union{Nothing, Matrix{<:Real}}`: for semantic search +- `tags::Union{Nothing, AbstractMatrix{<:Bool}}`: for exact search, filtering, etc. This is often a sparse matrix indicating which chunks have the given `tag` (see `tag_vocab` for the position lookup) +- `tags_vocab::Union{Nothing, Vector{<:AbstractString}}`: vocabulary for the `tags` matrix (each column in `tags` is one item in `tags_vocab` and rows are the chunks) +- `sources::Vector{<:AbstractString}`: sources of the chunks +- `metadata::Vector{Dict{String, Any}}`: metadata for each chunk/embedding stored in Pinecone +""" +@kwdef struct PineconeIndex{ + T1 <: Union{Nothing, AbstractString}, + T2 <: Union{Nothing, Matrix{<:Real}}, + T3 <: Union{Nothing, AbstractMatrix{<:Bool}} +} <: AbstractManagedIndex + # TODO: id should be a combination of index + namespace? + id::Symbol # namespace + # TODO: these should not be v3, maybe? + pinecone_context::Pinecone.PineconeContextv3 + pinecone_index::Pinecone.PineconeIndexv3 + pinecone_namespace::String + # underlying document chunks / snippets + chunks::Vector{T1} = nothing + # for semantic search + embeddings::T2 = nothing + # for exact search, filtering, etc. + # expected to be some sparse structure, eg, sparse matrix or nothing + # column oriented, ie, each column is one item in `tags_vocab` and rows are the chunks + tags::T3 = nothing + tags_vocab::Union{Nothing, Vector{<:AbstractString}} = nothing + sources::Union{Nothing, Vector{<:AbstractString}} = nothing + # metadata for each chunk + # TODO: should be changed to `extras`? but different type -- this needs to be vector of dicts + metadata::Vector{Dict{String, Any}} = Vector{Dict{String, Any}}() +end +HasKeywords(::PineconeIndex) = false +HasEmbeddings(::PineconeIndex) = true +embeddings(index::PineconeIndex) = index.embeddings + abstract type AbstractDocumentTermMatrix end """ @@ -515,6 +571,89 @@ Base.@propagate_inbounds function translate_positions_to_parent( return sub_positions[pos] end + +""" + SubManagedIndex + +Provides the same functionality for `AbstractManagedIndex` as `SubChunkIndex` does for `AbstractChunkIndex`. +""" +@kwdef struct SubManagedIndex{T <: AbstractManagedIndex} <: AbstractManagedIndex + parent::T + positions::Vector{Int} +end + +indexid(index::SubManagedIndex) = parent(index) |> indexid +positions(index::SubManagedIndex) = index.positions +Base.parent(index::SubManagedIndex) = index.parent +HasEmbeddings(index::SubManagedIndex) = HasEmbeddings(parent(index)) +HasKeywords(index::SubManagedIndex) = HasKeywords(parent(index)) + +# TODO: see which of these are needed +Base.@propagate_inbounds function chunks(index::SubManagedIndex) + view(chunks(parent(index)), positions(index)) +end +Base.@propagate_inbounds function sources(index::SubManagedIndex) + view(sources(parent(index)), positions(index)) +end +Base.@propagate_inbounds function chunkdata(index::SubManagedIndex) + chunkdata(parent(index), positions(index)) +end +Base.@propagate_inbounds function chunkdata( + index::SubManagedIndex, chunk_idx::AbstractVector{<:Integer}) + ## We need this accessor because different chunk indices can have chunks in different dimensions!! + index_chunk_idx = translate_positions_to_parent(index, chunk_idx) + pos = intersect(positions(index), index_chunk_idx) + chkdata = chunkdata(parent(index), pos) +end +function embeddings(index::SubManagedIndex) + if HasEmbeddings(index) + view(embeddings(parent(index)), :, positions(index)) + else + throw(ArgumentError("`embeddings` not implemented for $(typeof(index))")) + end +end +function tags(index::SubManagedIndex) + tagsdata = tags(parent(index)) + isnothing(tagsdata) && return nothing + view(tagsdata, positions(index), :) +end +function tags_vocab(index::SubManagedIndex) + tags_vocab(parent(index)) +end +function extras(index::SubManagedIndex) + extrasdata = extras(parent(index)) + isnothing(extrasdata) && return nothing + view(extrasdata, positions(index)) +end +function Base.vcat(i1::SubManagedIndex, i2::SubManagedIndex) + throw(ArgumentError("vcat not implemented for type $(typeof(i1)) and $(typeof(i2))")) +end +function Base.vcat(i1::T, i2::T) where {T <: SubManagedIndex} + ## Check if can be merged + if indexid(parent(i1)) != indexid(parent(i2)) + throw(ArgumentError("Parent indices must be the same (provided: $(indexid(parent(i1))) and $(indexid(parent(i2))))")) + end + return SubChunkIndex(parent(i1), vcat(positions(i1), positions(i2))) +end +function Base.unique(index::SubManagedIndex) + return SubChunkIndex(parent(index), unique(positions(index))) +end +function Base.length(index::SubManagedIndex) + return length(positions(index)) +end +function Base.isempty(index::SubManagedIndex) + return isempty(positions(index)) +end +function Base.show(io::IO, index::SubManagedIndex) + print(io, + "A view of $(typeof(parent(index))|>nameof) (id: $(indexid(parent(index)))) with $(length(index)) chunks") +end +Base.@propagate_inbounds function translate_positions_to_parent( + index::SubManagedIndex, pos::AbstractVector{<:Integer}) + sub_positions = positions(index) + return sub_positions[pos] +end + # # CandidateChunks for Retrieval """ @@ -560,6 +699,43 @@ function CandidateChunks(index::AbstractChunkIndex, positions::AbstractVector{<: indexid(index), convert(Vector{Int}, positions), convert(Vector{Float32}, scores)) end + +""" + CandidateWithChunks + +Similar to `CandidateChunks`, but for `AbstractManagedIndex`. It's the result of the retrieval stage of RAG. + +# Fields +- `index_id::Symbol`: the id of the index from which the candidates are drawn +- `positions::Vector{Int}`: the positions of the candidates in the index (ie, `5` refers to the 5th chunk in the index - `chunks(index)[5]`) +- `scores::Vector{Float32}`: the similarity scores of the candidates from the query (higher is better) +- `chunks::Vector{String}`: the chunks retrieved for a given question +- `metadata::AbstractVector`: metadata corresponding to `chunks` +- `sources::Vector{String}`: sources corresponding to `chunks` +""" +@kwdef struct CandidateWithChunks{TP <: Integer, TD <: Real} <: + AbstractCandidateWithChunks + index_id::Symbol + positions::Vector{TP} = Int[] + scores::Vector{TD} = Float32[] + ## fields obtained "per question" + chunks::Vector{String} = String[] + metadata::AbstractVector = Dict{String, Any}[] + sources::Vector{String} = String[] +end +# TODO: see which can be removed +indexid(cc::CandidateWithChunks) = cc.index_id +positions(cc::CandidateWithChunks) = cc.positions +scores(cc::CandidateWithChunks) = cc.scores +chunks(cc::CandidateWithChunks) = cc.chunks +metadata(cc::CandidateWithChunks) = cc.metadata +sources(cc::CandidateWithChunks) = cc.sources +Base.length(cc::CandidateWithChunks) = length(cc.positions) +function Base.first(cc::CandidateWithChunks, k::Integer) + sorted_idxs = sortperm(scores(cc), rev = true) |> x -> first(x, k) + CandidateWithChunks(indexid(cc), positions(cc)[sorted_idxs], scores(cc)[sorted_idxs], chunks(cc), metadata(cc), sources(cc)) +end + """ MultiCandidateChunks @@ -809,6 +985,20 @@ end Base.@propagate_inbounds function Base.view(index::SubChunkIndex, cc::MultiCandidateChunks) SubChunkIndex(index, cc) end +Base.@propagate_inbounds function Base.view(index::AbstractManagedIndex, cc::CandidateWithChunks) + @boundscheck let chk_vector = chunks(parent(index)) + if !checkbounds(Bool, axes(chk_vector, 1), positions(cc)) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(positions(cc)) + throw(BoundsError(chk_vector, max_pos)) + end + end + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + return SubManagedIndex(parent(index), pos) +end +Base.@propagate_inbounds function Base.view(index::SubManagedIndex, cc::CandidateWithChunks) + SubManagedIndex(index, cc) +end Base.@propagate_inbounds function SubChunkIndex(index::SubChunkIndex, cc::CandidateChunks) pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] intersect_pos = intersect(pos, positions(index)) @@ -835,6 +1025,18 @@ Base.@propagate_inbounds function SubChunkIndex( end return SubChunkIndex(parent(index), intersect_pos) end +Base.@propagate_inbounds function SubManagedIndex(index::SubManagedIndex, cc::CandidateWithChunks) + pos = indexid(index) == indexid(cc) ? positions(cc) : Int[] + intersect_pos = intersect(pos, positions(index)) + @boundscheck let chk_vector = chunks(parent(index)) + if !checkbounds(Bool, axes(chk_vector, 1), intersect_pos) + ## Avoid printing huge position arrays, show the extremas of the attempted range + max_pos = extrema(intersect_pos) + throw(BoundsError(chk_vector, max_pos)) + end + end + return SubManagedIndex(parent(index), intersect_pos) +end ## Getindex @@ -882,6 +1084,45 @@ function Base.getindex(ci::AbstractChunkIndex, end end end +function Base.getindex(pidx::AbstractManagedIndex, + candidate::CandidateWithChunks{TP, TD}, + field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real} + @assert field in [:chunks, :embeddings, :chunkdata, :sources, :scores] "Only `chunks`, `embeddings`, `chunkdata`, `sources`, `scores` fields are supported for now" + ## embeddings is a compatibility alias, use chunkdata + field = field == :embeddings ? :chunkdata : field + + if indexid(pidx) == indexid(candidate) + # Sort if requested + sorted_idx = sorted ? sortperm(scores(candidate), rev = true) : + eachindex(scores(candidate)) + sub_index = view(pidx, candidate) + if field == :chunks + chunks(sub_index)[sorted_idx] + elseif field == :chunkdata + ## If embeddings, chunks are columns + ## If keywords (DTM), chunks are rows + chkdata = chunkdata(sub_index, sorted_idx) + elseif field == :sources + sources(sub_index)[sorted_idx] + elseif field == :scores + scores(candidate)[sorted_idx] + end + else + if field == :chunks + eltype(chunks(pidx))[] + elseif field == :chunkdata + chkdata = chunkdata(pidx) + isnothing(chkdata) && return nothing + TypeItem = typeof(chkdata) + init_dim = ntuple(i -> 0, ndims(chkdata)) + TypeItem(undef, init_dim) + elseif field == :sources + eltype(sources(pidx))[] + elseif field == :scores + TD[] + end + end +end function Base.getindex(mi::MultiIndex, candidate::CandidateChunks{TP, TD}, field::Symbol = :chunks; sorted::Bool = false) where {TP <: Integer, TD <: Real} @@ -936,6 +1177,9 @@ end function Base.getindex(index::AbstractChunkIndex, id::Symbol) id == indexid(index) ? index : nothing end +function Base.getindex(index::AbstractManagedIndex, id::Symbol) + id == indexid(index) ? index : nothing +end function Base.getindex(index::AbstractMultiIndex, id::Symbol) id == indexid(index) && return index idx = findfirst(x -> indexid(x) == id, indexes(index)) @@ -971,13 +1215,13 @@ See also: `pprint` (pretty printing), `annotate_support` (for annotating the ans final_answer::Union{Nothing, AbstractString} = nothing context::Vector{<:AbstractString} = String[] sources::Vector{<:AbstractString} = String[] - emb_candidates::Union{CandidateChunks, MultiCandidateChunks} = CandidateChunks( + emb_candidates::Union{CandidateChunks, CandidateWithChunks, MultiCandidateChunks} = CandidateChunks( index_id = :NOTINDEX, positions = Int[], scores = Float32[]) - tag_candidates::Union{Nothing, CandidateChunks, MultiCandidateChunks} = CandidateChunks( + tag_candidates::Union{Nothing, CandidateChunks, CandidateWithChunks, MultiCandidateChunks} = CandidateChunks( index_id = :NOTINDEX, positions = Int[], scores = Float32[]) - filtered_candidates::Union{CandidateChunks, MultiCandidateChunks} = CandidateChunks( + filtered_candidates::Union{CandidateChunks, CandidateWithChunks, MultiCandidateChunks} = CandidateChunks( index_id = :NOTINDEX, positions = Int[], scores = Float32[]) - reranked_candidates::Union{CandidateChunks, MultiCandidateChunks} = CandidateChunks( + reranked_candidates::Union{CandidateChunks, CandidateWithChunks, MultiCandidateChunks} = CandidateChunks( index_id = :NOTINDEX, positions = Int[], scores = Float32[]) conversations::Dict{Symbol, Vector{<:AbstractMessage}} = Dict{ Symbol, Vector{<:AbstractMessage}}() diff --git a/templates/persona-task/JuliaRAGAssistant.json b/templates/persona-task/JuliaRAGAssistant.json new file mode 100644 index 000000000..7b3c272f5 --- /dev/null +++ b/templates/persona-task/JuliaRAGAssistant.json @@ -0,0 +1,23 @@ +[ + { + "content": "Template Metadata", + "description": "For asking questions for Julia in a RAG context. Placeholders: `question` and `context`", + "version": "1", + "source": "", + "_type": "metadatamessage" + }, + { + "content": "Act as a world-class Julia language programmer with access to the latest Julia-related knowledge via Context Information. \n\n**Instructions:**\n- Answer the question based only on the provided Context.\n- Be precise and answer only when you're confident in the high quality of your answer.\n- Be brief and concise.\n\n**Context Information:**\n---\n{{context}}\n---\n", + "variables": [ + "context" + ], + "_type": "systemmessage" + }, + { + "content": "# Question\n\n{{question}}\n\n\n\n# Answer\n\n", + "variables": [ + "question" + ], + "_type": "usermessage" + } + ] \ No newline at end of file