Skip to content

Dev#1403

Open
mi804 wants to merge 8 commits intomainfrom
dev
Open

Dev#1403
mi804 wants to merge 8 commits intomainfrom
dev

Conversation

@mi804
Copy link
Copy Markdown
Collaborator

@mi804 mi804 commented Apr 20, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces "Diffusion Templates," a framework for controllable generation plugins in DiffSynth-Studio, specifically targeting the FLUX.2 model. It includes a new TemplatePipeline for managing multiple template models, updates to the attention mechanisms in Flux2DiT to support KV-cache and LoRA as control mediums, and comprehensive documentation and example scripts. Feedback highlights several critical issues: a regression in LoRA clearing logic that could lead to weight accumulation, significant performance bottlenecks in the lazy loading implementation and repeated signature inspections, and potential runtime errors due to missing file checks or device mismatches during tensor concatenation. Additionally, minor optimizations regarding list flattening complexity and local import overhead were suggested.

Comment on lines 324 to +337
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)

if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
# Negative side forward
if inputs_shared.get("negative_only_lora", None) is not None:
self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if inputs_shared.get("negative_only_lora", None) is not None:
self.clear_lora(verbose=0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for positive_only_lora and negative_only_lora has changed in a way that likely introduces regressions. Previously, clear_lora was called before loading the pass-specific LoRA, ensuring it was the only active LoRA. Now, clear_lora is called after the forward pass. This has two major side effects:

  1. The pass-specific LoRA is now added to any existing global LoRAs instead of replacing them, which contradicts the 'only' naming.
  2. Any global LoRAs are permanently cleared from the model after the first pass that uses a pass-specific LoRA, which will affect all subsequent passes and inference steps.
    Consider restoring the previous behavior or implementing a mechanism to save and restore the LoRA state.

Comment on lines +158 to +165
def fetch_model(self, model_id):
if self.lazy_loading:
model_config = self.model_configs[model_id]
model_config.download_if_necessary()
model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
model = self.models[model_id]
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of lazy_loading re-imports and re-instantiates the template model on every call to fetch_model if it's not the currently loaded one. Since fetch_model is called during every inference step, this will cause a massive performance degradation. Models should be cached (e.g., in CPU memory) and only moved to the target device when needed.

Comment on lines +35 to +36
spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
module = importlib.util.module_from_spec(spec)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The spec returned by importlib.util.spec_from_file_location can be None if the model.py file is missing at the specified path. Passing None to module_from_spec will cause an AttributeError. You should verify that the spec is valid before proceeding.

Suggested change
spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
module = importlib.util.module_from_spec(spec)
spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py"))
if spec is None: raise FileNotFoundError(f"model.py not found in {path}")
module = importlib.util.module_from_spec(spec)

Comment on lines +114 to +115
k = torch.concat([kv[0] for kv in kv_list], dim=1)
v = torch.concat([kv[1] for kv in kv_list], dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch.concat will fail if the tensors in kv_list are on different devices (e.g., some on CPU due to offloading and some on CUDA). Given the VRAM management context of this project, it is safer to ensure all tensors are moved to a consistent device (e.g., self.device) before concatenation.

Suggested change
k = torch.concat([kv[0] for kv in kv_list], dim=1)
v = torch.concat([kv[1] for kv in kv_list], dim=1)
k = torch.concat([kv[0].to(self.device) for kv in kv_list], dim=1)
v = torch.concat([kv[1].to(self.device) for kv in kv_list], dim=1)

return kv_cache_merged

def merge_template_cache(self, template_cache_list):
params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], []))))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using sum(..., []) to flatten a list of keys has $O(N^2)$ time complexity relative to the number of template models. A nested list comprehension or using set.union is more efficient.

Suggested change
params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], []))))
params = sorted(list(set(k for template_cache in template_cache_list for k in template_cache.keys())))

):
template_cache = self.call_single_side(pipe=pipe, inputs=template_inputs or [])
negative_template_cache = self.call_single_side(pipe=pipe, inputs=negative_template_inputs or [])
required_params = list(inspect.signature(pipe.__call__).parameters.keys())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling inspect.signature on every execution of __call__ is computationally expensive. This information should be retrieved once and cached to avoid unnecessary overhead during the inference loop.

inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 16, height // 16)), min_value=0, max_value=1)
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
from torchvision.transforms import GaussianBlur
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing GaussianBlur inside the process method adds overhead to every inference call. It is better to move this import to the top of the file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants