diff --git a/af/examples/predict.ipynb b/af/examples/predict.ipynb index 0d858195..34a6d0ff 100644 --- a/af/examples/predict.ipynb +++ b/af/examples/predict.ipynb @@ -5,6 +5,7 @@ "colab": { "provenance": [], "gpuType": "T4", + "machine_shape": "hm", "include_colab_link": true }, "kernelspec": { @@ -24,7 +25,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -56,9 +57,10 @@ " aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n", " tar -xf alphafold_params_2022-12-06.tar -C params; touch params/done.txt )&\")\n", "\n", - " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@gamma\")\n", + " os.system(\"pip -q install git+https://github.com/gezmi/ColabDesign.git@gamma\") # TODO: change to sokrypton\n", " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n", " os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py\")\n", + " os.system(\"wget https://raw.githubusercontent.com/gezmi/ColabFold/actifptm/colabfold/alphafold/extended_metrics.py -O extended_metrics.py\") # TODO: change to sokrypton\n", " #os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py\")\n", "\n", " # install hhsuite\n", @@ -83,7 +85,7 @@ "from colabdesign.af.contrib.cyclic import add_cyclic_offset\n", "from colabdesign.shared.protein import _np_rmsd, _np_kabsch\n", "from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap\n", - "\n", + "import extended_metrics\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -190,7 +192,7 @@ "cell_type": "code", "source": [ "#@title prep_inputs\n", - "sequence = \"PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\" #@param {type:\"string\"}\n", + "sequence = \"PASQHFLSTSVQGPWERAISPNKVPYYINHETQTTCWDHPKM:KNMTPYRSPPPYVPP\" #@param {type:\"string\"}\n", "jobname = \"test\" #@param {type:\"string\"}\n", "\n", "copies = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"10\", \"11\", \"12\"] {type:\"raw\"}\n", @@ -403,7 +405,7 @@ "cell_type": "code", "source": [ "#@title pre_analysis (optional)\n", - "analysis = \"none\" # @param [\"none\", \"coevolution\"]\n", + "analysis = \"coevolution\" # @param [\"none\", \"coevolution\"]\n", "dpi = 100 # @param [\"100\", \"200\", \"300\"] {type:\"raw\"}\n", "if analysis == \"coevolution\":\n", " coevol = get_coevolution(msa)\n", @@ -434,6 +436,13 @@ "num_msa = 512 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\"] {type:\"raw\"}\n", "num_extra_msa = 1024 #@param [\"1\",\"2\",\"4\",\"8\",\"16\",\"32\", \"64\", \"128\", \"256\", \"512\", \"1024\",\"2048\",\"4096\"] {type:\"raw\"}\n", "use_cluster_profile = True #@param {type:\"boolean\"}\n", + "\n", + "#@markdown Extended metrics (calculate pairwise ipTM, actifpTM and chain pTM)\n", + "calc_extended_metrics = True #@param {type:\"boolean\"}\n", + "\n", + "if extended_metrics:\n", + " debug=True\n", + "\n", "if model_type == \"monomer (ptm)\":\n", " use_multimer = False\n", " pseudo_multimer = False\n", @@ -526,8 +535,8 @@ " add_cyclic_offset(af,i_cyclic)" ], "metadata": { - "cellView": "form", - "id": "0AefVJipkQe3" + "id": "0AefVJipkQe3", + "cellView": "form" }, "execution_count": null, "outputs": [] @@ -617,6 +626,9 @@ " tag = f\"{model}_r{recycle}_seed{seed}\"\n", " if select_best_across_recycles:\n", " info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n", + " if calc_extended_metrics:\n", + " extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n", + " info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n", " af._save_results(save_best=True,\n", " best_metric=rank_by, metric_higher_better=True,\n", " verbose=False)\n", @@ -627,6 +639,9 @@ "\n", " if not select_best_across_recycles:\n", " info.append([tag,print_str,af.aux[\"log\"][rank_by]])\n", + " if calc_extended_metrics:\n", + " extended_ptms = extended_metrics.get_chain_and_interface_metrics(af.aux['debug']['outputs'], af._inputs['asym_id'])\n", + " info[-1].extend([extended_ptms['pairwise_iptm'], extended_ptms['pairwise_actifptm'], extended_ptms['per_chain_ptm']])\n", " af._save_results(save_best=True,\n", " best_metric=rank_by, metric_higher_better=True,\n", " verbose=False)\n", @@ -648,12 +663,18 @@ " plddt=aux_best[\"plddt\"].astype(np.float16),\n", " pae=aux_best[\"pae\"].astype(np.float16),\n", " tag=np.array(info[rank[0]][0]),\n", - " metrics=np.array(info[rank[0]][1]))\n", + " metrics=np.array(info[rank[0]][1]),\n", + " iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n", + " actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n", + " cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n", "np.savez_compressed(f\"{pdb_path}/all.npz\",\n", " plddt=np.array(af._tmp[\"traj\"][\"plddt\"], dtype=np.float16),\n", " pae=np.array(af._tmp[\"traj\"][\"pae\"], dtype=np.float16),\n", " tag=np.array([x[0] for x in info]),\n", - " metrics=np.array([x[1] for x in info]))\n", + " metrics=np.array([x[1] for x in info]),\n", + " iptm_pairwise=np.array(info[rank[0]][3]) if len(info[rank[0]]) > 3 else np.array([]),\n", + " actifptm_pairwise=np.array(info[rank[0]][4]) if len(info[rank[0]]) > 4 else np.array([]),\n", + " cptm=np.array(info[rank[0]][5]) if len(info[rank[0]]) > 5 else np.array([]))\n", "plot_3D(aux_best, Ls * copies, f\"{pdb_path}/best.pdf\", show=False)\n", "predict.plot_confidence(aux_best[\"plddt\"]*100, aux_best[\"pae\"], Ls * copies)\n", "plt.savefig(f\"{pdb_path}/best.png\", dpi=200, bbox_inches='tight')\n", @@ -663,8 +684,8 @@ "print(\"GC\",gc.collect())" ], "metadata": { - "cellView": "form", - "id": "xYXcKFPQyTQU" + "id": "xYXcKFPQyTQU", + "cellView": "form" }, "execution_count": null, "outputs": [] @@ -785,6 +806,20 @@ "execution_count": null, "outputs": [] }, + { + "cell_type": "code", + "source": [ + "#@title chainwise_and_pairwise_analysis\n", + "if calc_extended_metrics:\n", + " extended_metrics.plot_chain_pairwise_analysis(info, prefix='model_')" + ], + "metadata": { + "id": "WoLg1yqYDbiX", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, { "cell_type": "code", "source": [ diff --git a/colabdesign/af/alphafold/common/confidence.py b/colabdesign/af/alphafold/common/confidence.py index 1d566d86..6ddb09b4 100644 --- a/colabdesign/af/alphafold/common/confidence.py +++ b/colabdesign/af/alphafold/common/confidence.py @@ -111,7 +111,7 @@ def compute_predicted_aligned_error(logits, breaks, use_jnp=False): } def predicted_tm_score(logits, breaks, residue_weights = None, - asym_id = None, use_jnp=False): + asym_id = None, use_jnp=False, pair_residue_weights=None): """Computes predicted TM alignment or predicted interface TM alignment score. Args: @@ -122,6 +122,7 @@ def predicted_tm_score(logits, breaks, residue_weights = None, expectation. asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for ipTM calculation. + pair_residue_weights: [num_res, num_res] unnormalized weights for ifptm calculation Returns: ptm_score: The predicted TM alignment or the predicted iTM score. @@ -135,23 +136,23 @@ def predicted_tm_score(logits, breaks, residue_weights = None, # exp. resolved head's probability. if residue_weights is None: residue_weights = _np.ones(logits.shape[0]) - bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp) num_res = residue_weights.shape[0] # Clip num_res to avoid negative/undefined d0. - clipped_num_res = _np.maximum(residue_weights.sum(), 19) + clipped_num_res = _np.maximum(num_res, 19) # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick # "Scoring function for automated assessment of protein structure template # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 - + # Convert logits to probs. probs = _softmax(logits, axis=-1) # TM-Score term for every bin. tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0)) + # E_distances tm(distance). predicted_tm_term = (probs * tm_per_bin).sum(-1) @@ -162,8 +163,10 @@ def predicted_tm_score(logits, breaks, residue_weights = None, predicted_tm_term *= pair_mask - pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) + # If normed_residue_mask is provided (e.g. for if_ptm with contact probabilities), + # it should not be overwritten + if pair_residue_weights is None: + pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1) - - return (per_alignment * residue_weights).max() \ No newline at end of file + return (per_alignment * residue_weights).max() diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index e5cb7ae8..a0b5b35a 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -454,4 +454,92 @@ def or_masks(*m): mask = m[0] for n in range(1,len(m)): mask = jnp.logical_or(mask,m[n]) - return mask \ No newline at end of file + return mask + +def numpy_callback(x): + # Need to forward-declare the shape & dtype of the expected output. + result_shape = jax.core.ShapedArray(x.shape, x.dtype) + return jax.pure_callback(np.sin, result_shape, x) + +def get_chain_indices(chain_boundaries): + """Returns a list of tuples indicating the start and end indices for each chain.""" + chain_starts_ends = [] + unique_chains = np.unique(chain_boundaries) + for chain in unique_chains: + positions = np.where(chain_boundaries == chain)[0] + chain_starts_ends.append((positions[0], positions[-1])) + return chain_starts_ends + +def get_pairwise_iptm(af, calculate_interface=False): + import string + from copy import deepcopy + cmap = get_contact_map(af.aux['debug']['outputs'], 8) # Define interface with 8A between Cb-s + + # Initialize seq_mask to all False + inputs_ifptm = {} + input_pairwise_iptm = {} + pairwise_iptm = {} + + # Prepare a dictionary to collect results + pairwise_if_ptm = {} + chain_starts_ends = get_chain_indices(af._inputs['asym_id']) + + # Generate chain labels (A, B, C, ...) + chain_labels = list(string.ascii_uppercase) + + total_length = len(af._inputs['asym_id']) + for i, (start_i, end_i) in enumerate(chain_starts_ends): + chain_label_i = chain_labels[i % len(chain_labels)] # Wrap around if more than 26 chains + for j, (start_j, end_j) in enumerate(chain_starts_ends): + chain_label_j = chain_labels[j % len(chain_labels)] # Wrap around if more than 26 chains + if i < j: # Avoid self-comparison and duplicate comparisons + outputs = deepcopy(af.aux['debug']['outputs']) + key = f"{chain_label_i}-{chain_label_j}" + + if calculate_interface: + contacts = np.where(cmap[start_i:end_i+1, start_j:end_j+1] >= 0.6) + + if contacts[0].size > 0: # If there are contacts + # Convert local chain positions back to global positions using JAX + global_i_positions = contacts[0] + start_i + global_j_positions = contacts[1] + start_j + global_positions = list(set(np.concatenate((global_i_positions, global_j_positions)))) + global_positions = np.array(global_positions, dtype=int) + global_positions.sort() + + # Initialize new input dictionary + inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float) + inputs_ifptm['asym_id'] = af._inputs['asym_id'] + # Update seq_mask for these positions to True within inputs + inputs_ifptm['seq_mask'][global_positions] = 1 + # Call get_ptm with updated inputs and outputs + pairwise_if_ptm[key] = get_ptm(inputs_ifptm, outputs, interface=True) + else: + pairwise_if_ptm[key] = 0 + else: + cmap_copy = np.zeros((total_length, total_length)) + cmap_copy[start_i:end_i+1, start_j:end_j+1] = cmap[start_i:end_i+1, start_j:end_j+1] + cmap_copy[start_j:end_j+1, start_i:end_i+1] = cmap[start_j:end_j+1, start_i:end_i+1] + + # Initialize new input dictionary + inputs_ifptm['seq_mask'] = np.full(total_length, 0, dtype=float) + inputs_ifptm['asym_id'] = af._inputs['asym_id'] + # Update seq_mask for these positions to True within inputs + inputs_ifptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1), + np.arange(start_j,end_j+1)))] = 1 + + # Call get_ptm with updated inputs and outputs + pae = {"residue_weights":inputs_ifptm["seq_mask"], + **outputs["predicted_aligned_error"]} + pae["asym_id"] = inputs_ifptm["asym_id"] + pairwise_if_ptm[key] = confidence.predicted_tm_score(**pae, use_jnp=True, pair_residue_weights=cmap_copy) + + # Also adding regular i_ptm (interchain), pairwise + outputs = deepcopy(af.aux['debug']['outputs']) + input_pairwise_iptm['seq_mask'] = np.full(len(af._inputs['asym_id']), 0, dtype=float) + input_pairwise_iptm['asym_id'] = af._inputs['asym_id'] + input_pairwise_iptm['seq_mask'][np.concatenate((np.arange(start_i,end_i+1), + np.arange(start_j,end_j+1)))] = 1 + pairwise_iptm[key] = get_ptm(input_pairwise_iptm, outputs, interface=True) + + return pairwise_if_ptm, pairwise_iptm diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 49bb07a2..a8e2b97a 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -12,7 +12,7 @@ from colabdesign.af.prep import _af_prep from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_ptm -from colabdesign.af.loss import get_contact_map +from colabdesign.af.loss import get_contact_map, get_dgram_bins from colabdesign.af.utils import _af_utils from colabdesign.af.design import _af_design from colabdesign.af.inputs import _af_inputs @@ -210,8 +210,8 @@ def _model(params, model_params, inputs, key): "ptm": get_ptm(inputs, outputs), "i_ptm": get_ptm(inputs, outputs, interface=True), "cmap": get_contact_map(outputs, opt["con"]["cutoff"]), - "i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]), - "prev": outputs["prev"]}) + "i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]), + "prev": outputs["prev"]}) ####################################################################### # LOSS