diff --git a/src/LocalCommandDeclarations.h b/src/LocalCommandDeclarations.h index 689d039a..34af9d2d 100644 --- a/src/LocalCommandDeclarations.h +++ b/src/LocalCommandDeclarations.h @@ -17,5 +17,6 @@ extern int databaseReport(int argc, const char **argv, const Command& command); extern int mapping2taxon(int argc, const char **argv, const Command& command); extern int expand_diffidx(int argc, const char **argv, const Command& command); extern int makeAAoffset(int argc, const char **argv, const Command& command); +extern int extract(int argc, const char **argv, const Command& command); #endif //ADCLASSIFIER2_LOCALCOMMANDDECLARATIONS_H diff --git a/src/commons/KmerExtractor.cpp b/src/commons/KmerExtractor.cpp index 593ebe16..6988edaf 100644 --- a/src/commons/KmerExtractor.cpp +++ b/src/commons/KmerExtractor.cpp @@ -317,7 +317,10 @@ void KmerExtractor::loadChunkOfReads(KSeqWrapper *kseq, kseq->ReadEntry(); queryList[processedQueryNum].queryLength2 = LocalUtil::getMaxCoveredLength((int) kseq->entry.sequence.l); - if (emptyReads[i]) { continue; } + if (emptyReads[i]) { + count ++; + continue; + } // Check if the read is too short int kmerCnt = LocalUtil::getQueryKmerNumber((int) kseq->entry.sequence.l, spaceNum); diff --git a/src/commons/LocalParameters.cpp b/src/commons/LocalParameters.cpp index 77db0e6a..29667bbd 100644 --- a/src/commons/LocalParameters.cpp +++ b/src/commons/LocalParameters.cpp @@ -159,6 +159,13 @@ LocalParameters::LocalParameters() : typeid(float), (void *) &tieRatio, "^0(\\.[0-9]+)?|1(\\.0+)?$"), + TARGET_TAX_ID(TARGET_TAX_ID_ID, + "--tax-id", + "Tax. ID of clade to be extracted", + "Tax. ID of clade to be extracted", + typeid(int), + (void *) &targetTaxId, + "^[0-9]+$"), LIBRARY_PATH(LIBRARY_PATH_ID, "--library-path", "Path to library where the FASTA files are stored", @@ -385,6 +392,11 @@ LocalParameters::LocalParameters() : classify.push_back(&TIE_RATIO); // classify.push_back(&MIN_SS_MATCH); + // extract + extract.push_back(&TAXONOMY_PATH); + extract.push_back(&SEQ_MODE); + extract.push_back(&TARGET_TAX_ID); + // filter filter.push_back(&PARAM_THREADS); filter.push_back(&SEQ_MODE); diff --git a/src/commons/LocalParameters.h b/src/commons/LocalParameters.h index d778cd0b..5203e281 100644 --- a/src/commons/LocalParameters.h +++ b/src/commons/LocalParameters.h @@ -22,6 +22,7 @@ class LocalParameters : public Parameters { } std::vector classify; + std::vector extract; std::vector filter; std::vector exclusiontest_hiv; std::vector seqHeader2TaxId; @@ -59,6 +60,9 @@ class LocalParameters : public Parameters { PARAMETER(MIN_SS_MATCH) PARAMETER(TIE_RATIO) + // extract + PARAMETER(TARGET_TAX_ID) + // DB build parameters PARAMETER(LIBRARY_PATH) PARAMETER(TAXONOMY_PATH) @@ -109,6 +113,9 @@ class LocalParameters : public Parameters { int minSSMatch; float tieRatio; + // Extract + int targetTaxId; + // Database creation std::string tinfoPath; std::string libraryPath; diff --git a/src/commons/Reporter.cpp b/src/commons/Reporter.cpp index 7c813058..23f6b663 100644 --- a/src/commons/Reporter.cpp +++ b/src/commons/Reporter.cpp @@ -2,6 +2,7 @@ #include "taxonomyreport.cpp" Reporter::Reporter(const LocalParameters &par, NcbiTaxonomy *taxonomy) : taxonomy(taxonomy) { + if (par.targetTaxId != 0) {return;} if (par.contamList == "") { // classify module if (par.seqMode == 2) { outDir = par.filenames[3]; @@ -13,10 +14,7 @@ Reporter::Reporter(const LocalParameters &par, NcbiTaxonomy *taxonomy) : taxonom // Output file names reportFileName = outDir + + "/" + jobId + "_report.tsv"; readClassificationFileName = outDir + "/" + jobId + "_classifications.tsv"; - } - - - + } } void Reporter::openReadClassificationFile() { @@ -105,4 +103,90 @@ unsigned int Reporter::cladeCountVal(const std::unordered_mapsecond.cladeCount; } +} + +void Reporter::getReadsClassifiedToClade(TaxID cladeId, + const string &readClassificationFileName, + vector &readIdxs) { + FILE *results = fopen(readClassificationFileName.c_str(), "r"); + if (!results) { + perror("Failed to open read-by-read classification file"); + return; + } + + char line[4096]; + size_t idx = 0; + // int classification; + + while (fgets(line, sizeof(line), results)) { + int taxId; + if (sscanf(line, "%*s %*s %d", &taxId) == 1) { + if (taxonomy->IsAncestor(cladeId, taxId)) { + readIdxs.push_back(idx); + } + } + idx++; + } + + fclose(results); +} + +void Reporter::printSpecifiedReads(const vector & readIdxs, + const string & readFileName, + const string & outFileName) { + // Check FASTA or FASTQ + KSeqWrapper* tempKseq = KSeqFactory(readFileName.c_str()); + tempKseq->ReadEntry(); + bool isFasta = tempKseq->entry.qual.l == 0; + delete tempKseq; + + KSeqWrapper* kseq = KSeqFactory(readFileName.c_str()); + + FILE *outFile = fopen(outFileName.c_str(), "w"); + if (!outFile) { + perror("Failed to open file"); + return; + } + + size_t readCnt = 0; + size_t idx = 0; + + if (isFasta) { + while (kseq->ReadEntry()) { + if (readCnt == readIdxs[idx]) { + fprintf(outFile, ">%s\n%s\n", kseq->entry.name.s, kseq->entry.sequence.s); + idx++; + if (idx == readIdxs.size()) { + break; + } + } + readCnt++; + } + } else { + while (kseq->ReadEntry()) { + if (readCnt == readIdxs[idx]) { + fprintf(outFile, "@%s", kseq->entry.name.s); + if (kseq->entry.comment.l > 0) { + fprintf(outFile, " %s\n", kseq->entry.comment.s); + } else { + fprintf(outFile, "\n"); + } + fprintf(outFile, "%s\n", kseq->entry.sequence.s); + fprintf(outFile, "+%s", kseq->entry.name.s); + if (kseq->entry.comment.l > 0) { + fprintf(outFile, " %s\n", kseq->entry.comment.s); + } else { + fprintf(outFile, "\n"); + } + fprintf(outFile, "%s\n", kseq->entry.qual.s); + + idx++; + if (idx == readIdxs.size()) { + break; + } + } + readCnt++; + } + } + delete kseq; } \ No newline at end of file diff --git a/src/commons/Reporter.h b/src/commons/Reporter.h index 4745bf26..5b5af7e6 100644 --- a/src/commons/Reporter.h +++ b/src/commons/Reporter.h @@ -6,10 +6,10 @@ #include #include "NcbiTaxonomy.h" #include "LocalParameters.h" +#include "KSeqWrapper.h" using namespace std; - class Reporter { private: string outDir; @@ -35,6 +35,15 @@ class Reporter { unsigned int cladeCountVal(const std::unordered_map &map, TaxID key); + // Extract reads classified to a specific clade + void getReadsClassifiedToClade(TaxID cladeId, + const string &readClassificationFileName, + vector &readIdxs); + + void printSpecifiedReads(const vector & readIdxs, + const string & readFileName, + const string & outFileName); + // Setter void setReportFileName(const string &reportFileName) { Reporter::reportFileName = reportFileName; diff --git a/src/metabuli.cpp b/src/metabuli.cpp index b523abce..4326ecfb 100644 --- a/src/metabuli.cpp +++ b/src/metabuli.cpp @@ -69,10 +69,19 @@ std::vector commands = { "Jaebeom Kim ", " ", CITATION_SPACEPHARER, - {{"FASTA", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA | DbType::VARIADIC, &DbValidator::flatfile}, + {{"FASTA/Q", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA | DbType::VARIADIC, &DbValidator::flatfile}, {"DB dir", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::directory}, {"out dir", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::directory}, {"job ID", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::flatfile}}}, + {"extract", extract, &localPar.extract, COMMAND_MAIN, + "It extracts reads classified into a certain taxon. It should be used after classification.", + nullptr, + "Jaebeom Kim ", + " ", + CITATION_SPACEPHARER, + {{"FASTA/Q", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA | DbType::VARIADIC, &DbValidator::flatfile}, + {"read-by-read result", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::flatfile}, + {"DB dir", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::directory}}}, {"grade", grade, &localPar.grade, COMMAND_EXPERT, "Grade the classification result (only for benchmarking)", nullptr, diff --git a/src/workflow/CMakeLists.txt b/src/workflow/CMakeLists.txt index 66d1fb0c..2f1589e1 100644 --- a/src/workflow/CMakeLists.txt +++ b/src/workflow/CMakeLists.txt @@ -4,4 +4,5 @@ set(workflow_source_files workflow/add_to_library.cpp workflow/build.cpp workflow/filter.cpp + workflow/extract.cpp PARENT_SCOPE) \ No newline at end of file diff --git a/src/workflow/extract.cpp b/src/workflow/extract.cpp new file mode 100644 index 00000000..21dc44f9 --- /dev/null +++ b/src/workflow/extract.cpp @@ -0,0 +1,93 @@ +#include "LocalParameters.h" +#include "FileUtil.h" +#include "common.h" +#include "Reporter.h" + +void setExtractDefaults(LocalParameters & par){ + par.taxonomyPath = "" ; + par.seqMode = 2; + par.targetTaxId = 0; +} + +int extract(int argc, const char **argv, const Command& command) +{ + LocalParameters & par = LocalParameters::getLocalInstance(); + setExtractDefaults(par); + par.parseParameters(argc, argv, command, true, Parameters::PARSE_ALLOW_EMPTY, 0); + + if (par.seqMode == 2) { + // Check if the second argument is a directory + if (FileUtil::directoryExists(par.filenames[2].c_str())) { + cerr << "For '--seq-mode 2', please provide two query files." << endl; + exit(1); + } + } else { + // Check if the second argument is file + if (!FileUtil::directoryExists(par.filenames[2].c_str())) { + cerr << "For '--seq-mode 1' and '--seq-mode 3', please provide one query file." << endl; + exit(1); + } + } + + if (par.targetTaxId == 0) { + cerr << "Please provide a target taxon ID with --tax-id parameter." << endl; + exit(1); + } + + string classificationFileName = par.filenames[1 + (par.seqMode == 2)]; + string dbDir = par.filenames[2 + (par.seqMode == 2)]; + TaxID targetTaxID = par.targetTaxId; + + cout << "Loading taxonomy ... " << endl; + NcbiTaxonomy *taxonomy = loadTaxonomy(dbDir, par.taxonomyPath); + Reporter reporter(par, taxonomy); + + vector readIdxs; + + cout << "Extracting reads classified to taxon " << targetTaxID << " ... " << flush; + reporter.getReadsClassifiedToClade(targetTaxID, classificationFileName, readIdxs); + cout << "done." << endl; + + string queryFileName = par.filenames[0]; + size_t lastDotPos = queryFileName.find_last_of('.'); + string baseName = ""; + string extension = ""; + + if (queryFileName.substr(lastDotPos) == ".gz") { + lastDotPos = queryFileName.substr(0, lastDotPos).find_last_of('.'); + baseName = queryFileName.substr(0, lastDotPos); + extension = queryFileName.substr(lastDotPos); + // Remove the last ".gz" from the extension + extension = extension.substr(0, extension.length() - 3); + } else { + baseName = queryFileName.substr(0, lastDotPos); + extension = queryFileName.substr(lastDotPos); + } + string outFileName = baseName + "_" + to_string(targetTaxID) + extension; + + cout << "Writing extracted reads to " << outFileName << " ... " << flush; + reporter.printSpecifiedReads(readIdxs, queryFileName, outFileName); + cout << "done." << endl; + + if (par.seqMode == 2) { + queryFileName = par.filenames[1]; + lastDotPos = queryFileName.find_last_of('.'); + + if (queryFileName.substr(lastDotPos) == ".gz") { + lastDotPos = queryFileName.substr(0, lastDotPos).find_last_of('.'); + baseName = queryFileName.substr(0, lastDotPos); + extension = queryFileName.substr(lastDotPos); + extension = extension.substr(0, extension.length() - 3); + } else { + baseName = queryFileName.substr(0, lastDotPos); + extension = queryFileName.substr(lastDotPos); + } + outFileName = baseName + "_" + to_string(targetTaxID) + extension; + cout << "Writing extracted reads to " << outFileName << " ... " << flush; + reporter.printSpecifiedReads(readIdxs, queryFileName, outFileName); + cout << "done." << endl; + } + + delete taxonomy; + return 0; +} \ No newline at end of file