diff --git a/af/examples/predict.ipynb b/af/examples/predict.ipynb
index 0d858195..34a6d0ff 100644
--- a/af/examples/predict.ipynb
+++ b/af/examples/predict.ipynb
@@ -5,6 +5,7 @@
"colab": {
"provenance": [],
"gpuType": "T4",
+ "machine_shape": "hm",
"include_colab_link": true
},
"kernelspec": {
@@ -24,7 +25,7 @@
"colab_type": "text"
},
"source": [
- "
"
+ "
"
]
},
{
@@ -56,9 +57,10 @@
" aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n",
" tar -xf alphafold_params_2022-12-06.tar -C params; touch params/done.txt )&\")\n",
"\n",
- " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@gamma\")\n",
+ " os.system(\"pip -q install git+https://github.com/gezmi/ColabDesign.git@gamma\") # TODO: change to sokrypton\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n",
" os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py\")\n",
+ " os.system(\"wget https://raw.githubusercontent.com/gezmi/ColabFold/actifptm/colabfold/alphafold/extended_metrics.py -O extended_metrics.py\") # TODO: change to sokrypton\n",
" #os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py\")\n",
"\n",
" # install hhsuite\n",
@@ -83,7 +85,7 @@
"from colabdesign.af.contrib.cyclic import add_cyclic_offset\n",
"from colabdesign.shared.protein import _np_rmsd, _np_kabsch\n",
"from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap\n",
- "\n",
+ "import extended_metrics\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
@@ -190,7 +192,7 @@
"cell_type": "code",
"source": [
"#@title prep_inputs\n",
- "sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n",
+ "sequence = \"PASQHFLSTSVQGPWERAISPNKVPYYINHETQTTCWDHPKM:KNMTPYRSPPPYVPP\" #@param {type:\"string\"}\n",
"jobname = \"test\" #@param {type:\"string\"}\n",
"\n",
"copies = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"11\", \"12\"] {type:\"raw\"}\n",
@@ -403,7 +405,7 @@
"cell_type": "code",
"source": [
"#@title pre_analysis (optional)\n",
- "analysis = \"none\" # @param [\"none\", \"coevolution\"]\n",
+ "analysis = \"coevolution\" # @param [\"none\", \"coevolution\"]\n",
"dpi = 100 # @param [\"100\", \"200\", \"300\"] {type:\"raw\"}\n",
"if analysis == \"coevolution\":\n",
" coevol = get_coevolution(msa)\n",
@@ -434,6 +436,13 @@
"num_msa = 512 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\"] {type:\"raw\"}\n",
"num_extra_msa = 1024 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\", \"1024\",\"2048\",\"4096\"] {type:\"raw\"}\n",
"use_cluster_profile = True #@param {type:\"boolean\"}\n",
+ "\n",
+ "#@markdown Extended metrics (calculate pairwise ipTM, actifpTM and chain pTM)\n",
+ "calc_extended_metrics = True #@param {type:\"boolean\"}\n",
+ "\n",
+ "if extended_metrics:\n",
+ " debug=True\n",
+ "\n",
"if model_type == \"monomer (ptm)\":\n",
" use_multimer = False\n",
" pseudo_multimer = False\n",
@@ -526,8 +535,8 @@
" add_cyclic_offset(af,i_cyclic)"
],
"metadata": {
- "cellView": "form",
- "id": "0AefVJipkQe3"
+ "id": "0AefVJipkQe3",
+ "cellView": "form"
},
"execution_count": null,
"outputs": []
@@ -617,6 +626,9 @@
" tag = f\"{model}_r{recycle}_seed{seed}\"\n",
" if select_best_across_recycles:\n",
" info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n",
+ " if calc_extended_metrics:\n",
+ " extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n",
+ " info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n",
" af._save_results(save_best=True,\n",
" best_metric=rank_by, metric_higher_better=True,\n",
" verbose=False)\n",
@@ -627,6 +639,9 @@
"\n",
" if not select_best_across_recycles:\n",
" info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n",
+ " if calc_extended_metrics:\n",
+ " extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n",
+ " info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n",
" af._save_results(save_best=True,\n",
" best_metric=rank_by, metric_higher_better=True,\n",
" verbose=False)\n",
@@ -648,12 +663,18 @@
" plddt=aux_best[\"plddt\"].astype(np.float16),\n",
" pae=aux_best[\"pae\"].astype(np.float16),\n",
" tag=np.array(info[rank[0]][0]),\n",
- " metrics=np.array(info[rank[0]][1]))\n",
+ " metrics=np.array(info[rank[0]][1]),\n",
+ " iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n",
+ " actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n",
+ " cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n",
"np.savez_compressed(f\"{pdb_path}/all.npz\",\n",
" plddt=np.array(af._tmp[\"traj\"][\"plddt\"], dtype=np.float16),\n",
" pae=np.array(af._tmp[\"traj\"][\"pae\"], dtype=np.float16),\n",
" tag=np.array([x[0] for x in info]),\n",
- " metrics=np.array([x[1] for x in info]))\n",
+ " metrics=np.array([x[1] for x in info]),\n",
+ " iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n",
+ " actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n",
+ " cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n",
"plot_3D(aux_best, Ls * copies, f\"{pdb_path}/best.pdf\", show=False)\n",
"predict.plot_confidence(aux_best[\"plddt\"]*100, aux_best[\"pae\"], Ls * copies)\n",
"plt.savefig(f\"{pdb_path}/best.png\", dpi=200, bbox_inches='tight')\n",
@@ -663,8 +684,8 @@
"print(\"GC\",gc.collect())"
],
"metadata": {
- "cellView": "form",
- "id": "xYXcKFPQyTQU"
+ "id": "xYXcKFPQyTQU",
+ "cellView": "form"
},
"execution_count": null,
"outputs": []
@@ -785,6 +806,20 @@
"execution_count": null,
"outputs": []
},
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title chainwise_and_pairwise_analysis\n",
+ "if calc_extended_metrics:\n",
+ " extended_metrics.plot_chain_pairwise_analysis(info, prefix='model_')"
+ ],
+ "metadata": {
+ "id": "WoLg1yqYDbiX",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
{
"cell_type": "code",
"source": [
diff --git a/colabdesign/af/alphafold/common/confidence.py b/colabdesign/af/alphafold/common/confidence.py
index 1d566d86..6ddb09b4 100644
--- a/colabdesign/af/alphafold/common/confidence.py
+++ b/colabdesign/af/alphafold/common/confidence.py
@@ -111,7 +111,7 @@ def compute_predicted_aligned_error(logits, breaks, use_jnp=False):
}
def predicted_tm_score(logits, breaks, residue_weights = None,
- asym_id = None, use_jnp=False):
+ asym_id = None, use_jnp=False, pair_residue_weights=None):
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
@@ -122,6 +122,7 @@ def predicted_tm_score(logits, breaks, residue_weights = None,
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation.
+ pair_residue_weights: [num_res, num_res] unnormalized weights for ifptm calculation
Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
@@ -135,23 +136,23 @@ def predicted_tm_score(logits, breaks, residue_weights = None,
# exp. resolved head's probability.
if residue_weights is None:
residue_weights = _np.ones(logits.shape[0])
-
bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp)
num_res = residue_weights.shape[0]
# Clip num_res to avoid negative/undefined d0.
- clipped_num_res = _np.maximum(residue_weights.sum(), 19)
+ clipped_num_res = _np.maximum(num_res, 19)
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
-
+
# Convert logits to probs.
probs = _softmax(logits, axis=-1)
# TM-Score term for every bin.
tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0))
+
# E_distances tm(distance).
predicted_tm_term = (probs * tm_per_bin).sum(-1)
@@ -162,8 +163,10 @@ def predicted_tm_score(logits, breaks, residue_weights = None,
predicted_tm_term *= pair_mask
- pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
+ # If normed_residue_mask is provided (e.g. for if_ptm with contact probabilities),
+ # it should not be overwritten
+ if pair_residue_weights is None:
+ pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True))
per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1)
-
- return (per_alignment * residue_weights).max()
\ No newline at end of file
+ return (per_alignment * residue_weights).max()
diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py
index e5cb7ae8..a0b5b35a 100644
--- a/colabdesign/af/loss.py
+++ b/colabdesign/af/loss.py
@@ -454,4 +454,92 @@ def or_masks(*m):
mask = m[0]
for n in range(1,len(m)):
mask = jnp.logical_or(mask,m[n])
- return mask
\ No newline at end of file
+ return mask
+
+def numpy_callback(x):
+ # Need to forward-declare the shape & dtype of the expected output.
+ result_shape = jax.core.ShapedArray(x.shape, x.dtype)
+ return jax.pure_callback(np.sin, result_shape, x)
+
+def get_chain_indices(chain_boundaries):
+ """Returns a list of tuples indicating the start and end indices for each chain."""
+ chain_starts_ends = []
+ unique_chains = np.unique(chain_boundaries)
+ for chain in unique_chains:
+ positions = np.where(chain_boundaries == chain)[0]
+ chain_starts_ends.append((positions[0], positions[-1]))
+ return chain_starts_ends
+
+def get_pairwise_iptm(af, calculate_interface=False):
+ import string
+ from copy import deepcopy
+ cmap = get_contact_map(af.aux['debug']['outputs'], 8) # Define interface with 8A between Cb-s
+
+ # Initialize seq_mask to all False
+ inputs_ifptm = {}
+ input_pairwise_iptm = {}
+ pairwise_iptm = {}
+
+ # Prepare a dictionary to collect results
+ pairwise_if_ptm = {}
+ chain_starts_ends = get_chain_indices(af._inputs['asym_id'])
+
+ # Generate chain labels (A, B, C, ...)
+ chain_labels = list(string.ascii_uppercase)
+
+ total_length = len(af._inputs['asym_id'])
+ for i, (start_i, end_i) in enumerate(chain_starts_ends):
+ chain_label_i = chain_labels[i % len(chain_labels)] # Wrap around if more than 26 chains
+ for j, (start_j, end_j) in enumerate(chain_starts_ends):
+ chain_label_j = chain_labels[j % len(chain_labels)] # Wrap around if more than 26 chains
+ if i < j: # Avoid self-comparison and duplicate comparisons
+ outputs = deepcopy(af.aux['debug']['outputs'])
+ key = f"{chain_label_i}-{chain_label_j}"
+
+ if calculate_interface:
+ contacts = np.where(cmap[start_i:end_i+1, start_j:end_j+1] >= 0.6)
+
+ if contacts[0].size > 0: # If there are contacts
+ # Convert local chain positions back to global positions using JAX
+ global_i_positions = contacts[0] + start_i
+ global_j_positions = contacts[1] + start_j
+ global_positions = list(set(np.concatenate((global_i_positions, global_j_positions))))
+ global_positions = np.array(global_positions, dtype=int)
+ global_positions.sort()
+
+ # Initialize new input dictionary
+ inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float)
+ inputs_ifptm['asym_id'] = af._inputs['asym_id']
+ # Update seq_mask for these positions to True within inputs
+ inputs_ifptm['seq_mask'][global_positions] = 1
+ # Call get_ptm with updated inputs and outputs
+ pairwise_if_ptm[key] = get_ptm(inputs_ifptm, outputs, interface=True)
+ else:
+ pairwise_if_ptm[key] = 0
+ else:
+ cmap_copy = np.zeros((total_length, total_length))
+ cmap_copy[start_i:end_i+1, start_j:end_j+1] = cmap[start_i:end_i+1, start_j:end_j+1]
+ cmap_copy[start_j:end_j+1, start_i:end_i+1] = cmap[start_j:end_j+1, start_i:end_i+1]
+
+ # Initialize new input dictionary
+ inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float)
+ inputs_ifptm['asym_id'] = af._inputs['asym_id']
+ # Update seq_mask for these positions to True within inputs
+ inputs_ifptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1),
+ np.arange(start_j,end_j+1)))] = 1
+
+ # Call get_ptm with updated inputs and outputs
+ pae = {"residue_weights":inputs_ifptm["seq_mask"],
+ **outputs["predicted_aligned_error"]}
+ pae["asym_id"] = inputs_ifptm["asym_id"]
+ pairwise_if_ptm[key] = confidence.predicted_tm_score(**pae, use_jnp=True, pair_residue_weights=cmap_copy)
+
+ # Also adding regular i_ptm (interchain), pairwise
+ outputs = deepcopy(af.aux['debug']['outputs'])
+ input_pairwise_iptm['seq_mask'] = np.full(len(af._inputs['asym_id']), 0, dtype=float)
+ input_pairwise_iptm['asym_id'] = af._inputs['asym_id']
+ input_pairwise_iptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1),
+ np.arange(start_j,end_j+1)))] = 1
+ pairwise_iptm[key] = get_ptm(input_pairwise_iptm, outputs, interface=True)
+
+ return pairwise_if_ptm, pairwise_iptm
diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py
index 49bb07a2..a8e2b97a 100644
--- a/colabdesign/af/model.py
+++ b/colabdesign/af/model.py
@@ -12,7 +12,7 @@
from colabdesign.af.prep import _af_prep
from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_ptm
-from colabdesign.af.loss import get_contact_map
+from colabdesign.af.loss import get_contact_map, get_dgram_bins
from colabdesign.af.utils import _af_utils
from colabdesign.af.design import _af_design
from colabdesign.af.inputs import _af_inputs
@@ -210,8 +210,8 @@ def _model(params, model_params, inputs, key):
"ptm": get_ptm(inputs, outputs),
"i_ptm": get_ptm(inputs, outputs, interface=True),
"cmap": get_contact_map(outputs, opt["con"]["cutoff"]),
- "i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]),
- "prev": outputs["prev"]})
+ "i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]),
+ "prev": outputs["prev"]})
#######################################################################
# LOSS