画像生成 AI 入門: Python による拡散モデルの理論と実践#
Section 07. Play with Diffusion Model#
Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。
Lecture 22. Attend-and-Excite#
Attend-and-Excite [Chefer+ SIGGRAPH'23] を用いて Stable Diffusion 等の生成モデルが苦手としている複数の物体や性質を指定しているプロンプトに対する忠実な画像生成を実現します。
セットアップ#
GPU が使用できるか確認#
本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi
コマンドが失敗する場合は再度講義資料の GPU 使用設定
のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。
!nvidia-smi
Sun Jul 16 14:39:33 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 64C P8 12W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
利用する Python ライブラリをインストール#
diffusers ライブラリをインストールすることで拡散モデルを簡単に使用できるようにします。diffusers ライブラリを動かす上で必要となるライブラリも追加でインストールします:
transformers: 拡散モデルにおいて核となる Transformer モデルが定義されているライブラリ
accelerate: transformers と連携してより高速な画像生成をサポートするライブラリ
!pip install diffusers==0.16.1
!pip install transformers accelerate
Collecting diffusers==0.16.1
Downloading diffusers-0.16.1-py3-none-any.whl (934 kB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/934.9 kB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 934.9/934.9 kB 36.7 MB/s eta 0:00:00
?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (8.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (3.12.2)
Collecting huggingface-hub>=0.13.2 (from diffusers==0.16.1)
Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/268.8 kB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 33.0 MB/s eta 0:00:00
?25hCollecting importlib-metadata (from diffusers==0.16.1)
Downloading importlib_metadata-6.8.0-py3-none-any.whl (22 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2.27.1)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (23.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata->diffusers==0.16.1) (3.16.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (3.4)
Installing collected packages: importlib-metadata, huggingface-hub, diffusers
Successfully installed diffusers-0.16.1 huggingface-hub-0.16.4 importlib-metadata-6.8.0
Collecting transformers
Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 54.8 MB/s eta 0:00:00
?25hCollecting accelerate
Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 244.2/244.2 kB 27.3 MB/s eta 0:00:00
?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 116.0 MB/s eta 0:00:00
?25hCollecting safetensors>=0.3.1 (from transformers)
Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 88.0 MB/s eta 0:00:00
?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)
Installing collected packages: tokenizers, safetensors, transformers, accelerate
Successfully installed accelerate-0.21.0 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.30.2
Attend-and-Excite を扱うパイプラインを構築#
本セクションでは StableDiffusionAttendAndExcitePipeline
を使用して Attend-and-Excite パイプラインの動作を確認します。
まず準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。
from typing import List
from PIL import Image
from PIL.Image import Image as PilImage
def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
今回もこれまで同様に事前学習済みモデルとして runwayml/stable-diffusion-v1-5
を使用します。以下、素の Stable Diffusion と Attend-and-Excite による Stable Diffusion それぞれのパイプラインを読み込みます。
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionAttendAndExcitePipeline
model_id = "runwayml/stable-diffusion-v1-5"
#
# 素の Stable Diffusion (SD) の読み込み
#
pipe_sd = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe_sd = pipe_sd.to("cuda")
#
# Stable Diffusion をベースとした Attend & Excite (AE) の読み込み
#
pipe_ae = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe_ae = pipe_ae.to("cuda")
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
入力するプロンプトと乱数の seed を指定します。今回はプロンプトとして Stable Diffusion が潜在的に苦手とする複数の物体の生成を指定してみます。
prompt = "A horse and a dog"
seed = 42
素の Stable Diffusion で画像を生成してみます。生成時に乱数の seed を固定することを忘れないようにします。
generator = torch.Generator().manual_seed(seed)
images_sd = pipe_sd(
prompt,
num_images_per_prompt=2,
generator=generator
).images
Stable Diffusion で生成した画像を確認します。プロンプトとして A horse and a dog
を指定しましたが、以下の画像ではどうでしょうか?
gen_result_sd = image_grid(images_sd, rows=1, cols=2)
gen_result_sd
次に Stable Diffusion を元にした Attend-and-Excite で画像を生成させてみます。
# `get_indices` 関数を使用して、対象のトークン(horse と dog)のインデックスを調べる
# 2 と 5 がそれぞれ horse と dog であることを確認
print(f"Indicies: {pipe_ae.get_indices(prompt)}")
# 上記で調べたトークンのインデックスを指定
token_indices = [2, 5]
generator = torch.Generator().manual_seed(seed)
# Attend-and-Excite パイプラインによって画像を生成
images_ae = pipe_ae(
prompt,
num_images_per_prompt=2,
generator=generator,
#
# Additional arguments for Attend-and-Excite
# 対象のトークンを指定
#
token_indices=token_indices,
).images
Indicies: {0: '<|startoftext|>', 1: 'a</w>', 2: 'horse</w>', 3: 'and</w>', 4: 'a</w>', 5: 'dog</w>', 6: '<|endoftext|>'}
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Attend-and-Excite で生成した画像を確認します。プロンプトとして A horse and a dog
を指定しましたが、以下の画像ではどうでしょうか?正しく犬と馬が生成されているように見えます。
gen_result_ae = image_grid(images_ae, rows=1, cols=2)
gen_result_ae
最後に素の Stable Diffusion と Attend-and-Excite の生成結果を比較してみましょう。
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(20, 5))
grid = ImageGrid(
fig,
rect=111,
nrows_ncols=(1, 2),
axes_pad=0.1,
)
fig.suptitle(f"Prompt: {prompt}")
grid[0].imshow(gen_result_sd)
grid[0].axis("off")
grid[0].set_title("Stable Diffusion")
grid[1].imshow(gen_result_ae)
grid[1].axis("off")
grid[1].set_title("Stable Diffusion with Attend-and-Excite")
Text(0.5, 1.0, 'Stable Diffusion with Attend-and-Excite')