Skip to content

Commit

Permalink
Fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
emarinier committed Dec 6, 2024
1 parent e9badeb commit 74f9c6b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, file_blast_map, blast_database, pid_threshold, plength_thresh

def _create_hit(self, file, database_name, blast_record):
logger.debug("database_name=%s", database_name)
if (database_name == '16S_rrsD') or (database_name == '23S'):
if (database_name.startswith('16S_rrs') or database_name.startswith('16S-rrs') \
or (database_name == '23S')):
return PointfinderHitHSPRNA(file, blast_record)
elif ('promoter' in database_name):
return PointfinderHitHSPPromoter(file, blast_record, database_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_amr_gene_name(self):
# pointfinder/campylobacter/23S.fsa -> 23S_1_LR134511.1
# pointfinder/neisseria_gonorrhoeae/23S-rRNA-a1.fsa -> 23S-rRNA-a1_1_AE004969.1
if name.startswith("16S_rrs"): name = name.split("_")[0] + "_" + name.split("_")[1]
elif name.startswith("16S"): name = "16S"
elif name.startswith("16S-rrs"): name = name.split("_")[0].replace("-", "_", 1) # Ex: 16S-rrsD_1_CP049983.1
elif name.startswith("23S"): name = "23S"
else: name = name.split("_")[0]

Expand Down
16 changes: 8 additions & 8 deletions staramr/tests/integration/detection/test_AMRDetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def testPointfinderSalmonellaA67PSuccessNoPhenotype(self):
pointfinder_results = amr_detection.get_pointfinder_results()
self.assertEqual(len(pointfinder_results.index), 1, 'Wrong number of rows in result')

result = pointfinder_results[pointfinder_results['Gene'] == 'gyrA (A67P)']
result = pointfinder_results[pointfinder_results['Gene'] == 'gyrA_1_MH933946.1 (A67P)']
self.assertEqual(len(result.index), 1, 'Wrong number of results detected')
self.assertEqual(result.index[0], 'gyrA-A67P', msg='Wrong file')
self.assertEqual(result['Type'].iloc[0], 'codon', msg='Wrong type')
Expand All @@ -549,7 +549,7 @@ def testPointfinderSalmonellaA67PSuccessNoPhenotype(self):
self.assertEqual(len(records), 1, 'Wrong number of hit records')

expected_records = SeqIO.to_dict(SeqIO.parse(file, 'fasta'))
self.assertEqual(expected_records['gyrA'].seq.upper(), records['gyrA'].seq.upper(), "records don't match")
self.assertEqual(expected_records['gyrA'].seq.upper(), records['gyrA_1_MH933946.1'].seq.upper(), "records don't match")

def testPointfinderSalmonellaA67PDelEndSuccess(self):
pointfinder_database = PointfinderBlastDatabase(self.pointfinder_dir, 'salmonella')
Expand Down Expand Up @@ -729,7 +729,7 @@ def testPointfinderSalmonella_16S_rrSD_C1065T_Success(self):
self.assertEqual(len(records), 1, 'Wrong number of hit records')

expected_records = SeqIO.to_dict(SeqIO.parse(file, 'fasta'))
self.assertEqual(expected_records['16S_rrsD'].seq.upper(), records['16S_rrsD'].seq.upper(),
self.assertEqual(expected_records['16S_rrsD'].seq.upper(), records['16S-rrsD_1_CP049983.1'].seq.upper(),
"records don't match")


Expand Down Expand Up @@ -1016,9 +1016,9 @@ def testResfinderPointfinderSalmonella_16S_C1065T_gyrA_A67_beta_lactam_Success(s
records = SeqIO.to_dict(SeqIO.parse(hit_file, 'fasta'))
self.assertEqual(len(records), 2, 'Wrong number of hit records')
expected_records1 = SeqIO.to_dict(SeqIO.parse(path.join(self.test_data_dir, 'gyrA-A67P.fsa'), 'fasta'))
self.assertEqual(expected_records1['gyrA'].seq.upper(), records['gyrA'].seq.upper(), "records don't match")
self.assertEqual(expected_records1['gyrA'].seq.upper(), records['gyrA_1_MH933946.1'].seq.upper(), "records don't match")
expected_records2 = SeqIO.to_dict(SeqIO.parse(path.join(self.test_data_dir, '16S_rrsD-1T1065.fsa'), 'fasta'))
self.assertEqual(expected_records2['16S_rrsD'].seq.upper(), records['16S_rrsD'].seq.upper(),
self.assertEqual(expected_records2['16S_rrsD'].seq.upper(), records['16S-rrsD_1_CP049983.1'].seq.upper(),
"records don't match")

def testResfinderPointfinderSalmonellaExcludeGenesListSuccess(self):
Expand All @@ -1028,7 +1028,7 @@ def testResfinderPointfinderSalmonellaExcludeGenesListSuccess(self):
amr_detection = AMRDetectionResistance(self.resfinder_database, self.resfinder_drug_table,
self.cge_drug_table, blast_handler,
self.pointfinder_drug_table, pointfinder_database,
output_dir=self.outdir.name, genes_to_exclude=['gyrA'])
output_dir=self.outdir.name, genes_to_exclude=['gyrA_1_MH933946.1'])

file = path.join(self.test_data_dir, "16S_gyrA_beta-lactam.fsa")
files = [file]
Expand Down Expand Up @@ -1115,9 +1115,9 @@ def testResfinderPointfinderSalmonella_16Src_C1065T_gyrArc_A67_beta_lactam_Succe
records = SeqIO.to_dict(SeqIO.parse(hit_file, 'fasta'))
self.assertEqual(len(records), 2, 'Wrong number of hit records')
expected_records1 = SeqIO.to_dict(SeqIO.parse(path.join(self.test_data_dir, 'gyrA-A67P.fsa'), 'fasta'))
self.assertEqual(expected_records1['gyrA'].seq.upper(), records['gyrA'].seq.upper(), "records don't match")
self.assertEqual(expected_records1['gyrA'].seq.upper(), records['gyrA_1_MH933946.1'].seq.upper(), "records don't match")
expected_records2 = SeqIO.to_dict(SeqIO.parse(path.join(self.test_data_dir, '16S_rrsD-1T1065.fsa'), 'fasta'))
self.assertEqual(expected_records2['16S_rrsD'].seq.upper(), records['16S_rrsD'].seq.upper(),
self.assertEqual(expected_records2['16S_rrsD'].seq.upper(), records['16S-rrsD_1_CP049983.1'].seq.upper(),
"records don't match")

def testResfinderExcludeNonMatches(self):
Expand Down

0 comments on commit 74f9c6b

Please sign in to comment.