Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces "Diffusion Templates," a framework for controllable generation plugins in DiffSynth-Studio, specifically targeting the FLUX.2 model. It includes a new TemplatePipeline for managing multiple template models, updates to the attention mechanisms in Flux2DiT to support KV-cache and LoRA as control mediums, and comprehensive documentation and example scripts. Feedback highlights several critical issues: a regression in LoRA clearing logic that could lead to weight accumulation, significant performance bottlenecks in the lazy loading implementation and repeated signature inspections, and potential runtime errors due to missing file checks or device mismatches during tensor concatenation. Additionally, minor optimizations regarding list flattening complexity and local import overhead were suggested.
| if inputs_shared.get("positive_only_lora", None) is not None: | ||
| self.clear_lora(verbose=0) | ||
| self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) | ||
| noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) | ||
| if inputs_shared.get("positive_only_lora", None) is not None: | ||
| self.clear_lora(verbose=0) | ||
|
|
||
| if cfg_scale != 1.0: | ||
| if inputs_shared.get("positive_only_lora", None) is not None: | ||
| self.clear_lora(verbose=0) | ||
| # Negative side forward | ||
| if inputs_shared.get("negative_only_lora", None) is not None: | ||
| self.load_lora(self.dit, state_dict=inputs_shared["negative_only_lora"], verbose=0) | ||
| noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) | ||
| if inputs_shared.get("negative_only_lora", None) is not None: | ||
| self.clear_lora(verbose=0) | ||
|
|
There was a problem hiding this comment.
The logic for positive_only_lora and negative_only_lora has changed in a way that likely introduces regressions. Previously, clear_lora was called before loading the pass-specific LoRA, ensuring it was the only active LoRA. Now, clear_lora is called after the forward pass. This has two major side effects:
- The pass-specific LoRA is now added to any existing global LoRAs instead of replacing them, which contradicts the 'only' naming.
- Any global LoRAs are permanently cleared from the model after the first pass that uses a pass-specific LoRA, which will affect all subsequent passes and inference steps.
Consider restoring the previous behavior or implementing a mechanism to save and restore the LoRA state.
| def fetch_model(self, model_id): | ||
| if self.lazy_loading: | ||
| model_config = self.model_configs[model_id] | ||
| model_config.download_if_necessary() | ||
| model = load_template_model(model_config.path, torch_dtype=self.torch_dtype, device=self.device) | ||
| else: | ||
| model = self.models[model_id] | ||
| return model |
There was a problem hiding this comment.
The current implementation of lazy_loading re-imports and re-instantiates the template model on every call to fetch_model if it's not the currently loaded one. Since fetch_model is called during every inference step, this will cause a massive performance degradation. Models should be cached (e.g., in CPU memory) and only moved to the target device when needed.
| spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py")) | ||
| module = importlib.util.module_from_spec(spec) |
There was a problem hiding this comment.
The spec returned by importlib.util.spec_from_file_location can be None if the model.py file is missing at the specified path. Passing None to module_from_spec will cause an AttributeError. You should verify that the spec is valid before proceeding.
| spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py")) | |
| module = importlib.util.module_from_spec(spec) | |
| spec = importlib.util.spec_from_file_location("template_model", os.path.join(path, "model.py")) | |
| if spec is None: raise FileNotFoundError(f"model.py not found in {path}") | |
| module = importlib.util.module_from_spec(spec) |
| k = torch.concat([kv[0] for kv in kv_list], dim=1) | ||
| v = torch.concat([kv[1] for kv in kv_list], dim=1) |
There was a problem hiding this comment.
torch.concat will fail if the tensors in kv_list are on different devices (e.g., some on CPU due to offloading and some on CUDA). Given the VRAM management context of this project, it is safer to ensure all tensors are moved to a consistent device (e.g., self.device) before concatenation.
| k = torch.concat([kv[0] for kv in kv_list], dim=1) | |
| v = torch.concat([kv[1] for kv in kv_list], dim=1) | |
| k = torch.concat([kv[0].to(self.device) for kv in kv_list], dim=1) | |
| v = torch.concat([kv[1].to(self.device) for kv in kv_list], dim=1) |
| return kv_cache_merged | ||
|
|
||
| def merge_template_cache(self, template_cache_list): | ||
| params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], [])))) |
There was a problem hiding this comment.
Using sum(..., []) to flatten a list of keys has set.union is more efficient.
| params = sorted(list(set(sum([list(template_cache.keys()) for template_cache in template_cache_list], [])))) | |
| params = sorted(list(set(k for template_cache in template_cache_list for k in template_cache.keys()))) |
| ): | ||
| template_cache = self.call_single_side(pipe=pipe, inputs=template_inputs or []) | ||
| negative_template_cache = self.call_single_side(pipe=pipe, inputs=negative_template_inputs or []) | ||
| required_params = list(inspect.signature(pipe.__call__).parameters.keys()) |
| inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 16, height // 16)), min_value=0, max_value=1) | ||
| inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) | ||
| if inpaint_blur_size is not None and inpaint_blur_sigma is not None: | ||
| from torchvision.transforms import GaussianBlur |
No description provided.