画像生成 AI 入門: Python による拡散モデルの理論と実践#

Open In Colab

Section 07. Play with Diffusion Model#

Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。

Lecture 25. InstructPix2Pix#

InstructPix2Pix [Brooks+ CVPR'23] を用いて、テキストプロンプトにのみによる画像編集を実現します。

セットアップ#

GPU が使用できるか確認#

本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi コマンドが失敗する場合は再度講義資料の GPU 使用設定 のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。

!nvidia-smi
Wed Jul 19 13:53:32 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

利用する Python ライブラリをインストール#

diffusers ライブラリをインストールすることで拡散モデルを簡単に使用できるようにします。diffusers ライブラリを動かす上で必要となるライブラリも追加でインストールします:

  • transformers: 拡散モデルにおいて核となる Transformer モデルが定義されているライブラリ

  • accelerate: transformers と連携してより高速な画像生成をサポートするライブラリ

!pip install diffusers==0.16.1
!pip install transformers accelerate
Requirement already satisfied: diffusers==0.16.1 in /usr/local/lib/python3.10/dist-packages (0.16.1)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (8.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (3.12.2)
Requirement already satisfied: huggingface-hub>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (0.16.4)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (6.8.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2.27.1)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (23.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata->diffusers==0.16.1) (3.16.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (3.4)
Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)
Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.21.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)

InstructPix2Pix を扱うパイプラインを構築#

本セクションでは StableDiffusionInstructPix2PixPipeline を使用して InstructPix2Pix パイプラインの動作を確認します。

以下、InstructPix2Pix: Learning to Follow Image Editing Instructions を参考に、InstructPix2Pix の動作を追っていきます。

InstructPix2Pix は Stable Diffusion をベースにテキストで画像編集の指示が出せるモデルです。

まず準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。

from typing import List
from PIL import Image
from PIL.Image import Image as PilImage

def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

今回使用するサンプル画像を InstructPix2Pix の著者のレポジトリからダウンロードしてきます。以下ではミケランジェロのダビデ像の画像を使用します。

from diffusers.utils import load_image

original_image = load_image("https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg")
original_image
../_images/e9928f12a663dd297ad0e10a4f354044f0d75508d62057353fc3394c0fc6a04b.png

InstructPix2Pix が使えるように StableDiffusionInstructPix2PixPipeline で学習済みパラメータである timbrooks/instruct-pix2pix を読み込みます。

import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler

model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    revision="fp16",
    safety_checker=None, # NSFW 用のフィルターを今回は使わないようにします
)
pipe = pipe.to("cuda")

# attention 計算の最適化を実施
pipe.enable_attention_slicing()
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .

InstructPix2Pix の論文でも紹介されているプロンプト turn him into cyborg を試してみます。オリジナルの画像とその編集指示プロンプトをパイプラインに入力することで、指示した通りのスタイルへと画像を編集します。

ここで image_guidance_scale はオリジナル画像をどれだけ反映するかのパラメータです。大きければ大きいほどオリジナル画像に近い画像が生成されます

edit_instruction = "turn him into cyborg"

generator = torch.Generator().manual_seed(19950815)

edited_image = pipe(
    prompt=edit_instruction,
    image=original_image,
    generator=generator,
    num_inference_steps=30,
    image_guidance_scale=1,
).images[0]

image_grid([original_image, edited_image], rows=1, cols=2)
../_images/5c639a555f1754a7ec21e4ed56c87056dca912a4b0f921796de5671b54d112f5.png

編集指示テキストの強さ s_T と オリジナル画像の強さ s_I をそれぞれ変化させたときに生成される画像を比較してみましょう。パイプラインの呼び出しを簡略化する以下の関数を使って比較してみます。

def generate_image(s_T: float, s_I: float) -> PilImage:
    generator = torch.Generator().manual_seed(19950815)

    return pipe(
        prompt=edit_instruction,
        image=original_image,
        generator=generator,
        num_inference_steps=30,
        guidance_scale=s_T,
        image_guidance_scale=s_I
    ).images[0]

s_Ts_I をそれぞれ変化させて、画像を生成させてみます。

import itertools

s_T_list = [3, 7.5, 15]
s_I_list = [1.0, 1.2, 1.6]
combinations = list(itertools.product(s_I_list, s_T_list))

images = []
for (s_I, s_T) in combinations:
    print(f"Generate image with s_T = {s_T} and s_I = {s_I}")
    images.append(generate_image(s_T=s_T, s_I=s_I))
Generate image with s_T = 3 and s_I = 1.0
Generate image with s_T = 7.5 and s_I = 1.0
Generate image with s_T = 15 and s_I = 1.0
Generate image with s_T = 3 and s_I = 1.2
Generate image with s_T = 7.5 and s_I = 1.2
Generate image with s_T = 15 and s_I = 1.2
Generate image with s_T = 3 and s_I = 1.6
Generate image with s_T = 7.5 and s_I = 1.6
Generate image with s_T = 15 and s_I = 1.6

生成結果について確認します。

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 10))

for i, ((s_T, s_I), image) in enumerate(zip(combinations, images)):
    ax = fig.add_subplot(3, 3, i + 1)
    ax.imshow(image)
    ax.set_title(f"$s_T = {s_T}$ and $s_I = {s_I}$")
    ax.axis("off")
plt.show()
../_images/3a63abe2229991798647667114c16169580970e700f9eee39d043c8e369edf71.png
  • 横方向で見ると: 左から右へ、指示テキストの影響が強くなっていく

  • 縦方向で見ると: 上から下へ、オリジナル画像の影響が強くなっていく