画像生成 AI 入門: Python による拡散モデルの理論と実践#

Open In Colab

Section 07. Play with Diffusion Model#

Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。

Lecture 29. Safe Latent Diffusion#

Safe Latent Diffusion [Schramowski+ CVPR'23] による、不適切な画像生成を抑制する手法を実践します。Safe Latent Diffusion は classifier-free guidance の枠組みを応用して、予め決めた "不適切コンテンツ" を生成しないように誘導する機構を備えています。 一方で乱数のシード値やハイパーパラメータの設定によって不適切な画像が生成されてしまう場合もあることに注意してください。

セットアップ#

GPU が使用できるか確認#

本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi コマンドが失敗する場合は再度講義資料の GPU 使用設定 のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。

!nvidia-smi
Wed Jul 26 07:20:52 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

利用する Python ライブラリをインストール#

diffusers ライブラリをインストールすることで拡散モデルを簡単に使用できるようにします。diffusers ライブラリを動かす上で必要となるライブラリも追加でインストールします:

  • transformers: 拡散モデルにおいて核となる Transformer モデルが定義されているライブラリ

  • accelerate: transformers と連携してより高速な画像生成をサポートするライブラリ

!pip install diffusers==0.16.1
!pip install transformers accelerate
Collecting diffusers==0.16.1
  Downloading diffusers-0.16.1-py3-none-any.whl (934 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 934.9/934.9 kB 8.5 MB/s eta 0:00:00
?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (8.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (3.12.2)
Collecting huggingface-hub>=0.13.2 (from diffusers==0.16.1)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 11.7 MB/s eta 0:00:00
?25hRequirement already satisfied: importlib-metadata in /usr/lib/python3/dist-packages (from diffusers==0.16.1) (4.6.4)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2.27.1)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (6.0.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (23.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (3.4)
Installing collected packages: huggingface-hub, diffusers
Successfully installed diffusers-0.16.1 huggingface-hub-0.16.4
Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.4/7.4 MB 22.4 MB/s eta 0:00:00
?25hCollecting accelerate
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 244.2/244.2 kB 27.2 MB/s eta 0:00:00
?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 58.1 MB/s eta 0:00:00
?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 62.4 MB/s eta 0:00:00
?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)
Installing collected packages: tokenizers, safetensors, transformers, accelerate
Successfully installed accelerate-0.21.0 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.31.0

Safe Latent Diffusion を扱うパイプラインを構築#

本セクションでは StableDiffusionPipelineSafe を使用して Stable Diffusion をもとにした Safe Latent Diffusion パイプラインの動作を確認します。

まず準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。

from typing import List
from PIL import Image
from PIL.Image import Image as PilImage

def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

Stable Diffusion を元にした Safe Latent Diffusion は以下のように StableDiffusionPipelineSafe にて簡単に読み込むことが可能です。

import torch
from diffusers import StableDiffusionPipeline, StableDiffusionPipelineSafe

model_id = "runwayml/stable-diffusion-v1-5"

# パイプラインの読み込み
pipe_safe = StableDiffusionPipelineSafe.from_pretrained(
    model_id,
    revision="fp16",
    torch_dtype=torch.float16,
)
pipe_safe = pipe_safe.to("cuda")
text_encoder/model.safetensors not found
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
  warnings.warn(
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.

不適切なコンセプトを抑制する Safety Concept の確認#

StableDiffusionPipelineSafe には不適切なコンセプトを抑制する classifier-free guidance を元にした機構が存在します。以下は本パイプラインが対象としている不適切コンセプトを safety_concept で確認してみます。

pipe_safe.safety_concept
'an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty'

StableDiffusionPipelineSafe による画像生成#

StableDiffusionPipelineSafe はこれまでのパイプライン同様に、プロンプトや乱数の seed 、classifier-free guidance のパラメータを与えて以下のように画像が生成可能です。今回使用するプロンプトは huggingface/diffusers の StableDiffusionPipelineSafe にて挙げられている例を使用しています。このプロンプト例は必ずしも不適切なコンテンツを表していないことに注意してください。

prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"

seed = 42
generator = torch.Generator().manual_seed(seed)
guidance_scale = 6

image_safe = pipe_safe(
    prompt=prompt,
    generator=generator,
    guidance_scale=guidance_scale,
).images[0]

image_safe
../_images/56da1024a3e4c6cf7502af8a8f607c54d485624605b491bbad77aed06af70a41.png

オリジナルの Stable Diffusion との比較#

オリジナルの Stable Diffusion を StableDiffusionPipeline を用いて読み込み、上記と同様のプロンプトを使用して画像を生成させてみます。これまで同様に以下のようにしてパイプラインを読み込みます。

from diffusers import StableDiffusionPipeline

pipe_unsafe = StableDiffusionPipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16
)
pipe_unsafe = pipe_unsafe.to("cuda")
text_encoder/model.safetensors not found
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.

読み込んだ StableDiffusionPipeline を用いて上記で試したプロンプトで画像を生成させてみましょう。

generator = torch.Generator().manual_seed(seed)

image_unsafe = pipe_unsafe(
    prompt=prompt, generator=generator, guidance_scale=guidance_scale
).images[0]

image_grid([image_unsafe, image_safe], rows=1, cols=2)
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
../_images/be2f66404464642b4a920a13ae88a7c809f79779acefc5aed3839607ea65eb40.png

オリジナルの Stable Diffusion で画像を生成させた場合、NSFW (not safe for work) フィルターに引っかかってしまいました。一方で Safe Latent Diffusion では NSFW フィルターの判定では安全だと考えられる画像が生成されました。

複数の Safety Configuration における画像生成#

Safe Latent Diffusion の論文でも紹介されていた 4 つのハイパーパラメータ(safety configuration)の設定 WEAK, MEDIUM, STRONG, MAX をそれぞれ比較していきます。

from diffusers.pipelines.stable_diffusion_safe import SafetyConfig

print(f"WEAK:   {SafetyConfig.WEAK}")
print(f"MEDIUM: {SafetyConfig.MEDIUM}")
print(f"STRONG: {SafetyConfig.STRONG}")
print(f"MAX:    {SafetyConfig.MAX}")
WEAK:   {'sld_warmup_steps': 15, 'sld_guidance_scale': 20, 'sld_threshold': 0.0, 'sld_momentum_scale': 0.0, 'sld_mom_beta': 0.0}
MEDIUM: {'sld_warmup_steps': 10, 'sld_guidance_scale': 1000, 'sld_threshold': 0.01, 'sld_momentum_scale': 0.3, 'sld_mom_beta': 0.4}
STRONG: {'sld_warmup_steps': 7, 'sld_guidance_scale': 2000, 'sld_threshold': 0.025, 'sld_momentum_scale': 0.5, 'sld_mom_beta': 0.7}
MAX:    {'sld_warmup_steps': 0, 'sld_guidance_scale': 5000, 'sld_threshold': 1.0, 'sld_momentum_scale': 0.5, 'sld_mom_beta': 0.7}

StableDiffusionPipelineSafe ではデフォルトでは SafetyConfig.MEDIUM のハイパーパラメータが採用されています。

prompt = "the four horsewomen of the apocalypse, \
painting by tom of finland, gaston bussiere, \
craig mullins, j. c. leyendecker"

generator = torch.Generator().manual_seed(seed)

image_default = pipe_safe(
    prompt=prompt,
    generator=generator,
    guidance_scale=guidance_scale,
).images[0]

generator = torch.Generator().manual_seed(seed)

image_medium = pipe_safe(
    prompt=prompt,
    generator=generator,
    guidance_scale=guidance_scale,
    **SafetyConfig.MEDIUM, # MEDIUM のハイパーパラメータを指定
).images[0]

image_grid([image_default, image_medium], rows=1, cols=2)
../_images/4bbd3f14d552a4d49b953898b9515e80382dee1115feac987e7ac41f7d1333bb.png

以下は MEDIUM, STRONG, MAX それぞれのハイパーパラメータの設定で Safe Latent Diffusion で画像を生成させたときの結果です。ハイパーパラメータの強度を上げれば上げるほど不適切なコンテンツを生成する確率が下がるように見えます。

safety_configs = {
    # "WEAK": SafetyConfig.WEAK, # 不適切な画像が生成しうるので、今回は対象外としました
    "MEDIUM": SafetyConfig.MEDIUM,
    "STRONG": SafetyConfig.STRONG,
    "MAX": SafetyConfig.MAX,
}

generated_images = []
for config_type, safety_config in safety_configs.items():
    print(f"Generating images based on the following safety config: {config_type} -> {safety_config}")

    generator = torch.Generator().manual_seed(seed)
    generated_image = pipe_safe(
        prompt=prompt,
        generator=generator,
        guidance_scale=guidance_scale,
        **safety_config,
    ).images[0]
    generated_images.append(generated_image)

image_grid(generated_images, rows=1, cols=len(generated_images))
Generating images based on the following safety config: MEDIUM -> {'sld_warmup_steps': 10, 'sld_guidance_scale': 1000, 'sld_threshold': 0.01, 'sld_momentum_scale': 0.3, 'sld_mom_beta': 0.4}
Generating images based on the following safety config: STRONG -> {'sld_warmup_steps': 7, 'sld_guidance_scale': 2000, 'sld_threshold': 0.025, 'sld_momentum_scale': 0.5, 'sld_mom_beta': 0.7}
Generating images based on the following safety config: MAX -> {'sld_warmup_steps': 0, 'sld_guidance_scale': 5000, 'sld_threshold': 1.0, 'sld_momentum_scale': 0.5, 'sld_mom_beta': 0.7}
../_images/66d60fb2ae228a6d5b384e7e1a108a57b17c1118ebfbf75a13309eb6520209d7.png