Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
convert_gemma_weights,
convert_gpt2_weights,
convert_gptj_weights,
convert_granite_weights,
convert_llama_weights,
convert_mingpt_weights,
convert_mistral_weights,
Expand Down Expand Up @@ -263,6 +264,10 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"ibm-granite/granite-3.3-2b-instruct",
"ibm-granite/granite-3.3-2b-base",
"ibm-granite/granite-3.3-8b-instruct",
"ibm-granite/granite-3.3-8b-base",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -719,6 +724,10 @@
"google-t5/t5-base": ["t5-base"],
"google-t5/t5-large": ["t5-large"],
"ai-forever/mGPT": ["mGPT"],
"ibm-granite/granite-3.3-2b-instruct": ["granite-3.3-2b", "granite-3.3-2b-instruct"],
"ibm-granite/granite-3.3-2b-base": ["granite-3.3-2b-base"],
"ibm-granite/granite-3.3-8b-instruct": ["granite-3.3-8b", "granite-3.3-8b-instruct"],
"ibm-granite/granite-3.3-8b-base": ["granite-3.3-8b-base"],
}
"""Model aliases for models on HuggingFace."""

Expand Down Expand Up @@ -1436,7 +1445,29 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"parallel_attn_mlp": False,
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
}

elif architecture == "GraniteForCausalLM":
# Granite 3.3 models configuration
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": min(hf_config.max_position_embeddings, 2048), # Cap context length for memory
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"n_key_value_heads": hf_config.num_key_value_heads,
"gated_mlp": True,
"final_rms": True,
"rotary_adjacent_pairs": False,
"gated_mlp": True, ## Remove this later
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"rotary_base": getattr(hf_config, "rope_theta", 10000.0),
"use_attn_scale": True,
}
elif official_model_name.startswith("google/gemma-2b"):
# Architecture for Gemma 2b and Gemma 2b Instruct models
cfg_dict = {
Expand Down Expand Up @@ -1986,6 +2017,9 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "GraniteForCausalLM":
state_dict = convert_granite_weights(hf_model, cfg)

else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .granite import convert_granite_weights
94 changes: 94 additions & 0 deletions transformer_lens/pretrained/weight_conversions/granite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import cast

import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_granite_weights(hf_model, cfg: HookedTransformerConfig):
"""
Converts the weights of a Hugging Face GraniteForCausalLM model to the format
used by HookedTransformer
"""
state_dict = {}

# Token Embeddings - move to the correct device
state_dict["embed.W_E"] = hf_model.model.embed_tokens.weight.to(device=cfg.device)

# Granite architecture use Grouped Query Attention
using_gqa = cfg.n_key_value_heads is not None
gqa_uscore = "_" if using_gqa else ""
n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)

for l in range(cfg.n_layers):
# LayerNorm 1 - move to the correct device
state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight.to(
device=cfg.device
)

# Attention weights
# Transpose the weights first, then rearrange
W_Q = hf_model.model.layers[l].self_attn.q_proj.weight.T
W_K = hf_model.model.layers[l].self_attn.k_proj.weight.T
W_V = hf_model.model.layers[l].self_attn.v_proj.weight.T
W_O = hf_model.model.layers[l].self_attn.o_proj.weight.T

# Reshape weights for TransformerLens internal format
W_Q = einops.rearrange(W_Q, "m (n h) -> n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "m (n h) -> n m h", n=n_kv_heads)
W_V = einops.rearrange(W_V, "m (n h) -> n m h", n=n_kv_heads)
W_O = einops.rearrange(W_O, "(n h) m -> n h m", n=cfg.n_heads)

# Move weights to the correct device
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q.to(device=cfg.device)
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K.to(device=cfg.device)
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V.to(device=cfg.device)
state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)

# Attention biases (Granite models don't use biases, so we set them to zero)
state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

# LayerNorm 2
state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[
l
].post_attention_layernorm.weight.to(device=cfg.device)

# MLP weights for GatedMLP - move to the correct device
state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T.to(
device=cfg.device
)
state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T.to(
device=cfg.device
)
state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T.to(
device=cfg.device
)

# MLP biases (Granite models don't use biases, so we set them to zero)
state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

# Final LayerNorm
state_dict["ln_final.w"] = hf_model.model.norm.weight.to(device=cfg.device)

# Unembedding weights
state_dict["unembed.W_U"] = hf_model.lm_head.weight.T.to(device=cfg.device)
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)

return state_dict
1 change: 1 addition & 0 deletions transformer_lens/utilities/activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
"relu": F.relu,
"gelu": F.gelu,
"gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
"swiglu": lambda x: F.silu(torch.chunk(x, 2, dim=-1)[0]) * torch.chunk(x, 2, dim=-1)[1],
}
Loading