Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
bacf5ce
returning dgram
gezmi Feb 11, 2024
f467bdf
added calculation for interface ptm
gezmi Feb 11, 2024
44ded08
returning if_ptm between all chain combinations
gezmi Feb 11, 2024
864ae44
correct import of get_ifptm
gezmi Feb 11, 2024
ffa0c74
corrected if_ptm calculations
gezmi Feb 11, 2024
e48b3d7
updated if_ptm function call
gezmi Feb 11, 2024
ac78513
make get_ifptm compatible with jax
gezmi Feb 11, 2024
8478135
updated chain_indices function compatible with jax
gezmi Feb 11, 2024
673152e
getting chain indices from other source
gezmi Feb 11, 2024
db3fe86
geeting chain indices from other source
gezmi Feb 11, 2024
30a4e50
Update model.py
gezmi Feb 11, 2024
1cba10f
removed if_ptm code from model
gezmi Feb 11, 2024
43ccdab
working get_ifptm
gezmi Feb 11, 2024
62b45be
working get_ifptm
gezmi Feb 11, 2024
34c081a
get_ifptm works but gives unexpected results
gezmi Feb 17, 2024
c7d1030
modified confidence to debug if_ptm
gezmi Feb 17, 2024
5972815
Update confidence.py
gezmi Feb 19, 2024
0a88639
updated ifptm in case no interface exists
gezmi Feb 19, 2024
d08bc19
Update loss.py
gezmi Apr 3, 2024
e3201a2
Update loss.py
gezmi Apr 9, 2024
1228576
Update model.py
gezmi Apr 9, 2024
24c68f5
Update confidence.py
gezmi Apr 9, 2024
9a72ff5
Merge branch 'sokrypton:gamma' into gamma
gezmi Apr 9, 2024
fee5405
Revert "fixing jax.random.PRNGKeyArray error"
gezmi Apr 9, 2024
ad556f7
Revert "Update confidence.py"
gezmi Apr 9, 2024
ebc7c96
Revert "Update model.py"
gezmi Apr 9, 2024
3027899
Revert "Update loss.py"
gezmi Apr 9, 2024
0455f6e
Revert "Revert "fixing jax.random.PRNGKeyArray error""
gezmi Apr 9, 2024
7594bec
remove error from confidence
gezmi Apr 10, 2024
07e4d0d
working function for getting if_ptm
gezmi Apr 11, 2024
83bbb0b
ptm updated to accommodate if_ptm
gezmi Apr 11, 2024
3bb4a24
indentation correct
gezmi Apr 11, 2024
c1ccdff
update to accept custom pair_mask
gezmi Jul 7, 2024
dff410d
update with pairwise iptm and ifptm
gezmi Jul 7, 2024
f4027d7
added 2 different ifptm calculations
gezmi Jul 7, 2024
c8cf98f
calculate tm_score with contact probabilities
gezmi Jul 9, 2024
30dc15b
update loss for ifptm
gezmi Jul 9, 2024
694cc6b
probability mask usage corrected
gezmi Jul 9, 2024
1cfa7c3
changed input to ifptm calculation
gezmi Jul 9, 2024
e2c8414
changing back to pair-residue mask
gezmi Jul 9, 2024
77f4426
Update loss.py with correct input
gezmi Jul 9, 2024
e1172ce
Update confidence.py
gezmi Jul 9, 2024
61a0c59
Created using Colab
gezmi Oct 27, 2024
73be301
added interface metrics to colabdesign
Oct 27, 2024
4ae4f82
Merge remote-tracking branch 'origin/gamma' into gamma
Oct 27, 2024
c312f53
added interface analysis
gezmi Oct 28, 2024
e7f16c7
added interface metrics to colabdesign
Oct 28, 2024
3806a99
Merge remote-tracking branch 'origin/gamma' into gamma
Oct 28, 2024
ff0249b
added chain ptm calculation
Oct 28, 2024
f695435
finished new ptm plots
gezmi Oct 28, 2024
e617dd2
corrected upper and lower triangle
gezmi Oct 28, 2024
d1b6721
renamed cell, cleared output
gezmi Oct 28, 2024
3afe6dc
renamed to actifptm
Oct 29, 2024
803216a
Merge remote-tracking branch 'origin/gamma' into gamma
Oct 29, 2024
4557594
created plotting for pariwise ptm-s and chain ptm
gezmi Oct 29, 2024
860122e
Rename chains_ptms.py to extended_metrics.py
gezmi Nov 5, 2024
5b30946
adding extended metrics from colabfold
gezmi Nov 5, 2024
752e475
Delete extended_metrics, replace from colabfold
gezmi Nov 5, 2024
d14e96a
fixed calculation of extended metrics
gezmi Nov 5, 2024
f155762
running all models by default
gezmi Nov 20, 2024
053c439
Merge branch 'sokrypton:gamma' into gamma
gezmi Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions af/examples/predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"colab": {
"provenance": [],
"gpuType": "T4",
"machine_shape": "hm",
"include_colab_link": true
},
"kernelspec": {
Expand All @@ -24,7 +25,7 @@
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabDesign/blob/gamma/af/examples/predict.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/gezmi/ColabDesign/blob/gamma/af/examples/predict.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -56,9 +57,10 @@
" aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n",
" tar -xf alphafold_params_2022-12-06.tar -C params; touch params/done.txt )&\")\n",
"\n",
" os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@gamma\")\n",
" os.system(\"pip -q install git+https://github.com/gezmi/ColabDesign.git@gamma\") # TODO: change to sokrypton\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n",
" os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py\")\n",
" os.system(\"wget https://raw.githubusercontent.com/gezmi/ColabFold/actifptm/colabfold/alphafold/extended_metrics.py -O extended_metrics.py\") # TODO: change to sokrypton\n",
" #os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py\")\n",
"\n",
" # install hhsuite\n",
Expand All @@ -83,7 +85,7 @@
"from colabdesign.af.contrib.cyclic import add_cyclic_offset\n",
"from colabdesign.shared.protein import _np_rmsd, _np_kabsch\n",
"from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap\n",
"\n",
"import extended_metrics\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -190,7 +192,7 @@
"cell_type": "code",
"source": [
"#@title prep_inputs\n",
"sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n",
"sequence = \"PASQHFLSTSVQGPWERAISPNKVPYYINHETQTTCWDHPKM:KNMTPYRSPPPYVPP\" #@param {type:\"string\"}\n",
"jobname = \"test\" #@param {type:\"string\"}\n",
"\n",
"copies = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"11\", \"12\"] {type:\"raw\"}\n",
Expand Down Expand Up @@ -403,7 +405,7 @@
"cell_type": "code",
"source": [
"#@title pre_analysis (optional)\n",
"analysis = \"none\" # @param [\"none\", \"coevolution\"]\n",
"analysis = \"coevolution\" # @param [\"none\", \"coevolution\"]\n",
"dpi = 100 # @param [\"100\", \"200\", \"300\"] {type:\"raw\"}\n",
"if analysis == \"coevolution\":\n",
" coevol = get_coevolution(msa)\n",
Expand Down Expand Up @@ -434,6 +436,13 @@
"num_msa = 512 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\"] {type:\"raw\"}\n",
"num_extra_msa = 1024 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\", \"1024\",\"2048\",\"4096\"] {type:\"raw\"}\n",
"use_cluster_profile = True #@param {type:\"boolean\"}\n",
"\n",
"#@markdown Extended metrics (calculate pairwise ipTM, actifpTM and chain pTM)\n",
"calc_extended_metrics = True #@param {type:\"boolean\"}\n",
"\n",
"if extended_metrics:\n",
" debug=True\n",
"\n",
"if model_type == \"monomer (ptm)\":\n",
" use_multimer = False\n",
" pseudo_multimer = False\n",
Expand Down Expand Up @@ -526,8 +535,8 @@
" add_cyclic_offset(af,i_cyclic)"
],
"metadata": {
"cellView": "form",
"id": "0AefVJipkQe3"
"id": "0AefVJipkQe3",
"cellView": "form"
},
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -617,6 +626,9 @@
" tag = f\"{model}_r{recycle}_seed{seed}\"\n",
" if select_best_across_recycles:\n",
" info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n",
" if calc_extended_metrics:\n",
" extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n",
" info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n",
" af._save_results(save_best=True,\n",
" best_metric=rank_by, metric_higher_better=True,\n",
" verbose=False)\n",
Expand All @@ -627,6 +639,9 @@
"\n",
" if not select_best_across_recycles:\n",
" info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n",
" if calc_extended_metrics:\n",
" extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n",
" info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n",
" af._save_results(save_best=True,\n",
" best_metric=rank_by, metric_higher_better=True,\n",
" verbose=False)\n",
Expand All @@ -648,12 +663,18 @@
" plddt=aux_best[\"plddt\"].astype(np.float16),\n",
" pae=aux_best[\"pae\"].astype(np.float16),\n",
" tag=np.array(info[rank[0]][0]),\n",
" metrics=np.array(info[rank[0]][1]))\n",
" metrics=np.array(info[rank[0]][1]),\n",
" iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n",
" actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n",
" cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n",
"np.savez_compressed(f\"{pdb_path}/all.npz\",\n",
" plddt=np.array(af._tmp[\"traj\"][\"plddt\"], dtype=np.float16),\n",
" pae=np.array(af._tmp[\"traj\"][\"pae\"], dtype=np.float16),\n",
" tag=np.array([x[0] for x in info]),\n",
" metrics=np.array([x[1] for x in info]))\n",
" metrics=np.array([x[1] for x in info]),\n",
" iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n",
" actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n",
" cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n",
"plot_3D(aux_best, Ls * copies, f\"{pdb_path}/best.pdf\", show=False)\n",
"predict.plot_confidence(aux_best[\"plddt\"]*100, aux_best[\"pae\"], Ls * copies)\n",
"plt.savefig(f\"{pdb_path}/best.png\", dpi=200, bbox_inches='tight')\n",
Expand All @@ -663,8 +684,8 @@
"print(\"GC\",gc.collect())"
],
"metadata": {
"cellView": "form",
"id": "xYXcKFPQyTQU"
"id": "xYXcKFPQyTQU",
"cellView": "form"
},
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -785,6 +806,20 @@
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title chainwise_and_pairwise_analysis\n",
"if calc_extended_metrics:\n",
" extended_metrics.plot_chain_pairwise_analysis(info, prefix='model_')"
],
"metadata": {
"id": "WoLg1yqYDbiX",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
Expand Down
17 changes: 10 additions & 7 deletions colabdesign/af/alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute_predicted_aligned_error(logits, breaks, use_jnp=False):
}

def predicted_tm_score(logits, breaks, residue_weights = None,
asym_id = None, use_jnp=False):
asym_id = None, use_jnp=False, pair_residue_weights=None):
"""Computes predicted TM alignment or predicted interface TM alignment score.

Args:
Expand All @@ -122,6 +122,7 @@ def predicted_tm_score(logits, breaks, residue_weights = None,
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation.
pair_residue_weights: [num_res, num_res] unnormalized weights for ifptm calculation

Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
Expand All @@ -135,23 +136,23 @@ def predicted_tm_score(logits, breaks, residue_weights = None,
# exp. resolved head's probability.
if residue_weights is None:
residue_weights = _np.ones(logits.shape[0])

bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp)
num_res = residue_weights.shape[0]

# Clip num_res to avoid negative/undefined d0.
clipped_num_res = _np.maximum(residue_weights.sum(), 19)
clipped_num_res = _np.maximum(num_res, 19)

# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8

# Convert logits to probs.
probs = _softmax(logits, axis=-1)

# TM-Score term for every bin.
tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0))

# E_distances tm(distance).
predicted_tm_term = (probs * tm_per_bin).sum(-1)

Expand All @@ -162,8 +163,10 @@ def predicted_tm_score(logits, breaks, residue_weights = None,

predicted_tm_term *= pair_mask

pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
# If normed_residue_mask is provided (e.g. for if_ptm with contact probabilities),
# it should not be overwritten
if pair_residue_weights is None:
pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True))
per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1)

return (per_alignment * residue_weights).max()
return (per_alignment * residue_weights).max()
90 changes: 89 additions & 1 deletion colabdesign/af/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,92 @@ def or_masks(*m):
mask = m[0]
for n in range(1,len(m)):
mask = jnp.logical_or(mask,m[n])
return mask
return mask

def numpy_callback(x):
# Need to forward-declare the shape & dtype of the expected output.
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
return jax.pure_callback(np.sin, result_shape, x)

def get_chain_indices(chain_boundaries):
"""Returns a list of tuples indicating the start and end indices for each chain."""
chain_starts_ends = []
unique_chains = np.unique(chain_boundaries)
for chain in unique_chains:
positions = np.where(chain_boundaries == chain)[0]
chain_starts_ends.append((positions[0], positions[-1]))
return chain_starts_ends

def get_pairwise_iptm(af, calculate_interface=False):
import string
from copy import deepcopy
cmap = get_contact_map(af.aux['debug']['outputs'], 8) # Define interface with 8A between Cb-s

# Initialize seq_mask to all False
inputs_ifptm = {}
input_pairwise_iptm = {}
pairwise_iptm = {}

# Prepare a dictionary to collect results
pairwise_if_ptm = {}
chain_starts_ends = get_chain_indices(af._inputs['asym_id'])

# Generate chain labels (A, B, C, ...)
chain_labels = list(string.ascii_uppercase)

total_length = len(af._inputs['asym_id'])
for i, (start_i, end_i) in enumerate(chain_starts_ends):
chain_label_i = chain_labels[i % len(chain_labels)] # Wrap around if more than 26 chains
for j, (start_j, end_j) in enumerate(chain_starts_ends):
chain_label_j = chain_labels[j % len(chain_labels)] # Wrap around if more than 26 chains
if i < j: # Avoid self-comparison and duplicate comparisons
outputs = deepcopy(af.aux['debug']['outputs'])
key = f"{chain_label_i}-{chain_label_j}"

if calculate_interface:
contacts = np.where(cmap[start_i:end_i+1, start_j:end_j+1] >= 0.6)

if contacts[0].size > 0: # If there are contacts
# Convert local chain positions back to global positions using JAX
global_i_positions = contacts[0] + start_i
global_j_positions = contacts[1] + start_j
global_positions = list(set(np.concatenate((global_i_positions, global_j_positions))))
global_positions = np.array(global_positions, dtype=int)
global_positions.sort()

# Initialize new input dictionary
inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float)
inputs_ifptm['asym_id'] = af._inputs['asym_id']
# Update seq_mask for these positions to True within inputs
inputs_ifptm['seq_mask'][global_positions] = 1
# Call get_ptm with updated inputs and outputs
pairwise_if_ptm[key] = get_ptm(inputs_ifptm, outputs, interface=True)
else:
pairwise_if_ptm[key] = 0
else:
cmap_copy = np.zeros((total_length, total_length))
cmap_copy[start_i:end_i+1, start_j:end_j+1] = cmap[start_i:end_i+1, start_j:end_j+1]
cmap_copy[start_j:end_j+1, start_i:end_i+1] = cmap[start_j:end_j+1, start_i:end_i+1]

# Initialize new input dictionary
inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float)
inputs_ifptm['asym_id'] = af._inputs['asym_id']
# Update seq_mask for these positions to True within inputs
inputs_ifptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1),
np.arange(start_j,end_j+1)))] = 1

# Call get_ptm with updated inputs and outputs
pae = {"residue_weights":inputs_ifptm["seq_mask"],
**outputs["predicted_aligned_error"]}
pae["asym_id"] = inputs_ifptm["asym_id"]
pairwise_if_ptm[key] = confidence.predicted_tm_score(**pae, use_jnp=True, pair_residue_weights=cmap_copy)

# Also adding regular i_ptm (interchain), pairwise
outputs = deepcopy(af.aux['debug']['outputs'])
input_pairwise_iptm['seq_mask'] = np.full(len(af._inputs['asym_id']), 0, dtype=float)
input_pairwise_iptm['asym_id'] = af._inputs['asym_id']
input_pairwise_iptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1),
np.arange(start_j,end_j+1)))] = 1
pairwise_iptm[key] = get_ptm(input_pairwise_iptm, outputs, interface=True)

return pairwise_if_ptm, pairwise_iptm
6 changes: 3 additions & 3 deletions colabdesign/af/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from colabdesign.af.prep import _af_prep
from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_ptm
from colabdesign.af.loss import get_contact_map
from colabdesign.af.loss import get_contact_map, get_dgram_bins
from colabdesign.af.utils import _af_utils
from colabdesign.af.design import _af_design
from colabdesign.af.inputs import _af_inputs
Expand Down Expand Up @@ -210,8 +210,8 @@ def _model(params, model_params, inputs, key):
"ptm": get_ptm(inputs, outputs),
"i_ptm": get_ptm(inputs, outputs, interface=True),
"cmap": get_contact_map(outputs, opt["con"]["cutoff"]),
"i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]),
"prev": outputs["prev"]})
"i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]),
"prev": outputs["prev"]})

#######################################################################
# LOSS
Expand Down