Skip to content

Commit

Permalink
add test (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdidion authored Jun 5, 2021
1 parent 99388da commit b13ce11
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
52 changes: 26 additions & 26 deletions atropos/commands/trim/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,32 @@

class Writers(object):
"""Manages writing to one or more outputs.
Args:
force_create: Whether empty output files should be created.
"""
def __init__(self, force_create=[]):
self.writers = {}
self.force_create = force_create
self.suffix = None

def get_writer(self, file_desc, compressed=False):
"""Create the writer for a file descriptor if it does not already
exist.
Args:
file_desc: File descriptor. If `compressed==True`, this is a tuple
(path, mode), otherwise it's only a path.
compressed: Whether data has already been compressed.
Returns:
The writer.
"""
if compressed:
path, mode = file_desc
else:
path = file_desc

if path not in self.writers:
if self.suffix:
real_path = add_suffix_to_path(path, self.suffix)
Expand All @@ -44,12 +44,12 @@ def get_writer(self, file_desc, compressed=False):
self.writers[path] = open_output(real_path, mode)
else:
self.writers[path] = xopen(real_path, "w")

return self.writers[path]

def write_result(self, result, compressed=False):
"""Write results to output.
Args:
result: Dict with keys being file descriptors and values being data
(either bytes or strings). Strings are expected to already have
Expand All @@ -58,33 +58,33 @@ def write_result(self, result, compressed=False):
"""
for file_desc, data in result.items():
self.write(file_desc, data, compressed)

def write(self, file_desc, data, compressed=False):
"""Write data to output. If the specified path has not been seen before,
the output is opened.
Args:
file_desc: File descriptor. If `compressed==True`, this is a tuple
(path, mode), otherwise it's only a path.
data: The data to write.
compressed: Whether data has already been compressed.
"""
self.get_writer(file_desc, compressed).write(data)

def close(self):
"""Close all outputs.
"""
for path in self.force_create:
if path not in self.writers and path != STDOUT:
with open_output(path, "w"):
with xopen(path, "w"):
pass
for writer in self.writers.values():
if writer not in (sys.stdout, sys.stderr):
writer.close()

class Formatters(object):
"""Manages multiple formatters.
Args:
output: The output file name template.
seq_formatter_args: Additional arguments to pass to the formatter
Expand All @@ -98,24 +98,24 @@ def __init__(self, output, seq_formatter_args):
self.mux_formatters = {}
self.info_formatters = []
self.discarded = 0

def add_seq_formatter(self, filter_type, file1, file2=None):
"""Add a formatter.
Args:
filter_type: The type of filter that triggers writing with the
formatter.
file1, file2: The output file(s).
"""
self.seq_formatters[filter_type] = create_seq_formatter(
file1, file2, **self.seq_formatter_args)

def add_info_formatter(self, formatter):
"""Add a formatter for one of the delimited detail files
(rest, info, wildcard).
"""
self.info_formatters.append(formatter)

def get_mux_formatter(self, name):
"""Returns the formatter associated with the given name (barcode) when
running in multiplexed mode.
Expand All @@ -126,19 +126,19 @@ def get_mux_formatter(self, name):
self.mux_formatters[name] = create_seq_formatter(
path, **self.seq_formatter_args)
return self.mux_formatters[name]

def get_seq_formatters(self):
"""Returns a set containing all formatters that have handled at least
one record.
"""
return (
set(f for f in self.seq_formatters.values() if f.written > 0) |
set(f for f in self.mux_formatters.values() if f.written > 0))

def format(self, result, dest, read1, read2=None):
"""Format read(s) and add to a result dict. Also writes info records
to any registered info formatters.
Args:
result: The result dict.
dest: The destination (filter type).
Expand All @@ -152,12 +152,12 @@ def format(self, result, dest, read1, read2=None):
self.seq_formatters[dest].format(result, read1, read2)
else:
self.discarded += 1

for fmtr in self.info_formatters:
fmtr.format(result, read1)
if read2:
fmtr.format(result, read2)

def summarize(self):
"""Returns a summary dict.
"""
Expand All @@ -171,20 +171,20 @@ def summarize(self):

class DelimFormatter(object):
"""Base class for formatters that write to a delimited file.
Args:
path: The output file path.
delim: The field delimiter.
"""
def __init__(self, path, delim=' '):
self.path = path
self.delim = delim

def format(self, result, read):
"""Format a read and add it to `result`.
"""
raise NotImplementedError()

def _format(self, result, fields):
result[self.path].append("".join((
self.delim.join(str(f) for f in fields),
Expand All @@ -204,7 +204,7 @@ class InfoFormatter(DelimFormatter):
"""
def __init__(self, path):
super(InfoFormatter, self).__init__(path, delim='\t')

def format(self, result, read):
if read.match:
for match_info in read.match_info:
Expand Down
12 changes: 9 additions & 3 deletions tests/test_paired.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import gzip
import os
import shutil
from unittest import TestCase
Expand Down Expand Up @@ -610,10 +611,15 @@ def test_issue68(self):

@pytest.mark.timeout(10)
def test_issue122(self):
# test that the empty fastq.gz files are valid gzip files
def callback(aligner, infiles, outfiles, result):
for out in outfiles:
with gzip.open(out) as z:
assert z.read() == ""
run_paired(
"--threads 2 --preserve-order --no-default-adapters -a TTAGACATAT -A CAGTGGAGTA",
in1="empty.fastq",
in2="empty.fastq",
expected1="empty.fastq",
expected2="empty.fastq",
)
expected1="empty.fastq.gz",
expected2="empty.fastq.gz",
)
28 changes: 19 additions & 9 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from contextlib import contextmanager
from importlib import import_module
import os
from subprocess import check_output, CalledProcessError
import sys
import tempfile
import traceback
import urllib.request
from atropos.commands import get_command

from atropos.io import xopen

@contextmanager
def redirect_stderr():
Expand Down Expand Up @@ -39,19 +41,27 @@ def cutpath(path):


def files_equal(path1, path2):
# return os.system("diff -u {0} {1}".format(path1, path2)) == 0
with open(path1, 'r') as i1, open(path2, 'r') as i2:
print("<[{}]>".format(i1.read()))
print("<[{}]>".format(i2.read()))
from subprocess import check_output, CalledProcessError

temp1 = tempfile.mkstemp()[1]
temp2 = tempfile.mkstemp()[1]
try:
check_output("diff -u {0} {1}".format(path1, path2), shell=True)
with xopen(path1, 'r') as i1, xopen(path2, 'r') as i2:
# write contents to temp files in case the files are compressed
content1 = i1.read()
content2 = i2.read()
print("<[{}]>".format(content1))
print("<[{}]>".format(content2))
with open(temp1, "w") as out:
out.write(content1)
with open(temp2, "w") as out:
out.write(content2)
check_output("diff -u {0} {1}".format(temp1, temp2), shell=True)
return True

except CalledProcessError as e:
print("Diff: <{}>".format(e.output.decode("utf-8")))
return False
finally:
os.remove(temp1)
os.remove(temp2)


def run(
Expand Down

0 comments on commit b13ce11

Please sign in to comment.