diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..8219c61b --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,20 @@ +name: CI +on: push +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff==0.6.8 + - name: Run Ruff + run: ruff check --output-format=github . + - name: Check imports + run: ruff check --select I --output-format=github . + - name: Check formatting + run: ruff format --check . diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..dfd2119c --- /dev/null +++ b/.gitignore @@ -0,0 +1,236 @@ +# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python + +# Ignore output files +output/ + +# Ignore version file which is generated dynamically +src/flux/_version.py diff --git a/README.md b/README.md index 29e21253..e73c1f13 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,9 @@ # FLUX -by Black Forest Labs: https://blackforestlabs.ai +by Black Forest Labs: https://blackforestlabs.ai. Documentation for our API can be found here: [docs.bfl.ml](https://docs.bfl.ml/). ![grid](assets/grid.jpg) -This repo contains minimal inference code to run text-to-image and image-to-image with our Flux latent rectified flow transformers. - -### Inference partners - -We are happy to partner with [Replicate](https://replicate.com/), [FAL](https://fal.ai/) and [Mystic](https://www.mystic.ai). You can sample our models using their services. -Below we list relevant links. - -Replicate: - -- https://replicate.com/collections/flux -- https://replicate.com/black-forest-labs/flux-pro -- https://replicate.com/black-forest-labs/flux-dev -- https://replicate.com/black-forest-labs/flux-schnell - -FAL: - -- https://fal.ai/models/fal-ai/flux-pro -- https://fal.ai/models/fal-ai/flux/dev -- https://fal.ai/models/fal-ai/flux/schnell - -Mystic: - -- https://www.mystic.ai/black-forest-labs -- https://www.mystic.ai/black-forest-labs/flux1-pro -- https://www.mystic.ai/black-forest-labs/flux1-dev -- https://www.mystic.ai/black-forest-labs/flux1-schnell +This repo contains minimal inference code to run image generation & editing with our Flux models. ## Local installation @@ -42,105 +17,32 @@ pip install -e ".[all]" ### Models -We are offering three models: - -- `FLUX.1 [pro]` the base model, available via API -- `FLUX.1 [dev]` guidance-distilled variant -- `FLUX.1 [schnell]` guidance and step-distilled variant - -| Name | HuggingFace repo | License | md5sum | -| ------------------ | ------------------------------------------------------- | --------------------------------------------------------------------- | -------------------------------- | -| `FLUX.1 [schnell]` | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) | a9e1e277b9b16add186f38e3f5a34044 | -| `FLUX.1 [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | a6bd8c16dfc23db6aee2f63a2eba78c0 | -| `FLUX.1 [pro]` | Only available in our API. | - -The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in either of the two HuggingFace repos above. They are the same for both models. - -## Usage - -The weights will be downloaded automatically from HuggingFace once you start one of the demos. To download `FLUX.1 [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). -If you have downloaded the model weights manually, you can specify the downloaded paths via environment-variables: - -```bash -export FLUX_SCHNELL= -export FLUX_DEV= -export AE= -``` - -For interactive sampling run - -```bash -python -m flux --name --loop -``` - -Or to generate a single sample run - -```bash -python -m flux --name \ - --height --width \ - --prompt "" -``` - -We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via - -```bash -streamlit run demo_st.py -``` - -We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo: - -```bash -python demo_gr.py --name flux-schnell --device cuda -``` - -Options: - -- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev") -- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu") -- `--offload`: Offload model to CPU when not in use -- `--share`: Create a public link to your demo - -To run the demo with the dev model and create a public link: - -```bash -python demo_gr.py --name flux-dev --share -``` - -## Diffusers integration - -`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it: - -```shell -pip install git+https://github.com/huggingface/diffusers.git -``` - -Then you can use `FluxPipeline` to run the model - -```python -import torch -from diffusers import FluxPipeline - -model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev` - -pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) -pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power - -prompt = "A cat holding a sign that says hello world" -seed = 42 -image = pipe( - prompt, - output_type="pil", - num_inference_steps=4, #use a larger number if you are using [dev] - generator=torch.Generator("cpu").manual_seed(seed) -).images[0] -image.save("flux-schnell.png") -``` - -To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation +We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**. + +| Name | Usage | HuggingFace repo | License | +| --------------------------- | ---------------------------------------------------------- | ------------------------------------------------------------- | --------------------------------------------------------------------- | +| `FLUX.1 [schnell]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) | +| `FLUX.1 [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Fill [dev]` | [In/Out-painting](docs/fill.md) | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Canny [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Depth [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Canny [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Depth [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 Redux [dev]` | [Image variation](docs/image-variation.md) | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | +| `FLUX.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX1.1 [pro]` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX1.1 [pro] Ultra/raw` | [Text to Image](docs/text-to-image.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX.1 Fill [pro]` | [In/Out-painting](docs/fill.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX.1 Canny [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX.1 Depth [pro]` | [Structural Conditioning](docs/controlnet.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX1.1 Redux [pro]` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) | +| `FLUX1.1 Redux [pro] Ultra` | [Image variation](docs/image-variation.md) | [Available in our API.](https://docs.bfl.ml/) | + +The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above. ## API usage -Our API offers access to the pro model. It is documented here: +Our API offers access to our models. It is documented here: [docs.bfl.ml](https://docs.bfl.ml/). In this repository we also offer an easy python interface. To use this, you @@ -157,8 +59,8 @@ Usage from python: from flux.api import ImageRequest # this will create an api request directly but not block until the generation is finished -request = ImageRequest("A beautiful beach") -# or: request = ImageRequest("A beautiful beach", api_key="your_key_here") +request = ImageRequest("A beautiful beach", name="flux.1.1-pro") +# or: request = ImageRequest("A beautiful beach", name="flux.1.1-pro", api_key="your_key_here") # any of the following will block until the generation is finished request.url diff --git a/assets/cup.png b/assets/cup.png new file mode 100644 index 00000000..3a810e4b Binary files /dev/null and b/assets/cup.png differ diff --git a/assets/cup_mask.png b/assets/cup_mask.png new file mode 100644 index 00000000..36f75c6e Binary files /dev/null and b/assets/cup_mask.png differ diff --git a/assets/docs/canny.png b/assets/docs/canny.png new file mode 100644 index 00000000..ef1f3ce3 Binary files /dev/null and b/assets/docs/canny.png differ diff --git a/assets/docs/depth.png b/assets/docs/depth.png new file mode 100644 index 00000000..c16d3523 Binary files /dev/null and b/assets/docs/depth.png differ diff --git a/assets/docs/inpainting.png b/assets/docs/inpainting.png new file mode 100644 index 00000000..2072c6dd Binary files /dev/null and b/assets/docs/inpainting.png differ diff --git a/assets/docs/outpainting.png b/assets/docs/outpainting.png new file mode 100644 index 00000000..2b5ada06 Binary files /dev/null and b/assets/docs/outpainting.png differ diff --git a/assets/docs/redux.png b/assets/docs/redux.png new file mode 100644 index 00000000..6b350485 Binary files /dev/null and b/assets/docs/redux.png differ diff --git a/assets/robot.webp b/assets/robot.webp new file mode 100644 index 00000000..090b3cc4 Binary files /dev/null and b/assets/robot.webp differ diff --git a/demo_gr.py b/demo_gr.py index 3393886a..33bb72dd 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -1,13 +1,12 @@ import os import time -from io import BytesIO import uuid -import torch import gradio as gr import numpy as np +import torch from einops import rearrange -from PIL import Image, ExifTags +from PIL import ExifTags, Image from transformers import pipeline from flux.cli import SamplingOptions @@ -16,6 +15,7 @@ NSFW_THRESHOLD = 0.85 + def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) @@ -24,8 +24,9 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier + class FluxGenerator: - def __init__(self, model_name: str, device: str, offload: bool): + def __init__(self, model_name: str, device: str, offload: bool, use_compile: bool = False): self.device = torch.device(device) self.offload = offload self.model_name = model_name @@ -36,7 +37,8 @@ def __init__(self, model_name: str, device: str, offload: bool): offload=self.offload, is_schnell=self.is_schnell, ) - self.model = torch.compile(self.model) + if use_compile: + self.model = torch.compile(self.model) @torch.inference_mode() def generate_image( @@ -72,7 +74,7 @@ def generate_image( if init_image is not None: if isinstance(init_image, np.ndarray): init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0 - init_image = init_image.unsqueeze(0) + init_image = init_image.unsqueeze(0) init_image = init_image.to(self.device) init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) if self.offload: @@ -153,37 +155,49 @@ def generate_image( exif_data[ExifTags.Base.Model] = self.model_name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt - + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) return img, str(opts.seed), filename, None else: return None, str(opts.seed), None, "Your generated image may contain NSFW content." -def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False): + +def create_demo( + model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False +): generator = FluxGenerator(model_name, device, offload) is_schnell = model_name == "flux-schnell" with gr.Blocks() as demo: gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}") - + with gr.Row(): with gr.Column(): - prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture") + prompt = gr.Textbox( + label="Prompt", + value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture', + ) do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell) init_image = gr.Image(label="Input Image", visible=False) - image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False) - + image2image_strength = gr.Slider( + 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False + ) + with gr.Accordion("Advanced Options", open=False): width = gr.Slider(128, 8192, 1360, step=16, label="Width") height = gr.Slider(128, 8192, 768, step=16, label="Height") num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps") - guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell) + guidance = gr.Slider( + 1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell + ) seed = gr.Textbox(-1, label="Seed (-1 for random)") - add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True) - + add_sampling_metadata = gr.Checkbox( + label="Add sampling parameters to metadata?", value=True + ) + generate_btn = gr.Button("Generate") - + with gr.Column(): output_image = gr.Image(label="Generated Image") seed_output = gr.Number(label="Used Seed") @@ -200,17 +214,33 @@ def update_img2img(do_img2img): generate_btn.click( fn=generator.generate_image, - inputs=[width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata], + inputs=[ + width, + height, + num_steps, + guidance, + seed, + prompt, + init_image, + image2image_strength, + add_sampling_metadata, + ], outputs=[output_image, seed_output, download_btn, warning_text], ) return demo + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description="Flux") - parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name") - parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use" + ) parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") args = parser.parse_args() diff --git a/demo_st.py b/demo_st.py index 0c52564d..f6708891 100644 --- a/demo_st.py +++ b/demo_st.py @@ -25,6 +25,7 @@ ) NSFW_THRESHOLD = 0.85 +CHECK_NSFW = False @st.cache_resource() @@ -242,36 +243,36 @@ def decrement_counter(): img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] - # if nsfw_score < NSFW_THRESHOLD: - buffer = BytesIO() - exif_data = Image.Exif() - if init_image is None: - exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + if not CHECK_NSFW or nsfw_score < NSFW_THRESHOLD: + buffer = BytesIO() + exif_data = Image.Exif() + if init_image is None: + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + else: + exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) + + img_bytes = buffer.getvalue() + if save_samples: + print(f"Saving {fn}") + with open(fn, "wb") as file: + file.write(img_bytes) + idx += 1 + + st.session_state["samples"] = { + "prompt": opts.prompt, + "img": img, + "seed": opts.seed, + "bytes": img_bytes, + } + opts.seed = None else: - exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" - exif_data[ExifTags.Base.Make] = "Black Forest Labs" - exif_data[ExifTags.Base.Model] = name - if add_sampling_metadata: - exif_data[ExifTags.Base.ImageDescription] = prompt - img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) - - img_bytes = buffer.getvalue() - if save_samples: - print(f"Saving {fn}") - with open(fn, "wb") as file: - file.write(img_bytes) - idx += 1 - - st.session_state["samples"] = { - "prompt": opts.prompt, - "img": img, - "seed": opts.seed, - "bytes": img_bytes, - } - opts.seed = None - # else: - # st.warning("Your generated image may contain NSFW content.") - # st.session_state["samples"] = None + st.warning("Your generated image may contain NSFW content.") + st.session_state["samples"] = None samples = st.session_state.get("samples", None) if samples is not None: diff --git a/demo_st_fill.py b/demo_st_fill.py new file mode 100644 index 00000000..ddba6688 --- /dev/null +++ b/demo_st_fill.py @@ -0,0 +1,487 @@ +import os +import re +import tempfile +import time +from glob import iglob +from io import BytesIO + +import numpy as np +import streamlit as st +import torch +from einops import rearrange +from PIL import ExifTags, Image +from st_keyup import st_keyup +from streamlit_drawable_canvas import st_canvas +from transformers import pipeline + +from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack +from flux.util import embed_watermark, load_ae, load_clip, load_flow_model, load_t5 + +NSFW_THRESHOLD = 0.85 + + +def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, zoom_up=0, zoom_down=0, overlap=0): + """Adds a black border around the image with individual side control and mask overlap""" + orig_width, orig_height = image.size + + # Calculate padding for each side (in pixels) + left_pad = int(orig_width * zoom_left) + right_pad = int(orig_width * zoom_right) + top_pad = int(orig_height * zoom_up) + bottom_pad = int(orig_height * zoom_down) + + # Calculate overlap in pixels + overlap_left = int(orig_width * overlap) + overlap_right = int(orig_width * overlap) + overlap_top = int(orig_height * overlap) + overlap_bottom = int(orig_height * overlap) + + # If using the all-sides zoom, add it to each side + if zoom_all > 1.0: + extra_each_side = (zoom_all - 1.0) / 2 + left_pad += int(orig_width * extra_each_side) + right_pad += int(orig_width * extra_each_side) + top_pad += int(orig_height * extra_each_side) + bottom_pad += int(orig_height * extra_each_side) + + # Calculate new dimensions (ensure they're multiples of 32) + new_width = 32 * round((orig_width + left_pad + right_pad) / 32) + new_height = 32 * round((orig_height + top_pad + bottom_pad) / 32) + + # Create new image with black border + bordered_image = Image.new("RGB", (new_width, new_height), (0, 0, 0)) + # Paste original image in position + paste_x = left_pad + paste_y = top_pad + bordered_image.paste(image, (paste_x, paste_y)) + + # Create mask (white where the border is, black where the original image was) + mask = Image.new("L", (new_width, new_height), 255) # White background + # Paste black rectangle with overlap adjustment + mask.paste( + 0, + ( + paste_x + overlap_left, # Left edge moves right + paste_y + overlap_top, # Top edge moves down + paste_x + orig_width - overlap_right, # Right edge moves left + paste_y + orig_height - overlap_bottom, # Bottom edge moves up + ), + ) + + return bordered_image, mask + + +@st.cache_resource() +def get_models(name: str, device: torch.device, offload: bool): + t5 = load_t5(device, max_length=128) + clip = load_clip(device) + model = load_flow_model(name, device="cpu" if offload else device) + ae = load_ae(name, device="cpu" if offload else device) + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + return model, ae, t5, clip, nsfw_classifier + + +def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -> Image.Image: + width, height = img.size + mp = (width * height) / 1_000_000 # Current megapixels + + if min_mp <= mp <= max_mp: + # Even if MP is in range, ensure dimensions are multiples of 32 + new_width = int(32 * round(width / 32)) + new_height = int(32 * round(height / 32)) + if new_width != width or new_height != height: + return img.resize((new_width, new_height), Image.Resampling.LANCZOS) + return img + + # Calculate scaling factor + if mp < min_mp: + scale = (min_mp / mp) ** 0.5 + else: # mp > max_mp + scale = (max_mp / mp) ** 0.5 + + new_width = int(32 * round(width * scale / 32)) + new_height = int(32 * round(height * scale / 32)) + + return img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + +def clear_canvas_state(): + """Clear all canvas-related state""" + keys_to_clear = ["canvas", "last_image_dims"] + for key in keys_to_clear: + if key in st.session_state: + del st.session_state[key] + + +def set_new_image(img: Image.Image): + """Safely set a new image and clear relevant state""" + st.session_state["current_image"] = img + clear_canvas_state() + st.rerun() + + +def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image: + """Downscale image by a given factor while maintaining 32-pixel multiple dimensions""" + if scale_factor >= 1.0: + return img + + width, height = img.size + new_width = int(32 * round(width * scale_factor / 32)) + new_height = int(32 * round(height * scale_factor / 32)) + + # Ensure minimum dimensions + new_width = max(64, new_width) # minimum 64 pixels + new_height = max(64, new_height) # minimum 64 pixels + + return img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + +@torch.inference_mode() +def main( + device: str = "cuda" if torch.cuda.is_available() else "cpu", + offload: bool = False, + output_dir: str = "output", +): + torch_device = torch.device(device) + st.title("Flux Fill: Inpainting & Outpainting") + + # Model selection and loading + name = "flux-dev-fill" + if not st.checkbox("Load model", False): + return + + try: + model, ae, t5, clip, nsfw_classifier = get_models( + name, + device=torch_device, + offload=offload, + ) + except Exception as e: + st.error(f"Error loading models: {e}") + return + + # Mode selection + mode = st.radio("Select Mode", ["Inpainting", "Outpainting"]) + + # Image handling - either from previous generation or new upload + if "input_image" in st.session_state: + image = st.session_state["input_image"] + del st.session_state["input_image"] + set_new_image(image) + st.write("Continuing from previous result") + else: + uploaded_image = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"]) + if uploaded_image is None: + st.warning("Please upload an image") + return + + if ( + "current_image_name" not in st.session_state + or st.session_state["current_image_name"] != uploaded_image.name + ): + try: + image = Image.open(uploaded_image).convert("RGB") + st.session_state["current_image_name"] = uploaded_image.name + set_new_image(image) + except Exception as e: + st.error(f"Error loading image: {e}") + return + else: + image = st.session_state.get("current_image") + if image is None: + st.error("Error: Image state is invalid. Please reupload the image.") + clear_canvas_state() + return + + # Add downscale control + with st.expander("Image Size Control"): + current_mp = (image.size[0] * image.size[1]) / 1_000_000 + st.write(f"Current image size: {image.size[0]}x{image.size[1]} ({current_mp:.1f}MP)") + + scale_factor = st.slider( + "Downscale Factor", + min_value=0.1, + max_value=1.0, + value=1.0, + step=0.1, + help="1.0 = original size, 0.5 = half size, etc.", + ) + + if scale_factor < 1.0 and st.button("Apply Downscaling"): + image = downscale_image(image, scale_factor) + set_new_image(image) + st.rerun() + + # Resize image with validation + try: + original_mp = (image.size[0] * image.size[1]) / 1_000_000 + image = resize(image) + width, height = image.size + current_mp = (width * height) / 1_000_000 + + if width % 32 != 0 or height % 32 != 0: + st.error("Error: Image dimensions must be multiples of 32") + return + + st.write(f"Image dimensions: {width}x{height} pixels") + if original_mp != current_mp: + st.write( + f"Image has been resized from {original_mp:.1f}MP to {current_mp:.1f}MP to stay within bounds (0.5MP - 2MP)" + ) + except Exception as e: + st.error(f"Error processing image: {e}") + return + + if mode == "Outpainting": + # Outpainting controls + zoom_all = st.slider("Zoom Out Amount (All Sides)", min_value=1.0, max_value=3.0, value=1.0, step=0.1) + + with st.expander("Advanced Zoom Controls"): + st.info("These controls add additional zoom to specific sides") + col1, col2 = st.columns(2) + with col1: + zoom_left = st.slider("Left", min_value=0.0, max_value=1.0, value=0.0, step=0.1) + zoom_right = st.slider("Right", min_value=0.0, max_value=1.0, value=0.0, step=0.1) + with col2: + zoom_up = st.slider("Up", min_value=0.0, max_value=1.0, value=0.0, step=0.1) + zoom_down = st.slider("Down", min_value=0.0, max_value=1.0, value=0.0, step=0.1) + + overlap = st.slider("Overlap", min_value=0.01, max_value=0.25, value=0.01, step=0.01) + + # Generate bordered image and mask + image_for_generation, mask = add_border_and_mask( + image, + zoom_all=zoom_all, + zoom_left=zoom_left, + zoom_right=zoom_right, + zoom_up=zoom_up, + zoom_down=zoom_down, + overlap=overlap, + ) + width, height = image_for_generation.size + + # Show preview + col1, col2 = st.columns(2) + with col1: + st.image(image_for_generation, caption="Image with Border") + with col2: + st.image(mask, caption="Mask (white areas will be generated)") + + else: # Inpainting mode + # Canvas setup with dimension tracking + canvas_key = f"canvas_{width}_{height}" + if "last_image_dims" not in st.session_state: + st.session_state.last_image_dims = (width, height) + elif st.session_state.last_image_dims != (width, height): + clear_canvas_state() + st.session_state.last_image_dims = (width, height) + st.rerun() + + try: + canvas_result = st_canvas( + fill_color="rgba(255, 255, 255, 0.0)", + stroke_width=st.slider("Brush size", 1, 500, 50), + stroke_color="#fff", + background_image=image, + height=height, + width=width, + drawing_mode="freedraw", + key=canvas_key, + display_toolbar=True, + ) + except Exception as e: + st.error(f"Error creating canvas: {e}") + clear_canvas_state() + st.rerun() + return + + # Sampling parameters + num_steps = int(st.number_input("Number of steps", min_value=1, value=50)) + guidance = float(st.number_input("Guidance", min_value=1.0, value=30.0)) + seed_str = st.text_input("Seed") + if seed_str.isdecimal(): + seed = int(seed_str) + else: + st.info("No seed set, using random seed") + seed = None + + save_samples = st.checkbox("Save samples?", True) + add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True) + + # Prompt input + prompt = st_keyup("Enter a prompt", value="", debounce=300, key="interactive_text") + + # Setup output path + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + idx = len(fns) + + if st.button("Generate"): + valid_input = False + + if mode == "Inpainting" and canvas_result.image_data is not None: + valid_input = True + # Create mask from canvas + try: + mask = Image.fromarray(canvas_result.image_data) + mask = mask.getchannel("A") # Get alpha channel + mask_array = np.array(mask) + mask_array = (mask_array > 0).astype(np.uint8) * 255 + mask = Image.fromarray(mask_array) + image_for_generation = image + except Exception as e: + st.error(f"Error creating mask: {e}") + return + + elif mode == "Outpainting": + valid_input = True + # image_for_generation and mask are already set above + + if not valid_input: + st.error("Please draw a mask or configure outpainting settings") + return + + # Create temporary files + with ( + tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img, + tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask, + ): + try: + image_for_generation.save(tmp_img.name) + mask.save(tmp_mask.name) + except Exception as e: + st.error(f"Error saving temporary files: {e}") + return + + try: + # Generate inpainting/outpainting + rng = torch.Generator(device="cpu") + if seed is None: + seed = rng.seed() + + print(f"Generating with seed {seed}:\n{prompt}") + t0 = time.perf_counter() + + x = get_noise( + 1, + height, + width, + device=torch_device, + dtype=torch.bfloat16, + seed=seed, + ) + + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) + + inp = prepare_fill( + t5, + clip, + x, + prompt=prompt, + ae=ae, + img_cond_path=tmp_img.name, + mask_path=tmp_mask.name, + ) + + timesteps = get_schedule(num_steps, inp["img"].shape[1], shift=True) + + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) + + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + x = unpack(x.float(), height, width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + # Process and display result + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + buffer = BytesIO() + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;inpainting;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) + + img_bytes = buffer.getvalue() + if save_samples: + fn = output_name.format(idx=idx) + print(f"Saving {fn}") + with open(fn, "wb") as file: + file.write(img_bytes) + + st.session_state["samples"] = { + "prompt": prompt, + "img": img, + "seed": seed, + "bytes": img_bytes, + } + else: + st.warning("Your generated image may contain NSFW content.") + st.session_state["samples"] = None + + except Exception as e: + st.error(f"Error during generation: {e}") + return + finally: + # Clean up temporary files + try: + os.unlink(tmp_img.name) + os.unlink(tmp_mask.name) + except Exception as e: + print(f"Error cleaning up temporary files: {e}") + + # Display results + samples = st.session_state.get("samples", None) + if samples is not None: + st.image(samples["img"], caption=samples["prompt"]) + col1, col2 = st.columns(2) + with col1: + st.download_button( + "Download full-resolution", + samples["bytes"], + file_name="generated.jpg", + mime="image/jpg", + ) + with col2: + if st.button("Continue from this image"): + # Store the generated image + new_image = samples["img"] + # Clear ALL canvas state + clear_canvas_state() + if "samples" in st.session_state: + del st.session_state["samples"] + # Set as current image + st.session_state["current_image"] = new_image + st.rerun() + + st.write(f"Seed: {samples['seed']}") + + +if __name__ == "__main__": + st.set_page_config(layout="wide") + main() diff --git a/docs/fill.md b/docs/fill.md new file mode 100644 index 00000000..c73bf1f4 --- /dev/null +++ b/docs/fill.md @@ -0,0 +1,44 @@ +## Models + +FLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless edits that integrate naturally with existing images. + +| Name | HuggingFace repo | License | sha256sum | +| ------------------- | -------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `FLUX.1 Fill [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 03e289f530df51d014f48e675a9ffa2141bc003259bf5f25d75b957e920a41ca | +| `FLUX.1 Fill [pro]` | Only available in our API. | + +## Examples + +![inpainting](../assets/docs/inpainting.png) +![outpainting](../assets/docs/outpainting.png) + +## Open-weights usage + +The weights will be downloaded automatically from HuggingFace once you start one of the demos. To download `FLUX.1 Fill [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). Alternatively, if you have downloaded the model weights manually from [here](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), you can specify the downloaded paths via environment variables: + +```bash +export FLUX_DEV_FILL= +export AE= +``` + +For interactive sampling run + +```bash +python -m src.flux.cli_fill --loop +``` + +Or to generate a single sample run + +```bash +python -m src.flux.cli_fill \ + --img_cond_path \ + --img_cond_mask +``` + +The input_mask should be an image of the same size as the conditioning image that only contains black and white pixels; see [an example mask](../assets/cup_mask.png) for [this image](../assets/cup.png). + +We also provide an interactive streamlit demo. The demo can be run via + +```bash +streamlit run demo_st_fill.py +``` diff --git a/docs/image-variation.md b/docs/image-variation.md new file mode 100644 index 00000000..a15511d4 --- /dev/null +++ b/docs/image-variation.md @@ -0,0 +1,33 @@ +## Models + +FLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which can be used to generate image variations. +In addition, FLUX.1 Redux [pro] is available in our API and, augmenting the [dev] adapter, the API endpoint allows users to modify an image given a textual description. The feature is supported in our latest model FLUX1.1 [pro] Ultra, allowing for combining input images and text prompts to create high-quality 4-megapixel outputs with flexible aspect ratios. + +| Name | HuggingFace repo | License | sha256sum | +| --------------------------- | ----------------------------------------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `FLUX.1 Redux [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | a1b3bdcb4bdc58ce04874b9ca776d61fc3e914bb6beab41efb63e4e2694dca45 | +| `FLUX.1 Redux [pro]` | [Available in our API.](https://docs.bfl.ml/) Supports image variations. | +| `FLUX1.1 Redux [pro] Ultra` | [Available in our API.](https://docs.bfl.ml/) Supports image variations based on a text prompt. | + +## Examples + +![redux](../assets/docs/redux.png) + +## Open-weights usage + +The text-to-image base model weights and the autoencoder weights will be downloaded automatically from HuggingFace once you start the demo. To download `FLUX.1 [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). You need to manually download the adapter weights from [here](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) and specify them via an environment variable `export FLUX_REDUX=`. In general, you may specify any manually downloaded weights via environment variables: + +```bash +export FLUX_REDUX= +export FLUX_SCHNELL= +export FLUX_DEV= +export AE= +``` + +For interactive sampling run + +```bash +python -m src.flux.cli_redux --loop --name +``` + +where `name` is one of `flux-dev` or `flux-schnell`. diff --git a/docs/structural-conditioning.md b/docs/structural-conditioning.md new file mode 100644 index 00000000..4cbc16da --- /dev/null +++ b/docs/structural-conditioning.md @@ -0,0 +1,40 @@ +## Models + +Structural conditioning uses canny edge or depth detection to maintain precise control during image transformations. By preserving the original image's structure through edge or depth maps, users can make text-guided edits while keeping the core composition intact. This is particularly effective for retexturing images. We release four variations: two based on edge maps (full model and LoRA for FLUX.1 [dev]) and two based on depth maps (full model and LoRA for FLUX.1 [dev]). + +| Name | HuggingFace repo | License | sha256sum | +| ------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `FLUX.1 Canny [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 996876670169591cb412b937fbd46ea14cbed6933aef17c48a2dcd9685c98cdb | +| `FLUX.1 Depth [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 41360d1662f44ca45bc1b665fe6387e91802f53911001630d970a4f8be8dac21 | +| `FLUX.1 Canny [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 8eaa21b9c43d5e7242844deb64b8cf22ae9010f813f955ca8c05f240b8a98f7e | +| `FLUX.1 Depth [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 1938b38ea0fdd98080fa3e48beb2bedfbc7ad102d8b65e6614de704a46d8b907 | +| `FLUX.1 Canny [pro]` | [Available in our API](https://docs.bfl.ml/). | +| `FLUX.1 Depth [pro]` | [Available in our API](https://docs.bfl.ml/). | + +## Examples + +![canny](../assets/docs/canny.png) +![depth](../assets/docs/depth.png) + +## Open-weights usage + +The full model weights (`FLUX.1 Canny [dev], Flux.1 Depth [dev], FLUX.1 [dev], and the autoencoder) will be downloaded automatically from HuggingFace once you start one of the demos. To download them, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). The LoRA weights are not downloaded automatically, but can be downloaded manually [here (Canny)](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) and [here (Depth)](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora). You may specify any manually downloaded weights via environment variables: (**necessary for LoRAs**): + +```bash +export FLUX_DEV_DEPTH= +export FLUX_DEV_CANNY= +export FLUX_DEV_DEPTH_LORA= +export FLUX_DEV_CANNY_LORA= +export FLUX_REDUX= +export FLUX_SCHNELL= +export FLUX_DEV= +export AE= +``` + +For interactive sampling run + +```bash +python -m src.flux.cli_control --loop --name +``` + +where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`. diff --git a/docs/text-to-image.md b/docs/text-to-image.md new file mode 100644 index 00000000..172004ed --- /dev/null +++ b/docs/text-to-image.md @@ -0,0 +1,93 @@ +## Models + +We currently offer four text-to-image models. `FLUX1.1 [pro]` is our most capable model which can generate images at up to 4MP while maintaining an impressive generation time of only 10 seconds per sample. + +| Name | HuggingFace repo | License | sha256sum | +| ------------------------- | ------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | +| `FLUX.1 [schnell]` | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) | 9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a | +| `FLUX.1 [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 | +| `FLUX.1 [pro]` | [Available in our API](https://docs.bfl.ml/). | +| `FLUX1.1 [pro]` | [Available in our API](https://docs.bfl.ml/). | +| `FLUX1.1 [pro] Ultra/raw` | [Available in our API](https://docs.bfl.ml/). | + +## Open-weights usage + +The weights will be downloaded automatically from HuggingFace once you start one of the demos. To download `FLUX.1 [dev]`, you will need to be logged in, see [here](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). +If you have downloaded the model weights manually, you can specify the downloaded paths via environment-variables: + +```bash +export FLUX_SCHNELL= +export FLUX_DEV= +export AE= +``` + +For interactive sampling run + +```bash +python -m flux --name --loop +``` + +Or to generate a single sample run + +```bash +python -m flux --name \ + --height --width \ + --prompt "" +``` + +We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via + +```bash +streamlit run demo_st.py +``` + +We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo: + +```bash +python demo_gr.py --name flux-schnell --device cuda +``` + +Options: + +- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev") +- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu") +- `--offload`: Offload model to CPU when not in use +- `--share`: Create a public link to your demo + +To run the demo with the dev model and create a public link: + +```bash +python demo_gr.py --name flux-dev --share +``` + +## Diffusers integration + +`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it: + +```shell +pip install git+https://github.com/huggingface/diffusers.git +``` + +Then you can use `FluxPipeline` to run the model + +```python +import torch +from diffusers import FluxPipeline + +model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev` + +pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power + +prompt = "A cat holding a sign that says hello world" +seed = 42 +image = pipe( + prompt, + output_type="pil", + num_inference_steps=4, #use a larger number if you are using [dev] + generator=torch.Generator("cpu").manual_seed(seed) +).images[0] +image.save("flux-schnell.png") +``` + +To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation diff --git a/model_licenses/LICENSE-FLUX1-dev b/model_licenses/LICENSE-FLUX1-dev index d91cf0bc..3c87409f 100644 --- a/model_licenses/LICENSE-FLUX1-dev +++ b/model_licenses/LICENSE-FLUX1-dev @@ -1,5 +1,5 @@ FLUX.1 [dev] Non-Commercial License -Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] text-to-image AI model and its elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI model made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). +Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models, including FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA and FLUX.1 Depth [dev] LoRA, and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings: a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. diff --git a/pyproject.toml b/pyproject.toml index ccf0f77a..fb242dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,14 +9,33 @@ requires-python = ">=3.10" license = { file = "LICENSE.md" } dynamic = ["version"] dependencies = [ + "torch == 2.5.1", + "torchvision", "einops", "fire >= 0.6.0", "huggingface-hub", "safetensors", + "sentencepiece", "transformers", "tokenizers", + "protobuf", "requests", "invisible-watermark", + "ruff == 0.6.8", +] + +[project.optional-dependencies] +streamlit = [ + "streamlit", + "streamlit-drawable-canvas", + "streamlit-keyup", +] +gradio = [ + "gradio", +] +all = [ + "flux[streamlit]", + "flux[gradio]", ] [project.scripts] diff --git a/src/flux/__init__.py b/src/flux/__init__.py index 43c365a4..dddc6a38 100644 --- a/src/flux/__init__.py +++ b/src/flux/__init__.py @@ -1,6 +1,8 @@ try: - from ._version import version as __version__ # type: ignore - from ._version import version_tuple + from ._version import ( + version as __version__, # type: ignore + version_tuple, + ) except ImportError: __version__ = "unknown (no version information available)" version_tuple = (0, 0, "unknown", "noinfo") diff --git a/src/flux/api.py b/src/flux/api.py index b08202ad..6a608840 100644 --- a/src/flux/api.py +++ b/src/flux/api.py @@ -6,7 +6,12 @@ import requests from PIL import Image -API_ENDPOINT = "https://api.bfl.ml" +API_URL = "https://api.bfl.ml" +API_ENDPOINTS = { + "flux.1-pro": "flux-pro", + "flux.1-dev": "flux-dev", + "flux.1.1-pro": "flux-pro-1.1", +} class ApiException(Exception): @@ -31,13 +36,18 @@ def __repr__(self) -> str: class ImageRequest: def __init__( self, + # api inputs prompt: str, - width: int = 1024, - height: int = 1024, - name: str = "flux.1-pro", - num_steps: int = 50, - prompt_upsampling: bool = False, + name: str = "flux.1.1-pro", + width: int | None = None, + height: int | None = None, + num_steps: int | None = None, + prompt_upsampling: bool | None = None, seed: int | None = None, + guidance: float | None = None, + interval: float | None = None, + safety_tolerance: int | None = None, + # behavior of this class validate: bool = True, launch: bool = True, api_key: str | None = None, @@ -45,46 +55,67 @@ def __init__( """ Manages an image generation request to the API. + All parameters not specified will use the API defaults. + Args: - prompt: Prompt to sample - width: Width of the image in pixel - height: Height of the image in pixel - name: Name of the model - num_steps: Number of network evaluations - prompt_upsampling: Use prompt upsampling - seed: Fix the generation seed + prompt: Text prompt for image generation. + width: Width of the generated image in pixels. Must be a multiple of 32. + height: Height of the generated image in pixels. Must be a multiple of 32. + name: Which model version to use + num_steps: Number of steps for the image generation process. + prompt_upsampling: Whether to perform upsampling on the prompt. + seed: Optional seed for reproducibility. + guidance: Guidance scale for image generation. + safety_tolerance: Tolerance level for input and output moderation. + Between 0 and 6, 0 being most strict, 6 being least strict. validate: Run input validation launch: Directly launches request api_key: Your API key if not provided by the environment Raises: - ValueError: For invalid input + ValueError: For invalid input, when `validate` ApiException: For errors raised from the API """ if validate: - if name not in ["flux.1-pro"]: + if name not in API_ENDPOINTS.keys(): raise ValueError(f"Invalid model {name}") - elif width % 32 != 0: + elif width is not None and width % 32 != 0: raise ValueError(f"width must be divisible by 32, got {width}") - elif not (256 <= width <= 1440): + elif width is not None and not (256 <= width <= 1440): raise ValueError(f"width must be between 256 and 1440, got {width}") - elif height % 32 != 0: + elif height is not None and height % 32 != 0: raise ValueError(f"height must be divisible by 32, got {height}") - elif not (256 <= height <= 1440): + elif height is not None and not (256 <= height <= 1440): raise ValueError(f"height must be between 256 and 1440, got {height}") - elif not (1 <= num_steps <= 50): + elif num_steps is not None and not (1 <= num_steps <= 50): raise ValueError(f"steps must be between 1 and 50, got {num_steps}") - + elif guidance is not None and not (1.5 <= guidance <= 5.0): + raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}") + elif interval is not None and not (1.0 <= interval <= 4.0): + raise ValueError(f"interval must be between 1 and 4, got {interval}") + elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0): + raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}") + + if name == "flux.1-dev": + if interval is not None: + raise ValueError("Interval is not supported for flux.1-dev") + if name == "flux.1.1-pro": + if interval is not None or num_steps is not None or guidance is not None: + raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro") + + self.name = name self.request_json = { "prompt": prompt, "width": width, "height": height, - "variant": name, "steps": num_steps, "prompt_upsampling": prompt_upsampling, + "seed": seed, + "guidance": guidance, + "interval": interval, + "safety_tolerance": safety_tolerance, } - if seed is not None: - self.request_json["seed"] = seed + self.request_json = {key: value for key, value in self.request_json.items() if value is not None} self.request_id: str | None = None self.result: dict | None = None @@ -105,7 +136,7 @@ def request(self): if self.request_id is not None: return response = requests.post( - f"{API_ENDPOINT}/v1/image", + f"{API_URL}/v1/{API_ENDPOINTS[self.name]}", headers={ "accept": "application/json", "x-key": self.api_key, @@ -126,7 +157,7 @@ def retrieve(self) -> dict: self.request() while self.result is None: response = requests.get( - f"{API_ENDPOINT}/v1/get_result", + f"{API_URL}/v1/get_result", headers={ "accept": "application/json", "x-key": self.api_key, diff --git a/src/flux/cli.py b/src/flux/cli.py index 4e92960d..d44e459a 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -6,16 +6,16 @@ import torch import torch._inductor.config as inductor_config -from einops import rearrange from fire import Fire -from PIL import Image +from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5 +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, save_image_without_nsfw_check -NSFW_THRESHOLD = 0.85 +CHECK_NSFW = False TORCH_COMPILE = os.getenv("TORCH_COMPILE", "0") == "1" + @dataclass class SamplingOptions: prompt: str @@ -25,37 +25,6 @@ class SamplingOptions: guidance: float seed: int | None -class CudaTimer: - """ - A static context manager class for measuring execution time of PyTorch code - using CUDA events. It synchronizes GPU operations to ensure accurate time measurements. - """ - - def __init__(self, name="", precision=5, display=False): - self.name = name - self.precision = precision - self.display = display - - def __enter__(self): - torch.cuda.synchronize() - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - self.start_event.record() - return self - - def __exit__(self, *exc): - self.end_event.record() - torch.cuda.synchronize() - # Convert from ms to s - self.elapsed_time = self.start_event.elapsed_time(self.end_event) * 1e-3 - - if self.display: - print(f"{self.name}: {self.elapsed_time:.{self.precision}f} s") - - def get_elapsed_time(self): - """Returns the elapsed time in microseconds.""" - return self.elapsed_time - def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" @@ -111,7 +80,7 @@ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: continue _, steps = prompt.split() options.num_steps = int(steps) - print(f"Setting seed to {options.num_steps}") + print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None @@ -130,11 +99,17 @@ def main( width: int = 1360, height: int = 768, seed: int | None = None, - prompt: str = "A tree", + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int | None = None, loop: bool = False, guidance: float = 3.5, + offload: bool = False, output_dir: str = "output", + add_sampling_metadata: bool = True, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a @@ -154,12 +129,13 @@ def main( guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata """ - device = torch.device("cuda") + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if name not in configs: available = ", ".join(configs.keys()) raise ValueError(f"Got unknown model name: {name}, chose from {available}") + torch_device = torch.device(device) if num_steps is None: num_steps = 4 if name == "flux-schnell" else 50 @@ -179,10 +155,10 @@ def main( idx = 0 # init all components - t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512) - clip = load_clip(device=device) - model = load_flow_model(name, device=device) - ae = load_ae(name, device=device) + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) if TORCH_COMPILE: # torch._inductor.list_options() @@ -208,69 +184,67 @@ def main( if loop: opts = parse_prompt(opts) - # warmup - for _ in range(3): + + while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + # prepare input x = get_noise( 1, opts.height, opts.width, - device=device, + device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) opts.seed = None - + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) inp = prepare(t5, clip, x, prompt=opts.prompt) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + # denoise initial noise - # with torch.no_grad(): - # with torch.profiler.profile(on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), with_stack=True) as prof: x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) - - with torch.autocast(device_type=device.type, dtype=torch.bfloat16): + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) - fn = output_name.format(idx=idx) - print(f"Saving {fn}") - # bring into PIL format and save - x = x.clamp(-1, 1) - x = embed_watermark(x.float()) - x = rearrange(x[0], "c h w -> h w c") + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() - img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) - img.save(fn, exif=Image.Exif(), quality=95, subsampling=0) - idx += 1 + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") - with CudaTimer(display=False) as timer: - if opts.seed is None: - opts.seed = rng.seed() - print(f"Generating with seed {opts.seed}:\n{opts.prompt}") - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - opts.seed = None - inp = prepare(t5, clip, x, prompt=opts.prompt) - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=device.type, dtype=torch.bfloat16): - x = ae.decode(x) + if CHECK_NSFW: + idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt) + else: + idx = save_image_without_nsfw_check(name, output_name, idx, x, add_sampling_metadata, prompt) - print(f"Inference time: {timer.get_elapsed_time()}") + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None def app(): diff --git a/src/flux/cli_control.py b/src/flux/cli_control.py new file mode 100644 index 00000000..cd83c89e --- /dev/null +++ b/src/flux/cli_control.py @@ -0,0 +1,347 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from transformers import pipeline + +from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder +from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + lora_scale: float | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]: + changed = False + + if options is None: + return None, changed + + user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the lora scale or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/q"): + print("Quitting") + return None, changed + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.lora_scale = float(prompt) + changed = True + return options, changed + + +@torch.inference_mode() +def main( + name: str, + width: int = 1024, + height: int = 1024, + seed: int | None = None, + prompt: str = "a robot made out of gold", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int = 50, + loop: bool = False, + guidance: float | None = None, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/robot.webp", + lora_scale: float | None = 0.85, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + assert name in [ + "flux-dev-canny", + "flux-dev-depth", + "flux-dev-canny-lora", + "flux-dev-depth-lora", + ], f"Got unknown model name: {name}" + if guidance is None: + if name in ["flux-dev-canny", "flux-dev-canny-lora"]: + guidance = 30.0 + elif name in ["flux-dev-depth", "flux-dev-depth-lora"]: + guidance = 10.0 + else: + raise NotImplementedError() + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + # set lora scale + if "lora" in name and lora_scale is not None: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(lora_scale) + + if name in ["flux-dev-depth", "flux-dev-depth-lora"]: + img_embedder = DepthImageEncoder(torch_device) + elif name in ["flux-dev-canny", "flux-dev-canny-lora"]: + img_embedder = CannyImageEncoder(torch_device) + else: + raise NotImplementedError() + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + lora_scale=lora_scale, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + if "lora" in name: + opts, changed = parse_lora_scale(opts) + if changed: + # update the lora scale: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(opts.lora_scale) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) + inp = prepare_control( + t5, + clip, + x, + prompt=opts.prompt, + ae=ae, + encoder=img_embedder, + img_cond_path=opts.img_cond_path, + ) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs and AE to CPU, load model to gpu + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + if "lora" in name: + opts, changed = parse_lora_scale(opts) + if changed: + # update the lora scale: + for _, module in model.named_modules(): + if hasattr(module, "set_scale"): + module.set_scale(opts.lora_scale) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/src/flux/cli_fill.py b/src/flux/cli_fill.py new file mode 100644 index 00000000..415c0420 --- /dev/null +++ b/src/flux/cli_fill.py @@ -0,0 +1,334 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from PIL import Image +from transformers import pipeline + +from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + img_mask_path: str + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + else: + with Image.open(img_cond_path) as img: + width, height = img.size + + if width % 32 != 0 or height % 32 != 0: + print(f"Image dimensions must be divisible by 32, got {width}x{height}") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning mask or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_mask_path = input(user_question) + + if img_mask_path.startswith("/"): + if img_mask_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_mask_path.startswith("/h"): + print(f"Got invalid command '{img_mask_path}'\n{usage}") + print(usage) + continue + + if img_mask_path == "": + break + + if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_mask_path}' does not exist or is not a valid image file") + continue + else: + with Image.open(img_mask_path) as img: + width, height = img.size + + if width % 32 != 0 or height % 32 != 0: + print(f"Image dimensions must be divisible by 32, got {width}x{height}") + continue + else: + with Image.open(options.img_cond_path) as img_cond: + img_cond_width, img_cond_height = img_cond.size + + if width != img_cond_width or height != img_cond_height: + print( + f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}" + ) + continue + + options.img_mask_path = img_mask_path + break + + return options + + +@torch.inference_mode() +def main( + seed: int | None = None, + prompt: str = "a white paper cup", + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int = 50, + loop: bool = False, + guidance: float = 30.0, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/cup.png", + img_mask_path: str = "assets/cup_mask.png", +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. This demo assumes that the conditioning image and mask have + the same shape and that height and width are divisible by 32. + + Args: + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + img_mask_path: path to conditioning mask (jpeg/png/webp + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + name = "flux-dev-fill" + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=128) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + with Image.open(img_cond_path) as img: + width, height = img.size + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + img_mask_path=img_mask_path, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + with Image.open(opts.img_cond_path) as img: + width, height = img.size + opts.height = height + opts.width = width + + opts = parse_img_mask_path(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch.device) + inp = prepare_fill( + t5, + clip, + x, + prompt=opts.prompt, + ae=ae, + img_cond_path=opts.img_cond_path, + mask_path=opts.img_mask_path, + ) + + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs and AE to CPU, load model to gpu + if offload: + t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + with Image.open(opts.img_cond_path) as img: + width, height = img.size + opts.height = height + opts.width = width + + opts = parse_img_mask_path(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/src/flux/cli_redux.py b/src/flux/cli_redux.py new file mode 100644 index 00000000..6c03435a --- /dev/null +++ b/src/flux/cli_redux.py @@ -0,0 +1,279 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from fire import Fire +from transformers import pipeline + +from flux.modules.image_embedders import ReduxImageEncoder +from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image + + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + img_cond_path: str + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Leave this field empty to do nothing " + "or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting number of steps to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + return options + + +def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: + if options is None: + return None + + user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the conditioning image or write a command starting with a slash:\n" + "- '/q' to quit" + ) + + while True: + img_cond_path = input(user_question) + + if img_cond_path.startswith("/"): + if img_cond_path.startswith("/q"): + print("Quitting") + return None + else: + if not img_cond_path.startswith("/h"): + print(f"Got invalid command '{img_cond_path}'\n{usage}") + print(usage) + continue + + if img_cond_path == "": + break + + if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + print(f"File '{img_cond_path}' does not exist or is not a valid image file") + continue + + options.img_cond_path = img_cond_path + break + + return options + + +@torch.inference_mode() +def main( + name: str = "flux-dev", + width: int = 1360, + height: int = 768, + seed: int | None = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 2.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, + img_cond_path: str = "assets/robot.webp", +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + img_cond_path: path to conditioning image (jpeg/png/webp) + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + img_embedder = ReduxImageEncoder(torch_device) + + rng = torch.Generator(device="cpu") + prompt = "" + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + img_cond_path=img_cond_path, + ) + + if loop: + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare_redux( + t5, + clip, + x, + prompt=opts.prompt, + encoder=img_embedder, + img_cond_path=opts.img_cond_path, + ) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + print(f"Done in {t1 - t0:.1f}s") + + idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt) + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + opts = parse_img_cond_path(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/src/flux/math.py b/src/flux/math.py index 24af4625..2927f460 100644 --- a/src/flux/math.py +++ b/src/flux/math.py @@ -1,15 +1,14 @@ import os import torch +from torch import Tensor from torch.nn.functional import scaled_dot_product_attention -import xformers.ops -import xformers.ops.fmha as fmha from einops import rearrange -from torch import Tensor -from triton.ops import attention as attention_triton -def compiled_xformers_flash_hopper(q, k, v): +def _compiled_xformers_flash_hopper(q, k, v): + import xformers.ops + torch_custom_op_compile = os.getenv("TORCH_CUSTOM_OP_COMPILE", "0") == "1" if torch_custom_op_compile: @@ -22,7 +21,7 @@ def compiled_xformers_flash_hopper(q, k, v): xformers_flash3 = xformers.ops.fmha.flash3.FwOp() softmax_scale = q.size(-1) ** -0.5 - return fmha.memory_efficient_attention_forward( # noqa: E731 + return xformers.ops.fmha.memory_efficient_attention_forward( # noqa: E731 q, k, v, @@ -38,16 +37,32 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) if xformers_flash3: + if torch_sdpa or triton_attention: + print( + "Warning: xformers_flash3 is enabled, but torch_sdpa or triton_attention is also enabled. " + "Please remain only one of them." + ) + q = q.permute(0, 2, 1, 3) # B, H, S, D k = k.permute(0, 2, 1, 3) # B, H, S, D v = v.permute(0, 2, 1, 3) # B, H, S, D - x = compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3) - if torch_sdpa: + x = _compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3) + elif torch_sdpa: + if triton_attention: + print( + "Warning: torch_sdpa is enabled, but triton_attention is also enabled. " + "Please remain only one of them." + ) + x = scaled_dot_product_attention(q, k, v) - if triton_attention: + elif triton_attention: + from triton.ops import attention as attention_triton + softmax_scale = q.size(-1) ** -0.5 x = attention_triton(q, k, v, True, softmax_scale) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") return x diff --git a/src/flux/model.py b/src/flux/model.py index f33ab832..46527131 100644 --- a/src/flux/model.py +++ b/src/flux/model.py @@ -3,14 +3,31 @@ import torch from torch import Tensor, nn -from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding) - +from flux.modules.layers import ( + EmbedND, + LastLayer, + MLPEmbedder, + timestep_embedding, +) +from flux.modules.lora import LinearLora, replace_linear_with_lora + +try: + import triton_kernels + from triton_kernels import SingleStreamBlock, DoubleStreamBlock +except ImportError: + print("Triton kernels not found, using flux native implementation.") + from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock +except ModuleNotFoundError: + print("Triton kernels not found, using flux native implementation.") + from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock +except Exception as e: + print(f"Error: {e}") + from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock @dataclass class FluxParams: in_channels: int + out_channels: int vec_in_dim: int context_in_dim: int hidden_size: int @@ -34,7 +51,7 @@ def __init__(self, params: FluxParams): self.params = params self.in_channels = params.in_channels - self.out_channels = self.in_channels + self.out_channels = params.out_channels if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -110,3 +127,27 @@ def forward( img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + +class FluxLoraWrapper(Flux): + def __init__( + self, + lora_rank: int = 128, + lora_scale: float = 1.0, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.lora_rank = lora_rank + + replace_linear_with_lora( + self, + max_rank=lora_rank, + scale=lora_scale, + ) + + def set_lora_scale(self, scale: float) -> None: + for module in self.modules(): + if isinstance(module, LinearLora): + module.set_scale(scale=scale) diff --git a/src/flux/modules/conditioner.py b/src/flux/modules/conditioner.py index 7cdd8818..e60297e4 100644 --- a/src/flux/modules/conditioner.py +++ b/src/flux/modules/conditioner.py @@ -1,6 +1,5 @@ from torch import Tensor, nn -from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, - T5Tokenizer) +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class HFEmbedder(nn.Module): diff --git a/src/flux/modules/image_embedders.py b/src/flux/modules/image_embedders.py new file mode 100644 index 00000000..e7177d2f --- /dev/null +++ b/src/flux/modules/image_embedders.py @@ -0,0 +1,103 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange, repeat +from PIL import Image +from safetensors.torch import load_file as load_sft +from torch import nn +from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel + +from flux.util import print_load_warning + + +class DepthImageEncoder: + depth_model_name = "LiheYoung/depth-anything-large-hf" + + def __init__(self, device): + self.device = device + self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device) + self.processor = AutoProcessor.from_pretrained(self.depth_model_name) + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + hw = img.shape[-2:] + + img = torch.clamp(img, -1.0, 1.0) + img_byte = ((img + 1.0) * 127.5).byte() + + img = self.processor(img_byte, return_tensors="pt")["pixel_values"] + depth = self.depth_model(img.to(self.device)).predicted_depth + depth = repeat(depth, "b h w -> b 3 h w") + depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True) + + depth = depth / 127.5 - 1.0 + return depth + + +class CannyImageEncoder: + def __init__( + self, + device, + min_t: int = 50, + max_t: int = 200, + ): + self.device = device + self.min_t = min_t + self.max_t = max_t + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + assert img.shape[0] == 1, "Only batch size 1 is supported" + + img = rearrange(img[0], "c h w -> h w c") + img = torch.clamp(img, -1.0, 1.0) + img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8) + + # Apply Canny edge detection + canny = cv2.Canny(img_np, self.min_t, self.max_t) + + # Convert back to torch tensor and reshape + canny = torch.from_numpy(canny).float() / 127.5 - 1.0 + canny = rearrange(canny, "h w -> 1 1 h w") + canny = repeat(canny, "b 1 ... -> b 3 ...") + return canny.to(self.device) + + +class ReduxImageEncoder(nn.Module): + siglip_model_name = "google/siglip-so400m-patch14-384" + + def __init__( + self, + device, + redux_dim: int = 1152, + txt_in_features: int = 4096, + redux_path: str | None = os.getenv("FLUX_REDUX"), + dtype=torch.bfloat16, + ) -> None: + assert redux_path is not None, "Redux path must be provided" + + super().__init__() + + self.redux_dim = redux_dim + self.device = device if isinstance(device, torch.device) else torch.device(device) + self.dtype = dtype + + with self.device: + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + sd = load_sft(redux_path, device=str(device)) + missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + + self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype) + self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name) + + def __call__(self, x: Image.Image) -> torch.Tensor: + imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) + + _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state + + projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) + + return projected_x diff --git a/src/flux/modules/lora.py b/src/flux/modules/lora.py new file mode 100644 index 00000000..556027e8 --- /dev/null +++ b/src/flux/modules/lora.py @@ -0,0 +1,94 @@ +import torch +from torch import nn + + +def replace_linear_with_lora( + module: nn.Module, + max_rank: int, + scale: float = 1.0, +) -> None: + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + new_lora = LinearLora( + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias, + rank=max_rank, + scale=scale, + dtype=child.weight.dtype, + device=child.weight.device, + ) + + new_lora.weight = child.weight + new_lora.bias = child.bias if child.bias is not None else None + + setattr(module, name, new_lora) + else: + replace_linear_with_lora( + module=child, + max_rank=max_rank, + scale=scale, + ) + + +class LinearLora(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + rank: int, + dtype: torch.dtype, + device: torch.device, + lora_bias: bool = True, + scale: float = 1.0, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias is not None, + device=device, + dtype=dtype, + *args, + **kwargs, + ) + + assert isinstance(scale, float), "scale must be a float" + + self.scale = scale + self.rank = rank + self.lora_bias = lora_bias + self.dtype = dtype + self.device = device + + if rank > (new_rank := min(self.out_features, self.in_features)): + self.rank = new_rank + + self.lora_A = nn.Linear( + in_features=in_features, + out_features=self.rank, + bias=False, + dtype=dtype, + device=device, + ) + self.lora_B = nn.Linear( + in_features=self.rank, + out_features=out_features, + bias=self.lora_bias, + dtype=dtype, + device=device, + ) + + def set_scale(self, scale: float) -> None: + assert isinstance(scale, float), "scalar value must be a float" + self.scale = scale + + def forward(self, input: torch.Tensor) -> torch.Tensor: + base_out = super().forward(input) + + _lora_out_B = self.lora_B(self.lora_A(input)) + lora_update = _lora_out_B * self.scale + + return base_out + lora_update diff --git a/src/flux/sampling.py b/src/flux/sampling.py index da37b49e..048b76cf 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -1,13 +1,16 @@ import math from typing import Callable -import time +import numpy as np import torch from einops import rearrange, repeat +from PIL import Image from torch import Tensor from .model import Flux +from .modules.autoencoder import AutoEncoder from .modules.conditioner import HFEmbedder +from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder def get_noise( @@ -46,25 +49,152 @@ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[st if isinstance(prompt, str): prompt = [prompt] - - # start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - # start.record() txt = t5(prompt) - # end.record() - # torch.cuda.synchronize() - # print(f"t5 time: {start.elapsed_time(end):.2f} ms") - if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) - # start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - # start.record() vec = clip(prompt) - # end.record() - # torch.cuda.synchronize() - # print(f"clip time: {start.elapsed_time(end):.2f} ms") + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def prepare_control( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ae: AutoEncoder, + encoder: DepthImageEncoder | CannyImageEncoder, + img_cond_path: str, +) -> dict[str, Tensor]: + # load and encode the conditioning image + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + + width = w * 8 + height = h * 8 + img_cond = img_cond.resize((width, height), Image.LANCZOS) + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + + with torch.no_grad(): + img_cond = encoder(img_cond) + img_cond = ae.encode(img_cond) + + img_cond = img_cond.to(torch.bfloat16) + img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + return_dict = prepare(t5, clip, img, prompt) + return_dict["img_cond"] = img_cond + return return_dict + + +def prepare_fill( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + ae: AutoEncoder, + img_cond_path: str, + mask_path: str, +) -> dict[str, Tensor]: + # load and encode the conditioning image and the mask + bs, _, _, _ = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + img_cond = np.array(img_cond) + img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 + img_cond = rearrange(img_cond, "h w c -> 1 c h w") + + mask = Image.open(mask_path).convert("L") + mask = np.array(mask) + mask = torch.from_numpy(mask).float() / 255.0 + mask = rearrange(mask, "h w -> 1 1 h w") + with torch.no_grad(): + img_cond = img_cond.to(img.device) + mask = mask.to(img.device) + img_cond = img_cond * (1 - mask) + img_cond = ae.encode(img_cond) + mask = mask[:, 0, :, :] + mask = mask.to(torch.bfloat16) + mask = rearrange( + mask, + "b (h ph) (w pw) -> b (ph pw) h w", + ph=8, + pw=8, + ) + mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if mask.shape[0] == 1 and bs > 1: + mask = repeat(mask, "1 ... -> bs ...", bs=bs) + + img_cond = img_cond.to(torch.bfloat16) + img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + img_cond = torch.cat((img_cond, mask), dim=-1) + + return_dict = prepare(t5, clip, img, prompt) + return_dict["img_cond"] = img_cond.to(img.device) + return return_dict + + +def prepare_redux( + t5: HFEmbedder, + clip: HFEmbedder, + img: Tensor, + prompt: str | list[str], + encoder: ReduxImageEncoder, + img_cond_path: str, +) -> dict[str, Tensor]: + bs, _, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img_cond = Image.open(img_cond_path).convert("RGB") + with torch.no_grad(): + img_cond = encoder(img_cond) + + img_cond = img_cond.to(torch.bfloat16) + if img_cond.shape[0] == 1 and bs > 1: + img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + txt = torch.cat((txt, img_cond.to(txt)), dim=-2) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) @@ -119,13 +249,15 @@ def denoise( # sampling parameters timesteps: list[float], guidance: float = 4.0, + # extra img tokens + img_cond: Tensor | None = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( - img=img, + img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, diff --git a/src/flux/util.py b/src/flux/util.py index 77fc76c0..f2acee1d 100644 --- a/src/flux/util.py +++ b/src/flux/util.py @@ -5,18 +5,84 @@ from einops import rearrange from huggingface_hub import hf_hub_download from imwatermark import WatermarkEncoder +from PIL import ExifTags, Image from safetensors.torch import load_file as load_sft -from flux.model import Flux, FluxParams +from flux.model import Flux, FluxLoraWrapper, FluxParams from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from flux.modules.conditioner import HFEmbedder +def save_image_without_nsfw_check( + name: str, + output_name: str, + idx: int, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, +) -> int: + fn = output_name.format(idx=idx) + print(f"Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + + return idx + + +def save_image( + nsfw_classifier, + name: str, + output_name: str, + idx: int, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, + nsfw_threshold: float = 0.85, +) -> int: + fn = output_name.format(idx=idx) + print(f"Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < nsfw_threshold: + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + return idx + + @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams ckpt_path: str | None + lora_path: str | None ae_path: str | None repo_id: str | None repo_flow: str | None @@ -29,8 +95,10 @@ class ModelSpec: repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV"), + lora_path=None, params=FluxParams( in_channels=64, + out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, @@ -61,8 +129,10 @@ class ModelSpec: repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_SCHNELL"), + lora_path=None, params=FluxParams( in_channels=64, + out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, @@ -88,6 +158,176 @@ class ModelSpec: shift_factor=0.1159, ), ), + "flux-dev-canny": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Canny-dev", + repo_flow="flux1-canny-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_CANNY"), + lora_path=None, + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-canny-lora": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + lora_path=os.getenv("FLUX_DEV_CANNY_LORA"), + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-depth": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Depth-dev", + repo_flow="flux1-depth-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_DEPTH"), + lora_path=None, + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-depth-lora": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + lora_path=os.getenv("FLUX_DEV_DEPTH_LORA"), + params=FluxParams( + in_channels=128, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fill": ModelSpec( + repo_id="black-forest-labs/FLUX.1-Fill-dev", + repo_flow="flux1-fill-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FILL"), + lora_path=None, + params=FluxParams( + in_channels=384, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), } @@ -102,10 +342,13 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) -def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): +def load_flow_model( + name: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False +) -> Flux: # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path + lora_path = configs[name].lora_path if ( ckpt_path is None and configs[name].repo_id is not None @@ -115,14 +358,27 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) with torch.device("meta" if ckpt_path is not None else device): - model = Flux(configs[name].params).to(torch.bfloat16) + if lora_path is not None: + model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16) + else: + model = Flux(configs[name].params).to(torch.bfloat16) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) + sd = optionally_expand_state_dict(model, sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) - print_load_warning(missing, unexpected) + if verbose: + print_load_warning(missing, unexpected) + + if configs[name].lora_path is not None: + print("Loading LoRA") + lora_sd = load_sft(configs[name].lora_path, device=str(device)) + # loading the lora params + overwriting scale values in the norms + missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True) + if verbose: + print_load_warning(missing, unexpected) return model @@ -157,6 +413,25 @@ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = return ae +def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: + """ + Optionally expand the state dict to match the model's parameters shapes. + """ + for name, param in model.named_parameters(): + if name in state_dict: + if state_dict[name].shape != param.shape: + print( + f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." + ) + # expand with zeros: + expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) + slices = tuple(slice(0, dim) for dim in state_dict[name].shape) + expanded_state_dict_weight[slices] = state_dict[name] + state_dict[name] = expanded_state_dict_weight + + return state_dict + + class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark