Skip to content

Commit

Permalink
fixes bug in rate shifts for skyline when chaniging tree heights
Browse files Browse the repository at this point in the history
fixes a bug where the interval times for new rate shifts are not updated when updating the tree and the rate shifts are relative to the tree heights (for skyline only)
  • Loading branch information
nicfel committed Mar 10, 2023
1 parent 2f736c4 commit d6137ab
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 55 deletions.
40 changes: 39 additions & 1 deletion src/mascot/dynamics/RateShifts.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package mascot.dynamics;

import java.io.PrintStream;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -8,9 +9,10 @@
import beast.base.core.BEASTObject;
import beast.base.inference.CalculationNode;
import beast.base.core.Input;
import beast.base.core.Loggable;
import beast.base.evolution.tree.Tree;

public class RateShifts extends CalculationNode {
public class RateShifts extends CalculationNode implements Loggable {

final public Input<String> dateTimeFormatInput = new Input<>("dateFormat", "the date/time format to be parsed, (e.g., 'dd/M/yyyy')", "dd/M/yyyy");
public Input<List<Double>> valuesInput = new Input<>(
Expand Down Expand Up @@ -111,7 +113,43 @@ public boolean somethingIsDirty(){
private Class<?> getInputClass() {
return ((Double) 0.0).getClass();
}


@Override
public void init(PrintStream out) {
for (int i = 0; i < rateShifts.length; i++) {
out.print("RateShift."+i + "\t");

}

}


@Override
public void log(long sample, PrintStream out) {
for (int i = 0; i < rateShifts.length; i++) {
double val = 0;
if (isRelative) {
val = tree.getRoot().getHeight()*rateShifts[i];
}else{
val= rateShifts[i];
}

out.print(val + "\t");

}


}


@Override
public void close(PrintStream out) {
// TODO Auto-generated method stub

}



}

62 changes: 50 additions & 12 deletions src/mascot/dynamics/StructuredSkyline.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ public class StructuredSkyline extends Dynamics implements Loggable {


double[] intTimes;
double[] storedIntTimes;

int firstlargerzero;

boolean isForward = false;

boolean intTimesKnown = false;

RealParameter migration;

NeDynamicsList parametricFunction;
Expand Down Expand Up @@ -81,16 +84,7 @@ public void initAndValidate() {
break;
}
}
// initialize the intervals
intTimes = new double[rateShiftsInput.get().getDimension()-firstlargerzero];
for (int i=0; i < intTimes.length; i++) {
if (i==0) {
intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero);
}
else {
intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero)-rateShiftsInput.get().getValue(i-1+firstlargerzero);
}
}
computeIntTimes();

// set the Ne dimension for all skygrid dynamics.
for (int i = 0; i < parametricFunction.size(); i++)
Expand All @@ -116,11 +110,29 @@ public void initAndValidate() {

}

/**
private void computeIntTimes() {
// initialize the intervals
intTimes = new double[rateShiftsInput.get().getDimension()-firstlargerzero];
for (int i=0; i < intTimes.length; i++) {
if (i==0) {
intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero);
}
else {
intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero)-rateShiftsInput.get().getValue(i-1+firstlargerzero);
}
}

intTimesKnown=true;
}

/**
* Returns the time to the next interval.
*/
@Override
public double getInterval(int i) {
if (!intTimesKnown)
computeIntTimes();

if (i >= intTimes.length){
return Double.POSITIVE_INFINITY;
}else{
Expand All @@ -130,11 +142,17 @@ public double getInterval(int i) {

@Override
public double[] getIntervals() {
if (!intTimesKnown)
computeIntTimes();

return intTimes;
}

@Override
public double[] getCoalescentRate(int i){
if (!intTimesKnown)
computeIntTimes();

int intervalNr;
if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1)
intervalNr = rateShiftsInput.get().getDimension()-2;
Expand All @@ -156,6 +174,10 @@ public double[] getCoalescentRate(int i){

@Override
public double[] getBackwardsMigration(int i){
if (!intTimesKnown)
computeIntTimes();


int intervalNr;
if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1)
intervalNr = rateShiftsInput.get().getDimension()-2;
Expand Down Expand Up @@ -351,8 +373,11 @@ public boolean intervalIsDirty(int j) {
if(migration.isDirty(i))
return true;

if(rateShiftsInput.get().somethingIsDirty())
if(rateShiftsInput.get().somethingIsDirty()) {
intTimesKnown = false;
return true;
}


return false;
}
Expand Down Expand Up @@ -421,4 +446,17 @@ public void recalculate() {

}

public void store() {
storedIntTimes = new double[intTimes.length];
System.arraycopy(intTimes, 0, storedIntTimes, 0, intTimes.length);
super.store();
}

@Override
public void restore() {
System.arraycopy(storedIntTimes, 0, intTimes, 0, storedIntTimes.length);
super.restore();
}


}
49 changes: 7 additions & 42 deletions src/mascot/operators/NeSwapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
public class NeSwapper extends Operator {
public Input<List<RealParameter>> logNeInput = new Input<>(
"logNe", "input of the log effective population sizes", new ArrayList<>());
public Input<RealParameter> migrationInput = new Input<>(
"migration", "input of the migration rates");

int length;
int dim;
RealParameter mig;

int[][] dirs;

Expand All @@ -27,66 +24,34 @@ public class NeSwapper extends Operator {
public void initAndValidate() {
dim = logNeInput.get().size();
length = logNeInput.get().get(0).getDimension();
mig = migrationInput.get();
for (int i = 0; i < dim; i++) {
if (logNeInput.get().get(0).getDimension()!=length)
throw new IllegalArgumentException("all input paramter have to have the same dimension");
}
dirs = new int[dim][dim];

int c = 0;
for (int a = 0; a < dim; a++) {
for (int b = 0; b < dim; b++) {
if (a!=b) {
dirs[a][b] = c;
c++;
}
}
}

}
}

@Override
public double proposal() {
int nrSpots = Randomizer.nextInt(length)+1;
double add = Randomizer.nextGaussian();

// int nrSpots = 1;

int startSpot = 0;
if (nrSpots!=length) {
startSpot= Randomizer.nextInt(length-nrSpots+1);
}

int i = Randomizer.nextInt(dim);
// int j = Randomizer.nextInt(dim);
int j = Randomizer.nextInt(dim);

// while (i == j)
// j = Randomizer.nextInt(dim);
while (i == j)
j = Randomizer.nextInt(dim);


for (int a = 0; a < nrSpots; a++) {
int index = a+startSpot;
double val = logNeInput.get().get(i).getArrayValue(index);
double newValue = val+add;

if (newValue < logNeInput.get().get(i).getLower() || newValue > logNeInput.get().get(i).getUpper()) {
return Double.NEGATIVE_INFINITY;
}

logNeInput.get().get(i).setValue(index, newValue);


logNeInput.get().get(i).setValue(index, logNeInput.get().get(i).getArrayValue(index)+add);
// logNeInput.get().get(j).setValue(index, val);
logNeInput.get().get(i).setValue(index, logNeInput.get().get(j).getArrayValue(index));
logNeInput.get().get(j).setValue(index, val);
}


// for (int a = 0; a < nrSpots; a++) {
// int index = a+startSpot;
// double val = logNeInput.get().get(i).getArrayValue(index);
// logNeInput.get().get(i).setValue(index, logNeInput.get().get(j).getArrayValue(index));
// logNeInput.get().get(j).setValue(index, val);
// }

//
// List<Integer> froms = new ArrayList<>();
Expand Down

0 comments on commit d6137ab

Please sign in to comment.