diff --git a/esm/layers/transformer_stack.py b/esm/layers/transformer_stack.py index 0b587819..bef1706b 100644 --- a/esm/layers/transformer_stack.py +++ b/esm/layers/transformer_stack.py @@ -37,8 +37,10 @@ def __init__( ffn_type: str = "swiglu", # swiglu | gelu expansion_ratio: float = 8 / 3, use_flash_attn: bool = False, + return_hidden_states: bool = False, ): super().__init__() + self.return_hidden_states = return_hidden_states self.blocks = nn.ModuleList( [ UnifiedTransformerBlock( @@ -90,5 +92,6 @@ def forward( hiddens = [] for block in self.blocks: x = block(x, sequence_id, affine, affine_mask, chain_id) - hiddens.append(x) + if self.return_hidden_states: + hiddens.append(x) return self.norm(x), x, hiddens diff --git a/esm/models/esm3.py b/esm/models/esm3.py index cbe02ddd..78034189 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -165,7 +165,8 @@ def __init__(self, d_model: int): self.function_head = RegressionHead(d_model, 260 * 8) self.residue_head = RegressionHead(d_model, 1478) - def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: + def forward(self, x: torch.Tensor, last_hidden_state: torch.Tensor) -> ESMOutput: + embeddings = x.clone() sequence_logits = self.sequence_head(x) structure_logits = self.structure_head(x) secondary_structure_logits = self.ss8_head(x) @@ -182,7 +183,7 @@ def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput: sasa_logits=sasa_logits, function_logits=function_logits, residue_logits=residue_logits, - embeddings=embed, + embeddings=embeddings, ) @@ -376,10 +377,10 @@ def forward( function_tokens, residue_annotation_tokens, ) - x, embedding, _ = self.transformer( + x, last_hidden_states, _ = self.transformer( x, sequence_id, affine, affine_mask, chain_id ) - return self.output_heads(x, embedding) + return self.output_heads(x, last_hidden_states) # The following methods are for the ESM3InferenceClient interface def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: diff --git a/esm/models/esmc.py b/esm/models/esmc.py index 0d3438f6..020407e2 100644 --- a/esm/models/esmc.py +++ b/esm/models/esmc.py @@ -58,11 +58,13 @@ def __init__( n_layers: int, tokenizer: EsmSequenceTokenizer, use_flash_attn: bool = True, + return_hidden_states: bool = False, ): super().__init__() self.embed = nn.Embedding(64, d_model) self._use_flash_attn = is_flash_attn_available and use_flash_attn + self.return_hidden_states = return_hidden_states self.transformer = TransformerStack( d_model, n_heads, @@ -70,6 +72,7 @@ def __init__( n_layers, n_layers_geom=0, use_flash_attn=self._use_flash_attn, + return_hidden_states=self.return_hidden_states ) self.sequence_head = RegressionHead(d_model, 64) @@ -164,7 +167,8 @@ def forward( ] # Stack hidden states into a [n_layers, B, L, D] matrix. - hiddens = torch.stack(hiddens, dim=0) # type: ignore + if len(hiddens): + hiddens = torch.stack(hiddens, dim=0) # type: ignore sequence_logits = self.sequence_head(x) output = ESMCOutput(