diff --git a/src/mascot/dynamics/RateShifts.java b/src/mascot/dynamics/RateShifts.java index 9335266..ae1ad9b 100644 --- a/src/mascot/dynamics/RateShifts.java +++ b/src/mascot/dynamics/RateShifts.java @@ -1,5 +1,6 @@ package mascot.dynamics; +import java.io.PrintStream; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; @@ -8,9 +9,10 @@ import beast.base.core.BEASTObject; import beast.base.inference.CalculationNode; import beast.base.core.Input; +import beast.base.core.Loggable; import beast.base.evolution.tree.Tree; -public class RateShifts extends CalculationNode { +public class RateShifts extends CalculationNode implements Loggable { final public Input dateTimeFormatInput = new Input<>("dateFormat", "the date/time format to be parsed, (e.g., 'dd/M/yyyy')", "dd/M/yyyy"); public Input> valuesInput = new Input<>( @@ -111,7 +113,43 @@ public boolean somethingIsDirty(){ private Class getInputClass() { return ((Double) 0.0).getClass(); } + + + @Override + public void init(PrintStream out) { + for (int i = 0; i < rateShifts.length; i++) { + out.print("RateShift."+i + "\t"); + + } + + } + + + @Override + public void log(long sample, PrintStream out) { + for (int i = 0; i < rateShifts.length; i++) { + double val = 0; + if (isRelative) { + val = tree.getRoot().getHeight()*rateShifts[i]; + }else{ + val= rateShifts[i]; + } + + out.print(val + "\t"); + + } + + + } + + + @Override + public void close(PrintStream out) { + // TODO Auto-generated method stub + + } } + diff --git a/src/mascot/dynamics/StructuredSkyline.java b/src/mascot/dynamics/StructuredSkyline.java index 0084385..6237d40 100644 --- a/src/mascot/dynamics/StructuredSkyline.java +++ b/src/mascot/dynamics/StructuredSkyline.java @@ -35,11 +35,14 @@ public class StructuredSkyline extends Dynamics implements Loggable { double[] intTimes; + double[] storedIntTimes; int firstlargerzero; boolean isForward = false; + boolean intTimesKnown = false; + RealParameter migration; NeDynamicsList parametricFunction; @@ -81,16 +84,7 @@ public void initAndValidate() { break; } } - // initialize the intervals - intTimes = new double[rateShiftsInput.get().getDimension()-firstlargerzero]; - for (int i=0; i < intTimes.length; i++) { - if (i==0) { - intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero); - } - else { - intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero)-rateShiftsInput.get().getValue(i-1+firstlargerzero); - } - } + computeIntTimes(); // set the Ne dimension for all skygrid dynamics. for (int i = 0; i < parametricFunction.size(); i++) @@ -116,11 +110,29 @@ public void initAndValidate() { } - /** + private void computeIntTimes() { + // initialize the intervals + intTimes = new double[rateShiftsInput.get().getDimension()-firstlargerzero]; + for (int i=0; i < intTimes.length; i++) { + if (i==0) { + intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero); + } + else { + intTimes[i] = rateShiftsInput.get().getValue(i+firstlargerzero)-rateShiftsInput.get().getValue(i-1+firstlargerzero); + } + } + + intTimesKnown=true; + } + + /** * Returns the time to the next interval. */ @Override public double getInterval(int i) { + if (!intTimesKnown) + computeIntTimes(); + if (i >= intTimes.length){ return Double.POSITIVE_INFINITY; }else{ @@ -130,11 +142,17 @@ public double getInterval(int i) { @Override public double[] getIntervals() { + if (!intTimesKnown) + computeIntTimes(); + return intTimes; } @Override public double[] getCoalescentRate(int i){ + if (!intTimesKnown) + computeIntTimes(); + int intervalNr; if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1) intervalNr = rateShiftsInput.get().getDimension()-2; @@ -156,6 +174,10 @@ public double[] getCoalescentRate(int i){ @Override public double[] getBackwardsMigration(int i){ + if (!intTimesKnown) + computeIntTimes(); + + int intervalNr; if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1) intervalNr = rateShiftsInput.get().getDimension()-2; @@ -351,8 +373,11 @@ public boolean intervalIsDirty(int j) { if(migration.isDirty(i)) return true; - if(rateShiftsInput.get().somethingIsDirty()) + if(rateShiftsInput.get().somethingIsDirty()) { + intTimesKnown = false; return true; + } + return false; } @@ -421,4 +446,17 @@ public void recalculate() { } + public void store() { + storedIntTimes = new double[intTimes.length]; + System.arraycopy(intTimes, 0, storedIntTimes, 0, intTimes.length); + super.store(); + } + + @Override + public void restore() { + System.arraycopy(storedIntTimes, 0, intTimes, 0, storedIntTimes.length); + super.restore(); + } + + } \ No newline at end of file diff --git a/src/mascot/operators/NeSwapper.java b/src/mascot/operators/NeSwapper.java index ce43770..f11a4f6 100644 --- a/src/mascot/operators/NeSwapper.java +++ b/src/mascot/operators/NeSwapper.java @@ -13,12 +13,9 @@ public class NeSwapper extends Operator { public Input> logNeInput = new Input<>( "logNe", "input of the log effective population sizes", new ArrayList<>()); - public Input migrationInput = new Input<>( - "migration", "input of the migration rates"); int length; int dim; - RealParameter mig; int[][] dirs; @@ -27,31 +24,15 @@ public class NeSwapper extends Operator { public void initAndValidate() { dim = logNeInput.get().size(); length = logNeInput.get().get(0).getDimension(); - mig = migrationInput.get(); for (int i = 0; i < dim; i++) { if (logNeInput.get().get(0).getDimension()!=length) throw new IllegalArgumentException("all input paramter have to have the same dimension"); - } - dirs = new int[dim][dim]; - - int c = 0; - for (int a = 0; a < dim; a++) { - for (int b = 0; b < dim; b++) { - if (a!=b) { - dirs[a][b] = c; - c++; - } - } - } - + } } @Override public double proposal() { int nrSpots = Randomizer.nextInt(length)+1; - double add = Randomizer.nextGaussian(); - -// int nrSpots = 1; int startSpot = 0; if (nrSpots!=length) { @@ -59,34 +40,18 @@ public double proposal() { } int i = Randomizer.nextInt(dim); -// int j = Randomizer.nextInt(dim); + int j = Randomizer.nextInt(dim); -// while (i == j) -// j = Randomizer.nextInt(dim); + while (i == j) + j = Randomizer.nextInt(dim); + for (int a = 0; a < nrSpots; a++) { int index = a+startSpot; double val = logNeInput.get().get(i).getArrayValue(index); - double newValue = val+add; - - if (newValue < logNeInput.get().get(i).getLower() || newValue > logNeInput.get().get(i).getUpper()) { - return Double.NEGATIVE_INFINITY; - } - - logNeInput.get().get(i).setValue(index, newValue); - - - logNeInput.get().get(i).setValue(index, logNeInput.get().get(i).getArrayValue(index)+add); -// logNeInput.get().get(j).setValue(index, val); + logNeInput.get().get(i).setValue(index, logNeInput.get().get(j).getArrayValue(index)); + logNeInput.get().get(j).setValue(index, val); } - - -// for (int a = 0; a < nrSpots; a++) { -// int index = a+startSpot; -// double val = logNeInput.get().get(i).getArrayValue(index); -// logNeInput.get().get(i).setValue(index, logNeInput.get().get(j).getArrayValue(index)); -// logNeInput.get().get(j).setValue(index, val); -// } // // List froms = new ArrayList<>();