Skip to content

Commit

Permalink
more robust hijack ( #5 )
Browse files Browse the repository at this point in the history
  • Loading branch information
Haoming02 committed Aug 20, 2024
1 parent 72bc138 commit 117619b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 50 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ This is an Extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI), w
> **ie.** This is not just a postprocessing filter
## How to Use
- Attach the **ReSharpen** node between `Empty Latent` and `KSampler` nodes
- Attach the **ReSharpen** node between `Empty Latent` and your `Sampler` node of choice
- Adjust the **details** slider:
- **Positive** values cause the images to be noisy
- **Negative** values cause the images to be blurry

> Don't use values too close to `1` or `-1`, as it will become distorted
> Values too large or small may cause the result to become distorted!
### Important:
- `Ancestral` samplers *(**eg.** `Euler a`)* do **not** work.
- The **enable** is "global." If you want to disable it during later part of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **ReSharpen** node and set it to disable.
- The effect is "global," meaning if you want to disable it during other parts of the workflow *(**eg.** during `Hires. Fix`)*, you need to add another **ReSharpen** node and set the `details` to `0` again.

## Examples

Expand Down
18 changes: 8 additions & 10 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
from .resharpen import ReSharpen, disable
from .resharpen import ReSharpen, disable_resharpen
from functools import wraps
from typing import Callable
import execution

NODE_CLASS_MAPPINGS = {"Resharpen": ReSharpen}

NODE_DISPLAY_NAME_MAPPINGS = {"Resharpen": "ReSharpen"}


def find_node(prompt: dict) -> bool:
"""Find any ReSharpen Node"""

for k, v in prompt.items():
if v["class_type"] == "Resharpen":
return True

return False
return any(v.get("class_type") == "Resharpen" for v in prompt.values())


original_validate = execution.validate_prompt
original_validate: Callable = execution.validate_prompt


def hijack_validate(prompt):
@wraps(original_validate)
def hijack_validate(prompt: dict) -> Callable:

if not find_node(prompt):
disable()
disable_resharpen()

return original_validate(prompt)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-resharpen"
description = "Manipulate the details of generations"
version = "1.0.1"
version = "1.1.0"
license = { text = "MIT License" }
dependencies = []

Expand Down
74 changes: 38 additions & 36 deletions resharpen.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,51 @@
import comfy
from functools import wraps
from typing import Callable
import latent_preview
import torch

isEnabled = False
traj_cache = None
strength = 0.0

ORIGINAL_SAMPLE = comfy.sample.sample
ORIGINAL_SAMPLE_CUSTOM = comfy.sample.sample_custom
ORIGINAL_PREP: Callable = latent_preview.prepare_callback

RESHARPEN_STRENGTH: float = 0.0
LATENT_CACHE: torch.Tensor = None

def disable():
global isEnabled
isEnabled = False

def disable_resharpen():
"""Reset the ReSharpen Strength"""
global RESHARPEN_STRENGTH
RESHARPEN_STRENGTH = 0.0

def hijack(SAMPLE):

def sample_center(*args, **kwargs):
original_callback = kwargs["callback"]
def hijack(PREP) -> Callable:

def hijack_callback(step, x0, x, total_steps):
global isEnabled
global traj_cache
global strength
@wraps(PREP)
def prep_callback(*args, **kwargs):
original_callback: Callable = PREP(*args, **kwargs)
if not RESHARPEN_STRENGTH:
return original_callback

print("[ReSharpen] Enabled~")

if not isEnabled:
@torch.inference_mode()
@wraps(original_callback)
def hijack_callback(step, x0, x, total_steps):
if not RESHARPEN_STRENGTH:
return original_callback(step, x0, x, total_steps)

if traj_cache is not None:
delta = x.detach().clone() - traj_cache
x += delta * strength
global LATENT_CACHE
if LATENT_CACHE is not None:
delta = x.detach().clone() - LATENT_CACHE
x += delta * RESHARPEN_STRENGTH

traj_cache = x.detach().clone()
LATENT_CACHE = x.detach().clone()
return original_callback(step, x0, x, total_steps)

kwargs["callback"] = hijack_callback
return SAMPLE(*args, **kwargs)
return hijack_callback

return sample_center
return prep_callback


comfy.sample.sample = hijack(ORIGINAL_SAMPLE)
comfy.sample.sample_custom = hijack(ORIGINAL_SAMPLE_CUSTOM)
latent_preview.prepare_callback = hijack(ORIGINAL_PREP)


class ReSharpen:
Expand All @@ -49,10 +54,9 @@ def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"enable": ("BOOLEAN", {"default": False}),
"details": (
"FLOAT",
{"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1},
{"default": 0.0, "min": -2.0, "max": 2.0, "step": 0.1},
),
}
}
Expand All @@ -61,14 +65,12 @@ def INPUT_TYPES(s):
FUNCTION = "hook"
CATEGORY = "latent"

def hook(self, latent, enable, details):
global isEnabled
isEnabled = enable
def hook(self, latent, details: float):

global RESHARPEN_STRENGTH
RESHARPEN_STRENGTH = details / -10.0

if isEnabled:
global traj_cache
traj_cache = None
global strength
strength = details / -10.0
global LATENT_CACHE
LATENT_CACHE = None

return (latent,)

0 comments on commit 117619b

Please sign in to comment.