-
Notifications
You must be signed in to change notification settings - Fork 273
Closed
Description
I am using this example from #24 (comment) to predict function
import torch
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESMProtein,
SamplingConfig,
SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain
# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda"))
# Load the protein
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)
# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
protein_tensor,
SamplingConfig(
sequence=SamplingTrackConfig(),
structure=SamplingTrackConfig(),
secondary_structure=SamplingTrackConfig(),
sasa=SamplingTrackConfig(),
function=SamplingTrackConfig(only_sample_masked_tokens=False),
),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print(protein_with_function.function_annotations)And I get the error (also seen by another user who commented after the previous issue was closed):
--> [198]raise ValueError("SASA does not start with 0 corresponding to BOS token")
[199]if sasa_tokens[-1] != 0:
[200]raise ValueError("SASA does not end with 0 corresponding to EOS token")
ValueError: SASA does not start with 0 corresponding to BOS token
I went to the relevant line in decoding.py and inserted a print statement. It appears that the first and last value in input.sasa is inf, hence raising the error that the value is not zero. As a quick fix, I could simply hardcode decode_protein_tensor() to change those values to zero, but I'm not sure if that would be wise, or what the implications of this error are.
Metadata
Metadata
Assignees
Labels
No labels