Skip to content

Commit

Permalink
Stage helper for dataset building.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706513090
  • Loading branch information
agutkin committed Dec 16, 2024
1 parent a017a15 commit 5f2fa99
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 0 deletions.
172 changes: 172 additions & 0 deletions protoscribe/evolution/stages/build_dataset_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A stage script responsible for building dataset for a particular round.
Typically there is one corpus preparation stage for each round. Between the
rounds the setup needs to be different. When the initial corpus is created
in the first round, we need to generate the language. The subsequent rounds
need to use this language from the first round unchanged. At the same time,
each round needs to use the updated set of categories and glyphs from the
previous round of evolution.
"""

from collections.abc import Sequence
import logging
import os
from typing import Any

from absl import app
from absl import flags
from protoscribe.corpus.builder import build_dataset as builder_lib
from protoscribe.evolution.stages import common_flags
from protoscribe.utils import file_utils

import glob
import os

FLAGS = flags.FLAGS


def _setup_builder(round_data_dir: str) -> list[tuple[str, Any]]:
"""Sets up builder environment and updates the relevant flags.
Args:
round_data_dir: Data directory for the current round.
Returns:
A list of flags for the given round necessary for running the builder.
These are categories flags contain flags to pick up the administrative and
non-administrative categories lists, and the spellings created from the
previous round for round > 0.
"""
categories_flags = []

# Figure out the locations for the data and perform the necessary sanity
# checks.
if os.path.isdir(round_data_dir):
raise ValueError(
f"Directory `{round_data_dir}` already exists: Cowardly unwilling to "
"overwrite previous experiment."
)
round_id = common_flags.ROUND.value
if round_id > 0:
previous_data_dir = common_flags.previous_data_dir()
if not os.path.isdir(previous_data_dir):
raise ValueError(
f"Directory `{previous_data_dir}` does not exist: did you run the "
f"previous round {round_id - 1} needed for round {round_id}?"
)

# Next we check to see if we have correctly generated spelling extensions in
# output directory `inference_extensions` on the previous generation's run.
extensions_dir = f"{previous_data_dir}/inference_extensions"
if not os.path.isdir(extensions_dir):
# TODO: Revisit this when we get to Round 1, since actually the
# *language* does not change. The only thing that changes is that more of
# these will acquire spellings, meaning that we need to update the glyphs,
# plus what gets put into the training versus held-out data.
raise ValueError(
f"Directory `{extensions_dir} does not exist: did you run the "
f"previous round {round_id - 1} needed for round {round_id}?"
)

# Prepare data for new round: make new round directory and copy over the
# language definitions from the previous round.
logging.info("Making %s ...", round_data_dir)
language_dir = os.path.join(round_data_dir, "language")
os.makedirs(language_dir, exist_ok=True)
file_utils.copy_dir(
os.path.join(previous_data_dir, "language"), language_dir
)

# Pick up categories and spellings.
categories_flags.extend([
(
"administrative_categories", os.path.join(
extensions_dir, "administrative_categories.txt"
)
),
(
"non_administrative_categories", os.path.join(
extensions_dir, "non_administrative_categories.txt"
)
),
("concept_spellings", os.path.join(extensions_dir, "spellings.tsv")),
("prefer_concept_svg", "true"),
])

# Check for directory containing SVG glyph extensions.
extensions_svg_dir = os.path.join(round_data_dir, "glyph_extensions_svg")
if os.path.isdir(extensions_svg_dir):
categories_flags.append(
("extension_glyphs_svg_dir", extensions_svg_dir),
)

# At this stage it is safe to do this again.
if not os.path.isdir(round_data_dir):
os.makedirs(round_data_dir, exist_ok=True)
logging.info(
"Created `%s` for outputs for round %d.", round_data_dir, round_id
)

return categories_flags


def _run_builder(app_flags: list[tuple[str, Any]]) -> None:
"""Invokes dataset builder.
Args:
app_flags: A list of pairs mapping flag names to the respective values.
These are the flags filled in by this script. Any other flags passed to
this script by the caller are already parsed.
"""
logging.info("Final local flags: %s", app_flags)
for flag_name, flag_value in app_flags:
FLAGS[flag_name].parse(flag_value)
builder_lib.build_dataset()


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Set up the environment and prepare the flags. Generate language for the
# first round only.
round_data_dir = common_flags.round_data_dir()
categories_flags = _setup_builder(round_data_dir)
generate_language = common_flags.ROUND.value == 0
logging.info("Categories and spellings flags: %s", categories_flags)

# Uses most of the defaults set in `builder_lib`, the other flags are passed
# to this binary directly by the calling scripts.
logging.info("Done with setup. Running dataset builder ...")
app_flags = [
("generate_language", generate_language),
("output_dir", round_data_dir),
("probability_of_supercategory_glyph", 0.0),
("logtostderr", True),
]
if categories_flags:
app_flags.extend(categories_flags)
_run_builder(app_flags=app_flags)


if __name__ == "__main__":
# Temporarily set the output directory flag required by the vanilla builder
# to some temporary value. This is going to be overwritten programmatically
# by the implementation above.
FLAGS.output_dir = "tmp"

app.run(main)
66 changes: 66 additions & 0 deletions protoscribe/evolution/stages/common_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flags common to all the stages."""

import os

from absl import flags

DEFAULT_BASE_DIR = flags.DEFINE_string(
"default_base_dir", None,
"Default base directory.",
required=True
)

ROUND = flags.DEFINE_integer(
"round", 0,
"Identifies which round of the experiment we are running. Note that to run "
"round N, for N>0, round N-1 must have been run."
)

SEMANTIC_MODEL = flags.DEFINE_enum(
"semantic_model", "concepts",
[
"concepts",
"vision"
],
"Type of the semantics model to use."
)

PHONETIC_MODEL = flags.DEFINE_enum(
"phonetic_model", "phonemes",
[
"phonemes",
"logmel-spectrum",
],
"Type of the phonetic model to use."
)


def experiment_dir() -> str:
"""Returns fully-qualified experiment directory path."""
if not flags.FLAGS.experiment_name:
raise ValueError("Experiment name is not provided with --experiment_name!")
return os.path.join(DEFAULT_BASE_DIR.value, flags.FLAGS.experiment_name)


def round_data_dir() -> str:
"""Returns fully-qualified path to the dataset for this round."""
return os.path.join(experiment_dir(), str(ROUND.value))


def previous_data_dir() -> str:
"""Returns fully-qualified path to the previous round's data."""
return os.path.join(experiment_dir(), str(ROUND.value - 1))
28 changes: 28 additions & 0 deletions protoscribe/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def copy_file(src_path: str, dst_path: str) -> None:
def copy_src_file(source_dir: str, file_name: str, output_dir: str) -> None:
"""Copy a source file to a target directory.
Target directory must exist.
Args:
source_dir: Source directory.
file_name: File name or path in a `source_dir` to copy.
Expand All @@ -78,6 +80,8 @@ def copy_src_file(source_dir: str, file_name: str, output_dir: str) -> None:
def copy_full_path(file_path: str, output_dir: str) -> None:
"""Copies a file provided by the full path to target directory.
Target directory must exist.
Args:
file_path: Fully-qualified file path.
output_dir: Output directory.
Expand All @@ -93,6 +97,8 @@ def copy_full_path(file_path: str, output_dir: str) -> None:
def copy_files(paths: list[str], target_dir: str) -> None:
"""Copies files to a target directory.
Target directory must exist.
Args:
paths: List of file paths.
target_dir: Target directory for copying.
Expand All @@ -107,3 +113,25 @@ def copy_files(paths: list[str], target_dir: str) -> None:
for source_path, target_path in paths:
logging.info("Copying %s -> %s ...", source_path, target_path)
shutil.copy(source_path, target_path)


def copy_dir(source_dir: str, target_dir: str) -> None:
"""Copies files from a source directory to a target directory.
Source and target directories must exist. Important: This is NOT a recursive
copy.
Args:
source_dir: Source directory from which the files will be recursively
copied.
target_dir: Target directory for copying.
"""
if not os.path.isdir(source_dir):
raise ValueError(f"Source directory {source_dir} does not exist!")

logging.info("Copying `%s` to `%s` ...", source_dir, target_dir)
source_paths = []
for path in glob.glob(f"{source_dir}/*"):
if not os.path.isdir(path):
source_paths.append(path)
copy_files(source_paths, target_dir)

0 comments on commit 5f2fa99

Please sign in to comment.