diff --git a/models/.DS_Store b/models/.DS_Store deleted file mode 100644 index 92789af..0000000 Binary files a/models/.DS_Store and /dev/null differ diff --git a/models/pos_egnn/posegnn/__init__.py b/models/pos_egnn/posegnn/__init__.py index 9fb1220..97b0cdc 100644 --- a/models/pos_egnn/posegnn/__init__.py +++ b/models/pos_egnn/posegnn/__init__.py @@ -1,3 +1,3 @@ -from . import calculator, encoder, model, ops, utils +from . import calculator, encoder, model, ops, utils, adapter -__all__ = ["calculator", "encoder", "model", "ops", "utils"] \ No newline at end of file +__all__ = ["calculator", "encoder", "model", "ops", "utils", "adapter"] \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/README.md b/models/pos_egnn/posegnn/adapter/README.md new file mode 100644 index 0000000..1d222f8 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/README.md @@ -0,0 +1,22 @@ +# PosEGNN + LoRA + +This adapter injects LoRA into mergeable linear layers of **PosEGNN** and exports merged weights that load into a plain `PosEGNN` with `strict=True`. + +## Usage + +```python +# 1) build and load the backbone +backbone = PosEGNN(checkpoint_dict["config"]) +backbone.load_state_dict(checkpoint_dict["state_dict"], strict=True) + +# 2) wrap with LoRA (post-activation linears are skipped automatically) +cfg = LoRAConfig(rank=8, alpha=8, dropout=0.0, freeze_base=True, merge_on_save=True) +model = PosEGNNLoRAModel(backbone, cfg) + +# 3) train or evaluate +out = model(batch) + +# 4) export merged weights that load into plain PosEGNN +merged = model.state_dict_backbone(merged=True) +plain = PosEGNN(checkpoint_dict["config"]) +plain.load_state_dict(merged, strict=True) \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/__init__.py b/models/pos_egnn/posegnn/adapter/__init__.py new file mode 100644 index 0000000..56528dd --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/__init__.py @@ -0,0 +1,11 @@ +from .config import LoRAConfig +from .layers import LoRALinear +from .model import PosEGNNLoRAModel +from .inject import apply_lora + +__all__ = [ + "LoRAConfig", + "LoRALinear", + "apply_lora", + "PosEGNNLoRAModel", +] diff --git a/models/pos_egnn/posegnn/adapter/config.py b/models/pos_egnn/posegnn/adapter/config.py new file mode 100644 index 0000000..2c68659 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/config.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Optional, Sequence + +@dataclass +class LoRAConfig: + rank: int = 8 + alpha: Optional[float] = None + dropout: float = 0.0 + merge_on_save: bool = True + freeze_base: bool = True + include_names: Optional[Sequence[str]] = None + exclude_names: Optional[Sequence[str]] = None + preset: Optional[str] = "posegnn" + log_skipped: bool = False \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/inject.py b/models/pos_egnn/posegnn/adapter/inject.py new file mode 100644 index 0000000..94f0aa9 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/inject.py @@ -0,0 +1,55 @@ +import re +import torch +import torch.nn as nn +from .layers import LoRALinear +from .config import LoRAConfig + +def apply_lora(model: nn.Module, cfg: LoRAConfig) -> tuple[int, int]: + """ + Replace leaf linear-like layers under include patterns with LoRA. + Safely wraps linears with internal norm/activation since LoRA is pre-activation. + Returns (num_scalar_wrapped, 0). + """ + include_patterns = list(cfg.include_names or []) + exclude_patterns = list(cfg.exclude_names or []) + if getattr(cfg, "preset", None) == "posegnn" and not include_patterns: + include_patterns = [r"^encoder\.", r"^readout\."] + + inc_re = [re.compile(p) for p in include_patterns] + exc_re = [re.compile(p) for p in exclude_patterns] + + def wants(name: str) -> bool: + if any(p.search(name) for p in exc_re): + return False + if inc_re and not any(p.search(name) for p in inc_re): + return False + return True + + def is_linear_like(m: nn.Module) -> bool: + w = getattr(m, "weight", None) + if isinstance(m, nn.Embedding): + return False + return isinstance(w, torch.Tensor) and w.ndim == 2 + + n_scalar = 0 + + for full_name, module in list(model.named_modules()): + if not is_linear_like(module): + continue + if not wants(full_name): + continue + + parent_name, _, child = full_name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + + # already wrapped guard + if hasattr(module, "base") and hasattr(module, "lora_A") and hasattr(module, "lora_B"): + continue + + wrapped = LoRALinear( + module, cfg.rank, cfg.alpha, cfg.dropout, cfg.merge_on_save, cfg.freeze_base + ) + setattr(parent, child, wrapped) + n_scalar += 1 + + return n_scalar, 0 \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/layers.py b/models/pos_egnn/posegnn/adapter/layers.py new file mode 100644 index 0000000..cb32fc5 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/layers.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + +def _init_lora(linear: nn.Linear, freeze_base: bool): + if freeze_base: + linear.weight.requires_grad = False + if linear.bias is not None: + linear.bias.requires_grad = False + +class LoRALinear(nn.Module): + """ + LoRA for linear layers applied pre-activation: + y = act( norm( (W x + b) + scaling * B(A(dropout(x))) ) ) + """ + def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float], + dropout: float, merge_on_save: bool, freeze_base: bool): + super().__init__() + assert isinstance(base_linear, nn.Linear) + self.base = base_linear + _init_lora(self.base, freeze_base) + + self.in_features = base_linear.in_features + self.out_features = base_linear.out_features + self.r = int(rank) + self.lora_alpha = float(alpha) if alpha is not None else float(self.r) + self.scaling = self.lora_alpha / max(self.r, 1) + self.enable_lora = True + self.merged = False + + # Optional submodules carried by custom Dense + self._norm = getattr(base_linear, "norm", None) + if not isinstance(self._norm, nn.Module): + self._norm = None + + self._post_act = getattr(base_linear, "activation", None) + self._has_post_act = self._post_act is not None and not isinstance(self._post_act, nn.Identity) + + # Always allow merge on save now that we inject pre-activation + self.merge_on_save = bool(merge_on_save) + + # LoRA adapters + self.lora_dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity() + self.lora_A = nn.Linear(self.in_features, self.r, bias=False) # down + self.lora_B = nn.Linear(self.r, self.out_features, bias=False) # up + + nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5) + nn.init.zeros_(self.lora_B.weight) + + if self.merge_on_save: + self._register_state_dict_hook(self._merge_on_state_dict) + self._register_load_state_dict_pre_hook(self._strict_fill_on_load, with_module=True) + + def _apply_activation(self, y): + if not self._has_post_act: + return y + act = self._post_act + # support nn.Module or callable (e.g. torch.nn.functional.silu) + if isinstance(act, nn.Module): + return act(y) + if callable(act): + return act(y) + return y + + def forward(self, x): + # linear pre-activation + y = F.linear(x, self.base.weight, self.base.bias) + + # add LoRA delta pre-activation + if self.enable_lora and self.r > 0: + z = self.lora_dropout(x) + z = self.lora_A(z) + z = self.lora_B(z) + y = y + self.scaling * z + + # optional norm then activation + if self._norm is not None: + y = self._norm(y) + y = self._apply_activation(y) + return y + + @torch.no_grad() + def merged_weight(self): + # Always valid since injected pre-activation + return self.base.weight + self.scaling * (self.lora_B.weight @ self.lora_A.weight) + + def _merge_on_state_dict(self, module, state_dict, prefix, local_metadata): + # replace the tensor stored at base.weight with merged values + key_w = prefix + "base.weight" + if key_w in state_dict: + state_dict[key_w] = self.merged_weight() + # drop adapter tensors from the saved dict + state_dict.pop(prefix + "lora_A.weight", None) + state_dict.pop(prefix + "lora_B.weight", None) + return state_dict + + @torch.no_grad() + def _strict_fill_on_load(self, module, state_dict, prefix, local_metadata, strict, missing, unexpected, errors): + for k, ref in [(prefix + "lora_A.weight", self.lora_A.weight), + (prefix + "lora_B.weight", self.lora_B.weight)]: + if k not in state_dict: + state_dict[k] = torch.zeros_like(ref) \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/model.py b/models/pos_egnn/posegnn/adapter/model.py new file mode 100644 index 0000000..3f89464 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/model.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from collections import OrderedDict +from .config import LoRAConfig +from .inject import apply_lora + +def _is_lora_module(m: nn.Module) -> bool: + return hasattr(m, "base") and hasattr(m, "lora_A") and hasattr(m, "lora_B") + +class PosEGNNLoRAModel(nn.Module): + """ + Wrap a PosEGNN backbone, inject LoRA only into mergeable linear layers + (post-activation linears are skipped by the injector), and expose a merged export. + """ + def __init__(self, backbone: nn.Module, lora_config: LoRAConfig = LoRAConfig()): + super().__init__() + self.backbone = backbone + self.lora_config = lora_config + + # Injector must skip any module with a non-identity .activation + self.n_scalar, _ = apply_lora(self.backbone, self.lora_config) + + if getattr(lora_config, "freeze_base", False): + for p in self.backbone.parameters(): + p.requires_grad_(False) + for p in self.lora_parameters(): + p.requires_grad_(True) + + self._adapters_enabled = True + + def forward(self, *args, **kwargs): + return self.backbone(*args, **kwargs) + + def lora_parameters(self): + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + yield from m.lora_A.parameters() + yield from m.lora_B.parameters() + + @torch.no_grad() + def enable_adapter(self): + self._adapters_enabled = True + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + setattr(m, "enable_lora", True) + + @torch.no_grad() + def disable_adapter(self): + self._adapters_enabled = False + for _, m in self.backbone.named_modules(): + if _is_lora_module(m): + setattr(m, "enable_lora", False) + + @torch.no_grad() + def _export_merged_state_dict(self) -> OrderedDict: + """ + Build a plain PosEGNN state dict (original layout): + - merge adapter weights into '...weight' + - copy '...bias' + - copy other base.* params with 'base.' stripped + - drop all '...lora_*' + """ + sd_out = OrderedDict() + sd_in = self.backbone.state_dict() + wrapper_paths = {name for name, m in self.backbone.named_modules() if _is_lora_module(m)} + + def wrapper_path_for(key: str): + for p in wrapper_paths: + if key.startswith(p + "."): + return p + return None + + for key, val in sd_in.items(): + # drop adapter tensors + if ".lora_A." in key or ".lora_B." in key: + continue + + wp = wrapper_path_for(key) + if wp is not None: + if key.endswith(".base.weight"): + path = key[:-len(".base.weight")] + mod = self.backbone.get_submodule(path) # LoRA wrapper + sd_out[path + ".weight"] = mod.merged_weight() + continue + if key.endswith(".base.bias"): + sd_out[key.replace(".base.bias", ".bias")] = val + continue + prefix = f"{wp}.base." + if key.startswith(prefix): + sd_out[key.replace(prefix, f"{wp}.", 1)] = val + continue + continue # ignore other wrapper internals + + # passthrough for non-wrapper keys + sd_out[key] = val + + return sd_out + + @torch.no_grad() + def state_dict_backbone(self, merged: bool = False) -> OrderedDict: + """ + Return backbone-only weights. If merged=True, adapters are fused and + keys match the original PosEGNN layout. + """ + if merged: + return self._export_merged_state_dict() + return self.backbone.state_dict() + + @torch.no_grad() + def save_pretrained_merged(self, path: str): + torch.save(self.state_dict_backbone(merged=True), path) + + @torch.no_grad() + def save_pretrained_adapted(self, path: str): + torch.save(self.state_dict(), path) + + def lora_report(self) -> str: + return f"LoRA injected - scalar layers: {self.n_scalar}" \ No newline at end of file diff --git a/models/pos_egnn/posegnn/adapter/test.ipynb b/models/pos_egnn/posegnn/adapter/test.ipynb new file mode 100644 index 0000000..ca18b41 --- /dev/null +++ b/models/pos_egnn/posegnn/adapter/test.ipynb @@ -0,0 +1,659 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# POS-EGNN ELoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import copy\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from ase import Atoms as ASEAtoms\n", + "from ase.io import read\n", + "from torch_geometric.data import Data, Batch\n", + "\n", + "import sys\n", + "sys.path.append('../../')\n", + "from posegnn.model import PosEGNN\n", + "from posegnn.adapter import PosEGNNLoRAModel, LoRAConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "atoms = read(\"../../inputs/3BPA.xyz\", index=0)\n", + "\n", + "def build_data_from_ase(atoms: ASEAtoms) -> Data:\n", + " z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)\n", + " box = torch.tensor(atoms.get_cell().tolist()).unsqueeze(0).float()\n", + " pos = torch.tensor(atoms.get_positions().tolist()).float()\n", + " batch = torch.zeros(len(z), dtype=torch.long)\n", + " return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1)\n", + "\n", + "data = build_data_from_ase(atoms)\n", + "batch = Batch.from_data_list([data])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get LoRA model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LoRA injected - scalar layers: 69\n" + ] + } + ], + "source": [ + "# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn\n", + "checkpoint_dict = torch.load('../../to_delete/pytorch_model.bin', weights_only=True, map_location='cpu')\n", + "backbone = PosEGNN(checkpoint_dict[\"config\"])\n", + "backbone.load_state_dict(checkpoint_dict[\"state_dict\"], strict=True)\n", + "\n", + "cfg = LoRAConfig(\n", + " rank=16,\n", + " alpha=16, # uses alpha = rank by default\n", + " dropout=0.0,\n", + " merge_on_save=True, # saves merged weights for compatibility\n", + " freeze_base=True, # train only adapters\n", + " log_skipped=True\n", + ")\n", + "model = PosEGNNLoRAModel(backbone, cfg)\n", + "print(model.lora_report())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[merge] missing: 0 unexpected: 0\n", + "[merge] max abs diff = 0.000e+00\n", + " embedding_0: 0.000e+00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_38964/3267473328.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state = torch.load(buf, map_location=\"cpu\")\n" + ] + } + ], + "source": [ + "# adapter-enabled reference\n", + "model.eval()\n", + "model.enable_adapter()\n", + "with torch.no_grad():\n", + " o_ref = model(batch)\n", + "\n", + "# merged export\n", + "import io\n", + "buf = io.BytesIO()\n", + "torch.save(model.state_dict_backbone(merged=True), buf)\n", + "buf.seek(0)\n", + "\n", + "plain = PosEGNN(checkpoint_dict[\"config\"])\n", + "state = torch.load(buf, map_location=\"cpu\")\n", + "missing, unexpected = plain.load_state_dict(state, strict=True)\n", + "print(\"[merge] missing:\", len(missing), \"unexpected:\", len(unexpected)) # expect 0, 0\n", + "\n", + "plain.eval()\n", + "with torch.no_grad():\n", + " o_merge = plain(batch)\n", + "\n", + "def _as32(x): return x.detach().to(torch.float32)\n", + "def _cmp(a, b):\n", + " if isinstance(a, dict) and isinstance(b, dict):\n", + " ks = sorted(set(a) & set(b))\n", + " diffs = {k: (_as32(a[k]) - _as32(b[k])).abs().max().item() for k in ks}\n", + " return max(diffs.values()) if diffs else 0.0, diffs\n", + " d = (_as32(a) - _as32(b)).abs().max().item()\n", + " return d, {\"_\": d}\n", + "\n", + "mx, per = _cmp(o_ref, o_merge)\n", + "print(f\"[merge] max abs diff = {mx:.3e}\")\n", + "for k, v in per.items():\n", + " print(f\" {k}: {v:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def random_SO3(dtype, device):\n", + " A = torch.randn(3, 3, dtype=dtype, device=device)\n", + " Q, _ = torch.linalg.qr(A)\n", + " if torch.det(Q) < 0:\n", + " Q[:, 0] = -Q[:, 0]\n", + " return Q\n", + "\n", + "def rotate_atoms_data(data: Data, R: torch.Tensor) -> Data:\n", + " # Rotate positions and box. Keep everything else.\n", + " pos = data.pos @ R.T\n", + " if hasattr(data, \"box\") and data.box is not None:\n", + " # box is shape [1, 3, 3] or [3, 3]\n", + " box = data.box\n", + " if box.dim() == 2:\n", + " box_rot = box @ R.T\n", + " box_rot = box_rot.unsqueeze(0)\n", + " else:\n", + " box_rot = box @ R.T\n", + " else:\n", + " box_rot = None\n", + " new = Data(\n", + " z=data.z.clone() if hasattr(data, \"z\") else None,\n", + " pos=pos,\n", + " box=box_rot,\n", + " batch=data.batch.clone() if hasattr(data, \"batch\") else None,\n", + " num_graphs=getattr(data, \"num_graphs\", None),\n", + " )\n", + " return new\n", + "\n", + "def act_block_lastdim(vec, R, block=3):\n", + " # Apply R to each 3-vector block along the last dim\n", + " # vec [..., 3k], R [3, 3]\n", + " c = vec.shape[-1]\n", + " assert c % block == 0, \"embedding last dim is not a multiple of 3\"\n", + " k = c // block\n", + " v = vec.view(*vec.shape[:-1], k, block)\n", + " vR = torch.einsum(\"...bi,ij->...bj\", v, R)\n", + " return vR.reshape(*vec.shape)\n", + "\n", + "@torch.no_grad()\n", + "def cosine(a, b, eps=1e-12):\n", + " num = (a * b).sum()\n", + " den = a.norm() * b.norm() + eps\n", + " return float((num / den).clamp(-1, 1))\n", + "\n", + "# ---------- 0) prep ----------\n", + "device = torch.device(\"cpu\")\n", + "model.eval()\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Determinism" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[determinism] embedding_0 max|Δ| = 0.000e+00\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " o1 = model(batch)\n", + " o2 = model(batch)\n", + "for k in o1.keys():\n", + " if torch.is_tensor(o1[k]):\n", + " diff = (o1[k] - o2[k]).abs().max().item()\n", + " print(f\"[determinism] {k:<20} max|Δ| = {diff:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check LoRA parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wrapped layers: 69\n", + "encoder.neighbor_embedding.distance_proj.dense_layers.0\n", + "encoder.neighbor_embedding.combine.dense_layers.0\n", + "encoder.neighbor_embedding.combine.dense_layers.1\n", + "encoder.edge_embedding.edge_up.dense_layers.0\n", + "encoder.edge_embedding.edge_up.dense_layers.1\n", + "encoder.gata.0.gamma_s.0\n", + "encoder.gata.0.gamma_s.1\n", + "encoder.gata.0.q_w\n", + "encoder.gata.0.k_w\n", + "encoder.gata.0.gamma_v.0\n" + ] + } + ], + "source": [ + "# how many layers got LoRA\n", + "wrapped = []\n", + "for name, m in model.backbone.named_modules():\n", + " if hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"):\n", + " wrapped.append(name)\n", + "len_wrapped = len(wrapped)\n", + "print(\"wrapped layers:\", len_wrapped)\n", + "print(\"\\n\".join(wrapped[:10]))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total trainable params: 690240\n", + "backbone.encoder.neighbor_embedding.distance_proj.dense_layers.0.lora_A.weight (16, 64)\n" + ] + } + ], + "source": [ + "total_trainable = 0\n", + "for n, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " total_trainable += p.numel()\n", + "print(\"total trainable params:\", total_trainable)\n", + "\n", + "# show a few LoRA shapes\n", + "for n, p in model.named_parameters():\n", + " if \"lora_A\" in n or \"lora_B\" in n:\n", + " print(n, tuple(p.shape))\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rotation sanity on embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[embed] shape: (27, 256, 1, 4)\n", + "[embed] invariance max|Δ| = 3.338e-06\n" + ] + } + ], + "source": [ + "R = random_SO3(dtype=batch.pos.dtype, device=batch.pos.device)\n", + "batch_R = Batch.from_data_list([rotate_atoms_data(data, R)])\n", + "\n", + "with torch.no_grad():\n", + " out = model(batch)\n", + " out_R = model(batch_R)\n", + "\n", + "if \"embedding_0\" in out and torch.is_tensor(out[\"embedding_0\"]):\n", + " e = out[\"embedding_0\"]\n", + " eR = out_R[\"embedding_0\"]\n", + " print(f\"[embed] shape: {tuple(e.shape)}\")\n", + "\n", + " # Invariance check\n", + " inv_err = (e - eR).abs().max().item()\n", + " print(f\"[embed] invariance max|Δ| = {inv_err:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Energy invariance and force covariance (if energy available)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# energy_key = None\n", + "# for k in [\"energy\", \"y_energy\", \"E\", \"total_energy\"]:\n", + "# if k in out:\n", + "# energy_key = k\n", + "# break\n", + "\n", + "# if energy_key is not None:\n", + "# # Build fresh tensors with grad for force test\n", + "# d1 = Batch.from_data_list([build_data_from_ase(atoms)])\n", + "# d1.pos.requires_grad_(True)\n", + "# E1 = model(d1)[energy_key] # scalar\n", + "\n", + "# d2 = Batch.from_data_list([rotate_atoms_data(build_data_from_ase(atoms), R)])\n", + "# d2.pos.requires_grad_(True)\n", + "# E2 = model(d2)[energy_key] # scalar\n", + "\n", + "# # energy invariance\n", + "# e_err = (E2.detach() - E1.detach()).abs().item()\n", + "# print(f\"[energy] |E(Rx) - E(x)| = {e_err:.3e}\")\n", + "\n", + "# # forces = -dE/dx, covariance: F(Rx) = R F(x)\n", + "# (F1,) = torch.autograd.grad(E1, d1.pos, retain_graph=False)\n", + "# (F2,) = torch.autograd.grad(E2, d2.pos, retain_graph=False)\n", + "# F1_equiv = F1 @ R.T\n", + "# f_err = (F2 - F1_equiv).abs().max().item()\n", + "# print(f\"[forces] covariance max|Δ| = {f_err:.3e}\")\n", + "# else:\n", + "# print(\"[energy] no energy key found in outputs. Skipping energy/force checks.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LoRA merge correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[merge] missing: 0, unexpected: 0\n", + "[merge] max abs diff = 0.000e+00\n", + " embedding_0: 0.000e+00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fg/r5cyn4ss41s84ytnzqjn68bw0000gn/T/ipykernel_38964/430695950.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state = torch.load(buf, map_location=\"cpu\")\n" + ] + } + ], + "source": [ + "import io\n", + "import torch\n", + "\n", + "def _as_float_tensor(x):\n", + " return x.detach().to(dtype=torch.float32)\n", + "\n", + "def _compare_outputs(a, b):\n", + " if isinstance(a, dict) and isinstance(b, dict):\n", + " keys = sorted(set(a.keys()) & set(b.keys()))\n", + " diffs = {}\n", + " for k in keys:\n", + " ta = _as_float_tensor(a[k])\n", + " tb = _as_float_tensor(b[k])\n", + " diffs[k] = (ta - tb).abs().max().item()\n", + " return max(diffs.values()) if diffs else 0.0, diffs\n", + " ta = _as_float_tensor(a)\n", + " tb = _as_float_tensor(b)\n", + " d = (ta - tb).abs().max().item()\n", + " return d, {\"_\": d}\n", + "\n", + "try:\n", + " # 0) make sure dropout is off for determinism\n", + " model.eval()\n", + "\n", + " # 1) reference with adapters ENABLED (this is what the merged backbone should match)\n", + " model.enable_adapter()\n", + " with torch.no_grad():\n", + " o_ref = model(batch)\n", + "\n", + " # 2) export merged backbone (plain PosEGNN keys) into an in-memory buffer\n", + " buf = io.BytesIO()\n", + " merged_sd = model.state_dict_backbone(merged=True)\n", + " torch.save(merged_sd, buf)\n", + " buf.seek(0)\n", + "\n", + " # 3) load merged weights into a fresh plain backbone\n", + " backbone2 = PosEGNN(checkpoint_dict[\"config\"])\n", + " state = torch.load(buf, map_location=\"cpu\")\n", + " missing, unexpected = backbone2.load_state_dict(state, strict=True)\n", + " print(f\"[merge] missing: {len(missing)}, unexpected: {len(unexpected)}\") # expect 0, 0\n", + "\n", + " # 4) numerics: merged backbone vs adapter-enabled model\n", + " backbone2.eval()\n", + " with torch.no_grad():\n", + " o_merge = backbone2(batch)\n", + "\n", + " max_diff, per_key = _compare_outputs(o_ref, o_merge)\n", + " print(f\"[merge] max abs diff = {max_diff:.3e}\")\n", + " if isinstance(per_key, dict):\n", + " for k, v in per_key.items():\n", + " print(f\" {k}: {v:.3e}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"[merge] test failed: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[audit] previously non-mergeable wrapped layers (have activation):\n", + " encoder.neighbor_embedding.combine.dense_layers.0 (activation=SiLU)\n", + " encoder.edge_embedding.edge_up.dense_layers.0 (activation=function)\n", + " encoder.gata.0.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.0.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.0.phik_w_ra (activation=SiLU)\n", + " encoder.gata.0.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.1.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.1.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.1.phik_w_ra (activation=SiLU)\n", + " encoder.gata.1.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.2.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.2.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.2.phik_w_ra (activation=SiLU)\n", + " encoder.gata.2.edge_attr_up.dense_layers.0 (activation=SiLU)\n", + " encoder.gata.3.gamma_s.0 (activation=SiLU)\n", + " encoder.gata.3.gamma_v.0 (activation=SiLU)\n", + " encoder.gata.3.phik_w_ra (activation=SiLU)\n", + " encoder.eqff.0.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.1.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.2.gamma_m.0 (activation=SiLU)\n", + " encoder.eqff.3.gamma_m.0 (activation=SiLU)\n", + "count: 21\n" + ] + } + ], + "source": [ + "def audit_nonmergeable(model):\n", + " bad = []\n", + " for name, m in model.backbone.named_modules():\n", + " if hasattr(m, \"base\") and hasattr(m.base, \"weight\"):\n", + " act = getattr(m.base, \"activation\", None)\n", + " has_post = act is not None and not isinstance(act, torch.nn.Identity)\n", + " if has_post:\n", + " bad.append((name, type(act).__name__))\n", + " print(\"[audit] previously non-mergeable wrapped layers (have activation):\")\n", + " for n, a in bad:\n", + " print(f\" {n} (activation={a})\")\n", + " print(f\"count: {len(bad)}\")\n", + " return bad\n", + "\n", + "_ = audit_nonmergeable(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoder.neighbor_embedding.combine.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.edge_embedding.edge_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.gamma_s.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.gamma_v.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.0.phik_w_ra -> LoRALinear has LoRA: True\n", + "encoder.gata.0.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.gamma_s.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.gamma_v.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.1.phik_w_ra -> LoRALinear has LoRA: True\n", + "encoder.gata.1.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.gata.2.edge_attr_up.dense_layers.0 -> LoRALinear has LoRA: True\n", + "encoder.eqff.0.gamma_m.0 -> LoRALinear has LoRA: True\n", + "encoder.eqff.3.gamma_m.0 -> LoRALinear has LoRA: True\n" + ] + } + ], + "source": [ + "for name in [\n", + " \"encoder.neighbor_embedding.combine.dense_layers.0\",\n", + " \"encoder.edge_embedding.edge_up.dense_layers.0\",\n", + " \"encoder.gata.0.gamma_s.0\",\n", + " \"encoder.gata.0.gamma_v.0\",\n", + " \"encoder.gata.0.phik_w_ra\",\n", + " \"encoder.gata.0.edge_attr_up.dense_layers.0\",\n", + " \"encoder.gata.1.gamma_s.0\",\n", + " \"encoder.gata.1.gamma_v.0\",\n", + " \"encoder.gata.1.phik_w_ra\",\n", + " \"encoder.gata.1.edge_attr_up.dense_layers.0\",\n", + " \"encoder.gata.2.edge_attr_up.dense_layers.0\",\n", + " \"encoder.eqff.0.gamma_m.0\",\n", + " \"encoder.eqff.3.gamma_m.0\",\n", + "]:\n", + " m = model.backbone.get_submodule(name)\n", + " print(name, \"->\", type(m).__name__, \"has LoRA:\", hasattr(m, \"lora_A\") and hasattr(m, \"lora_B\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check requires grad" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[grads] LoRA params with grad: 106, base params with grad: 0\n" + ] + } + ], + "source": [ + "model.train()\n", + "for n, p in model.named_parameters():\n", + " if not p.requires_grad:\n", + " continue\n", + " p.grad = None\n", + "loss = 0.0\n", + "out_train = model(batch)\n", + "\n", + "if \"embedding_0\" in out_train and torch.is_tensor(out_train[\"embedding_0\"]):\n", + " loss = out_train[\"embedding_0\"].pow(2).mean()\n", + "else:\n", + " # fallback: sum of any float tensor in outputs\n", + " for v in out_train.values():\n", + " if torch.is_tensor(v) and v.dtype.is_floating_point:\n", + " loss = v.sum()\n", + " break\n", + "loss.backward()\n", + "\n", + "num_lora_grads = 0\n", + "num_base_grads = 0\n", + "for n, p in model.named_parameters():\n", + " if p.grad is None:\n", + " continue\n", + " if \"lora\" in n.lower():\n", + " num_lora_grads += 1\n", + " else:\n", + " num_base_grads += 1\n", + "print(f\"[grads] LoRA params with grad: {num_lora_grads}, base params with grad: {num_base_grads}\")\n", + "model.eval();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}