Conversation
Collaborator
|
Does LoRA support the I2V pipelines as well? |
9290a6e to
e1b7221
Compare
Collaborator
Author
|
Added examples of I2V support |
entrpn
reviewed
Jan 23, 2026
Collaborator
entrpn
left a comment
There was a problem hiding this comment.
can this implementation load multiple loras at once?
src/maxdiffusion/models/lora_nnx.py
Outdated
| return jnp.array(v) | ||
|
|
||
|
|
||
| def parse_lora_dict(state_dict): |
Collaborator
There was a problem hiding this comment.
do you know which lora formats are supported by this function? There are a couple lora trainers out there, might want to specify in a comment or readme which ones we're specifically targeting (diffusers, or others).
Collaborator
Author
There was a problem hiding this comment.
Added comment that it supports ComfyUI and AI Toolkit lora formats
Collaborator
Author
|
Now supports multiple loras at once. Example added to description |
7f018e4 to
9b5051c
Compare
entrpn
previously approved these changes
Jan 28, 2026
Collaborator
|
@Perseus14 please squash your commit and make sure linter tests pass. Other than than, looks good. |
entrpn
approved these changes
Jan 28, 2026
prishajain1
approved these changes
Jan 28, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.
Unlike previous implementations in this codebase that rely on
flax.linen, this implementation leveragesflax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify thetransformer modelin-place.Key Features
1. Transition to
flax.nnxWAN models in MaxDiffusion are implemented using
flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrappinglinenmodules.nnx.iter_graph) to identify target layers (nnx.Linear,nnx.Conv,nnx.Embed,nnx.LayerNorm) and merge LoRA weights directly into the kernel values.2. Robust Weight Merging Strategy
This implementation solves several critical distributed training/inference challenges:
jax.jit): To avoidShardingMismatchandDeviceArrayerrors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.jax.dlpackwhere possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.3. Advanced LoRA Support
Beyond standard
Linearrank reduction, this PR supports:diffweights before device-side merging.diff,diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.text_embedding,time_embedding, andLayerNorm/RMSNormscales and biases.4. Scanned vs. Unscanned Layers
MaxDiffusion supports enabling
jax.scanfor transformer layers via thescan_layers: Trueconfiguration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.The loader distinguishes between:
merge_lora()function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).merge_lora_for_scanned()function is used. It detects which parameters are stacked (e.g.,kernel.ndim > 2) and which are not._compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.embeddings,proj_out): It merges them individually using the single-layer JIT logic.This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.
Files Added / Modified
src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions,parse_lora_dict, and the graph traversal logic (merge_lora,merge_lora_for_scanned) to inject weights into NNX modules.src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify iflorais enabled and trigger the loading sequence before inference.src/maxdiffusion/lora_conversion_utils.py: Updatedtranslate_wan_nnx_path_to_diffusers_lorato accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.Testing
Scenario 2: Validation of Multiple LoRA weights
WAN2.1 distill_lora and divine_power_lora
WAN2.2 distill_lora and orbit_shot_lora