diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..366df75f8 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -30,6 +30,7 @@ convert_gemma_weights, convert_gpt2_weights, convert_gptj_weights, + convert_granite_weights, convert_llama_weights, convert_mingpt_weights, convert_mistral_weights, @@ -263,6 +264,10 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "ibm-granite/granite-3.3-2b-instruct", + "ibm-granite/granite-3.3-2b-base", + "ibm-granite/granite-3.3-8b-instruct", + "ibm-granite/granite-3.3-8b-base", ] """Official model names for models on HuggingFace.""" @@ -719,6 +724,10 @@ "google-t5/t5-base": ["t5-base"], "google-t5/t5-large": ["t5-large"], "ai-forever/mGPT": ["mGPT"], + "ibm-granite/granite-3.3-2b-instruct": ["granite-3.3-2b", "granite-3.3-2b-instruct"], + "ibm-granite/granite-3.3-2b-base": ["granite-3.3-2b-base"], + "ibm-granite/granite-3.3-8b-instruct": ["granite-3.3-8b", "granite-3.3-8b-instruct"], + "ibm-granite/granite-3.3-8b-base": ["granite-3.3-8b-base"], } """Model aliases for models on HuggingFace.""" @@ -1436,7 +1445,29 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } - + elif architecture == "GraniteForCausalLM": + # Granite 3.3 models configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": min(hf_config.max_position_embeddings, 2048), # Cap context length for memory + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": "silu", + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "n_key_value_heads": hf_config.num_key_value_heads, + "gated_mlp": True, + "final_rms": True, + "rotary_adjacent_pairs": False, + "gated_mlp": True, ## Remove this later + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "rotary_base": getattr(hf_config, "rope_theta", 10000.0), + "use_attn_scale": True, + } elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models cfg_dict = { @@ -1986,6 +2017,9 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "GraniteForCausalLM": + state_dict = convert_granite_weights(hf_model, cfg) + else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..23c36bbf8 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .granite import convert_granite_weights diff --git a/transformer_lens/pretrained/weight_conversions/granite.py b/transformer_lens/pretrained/weight_conversions/granite.py new file mode 100644 index 000000000..e1451a985 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/granite.py @@ -0,0 +1,94 @@ +from typing import cast + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): + """ + Converts the weights of a Hugging Face GraniteForCausalLM model to the format + used by HookedTransformer + """ + state_dict = {} + + # Token Embeddings - move to the correct device + state_dict["embed.W_E"] = hf_model.model.embed_tokens.weight.to(device=cfg.device) + + # Granite architecture use Grouped Query Attention + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) + + for l in range(cfg.n_layers): + # LayerNorm 1 - move to the correct device + state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight.to( + device=cfg.device + ) + + # Attention weights + # Transpose the weights first, then rearrange + W_Q = hf_model.model.layers[l].self_attn.q_proj.weight.T + W_K = hf_model.model.layers[l].self_attn.k_proj.weight.T + W_V = hf_model.model.layers[l].self_attn.v_proj.weight.T + W_O = hf_model.model.layers[l].self_attn.o_proj.weight.T + + # Reshape weights for TransformerLens internal format + W_Q = einops.rearrange(W_Q, "m (n h) -> n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "m (n h) -> n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "m (n h) -> n m h", n=n_kv_heads) + W_O = einops.rearrange(W_O, "(n h) m -> n h m", n=cfg.n_heads) + + # Move weights to the correct device + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) + + # Attention biases (Granite models don't use biases, so we set them to zero) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # LayerNorm 2 + state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[ + l + ].post_attention_layernorm.weight.to(device=cfg.device) + + # MLP weights for GatedMLP - move to the correct device + state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T.to( + device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T.to( + device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T.to( + device=cfg.device + ) + + # MLP biases (Granite models don't use biases, so we set them to zero) + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # Final LayerNorm + state_dict["ln_final.w"] = hf_model.model.norm.weight.to(device=cfg.device) + + # Unembedding weights + state_dict["unembed.W_U"] = hf_model.lm_head.weight.T.to(device=cfg.device) + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) + + return state_dict diff --git a/transformer_lens/utilities/activation_functions.py b/transformer_lens/utilities/activation_functions.py index 6cc701360..91a9972f1 100644 --- a/transformer_lens/utilities/activation_functions.py +++ b/transformer_lens/utilities/activation_functions.py @@ -23,4 +23,5 @@ "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), + "swiglu": lambda x: F.silu(torch.chunk(x, 2, dim=-1)[0]) * torch.chunk(x, 2, dim=-1)[1], }