Skip to content

Commit

Permalink
add visulaisation report
Browse files Browse the repository at this point in the history
  • Loading branch information
ziadbkh committed May 20, 2024
1 parent 7a183c8 commit 17eef1c
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 149 deletions.
Empty file added assets/NO_FILE
Empty file.
28 changes: 28 additions & 0 deletions bin/extract_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python
import pickle
import os, sys
import argparse

def read_pkl(id, pkl_files):
for pkl_file in pkl_files:
dict_data = pickle.load(open(pkl_file,'rb'))
#print(dict_data.keys())
if pkl_file.endswith("features.pkl"):
with open (f"{id}_msa.tsv", "w") as out_f:
for val in dict_data['msa']:
out_f.write("\t".join([str(x) for x in val]) + "\n")
else:
model_id = os.path.basename(pkl_file).replace("result_model_", "").replace("_pred_0.pkl", "")
with open (f"{id}_lddt_{model_id}.tsv", "w") as out_f:
out_f.write("\t".join([str(x) for x in dict_data['plddt']]) + "\n")


parser = argparse.ArgumentParser()
parser.add_argument('--pkls',dest='pkls',required=True, nargs="+")
parser.add_argument('--name',dest='name')
parser.add_argument('--output_dir',dest='output_dir')
parser.set_defaults(output_dir='')
parser.set_defaults(name='')
args = parser.parse_args()

read_pkl(args.name, args.pkls)
118 changes: 66 additions & 52 deletions assets/generat_plots_2.py → bin/generat_plots_2.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,69 @@
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def generate_output_images(msa_path, plddt_paths, name, out_dir):
def generate_output_images(msa_path, plddt_paths, name, out_dir, in_type):
msa = []
with open(msa_path, 'r') as in_file:
for line in in_file:
msa.append([int(x) for x in line.strip().split()])
if not msa_path.endswith("NO_FILE"):
with open(msa_path, 'r') as in_file:
for line in in_file:
msa.append([int(x) for x in line.strip().split()])

seqid = []
for sequence in msa:
matches = [1.0 if first == other else 0.0 for first, other in zip(msa[0], sequence)]
seqid.append(sum(matches) / len(matches))

seqid_sort = sorted(range(len(seqid)), key=seqid.__getitem__)
seqid = []
for sequence in msa:
matches = [1.0 if first == other else 0.0 for first, other in zip(msa[0], sequence)]
seqid.append(sum(matches) / len(matches))

non_gaps = []
for sequence in msa:
non_gaps.append([float(num != 21) if num != 21 else float('nan') for num in sequence])

sorted_non_gaps = [non_gaps[i] for i in seqid_sort]
final = []
for sorted_seq, identity in zip(sorted_non_gaps, [seqid[i] for i in seqid_sort]):
final.append([value * identity if not isinstance(value, str) else value for value in sorted_seq])
seqid_sort = sorted(range(len(seqid)), key=seqid.__getitem__)

non_gaps = []
for sequence in msa:
non_gaps.append([float(num != 21) if num != 21 else float('nan') for num in sequence])

sorted_non_gaps = [non_gaps[i] for i in seqid_sort]
final = []
for sorted_seq, identity in zip(sorted_non_gaps, [seqid[i] for i in seqid_sort]):
final.append([value * identity if not isinstance(value, str) else value for value in sorted_seq])

##################################################################
plt.figure(figsize=(14, 14), dpi=100)
##################################################################
plt.title("Sequence coverage")
plt.imshow(final,
interpolation='nearest', aspect='auto',
cmap="rainbow_r", vmin=0, vmax=1, origin='lower')

column_counts = [0] * len(msa[0])
for col in range(len(msa[0])):
for row in msa:
if row[col] != 21:
column_counts[col] += 1

plt.plot(column_counts, color='black')
plt.xlim(-0.5, len(msa[0]) - 0.5)
plt.ylim(-0.5, len(msa) - 0.5)

plt.colorbar(label="Sequence identity to query", )
plt.xlabel("Positions")
plt.ylabel("Sequences")
plt.savefig(f"{out_dir}/{name+('_' if name else '')}seq_coverage.png")
##################################################################
plt.figure(figsize=(14, 14), dpi=100)
##################################################################
plt.title("Sequence coverage")
plt.imshow(final,
interpolation='nearest', aspect='auto',
cmap="rainbow_r", vmin=0, vmax=1, origin='lower')

column_counts = [0] * len(msa[0])
for col in range(len(msa[0])):
for row in msa:
if row[col] != 21:
column_counts[col] += 1

plt.plot(column_counts, color='black')
plt.xlim(-0.5, len(msa[0]) - 0.5)
plt.ylim(-0.5, len(msa) - 0.5)

plt.colorbar(label="Sequence identity to query", )
plt.xlabel("Positions")
plt.ylabel("Sequences")
plt.savefig(f"{out_dir}/{name+('_' if name else '')}seq_coverage.png")

##################################################################

##################################################################
plddt_per_model = OrderedDict()
plddt_paths_srt = plddt_paths
plddt_paths_srt.sort()
for plddt_path in plddt_paths_srt:
with open(plddt_path, 'r') as in_file:
plddt_per_model[os.path.basename(plddt_path)[:-4]] = [float(x) for x in in_file.read().strip().split()]
if in_type == "ESM-FOLD":
plddt_per_model[os.path.basename(plddt_path)[:-4]] = []
in_file.readline()
for line in in_file:
vals = line.strip().split()
#print(vals)
if len(vals) == 5:
plddt_per_model[os.path.basename(plddt_path)[:-4]].append(float(vals[-1].strip()))
else:
plddt_per_model[os.path.basename(plddt_path)[:-4]] = [float(x) for x in in_file.read().strip().split()]

plt.figure(figsize=(14, 14), dpi=100)
plt.title("Predicted LDDT per position")
Expand Down Expand Up @@ -104,7 +113,7 @@ def generate_output_images(msa_path, plddt_paths, name, out_dir):
plt.savefig(f"{out_dir}/{name+('_' if name else '')}PAE.png")
"""
##################################################################


def generate_plots(msa_path, plddt_paths, name, out_dir):
msa = []
Expand Down Expand Up @@ -174,18 +183,20 @@ def generate_plots(msa_path, plddt_paths, name, out_dir):

print("Starting..")
parser = argparse.ArgumentParser()
parser.add_argument('--msa',dest='msa',required=True)
parser.add_argument('--plddt',dest='plddt',required=True, nargs="+")
parser.add_argument('--pdb',dest='pdb',required=True, nargs="+")
parser.add_argument('--name',dest='name')
parser.add_argument('--type', dest='in_type')
parser.add_argument('--msa', dest='msa',required=True)
parser.add_argument('--plddt', dest='plddt',required=True, nargs="+")
parser.add_argument('--pdb', dest='pdb',required=True, nargs="+")
parser.add_argument('--name', dest='name')
parser.add_argument('--output_dir',dest='output_dir')
parser.add_argument('--html_template',dest='html_template')
parser.set_defaults(output_dir='')
parser.set_defaults(in_type='ESM-FOLD')
parser.set_defaults(name='')
args = parser.parse_args()


generate_output_images(args.msa, args.plddt, args.name, args.output_dir)
generate_output_images(args.msa, args.plddt, args.name, args.output_dir, args.in_type)

#generate_plots(args.msa, args.plddt, args.name, args.output_dir)

Expand All @@ -202,10 +213,13 @@ def generate_plots(msa_path, plddt_paths, name, out_dir):
i += 1

if True:
with open(f"{args.output_dir}/{args.name + ('_' if args.name else '')}seq_coverage.png", "rb") as in_file:
alphfold_template = alphfold_template.replace(f"seq_coverage.png", f"data:image/png;base64,{base64.b64encode(in_file.read()).decode('utf-8')}")

for i in range(0, 5):
if not args.msa.endswith("NO_FILE"):
with open(f"{args.output_dir}/{args.name + ('_' if args.name else '')}seq_coverage.png", "rb") as in_file:
alphfold_template = alphfold_template.replace("seq_coverage.png", f"data:image/png;base64,{base64.b64encode(in_file.read()).decode('utf-8')}")
else:
alphfold_template = alphfold_template.replace("seq_coverage.png","")

for i in range(0, len(args.plddt)):
with open(f"{args.output_dir}/{args.name + ('_' if args.name else '')}coverage_LDDT_{i}.png", "rb") as in_file:
alphfold_template = alphfold_template.replace(f"coverage_LDDT_{i}.png", f"data:image/png;base64,{base64.b64encode(in_file.read()).decode('utf-8')}")

Expand Down
26 changes: 16 additions & 10 deletions conf/gadi.config
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,28 @@ params {
input = 'https://raw.githubusercontent.com/nf-core/test-datasets/proteinfold/testdata/samplesheet/v1.0/samplesheet.csv'
alphafold2_db = '/g/data/if89/alphafold2/standard/'
use_dgxa100 = false
esmfold_params_path = '/g/data/if89/esm-fold/checkpoints'
}

process {
storage = "gdata/if89+scratch/${params.project}"
storage = "gdata/if89+scratch/${params.project}+gdata/${params.project}"

if (params.use_gpu) {
withName: 'RUN_ALPHAFOLD2_PRED|RUN_ALPHAFOLD2' {
if (params.use_dgxa100){
queue = "dgxa100"
cpus = 16
}else{
queue = "gpuvolta"
cpus = 12
}
gpus = 1

if (params.use_dgxa100){
queue = "dgxa100"
cpus = 16
}else{
queue = "gpuvolta"
cpus = 12
}
gpus = 1
}
}

withName: 'RUN_ESMFOLD' {
queue = "copyq"
cpus = 1
time = 10.h
}
}
2 changes: 1 addition & 1 deletion conf/test_full_esmfold.config
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ params {
mode = 'esmfold'
esmfold_model_preset = 'monomer'
input = 'https://raw.githubusercontent.com/nf-core/test-datasets/proteinfold/testdata/samplesheet/v1.0/samplesheet.csv'
esmfold_db = 's3://proteinfold-dataset/test-data/db/esmfold'
//esmfold_db = 's3://proteinfold-dataset/test-data/db/esmfold'
}
86 changes: 42 additions & 44 deletions main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -63,53 +63,50 @@ workflow NFCORE_PROTEINFOLD {
if(params.mode == "alphafold2") {
//
// SUBWORKFLOW: Prepare Alphafold2 DBs
//
PREPARE_ALPHAFOLD2_DBS (
params.alphafold2_db,
params.full_dbs,
params.bfd_path,
params.small_bfd_path,
params.alphafold2_params_path,
params.mgnify_path,
params.pdb70_path,
params.pdb_mmcif_path,
params.uniref30_alphafold2_path,
params.uniref90_path,
params.pdb_seqres_path,
params.uniprot_path,
params.bfd_link,
params.small_bfd_link,
params.alphafold2_params_link,
params.mgnify_link,
params.pdb70_link,
params.pdb_mmcif_link,
params.pdb_obsolete_link,
params.uniref30_alphafold2_link,
params.uniref90_link,
params.pdb_seqres_link,
params.uniprot_sprot_link,
params.uniprot_trembl_link
)
ch_versions = ch_versions.mix(PREPARE_ALPHAFOLD2_DBS.out.versions)

//
// WORKFLOW: Run nf-core/alphafold2 workflow
//

ch_params = Channel.fromPath( params.alphafold2_params_path )
ch_mgnify = Channel.fromPath( params.mgnify_path )
ch_pdb70 = Channel.fromPath( params.pdb70_path, type: 'dir' )
ch_mmcif_files = Channel.fromPath( params.pdb_mmcif_path, type: 'dir' )
ch_mmcif_obsolete = Channel.fromPath( params.pdb_mmcif_path, type: 'file' )
ch_mmcif = ch_mmcif_files.mix(ch_mmcif_obsolete)
ch_uniref30 = Channel.fromPath( params.uniref30_alphafold2_path, type: 'any' )
ch_uniref90 = Channel.fromPath( params.uniref90_path )
ch_pdb_seqres = Channel.fromPath( params.pdb_seqres_path )
ch_uniprot = Channel.fromPath( params.uniprot_path )
ch_small_bfd = Channel.fromPath( params.small_bfd_path)
ch_bfd = Channel.fromPath( params.bfd_path)

/*ch_params.view()
ch_params.first().view()
ch_bfd.ifEmpty([]).first().view()
ch_small_bfd.ifEmpty([]).first().view()
ch_uniref90.first().view()
ch_pdb_seqres.first().view()
ch_uniprot.first().view()
*/


ALPHAFOLD2 (
ch_versions,
params.full_dbs,
params.alphafold2_mode,
params.alphafold2_model_preset,
PREPARE_ALPHAFOLD2_DBS.out.params.first(),
PREPARE_ALPHAFOLD2_DBS.out.bfd.ifEmpty([]).first(),
PREPARE_ALPHAFOLD2_DBS.out.small_bfd.ifEmpty([]).first(),
PREPARE_ALPHAFOLD2_DBS.out.mgnify.first(),
PREPARE_ALPHAFOLD2_DBS.out.pdb70.first(),
PREPARE_ALPHAFOLD2_DBS.out.pdb_mmcif.first(),
PREPARE_ALPHAFOLD2_DBS.out.uniref30.first(),
PREPARE_ALPHAFOLD2_DBS.out.uniref90.first(),
PREPARE_ALPHAFOLD2_DBS.out.pdb_seqres.first(),
PREPARE_ALPHAFOLD2_DBS.out.uniprot.first()
ch_params.toList(),
ch_bfd.ifEmpty([]).first(),
ch_small_bfd.ifEmpty([]).first(),
ch_mgnify.first(),
ch_pdb70.first(),
ch_mmcif.toList(),
ch_uniref30.toList(),
ch_uniref90.first(),
ch_pdb_seqres.first(),
ch_uniprot.first()
)
ch_multiqc = ALPHAFOLD2.out.multiqc_report
ch_versions = ch_versions.mix(ALPHAFOLD2.out.versions)
Expand Down Expand Up @@ -157,21 +154,22 @@ workflow NFCORE_PROTEINFOLD {
//
// SUBWORKFLOW: Prepare esmfold DBs
//
PREPARE_ESMFOLD_DBS (
/*PREPARE_ESMFOLD_DBS (
params.esmfold_db,
params.esmfold_params_path,
params.esmfold_3B_v1,
params.esm2_t36_3B_UR50D,
params.esm2_t36_3B_UR50D_contact_regression
)
ch_versions = ch_versions.mix(PREPARE_ESMFOLD_DBS.out.versions)
)*/

//ch_versions = ch_versions.mix(PREPARE_ESMFOLD_DBS.out.versions)

//
// WORKFLOW: Run nf-core/esmfold workflow
//
ESMFOLD (
ch_versions,
PREPARE_ESMFOLD_DBS.out.params,
params.esmfold_params_path,
params.num_recycle
)
ch_multiqc = ESMFOLD.out.multiqc_report
Expand Down
18 changes: 9 additions & 9 deletions modules/local/generat_report.nf
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
process GENERATE_REPORT {
tag "$id"
tag "${meta.id}"
label 'process_single'
container "${ workflow.containerEngine == 'singularity' && !task.ext.singularity_pull_docker_container ?
'https://depot.galaxyproject.org/singularity/multiqc:1.21--pyhdfd78af_0' :
'biocontainers/multiqc:1.21--pyhdfd78af_0' }"

input:
tuple val(id), path(msa)
tuple val(id), path(lddt)
tuple val(id), path(pdb)
tuple val(meta_msa), path(msa)
tuple val(meta), path(lddt)
tuple val(meta), path(pdb)
path(template)
path(script)
val(output_type)

output:
tuple val(id), path ("*.html"), emit: report
tuple val(id), path ("*.png"), emit: images
tuple val(meta), path ("*.html"), emit: report
tuple val(meta), path ("*.png"), emit: images
//path "versions.yml", emit: versions

when:
Expand All @@ -22,7 +23,6 @@ process GENERATE_REPORT {
def args = task.ext.args ?: ''

"""
#export MPLCONFIGDIR=\$PBS_JOBFS
python ./generat_plots_2.py --msa ${msa} --plddt ${lddt.join(' ')} --pdb ${pdb.join(' ')} --html_template ${template} --output_dir ./ --name ${id} || true
generat_plots_2.py --type ${output_type} --msa ${msa} --plddt ${lddt.join(' ')} --pdb ${pdb.join(' ')} --html_template ${template} --output_dir ./ --name ${meta.id}
"""
}
Loading

0 comments on commit 17eef1c

Please sign in to comment.