From 99db861c4c3a0bf51d9abc3638dacab095f90516 Mon Sep 17 00:00:00 2001 From: Luis Pinto Date: Sun, 12 Oct 2025 12:48:05 -0400 Subject: [PATCH 1/2] LoRA added w/ layer exceptions Signed-off-by: Luis Pinto --- models/.DS_Store | Bin 6148 -> 0 bytes models/pos_egnn/posegnn/__init__.py | 4 +- models/pos_egnn/posegnn/adapter/README.md | 50 ++ models/pos_egnn/posegnn/adapter/__init__.py | 11 + models/pos_egnn/posegnn/adapter/config.py | 14 + models/pos_egnn/posegnn/adapter/inject.py | 68 +++ models/pos_egnn/posegnn/adapter/layers.py | 78 +++ models/pos_egnn/posegnn/adapter/model.py | 118 ++++ models/pos_egnn/posegnn/adapter/test.ipynb | 583 ++++++++++++++++++++ 9 files changed, 924 insertions(+), 2 deletions(-) delete mode 100644 models/.DS_Store create mode 100644 models/pos_egnn/posegnn/adapter/README.md create mode 100644 models/pos_egnn/posegnn/adapter/__init__.py create mode 100644 models/pos_egnn/posegnn/adapter/config.py create mode 100644 models/pos_egnn/posegnn/adapter/inject.py create mode 100644 models/pos_egnn/posegnn/adapter/layers.py create mode 100644 models/pos_egnn/posegnn/adapter/model.py create mode 100644 models/pos_egnn/posegnn/adapter/test.ipynb diff --git a/models/.DS_Store b/models/.DS_Store deleted file mode 100644 index 92789af487f8ccf3d007aa264e246f83edaaa21f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKK~BR!4D^;BQk6py7moXd9=U}oa6>}s13)D-sgx)YfqVYI4|og@;suP?R!G!F za6%QbCGTv!_O3ID;+Tls>NQ;wEr=+BGEQa~4hXNa?#Vy~-*{Hbu3;YSa{`^-cb9HRhTA(zDKcC!e?2 zj^@-ZJ zsM##hT+l{iz!)$FItKXr;GvAMVk;Ov9T-9j0Gz@c1as~sIL0f+imf14AWlMo66& tuple[int, int]: + """ + Replace leaf linear-like layers under include patterns with LoRA. + Skips any module that has a non-identity .activation to guarantee mergeability. + Returns (num_scalar_wrapped, 0). + """ + include_patterns = list(cfg.include_names or []) + exclude_patterns = list(cfg.exclude_names or []) + if getattr(cfg, "preset", None) == "posegnn" and not include_patterns: + include_patterns = [r"^encoder\.", r"^readout\."] + + inc_re = [re.compile(p) for p in include_patterns] + exc_re = [re.compile(p) for p in exclude_patterns] + + def wants(name: str) -> bool: + if any(p.search(name) for p in exc_re): + return False + if inc_re and not any(p.search(name) for p in inc_re): + return False + return True + + def is_linear_like(m: nn.Module) -> bool: + w = getattr(m, "weight", None) + if isinstance(m, nn.Embedding): + return False + return isinstance(w, torch.Tensor) and w.ndim == 2 + + def has_post_act(m: nn.Module) -> bool: + act = getattr(m, "activation", None) + return (act is not None) and (not isinstance(act, nn.Identity)) + + n_scalar = 0 + skipped = [] # <— track skipped post-activation linears + + for full_name, module in list(model.named_modules()): + if not is_linear_like(module): + continue + if not wants(full_name): + continue + if has_post_act(module): + skipped.append(full_name) # <— record and skip + continue + + parent_name, _, child = full_name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + + if hasattr(module, "base") and hasattr(module, "lora_A") and hasattr(module, "lora_B"): + continue + + wrapped = LoRALinear( + module, cfg.rank, cfg.alpha, cfg.dropout, cfg.merge_on_save, cfg.freeze_base + ) + setattr(parent, child, wrapped) + n_scalar += 1 + + if getattr(cfg, "log_skipped", False) and skipped: + print("[lora] skipped post-activation linears:") + for n in skipped: + print(" -", n) + + return n_scalar, 0 \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/layers.py b/models/pos_egnn/posegnn/adapter/layers.py new file mode 100644 index 0000000..974002c --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/layers.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +from typing import Optional + +def _init_lora(linear: nn.Linear, freeze_base: bool): + if freeze_base: + linear.weight.requires_grad = False + if linear.bias is not None: + linear.bias.requires_grad = False + +class LoRALinear(nn.Module): + """ + LoRA for linear layers: + y = base(x) + scaling * B(A(dropout(x))) + """ + def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float], + dropout: float, merge_on_save: bool, freeze_base: bool): + super().__init__() + assert isinstance(base_linear, nn.Linear) + self.base = base_linear + _init_lora(self.base, freeze_base) + + self.in_features = base_linear.in_features + self.out_features = base_linear.out_features + self.r = int(rank) + self.lora_alpha = float(alpha) if alpha is not None else float(self.r) + self.scaling = self.lora_alpha / max(self.r, 1) + self.enable_lora = True + self.merged = False + + self._post_act = getattr(base_linear, "activation", None) + self._has_post_act = self._post_act is not None and not isinstance(self._post_act, nn.Identity) + self.merge_on_save = bool(merge_on_save and not self._has_post_act) + + self.lora_dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity() + self.lora_A = nn.Linear(self.in_features, self.r, bias=False) + self.lora_B = nn.Linear(self.r, self.out_features, bias=False) + + nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5) + nn.init.zeros_(self.lora_B.weight) + + if self.merge_on_save: + self._register_state_dict_hook(self._merge_on_state_dict) + self._register_load_state_dict_pre_hook(self._strict_fill_on_load, with_module=True) + + def forward(self, x): + y = self.base(x) + if self._has_post_act: + y = self._post_act(y) + if self.enable_lora and self.r > 0: + z = self.lora_dropout(x) + z = self.lora_A(z) + z = self.lora_B(z) + y = y + self.scaling * z + return y + + @torch.no_grad() + def merged_weight(self): + if self._has_post_act: + return self.base.weight + return self.base.weight + self.scaling * (self.lora_B.weight @ self.lora_A.weight) + + def _merge_on_state_dict(self, module, state_dict, prefix, local_metadata): + # replace the tensor stored at base.weight with merged values + key_w = prefix + "base.weight" + if key_w in state_dict: + state_dict[key_w] = self.merged_weight() + # drop adapter tensors from the saved dict + state_dict.pop(prefix + "lora_A.weight", None) + state_dict.pop(prefix + "lora_B.weight", None) + return state_dict + + @torch.no_grad() + def _strict_fill_on_load(self, module, state_dict, prefix, local_metadata, strict, missing, unexpected, errors): + for k, ref in [(prefix + "lora_A.weight", self.lora_A.weight), + (prefix + "lora_B.weight", self.lora_B.weight)]: + if k not in state_dict: + state_dict[k] = torch.zeros_like(ref) \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/model.py b/models/pos_egnn/posegnn/adapter/model.py new file mode 100644 index 0000000..3f89464 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/model.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from collections import OrderedDict +from .config import LoRAConfig +from .inject import apply_lora + +def _is_lora_module(m: nn.Module) -> bool: + return hasattr(m, "base") and hasattr(m, "lora_A") and hasattr(m, "lora_B") + +class PosEGNNLoRAModel(nn.Module): + """ + Wrap a PosEGNN backbone, inject LoRA only into mergeable linear layers + (post-activation linears are skipped by the injector), and expose a merged export. + """ + def __init__(self, backbone: nn.Module, lora_config: LoRAConfig = LoRAConfig()): + super().__init__() + self.backbone = backbone + self.lora_config = lora_config + + # Injector must skip any module with a non-identity .activation + self.n_scalar, _ = apply_lora(self.backbone, self.lora_config) + + if getattr(lora_config, "freeze_base", False): + for p in self.backbone.parameters(): + p.requires_grad_(False) + for p in self.lora_parameters(): + p.requires_grad_(True) + + self._adapters_enabled = True + + def forward(self, *args, **kwargs): + return self.backbone(*args, **kwargs) + + def lora_parameters(self): + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + yield from m.lora_A.parameters() + yield from m.lora_B.parameters() + + @torch.no_grad() + def enable_adapter(self): + self._adapters_enabled = True + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + setattr(m, "enable_lora", True) + + @torch.no_grad() + def disable_adapter(self): + self._adapters_enabled = False + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + setattr(m, "enable_lora", False) + + @torch.no_grad() + def _export_merged_state_dict(self) -> OrderedDict: + """ + Build a plain PosEGNN state dict (original layout): + - merge adapter weights into '...weight' + - copy '...bias' + - copy other base.* params with 'base.' stripped + - drop all '...lora_*' + """ + sd_out = OrderedDict() + sd_in = self.backbone.state_dict() + wrapper_paths = {name for name, m in self.backbone.named_modules() if _is_lora_module(m)} + + def wrapper_path_for(key: str): + for p in wrapper_paths: + if key.startswith(p + "."): + return p + return None + + for key, val in sd_in.items(): + # drop adapter tensors + if ".lora_A." in key or ".lora_B." in key: + continue + + wp = wrapper_path_for(key) + if wp is not None: + if key.endswith(".base.weight"): + path = key[:-len(".base.weight")] + mod = self.backbone.get_submodule(path) # LoRA wrapper + sd_out[path + ".weight"] = mod.merged_weight() + continue + if key.endswith(".base.bias"): + sd_out[key.replace(".base.bias", ".bias")] = val + continue + prefix = f"{wp}.base." + if key.startswith(prefix): + sd_out[key.replace(prefix, f"{wp}.", 1)] = val + continue + continue # ignore other wrapper internals + + # passthrough for non-wrapper keys + sd_out[key] = val + + return sd_out + + @torch.no_grad() + def state_dict_backbone(self, merged: bool = False) -> OrderedDict: + """ + Return backbone-only weights. If merged=True, adapters are fused and + keys match the original PosEGNN layout. + """ + if merged: + return self._export_merged_state_dict() + return self.backbone.state_dict() + + @torch.no_grad() + def save_pretrained_merged(self, path: str): + torch.save(self.state_dict_backbone(merged=True), path) + + @torch.no_grad() + def save_pretrained_adapted(self, path: str): + torch.save(self.state_dict(), path) + + def lora_report(self) -> str: + return f"LoRA injected - scalar layers: {self.n_scalar}" \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/test.ipynb b/models/pos_egnn/posegnn/adapter/test.ipynb new file mode 100644 index 0000000..1feeb8d --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/test.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# POS-EGNN ELoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from ase import Atoms as ASEAtoms\n", + "from ase.io import read\n", + "from torch_geometric.data import Data, Batch\n", + "\n", + "import sys\n", + "sys.path.append('../../')\n", + "from posegnn.model import PosEGNN\n", + "from posegnn.adapter import PosEGNNLoRAModel, LoRAConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "atoms = read(\"../../inputs/3BPA.xyz\", index=0)\n", + "\n", + "def build_data_from_ase(atoms: ASEAtoms) -> Data:\n", + " z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)\n", + " box = torch.tensor(atoms.get_cell().tolist()).unsqueeze(0).float()\n", + " pos = torch.tensor(atoms.get_positions().tolist()).float()\n", + " batch = torch.zeros(len(z), dtype=torch.long)\n", + " return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1)\n", + "\n", + "data = build_data_from_ase(atoms)\n", + "batch = Batch.from_data_list([data])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get LoRA model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[lora] skipped post-activation linears:\n", + " - encoder.neighbor_embedding.combine.dense_layers.0\n", + " - encoder.edge_embedding.edge_up.dense_layers.0\n", + " - encoder.gata.0.gamma_s.0\n", + " - encoder.gata.0.gamma_v.0\n", + " - encoder.gata.0.phik_w_ra\n", + " - encoder.gata.0.edge_attr_up.dense_layers.0\n", + " - encoder.gata.1.gamma_s.0\n", + " - encoder.gata.1.gamma_v.0\n", + " - encoder.gata.1.phik_w_ra\n", + " - encoder.gata.1.edge_attr_up.dense_layers.0\n", + " - encoder.gata.2.gamma_s.0\n", + " - encoder.gata.2.gamma_v.0\n", + " - encoder.gata.2.phik_w_ra\n", + " - encoder.gata.2.edge_attr_up.dense_layers.0\n", + " - encoder.gata.3.gamma_s.0\n", + " - encoder.gata.3.gamma_v.0\n", + " - encoder.gata.3.phik_w_ra\n", + " - encoder.eqff.0.gamma_m.0\n", + " - encoder.eqff.1.gamma_m.0\n", + " - encoder.eqff.2.gamma_m.0\n", + " - encoder.eqff.3.gamma_m.0\n", + "LoRA injected - scalar layers: 48\n" + ] + } + ], + "source": [ + "# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn\n", + "checkpoint_dict = torch.load('../../pytorch_model.bin', weights_only=True, map_location='cpu')\n", + "backbone = PosEGNN(checkpoint_dict[\"config\"])\n", + "backbone.load_state_dict(checkpoint_dict[\"state_dict\"], strict=True)\n", + "\n", + "cfg = LoRAConfig(\n", + " rank=16,\n", + " alpha=16, # uses alpha = rank by default\n", + " dropout=0.0,\n", + " merge_on_save=True, # saves merged weights for compatibility\n", + " freeze_base=True, # train only adapters\n", + " log_skipped=True\n", + ")\n", + "model = PosEGNNLoRAModel(backbone, cfg)\n", + "print(model.lora_report())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def random_SO3(dtype, device):\n", + " A = torch.randn(3, 3, dtype=dtype, device=device)\n", + " Q, _ = torch.linalg.qr(A)\n", + " if torch.det(Q) < 0:\n", + " Q[:, 0] = -Q[:, 0]\n", + " return Q\n", + "\n", + "def rotate_atoms_data(data: Data, R: torch.Tensor) -> Data:\n", + " # Rotate positions and box. Keep everything else.\n", + " pos = data.pos @ R.T\n", + " if hasattr(data, \"box\") and data.box is not None:\n", + " # box is shape [1, 3, 3] or [3, 3]\n", + " box = data.box\n", + " if box.dim() == 2:\n", + " box_rot = box @ R.T\n", + " box_rot = box_rot.unsqueeze(0)\n", + " else:\n", + " box_rot = box @ R.T\n", + " else:\n", + " box_rot = None\n", + " new = Data(\n", + " z=data.z.clone() if hasattr(data, \"z\") else None,\n", + " pos=pos,\n", + " box=box_rot,\n", + " batch=data.batch.clone() if hasattr(data, \"batch\") else None,\n", + " num_graphs=getattr(data, \"num_graphs\", None),\n", + " )\n", + " return new\n", + "\n", + "def act_block_lastdim(vec, R, block=3):\n", + " # Apply R to each 3-vector block along the last dim\n", + " # vec [..., 3k], R [3, 3]\n", + " c = vec.shape[-1]\n", + " assert c % block == 0, \"embedding last dim is not a multiple of 3\"\n", + " k = c // block\n", + " v = vec.view(*vec.shape[:-1], k, block)\n", + " vR = torch.einsum(\"...bi,ij->...bj\", v, R)\n", + " return vR.reshape(*vec.shape)\n", + "\n", + "@torch.no_grad()\n", + "def cosine(a, b, eps=1e-12):\n", + " num = (a * b).sum()\n", + " den = a.norm() * b.norm() + eps\n", + " return float((num / den).clamp(-1, 1))\n", + "\n", + "# ---------- 0) prep ----------\n", + "device = torch.device(\"cpu\")\n", + "model.eval()\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Determinism" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[determinism] embedding_0 max|Δ| = 0.000e+00\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " o1 = model(batch)\n", + " o2 = model(batch)\n", + "for k in o1.keys():\n", + " if torch.is_tensor(o1[k]):\n", + " diff = (o1[k] - o2[k]).abs().max().item()\n", + " print(f\"[determinism] {k:<20} max|Δ| = {diff:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check LoRA parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wrapped layers: 48\n", + "encoder.neighbor_embedding.distance_proj.dense_layers.0\n", + "encoder.neighbor_embedding.combine.dense_layers.1\n", + "encoder.edge_embedding.edge_up.dense_layers.1\n", + "encoder.gata.0.gamma_s.1\n", + "encoder.gata.0.q_w\n", + "encoder.gata.0.k_w\n", + "encoder.gata.0.gamma_v.1\n", + "encoder.gata.0.vecq_w\n", + "encoder.gata.0.veck_w.0\n", + "encoder.gata.0.veck_w.1\n" + ] + } + ], + "source": [ + "# how many layers got LoRA\n", + "wrapped = []\n", + "for name, m in model.backbone.named_modules():\n", + " if hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"):\n", + " wrapped.append(name)\n", + "len_wrapped = len(wrapped)\n", + "print(\"wrapped layers:\", len_wrapped)\n", + "print(\"\\n\".join(wrapped[:10]))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total trainable params: 504896\n", + "backbone.encoder.neighbor_embedding.distance_proj.dense_layers.0.lora_A.weight (16, 64)\n" + ] + } + ], + "source": [ + "total_trainable = 0\n", + "for n, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " total_trainable += p.numel()\n", + "print(\"total trainable params:\", total_trainable)\n", + "\n", + "# show a few LoRA shapes\n", + "for n, p in model.named_parameters():\n", + " if \"lora_A\" in n or \"lora_B\" in n:\n", + " print(n, tuple(p.shape))\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rotation sanity on embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[embed] shape: (27, 256, 1, 4)\n", + "[embed] invariance max|Δ| = 3.338e-06\n" + ] + } + ], + "source": [ + "R = random_SO3(dtype=batch.pos.dtype, device=batch.pos.device)\n", + "batch_R = Batch.from_data_list([rotate_atoms_data(data, R)])\n", + "\n", + "with torch.no_grad():\n", + " out = model(batch)\n", + " out_R = model(batch_R)\n", + "\n", + "if \"embedding_0\" in out and torch.is_tensor(out[\"embedding_0\"]):\n", + " e = out[\"embedding_0\"]\n", + " eR = out_R[\"embedding_0\"]\n", + " print(f\"[embed] shape: {tuple(e.shape)}\")\n", + "\n", + " # Invariance check\n", + " inv_err = (e - eR).abs().max().item()\n", + " print(f\"[embed] invariance max|Δ| = {inv_err:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Energy invariance and force covariance (if energy available)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# energy_key = None\n", + "# for k in [\"energy\", \"y_energy\", \"E\", \"total_energy\"]:\n", + "# if k in out:\n", + "# energy_key = k\n", + "# break\n", + "\n", + "# if energy_key is not None:\n", + "# # Build fresh tensors with grad for force test\n", + "# d1 = Batch.from_data_list([build_data_from_ase(atoms)])\n", + "# d1.pos.requires_grad_(True)\n", + "# E1 = model(d1)[energy_key] # scalar\n", + "\n", + "# d2 = Batch.from_data_list([rotate_atoms_data(build_data_from_ase(atoms), R)])\n", + "# d2.pos.requires_grad_(True)\n", + "# E2 = model(d2)[energy_key] # scalar\n", + "\n", + "# # energy invariance\n", + "# e_err = (E2.detach() - E1.detach()).abs().item()\n", + "# print(f\"[energy] |E(Rx) - E(x)| = {e_err:.3e}\")\n", + "\n", + "# # forces = -dE/dx, covariance: F(Rx) = R F(x)\n", + "# (F1,) = torch.autograd.grad(E1, d1.pos, retain_graph=False)\n", + "# (F2,) = torch.autograd.grad(E2, d2.pos, retain_graph=False)\n", + "# F1_equiv = F1 @ R.T\n", + "# f_err = (F2 - F1_equiv).abs().max().item()\n", + "# print(f\"[forces] covariance max|Δ| = {f_err:.3e}\")\n", + "# else:\n", + "# print(\"[energy] no energy key found in outputs. Skipping energy/force checks.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LoRA merge correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[merge] missing: 0, unexpected: 0\n", + "[merge] max abs diff = 0.000e+00\n", + " embedding_0: 0.000e+00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_31481/430695950.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state = torch.load(buf, map_location=\"cpu\")\n" + ] + } + ], + "source": [ + "import io\n", + "import torch\n", + "\n", + "def _as_float_tensor(x):\n", + " return x.detach().to(dtype=torch.float32)\n", + "\n", + "def _compare_outputs(a, b):\n", + " if isinstance(a, dict) and isinstance(b, dict):\n", + " keys = sorted(set(a.keys()) & set(b.keys()))\n", + " diffs = {}\n", + " for k in keys:\n", + " ta = _as_float_tensor(a[k])\n", + " tb = _as_float_tensor(b[k])\n", + " diffs[k] = (ta - tb).abs().max().item()\n", + " return max(diffs.values()) if diffs else 0.0, diffs\n", + " ta = _as_float_tensor(a)\n", + " tb = _as_float_tensor(b)\n", + " d = (ta - tb).abs().max().item()\n", + " return d, {\"_\": d}\n", + "\n", + "try:\n", + " # 0) make sure dropout is off for determinism\n", + " model.eval()\n", + "\n", + " # 1) reference with adapters ENABLED (this is what the merged backbone should match)\n", + " model.enable_adapter()\n", + " with torch.no_grad():\n", + " o_ref = model(batch)\n", + "\n", + " # 2) export merged backbone (plain PosEGNN keys) into an in-memory buffer\n", + " buf = io.BytesIO()\n", + " merged_sd = model.state_dict_backbone(merged=True)\n", + " torch.save(merged_sd, buf)\n", + " buf.seek(0)\n", + "\n", + " # 3) load merged weights into a fresh plain backbone\n", + " backbone2 = PosEGNN(checkpoint_dict[\"config\"])\n", + " state = torch.load(buf, map_location=\"cpu\")\n", + " missing, unexpected = backbone2.load_state_dict(state, strict=True)\n", + " print(f\"[merge] missing: {len(missing)}, unexpected: {len(unexpected)}\") # expect 0, 0\n", + "\n", + " # 4) numerics: merged backbone vs adapter-enabled model\n", + " backbone2.eval()\n", + " with torch.no_grad():\n", + " o_merge = backbone2(batch)\n", + "\n", + " max_diff, per_key = _compare_outputs(o_ref, o_merge)\n", + " print(f\"[merge] max abs diff = {max_diff:.3e}\")\n", + " if isinstance(per_key, dict):\n", + " for k, v in per_key.items():\n", + " print(f\" {k}: {v:.3e}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"[merge] test failed: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[wrapped + post-act] 0\n" + ] + } + ], + "source": [ + "def find_nonmergeable_wrapped(mod):\n", + " bad = []\n", + " for name, m in model.backbone.named_modules():\n", + " if hasattr(m, \"base\") and hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"):\n", + " act = getattr(m.base, \"activation\", None)\n", + " if act is not None and not isinstance(act, torch.nn.Identity):\n", + " bad.append(name)\n", + " return bad\n", + "\n", + "bad = find_nonmergeable_wrapped(model)\n", + "print(\"[wrapped + post-act]\", len(bad))\n", + "for n in bad[:10]:\n", + " print(\" \", n)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[audit] non-mergeable wrapped layers (have activation):\n", + "count: 0\n" + ] + } + ], + "source": [ + "def audit_nonmergeable(model):\n", + " bad = []\n", + " for name, m in model.backbone.named_modules():\n", + " if hasattr(m, \"base\") and hasattr(m.base, \"weight\"):\n", + " act = getattr(m.base, \"activation\", None)\n", + " has_post = act is not None and not isinstance(act, torch.nn.Identity)\n", + " if has_post:\n", + " bad.append((name, type(act).__name__))\n", + " print(\"[audit] non-mergeable wrapped layers (have activation):\")\n", + " for n, a in bad:\n", + " print(f\" {n} (activation={a})\")\n", + " print(f\"count: {len(bad)}\")\n", + " return bad\n", + "\n", + "_ = audit_nonmergeable(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check requires grad" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[grads] LoRA params with grad: 74, base params with grad: 0\n" + ] + } + ], + "source": [ + "model.train()\n", + "for n, p in model.named_parameters():\n", + " if not p.requires_grad:\n", + " continue\n", + " p.grad = None\n", + "loss = 0.0\n", + "out_train = model(batch)\n", + "\n", + "if \"embedding_0\" in out_train and torch.is_tensor(out_train[\"embedding_0\"]):\n", + " loss = out_train[\"embedding_0\"].pow(2).mean()\n", + "else:\n", + " # fallback: sum of any float tensor in outputs\n", + " for v in out_train.values():\n", + " if torch.is_tensor(v) and v.dtype.is_floating_point:\n", + " loss = v.sum()\n", + " break\n", + "loss.backward()\n", + "\n", + "num_lora_grads = 0\n", + "num_base_grads = 0\n", + "for n, p in model.named_parameters():\n", + " if p.grad is None:\n", + " continue\n", + " if \"lora\" in n.lower():\n", + " num_lora_grads += 1\n", + " else:\n", + " num_base_grads += 1\n", + "print(f\"[grads] LoRA params with grad: {num_lora_grads}, base params with grad: {num_base_grads}\")\n", + "model.eval();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 569772bf0990d0dbf9ea213f878d2387ecbaaaf3 Mon Sep 17 00:00:00 2001 From: Luis Pinto Date: Mon, 13 Oct 2025 18:52:57 -0400 Subject: [PATCH 2/2] posegnn LoRA added to all layers Signed-off-by: Luis Pinto --- models/pos_egnn/posegnn/adapter/README.md | 28 --- models/pos_egnn/posegnn/adapter/inject.py | 17 +- models/pos_egnn/posegnn/adapter/layers.py | 45 ++++- models/pos_egnn/posegnn/adapter/test.ipynb | 212 ++++++++++++++------- 4 files changed, 181 insertions(+), 121 deletions(-) diff --git a/models/pos_egnn/posegnn/adapter/README.md b/models/pos_egnn/posegnn/adapter/README.md index d0234ea..1d222f8 100644 --- a/models/pos_egnn/posegnn/adapter/README.md +++ b/models/pos_egnn/posegnn/adapter/README.md @@ -2,34 +2,6 @@ This adapter injects LoRA into mergeable linear layers of **PosEGNN** and exports merged weights that load into a plain `PosEGNN` with `strict=True`. -## Skipped layers - -These layers have a built-in activation inside their Dense block, which makes algebraic merging incorrect. They are always skipped so that merged exports match adapter-enabled outputs exactly. - -- `encoder.neighbor_embedding.combine.dense_layers.0` -- `encoder.edge_embedding.edge_up.dense_layers.0` -- `encoder.gata.0.gamma_s.0` -- `encoder.gata.0.gamma_v.0` -- `encoder.gata.0.phik_w_ra` -- `encoder.gata.0.edge_attr_up.dense_layers.0` -- `encoder.gata.1.gamma_s.0` -- `encoder.gata.1.gamma_v.0` -- `encoder.gata.1.phik_w_ra` -- `encoder.gata.1.edge_attr_up.dense_layers.0` -- `encoder.gata.2.gamma_s.0` -- `encoder.gata.2.gamma_v.0` -- `encoder.gata.2.phik_w_ra` -- `encoder.gata.2.edge_attr_up.dense_layers.0` -- `encoder.gata.3.gamma_s.0` -- `encoder.gata.3.gamma_v.0` -- `encoder.gata.3.phik_w_ra` -- `encoder.eqff.0.gamma_m.0` -- `encoder.eqff.1.gamma_m.0` -- `encoder.eqff.2.gamma_m.0` -- `encoder.eqff.3.gamma_m.0` - -Skipping only affects where LoRA is attached. The base model behavior is unchanged. - ## Usage ```python diff --git a/models/pos_egnn/posegnn/adapter/inject.py b/models/pos_egnn/posegnn/adapter/inject.py index becbff6..94f0aa9 100644 --- a/models/pos_egnn/posegnn/adapter/inject.py +++ b/models/pos_egnn/posegnn/adapter/inject.py @@ -1,4 +1,3 @@ -# inject.py import re import torch import torch.nn as nn @@ -8,7 +7,7 @@ def apply_lora(model: nn.Module, cfg: LoRAConfig) -> tuple[int, int]: """ Replace leaf linear-like layers under include patterns with LoRA. - Skips any module that has a non-identity .activation to guarantee mergeability. + Safely wraps linears with internal norm/activation since LoRA is pre-activation. Returns (num_scalar_wrapped, 0). """ include_patterns = list(cfg.include_names or []) @@ -32,25 +31,18 @@ def is_linear_like(m: nn.Module) -> bool: return False return isinstance(w, torch.Tensor) and w.ndim == 2 - def has_post_act(m: nn.Module) -> bool: - act = getattr(m, "activation", None) - return (act is not None) and (not isinstance(act, nn.Identity)) - n_scalar = 0 - skipped = [] # <— track skipped post-activation linears for full_name, module in list(model.named_modules()): if not is_linear_like(module): continue if not wants(full_name): continue - if has_post_act(module): - skipped.append(full_name) # <— record and skip - continue parent_name, _, child = full_name.rpartition(".") parent = model.get_submodule(parent_name) if parent_name else model + # already wrapped guard if hasattr(module, "base") and hasattr(module, "lora_A") and hasattr(module, "lora_B"): continue @@ -60,9 +52,4 @@ def has_post_act(m: nn.Module) -> bool: setattr(parent, child, wrapped) n_scalar += 1 - if getattr(cfg, "log_skipped", False) and skipped: - print("[lora] skipped post-activation linears:") - for n in skipped: - print(" -", n) - return n_scalar, 0 \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/layers.py b/models/pos_egnn/posegnn/adapter/layers.py index 974002c..cb32fc5 100644 --- a/models/pos_egnn/posegnn/adapter/layers.py +++ b/models/pos_egnn/posegnn/adapter/layers.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from typing import Optional def _init_lora(linear: nn.Linear, freeze_base: bool): @@ -10,8 +11,8 @@ def _init_lora(linear: nn.Linear, freeze_base: bool): class LoRALinear(nn.Module): """ - LoRA for linear layers: - y = base(x) + scaling * B(A(dropout(x))) + LoRA for linear layers applied pre-activation: + y = act( norm( (W x + b) + scaling * B(A(dropout(x))) ) ) """ def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float], dropout: float, merge_on_save: bool, freeze_base: bool): @@ -28,13 +29,21 @@ def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float], self.enable_lora = True self.merged = False + # Optional submodules carried by custom Dense + self._norm = getattr(base_linear, "norm", None) + if not isinstance(self._norm, nn.Module): + self._norm = None + self._post_act = getattr(base_linear, "activation", None) self._has_post_act = self._post_act is not None and not isinstance(self._post_act, nn.Identity) - self.merge_on_save = bool(merge_on_save and not self._has_post_act) + # Always allow merge on save now that we inject pre-activation + self.merge_on_save = bool(merge_on_save) + + # LoRA adapters self.lora_dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity() - self.lora_A = nn.Linear(self.in_features, self.r, bias=False) - self.lora_B = nn.Linear(self.r, self.out_features, bias=False) + self.lora_A = nn.Linear(self.in_features, self.r, bias=False) # down + self.lora_B = nn.Linear(self.r, self.out_features, bias=False) # up nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5) nn.init.zeros_(self.lora_B.weight) @@ -43,21 +52,37 @@ def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float], self._register_state_dict_hook(self._merge_on_state_dict) self._register_load_state_dict_pre_hook(self._strict_fill_on_load, with_module=True) + def _apply_activation(self, y): + if not self._has_post_act: + return y + act = self._post_act + # support nn.Module or callable (e.g. torch.nn.functional.silu) + if isinstance(act, nn.Module): + return act(y) + if callable(act): + return act(y) + return y + def forward(self, x): - y = self.base(x) - if self._has_post_act: - y = self._post_act(y) + # linear pre-activation + y = F.linear(x, self.base.weight, self.base.bias) + + # add LoRA delta pre-activation if self.enable_lora and self.r > 0: z = self.lora_dropout(x) z = self.lora_A(z) z = self.lora_B(z) y = y + self.scaling * z + + # optional norm then activation + if self._norm is not None: + y = self._norm(y) + y = self._apply_activation(y) return y @torch.no_grad() def merged_weight(self): - if self._has_post_act: - return self.base.weight + # Always valid since injected pre-activation return self.base.weight + self.scaling * (self.lora_B.weight @ self.lora_A.weight) def _merge_on_state_dict(self, module, state_dict, prefix, local_metadata): diff --git a/models/pos_egnn/posegnn/adapter/test.ipynb b/models/pos_egnn/posegnn/adapter/test.ipynb index 1feeb8d..ca18b41 100644 --- a/models/pos_egnn/posegnn/adapter/test.ipynb +++ b/models/pos_egnn/posegnn/adapter/test.ipynb @@ -14,6 +14,7 @@ "outputs": [], "source": [ "import math\n", + "import copy\n", "import torch\n", "import torch.nn as nn\n", "\n", @@ -69,35 +70,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "[lora] skipped post-activation linears:\n", - " - encoder.neighbor_embedding.combine.dense_layers.0\n", - " - encoder.edge_embedding.edge_up.dense_layers.0\n", - " - encoder.gata.0.gamma_s.0\n", - " - encoder.gata.0.gamma_v.0\n", - " - encoder.gata.0.phik_w_ra\n", - " - encoder.gata.0.edge_attr_up.dense_layers.0\n", - " - encoder.gata.1.gamma_s.0\n", - " - encoder.gata.1.gamma_v.0\n", - " - encoder.gata.1.phik_w_ra\n", - " - encoder.gata.1.edge_attr_up.dense_layers.0\n", - " - encoder.gata.2.gamma_s.0\n", - " - encoder.gata.2.gamma_v.0\n", - " - encoder.gata.2.phik_w_ra\n", - " - encoder.gata.2.edge_attr_up.dense_layers.0\n", - " - encoder.gata.3.gamma_s.0\n", - " - encoder.gata.3.gamma_v.0\n", - " - encoder.gata.3.phik_w_ra\n", - " - encoder.eqff.0.gamma_m.0\n", - " - encoder.eqff.1.gamma_m.0\n", - " - encoder.eqff.2.gamma_m.0\n", - " - encoder.eqff.3.gamma_m.0\n", - "LoRA injected - scalar layers: 48\n" + "LoRA injected - scalar layers: 69\n" ] } ], "source": [ "# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn\n", - "checkpoint_dict = torch.load('../../pytorch_model.bin', weights_only=True, map_location='cpu')\n", + "checkpoint_dict = torch.load('../../to_delete/pytorch_model.bin', weights_only=True, map_location='cpu')\n", "backbone = PosEGNN(checkpoint_dict[\"config\"])\n", "backbone.load_state_dict(checkpoint_dict[\"state_dict\"], strict=True)\n", "\n", @@ -113,6 +92,66 @@ "print(model.lora_report())" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[merge] missing: 0 unexpected: 0\n", + "[merge] max abs diff = 0.000e+00\n", + " embedding_0: 0.000e+00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_38964/3267473328.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state = torch.load(buf, map_location=\"cpu\")\n" + ] + } + ], + "source": [ + "# adapter-enabled reference\n", + "model.eval()\n", + "model.enable_adapter()\n", + "with torch.no_grad():\n", + " o_ref = model(batch)\n", + "\n", + "# merged export\n", + "import io\n", + "buf = io.BytesIO()\n", + "torch.save(model.state_dict_backbone(merged=True), buf)\n", + "buf.seek(0)\n", + "\n", + "plain = PosEGNN(checkpoint_dict[\"config\"])\n", + "state = torch.load(buf, map_location=\"cpu\")\n", + "missing, unexpected = plain.load_state_dict(state, strict=True)\n", + "print(\"[merge] missing:\", len(missing), \"unexpected:\", len(unexpected)) # expect 0, 0\n", + "\n", + "plain.eval()\n", + "with torch.no_grad():\n", + " o_merge = plain(batch)\n", + "\n", + "def _as32(x): return x.detach().to(torch.float32)\n", + "def _cmp(a, b):\n", + " if isinstance(a, dict) and isinstance(b, dict):\n", + " ks = sorted(set(a) & set(b))\n", + " diffs = {k: (_as32(a[k]) - _as32(b[k])).abs().max().item() for k in ks}\n", + " return max(diffs.values()) if diffs else 0.0, diffs\n", + " d = (_as32(a) - _as32(b)).abs().max().item()\n", + " return d, {\"_\": d}\n", + "\n", + "mx, per = _cmp(o_ref, o_merge)\n", + "print(f\"[merge] max abs diff = {mx:.3e}\")\n", + "for k, v in per.items():\n", + " print(f\" {k}: {v:.3e}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -122,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -186,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -216,24 +255,24 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "wrapped layers: 48\n", + "wrapped layers: 69\n", "encoder.neighbor_embedding.distance_proj.dense_layers.0\n", + "encoder.neighbor_embedding.combine.dense_layers.0\n", "encoder.neighbor_embedding.combine.dense_layers.1\n", + "encoder.edge_embedding.edge_up.dense_layers.0\n", "encoder.edge_embedding.edge_up.dense_layers.1\n", + "encoder.gata.0.gamma_s.0\n", "encoder.gata.0.gamma_s.1\n", "encoder.gata.0.q_w\n", "encoder.gata.0.k_w\n", - "encoder.gata.0.gamma_v.1\n", - "encoder.gata.0.vecq_w\n", - "encoder.gata.0.veck_w.0\n", - "encoder.gata.0.veck_w.1\n" + "encoder.gata.0.gamma_v.0\n" ] } ], @@ -250,14 +289,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "total trainable params: 504896\n", + "total trainable params: 690240\n", "backbone.encoder.neighbor_embedding.distance_proj.dense_layers.0.lora_A.weight (16, 64)\n" ] } @@ -285,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -324,7 +363,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -367,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -383,7 +422,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_31481/430695950.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_38964/430695950.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " state = torch.load(buf, map_location=\"cpu\")\n" ] } @@ -447,63 +486,100 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[wrapped + post-act] 0\n" + "[audit] previously non-mergeable wrapped layers (have activation):\n", + " encoder.neighbor_embedding.combine.dense_layers.0 (activation=SiLU)\n", + " encoder.edge_embedding.edge_up.dense_layers.0 (activation=function)\n", + " encoder.gata.0.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.0.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.0.phik_w_ra (activation=SiLU)\n", + " encoder.gata.0.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.1.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.1.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.1.phik_w_ra (activation=SiLU)\n", + " encoder.gata.1.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.2.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.2.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.2.phik_w_ra (activation=SiLU)\n", + " encoder.gata.2.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.3.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.3.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.3.phik_w_ra (activation=SiLU)\n", + " encoder.eqff.0.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.1.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.2.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.3.gamma_m.0 (activation=SiLU)\n", + "count: 21\n" ] } ], "source": [ - "def find_nonmergeable_wrapped(mod):\n", + "def audit_nonmergeable(model):\n", " bad = []\n", " for name, m in model.backbone.named_modules():\n", - " if hasattr(m, \"base\") and hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"):\n", + " if hasattr(m, \"base\") and hasattr(m.base, \"weight\"):\n", " act = getattr(m.base, \"activation\", None)\n", - " if act is not None and not isinstance(act, torch.nn.Identity):\n", - " bad.append(name)\n", + " has_post = act is not None and not isinstance(act, torch.nn.Identity)\n", + " if has_post:\n", + " bad.append((name, type(act).__name__))\n", + " print(\"[audit] previously non-mergeable wrapped layers (have activation):\")\n", + " for n, a in bad:\n", + " print(f\" {n} (activation={a})\")\n", + " print(f\"count: {len(bad)}\")\n", " return bad\n", "\n", - "bad = find_nonmergeable_wrapped(model)\n", - "print(\"[wrapped + post-act]\", len(bad))\n", - "for n in bad[:10]:\n", - " print(\" \", n)" + "_ = audit_nonmergeable(model)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[audit] non-mergeable wrapped layers (have activation):\n", - "count: 0\n" + "encoder.neighbor_embedding.combine.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.edge_embedding.edge_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.gamma_s.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.gamma_v.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.phik_w_ra -> LoRALinear has LoRA: True\n", + "encoder.gata.0.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.gamma_s.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.gamma_v.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.phik_w_ra -> LoRALinear has LoRA: True\n", + "encoder.gata.1.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.2.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.eqff.0.gamma_m.0 -> LoRALinear has LoRA: True\n", + "encoder.eqff.3.gamma_m.0 -> LoRALinear has LoRA: True\n" ] } ], "source": [ - "def audit_nonmergeable(model):\n", - " bad = []\n", - " for name, m in model.backbone.named_modules():\n", - " if hasattr(m, \"base\") and hasattr(m.base, \"weight\"):\n", - " act = getattr(m.base, \"activation\", None)\n", - " has_post = act is not None and not isinstance(act, torch.nn.Identity)\n", - " if has_post:\n", - " bad.append((name, type(act).__name__))\n", - " print(\"[audit] non-mergeable wrapped layers (have activation):\")\n", - " for n, a in bad:\n", - " print(f\" {n} (activation={a})\")\n", - " print(f\"count: {len(bad)}\")\n", - " return bad\n", - "\n", - "_ = audit_nonmergeable(model)" + "for name in [\n", + " \"encoder.neighbor_embedding.combine.dense_layers.0\",\n", + " \"encoder.edge_embedding.edge_up.dense_layers.0\",\n", + " \"encoder.gata.0.gamma_s.0\",\n", + " \"encoder.gata.0.gamma_v.0\",\n", + " \"encoder.gata.0.phik_w_ra\",\n", + " \"encoder.gata.0.edge_attr_up.dense_layers.0\",\n", + " \"encoder.gata.1.gamma_s.0\",\n", + " \"encoder.gata.1.gamma_v.0\",\n", + " \"encoder.gata.1.phik_w_ra\",\n", + " \"encoder.gata.1.edge_attr_up.dense_layers.0\",\n", + " \"encoder.gata.2.edge_attr_up.dense_layers.0\",\n", + " \"encoder.eqff.0.gamma_m.0\",\n", + " \"encoder.eqff.3.gamma_m.0\",\n", + "]:\n", + " m = model.backbone.get_submodule(name)\n", + " print(name, \"->\", type(m).__name__, \"has LoRA:\", hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"))" ] }, { @@ -515,14 +591,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[grads] LoRA params with grad: 74, base params with grad: 0\n" + "[grads] LoRA params with grad: 106, base params with grad: 0\n" ] } ],