Enable Textual Inversion with Stable Diffusion Pipeline via Optimum-Intel

Introduction

Stable Diffusion (SD) is a state-of-the-art latent text-to-image diffusion model that generates photorealistic images from text. Recently, many fine-tuning technologies proposed to create custom Stable Diffusion pipelines for personalized image generation, such as Textual Inversion, Low-Rank Adaptation (LoRA). We’ve already published a blog for enabling LoRA with Stable Diffusion + ControlNet pipeline.

In this blog, we will focus on enabling pre-trained textual inversion with Stable Diffusion via Optimum-Intel. The feature is available in the latest Optimum-Intel, and documentation is available here.

Textual Inversion is a technique for capturing novel concepts from a small number of example images in a way that can later be used to control text-to-image pipelines. It does so by learning new “words” in the embedding space of the pipeline’s text encoder.

Figure1. Textual Inversion sample: injecting user-specific concepts into new scenes

As Figure 1 shows, you can teach new concepts to a model such as Stable Diffusion for personalized image generation using just 3-5 images.

Hugging Face Diffusers and Stable Diffusion Web UI provides useful tools and guides to train and save custom textual inversion embeddings. The pre-trained textual inversion embeddings are widely available in sd-concepts-library and civitai, which can be loaded for inference with the StableDiffusionPipeline using Pytorch as the runtime backend.

Here is an example to load pre-trained textual inversion embedding sd-concepts-library/cat-toy to inference with Pytorch backend.

from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
prompt = "A <cat-toy> backpack"

image = pipe(prompt, num_inference_steps=50).images[0]
image.save("cat-backpack.png")

Optimum-Intel provides the interface between the Hugging Face Transformers and Diffusers libraries to leverage OpenVINOTM runtime to accelerate end-to-end pipelines on Intel architectures.

Figure2: Two approaches to enable textual inversion with Stable Diffusion

As Figure 2 shows that two approaches are available to enable textual inversion with Stable Diffusion via Optimum-Intel.

Although approach 1 seems quite straightforward and does not need any code modification in Optimum-Intel, the method requires the re-export ONNX model and then model conversion to the OpenVINOTM IR model whenever the SD baseline model is merged with anew textual inversion.

Instead, we propose approach 2 to support OVStableDiffusionPipelineBase to load pre-trained textual inversion embeddings in runtime to save disk storage while keeping flexibility.

  • Save disk storage: We only need to save an SD baseline model converted to OpenVINOTM IR (e.g.: SD-1.5 ~5GB) and multiple textual embeddings (~10KB-100KB), instead of multiple SD OpenVINOTM IR with textual inversion embeddings merged (~n *5GB), since disk storage is limited, especially for edge/client use case.
  • Flexibility: We can load (multiple) pre-trained textual inversion embeddings in the SD baseline model in runtime quickly, which supports the combination of embeddings and avoid messing up the baseline model.

How to enable textual inversion in runtime?

We implemented OVTextualInversionLoaderMixinbased on diffusers.loaders.TextualInversionLoaderMixin with the following features:

  • Load and parse textual embeddings saved as*.bin, *.pt, *.safetensors as a list of Tensors.
  • Update tokenizer for new “words” using new token id and expand vocabulary size.
  • Update text encoder embeddings via InsertTextEmbedding class based on OpenVINOTM ngraph transformation.

For the implementation details of OVTextualInversionLoaderMixin, please refer to here

Here is the sample code for InsertTextEmbedding class:

class InsertTextEmbedding(MatcherPass):
    r"""
    OpenVINO ngraph transformation for inserting pre-trained texual inversion embedding to text encoder
    """

    def __init__(self, token_ids_and_embeddings):
        MatcherPass.__init__(self)
        self.model_changed = False
        param = WrapType("opset1.Constant")

        def callback(matcher: Matcher) -> bool:
            root = matcher.get_match_root()
            if root.get_friendly_name() == TEXTUAL_INVERSION_EMBEDDING_KEY:
                add_ti = root
                consumers = matcher.get_match_value().get_target_inputs()
                for token_id, embedding in token_ids_and_embeddings:
                    ti_weights = ops.constant(embedding, Type.f32, name=str(token_id))
                    ti_weights_unsqueeze = ops.unsqueeze(ti_weights, axes=0)
                    add_ti = ops.concat(
                        nodes=[add_ti, ti_weights_unsqueeze],
                        axis=0,
                        name=f"{TEXTUAL_INVERSION_EMBEDDING_KEY}.textual_inversion_{token_id}",
                    )

                for consumer in consumers:
                    consumer.replace_source_output(add_ti.output(0))

                # Use new operation for additional matching
                self.register_new_node(add_ti)

            # Root node wasn't replaced or changed
            return False

        self.register_matcher(Matcher(param, "InsertTextEmbedding"), callback)

InsertTextEmbeddingclass utilizes OpenVINOTM ngraph MatcherPass function to insert subgraph into the model. Please note, the MacherPass function can only filter layers by type, so we run two phases of filtering to find the layer that matched with the pre-defined key in the model:

  • Filter all Constant layers to trigger the callback function.
  • Filter layer name with pre-defined key “TEXTUAL_INVERSION_EMBEDDING_KEY” in the callback function

If the root name matched the pre-defined key, we will loop all parsed textual inversion embedding and token id pair and create a subgraph (Constant + Unsqueeze + Concat) by OpenVINOTM operation sets to insert into the text encoder model. In the end, we update the root output node with the last node in the subgraph.

Figure3. Overview of InsertTextEmbedding OpenVINOTM ngraph transformation

Figure 3 demonstrates the workflow of InsertTextEmbedding OpenVINOTM ngraph transformation. The left part shows the subgraph in SD 1.5 baseline text encoder model, where text embedding has a Constant node with shape [49408, 768], the 1st dimension is consistent with the original tokenizer (vocab size 49408), and the second dimension is feature length of each text embedding.

When we load (multiple) textual inversion, all textual inversion embeddings will be parsed as a list of tensors with shape[768], and each textual inversion constant will be unsqueezed and concatenated with original text embeddings. The right part is the result of applying InsertTextEmbedding ngraph transformation on the original text encoder, the green rectangle represents merged textual inversion subgraph.

Figure 4. 3 phase of SD 1.5 text encoder subgraph with single textual inversion visualized in Netron.

As Figure 4 shows, In the first phase, the original text embedding (marked as blue rectangle) is saved in Const node “text_model.embeddings.token_embedding.weight” with shape [49408,768], after InsertTextEmbedding ngraph transformation, new subgraph (marked as red rectangle) will be created in 2nd phase. In the 3rd phase, during model compilation, the new subgraph will be const folding into a single const node (marked as green rectangle) with a new shape [49409,768] by OpenVINOTM ConstantFolding transformation.

Stable Diffusion Textual Inversion Sample

Here are textual inversion examples verified with Stable Diffusion v1.5, Stable Diffusion v2.1 and Stable Diffusion XL 1.0 Base pipeline with latest optimum-intel

Setup Environment

conda create -n optimum-intel python=3.10
conda activate optimum-intel
python -m pip install "optimum-intel[openvino]"@git+https://github.com/huggingface/optimum-intel.git
python -m pip install transformers, diffusers, safetensors
python -m pip install invisible-watermark>=0.2.0

Run SD 1.5 + Cat-Toy Textual Inversion Example

from optimum.intel import OVStableDiffusionPipeline
import numpy as np

model_id = "runwayml/stable-diffusion-v1-5"
prompt = "A <cat-toy> back-pack"
np.random.seed(42)

# Run pipeline without textual inversion
pipe = OVStableDiffusionPipeline.from_pretrained(model_id, compile=False)
pipe.compile()
image1 = pipe(prompt, num_inference_steps=50).images[0]
image1.save("sd_v1.5_without_cat_toy_ti.png")

# Run pipeline with textual inversion
pipe.clear_requests()
pipe.load_textual_inversion("sd_concepts/cat-toy", "<cat-toy>")
pipe.compile()
image2 = pipe(prompt, num_inference_steps=50).images[0]
image2.save("sd_v1.5_with_cat_toy_ti.png")
Figure 5. The left image shows the generation result of SD 1.5 baseline, while the right image shows the generation result of SD 1.5 baseline + Cat-Toy textual inversion.

Run SD 2.1 + Midjourney 2.0 Textual Inversion Example

from optimum.intel import OVStableDiffusionPipeline
import numpy as np

model_id = "stabilityai/stable-diffusion-2-1"
prompt = "A <midjourney> style photo of an astronaut riding a horse on mars"
np.random.seed(42)

# Run pipeline without midjourney textual inversion
pipe = OVStableDiffusionPipeline.from_pretrained(model_id, compile=False, cache_dir=None)
pipe.compile()
image1 = pipe(prompt, num_inference_steps=50).images[0]
image1.save("sd_v2.1_without_midjourney_ti.png")

# Run pipeline with midjourney textual inversion
pipe.clear_requests()
pipe.load_textual_inversion("midjourney_sd_2_0", "<midjourney>")
pipe.compile()
image2 = pipe(prompt, num_inference_steps=50).images[0]
image2.save("sd_v2.1_with_midjourney_ti.png")
Figure 6. The left image shows the generation result of SD 2.1 baseline, while the right image shows the generation result of SD 2.1 + Midjourney 2.0 textual inversion.

Run SDXL 1.0 Base + CharTurnerV2 Textual Inversion Example

from optimum.intel import OVStableDiffusionXLPipeline
import numpy as np

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a beautiful woman wearing a red jacket and black shirt, best quality, intricate details."
np.random.seed(112)

pipe = OVStableDiffusionXLPipeline.from_pretrained(model_id, export=False, compile=False, cache_dir=None)

# Run pipeline without textual inversion
pipe.compile()
image1 = pipe(prompt, num_inference_steps=50).images[0]
image1.save("sdxl_base_1.0_without_charturnerv2_ti.png")

# Run pipeline with textual inversion
pipe.clear_requests()
pipe.load_textual_inversion("./charturnerv2.pt", "charturnerv2")
pipe.compile()
image2 = pipe(prompt, num_inference_steps=50).images[0]
image2.save("sdxl_base_1.0_with_charturnerv2_ti.png")
Figure 7. The left image shows the generation result of SDXL 1.0 Base baseline, while the right image shows the generation result of SDXL 1.0 Base + CharTurnerV2 textual inversion.

Conclusion

In this blog, we proposed to load textual inversion embedding in the stable diffusion pipeline in runtime to save disk storage while keeping flexibility.

  • Implemented OVTextualInversionLoaderMixin to update tokenizer with additional token id and update text encoder with InsertTextEmbedding OpenVNO ngraph transformation.
  • Provides sample code to load textual inversion with SD 1.5, SD 2.1, and SDXL 1.0 Base and inference with Optimum-Intel

Reference

An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion

Optimum-Intel Text-to-Image with Textual Inversion

Hugging Face Textual Inversion