画像生成 AI 入門: Python による拡散モデルの理論と実践#

Open In Colab

Section 07. Play with Diffusion Model#

Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。

Lecture 24. Prompt-to-Prompt#

Prompt-to-Prompt [Hertz+ ICLR'23] を用いて Stable Diffusion で生成した画像の編集を実現します。

セットアップ#

GPU が使用できるか確認#

本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi コマンドが失敗する場合は再度講義資料の GPU 使用設定 のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。

!nvidia-smi
Mon Jul 24 11:11:42 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

利用する Python ライブラリをインストール#

diffusers ライブラリをインストールすることで拡散モデルを簡単に使用できるようにします。diffusers ライブラリを動かす上で必要となるライブラリも追加でインストールします:

  • transformers: 拡散モデルにおいて核となる Transformer モデルが定義されているライブラリ

  • accelerate: transformers と連携してより高速な画像生成をサポートするライブラリ

!pip install diffusers==0.4.1
!pip install transformers accelerate
Requirement already satisfied: diffusers==0.4.1 in /usr/local/lib/python3.10/dist-packages (0.4.1)
Requirement already satisfied: importlib-metadata in /usr/lib/python3/dist-packages (from diffusers==0.4.1) (4.6.4)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (3.12.2)
Requirement already satisfied: huggingface-hub>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (0.16.4)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (2.27.1)
Requirement already satisfied: Pillow<10.0 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.4.1) (8.4.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.4.1) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.4.1) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.4.1) (6.0.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.4.1) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.4.1) (23.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.4.1) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.4.1) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.4.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.4.1) (3.4)
Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)
Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.21.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)

Prompt-to-Prompt による Cross Attention の制御#

本セクションでは Stable Diffusion をベースに、prompt-to-prompt による cross attention の制御に関する動作を確認します。

まず準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。

from typing import List
from PIL import Image
from PIL.Image import Image as PilImage

def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

以下、Improving Generative Images with Instructions: Prompt-to-Prompt Image Editing with Cross Attention Control - wandb 🪄🐝 を参考に動作を追っていきます。まず初めに Stable Diffusion を構成するコンポーネントを読み込んでいきます。なお本 notebook では runwayml/stable-diffusion-v1-5 を使用します。

import torch
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel

# CLIP tokenizer と text encoder の読み込み
model_path_clip = "openai/clip-vit-large-patch14"
clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)
clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch.float16)
clip = clip_model.text_model

# StableDiffusion のコンポーネントの読み込み
model_path_diffusion = "runwayml/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", revision="fp16", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", revision="fp16", torch_dtype=torch.float16)

# 各コンポーネントをそれぞれ GPU へ移動
device = "cuda"
unet.to(device)
vae.to(device)
clip.to(device)

print("Loaded all models")
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loaded all models

Cross attention を制御するための関数を以下のように定義していきます。

from typing import Tuple

def init_attention_weights(weight_tuples: List[Tuple[int, int]]) -> None:
    tokens_length = clip_tokenizer.model_max_length
    weights = torch.ones(tokens_length)

    for i, w in weight_tuples:
        if i < tokens_length and i >= 0:
            weights[i] = w

    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.last_attn_slice_weights = weights.to(device)
        if module_name == "CrossAttention" and "attn1" in name:
            module.last_attn_slice_weights = None

次にオリジナルのプロンプトと編集後のプロンプトの差分から cross attention を制御する関数を定義します。

from difflib import SequenceMatcher
from transformers.tokenization_utils_base import BatchEncoding

def init_attention_edit(tokens: BatchEncoding, tokens_edit: BatchEncoding):

    tokens_length = clip_tokenizer.model_max_length
    mask = torch.zeros(tokens_length)
    indices_target = torch.arange(tokens_length, dtype=torch.long)
    indices = torch.zeros(tokens_length, dtype=torch.long)

    tokens = tokens.input_ids.numpy()[0]
    tokens_edit = tokens_edit.input_ids.numpy()[0]

    for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
        if b0 < tokens_length:
            if name == "equal" or (name == "replace" and a1-a0 == b1-b0):
                mask[b0:b1] = 1
                indices[b0:b1] = indices_target[a0:a1]

    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.last_attn_slice_mask = mask.to(device)
            module.last_attn_slice_indices = indices.to(device)
        if module_name == "CrossAttention" and "attn1" in name:
            module.last_attn_slice_mask = None
            module.last_attn_slice_indices = None

Stable Diffusion の UNet における cross attention の計算を以下の関数で差し替えます。

def init_attention_func():
    #ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
    def new_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> torch.Tensor:

        # TODO: use baddbmm for better performance
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
        attn_slice = attention_scores.softmax(dim=-1)
        # compute attention output

        if self.use_last_attn_slice:
            if self.last_attn_slice_mask is not None:
                new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
                attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
            else:
                attn_slice = self.last_attn_slice

            self.use_last_attn_slice = False

        if self.save_last_attn_slice:
            self.last_attn_slice = attn_slice
            self.save_last_attn_slice = False

        if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
            attn_slice = attn_slice * self.last_attn_slice_weights
            self.use_last_attn_weights = False

        hidden_states = torch.matmul(attn_slice, value)
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states

    def new_sliced_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        sequence_length: int,
        dim: int
    ) -> torch.Tensor:

        batch_size_attention = query.shape[0]
        hidden_states = torch.zeros(
            (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
        )
        slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
        for i in range(hidden_states.shape[0] // slice_size):
            start_idx = i * slice_size
            end_idx = (i + 1) * slice_size
            attn_slice = (
                torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
            )  # TODO: use baddbmm for better performance
            attn_slice = attn_slice.softmax(dim=-1)

            if self.use_last_attn_slice:
                if self.last_attn_slice_mask is not None:
                    new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
                    attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
                else:
                    attn_slice = self.last_attn_slice

                self.use_last_attn_slice = False

            if self.save_last_attn_slice:
                self.last_attn_slice = attn_slice
                self.save_last_attn_slice = False

            if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
                attn_slice = attn_slice * self.last_attn_slice_weights
                self.use_last_attn_weights = False

            attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])

            hidden_states[start_idx:end_idx] = attn_slice

        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states

    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention":
            module.last_attn_slice = None
            module.use_last_attn_slice = False
            module.use_last_attn_weights = False
            module.save_last_attn_slice = False
            module._sliced_attention = new_sliced_attention.__get__(module, type(module))
            module._attention = new_attention.__get__(module, type(module))

更に上記で差し替えた cross attention に細かく制御を可能にする属性値を導入していきます。

def use_last_tokens_attention(use: bool = True) -> None:
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.use_last_attn_slice = use

def use_last_tokens_attention_weights(use: bool = True) -> None:
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.use_last_attn_weights = use

def use_last_self_attention(use: bool = True) -> None:
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn1" in name:
            module.use_last_attn_slice = use

def save_last_tokens_attention(save: bool = True) -> None:
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.save_last_attn_slice = save

def save_last_self_attention(save: bool = True) -> None:
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn1" in name:
            module.save_last_attn_slice = save

上記で定義した関数を組み込んだ Prompt-to-Prompt を実現する関数を以下のように定義します。

import random

import numpy as np

from typing import Optional

from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm

@torch.no_grad()
def stable_diffusion(
    # オリジナルのプロンプト
    prompt: str ="",
    # 編集後のプロンプト
    prompt_edit: Optional[str] = None,

    # 編集後のプロンプトにおける各トークンの重みの指定
    prompt_edit_token_weights: Optional[List[Tuple[int, int]]] = None,

    # 最初のプロンプトに対する影響度合い
    # 局所的な特徴(細かな詳細やテクスチャ)を制御
    prompt_edit_tokens_start: float = 0.0,
    # 大域的な特徴(大雑把な特徴や一般的なシーン構成)を制御
    prompt_edit_tokens_end: float = 1.0,

    # オリジナルのプロンプトから生成された画像に対する影響度合い
    # 局所的な特徴を制御
    prompt_edit_spatial_start: float = 0.0,
    # 大域的な特徴を制御
    prompt_edit_spatial_end: float =1.0,

    # その他、text2image のハイパーパラメータ群
    guidance_scale: float = 7.5,
    steps: int = 50,
    seed: Optional[str] = None,
    width: int = 512,
    height: int = 512,
    init_image: Optional[PilImage] = None,
    init_image_strength: float = 0.5
) -> PilImage:
    # 編集後のプロンプトにおけるトークンの重みが指定されていない場合は
    # 空のリストで初期化しておく
    if prompt_edit_token_weights is None:
        prompt_edit_token_weights = []

    # モデル内部で画像サイズの不一致防ぐために
    # Stable Diffusion に合うよう生成画像のサイズを 64 の倍数に変更する
    width = width - width % 64
    height = height - height % 64

    # 乱数の seed が指定されていない (= None) の場合
    # seed をランダムに設定して固定する
    if seed is None: seed = random.randrange(2**32 - 1)
    generator = torch.cuda.manual_seed(seed)

    # ノイズスケジューラを設定する
    scheduler = LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000
    )
    scheduler.set_timesteps(steps)

    # === image2image 条件下の場合 ===
    # 初期状態として画像が入力された場合に前処理を行う
    if init_image is not None:
        # 画像をリサイズしてテンソルにしつつ、
        # そのテンソルの形を変更: numpy (b, h, w, c) -> torch (b, c, h, w)
        init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
        init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0
        init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))

        # RGB の 3 チャンネル以上ある場合(例: アルファチャンネル)、
        # 今回使用している Stable Diffusion を含め多くの拡散モデルは
        # アルファチャンネルをサポートしていないため、アルファチャンネルの
        # 不透明度を考慮して元々のチャンネルに合成
        if init_image.shape[1] > 3:
            init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])

        # 画像テンソルを GPU へ移動
        init_image = init_image.to(device)

        # 画像テンソルを元に、潜在データをサンプリング
        with torch.autocast(device):
            init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215

        t_start = steps - int(steps * init_image_strength)

    # === image2image 条件下ではない場合 ===
    else:
        init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
        t_start = 0

    # ガウスノイズを生成
    noise = torch.randn(init_latent.shape, generator=generator, device=device)
    # latent = noise * scheduler.init_noise_sigma
    latent = scheduler.add_noise(init_latent, noise, torch.tensor([scheduler.timesteps[t_start]], device=device)).to(device)

    # CLIP text encoder による条件付けの計算
    with torch.autocast(device):
        tokens_unconditional = clip_tokenizer(
            text="",
            padding="max_length",
            max_length=clip_tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
            return_overflowing_tokens=True
        )
        embedding_unconditional = clip(
            tokens_unconditional.input_ids.to(device)
        ).last_hidden_state

        tokens_conditional = clip_tokenizer(
            text=prompt,
            padding="max_length",
            max_length=clip_tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
            return_overflowing_tokens=True
        )
        embedding_conditional = clip(
            tokens_conditional.input_ids.to(device)
        ).last_hidden_state

        # プロンプト編集による条件付けの計算
        if prompt_edit is not None:
            tokens_conditional_edit = clip_tokenizer(
                text=prompt_edit,
                padding="max_length",
                max_length=clip_tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
                return_overflowing_tokens=True
            )
            embedding_conditional_edit = clip(
                tokens_conditional_edit.input_ids.to(device)
            ).last_hidden_state

            init_attention_edit(tokens_conditional, tokens_conditional_edit)

        init_attention_func()
        init_attention_weights(prompt_edit_token_weights)

        timesteps = scheduler.timesteps[t_start:]

        # === denoising step ===
        for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
            t_index = t_start + i

            latent_model_input = latent
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # ノイズの残差を予測
            noise_pred_uncond = unet(
                latent_model_input, t, encoder_hidden_states=embedding_unconditional
            ).sample

            # Cross attention の計算の準備
            if prompt_edit is not None:
                save_last_tokens_attention()
                save_last_self_attention()
            else:
                # 編集されていないプロンプトに対して重みを適用する
                use_last_tokens_attention_weights()

            # プロンプトの条件を考慮したノイズ残差の予測と cross attention の計算の保存を実施
            noise_pred_cond = unet(
                latent_model_input, t, encoder_hidden_states=embedding_conditional
            ).sample

            # Cross attention の計算に変更を加える
            if prompt_edit is not None:
                t_scale = t / scheduler.num_train_timesteps
                if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
                    use_last_tokens_attention()
                if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
                    use_last_self_attention()

                # 編集後のプロンプトに対して重みを適用する
                use_last_tokens_attention_weights()

                # 編集後のプロンプトの条件を考慮したノイズ残差の予測と cross attention の計算を保存
                noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional_edit).sample

            # Classifier-free guidance を適用
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            latent = scheduler.step(noise_pred, t_index, latent).prev_sample

        # 潜在データのスケールを調整して VAE でデコードする
        latent = latent / 0.18215
        image = vae.decode(latent.to(vae.dtype)).sample

    # 生成画像の後処理を適用
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = (image[0] * 255).round().astype("uint8")
    return Image.fromarray(image)

編集対象のプロンプト内のトークンを確認するための関数を以下のようにして定義します。

def prompt_token(prompt: str, index: int) -> str:
    tokens = clip_tokenizer(
        prompt,
        padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True).input_ids[0]
    return clip_tokenizer.decode(tokens[index:index+1])

以下のプロンプトにおいて 2 番目のトークンは fantasy であることがわかります。これを参考に fantasy を他の単語に入れ替えたり、影響度合いを調整したりすることが Prompt-to-Prompt の利点になります。

prompt = "A fantasy landscape with a pine forest, trending on artstation"
prompt_token(prompt, index=2)
'fantasy'

では A fantasy landscape with a pine forest, trending on artstation というプロンプトに対してまずは Stable Diffusion を用いて画像を生成させてみます。

common_kwargs = {"seed": 2483964025, "width": 768}

prompt = "A fantasy landscape with a pine forest, trending on artstation"
image_origin = stable_diffusion(prompt, **common_kwargs)
image_origin
../_images/998aa95e5aac0b5b3265c0ca0d2f223eeb2fca31479a3d11d778ce8450687c72.png

次に生成した画像に対して fantasy 要素を消してみようと思います。fantasy は 2 番目のトークンであったため、2 を指定しつつ、その影響度を下げる -8 を重みとして指定します。

image_edited = stable_diffusion(
    prompt="A fantasy landscape with a pine forest, trending on artstation",
    prompt_edit_token_weights=[(2, -8)],
    **common_kwargs
)
image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

以上のように、よりリアルな森の画像が生成されました。

次に winter という単語をプロンプトに追加することで、生成画像を冬の景色にしてみましょう。

prompt =      "A fantasy landscape with a pine forest, trending on artstation"
prompt_edit = "A winter fantasy landscape with a pine forest, trending on artstation"

image_edited = stable_diffusion(
    prompt=prompt,
    prompt_edit=prompt_edit,
    **common_kwargs,
)
image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

指定したとおりに元々の生成画像に対して雪が降ったような画像を生成することができました。

次は水彩画のようなスタイルで生成画像を変更してもらおうと思います。

prompt      = "A fantasy landscape with a pine forest, trending on artstation"
prompt_edit = "A watercolor painting of a landscape with a pine forest, trending on artstation"

image_edited = stable_diffusion(
    prompt=prompt,
    prompt_edit=prompt_edit,
    **common_kwargs,
)
image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

オリジナルの画像に対して水彩画のようなスタイルの画像が生成されました。

次は霧 fog を生成画像から取り除くように編集してみます。fog の位置は prompt_token 関数を使って 9 番目であると特定できます。Prompt-to-Prompt 用に用意した stable_diffusion 関数の prompt_edit_token_weight 引数に [(fog の位置, そのトークンの重み)] の形式で、fog の影響が小さくなるように -6 を設定してみました。

prompt =      "A fantasy landscape with a pine forest, trending on artstation"
prompt_edit = "A fantasy landscape with a pine forest with fog, trending on artstation"

target_token = prompt_token(prompt_edit, 9)
print(f"Target token of the edit: {target_token}")

image_edited = stable_diffusion(
    prompt=prompt,
    prompt_edit=prompt_edit,
    prompt_edit_token_weights=[(9, -6)],
    **common_kwargs,
)
image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

元々の生成画像から霧を取り除いたような画像が生成されました。

次は霧 fog に加えて、岩 rock も取り除いてみましょう。prompt_token 関数で対象のトークンの位置を確認し、prompt_edit_token_weights にトークン位置とその重みを設定して画像を生成してみます。

prompt      = "A fantasy landscape with a pine forest, trending on artstation"
prompt_edit = "A fantasy landscape with a pine forest with fog and rocks, trending on artstation"

target_token = prompt_token(prompt_edit, 9)
print(f"Target token of the edit: {target_token}")

target_token = prompt_token(prompt_edit, 11)
print(f"Target token of the edit: {target_token}")

image_edited = stable_diffusion(
    prompt=prompt,
    prompt_edit=prompt_edit,
    prompt_edit_token_weights=[(9, -6), (11, -6)],
    **common_kwargs,
)
image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

元々の生成画像から霧と岩を取り除いたような画像が生成されました。

最後に生成画像に川を追加してみましょう。prompt_edit_spatial_end では大域的な編集をより大きな値を設定することで反映させることが可能です。

prompt      = "A fantasy landscape with a pine forest, trending on artstation"
prompt_edit = "A fantasy landscape with a pine forest and a river, trending on artstation"

image_edited = stable_diffusion(
    prompt=prompt,
    prompt_edit=prompt_edit,
    prompt_edit_spatial_end=0.8,
    **common_kwargs,
)

image_grid([image_origin, image_edited], rows=1, cols=2)
Output hidden; open in https://colab.research.google.com to view.

以上のようにして Prompt-to-Prompt の動作を確認しました。元々のプロンプトを参考に単語の入れ替えや token, attention の重み付けを変えるなどして柔軟に生成画像を編集することが可能であることを確認しました。