Skip to content

Commit

Permalink
Merge pull request #40 from samhorsfield96/min_ORF_error
Browse files Browse the repository at this point in the history
Min ORF error
  • Loading branch information
johnlees authored Sep 20, 2024
2 parents 4127601 + 9ed0ba3 commit bee0340
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 67 deletions.
13 changes: 8 additions & 5 deletions src/ORF_scoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,15 @@ std::vector<float> score_TIS (const std::vector<std::tuple<std::string, std::str
{
encoded.push_back(nuc_encode(c).out());
}
torch::Tensor t = torch::tensor(encoded, {torch::kInt64});

if (!tensor_size)
// ensure sequence is padded if too short
const int len_diff = 32 - encoded.size();
for (int pad = 0; pad < len_diff; pad++)
{
tensor_size = t.size(0);
encoded.push_back(0);
}

torch::Tensor t = torch::tensor(encoded, {torch::kInt64});

padded_stack.push_back(std::move(t));
pos_idx.push_back(pos);
Expand All @@ -101,10 +104,10 @@ std::vector<float> score_TIS (const std::vector<std::tuple<std::string, std::str

if (!pos_idx.empty())
{
// pad tensor if only single sequence
// pad tensor to 32 bp if only single sequence, scoring guaranteed to be on 32 length vector
if (pos_idx.size() == 1)
{
torch::Tensor zeroes = torch::zeros({tensor_size}, torch::kInt64);
torch::Tensor zeroes = torch::zeros({32}, torch::kInt64);
padded_stack.push_back(std::move(zeroes));
}
torch::Tensor pred = predict(TIS_model, torch::stack(padded_stack), false);
Expand Down
3 changes: 2 additions & 1 deletion src/call_ORFs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ void generate_ORFs(const int& colour_ID,
std::string start_site_AA = (translate(start_site_DNA)).aa();
site_hash = hasher{}(start_site_AA);

const int num_kmers = start_site_AA.size() - aa_kmer;
// ensure if small start found, can still generate sequence
int num_kmers = start_site_AA.size() > aa_kmer ? start_site_AA.size() - aa_kmer : 1;

site_coverage.resize(num_kmers);

Expand Down
95 changes: 35 additions & 60 deletions src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,35 @@ std::vector<std::pair<Kmer, bool>> get_neighbours (const T& neighbour_iterator)
return neighbour_vector;
}

void calc_start_freq (std::string& start_site_AA,
const boost::dynamic_bitset<>& full_unitig_colour,
tbb::concurrent_unordered_map<std::string, tbb::concurrent_unordered_set<int>>& start_freq_set,
const int& aa_kmer,
const size_t& nb_colours)
{
// ensure if small start found, can still generate sequence
int num_kmers = start_site_AA.size() > aa_kmer ? start_site_AA.size() - aa_kmer : 1;

std::vector<std::string> AA_kmers(num_kmers);

for (int kmer_index = 0; kmer_index < num_kmers; ++kmer_index)
{
AA_kmers[kmer_index] = get_kmers(start_site_AA, kmer_index, aa_kmer);
}

// add colours to start_freq_set
for (int i = 0; i < nb_colours; i++)
{
if ((bool)full_unitig_colour[i])
{
for (const auto& entry_aa : AA_kmers)
{
start_freq_set[entry_aa].insert(i);
}
}
}
}

template <class T, class U, bool is_const>
void analyse_unitigs_binary (ColoredCDBG<MyUnitigMap>& ccdbg,
UnitigMap<DataAccessor<T>, DataStorage<U>, is_const> um,
Expand Down Expand Up @@ -290,37 +319,10 @@ void analyse_unitigs_binary (ColoredCDBG<MyUnitigMap>& ccdbg,
// pull out start codon positions
for (const auto& pos : found_indices)
{
if (unitig.size() - pos >= kmer)
{
std::string start_site_DNA = unitig.substr(pos, kmer);
std::string start_site_AA = (translate(start_site_DNA)).aa();

if (start_site_AA.find('*') != std::string::npos)
{
continue;
}

const int num_kmers = start_site_AA.size() - aa_kmer;

std::vector<std::string> AA_kmers(num_kmers);

for (int kmer_index = 0; kmer_index < num_kmers; ++kmer_index)
{
AA_kmers[kmer_index] = get_kmers(start_site_AA, kmer_index, aa_kmer);
}
std::string start_site_DNA = unitig.substr(pos, kmer);
std::string start_site_AA = (translate(start_site_DNA)).aa();

// add colours to start_freq_set
for (int i = 0; i < nb_colours; i++)
{
if ((bool)full_unitig_colour[i])
{
for (const auto& entry_aa : AA_kmers)
{
start_freq_set[entry_aa].insert(i);
}
}
}
}
calc_start_freq (start_site_AA, full_unitig_colour, start_freq_set, aa_kmer, nb_colours);
}
}

Expand All @@ -341,37 +343,10 @@ void analyse_unitigs_binary (ColoredCDBG<MyUnitigMap>& ccdbg,
// pull out start codon positions
for (const auto& pos : found_indices)
{
if (unitig.size() - pos >= kmer)
{
std::string start_site_DNA = rev_unitig.substr(pos, kmer);
std::string start_site_AA = (translate(start_site_DNA)).aa();

if (start_site_AA.find('*') != std::string::npos)
{
continue;
}
std::string start_site_DNA = rev_unitig.substr(pos, kmer);
std::string start_site_AA = (translate(start_site_DNA)).aa();

const int num_kmers = start_site_AA.size() - aa_kmer;

std::vector<std::string> AA_kmers(num_kmers);

for (int kmer_index = 0; kmer_index < num_kmers; ++kmer_index)
{
AA_kmers[kmer_index] = get_kmers(start_site_AA, kmer_index, aa_kmer);
}

// add colours to start_freq_set
for (int i = 0; i < nb_colours; i++)
{
if ((bool)full_unitig_colour[i])
{
for (const auto& entry_aa : AA_kmers)
{
start_freq_set[entry_aa].insert(i);
}
}
}
}
calc_start_freq (start_site_AA, full_unitig_colour, start_freq_set, aa_kmer, nb_colours);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ boost::dynamic_bitset<> generate_colours(const UnitigMap<DataAccessor<T>, DataSt
const size_t nb_colours,
const size_t position);

void calc_start_freq (const std::string& start_site_AA,
const boost::dynamic_bitset<>& full_unitig_colour,
tbb::concurrent_unordered_map<std::string, tbb::concurrent_unordered_set<int>>& start_freq_set,
const int& aa_kmer,
const size_t& nb_colours);

template <class T, class U, bool is_const>
void analyse_unitigs_binary (ColoredCDBG<MyUnitigMap>& ccdbg,
UnitigMap<DataAccessor<T>, DataStorage<U>, is_const> um,
Expand Down
4 changes: 4 additions & 0 deletions src/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class translate
{
const std::string codon = dna_seq.substr(i, 3);
aa_seq_ += tMap_[codon];
// break if stop codon present
if (tMap_[codon] == '*') {
break;
}
}
};
std::string aa () {return aa_seq_;};
Expand Down
2 changes: 1 addition & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

sys.stderr.write("Running reference build workflow with annotation\n")
subprocess.run(
python_cmd + " ../ggcaller-runner.py --refs pneumo_CL_group2.txt --kmer 31 --out test_dir --max-path-length 5000 --clean-mode strict --min-orf-length 100 --max-ORF-overlap 55 --alignment core --aligner def --annotation fast --evalue 0.0001 --search-radius 3000 --save",
python_cmd + " ../ggcaller-runner.py --refs pneumo_CL_group2.txt --kmer 13 --out test_dir --max-path-length 5000 --clean-mode strict --min-orf-length 0 --max-ORF-overlap 55 --alignment core --aligner def --annotation fast --evalue 0.0001 --search-radius 3000 --save",
shell=True,
check=True)

Expand Down

0 comments on commit bee0340

Please sign in to comment.