Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed models/.DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions models/pos_egnn/posegnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import calculator, encoder, model, ops, utils
from . import calculator, encoder, model, ops, utils, adapter

__all__ = ["calculator", "encoder", "model", "ops", "utils"]
__all__ = ["calculator", "encoder", "model", "ops", "utils", "adapter"]
22 changes: 22 additions & 0 deletions models/pos_egnn/posegnn/adapter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# PosEGNN + LoRA

This adapter injects LoRA into mergeable linear layers of **PosEGNN** and exports merged weights that load into a plain `PosEGNN` with `strict=True`.

## Usage

```python
# 1) build and load the backbone
backbone = PosEGNN(checkpoint_dict["config"])
backbone.load_state_dict(checkpoint_dict["state_dict"], strict=True)

# 2) wrap with LoRA (post-activation linears are skipped automatically)
cfg = LoRAConfig(rank=8, alpha=8, dropout=0.0, freeze_base=True, merge_on_save=True)
model = PosEGNNLoRAModel(backbone, cfg)

# 3) train or evaluate
out = model(batch)

# 4) export merged weights that load into plain PosEGNN
merged = model.state_dict_backbone(merged=True)
plain = PosEGNN(checkpoint_dict["config"])
plain.load_state_dict(merged, strict=True)
11 changes: 11 additions & 0 deletions models/pos_egnn/posegnn/adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .config import LoRAConfig
from .layers import LoRALinear
from .model import PosEGNNLoRAModel
from .inject import apply_lora

__all__ = [
"LoRAConfig",
"LoRALinear",
"apply_lora",
"PosEGNNLoRAModel",
]
14 changes: 14 additions & 0 deletions models/pos_egnn/posegnn/adapter/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Optional, Sequence

@dataclass
class LoRAConfig:
rank: int = 8
alpha: Optional[float] = None
dropout: float = 0.0
merge_on_save: bool = True
freeze_base: bool = True
include_names: Optional[Sequence[str]] = None
exclude_names: Optional[Sequence[str]] = None
preset: Optional[str] = "posegnn"
log_skipped: bool = False
55 changes: 55 additions & 0 deletions models/pos_egnn/posegnn/adapter/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import re
import torch
import torch.nn as nn
from .layers import LoRALinear
from .config import LoRAConfig

def apply_lora(model: nn.Module, cfg: LoRAConfig) -> tuple[int, int]:
"""
Replace leaf linear-like layers under include patterns with LoRA.
Safely wraps linears with internal norm/activation since LoRA is pre-activation.
Returns (num_scalar_wrapped, 0).
"""
include_patterns = list(cfg.include_names or [])
exclude_patterns = list(cfg.exclude_names or [])
if getattr(cfg, "preset", None) == "posegnn" and not include_patterns:
include_patterns = [r"^encoder\.", r"^readout\."]

inc_re = [re.compile(p) for p in include_patterns]
exc_re = [re.compile(p) for p in exclude_patterns]

def wants(name: str) -> bool:
if any(p.search(name) for p in exc_re):
return False
if inc_re and not any(p.search(name) for p in inc_re):
return False
return True

def is_linear_like(m: nn.Module) -> bool:
w = getattr(m, "weight", None)
if isinstance(m, nn.Embedding):
return False
return isinstance(w, torch.Tensor) and w.ndim == 2

n_scalar = 0

for full_name, module in list(model.named_modules()):
if not is_linear_like(module):
continue
if not wants(full_name):
continue

parent_name, _, child = full_name.rpartition(".")
parent = model.get_submodule(parent_name) if parent_name else model

# already wrapped guard
if hasattr(module, "base") and hasattr(module, "lora_A") and hasattr(module, "lora_B"):
continue

wrapped = LoRALinear(
module, cfg.rank, cfg.alpha, cfg.dropout, cfg.merge_on_save, cfg.freeze_base
)
setattr(parent, child, wrapped)
n_scalar += 1

return n_scalar, 0
103 changes: 103 additions & 0 deletions models/pos_egnn/posegnn/adapter/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

def _init_lora(linear: nn.Linear, freeze_base: bool):
if freeze_base:
linear.weight.requires_grad = False
if linear.bias is not None:
linear.bias.requires_grad = False

class LoRALinear(nn.Module):
"""
LoRA for linear layers applied pre-activation:
y = act( norm( (W x + b) + scaling * B(A(dropout(x))) ) )
"""
def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float],
dropout: float, merge_on_save: bool, freeze_base: bool):
super().__init__()
assert isinstance(base_linear, nn.Linear)
self.base = base_linear
_init_lora(self.base, freeze_base)

self.in_features = base_linear.in_features
self.out_features = base_linear.out_features
self.r = int(rank)
self.lora_alpha = float(alpha) if alpha is not None else float(self.r)
self.scaling = self.lora_alpha / max(self.r, 1)
self.enable_lora = True
self.merged = False

# Optional submodules carried by custom Dense
self._norm = getattr(base_linear, "norm", None)
if not isinstance(self._norm, nn.Module):
self._norm = None

self._post_act = getattr(base_linear, "activation", None)
self._has_post_act = self._post_act is not None and not isinstance(self._post_act, nn.Identity)

# Always allow merge on save now that we inject pre-activation
self.merge_on_save = bool(merge_on_save)

# LoRA adapters
self.lora_dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()
self.lora_A = nn.Linear(self.in_features, self.r, bias=False) # down
self.lora_B = nn.Linear(self.r, self.out_features, bias=False) # up

nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5)
nn.init.zeros_(self.lora_B.weight)

if self.merge_on_save:
self._register_state_dict_hook(self._merge_on_state_dict)
self._register_load_state_dict_pre_hook(self._strict_fill_on_load, with_module=True)

def _apply_activation(self, y):
if not self._has_post_act:
return y
act = self._post_act
# support nn.Module or callable (e.g. torch.nn.functional.silu)
if isinstance(act, nn.Module):
return act(y)
if callable(act):
return act(y)
return y

def forward(self, x):
# linear pre-activation
y = F.linear(x, self.base.weight, self.base.bias)

# add LoRA delta pre-activation
if self.enable_lora and self.r > 0:
z = self.lora_dropout(x)
z = self.lora_A(z)
z = self.lora_B(z)
y = y + self.scaling * z

# optional norm then activation
if self._norm is not None:
y = self._norm(y)
y = self._apply_activation(y)
return y

@torch.no_grad()
def merged_weight(self):
# Always valid since injected pre-activation
return self.base.weight + self.scaling * (self.lora_B.weight @ self.lora_A.weight)

def _merge_on_state_dict(self, module, state_dict, prefix, local_metadata):
# replace the tensor stored at base.weight with merged values
key_w = prefix + "base.weight"
if key_w in state_dict:
state_dict[key_w] = self.merged_weight()
# drop adapter tensors from the saved dict
state_dict.pop(prefix + "lora_A.weight", None)
state_dict.pop(prefix + "lora_B.weight", None)
return state_dict

@torch.no_grad()
def _strict_fill_on_load(self, module, state_dict, prefix, local_metadata, strict, missing, unexpected, errors):
for k, ref in [(prefix + "lora_A.weight", self.lora_A.weight),
(prefix + "lora_B.weight", self.lora_B.weight)]:
if k not in state_dict:
state_dict[k] = torch.zeros_like(ref)
118 changes: 118 additions & 0 deletions models/pos_egnn/posegnn/adapter/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
import torch.nn as nn
from collections import OrderedDict
from .config import LoRAConfig
from .inject import apply_lora

def _is_lora_module(m: nn.Module) -> bool:
return hasattr(m, "base") and hasattr(m, "lora_A") and hasattr(m, "lora_B")

class PosEGNNLoRAModel(nn.Module):
"""
Wrap a PosEGNN backbone, inject LoRA only into mergeable linear layers
(post-activation linears are skipped by the injector), and expose a merged export.
"""
def __init__(self, backbone: nn.Module, lora_config: LoRAConfig = LoRAConfig()):
super().__init__()
self.backbone = backbone
self.lora_config = lora_config

# Injector must skip any module with a non-identity .activation
self.n_scalar, _ = apply_lora(self.backbone, self.lora_config)

if getattr(lora_config, "freeze_base", False):
for p in self.backbone.parameters():
p.requires_grad_(False)
for p in self.lora_parameters():
p.requires_grad_(True)

self._adapters_enabled = True

def forward(self, *args, **kwargs):
return self.backbone(*args, **kwargs)

def lora_parameters(self):
for _, m in self.backbone.named_modules():
if _is_lora_module(m):
yield from m.lora_A.parameters()
yield from m.lora_B.parameters()

@torch.no_grad()
def enable_adapter(self):
self._adapters_enabled = True
for _, m in self.backbone.named_modules():
if _is_lora_module(m):
setattr(m, "enable_lora", True)

@torch.no_grad()
def disable_adapter(self):
self._adapters_enabled = False
for _, m in self.backbone.named_modules():
if _is_lora_module(m):
setattr(m, "enable_lora", False)

@torch.no_grad()
def _export_merged_state_dict(self) -> OrderedDict:
"""
Build a plain PosEGNN state dict (original layout):
- merge adapter weights into '...weight'
- copy '...bias'
- copy other base.* params with 'base.' stripped
- drop all '...lora_*'
"""
sd_out = OrderedDict()
sd_in = self.backbone.state_dict()
wrapper_paths = {name for name, m in self.backbone.named_modules() if _is_lora_module(m)}

def wrapper_path_for(key: str):
for p in wrapper_paths:
if key.startswith(p + "."):
return p
return None

for key, val in sd_in.items():
# drop adapter tensors
if ".lora_A." in key or ".lora_B." in key:
continue

wp = wrapper_path_for(key)
if wp is not None:
if key.endswith(".base.weight"):
path = key[:-len(".base.weight")]
mod = self.backbone.get_submodule(path) # LoRA wrapper
sd_out[path + ".weight"] = mod.merged_weight()
continue
if key.endswith(".base.bias"):
sd_out[key.replace(".base.bias", ".bias")] = val
continue
prefix = f"{wp}.base."
if key.startswith(prefix):
sd_out[key.replace(prefix, f"{wp}.", 1)] = val
continue
continue # ignore other wrapper internals

# passthrough for non-wrapper keys
sd_out[key] = val

return sd_out

@torch.no_grad()
def state_dict_backbone(self, merged: bool = False) -> OrderedDict:
"""
Return backbone-only weights. If merged=True, adapters are fused and
keys match the original PosEGNN layout.
"""
if merged:
return self._export_merged_state_dict()
return self.backbone.state_dict()

@torch.no_grad()
def save_pretrained_merged(self, path: str):
torch.save(self.state_dict_backbone(merged=True), path)

@torch.no_grad()
def save_pretrained_adapted(self, path: str):
torch.save(self.state_dict(), path)

def lora_report(self) -> str:
return f"LoRA injected - scalar layers: {self.n_scalar}"
Loading