Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jobs to study alignments #558

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9a7506c
Fix DumpAlignmentJob flow
Icemole Nov 19, 2024
e07caff
Add job to dump text/alignment pairs for all segments
Icemole Nov 19, 2024
57db395
Revert changes in branch to main's
Icemole Nov 19, 2024
258b31d
Black
Icemole Nov 19, 2024
528bb29
Remove job from wrong location
Icemole Nov 19, 2024
dd5bec0
Add job to correct location, add PlotViterbiAlignmentJob
Icemole Nov 19, 2024
8217e5d
Fixes
Icemole Nov 19, 2024
597dfa4
More fixes
Icemole Nov 19, 2024
464ebc8
Fix when alignment is empty
Icemole Nov 19, 2024
cbdab57
Black
Icemole Nov 19, 2024
b61e063
Add file for faulty/empty alignment seqtags
Icemole Nov 19, 2024
b26a326
Work
Icemole Nov 20, 2024
15df7fe
Remove original author from docstring
Icemole Nov 20, 2024
4a703f7
PlotViterbiAlignmentJob: add functionality to plot subset of seq tags
Icemole Nov 20, 2024
cc07c0c
DumpSegmentTextAlignmentJob: always compress output csv
Icemole Nov 20, 2024
dbc4b09
DumpSegmentTextAlignmentJob: add functionality to plot subset of seq …
Icemole Nov 20, 2024
3c85bdd
More work
Icemole Nov 20, 2024
6ed4814
Fix uopen call
Icemole Nov 20, 2024
d155c35
Don't interpolate plot
Icemole Nov 20, 2024
7bb0009
alignment_files -> alignment_caches
Icemole Nov 20, 2024
899a68a
Add full orth function
Icemole Dec 10, 2024
6db360c
Black
Icemole Dec 10, 2024
3143196
Revert _orth change
Icemole Dec 10, 2024
7a550d4
Shorten full_orth code
Icemole Dec 10, 2024
f34fcfd
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
31c6488
Fix race condition in DumpSegmentTextAlignmentJob
Icemole Dec 18, 2024
ba20708
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
0b16c52
Merge remote-tracking branch 'origin/use-orth-function' into segment-…
Icemole Dec 18, 2024
923a821
Use full orth instead of only orth
Icemole Dec 18, 2024
5754ad7
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 265 additions & 2 deletions mm/alignment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
__all__ = [
"get_seq_tag_to_alignment_mapping",
"AlignmentJob",
"DumpAlignmentJob",
"PlotAlignmentJob",
"AMScoresFromAlignmentLogJob",
"ComputeTimeStampErrorJob",
"DumpSegmentTextAlignmentJob",
"PlotViterbiAlignmentJob",
]

import logging
Expand All @@ -12,19 +15,40 @@
import shutil
import statistics
import xml.etree.ElementTree as ET
from typing import Callable, Counter, List, Optional, Tuple, Union
from typing import Callable, Counter, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from sisyphus import *

Path = setup_path(__package__)

import i6_core.lib.corpus as corpus
import i6_core.lib.rasr_cache as rasr_cache
import i6_core.rasr as rasr
import i6_core.util as util

from .flow import alignment_flow, dump_alignment_flow


_AlignmentType = List[Tuple[int, int, int, float]]
_SeqTagToAlignmentType = Dict[str, List[Tuple[int, int, int, float]]]


def get_seq_tag_to_alignment_mapping(
alignment_cache: rasr_cache.FileArchive,
) -> _SeqTagToAlignmentType:
"""
:param alignment_cache: Opened alignment cache from which to extract the alignments.
:return: Mapping from sequence tags to alignments (by frame).
The alignments are a list of tuples (timestamp, allophone_id, hmm_state, alignment_weight).
"""
return {
seq_tag: alignment_cache.read(seq_tag, "align")
for seq_tag in alignment_cache.ft.keys()
if not seq_tag.endswith(".attribs")
}


class AlignmentJob(rasr.RasrCommand, Job):
"""
Align a dataset with the given feature scorer.
Expand Down Expand Up @@ -147,7 +171,6 @@ def run(self, task_id):
)

def plot(self):
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -789,3 +812,243 @@ def plot(self):
plt.xticks(rotation=45)

plt.savefig(plot_file)


class DumpSegmentTextAlignmentJob(Job):
"""
Dumps all text and alignments for the given corpus and alignment files
in a human-readable format defined as follows:
```
<seq-tag>
<text>
<alignment-index-0> <start-0> <end-0> <triphone-0> <weight-0>
<alignment-index-1> <start-1> <end-1> <triphone-1> <weight-1>
...
```
"""

def __init__(
self,
corpus_file: tk.Path,
alignment_caches: Iterable[tk.Path],
allophone_file: tk.Path,
seq_tags_to_dump: Optional[tk.Path] = None,
frame_size: float = 0.25,
frame_step: float = 0.1,
):
"""
:param corpus_file: Corpus file to get the text from.
:param alignment_caches: Alignment files to get the alignments from.
Must correspond to the corpus given in :param:`corpus_file` for the job to work properly.
:param allophone_file: Allophone file with which the alignments given in :param:`alignment_caches` were dumped.
:param seq_tags_to_dump: Specific sequence tags to dump.
By default, dump all sequences given in :param:`alignment_caches`.
:param frame_size: Frame size. Only used to calculate the timestamps of the alignments.
:param frame_step: Frame step. Only used to calculate the timestamps of the alignments.
"""
self.corpus_file = corpus_file
self.alignment_caches = alignment_caches
self.allophone_file = allophone_file
self.seq_tags_to_dump = seq_tags_to_dump
self.frame_size = frame_size
self.frame_step = frame_step

self.out_text_alignment_pairs = self.output_path("segment_txt_alignment.txt.gz")

self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0}

def tasks(self):
yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1))
yield Task("merge", resume="merge")

def run(self, task_id):
# Get the alignment information: seq_tag -> alignment.
align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
align_cache.setAllophones(self.allophone_file.get_path())
seq_tag_to_alignments = get_seq_tag_to_alignment_mapping(align_cache)

# Get the corpus information: seq_tag -> text.
c = corpus.Corpus()
c.load(self.corpus_file.get_path())
seq_tag_to_text = {seq_tag: segment.full_orth() for seq_tag, segment in c.get_segment_mapping().items()}

if self.seq_tags_to_dump is not None:
with util.uopen(self.seq_tags_to_dump.get_path(), "rt") as f:
seq_tags_to_dump = []
for seq_tag in f:
seq_tag = seq_tag.strip()
assert seq_tag in seq_tag_to_alignments, (
f"The sequence tag {seq_tag} provided in seq_tags_to_plot "
"is not in the provided alignment files."
)
seq_tags_to_dump.append(seq_tag)
else:
seq_tags_to_dump = seq_tag_to_alignments.keys()

with util.uopen(f"intermediate_segment_txt_alignment.{task_id}.txt.gz", "wt") as f:
for seq_tag in set(seq_tags_to_dump).intersection(set(seq_tag_to_text.keys())):
res = f"{seq_tag}\n"
res += f"{seq_tag_to_text[seq_tag]}\n"
for (align_idx, allo_id, hmm_state, weight) in seq_tag_to_alignments[seq_tag]:
res += (
f"{align_idx} "
f"{(self.frame_step * align_idx):.3f} "
f"{(self.frame_step * align_idx + self.frame_size):.3f} "
f"{align_cache.allophones[allo_id]}.{hmm_state} "
f"{weight:.3f}\n"
)
res += "\n"
f.write(res)

def merge(self):
with util.uopen(self.out_text_alignment_pairs.get_path(), "wt") as f_out:
for i in range(1, len(self.alignment_caches) + 1):
with util.uopen(f"intermediate_segment_txt_alignment.{i}.txt.gz", "rt") as f_in:
for line in f_in:
f_out.write(line)
f_out.write("\n")


class PlotViterbiAlignmentJob(Job):
"""
Plots the alignments of each segment in the specified alignment files.
"""

def __init__(
self,
alignment_caches: Iterable[tk.Path],
allophone_file: tk.Path,
seq_tags_to_plot: Optional[tk.Path] = None,
corpus_file: Optional[tk.Path] = None,
):
"""
:param alignment_caches: Alignment files to be plotted.
:param allophone_file: Allophone file used in the alignment process.
:param seq_tags_to_plot: Specific sequence tags to plot.
By default, plot all sequences given in :param:`alignment_caches`.
:param corpus_file: Corpus used to generate the alignments. By default, the plots have no title.
If provided, the plots will have the text from the respective segment as title,
whenever the segment is available in the corpus. This should only be given for convenience.
"""
self.alignment_caches = alignment_caches
self.allophone_file = allophone_file
self.seq_tags_to_plot = seq_tags_to_plot
self.corpus_file = corpus_file

self.out_plot_dir = self.output_path("plots", directory=True)
self.out_empty_alignment_seq_tags = self.output_path("empty_alignment_seq_tags.txt")

self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0}

def tasks(self):
yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1))

def extract_phoneme_sequence(self, alignment: np.array) -> Tuple[np.array, np.array]:
"""
:param alignment: Monophone alignment, for instance: `np.array(["a", "a", "b", ...])`.
:return: Monophone sequence (ordered as given),
as well as the indices corresponding to the monophone sequence from the Viterbi alignment.
"""
boundaries = np.concatenate(
[
np.where(alignment[:-1] != alignment[1:])[0],
[len(alignment) - 1], # manually add boundary of last allophone
]
)

lengths = boundaries - np.concatenate([[-1], boundaries[:-1]])
phonemes = alignment[boundaries]
monotonic_idx_alignment = np.repeat(np.arange(len(phonemes)), lengths)
return phonemes, monotonic_idx_alignment

def make_viterbi_matrix(self, label_idx_seq: np.array) -> np.array:
"""
:return: Matrix corresponding to the Viterbi alignment.
"""
num_alignments = len(label_idx_seq)
max_timestamp = max(label_idx_seq) + 1
viterbi_matrix = np.zeros((max_timestamp, num_alignments), dtype=np.float32)
for t, idx in enumerate(label_idx_seq):
viterbi_matrix[idx, t] = 1.0
return viterbi_matrix

def plot(self, viterbi_matrix: np.array, allophone_sequence: List[str], file_name: str, title: str = ""):
"""
:param viterbi_matrix: Matrix to be plotted, corresponding to the Viterbi alignment.
:param allophone_sequence: Allophone sequence (Y-axis tick labels).
:param file_name: File name where to store the plot, relative to `<job>/output/plots/`.
:param title: Optional title to add to the image. By default there will be no title.
:return: Plot corresponding to the monotonic alignment.
"""
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use("Agg")

max_timestamp, num_alignments = np.shape(viterbi_matrix)

fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlabel("Frame")
ax.xaxis.set_label_coords(0.98, -0.03)
ax.set_xbound(0, num_alignments - 1)
ax.set_ybound(-0.5, max_timestamp - 0.5)

ax.set_yticks(np.arange(max_timestamp))
ax.set_yticklabels(allophone_sequence)

ax.set_title(title)

ax.imshow(viterbi_matrix, cmap="Blues", interpolation="none", aspect="auto", origin="lower")

# The plot will be purposefully divided into subdirectories.
os.makedirs(os.path.dirname(os.path.join(self.out_plot_dir.get_path(), file_name)), exist_ok=True)
fig.savefig(os.path.join(self.out_plot_dir.get_path(), f"{file_name}.png"))
matplotlib.pyplot.close(fig)

def run(self, task_id):
import matplotlib

align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
align_cache.setAllophones(self.allophone_file.get_path())
seq_tag_to_alignments = get_seq_tag_to_alignment_mapping(align_cache)

seq_tag_to_text = {}
if self.corpus_file is not None:
c = corpus.Corpus()
c.load(self.corpus_file.get_path())
seq_tag_to_text = {seq_tag: segment.full_orth() for seq_tag, segment in c.get_segment_mapping().items()}

if self.seq_tags_to_plot is not None:
with util.uopen(self.seq_tags_to_plot.get_path(), "rt") as f:
seq_tags_to_plot = []
for seq_tag in f:
seq_tag = seq_tag.strip()
assert seq_tag in seq_tag_to_alignments, (
f"The sequence tag {seq_tag} provided in seq_tags_to_plot "
"is not in the provided alignment files."
)
seq_tags_to_plot.append(seq_tag)
else:
seq_tags_to_plot = seq_tag_to_alignments.keys()
empty_alignment_seq_tags = []
for seq_tag in seq_tags_to_plot:
alignments = seq_tag_to_alignments[seq_tag]
# In some rare cases, the alignment doesn't have to reach a satisfactory end.
# In these cases, the final alignment is empty. Skip those cases.
if len(alignments) == 0:
empty_alignment_seq_tags.append(seq_tag)
continue

for i, (timestamp, allo_id, hmm_state, weight) in enumerate(alignments):
allophone = align_cache.allophones[allo_id]
# Get the central part of the allophone.
seq_tag_to_alignments[seq_tag][i] = allophone.split("{")[0]

center_allophones = np.array(seq_tag_to_alignments[seq_tag])
phonemes, alignment_indices = self.extract_phoneme_sequence(center_allophones)
viterbi_matrix = self.make_viterbi_matrix(alignment_indices)
self.plot(viterbi_matrix, phonemes, file_name=seq_tag, title=seq_tag_to_text.get(seq_tag, ""))

with open(self.out_empty_alignment_seq_tags.get_path(), "wt") as f:
for seq_tag in empty_alignment_seq_tags:
f.write(f"{seq_tag}\n")
Loading