Skip to content

Commit

Permalink
fix tokens count migration
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jan 22, 2025
1 parent 739c8b3 commit 4f0b50c
Showing 1 changed file with 59 additions and 46 deletions.
105 changes: 59 additions & 46 deletions py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
import sqlalchemy as sa
import tiktoken
from alembic import op
from sqlalchemy import text
from sqlalchemy import inspect, text

# revision identifiers, used by Alembic.
revision = "3efc7b3b1b3d"
down_revision = "7eb70560f406" # Make sure this matches your newest migration
down_revision = "7eb70560f406"
branch_labels = None
depends_on = None

logger = logging.getLogger("alembic.runtime.migration")

# Get project name from environment variable, defaulting to 'r2r_default'
project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default")


def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int:
"""
Expand All @@ -37,56 +40,71 @@ def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int:
return len(encoding.encode(text))


def upgrade() -> None:
def check_if_upgrade_needed() -> bool:
"""Check if the upgrade has already been applied"""
connection = op.get_bind()
inspector = inspect(connection)

# 1) Check if column 'total_tokens' already exists in 'documents'
# If not, we'll create it with a default of 0.
# (If you want the default to be NULL instead of 0, adjust as needed.)
insp = sa.inspect(connection)
columns = insp.get_columns(
"documents"
) # uses default schema or your schema
col_names = [col["name"] for col in columns]
if "total_tokens" not in col_names:
logger.info("Adding 'total_tokens' column to 'documents' table...")
op.add_column(
"documents",
sa.Column(
"total_tokens",
sa.Integer(),
nullable=False,
server_default="0",
),
# Check if documents table exists in the correct schema
if not inspector.has_table("documents", schema=project_name):
logger.info(
f"Migration not needed: '{project_name}.documents' table doesn't exist"
)
else:
return False

# Check if total_tokens column already exists
columns = {
col["name"]
for col in inspector.get_columns("documents", schema=project_name)
}

if "total_tokens" in columns:
logger.info(
"Column 'total_tokens' already exists in 'documents' table, skipping add-column step."
"Migration not needed: documents table already has total_tokens column"
)
return False

# 2) Fill in 'total_tokens' for each document by summing the tokens from all chunks
# We do this in batches to avoid loading too much data at once.
logger.info("Migration needed: documents table needs total_tokens column")
return True


def upgrade() -> None:
if not check_if_upgrade_needed():
return

connection = op.get_bind()

# Add the total_tokens column
logger.info("Adding 'total_tokens' column to 'documents' table...")
op.add_column(
"documents",
sa.Column(
"total_tokens",
sa.Integer(),
nullable=False,
server_default="0",
),
schema=project_name,
)

# Process documents in batches
BATCH_SIZE = 500

# a) Count how many documents we have
# Count total documents
logger.info("Determining how many documents need updating...")
doc_count_query = text("SELECT COUNT(*) FROM documents")
doc_count_query = text(f"SELECT COUNT(*) FROM {project_name}.documents")
total_docs = connection.execute(doc_count_query).scalar() or 0
logger.info(f"Total documents found: {total_docs}")

if total_docs == 0:
logger.info("No documents found, nothing to update.")
return

# b) We'll iterate over documents in pages of size BATCH_SIZE
pages = math.ceil(total_docs / BATCH_SIZE)
logger.info(
f"Updating total_tokens in {pages} batches of up to {BATCH_SIZE} documents..."
)

# Optionally choose a Tiktoken model via environment variable
# or just default if none is set
default_model = os.getenv("R2R_TOKCOUNT_MODEL", "gpt-3.5-turbo")

offset = 0
Expand All @@ -95,12 +113,12 @@ def upgrade() -> None:
f"Processing batch {page_idx + 1} / {pages} (OFFSET={offset}, LIMIT={BATCH_SIZE})"
)

# c) Fetch the IDs of the next batch of documents
# Fetch next batch of document IDs
batch_docs_query = text(
f"""
SELECT id
FROM documents
ORDER BY id -- or ORDER BY created_at, if you prefer chronological
FROM {project_name}.documents
ORDER BY id
LIMIT :limit_val
OFFSET :offset_val
"""
Expand All @@ -109,20 +127,18 @@ def upgrade() -> None:
batch_docs_query, {"limit_val": BATCH_SIZE, "offset_val": offset}
).fetchall()

# If no results, break early
if not batch_docs:
break

doc_ids = [row["id"] for row in batch_docs]
offset += BATCH_SIZE

# d) For each document in this batch, sum up tokens from the chunks table
# Process each document in the batch
for doc_id in doc_ids:
# Get all chunk text for this doc
chunks_query = text(
"""
f"""
SELECT data
FROM chunks
FROM {project_name}.chunks
WHERE document_id = :doc_id
"""
)
Expand All @@ -137,10 +153,10 @@ def upgrade() -> None:
chunk_text, model=default_model
)

# e) Update total_tokens for this doc
# Update total_tokens for this document
update_query = text(
"""
UPDATE documents
f"""
UPDATE {project_name}.documents
SET total_tokens = :tokcount
WHERE id = :doc_id
"""
Expand All @@ -155,11 +171,8 @@ def upgrade() -> None:


def downgrade() -> None:
"""
If you want to remove the total_tokens column on downgrade, do so here.
Otherwise, you can leave it in place.
"""
"""Remove the total_tokens column on downgrade"""
logger.info(
"Dropping column 'total_tokens' from 'documents' table (downgrade)."
)
op.drop_column("documents", "total_tokens")
op.drop_column("documents", "total_tokens", schema=project_name)

0 comments on commit 4f0b50c

Please sign in to comment.