From ead39fa9cfb99ea1c0859ab587ed538397dfc78a Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 25 Jul 2024 14:20:18 +0100 Subject: [PATCH 1/3] Move "hull" pointer to lineage Update the class definitions Remove unused hull argument msp_alloc_segment --- algorithms.py | 48 +++++++++++++++++++++++------------------------- lib/msprime.c | 38 +++++++++++++++++--------------------- lib/msprime.h | 2 +- 3 files changed, 41 insertions(+), 47 deletions(-) diff --git a/algorithms.py b/algorithms.py index dec9976e7..d5fcbb478 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1,6 +1,8 @@ """ Python version of the simulation algorithm. """ +from __future__ import annotations + import argparse import dataclasses import heapq @@ -111,6 +113,9 @@ def find(self, v): return j + 1 +# Once we drop support for 3.9 we can use slots=True to prevent +# writing extra attrs. +@dataclasses.dataclass # (slots=True) class Segment: """ A class representing a single segment. Each segment has a left @@ -118,17 +123,15 @@ class Segment: next, giving the next in the chain. """ - def __init__(self, index): - self.left = None - self.right = None - self.node = None - self.prev = None - self.next = None - self.population = None - self.label = 0 - self.index = index - self.hull = None - self.lineage = None + index: int + left: float = 0 + right: float = 0 + node: int = -1 + prev: Segment = None + next: Segment = None # noqa: A003 + lineage: Lineage = None + population: int = -1 + label: int = 0 def __repr__(self): return repr((self.left, self.right, self.node)) @@ -154,7 +157,7 @@ def get_hull(self): assert seg is not None while seg.prev is not None: seg = seg.prev - hull = seg.hull + hull = seg.lineage.hull return hull def get_left_index(self): @@ -166,11 +169,6 @@ def get_left_index(self): return index -@dataclasses.dataclass -class Lineage: - head: Segment - - class Population: """ Class representing a population in the simulation. @@ -729,6 +727,12 @@ def __repr__(self): return f"x:{self.x}, io:{self.insertion_order}" +@dataclasses.dataclass +class Lineage: + head: Segment + hull: Hull = None + + class OrderStatisticsTree: """ Bintrees AVL tree with added functionality to keep track of the rank @@ -1054,15 +1058,11 @@ def change_migration_matrix_element(self, pop_i, pop_j, rate): self.migration_matrix[pop_i][pop_j] = rate def alloc_hull(self, left, right, lineage): - alpha = lineage.head hull = self.hull_stack.pop() hull.left = left hull.right = right - while alpha.prev is not None: - alpha = alpha.prev - assert alpha is not None hull.lineage = lineage - alpha.hull = hull + lineage.hull = hull return hull def alloc_segment( @@ -1074,7 +1074,6 @@ def alloc_segment( prev=None, next=None, # noqa: A002 label=0, - hull=None, ): """ Pops a new segment off the stack and sets its properties. @@ -1087,7 +1086,6 @@ def alloc_segment( s.next = next s.prev = prev s.label = label - s.hull = hull return s def alloc_lineage(self, head): @@ -1688,7 +1686,7 @@ def migration_event(self, j, k): index = random.randint(0, source.get_num_ancestors(label) - 1) lineage = source.remove(index, label) x = lineage.head - hull = x.get_hull() + hull = lineage.hull assert (self.model == "smc_k") == (hull is not None) dest.add(lineage, label) if self.model == "smc_k": diff --git a/lib/msprime.c b/lib/msprime.c index a60b0dd46..1e2d482db 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -200,7 +200,7 @@ segment_get_hull(segment_t *seg) seg = seg->prev; } tsk_bug_assert(seg->lineage != NULL); - hull = seg->hull; + hull = seg->lineage->hull; tsk_bug_assert(hull->lineage == seg->lineage); return hull; @@ -838,8 +838,7 @@ msp_set_hull_block_size(msp_t *self, size_t block_size) static segment_t *MSP_WARN_UNUSED msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, - population_id_t population, label_id_t label, segment_t *prev, segment_t *next, - hull_t *hull) + population_id_t population, label_id_t label, segment_t *prev, segment_t *next) { segment_t *seg = NULL; @@ -878,7 +877,6 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, seg->value = value; seg->population = population; seg->label = label; - seg->hull = hull; out: return seg; } @@ -907,7 +905,7 @@ static segment_t *MSP_WARN_UNUSED msp_copy_segment(msp_t *self, const segment_t *seg) { return msp_alloc_segment(self, seg->left, seg->right, seg->value, seg->population, - seg->label, seg->prev, seg->next, seg->hull); + seg->label, seg->prev, seg->next); } static hull_t *MSP_WARN_UNUSED @@ -950,7 +948,7 @@ msp_alloc_hull(msp_t *self, double left, double right, lineage_t *lineage) hull->count = 0; hull->insertion_order = UINT64_MAX; tsk_bug_assert(lineage->head->prev == NULL); - lineage->head->hull = hull; + lineage->hull = hull; out: return hull; } @@ -1949,7 +1947,7 @@ msp_verify_hulls(msp_t *self) for (a = avl->head; a->next != NULL; a = a->next) { lin = (lineage_t *) a->item; x = lin->head; - hull_right = x->hull->right; + hull_right = lin->hull->right; hull_a.left = x->left; while (x->next != NULL) { x = x->next; @@ -2571,7 +2569,7 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, lineage_t *ind; segment_t *x, *y; double recomb_mass, gc_mass; - hull_t *hull, *new_hull, *h; + hull_t *hull, *new_hull; if (self->populations[dest_pop].state != MSP_POP_STATE_ACTIVE) { ret = MSP_ERR_POPULATION_INACTIVE_MOVE; @@ -2623,10 +2621,9 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, // new_hull = msp_alloc_hull(self, hull->left, hull->right, new_ind); // msp_free_hull(self, hull, ind->population, ind->label); //} - h = new_hull; for (x = ind->head; x != NULL; x = x->next) { - y = msp_alloc_segment(self, x->left, x->right, x->value, x->population, - dest_label, y, NULL, h); + y = msp_alloc_segment( + self, x->left, x->right, x->value, x->population, dest_label, y, NULL); if (x->prev == NULL) { ind->head = y; y->lineage = ind; @@ -2644,7 +2641,6 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, fenwick_set_value(&self->gc_mass_index[y->label], y->id, gc_mass); } msp_free_segment(self, x); - h = NULL; } } if (new_hull != NULL) { @@ -3004,8 +3000,8 @@ msp_dtwf_recombine( } else { tail = seg_tails[ix]; } - z = msp_alloc_segment(self, k, x->right, x->value, x->population, x->label, - tail, x->next, NULL); + z = msp_alloc_segment( + self, k, x->right, x->value, x->population, x->label, tail, x->next); if (z == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3284,7 +3280,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ if (y->left < breakpoint) { tsk_bug_assert(breakpoint < y->right); alpha = msp_alloc_segment(self, breakpoint, y->right, y->value, y->population, - y->label, NULL, y->next, NULL); + y->label, NULL, y->next); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3694,7 +3690,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l alpha->next = NULL; } else if (x->left != y->left) { alpha = msp_alloc_segment(self, x->left, y->left, x->value, - x->population, x->label, NULL, NULL, NULL); + x->population, x->label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3754,7 +3750,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l r = nm->position; } alpha = msp_alloc_segment( - self, l, r, v, population_id, label, NULL, NULL, NULL); + self, l, r, v, population_id, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3954,7 +3950,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, x = H[0]; if (node != NULL && next_l < x->right) { alpha = msp_alloc_segment(self, x->left, next_l, x->value, x->population, - x->label, NULL, NULL, NULL); + x->label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4023,7 +4019,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, r = nm->position; } alpha = msp_alloc_segment( - self, l, r, new_node_id, population_id, label, NULL, NULL, NULL); + self, l, r, new_node_id, population_id, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4379,7 +4375,7 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri } if (root_segments_head[root] == NULL) { seg = msp_alloc_segment( - self, left, right, root, population, label, NULL, NULL, NULL); + self, left, right, root, population, label, NULL, NULL); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4392,7 +4388,7 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri tail->right = right; } else { seg = msp_alloc_segment( - self, left, right, root, population, label, tail, NULL, NULL); + self, left, right, root, population, label, tail, NULL); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; diff --git a/lib/msprime.h b/lib/msprime.h index f27a22c6f..8abc0abf6 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -84,12 +84,12 @@ typedef struct segment_t_t { size_t id; struct segment_t_t *prev; struct segment_t_t *next; - struct hull_t_t *hull; struct lineage_t_t *lineage; } segment_t; typedef struct lineage_t_t { segment_t *head; + struct hull_t_t *hull; } lineage_t; typedef struct { From f5376bce7c83ff317ef271b980ab584abc19fbf6 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 26 Jul 2024 12:44:54 +0100 Subject: [PATCH 2/3] Move label attr to lineage Patch-up DTWF in python Patch up sweep code in Python Basic coalescent working in C Other basic models passing in C C code mostly working! Fixup python algorithms code --- algorithms.py | 126 +++++++++++++++++++++++-------------- lib/msprime.c | 168 +++++++++++++++++++++++++++++++------------------- lib/msprime.h | 2 +- 3 files changed, 187 insertions(+), 109 deletions(-) diff --git a/algorithms.py b/algorithms.py index d5fcbb478..716f882e3 100644 --- a/algorithms.py +++ b/algorithms.py @@ -131,7 +131,7 @@ class Segment: next: Segment = None # noqa: A003 lineage: Lineage = None population: int = -1 - label: int = 0 + # label: int = 0 def __repr__(self): return repr((self.left, self.right, self.node)) @@ -210,7 +210,7 @@ def print_state(self): for label, ancestors in enumerate(self._ancestors): print("\tLabel = ", label) for lineage in ancestors: - print("\t\t" + Segment.show_chain(lineage.head)) + print(f"\t\t{lineage}") def set_growth_rate(self, growth_rate, time): # TODO This doesn't work because we need to know what the time @@ -415,7 +415,7 @@ def add(self, individual, label=0): Inserts the specified individual into this population. """ assert isinstance(individual, Lineage) - assert individual.head.label == label + assert individual.label == label self._ancestors[label].append(individual) def __iter__(self): @@ -731,6 +731,21 @@ def __repr__(self): class Lineage: head: Segment hull: Hull = None + label: int = 0 + + def __str__(self): + s = ( + f"Lineage(id={hex(id(self))},label={self.label},hull={self.hull}," + f"head={self.head.index}," + f"chain={Segment.show_chain(self.head)})" + ) + return s + + def reset_segments(self): + x = self.head + while x is not None: + x.lineage = self + x = x.next class OrderStatisticsTree: @@ -888,13 +903,14 @@ def __init__( self.hulls[j + 1] = h self.hull_stack.append(h) self.P = [Population(id_, num_labels, max_segments, model) for id_ in range(N)] - if self.recomb_map.total_mass == 0: + mass_indexes_not_used = model in ["dtwf", "fixed_pedigree"] + if self.recomb_map.total_mass == 0 or mass_indexes_not_used: self.recomb_mass_index = None else: self.recomb_mass_index = [ FenwickTree(self.max_segments) for j in range(num_labels) ] - if self.gc_map.total_mass == 0: + if self.gc_map.total_mass == 0 or mass_indexes_not_used: self.gc_mass_index = None else: self.gc_mass_index = [ @@ -990,7 +1006,6 @@ def initialise(self, ts): if seg is not None: left_end = seg.left pop = seg.population - label = seg.label lineage = self.alloc_lineage(seg) self.P[seg.population].add(lineage) while seg is not None: @@ -1003,8 +1018,8 @@ def initialise(self, ts): if seg is not None: left_end = seg.left pop = seg.population - label = seg.label lineage = seg.lineage + label = lineage.label right_end = root_segments_tail[node].right new_hull = self.alloc_hull(left_end, right_end, lineage) # insert Hull @@ -1073,7 +1088,7 @@ def alloc_segment( population, prev=None, next=None, # noqa: A002 - label=0, + lineage=None, ): """ Pops a new segment off the stack and sets its properties. @@ -1085,12 +1100,16 @@ def alloc_segment( s.population = population s.next = next s.prev = prev - s.label = label + s.lineage = lineage return s - def alloc_lineage(self, head): - lineage = Lineage(head) - head.lineage = lineage + def alloc_lineage(self, head, *, label=0): + lineage = Lineage(head, label=label) + lineage.reset_segments() + x = head + while x is not None: + x.lineage = lineage + x = x.next return lineage def copy_segment(self, segment): @@ -1101,7 +1120,7 @@ def copy_segment(self, segment): population=segment.population, next=segment.next, prev=segment.prev, - label=segment.label, + lineage=segment.lineage, ) def free_segment(self, u): @@ -1110,9 +1129,9 @@ def free_segment(self, u): setting its weight to zero. """ if self.recomb_mass_index is not None: - self.recomb_mass_index[u.label].set_value(u.index, 0) + self.recomb_mass_index[u.lineage.label].set_value(u.index, 0) if self.gc_mass_index is not None: - self.gc_mass_index[u.label].set_value(u.index, 0) + self.gc_mass_index[u.lineage.label].set_value(u.index, 0) self.segment_stack.append(u) def free_hull(self, u): @@ -1169,6 +1188,16 @@ def store_edge(self, left, right, parent, child): tskit.Edge(left=left, right=right, parent=parent, child=child) ) + def add_lineage(self, lineage): + pop = lineage.head.population + self.P[pop].add(lineage, lineage.label) + # print("add", lineage) + x = lineage.head + while x is not None: + # print("\t", x.lineage) + assert x.lineage == lineage + x = x.next + def finalise(self): """ Finalises the simulation returns an msprime tree sequence object. @@ -1395,7 +1424,7 @@ def single_sweep_simulate(self): self.set_labels(lineage, 1) indices.append(idx) else: - assert lineage.head.label == 0 + assert lineage.label == 0 popped = 0 for i in indices: tmp = self.P[0].remove(i - popped, 0) @@ -1534,8 +1563,7 @@ def dtwf_generation(self): lin_pair = self.dtwf_recombine(child, parent_nodes) for lin in lin_pair: if lin is not None and lin != child: - pop.add(lin) - + self.add_lineage(lin) self.verify() # Collect segments inherited from the same individual for i, lin in enumerate(lin_pair): @@ -1722,12 +1750,12 @@ def set_segment_mass(self, seg): appropriately set before calling this function. """ if self.recomb_mass_index is not None: - mass_index = self.recomb_mass_index[seg.label] + mass_index = self.recomb_mass_index[seg.lineage.label] recomb_left_bound = self.get_recomb_left_bound(seg) recomb_mass = self.recomb_map.mass_between(recomb_left_bound, seg.right) mass_index.set_value(seg.index, recomb_mass) if self.gc_mass_index is not None: - mass_index = self.gc_mass_index[seg.label] + mass_index = self.gc_mass_index[seg.lineage.label] gc_left_bound = self.get_gc_left_bound(seg) gc_mass = self.gc_map.mass_between(gc_left_bound, seg.right) mass_index.set_value(seg.index, gc_mass) @@ -1737,18 +1765,16 @@ def set_labels(self, lineage, new_label): Move the specified lineage to the specified label. """ mass_indexes = [self.recomb_mass_index, self.gc_mass_index] + assert new_label != lineage.label segment = lineage.head while segment is not None: - masses = [] for mass_index in mass_indexes: if mass_index is not None: - masses.append(mass_index[segment.label].get_value(segment.index)) - mass_index[segment.label].set_value(segment.index, 0) - segment.label = new_label - for mass, mass_index in zip(masses, mass_indexes): - if mass_index is not None: - mass_index[segment.label].set_value(segment.index, mass) + mass = mass_index[lineage.label].get_value(segment.index) + mass_index[lineage.label].set_value(segment.index, 0) + mass_index[new_label].set_value(segment.index, mass) segment = segment.next + lineage.label = new_label def choose_breakpoint(self, mass_index, rate_map): assert mass_index.get_total() > 0 @@ -1799,7 +1825,7 @@ def hudson_recombination_event(self, label, return_heads=False): alpha = y lhs_tail = x - right_lineage = self.alloc_lineage(alpha) + right_lineage = self.alloc_lineage(alpha, label=label) if self.model == "smc_k": # modify original hull pop = alpha.population @@ -1821,6 +1847,7 @@ def hudson_recombination_event(self, label, return_heads=False): self.store_arg_edges(alpha) ret = None if return_heads: + # if True: x = lhs_tail # Seek back to the head of the x chain while x.prev is not None: @@ -1976,12 +2003,8 @@ def wiuf_gene_conversion_within_event(self, label): assert hull_left < hull_right hull_right = min(self.L, hull_right + self.hull_offset) hull = self.alloc_hull(hull_left, hull_right, lineage) - self.P[new_individual_head.population].add_hull( - new_individual_head.label, hull - ) - self.P[new_individual_head.population].add( - lineage, new_individual_head.label - ) + self.P[new_individual_head.population].add_hull(lineage.label, hull) + self.add_lineage(lineage) def wiuf_gene_conversion_left_event(self, label): """ @@ -2073,15 +2096,15 @@ def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): if sweep_site < rhs.left: if r < 1.0 - pop_freq: # move rhs to other population - self.P[rhs.population].remove_individual(right_lin, rhs.label) + self.P[rhs.population].remove_individual(right_lin, right_lin.label) self.set_labels(right_lin, 1 - label) - self.P[rhs.population].add(right_lin, rhs.label) + self.P[rhs.population].add(right_lin, right_lin.label) else: if r < 1.0 - pop_freq: # move lhs to other population - self.P[rhs.population].remove_individual(left_lin, lhs.label) + self.P[rhs.population].remove_individual(left_lin, left_lin.label) self.set_labels(left_lin, 1 - label) - self.P[lhs.population].add(left_lin, lhs.label) + self.P[lhs.population].add(left_lin, left_lin.label) def dtwf_generate_breakpoint(self, start): left_bound = start + 1 if self.discrete_genome else start @@ -2096,6 +2119,9 @@ def dtwf_recombine(self, lineage, ind_nodes): Chooses breakpoints and returns segments sorted by inheritance direction, by iterating through segment chain starting with x """ + # NOTE: the logic here around new lineages being generated + # is very convoluted, and could be done much more simply now + # we have the lineage objects. u = self.alloc_segment(-1, -1, -1, -1, None, None) v = self.alloc_segment(-1, -1, -1, -1, None, None) seg_tails = [u, v] @@ -2183,7 +2209,8 @@ def dtwf_recombine(self, lineage, ind_nodes): if seg is None: ret.append(None) else: - if seg.lineage is lineage: + if seg == lineage.head: + lineage.reset_segments() ret.append(lineage) else: ret.append(self.alloc_lineage(seg)) @@ -2247,7 +2274,6 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): x = X[0] if len(H) > 0 and H[0][0] < x.right: alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population) - alpha.label = label x.left = H[0][0] heapq.heappush(H, (x.left, x)) else: @@ -2307,6 +2333,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): ) z.next = alpha alpha.prev = z + alpha.lineage = new_lineage self.set_segment_mass(alpha) z = alpha if coalescence: @@ -2470,7 +2497,6 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): right=right, node=u, population=population_index, - label=label, ) if x.node != u: # required for dtwf and fixed_pedigree self.store_edge(left, right, u, x.node) @@ -2491,8 +2517,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - new_lineage = self.alloc_lineage(alpha) - pop.add(new_lineage, label) + new_lineage = self.alloc_lineage(alpha, label=label) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2504,6 +2529,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): ) z.next = alpha alpha.prev = z + alpha.lineage = new_lineage self.set_segment_mass(alpha) z = alpha @@ -2519,6 +2545,14 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): if coalescence: self.defrag_breakpoints() + if new_lineage is not None: + x = new_lineage.head + # TODO do this more efficiently + while x is not None: + x.lineage = new_lineage + x = x.next + self.add_lineage(new_lineage) + if new_lineage is not None and self.model == "smc_k": merged_head = new_lineage.head assert merged_head.prev is None @@ -2599,16 +2633,18 @@ def verify_segments(self): for label in range(self.num_labels): for lineage in pop.iter_label(label): assert isinstance(lineage, Lineage) + assert lineage.label == label head = lineage.head assert head.lineage is lineage assert head.prev is None prev = head u = head.next + # print("LIN", lineage) while u is not None: + assert u.lineage == lineage assert prev.next is u assert u.prev is prev assert u.left >= prev.right - assert u.label == head.label assert u.population == head.population prev = u u = u.next @@ -2728,7 +2764,7 @@ def verify(self): Checks that the state of the simulator is consistent. """ self.verify_segments() - if self.model != "fixed_pedigree": + if self.model not in ["fixed_pedigree", "dtwf"]: # The fixed_pedigree model doesn't maintain a bunch of stuff. # It would probably be simpler if it did. self.verify_overlaps() diff --git a/lib/msprime.c b/lib/msprime.c index 1e2d482db..50ce93580 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -450,7 +450,7 @@ msp_set_segment_mass(msp_t *self, segment_t *seg) if (self->recomb_mass_index != NULL) { left_bound = msp_get_recomb_left_bound(self, seg); mass = rate_map_mass_between(&self->recomb_map, left_bound, seg->right); - fenwick_set_value(&self->recomb_mass_index[seg->label], seg->id, mass); + fenwick_set_value(&self->recomb_mass_index[seg->lineage->label], seg->id, mass); } if (self->gc_mass_index != NULL) { /* NOTE: it looks like the gc_left_bound doesn't actually give us the @@ -458,7 +458,7 @@ msp_set_segment_mass(msp_t *self, segment_t *seg) * and use the same left bound for both. */ left_bound = msp_get_gc_left_bound(self, seg); mass = rate_map_mass_between(&self->gc_map, left_bound, seg->right); - fenwick_set_value(&self->gc_mass_index[seg->label], seg->id, mass); + fenwick_set_value(&self->gc_mass_index[seg->lineage->label], seg->id, mass); } } @@ -876,13 +876,22 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, seg->right = right; seg->value = value; seg->population = population; - seg->label = label; out: return seg; } +static void +lineage_reset_segments(lineage_t *self) +{ + segment_t *x; + + for (x = self->head; x != NULL; x = x->next) { + x->lineage = self; + } +} + static lineage_t *MSP_WARN_UNUSED -msp_alloc_lineage(msp_t *self, segment_t *head) +msp_alloc_lineage(msp_t *self, segment_t *head, label_id_t label) { lineage_t *lin = NULL; @@ -896,16 +905,20 @@ msp_alloc_lineage(msp_t *self, segment_t *head) goto out; } lin->head = head; - head->lineage = lin; + lin->label = label; + lineage_reset_segments(lin); out: return lin; } static segment_t *MSP_WARN_UNUSED -msp_copy_segment(msp_t *self, const segment_t *seg) +msp_copy_segment(msp_t *self, label_id_t label, const segment_t *seg) { - return msp_alloc_segment(self, seg->left, seg->right, seg->value, seg->population, - seg->label, seg->prev, seg->next); + segment_t *new_seg = msp_alloc_segment(self, seg->left, seg->right, seg->value, + seg->population, label, seg->prev, seg->next); + // FIXME check for NULL return value + new_seg->lineage = seg->lineage; + return new_seg; } static hull_t *MSP_WARN_UNUSED @@ -916,7 +929,7 @@ msp_alloc_hull(msp_t *self, double left, double right, lineage_t *lineage) uint32_t j; tsk_bug_assert(lineage != NULL); - label = lineage->head->label; + label = lineage->label; if (object_heap_empty(&self->hull_heap[label])) { if (object_heap_expand(&self->hull_heap[label]) != 0) { @@ -1271,12 +1284,13 @@ msp_get_segment(msp_t *self, size_t id, label_id_t label) static void msp_free_segment(msp_t *self, segment_t *seg) { - object_heap_free_object(&self->segment_heap[seg->label], seg); + label_id_t label = seg->lineage->label; + object_heap_free_object(&self->segment_heap[label], seg); if (self->recomb_mass_index != NULL) { - fenwick_set_value(&self->recomb_mass_index[seg->label], seg->id, 0); + fenwick_set_value(&self->recomb_mass_index[label], seg->id, 0); } if (self->gc_mass_index != NULL) { - fenwick_set_value(&self->gc_mass_index[seg->label], seg->id, 0); + fenwick_set_value(&self->gc_mass_index[label], seg->id, 0); } } @@ -1313,7 +1327,7 @@ hullend_adjust_insertion_order(hullend_t *h, avl_node_t *node) static inline avl_tree_t * msp_get_segment_population(msp_t *self, segment_t *u) { - return &self->populations[u->population].ancestors[u->label]; + return &self->populations[u->population].ancestors[u->lineage->label]; } static int MSP_WARN_UNUSED @@ -1327,14 +1341,16 @@ msp_insert_hull(msp_t *self, hull_t *hull) hullend_t query; hullend_t *hullend; fenwick_t *coal_mass_index; + label_id_t label; uint64_t num_starting_before_left, num_ending_before_left, count; /* setting hull->count requires two steps step 1: num_starting before hull->left */ tsk_bug_assert(hull != NULL); u = hull->lineage->head; - hulls_left = &self->populations[u->population].hulls_left[u->label]; - coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; + label = hull->lineage->label; + hulls_left = &self->populations[u->population].hulls_left[label]; + coal_mass_index = &self->populations[u->population].coal_mass_index[label]; /* insert hull into state */ node = msp_alloc_avl_node(self); if (node == NULL) { @@ -1363,7 +1379,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) } /* step 2: num ending before hull->left */ - hulls_right = &self->populations[u->population].hulls_right[u->label]; + hulls_right = &self->populations[u->population].hulls_right[label]; query.position = hull->left; query.insertion_order = UINT64_MAX; if (hulls_right->head == NULL) { @@ -1379,7 +1395,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) fenwick_set_value(coal_mass_index, hull->id, (double) count); /* insert hullend into state */ - hullend = msp_alloc_hullend(self, hull->right, u->label); + hullend = msp_alloc_hullend(self, hull->right, label); if (hullend == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -1406,11 +1422,13 @@ msp_remove_hull(msp_t *self, hull_t *hull) avl_tree_t *hulls_left, *hulls_right; fenwick_t *coal_mass_index; segment_t *u; + label_id_t label; u = hull->lineage->head; + label = hull->lineage->label; tsk_bug_assert(u != NULL); - hulls_left = &self->populations[u->population].hulls_left[u->label]; - coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; + hulls_left = &self->populations[u->population].hulls_left[label]; + coal_mass_index = &self->populations[u->population].coal_mass_index[label]; node = avl_search(hulls_left, hull); tsk_bug_assert(node != NULL); @@ -1439,7 +1457,7 @@ msp_remove_hull(msp_t *self, hull_t *hull) msp_free_avl_node(self, node); /* remove node from hulls_right */ - hulls_right = &self->populations[u->population].hulls_right[u->label]; + hulls_right = &self->populations[u->population].hulls_right[label]; query.position = hull->right; query.insertion_order = UINT64_MAX; c = avl_search_closest(hulls_right, &query, &query_node); @@ -1452,7 +1470,7 @@ msp_remove_hull(msp_t *self, hull_t *hull) node = query_node; avl_unlink_node(hulls_right, node); msp_free_avl_node(self, node); - msp_free_hullend(self, query_ptr, u->label); + msp_free_hullend(self, query_ptr, label); } static inline int MSP_WARN_UNUSED @@ -1539,7 +1557,7 @@ msp_print_segment_chain(msp_t *MSP_UNUSED(self), segment_t *head, FILE *out) tsk_bug_assert(lin != NULL); - fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, s->population, s->label); + fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, s->population, lin->label); while (s != NULL) { fprintf(out, "[(%.14g,%.14g) %d] ", s->left, s->right, (int) s->value); s = s->next; @@ -1635,12 +1653,13 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) while (node != NULL) { lin = (lineage_t *) node->item; u = lin->head; + tsk_bug_assert(lin->label == (label_id_t) k); tsk_bug_assert(u->lineage == lin); tsk_bug_assert(u->prev == NULL); while (u != NULL) { label_segments++; + tsk_bug_assert(u->lineage == lin); tsk_bug_assert(u->population == (population_id_t) j); - tsk_bug_assert(u->label == (label_id_t) k); tsk_bug_assert(u->left < u->right); tsk_bug_assert(u->right <= self->sequence_length); if (u->prev != NULL) { @@ -1727,7 +1746,6 @@ overlap_counter_alloc(overlap_counter_t *self, double seq_length, int initial_co overlaps->right = seq_length; overlaps->value = initial_count; overlaps->population = 0; - overlaps->label = 0; self->seq_length = seq_length; self->overlaps = overlaps; @@ -1780,7 +1798,6 @@ overlap_counter_split_segment(segment_t *seg, double breakpoint) right_seg->right = seg->right; right_seg->value = seg->value; right_seg->population = 0; - right_seg->label = 0; if (seg->next != NULL) { right_seg->next = seg->next; @@ -2597,7 +2614,7 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, goto out; } } - if (ind->head->label == dest_label) { + if (ind->label == dest_label) { /* Need to set the population and label for each segment. */ new_hull = hull; for (x = ind->head; x != NULL; x = x->next) { @@ -2632,16 +2649,17 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } if (self->recomb_mass_index != NULL) { recomb_mass - = fenwick_get_value(&self->recomb_mass_index[x->label], x->id); + = fenwick_get_value(&self->recomb_mass_index[ind->label], x->id); fenwick_set_value( - &self->recomb_mass_index[y->label], y->id, recomb_mass); + &self->recomb_mass_index[dest_label], y->id, recomb_mass); } if (self->gc_mass_index != NULL) { - gc_mass = fenwick_get_value(&self->gc_mass_index[x->label], x->id); - fenwick_set_value(&self->gc_mass_index[y->label], y->id, gc_mass); + gc_mass = fenwick_get_value(&self->gc_mass_index[ind->label], x->id); + fenwick_set_value(&self->gc_mass_index[dest_label], y->id, gc_mass); } msp_free_segment(self, x); } + ind->label = dest_label; } if (new_hull != NULL) { new_hull->lineage = ind; @@ -2650,6 +2668,7 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, goto out; } } + lineage_reset_segments(ind); ret = msp_insert_individual(self, ind); out: return ret; @@ -2976,6 +2995,7 @@ msp_dtwf_recombine( segment_t s1, s2; segment_t *seg_tails[] = { &s1, &s2 }; segment_t **rec_heads[MSP_MAX_PED_PLOIDY] = { u, v }; + const label_id_t label = 0; x = x_head; k = msp_dtwf_generate_breakpoint(self, x->left); @@ -3001,11 +3021,12 @@ msp_dtwf_recombine( tail = seg_tails[ix]; } z = msp_alloc_segment( - self, k, x->right, x->value, x->population, x->label, tail, x->next); + self, k, x->right, x->value, x->population, label, tail, x->next); if (z == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } + z->lineage = x->lineage; msp_set_segment_mass(self, z); tsk_bug_assert(z->left < z->right); if (x->next != NULL) { @@ -3049,8 +3070,11 @@ msp_dtwf_recombine( for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { y = *rec_heads[j]; + if (y == x_head) { + lineage_reset_segments(y->lineage); + } if (y != x_head && y != NULL) { - lin = msp_alloc_lineage(self, y); + lin = msp_alloc_lineage(self, y, label); if (lin == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3262,7 +3286,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ { int ret = 0; double breakpoint; - lineage_t *right_lineage; + lineage_t *left_lineage, *right_lineage; segment_t *x, *y, *alpha, *lhs_tail; hull_t *lhs_hull, *rhs_hull; double lhs_right, rhs_right; @@ -3276,15 +3300,17 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ goto out; } x = y->prev; + left_lineage = y->lineage; if (y->left < breakpoint) { tsk_bug_assert(breakpoint < y->right); - alpha = msp_alloc_segment(self, breakpoint, y->right, y->value, y->population, - y->label, NULL, y->next); + alpha = msp_alloc_segment( + self, breakpoint, y->right, y->value, y->population, label, NULL, y->next); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } + alpha->lineage = left_lineage; if (y->next != NULL) { y->next->prev = alpha; } @@ -3311,7 +3337,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ } tsk_bug_assert(alpha->left < alpha->right); msp_set_segment_mass(self, alpha); - right_lineage = msp_alloc_lineage(self, alpha); + right_lineage = msp_alloc_lineage(self, alpha, label); if (right_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3461,7 +3487,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) // ===== ==== α // ====== /* alpha = self->copy_segment(y) */ - alpha = msp_copy_segment(self, y); + alpha = msp_copy_segment(self, label, y); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3509,7 +3535,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) // ===== =========== // ... === // z - head = msp_copy_segment(self, z); + head = msp_copy_segment(self, label, z); if (head == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3557,7 +3583,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) = GSL_MIN(reset_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); msp_reset_hull_right( - self, hull, hull->right, reset_right, y->population, y->label); + self, hull, hull->right, reset_right, y->population, label); } } @@ -3573,7 +3599,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) new_individual_head = head; } if (new_individual_head != NULL) { - new_lineage = msp_alloc_lineage(self, new_individual_head); + new_lineage = msp_alloc_lineage(self, new_individual_head, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3654,8 +3680,8 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l double l, r, l_min, r_max; avl_node_t *node; node_mapping_t *nm, search; - lineage_t *new_lineage; segment_t *x, *y, *z, *alpha, *beta, *merged_head; + lineage_t *new_lineage = NULL; hull_t *hull = NULL; x = a; @@ -3689,8 +3715,8 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l x = x->next; alpha->next = NULL; } else if (x->left != y->left) { - alpha = msp_alloc_segment(self, x->left, y->left, x->value, - x->population, x->label, NULL, NULL); + alpha = msp_alloc_segment( + self, x->left, y->left, x->value, x->population, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3783,15 +3809,11 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } if (alpha != NULL) { if (z == NULL) { - new_lineage = msp_alloc_lineage(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } - ret = msp_insert_individual(self, new_lineage); - if (ret != 0) { - goto out; - } merged_head = alpha; } else { if ((self->additional_nodes & MSP_NODE_IS_CA_EVENT) @@ -3806,6 +3828,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l z->next = alpha; } alpha->prev = z; + alpha->lineage = new_lineage; msp_set_segment_mass(self, alpha); z = alpha; } @@ -3837,6 +3860,16 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l goto out; } } + + if (new_lineage != NULL) { + // TODO this could be done more efficiently by exhausing the + // x and y chains above + lineage_reset_segments(new_lineage); + ret = msp_insert_individual(self, new_lineage); + if (ret != 0) { + goto out; + } + } if (ret_merged_head != NULL) { *ret_merged_head = merged_head; } @@ -3949,8 +3982,8 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, if (h == 1) { x = H[0]; if (node != NULL && next_l < x->right) { - alpha = msp_alloc_segment(self, x->left, next_l, x->value, x->population, - x->label, NULL, NULL); + alpha = msp_alloc_segment( + self, x->left, next_l, x->value, x->population, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4050,15 +4083,11 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, if (alpha != NULL) { if (z == NULL) { merged_head = alpha; - new_lineage = msp_alloc_lineage(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } - ret = msp_insert_individual(self, new_lineage); - if (ret != 0) { - goto out; - } } else { if ((self->additional_nodes & MSP_NODE_IS_CA_EVENT) || (!self->coalescing_segments_only && coalescence)) { @@ -4071,10 +4100,21 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, z->next = alpha; } alpha->prev = z; + alpha->lineage = new_lineage; msp_set_segment_mass(self, alpha); z = alpha; } } + if (new_lineage != NULL) { + ret = msp_insert_individual(self, new_lineage); + if (ret != 0) { + goto out; + } + /* FIXME see note above about avoiding this by exausting + * the original chains */ + lineage_reset_segments(new_lineage); + } + if (coalescence) { if (!self->coalescing_segments_only) { ret = msp_store_arg_edges(self, z, new_node_id); @@ -4101,6 +4141,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, goto out; } } + if (ret_merged_head != NULL) { *ret_merged_head = merged_head; } @@ -4264,6 +4305,7 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea const segment_t *seg; double breakpoints[2]; int j; + const label_id_t label = 0; hull_t *hull = NULL; prev = NULL; @@ -4281,7 +4323,7 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea } } /* Copy the segment and insert into the global state */ - copy = msp_copy_segment(self, seg); + copy = msp_copy_segment(self, label, seg); if (copy == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4291,7 +4333,7 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea } copy->prev = prev; if (prev == NULL) { - lineage = msp_alloc_lineage(self, copy); + lineage = msp_alloc_lineage(self, copy, label); if (lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5175,7 +5217,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) // ===== ===== // ===== // α - alpha = msp_copy_segment(self, y); + alpha = msp_copy_segment(self, label, y); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5211,7 +5253,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } lhs_new_right = y->right; - lineage = msp_alloc_lineage(self, alpha); + lineage = msp_alloc_lineage(self, alpha, label); if (lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5234,7 +5276,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) = GSL_MIN(lhs_new_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); msp_reset_hull_right( - self, lhs_hull, lhs_old_right, lhs_new_right, y->population, y->label); + self, lhs_hull, lhs_old_right, lhs_new_right, y->population, label); // rhs tsk_bug_assert(alpha->left < lhs_old_right); @@ -6001,7 +6043,7 @@ static int msp_change_label(msp_t *self, segment_t *ind, label_id_t label) { int ret = 0; - avl_tree_t *pop = &self->populations[ind->population].ancestors[ind->label]; + avl_tree_t *pop = &self->populations[ind->population].ancestors[ind->lineage->label]; avl_node_t *node; /* Find the this individual in the AVL tree. */ @@ -6024,6 +6066,7 @@ msp_sweep_recombination_event( if (ret != 0) { goto out; } + tsk_bug_assert(lhs->lineage != NULL); tsk_bug_assert(rhs->lineage != NULL); @@ -6106,7 +6149,6 @@ msp_run_sweep(msp_t *self) if (ret != 0) { goto out; } - msp_verify(self, 0); ret = msp_sweep_initialise(self, allele_frequency[0]); if (ret != 0) { goto out; @@ -6115,7 +6157,6 @@ msp_run_sweep(msp_t *self) curr_step = 1; while (msp_get_num_ancestors(self) > 0 && curr_step < num_steps) { events++; - msp_verify(self, 0); /* Set pop sizes & rec_rates */ for (j = 0; j < self->num_labels; j++) { label = (label_id_t) j; @@ -7523,6 +7564,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) } } } + out: msp_safe_free(lineages); msp_safe_free(pi); diff --git a/lib/msprime.h b/lib/msprime.h index 8abc0abf6..18043a863 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -77,7 +77,6 @@ typedef tsk_id_t label_id_t; typedef struct segment_t_t { population_id_t population; - label_id_t label; double left; double right; tsk_id_t value; @@ -89,6 +88,7 @@ typedef struct segment_t_t { typedef struct lineage_t_t { segment_t *head; + label_id_t label; struct hull_t_t *hull; } lineage_t; From be2644d419b9ce60ea2a9c3bdcbff62e34369499 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 30 Jul 2024 16:59:08 +0100 Subject: [PATCH 3/3] Move population attr to lineage Finish moving population to Lineage in Python Get population on lineage working for C FIXUP python algorithms code fixup Pyhton c interface Tidy Python --- algorithms.py | 120 +++++++++++++++++---------------- lib/msprime.c | 141 +++++++++++++++++++++------------------ lib/msprime.h | 9 +-- msprime/_msprimemodule.c | 4 +- 4 files changed, 145 insertions(+), 129 deletions(-) diff --git a/algorithms.py b/algorithms.py index 716f882e3..ba7abc477 100644 --- a/algorithms.py +++ b/algorithms.py @@ -130,8 +130,6 @@ class Segment: prev: Segment = None next: Segment = None # noqa: A003 lineage: Lineage = None - population: int = -1 - # label: int = 0 def __repr__(self): return repr((self.left, self.right, self.node)) @@ -145,10 +143,11 @@ def show_chain(seg): return s[:-2] def __lt__(self, other): - return (self.left, self.right, self.population, self.node) < ( + # TODO not clear here why we need population in the key? + return (self.left, self.right, self.lineage.population, self.node) < ( other.left, other.right, - other.population, + other.lineage.population, self.node, ) @@ -730,12 +729,14 @@ def __repr__(self): @dataclasses.dataclass class Lineage: head: Segment + population: int hull: Hull = None label: int = 0 def __str__(self): s = ( - f"Lineage(id={hex(id(self))},label={self.label},hull={self.hull}," + f"Lineage(id={hex(id(self))}," + f"population={self.population},label={self.label},hull={self.hull}," f"head={self.head.index}," f"chain={Segment.show_chain(self.head)})" ) @@ -969,9 +970,12 @@ def __init__( def initialise(self, ts): root_time = np.max(self.tables.nodes.time) self.t = root_time - + # Note: this is done slightly differently to the C code, which + # stores the root segments so that we can implement sampling + # events easily. root_segments_head = [None for _ in range(ts.num_nodes)] root_segments_tail = [None for _ in range(ts.num_nodes)] + root_lineages = [None for _ in range(ts.num_nodes)] last_S = -1 for tree in ts.trees(): left, right = tree.interval @@ -985,7 +989,9 @@ def initialise(self, ts): for root in tree.roots: population = ts.node(root).population if root_segments_head[root] is None: - seg = self.alloc_segment(left, right, root, population) + seg = self.alloc_segment(left, right, root) + lineage = self.alloc_lineage(seg, population) + root_lineages[root] = lineage root_segments_head[root] = seg root_segments_tail[root] = seg else: @@ -996,29 +1002,29 @@ def initialise(self, ts): seg = self.alloc_segment( left, right, root, population, tail ) + seg.lineage = root_lineages[root] tail.next = seg root_segments_tail[root] = seg self.S[self.L] = -1 # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): - seg = root_segments_head[node] - if seg is not None: + lineage = root_lineages[node] + if lineage is not None: + seg = lineage.head left_end = seg.left - pop = seg.population - lineage = self.alloc_lineage(seg) - self.P[seg.population].add(lineage) while seg is not None: self.set_segment_mass(seg) seg = seg.next + self.add_lineage(lineage) if self.model == "smc_k": for node in range(ts.num_nodes): - seg = root_segments_head[node] - if seg is not None: + lineage = root_lineages[node] + if lineage is not None: + seg = lineage.head left_end = seg.left - pop = seg.population - lineage = seg.lineage + pop = lineage.population label = lineage.label right_end = root_segments_tail[node].right new_hull = self.alloc_hull(left_end, right_end, lineage) @@ -1085,7 +1091,7 @@ def alloc_segment( left, right, node, - population, + population=None, prev=None, next=None, # noqa: A002 lineage=None, @@ -1097,14 +1103,13 @@ def alloc_segment( s.left = left s.right = right s.node = node - s.population = population s.next = next s.prev = prev s.lineage = lineage return s - def alloc_lineage(self, head, *, label=0): - lineage = Lineage(head, label=label) + def alloc_lineage(self, head, population, *, label=0): + lineage = Lineage(head, population=population, label=label) lineage.reset_segments() x = head while x is not None: @@ -1117,7 +1122,6 @@ def copy_segment(self, segment): left=segment.left, right=segment.right, node=segment.node, - population=segment.population, next=segment.next, prev=segment.prev, lineage=segment.lineage, @@ -1189,7 +1193,7 @@ def store_edge(self, left, right, parent, child): ) def add_lineage(self, lineage): - pop = lineage.head.population + pop = lineage.population self.P[pop].add(lineage, lineage.label) # print("add", lineage) x = lineage.head @@ -1615,7 +1619,7 @@ def process_pedigree_common_ancestors(self, ind, ploid): # ancestor in this ploid of this individual. First we remove # them from the populations they are stored in: for _, seg in common_ancestors: - pop = self.P[seg.population] + pop = self.P[seg.lineage.population] pop.remove_individual(seg.lineage) # Merge together these lists of ancestral segments to create the @@ -1723,11 +1727,7 @@ def migration_event(self, j, k): if self.additional_nodes.value & msprime.NODE_IS_MIG_EVENT > 0: self.store_node(k, flags=msprime.NODE_IS_MIG_EVENT) self.store_arg_edges(x) - # Set the population id for each segment also. - u = x - while u is not None: - u.population = k - u = u.next + lineage.population = k def get_recomb_left_bound(self, seg): """ @@ -1794,6 +1794,8 @@ def hudson_recombination_event(self, label, return_heads=False): """ self.num_re_events += 1 y, bp = self.choose_breakpoint(self.recomb_mass_index[label], self.recomb_map) + left_lineage = y.lineage + assert left_lineage.label == label x = y.prev if y.left < bp: # x y @@ -1825,10 +1827,10 @@ def hudson_recombination_event(self, label, return_heads=False): alpha = y lhs_tail = x - right_lineage = self.alloc_lineage(alpha, label=label) + right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label) if self.model == "smc_k": # modify original hull - pop = alpha.population + pop = left_lineage.population lhs_hull = lhs_tail.get_hull() rhs_right = lhs_hull.right lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L) @@ -1836,14 +1838,14 @@ def hudson_recombination_event(self, label, return_heads=False): # create hull for alpha alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage) - self.P[alpha.population].add_hull(label, alpha_hull) + self.P[pop].add_hull(label, alpha_hull) self.set_segment_mass(alpha) - self.P[alpha.population].add(right_lineage, label) + self.add_lineage(right_lineage) if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: - self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT) + self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) - self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT) + self.store_node(right_lineage.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(alpha) ret = None if return_heads: @@ -1888,7 +1890,8 @@ def wiuf_gene_conversion_within_event(self, label): self.num_gc_events += 1 hull = y.get_hull() assert (self.model == "smc_k") == (hull is not None) - pop = y.population + lineage = y.lineage + pop = lineage.population reset_right = -1 # Process left break @@ -1998,12 +2001,12 @@ def wiuf_gene_conversion_within_event(self, label): elif head is not None: new_individual_head = head if new_individual_head is not None: - lineage = self.alloc_lineage(new_individual_head) + lineage = self.alloc_lineage(new_individual_head, pop) if self.model == "smc_k": assert hull_left < hull_right hull_right = min(self.L, hull_right + self.hull_offset) hull = self.alloc_hull(hull_left, hull_right, lineage) - self.P[new_individual_head.population].add_hull(lineage.label, hull) + self.P[lineage.population].add_hull(lineage.label, hull) self.add_lineage(lineage) def wiuf_gene_conversion_left_event(self, label): @@ -2033,7 +2036,8 @@ def wiuf_gene_conversion_left_event(self, label): self.num_gc_events += 1 x = y.prev - pop = y.population + lineage = y.lineage + pop = lineage.population lhs_hull = y.get_hull() assert (self.model == "smc_k") == (lhs_hull is not None) if y.left < bp: @@ -2081,30 +2085,30 @@ def wiuf_gene_conversion_left_event(self, label): self.set_segment_mass(alpha) assert alpha.prev is None - lineage = self.alloc_lineage(alpha) - self.P[alpha.population].add(lineage, label) + lineage = self.alloc_lineage(alpha, pop) + self.add_lineage(lineage) def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): """ Implements a recombination event in during a selective sweep. """ left_lin, right_lin = self.hudson_recombination_event(label, return_heads=True) - lhs = left_lin.head - rhs = right_lin.head r = random.random() - if sweep_site < rhs.left: + if sweep_site < right_lin.head.left: if r < 1.0 - pop_freq: # move rhs to other population - self.P[rhs.population].remove_individual(right_lin, right_lin.label) + self.P[right_lin.population].remove_individual( + right_lin, right_lin.label + ) self.set_labels(right_lin, 1 - label) - self.P[rhs.population].add(right_lin, right_lin.label) + self.P[right_lin.population].add(right_lin, right_lin.label) else: if r < 1.0 - pop_freq: # move lhs to other population - self.P[rhs.population].remove_individual(left_lin, left_lin.label) + self.P[left_lin.population].remove_individual(left_lin, left_lin.label) self.set_labels(left_lin, 1 - label) - self.P[lhs.population].add(left_lin, left_lin.label) + self.P[left_lin.population].add(left_lin, left_lin.label) def dtwf_generate_breakpoint(self, start): left_bound = start + 1 if self.discrete_genome else start @@ -2213,7 +2217,7 @@ def dtwf_recombine(self, lineage, ind_nodes): lineage.reset_segments() ret.append(lineage) else: - ret.append(self.alloc_lineage(seg)) + ret.append(self.alloc_lineage(seg, lineage.population)) return ret @@ -2245,9 +2249,8 @@ def bottleneck_event(self, pop_id, label, intensity): def store_additional_nodes_edges(self, flag, new_node_id, z): if self.additional_nodes.value & flag > 0: if new_node_id == -1: - new_node_id = self.store_node(z.population, flags=flag) - else: - self.update_node_flag(new_node_id, flag) + new_node_id = self.store_node(z.lineage.population) + self.update_node_flag(new_node_id, flag) self.store_arg_edges(z, new_node_id) return new_node_id @@ -2273,7 +2276,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): if len(X) == 1: x = X[0] if len(H) > 0 and H[0][0] < x.right: - alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population) + alpha = self.alloc_segment(x.left, H[0][0], x.node) x.left = H[0][0] heapq.heappush(H, (x.left, x)) else: @@ -2320,7 +2323,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - new_lineage = self.alloc_lineage(alpha) + new_lineage = self.alloc_lineage(alpha, pop_id) pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( @@ -2517,7 +2520,9 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - new_lineage = self.alloc_lineage(alpha, label=label) + new_lineage = self.alloc_lineage( + alpha, population_index, label=label + ) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2629,23 +2634,22 @@ def print_state(self, verify=False): self.verify() def verify_segments(self): - for pop in self.P: + for pop_index, pop in enumerate(self.P): for label in range(self.num_labels): for lineage in pop.iter_label(label): assert isinstance(lineage, Lineage) assert lineage.label == label + assert lineage.population == pop_index head = lineage.head assert head.lineage is lineage assert head.prev is None prev = head u = head.next - # print("LIN", lineage) while u is not None: assert u.lineage == lineage assert prev.next is u assert u.prev is prev assert u.left >= prev.right - assert u.population == head.population prev = u u = u.next @@ -2702,10 +2706,10 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): for pop_index, pop in enumerate(self.P): for lineage in pop.iter_label(label): u = lineage.head + assert lineage.population == pop_index assert u.prev is None left = compute_left_bound(u) while u is not None: - assert u.population == pop_index assert u.left < u.right left_bound = compute_left_bound(u) s = rate_map.mass_between(left_bound, u.right) diff --git a/lib/msprime.c b/lib/msprime.c index 50ce93580..bdfe5af1f 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -838,7 +838,8 @@ msp_set_hull_block_size(msp_t *self, size_t block_size) static segment_t *MSP_WARN_UNUSED msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, - population_id_t population, label_id_t label, segment_t *prev, segment_t *next) + population_id_t TSK_UNUSED(population), label_id_t label, segment_t *prev, + segment_t *next) { segment_t *seg = NULL; @@ -875,7 +876,6 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, seg->left = left; seg->right = right; seg->value = value; - seg->population = population; out: return seg; } @@ -891,7 +891,8 @@ lineage_reset_segments(lineage_t *self) } static lineage_t *MSP_WARN_UNUSED -msp_alloc_lineage(msp_t *self, segment_t *head, label_id_t label) +msp_alloc_lineage( + msp_t *self, segment_t *head, population_id_t population, label_id_t label) { lineage_t *lin = NULL; @@ -905,6 +906,7 @@ msp_alloc_lineage(msp_t *self, segment_t *head, label_id_t label) goto out; } lin->head = head; + lin->population = population; lin->label = label; lineage_reset_segments(lin); out: @@ -914,8 +916,8 @@ msp_alloc_lineage(msp_t *self, segment_t *head, label_id_t label) static segment_t *MSP_WARN_UNUSED msp_copy_segment(msp_t *self, label_id_t label, const segment_t *seg) { - segment_t *new_seg = msp_alloc_segment(self, seg->left, seg->right, seg->value, - seg->population, label, seg->prev, seg->next); + segment_t *new_seg = msp_alloc_segment( + self, seg->left, seg->right, seg->value, -1, label, seg->prev, seg->next); // FIXME check for NULL return value new_seg->lineage = seg->lineage; return new_seg; @@ -1327,7 +1329,7 @@ hullend_adjust_insertion_order(hullend_t *h, avl_node_t *node) static inline avl_tree_t * msp_get_segment_population(msp_t *self, segment_t *u) { - return &self->populations[u->population].ancestors[u->lineage->label]; + return &self->populations[u->lineage->population].ancestors[u->lineage->label]; } static int MSP_WARN_UNUSED @@ -1336,7 +1338,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) int c, ret = 0; avl_node_t *node, *query_node; avl_tree_t *hulls_left, *hulls_right; - segment_t *u; + population_id_t pop; hull_t *curr_hull; hullend_t query; hullend_t *hullend; @@ -1347,10 +1349,10 @@ msp_insert_hull(msp_t *self, hull_t *hull) /* setting hull->count requires two steps step 1: num_starting before hull->left */ tsk_bug_assert(hull != NULL); - u = hull->lineage->head; + pop = hull->lineage->population; label = hull->lineage->label; - hulls_left = &self->populations[u->population].hulls_left[label]; - coal_mass_index = &self->populations[u->population].coal_mass_index[label]; + hulls_left = &self->populations[pop].hulls_left[label]; + coal_mass_index = &self->populations[pop].coal_mass_index[label]; /* insert hull into state */ node = msp_alloc_avl_node(self); if (node == NULL) { @@ -1379,7 +1381,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) } /* step 2: num ending before hull->left */ - hulls_right = &self->populations[u->population].hulls_right[label]; + hulls_right = &self->populations[pop].hulls_right[label]; query.position = hull->left; query.insertion_order = UINT64_MAX; if (hulls_right->head == NULL) { @@ -1387,7 +1389,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) } else { c = avl_search_closest(hulls_right, &query, &query_node); /* query < node->item ==> c = -1 */ - num_ending_before_left = (uint64_t) avl_index(query_node) + (uint64_t)(c != -1); + num_ending_before_left = (uint64_t) avl_index(query_node) + (uint64_t) (c != -1); } /* set number of pairs coalescing with hull */ count = num_starting_before_left - num_ending_before_left; @@ -1423,12 +1425,14 @@ msp_remove_hull(msp_t *self, hull_t *hull) fenwick_t *coal_mass_index; segment_t *u; label_id_t label; + population_id_t pop; u = hull->lineage->head; label = hull->lineage->label; + pop = hull->lineage->population; tsk_bug_assert(u != NULL); - hulls_left = &self->populations[u->population].hulls_left[label]; - coal_mass_index = &self->populations[u->population].coal_mass_index[label]; + hulls_left = &self->populations[pop].hulls_left[label]; + coal_mass_index = &self->populations[pop].coal_mass_index[label]; node = avl_search(hulls_left, hull); tsk_bug_assert(node != NULL); @@ -1457,7 +1461,7 @@ msp_remove_hull(msp_t *self, hull_t *hull) msp_free_avl_node(self, node); /* remove node from hulls_right */ - hulls_right = &self->populations[u->population].hulls_right[label]; + hulls_right = &self->populations[pop].hulls_right[label]; query.position = hull->right; query.insertion_order = UINT64_MAX; c = avl_search_closest(hulls_right, &query, &query_node); @@ -1557,7 +1561,7 @@ msp_print_segment_chain(msp_t *MSP_UNUSED(self), segment_t *head, FILE *out) tsk_bug_assert(lin != NULL); - fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, s->population, lin->label); + fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, lin->population, lin->label); while (s != NULL) { fprintf(out, "[(%.14g,%.14g) %d] ", s->left, s->right, (int) s->value); s = s->next; @@ -1654,12 +1658,12 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) lin = (lineage_t *) node->item; u = lin->head; tsk_bug_assert(lin->label == (label_id_t) k); + tsk_bug_assert(lin->population == (population_id_t) j); tsk_bug_assert(u->lineage == lin); tsk_bug_assert(u->prev == NULL); while (u != NULL) { label_segments++; tsk_bug_assert(u->lineage == lin); - tsk_bug_assert(u->population == (population_id_t) j); tsk_bug_assert(u->left < u->right); tsk_bug_assert(u->right <= self->sequence_length); if (u->prev != NULL) { @@ -1745,11 +1749,9 @@ overlap_counter_alloc(overlap_counter_t *self, double seq_length, int initial_co overlaps->left = 0; overlaps->right = seq_length; overlaps->value = initial_count; - overlaps->population = 0; self->seq_length = seq_length; self->overlaps = overlaps; - out: return ret; } @@ -1797,7 +1799,6 @@ overlap_counter_split_segment(segment_t *seg, double breakpoint) right_seg->left = breakpoint; right_seg->right = seg->right; right_seg->value = seg->value; - right_seg->population = 0; if (seg->next != NULL) { right_seg->next = seg->next; @@ -1889,7 +1890,7 @@ msp_verify_non_empty_populations(msp_t *self) for (avl_node = self->non_empty_populations.head; avl_node != NULL; avl_node = avl_node->next) { - j = (tsk_id_t)(intptr_t) avl_node->item; + j = (tsk_id_t) (intptr_t) avl_node->item; tsk_bug_assert(msp_get_num_population_ancestors(self, j) > 0); } @@ -2274,7 +2275,7 @@ msp_print_state(msp_t *self, FILE *out) } fprintf(out, "non_empty_populations = ["); for (a = self->non_empty_populations.head; a != NULL; a = a->next) { - j = (uint32_t)(intptr_t) a->item; + j = (uint32_t) (intptr_t) a->item; fprintf(out, "%d,", j); } fprintf(out, "]\n"); @@ -2615,18 +2616,17 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } } if (ind->label == dest_label) { - /* Need to set the population and label for each segment. */ new_hull = hull; - for (x = ind->head; x != NULL; x = x->next) { - if (self->store_migrations) { + if (self->store_migrations) { + for (x = ind->head; x != NULL; x = x->next) { ret = msp_record_migration( - self, x->left, x->right, x->value, x->population, dest_pop); + self, x->left, x->right, x->value, ind->population, dest_pop); if (ret != 0) { goto out; } } - x->population = dest_pop; } + ind->population = dest_pop; } else { /* Because we are changing to a different Fenwick tree we must allocate * new segments each time. */ @@ -2640,7 +2640,11 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, //} for (x = ind->head; x != NULL; x = x->next) { y = msp_alloc_segment( - self, x->left, x->right, x->value, x->population, dest_label, y, NULL); + self, x->left, x->right, x->value, -1, dest_label, y, NULL); + if (y == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } if (x->prev == NULL) { ind->head = y; y->lineage = ind; @@ -2996,6 +3000,7 @@ msp_dtwf_recombine( segment_t *seg_tails[] = { &s1, &s2 }; segment_t **rec_heads[MSP_MAX_PED_PLOIDY] = { u, v }; const label_id_t label = 0; + const population_id_t population = x_head->lineage->population; x = x_head; k = msp_dtwf_generate_breakpoint(self, x->left); @@ -3020,8 +3025,7 @@ msp_dtwf_recombine( } else { tail = seg_tails[ix]; } - z = msp_alloc_segment( - self, k, x->right, x->value, x->population, label, tail, x->next); + z = msp_alloc_segment(self, k, x->right, x->value, -1, label, tail, x->next); if (z == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3074,7 +3078,7 @@ msp_dtwf_recombine( lineage_reset_segments(y->lineage); } if (y != x_head && y != NULL) { - lin = msp_alloc_lineage(self, y, label); + lin = msp_alloc_lineage(self, y, population, label); if (lin == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3086,8 +3090,7 @@ msp_dtwf_recombine( for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { ret = msp_store_additional_nodes_edges(self, *rec_heads[j], ind_nodes[j], - MSP_NODE_IS_RE_EVENT, (*rec_heads[j])->population, TSK_NULL, - &ind_nodes[j]); + MSP_NODE_IS_RE_EVENT, population, TSK_NULL, &ind_nodes[j]); if (ret < 0) { goto out; } @@ -3104,7 +3107,7 @@ msp_store_arg_recombination(msp_t *self, segment_t *lhs_tail, segment_t *rhs) /* Store the edges for the LHS */ ret = msp_store_node( - self, MSP_NODE_IS_RE_EVENT, self->time, lhs_tail->population, TSK_NULL); + self, MSP_NODE_IS_RE_EVENT, self->time, lhs_tail->lineage->population, TSK_NULL); if (ret < 0) { goto out; } @@ -3114,7 +3117,7 @@ msp_store_arg_recombination(msp_t *self, segment_t *lhs_tail, segment_t *rhs) } /* Store the edges for the RHS */ ret = msp_store_node( - self, MSP_NODE_IS_RE_EVENT, self->time, rhs->population, TSK_NULL); + self, MSP_NODE_IS_RE_EVENT, self->time, rhs->lineage->population, TSK_NULL); if (ret < 0) { goto out; } @@ -3137,8 +3140,8 @@ msp_store_arg_gene_conversion( if (tail != NULL || head != NULL) { tsk_bug_assert(alpha != NULL); /* Store the edges for tail & head */ - ret = msp_store_node( - self, MSP_NODE_IS_GC_EVENT, self->time, alpha->population, TSK_NULL); + ret = msp_store_node(self, MSP_NODE_IS_GC_EVENT, self->time, + alpha->lineage->population, TSK_NULL); if (ret < 0) { goto out; } @@ -3151,8 +3154,8 @@ msp_store_arg_gene_conversion( goto out; } /* Store the edges for the alpha section */ - ret = msp_store_node( - self, MSP_NODE_IS_GC_EVENT, self->time, alpha->population, TSK_NULL); + ret = msp_store_node(self, MSP_NODE_IS_GC_EVENT, self->time, + alpha->lineage->population, TSK_NULL); if (ret < 0) { goto out; } @@ -3305,7 +3308,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ if (y->left < breakpoint) { tsk_bug_assert(breakpoint < y->right); alpha = msp_alloc_segment( - self, breakpoint, y->right, y->value, y->population, label, NULL, y->next); + self, breakpoint, y->right, y->value, -1, label, NULL, y->next); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3337,7 +3340,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ } tsk_bug_assert(alpha->left < alpha->right); msp_set_segment_mass(self, alpha); - right_lineage = msp_alloc_lineage(self, alpha, label); + right_lineage = msp_alloc_lineage(self, alpha, left_lineage->population, label); if (right_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3354,7 +3357,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ = GSL_MIN(lhs_tail->right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); msp_reset_hull_right( - self, lhs_hull, rhs_right, lhs_right, lhs_tail->population, label); + self, lhs_hull, rhs_right, lhs_right, left_lineage->population, label); /* create new hull for alpha */ rhs_hull = msp_alloc_hull(self, alpha->left, rhs_right, alpha->lineage); @@ -3420,6 +3423,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) hull_t *hull = NULL; double reset_right = 0.0; double tract_hull_left, tract_hull_right; + population_id_t population; tsk_bug_assert(self->gc_mass_index != NULL); self->num_gc_events++; @@ -3430,6 +3434,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) goto out; } + population = y->lineage->population; x = y->prev; /* generate tract length */ @@ -3583,7 +3588,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) = GSL_MIN(reset_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); msp_reset_hull_right( - self, hull, hull->right, reset_right, y->population, label); + self, hull, hull->right, reset_right, population, label); } } @@ -3599,7 +3604,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) new_individual_head = head; } if (new_individual_head != NULL) { - new_lineage = msp_alloc_lineage(self, new_individual_head, label); + new_lineage = msp_alloc_lineage(self, new_individual_head, population, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3716,7 +3721,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l alpha->next = NULL; } else if (x->left != y->left) { alpha = msp_alloc_segment( - self, x->left, y->left, x->value, x->population, label, NULL, NULL); + self, x->left, y->left, x->value, -1, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3809,7 +3814,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } if (alpha != NULL) { if (z == NULL) { - new_lineage = msp_alloc_lineage(self, alpha, label); + new_lineage = msp_alloc_lineage(self, alpha, population_id, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3983,7 +3988,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, x = H[0]; if (node != NULL && next_l < x->right) { alpha = msp_alloc_segment( - self, x->left, next_l, x->value, x->population, label, NULL, NULL); + self, x->left, next_l, x->value, -1, label, NULL, NULL); if (alpha == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4083,7 +4088,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, if (alpha != NULL) { if (z == NULL) { merged_head = alpha; - new_lineage = msp_alloc_lineage(self, alpha, label); + new_lineage = msp_alloc_lineage(self, alpha, population_id, label); if (new_lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4171,8 +4176,8 @@ msp_merge_n_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, for (a = Q->head; a != NULL; a = a->next) { u = (segment_t *) a->item; tsk_bug_assert(u->lineage != NULL); - if (u->population != population_id) { - current_pop = &self->populations[u->population]; + if (u->lineage->population != population_id) { + current_pop = &self->populations[u->lineage->population]; avl_node = avl_search(¤t_pop->ancestors[label], u->lineage); tsk_bug_assert(avl_node != NULL); ret = msp_move_individual( @@ -4208,7 +4213,7 @@ msp_merge_n_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, *ret_merged_head = merged_head; } if (merged_head != NULL) { - tsk_bug_assert(merged_head->population == population_id); + tsk_bug_assert(merged_head->lineage->population == population_id); } out: return ret; @@ -4303,6 +4308,7 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea lineage_t *lineage; segment_t *copy, *prev; const segment_t *seg; + const tsk_id_t *restrict node_population = self->tables->nodes.population; double breakpoints[2]; int j; const label_id_t label = 0; @@ -4333,7 +4339,7 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea } copy->prev = prev; if (prev == NULL) { - lineage = msp_alloc_lineage(self, copy, label); + lineage = msp_alloc_lineage(self, copy, node_population[head->value], label); if (lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4378,10 +4384,11 @@ msp_insert_sample(msp_t *self, tsk_id_t node) { int ret = 0; segment_t *root_seg; + const tsk_id_t *restrict node_population = self->tables->nodes.population; population_t pop; root_seg = self->root_segments[node]; - pop = self->populations[root_seg->population]; + pop = self->populations[node_population[node]]; if (pop.state != MSP_POP_STATE_ACTIVE) { ret = MSP_ERR_POPULATION_INACTIVE_SAMPLE; goto out; @@ -4403,7 +4410,7 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri tsk_id_t root; segment_t *seg, *tail; population_id_t population; - const population_id_t *restrict node_population = self->tables->nodes.population; + const tsk_id_t *restrict node_population = self->tables->nodes.population; label_id_t label = 0; /* For now only support label 0 */ for (root = tsk_tree_get_left_root(tree); root != TSK_NULL; @@ -4416,8 +4423,7 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri goto out; } if (root_segments_head[root] == NULL) { - seg = msp_alloc_segment( - self, left, right, root, population, label, NULL, NULL); + seg = msp_alloc_segment(self, left, right, root, -1, label, NULL, NULL); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4429,13 +4435,13 @@ msp_allocate_root_segments(msp_t *self, tsk_tree_t *tree, double left, double ri if (tail->right == left) { tail->right = right; } else { - seg = msp_alloc_segment( - self, left, right, root, population, label, tail, NULL); + seg = msp_alloc_segment(self, left, right, root, -1, label, tail, NULL); if (seg == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } tail->next = seg; + /* seg->lineage = tail->lineage; */ root_segments_tail[root] = seg; } } @@ -5166,6 +5172,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) const double gc_left_total = msp_get_total_gc_left(self); double h = gsl_rng_uniform(self->rng) * gc_left_total; double tl, bp, lhs_old_right, lhs_new_right; + population_id_t population; lineage_t *lineage; segment_t *y, *x, *alpha; hull_t *rhs_hull; @@ -5174,6 +5181,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) lhs_hull = NULL; lineage = msp_find_gc_left_individual(self, label, h); assert(lineage != NULL); + population = lineage->population; y = lineage->head; assert(y != NULL); @@ -5253,7 +5261,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } lhs_new_right = y->right; - lineage = msp_alloc_lineage(self, alpha, label); + lineage = msp_alloc_lineage(self, alpha, population, label); if (lineage == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5276,7 +5284,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) = GSL_MIN(lhs_new_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); msp_reset_hull_right( - self, lhs_hull, lhs_old_right, lhs_new_right, y->population, label); + self, lhs_hull, lhs_old_right, lhs_new_right, y->lineage->population, label); // rhs tsk_bug_assert(alpha->left < lhs_old_right); @@ -5360,7 +5368,7 @@ msp_run_coalescent(msp_t *self, double max_time, unsigned long max_events) ca_pop_id = 0; for (avl_node = self->non_empty_populations.head; avl_node != NULL; avl_node = avl_node->next) { - pop_id = (tsk_id_t)(intptr_t) avl_node->item; + pop_id = (tsk_id_t) (intptr_t) avl_node->item; t_temp = self->get_common_ancestor_waiting_time(self, pop_id, label); if (t_temp < ca_t_wait) { ca_t_wait = t_temp; @@ -5374,7 +5382,7 @@ msp_run_coalescent(msp_t *self, double max_time, unsigned long max_events) mig_dest_pop = 0; for (avl_node = self->non_empty_populations.head; avl_node != NULL; avl_node = avl_node->next) { - pop_id_j = (tsk_id_t)(intptr_t) avl_node->item; + pop_id_j = (tsk_id_t) (intptr_t) avl_node->item; pop = &self->populations[pop_id_j]; n = avl_count(&pop->ancestors[label]); tsk_bug_assert(n > 0); @@ -6043,13 +6051,14 @@ static int msp_change_label(msp_t *self, segment_t *ind, label_id_t label) { int ret = 0; - avl_tree_t *pop = &self->populations[ind->population].ancestors[ind->lineage->label]; + avl_tree_t *pop + = &self->populations[ind->lineage->population].ancestors[ind->lineage->label]; avl_node_t *node; /* Find the this individual in the AVL tree. */ node = avl_search(pop, ind->lineage); tsk_bug_assert(node != NULL); - ret = msp_move_individual(self, node, pop, ind->population, label); + ret = msp_move_individual(self, node, pop, ind->lineage->population, label); return ret; } @@ -7488,7 +7497,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) for (u = 0; u < (tsk_id_t) n; u++) { lineages[u] = u; } - for (u = 0; u < (tsk_id_t)(2 * n); u++) { + for (u = 0; u < (tsk_id_t) (2 * n); u++) { pi[u] = TSK_NULL; } j = 0; @@ -8363,7 +8372,7 @@ genic_selection_generate_trajectory(sweep_t *self, msp_t *simulator, alpha = 2 * pop_size * trajectory.s; x = 1.0 - genic_selection_stochastic_forwards( - trajectory.dt, 1.0 - x, alpha, gsl_rng_uniform(rng)); + trajectory.dt, 1.0 - x, alpha, gsl_rng_uniform(rng)); /* need our recored traj to stay in bounds */ t += trajectory.dt; sim_time += trajectory.dt * pop_size * simulator->ploidy; diff --git a/lib/msprime.h b/lib/msprime.h index 18043a863..2b8293bc1 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -1,5 +1,5 @@ /* -** Copyright (C) 2015-2020 University of Oxford +** Copyright (C) 2015-2024 University of Oxford ** ** This file is part of msprime. ** @@ -76,17 +76,18 @@ typedef tsk_id_t population_id_t; typedef tsk_id_t label_id_t; typedef struct segment_t_t { - population_id_t population; + tsk_id_t value; + // TODO change to tsk_id_t or uint32? Same for hull_t + size_t id; double left; double right; - tsk_id_t value; - size_t id; struct segment_t_t *prev; struct segment_t_t *next; struct lineage_t_t *lineage; } segment_t; typedef struct lineage_t_t { + population_id_t population; segment_t *head; label_id_t label; struct hull_t_t *hull; diff --git a/msprime/_msprimemodule.c b/msprime/_msprimemodule.c index fcfbf1266..6b3954f02 100644 --- a/msprime/_msprimemodule.c +++ b/msprime/_msprimemodule.c @@ -2470,7 +2470,9 @@ Simulator_individual_to_python(Simulator *self, segment_t *ind) PyObject *t = NULL; size_t num_segments, j; segment_t *u; + lineage_t *lin = ind->lineage; + assert(lin != NULL); num_segments = 0; u = ind; while (u != NULL) { @@ -2485,7 +2487,7 @@ Simulator_individual_to_python(Simulator *self, segment_t *ind) j = 0; while (u != NULL) { t = Py_BuildValue("(d,d,I,I)", u->left, u->right, u->value, - u->population); + lin->population); if (t == NULL) { Py_DECREF(l); goto out;