From 082a45e455eb00e903ff059c24ce1e367bd0df2f Mon Sep 17 00:00:00 2001 From: emharsha1812 Date: Tue, 8 Jul 2025 15:26:36 +0530 Subject: [PATCH 1/4] Support for Granite3.3 --- test_granite.py | 40 +++++++++++ transformer_lens/loading_from_pretrained.py | 47 ++++++++++++- .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/granite.py | 70 +++++++++++++++++++ .../utilities/activation_functions.py | 1 + 5 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 test_granite.py create mode 100644 transformer_lens/pretrained/weight_conversions/granite.py diff --git a/test_granite.py b/test_granite.py new file mode 100644 index 000000000..ad104fbf0 --- /dev/null +++ b/test_granite.py @@ -0,0 +1,40 @@ +import torch +from transformer_lens import HookedTransformer + +# The Hugging Face model name you added +MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct" + +print(f"--- Loading model: {MODEL_NAME} ---") + +# Set the device +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Using device: {device}") + +# Load the model from Hugging Face +# This will use all the code you've just written! +model = HookedTransformer.from_pretrained( + MODEL_NAME, + device=device, + # You might need to use a lower-precision dtype if you're low on VRAM + # torch_dtype=torch.bfloat16 +) + +print("\n--- Model Configuration ---") +# Print the config to double-check that your parameters were loaded correctly +print(model.cfg) + +print("\n--- Running Generation Test ---") +prompt = "The best programming language is" + +# Generate some text +# The model will return token IDs, so we use .to_string() to convert back to text +output = model.generate( + prompt, + max_new_tokens=10, + temperature=0.7, +) + +print(f"\nPrompt: '{prompt}'") +print(f"Generated text: '{output}'") + +print("\n--- ✅ Test Complete ---") \ No newline at end of file diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..1e87f98a3 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -44,6 +44,7 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, + convert_granite_weights ) OFFICIAL_MODEL_NAMES = [ @@ -263,6 +264,7 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "ibm-granite/granite-3.3-2b-instruct" ] """Official model names for models on HuggingFace.""" @@ -719,6 +721,7 @@ "google-t5/t5-base": ["t5-base"], "google-t5/t5-large": ["t5-large"], "ai-forever/mGPT": ["mGPT"], + "ibm-granite/granite-3.3-2b-instruct":["granite-3.3-2b","granite-3.3"] } """Model aliases for models on HuggingFace.""" @@ -794,6 +797,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): architecture = "Gemma2ForCausalLM" elif "gemma" in official_model_name.lower(): architecture = "GemmaForCausalLM" + # elif "granite" in official_model_name.lower(): + # architecture="GraniteForCausalLM" else: huggingface_token = os.environ.get("HF_TOKEN", "") hf_config = AutoConfig.from_pretrained( @@ -1436,7 +1441,44 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } - + # elif architecture=="GraniteForCausalLM": + # ### Architecture for Granite 3.3 2b and Granite 3.3 2b instruct models + # cfg_dict = { + # "d_model": 2048, + # "d_head": 64, + # "n_heads": 32, + # "d_mlp": 8192, + # "n_layers": 40, + # "n_ctx": 2048, + # "eps": 1e-5, + # "d_vocab": 49152, + # "act_fn": "swiglu", # Based on common practice for Granite-like models + # "normalization_type": "RMS", + # "positional_embedding_type": "rotary", + # "gated_mlp":True, + # "rotary_dim": 64, + # "final_rms":True + # } + elif architecture == "GraniteForCausalLM": + # This block now correctly assumes `hf_config` has been loaded + cfg_dict = { + "d_model": hf_config.hidden_size, + "n_layers": hf_config.num_hidden_layers, + "n_heads": hf_config.num_attention_heads, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": "swiglu", + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "n_key_value_heads": hf_config.num_key_value_heads, + "gated_mlp": True, + "final_rms": True, + # rotary_dim is often d_head + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + } elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models cfg_dict = { @@ -1986,6 +2028,9 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture=="GraniteForCausalLM": + state_dict=convert_granite_weights(hf_model,cfg) + else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..23c36bbf8 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .granite import convert_granite_weights diff --git a/transformer_lens/pretrained/weight_conversions/granite.py b/transformer_lens/pretrained/weight_conversions/granite.py new file mode 100644 index 000000000..d6f26b293 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/granite.py @@ -0,0 +1,70 @@ +# In transformer_lens/pretrained/weight_conversions/granite.py + +from typing import cast +import einops +import torch +from transformers import GraniteForCausalLM +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + +def convert_granite_weights( + hf_model: GraniteForCausalLM, cfg: HookedTransformerConfig +) -> dict[str, torch.Tensor]: + """ + Converts the weights of a Hugging Face GraniteForCausalLM model to the format + used by HookedTransformer, correctly handling Grouped-Query Attention (GQA) + and weight transpositions. + """ + state_dict = {} + + # Token Embeddings + state_dict["embed.W_E"] = hf_model.model.embed_tokens.weight + + # Safely get the number of key-value heads for GQA + # This is the number of heads for the Key and Value projections + n_kv_heads = cast(int, cfg.n_key_value_heads) + + for l in range(cfg.n_layers): + # LayerNorm 1 (before attention) + state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight + + # Attention weights + W_Q = hf_model.model.layers[l].self_attn.q_proj.weight + W_K = hf_model.model.layers[l].self_attn.k_proj.weight + W_V = hf_model.model.layers[l].self_attn.v_proj.weight + W_O = hf_model.model.layers[l].self_attn.o_proj.weight + + # Reshape weights for TransformerLens internal format. + + # W_Q uses the main number of heads (n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( + W_Q, "(n h) m -> n m h", n=cfg.n_heads, h=cfg.d_head + ) + + # W_K and W_V use the smaller number of heads for GQA (n_kv_heads) + # This is the line that fixes the bug. + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( + W_K, "(n h) m -> n m h", n=n_kv_heads, h=cfg.d_head + ) + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( + W_V, "(n h) m -> n m h", n=n_kv_heads, h=cfg.d_head + ) + + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( + W_O, "m (n h) -> n h m", n=cfg.n_heads, h=cfg.d_head + ) + + # LayerNorm 2 (before MLP) + state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[l].post_attention_layernorm.weight + + # MLP weights (transpose is necessary) + state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T + + # Final LayerNorm + state_dict["ln_final.w"] = hf_model.model.norm.weight + + # Unembedding weights (transpose is necessary) + state_dict["unembed.W_U"] = hf_model.lm_head.weight.T + + return state_dict \ No newline at end of file diff --git a/transformer_lens/utilities/activation_functions.py b/transformer_lens/utilities/activation_functions.py index 6cc701360..91a9972f1 100644 --- a/transformer_lens/utilities/activation_functions.py +++ b/transformer_lens/utilities/activation_functions.py @@ -23,4 +23,5 @@ "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), + "swiglu": lambda x: F.silu(torch.chunk(x, 2, dim=-1)[0]) * torch.chunk(x, 2, dim=-1)[1], } From 0d31a7bf990df522845226f70a0fc456ccc47af8 Mon Sep 17 00:00:00 2001 From: emharsha1812 Date: Sun, 13 Jul 2025 18:10:43 +0530 Subject: [PATCH 2/4] Added Granite 3.3. Support --- test_granite.py | 40 -------- transformer_lens/loading_from_pretrained.py | 45 ++++----- .../pretrained/weight_conversions/granite.py | 98 +++++++++++-------- 3 files changed, 72 insertions(+), 111 deletions(-) delete mode 100644 test_granite.py diff --git a/test_granite.py b/test_granite.py deleted file mode 100644 index ad104fbf0..000000000 --- a/test_granite.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from transformer_lens import HookedTransformer - -# The Hugging Face model name you added -MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct" - -print(f"--- Loading model: {MODEL_NAME} ---") - -# Set the device -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using device: {device}") - -# Load the model from Hugging Face -# This will use all the code you've just written! -model = HookedTransformer.from_pretrained( - MODEL_NAME, - device=device, - # You might need to use a lower-precision dtype if you're low on VRAM - # torch_dtype=torch.bfloat16 -) - -print("\n--- Model Configuration ---") -# Print the config to double-check that your parameters were loaded correctly -print(model.cfg) - -print("\n--- Running Generation Test ---") -prompt = "The best programming language is" - -# Generate some text -# The model will return token IDs, so we use .to_string() to convert back to text -output = model.generate( - prompt, - max_new_tokens=10, - temperature=0.7, -) - -print(f"\nPrompt: '{prompt}'") -print(f"Generated text: '{output}'") - -print("\n--- ✅ Test Complete ---") \ No newline at end of file diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 1e87f98a3..117f65476 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -264,7 +264,10 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", - "ibm-granite/granite-3.3-2b-instruct" + "ibm-granite/granite-3.3-2b-instruct", + "ibm-granite/granite-3.3-2b-base", + "ibm-granite/granite-3.3-8b-instruct", + "ibm-granite/granite-3.3-8b-base" ] """Official model names for models on HuggingFace.""" @@ -721,7 +724,10 @@ "google-t5/t5-base": ["t5-base"], "google-t5/t5-large": ["t5-large"], "ai-forever/mGPT": ["mGPT"], - "ibm-granite/granite-3.3-2b-instruct":["granite-3.3-2b","granite-3.3"] + "ibm-granite/granite-3.3-2b-instruct": ["granite-3.3-2b", "granite-3.3-2b-instruct"], + "ibm-granite/granite-3.3-2b-base": ["granite-3.3-2b-base"], + "ibm-granite/granite-3.3-8b-instruct": ["granite-3.3-8b", "granite-3.3-8b-instruct"], + "ibm-granite/granite-3.3-8b-base": ["granite-3.3-8b-base"] } """Model aliases for models on HuggingFace.""" @@ -797,8 +803,6 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): architecture = "Gemma2ForCausalLM" elif "gemma" in official_model_name.lower(): architecture = "GemmaForCausalLM" - # elif "granite" in official_model_name.lower(): - # architecture="GraniteForCausalLM" else: huggingface_token = os.environ.get("HF_TOKEN", "") hf_config = AutoConfig.from_pretrained( @@ -1441,43 +1445,28 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } - # elif architecture=="GraniteForCausalLM": - # ### Architecture for Granite 3.3 2b and Granite 3.3 2b instruct models - # cfg_dict = { - # "d_model": 2048, - # "d_head": 64, - # "n_heads": 32, - # "d_mlp": 8192, - # "n_layers": 40, - # "n_ctx": 2048, - # "eps": 1e-5, - # "d_vocab": 49152, - # "act_fn": "swiglu", # Based on common practice for Granite-like models - # "normalization_type": "RMS", - # "positional_embedding_type": "rotary", - # "gated_mlp":True, - # "rotary_dim": 64, - # "final_rms":True - # } elif architecture == "GraniteForCausalLM": - # This block now correctly assumes `hf_config` has been loaded + # Granite 3.3 models configuration cfg_dict = { "d_model": hf_config.hidden_size, - "n_layers": hf_config.num_hidden_layers, - "n_heads": hf_config.num_attention_heads, "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, "d_mlp": hf_config.intermediate_size, - "n_ctx": hf_config.max_position_embeddings, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": min(hf_config.max_position_embeddings, 2048), # Cap context length for memory "eps": hf_config.rms_norm_eps, "d_vocab": hf_config.vocab_size, - "act_fn": "swiglu", + "act_fn": "silu", "normalization_type": "RMS", "positional_embedding_type": "rotary", "n_key_value_heads": hf_config.num_key_value_heads, "gated_mlp": True, "final_rms": True, - # rotary_dim is often d_head + "rotary_adjacent_pairs": False, + "gated_mlp":True, ## Remove this later "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "rotary_base": getattr(hf_config, 'rope_theta', 10000.0), + "use_attn_scale": True, } elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models diff --git a/transformer_lens/pretrained/weight_conversions/granite.py b/transformer_lens/pretrained/weight_conversions/granite.py index d6f26b293..93abe5f3e 100644 --- a/transformer_lens/pretrained/weight_conversions/granite.py +++ b/transformer_lens/pretrained/weight_conversions/granite.py @@ -1,70 +1,82 @@ -# In transformer_lens/pretrained/weight_conversions/granite.py - from typing import cast import einops import torch -from transformers import GraniteForCausalLM from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -def convert_granite_weights( - hf_model: GraniteForCausalLM, cfg: HookedTransformerConfig -) -> dict[str, torch.Tensor]: + +def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): """ Converts the weights of a Hugging Face GraniteForCausalLM model to the format - used by HookedTransformer, correctly handling Grouped-Query Attention (GQA) - and weight transpositions. + used by HookedTransformer """ state_dict = {} - # Token Embeddings - state_dict["embed.W_E"] = hf_model.model.embed_tokens.weight + # Token Embeddings - move to the correct device + state_dict["embed.W_E"] = hf_model.model.embed_tokens.weight.to(device=cfg.device) - # Safely get the number of key-value heads for GQA - # This is the number of heads for the Key and Value projections - n_kv_heads = cast(int, cfg.n_key_value_heads) + # Granite architecture use Grouped Query Attention + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) for l in range(cfg.n_layers): - # LayerNorm 1 (before attention) - state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight + # LayerNorm 1 - move to the correct device + state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight.to(device=cfg.device) # Attention weights - W_Q = hf_model.model.layers[l].self_attn.q_proj.weight - W_K = hf_model.model.layers[l].self_attn.k_proj.weight - W_V = hf_model.model.layers[l].self_attn.v_proj.weight - W_O = hf_model.model.layers[l].self_attn.o_proj.weight + # Transpose the weights first, then rearrange + W_Q = hf_model.model.layers[l].self_attn.q_proj.weight.T + W_K = hf_model.model.layers[l].self_attn.k_proj.weight.T + W_V = hf_model.model.layers[l].self_attn.v_proj.weight.T + W_O = hf_model.model.layers[l].self_attn.o_proj.weight.T - # Reshape weights for TransformerLens internal format. + # Reshape weights for TransformerLens internal format + W_Q = einops.rearrange(W_Q, "m (n h) -> n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "m (n h) -> n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "m (n h) -> n m h", n=n_kv_heads) + W_O = einops.rearrange(W_O, "(n h) m -> n h m", n=cfg.n_heads) + + # Move weights to the correct device + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V.to(device=cfg.device) + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) - # W_Q uses the main number of heads (n_heads) - state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( - W_Q, "(n h) m -> n m h", n=cfg.n_heads, h=cfg.d_head + # Attention biases (Granite models don't use biases, so we set them to zero) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) - - # W_K and W_V use the smaller number of heads for GQA (n_kv_heads) - # This is the line that fixes the bug. - state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( - W_K, "(n h) m -> n m h", n=n_kv_heads, h=cfg.d_head + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) - state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( - W_V, "(n h) m -> n m h", n=n_kv_heads, h=cfg.d_head + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) - - state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( - W_O, "m (n h) -> n h m", n=cfg.n_heads, h=cfg.d_head + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device ) - # LayerNorm 2 (before MLP) - state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[l].post_attention_layernorm.weight + # LayerNorm 2 + state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[l].post_attention_layernorm.weight.to(device=cfg.device) + + # MLP weights for GatedMLP - move to the correct device + state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T.to(device=cfg.device) + state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T.to(device=cfg.device) + state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T.to(device=cfg.device) - # MLP weights (transpose is necessary) - state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T + # MLP biases (Granite models don't use biases, so we set them to zero) + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) - # Final LayerNorm - state_dict["ln_final.w"] = hf_model.model.norm.weight + # Final LayerNorm + state_dict["ln_final.w"] = hf_model.model.norm.weight.to(device=cfg.device) - # Unembedding weights (transpose is necessary) - state_dict["unembed.W_U"] = hf_model.lm_head.weight.T + # Unembedding weights + state_dict["unembed.W_U"] = hf_model.lm_head.weight.T.to(device=cfg.device) + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) return state_dict \ No newline at end of file From bb53d1d34a6be5c37b57938775d7f7631b2e9f7f Mon Sep 17 00:00:00 2001 From: emharsha1812 Date: Sun, 13 Jul 2025 18:27:13 +0530 Subject: [PATCH 3/4] Fixed imports --- transformer_lens/pretrained/weight_conversions/granite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_lens/pretrained/weight_conversions/granite.py b/transformer_lens/pretrained/weight_conversions/granite.py index 93abe5f3e..c303a19a0 100644 --- a/transformer_lens/pretrained/weight_conversions/granite.py +++ b/transformer_lens/pretrained/weight_conversions/granite.py @@ -1,6 +1,9 @@ from typing import cast import einops + + import torch + from transformer_lens.HookedTransformerConfig import HookedTransformerConfig From 6fa6bea5208f318ac8a7e8fa7f1650348f4bacb6 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 22 Jul 2025 19:40:53 +0200 Subject: [PATCH 4/4] ran format --- transformer_lens/loading_from_pretrained.py | 20 +++++------ .../pretrained/weight_conversions/granite.py | 33 ++++++++++++------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 117f65476..366df75f8 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -30,6 +30,7 @@ convert_gemma_weights, convert_gpt2_weights, convert_gptj_weights, + convert_granite_weights, convert_llama_weights, convert_mingpt_weights, convert_mistral_weights, @@ -44,7 +45,6 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, - convert_granite_weights ) OFFICIAL_MODEL_NAMES = [ @@ -266,8 +266,8 @@ "ai-forever/mGPT", "ibm-granite/granite-3.3-2b-instruct", "ibm-granite/granite-3.3-2b-base", - "ibm-granite/granite-3.3-8b-instruct", - "ibm-granite/granite-3.3-8b-base" + "ibm-granite/granite-3.3-8b-instruct", + "ibm-granite/granite-3.3-8b-base", ] """Official model names for models on HuggingFace.""" @@ -727,7 +727,7 @@ "ibm-granite/granite-3.3-2b-instruct": ["granite-3.3-2b", "granite-3.3-2b-instruct"], "ibm-granite/granite-3.3-2b-base": ["granite-3.3-2b-base"], "ibm-granite/granite-3.3-8b-instruct": ["granite-3.3-8b", "granite-3.3-8b-instruct"], - "ibm-granite/granite-3.3-8b-base": ["granite-3.3-8b-base"] + "ibm-granite/granite-3.3-8b-base": ["granite-3.3-8b-base"], } """Model aliases for models on HuggingFace.""" @@ -1456,16 +1456,16 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "n_ctx": min(hf_config.max_position_embeddings, 2048), # Cap context length for memory "eps": hf_config.rms_norm_eps, "d_vocab": hf_config.vocab_size, - "act_fn": "silu", + "act_fn": "silu", "normalization_type": "RMS", "positional_embedding_type": "rotary", "n_key_value_heads": hf_config.num_key_value_heads, "gated_mlp": True, "final_rms": True, "rotary_adjacent_pairs": False, - "gated_mlp":True, ## Remove this later + "gated_mlp": True, ## Remove this later "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, - "rotary_base": getattr(hf_config, 'rope_theta', 10000.0), + "rotary_base": getattr(hf_config, "rope_theta", 10000.0), "use_attn_scale": True, } elif official_model_name.startswith("google/gemma-2b"): @@ -2017,9 +2017,9 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) - elif cfg.original_architecture=="GraniteForCausalLM": - state_dict=convert_granite_weights(hf_model,cfg) - + elif cfg.original_architecture == "GraniteForCausalLM": + state_dict = convert_granite_weights(hf_model, cfg) + else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/granite.py b/transformer_lens/pretrained/weight_conversions/granite.py index c303a19a0..e1451a985 100644 --- a/transformer_lens/pretrained/weight_conversions/granite.py +++ b/transformer_lens/pretrained/weight_conversions/granite.py @@ -1,7 +1,6 @@ from typing import cast -import einops - +import einops import torch from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -24,7 +23,9 @@ def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): for l in range(cfg.n_layers): # LayerNorm 1 - move to the correct device - state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight.to(device=cfg.device) + state_dict[f"blocks.{l}.ln1.w"] = hf_model.model.layers[l].input_layernorm.weight.to( + device=cfg.device + ) # Attention weights # Transpose the weights first, then rearrange @@ -44,7 +45,7 @@ def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K.to(device=cfg.device) state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V.to(device=cfg.device) state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) - + # Attention biases (Granite models don't use biases, so we set them to zero) state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device @@ -59,13 +60,21 @@ def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): cfg.d_model, dtype=cfg.dtype, device=cfg.device ) - # LayerNorm 2 - state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[l].post_attention_layernorm.weight.to(device=cfg.device) + # LayerNorm 2 + state_dict[f"blocks.{l}.ln2.w"] = hf_model.model.layers[ + l + ].post_attention_layernorm.weight.to(device=cfg.device) # MLP weights for GatedMLP - move to the correct device - state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T.to(device=cfg.device) - state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T.to(device=cfg.device) - state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T.to(device=cfg.device) + state_dict[f"blocks.{l}.mlp.W_in"] = hf_model.model.layers[l].mlp.up_proj.weight.T.to( + device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.W_gate"] = hf_model.model.layers[l].mlp.gate_proj.weight.T.to( + device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.W_out"] = hf_model.model.layers[l].mlp.down_proj.weight.T.to( + device=cfg.device + ) # MLP biases (Granite models don't use biases, so we set them to zero) state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( @@ -75,11 +84,11 @@ def convert_granite_weights(hf_model, cfg: HookedTransformerConfig): cfg.d_model, dtype=cfg.dtype, device=cfg.device ) - # Final LayerNorm + # Final LayerNorm state_dict["ln_final.w"] = hf_model.model.norm.weight.to(device=cfg.device) - + # Unembedding weights state_dict["unembed.W_U"] = hf_model.lm_head.weight.T.to(device=cfg.device) state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) - return state_dict \ No newline at end of file + return state_dict