画像生成 AI 入門: Python による拡散モデルの理論と実践#

Open In Colab

Section 07. Play with Diffusion Model#

Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。

Lecture 21. DreamBooth#

DreamBooth [Ruiz+ CVPR'23] を用いて Stable Diffusion に新しい概念(コンセプト; concept)を「教える」方法を紹介します。

セットアップ#

GPU が使用できるか確認#

本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi コマンドが失敗する場合は再度講義資料の GPU 使用設定 のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。

!nvidia-smi
Sun Jul 30 02:05:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

利用する Python ライブラリをインストール#

diffusers ライブラリをインストールすることで拡散モデルを簡単に使用できるようにします。diffusers ライブラリを動かす上で必要となるライブラリも追加でインストールします:

  • transformers: 拡散モデルにおいて核となる Transformer モデルが定義されているライブラリ

  • accelerate: transformers と連携してより高速な画像生成をサポートするライブラリ

  • xformers: accelerate と同様に、Transformer モデルをより効率的に扱い高速な画像生成をサポートするライブラリ

  • bitsandbytes: 通常単精度 float32 であるところを半精度 float16 よりも少ない 8-bit でのモデルの読み込みが可能なライブラリ

!pip install diffusers==0.16.1
!pip install transformers accelerate xformers bitsandbytes
Collecting diffusers==0.16.1
  Downloading diffusers-0.16.1-py3-none-any.whl (934 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 934.9/934.9 kB 7.5 MB/s eta 0:00:00
?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (9.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (3.12.2)
Collecting huggingface-hub>=0.13.2 (from diffusers==0.16.1)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 11.3 MB/s eta 0:00:00
?25hRequirement already satisfied: importlib-metadata in /usr/lib/python3/dist-packages (from diffusers==0.16.1) (4.6.4)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.16.1) (2.27.1)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (6.0.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.16.1) (23.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2023.7.22)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.16.1) (3.4)
Installing collected packages: huggingface-hub, diffusers
Successfully installed diffusers-0.16.1 huggingface-hub-0.16.4
Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.4/7.4 MB 19.0 MB/s eta 0:00:00
?25hCollecting accelerate
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 244.2/244.2 kB 30.0 MB/s eta 0:00:00
?25hCollecting xformers
  Downloading xformers-0.0.20-cp310-cp310-manylinux2014_x86_64.whl (109.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 109.1/109.1 MB 1.9 MB/s eta 0:00:00
?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.41.0-py3-none-any.whl (92.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.6/92.6 MB 9.1 MB/s eta 0:00:00
?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 100.8 MB/s eta 0:00:00
?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 70.6 MB/s eta 0:00:00
?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)
Collecting pyre-extensions==0.0.29 (from xformers)
  Downloading pyre_extensions-0.0.29-py3-none-any.whl (12 kB)
Collecting typing-inspect (from pyre-extensions==0.0.29->xformers)
  Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pyre-extensions==0.0.29->xformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)
Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions==0.0.29->xformers)
  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Installing collected packages: tokenizers, safetensors, bitsandbytes, mypy-extensions, typing-inspect, transformers, pyre-extensions, xformers, accelerate
Successfully installed accelerate-0.21.0 bitsandbytes-0.41.0 mypy-extensions-1.0.0 pyre-extensions-0.0.29 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.31.0 typing-inspect-0.9.0 xformers-0.0.20

DreamBooth#

本セクションでは Dreambooth fine-tuning for Stable Diffusion using d🧨ffusers を参考に、dreambooth の動作を追っていきます。

前回の講義で取り扱った Textual Inversion とは異なり、本手法はモデル全体を学習させるため学習に利用できるパラメータ数が増えることになり結果的によりよい画像を生成することが可能です。一方で学習するパラメータが増えることで学習時間も長くなります。

まず準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。

from typing import List
from PIL import Image
from PIL.Image import Image as PilImage

def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

新たな概念をモデルに教えるための設定#

使用する事前学習済み拡散モデルを指定します。今回は runwayml/stable-diffusion-v1-5 を選択しました。

model_id = "runwayml/stable-diffusion-v1-5"

Google drive にデータを保存する設定を行います。Files タブから google drive に用意した画像を読み込むことも可能ですが、今回は google drive をマウントし、そこから drive 上に保存した画像を読み込む方法を検討します。

import os
from google.colab import drive

# /content/drive をマウントする
DRIVE_PATH = os.path.join(os.sep, "content", "drive")
print(f"Mount the following directory: {DRIVE_PATH}")

drive.mount(DRIVE_PATH)

#
# 本 notebook 用のデータを格納するディレクトリを作成する
# まずベースとなるディレクトリとして以下のようなディレクリを作成する:
# /content/drive/MyDrive/colab-notebooks/oloso/practice
#
base_dir_path = os.path.join(DRIVE_PATH, "MyDrive", "colab-notebooks", "coloso", "practice")
#
# 次に講義用のディレクトリを作成する。今回は第 20 講なので `lecture-21` と命名する:
# /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21
#
lecture_dir_path = os.path.join(base_dir_path, "lecture-21")

#
# 今回使用する学習画像を保存するディレクトリを作成する:
# /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images
#
sample_image_dir_path = os.path.join(lecture_dir_path, "sample-images")
print(f"The images will be saved in the following path: {sample_image_dir_path}")

# 上記のディレクトリが存在しない場合は作成する
if not os.path.exists(sample_image_dir_path):
    os.makedirs(sample_image_dir_path)
Mount the following directory: /content/drive
Mounted at /content/drive
The images will be saved in the following path: /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images

学習に使用するデータを用意します。今回は textual inversion にも登場している猫のおもちゃを huggingface dataset 上からダウンロードしてきます。ここではまず Colab 上でダウンロードしたのちに、google drive に保存する方法を取ります。以下の urls に学習したい概念の画像の URL を追加してください。dreambooth は数枚の画像で学習が可能であるため、3 〜 5 枚で十分です。

urls = [
    "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/1.jpeg",
    "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/2.jpeg",
    "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/3.jpeg",
    "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/4.jpeg",
    #
    # ここに更に画像を追加することができます
    #
    # "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/5.jpeg",
    # "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/6.jpeg",
    # "https://huggingface.co/datasets/diffusers/cat_toy_example/resolve/main/7.jpeg",
]

オンラインにある画像をダウンロードする関数を以下のように定義します。この関数を使って上記の urls で指定した画像をインターネット上からダウンロードします。

import requests

def download_image(url: str) -> PilImage:
    return Image.open(requests.get(url, stream=True).raw)

for i, url in enumerate(urls):
    image = download_image(url)
    image_filepath = os.path.join(sample_image_dir_path, f"{i}.jpg")
    print(f"The image is saved in the following path: {image_filepath}")
    image.save(image_filepath)
The image is saved in the following path: /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images/0.jpg
The image is saved in the following path: /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images/1.jpg
The image is saved in the following path: /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images/2.jpg
The image is saved in the following path: /content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sample-images/3.jpg

準備した画像を確認してみます。

images: List[PilImage] = []
for file_path in os.listdir(sample_image_dir_path):
    image_filepath = os.path.join(sample_image_dir_path, file_path)
    image = Image.open(image_filepath)
    image = image.resize((512, 512))
    images.append(image)

image_grid(images, rows=1, cols=len(images))
Output hidden; open in https://colab.research.google.com to view.

学習したい概念に対する dreambooth の設定を行います。特に dreambooth では事前学習によって得られている事前知識を保存するかを指定するパラメータが存在します。

  • instance_prompt には学習させたい概念を適切に説明し、なおかつ初期化トークン (initializer token) である <cat-toy> が含まれている必要があります。

  • is_prior_preservation には概念のクラス(例: おもちゃ、犬、絵画等)が保存されることを保証したい場合に使用します。学習時間が少し伸びますが、生成品質が向上します。

  • prior_preservation_class_prompt には概念のクラスを保存するために利用するプロンプトを指定します。

instance_prompt = "<cat-toy> toy"

is_prior_preservation = False
prior_preservation_class_prompt = "a photo of a cat clay toy"

以下では事前知識の保存に関するハイパーパラメータを指定します。

  • num_class_images には知識保存するクラスの画像を指定します。

  • prior_loss_weight には知識保存するクラスに対する損失の重みを指定します。

  • prior_preservation_class_folder には知識保存する画像の格納先を指定します。

num_class_images = 12
prior_loss_weight = 0.5
prior_preservation_class_dir = os.path.join(lecture_dir_path, "class-images")
# class_data_root = prior_preservation_class_dir
# class_prompt = prior_preservation_class_prompt
import torch

from typing import Optional, TypedDict
from torch.utils.data import Dataset
from torchvision import transforms
from transformers.tokenization_utils import PreTrainedTokenizer

#
# 加工したデータセットを辞書型のデータに加工する際に
# key の定義と対応する value の型アノテーションを宣言
#
# 以下のように宣言することで、想定とは異なるデータが
# 入ってきた場合にエラーを出すことができる
#
class ExampleRequired(TypedDict):
    instance_images: torch.Tensor
    instance_prompt_ids: torch.Tensor


class Example(ExampleRequired, total=False):
    class_images: torch.Tensor
    class_prompt_ids: torch.Tensor


class DreamBoothDataset(Dataset):
    def __init__(
        self,
        instance_data_root: str,
        instance_prompt: str,
        tokenizer: PreTrainedTokenizer,
        class_data_root: Optional[str] = None,
        class_prompt: Optional[str] = None,
        image_size: int = 512,
        is_center_crop: bool = False,
    ) -> None:
        self.image_size = image_size
        self.is_center_crop = is_center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = instance_data_root
        if not os.path.exists(self.instance_data_root):
            raise ValueError(
                f"The following `instance_data_root` does not exists: "
                f"{self.instance_data_root}"
            )

        self.instance_image_paths = [
            os.path.join(self.instance_data_root, file_path)
            for file_path in os.listdir(self.instance_data_root)
        ]
        self.num_instance_images = len(self.instance_image_paths)
        self.instance_prompt = instance_prompt
        self.dataset_length = self.num_instance_images

        #
        # class preservation まわりの設定
        #
        self.class_data_root = None
        if class_data_root is not None:
            self.class_data_root = class_data_root
            os.makedirs(self.class_data_root, exist_ok=True)

            self.class_image_paths = [
                os.path.join(self.class_data_root, file_path)
                for file_path in os.listdir(self.class_data_root)
            ]
            self.num_class_images = len(self.class_image_paths)
            self.data_length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt

        #
        # 前処理の設定
        #
        transform_list = [
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size) if is_center_crop else transforms.RandomCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5]),
        ]
        self.image_transforms = transforms.Compose(transform_list)

    def __len__(self) -> int:
        return self.dataset_length

    def load_image(self, image_path: str) -> torch.Tensor:
        image_pil = Image.open(image_path)
        if not image_pil.mode == "RGB":
            image_pil = image_pil.convert("RGB")

        # 前処理を実施
        return self.image_transforms(image_pil)

    def __getitem__(self, idx: int) -> Example:
        #
        # 画像の読み込み
        #
        instance_image_path = self.instance_image_paths[idx % self.num_instance_images]
        instance_image = self.load_image(instance_image_path)
        #
        # プロンプトのトークナイズ
        #
        instance_prompt_ids = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        # 入力するデータを辞書型に詰め込む
        example = {
            "instance_images": instance_image,
            "instance_prompt_ids": instance_prompt_ids
        }

        #
        # class preservation まわりの処理
        #
        if self.class_data_root:
            class_image = self.load_image(self.class_image_paths[idx % self.num_class_images])
            class_prompt_ids = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids

            example["class_images"] = class_image
            example["class_prompt_ids"] = class_prompt_ids

        return example

class preservation で使用する画像を生成します。これを使用することで、元々生成出来ていた画像を忘れずに、新たに入力する概念を学習することが可能です。

import gc
from diffusers import StableDiffusionPipeline
from tqdm.auto import tqdm

def generate_class_images() -> None:
    if not os.path.exists(prior_preservation_class_dir):
        os.makedirs(prior_preservation_class_dir)
    cur_class_images = len(os.listdir(prior_preservation_class_dir))

    if cur_class_images >= num_class_images:
        return

    pipe = StableDiffusionPipeline.from_pretrained(
        model_id, revision="fp16", torch_dtype=torch.float16
    )
    pipe = pipe.to("cuda")

    pipe.enable_attention_slicing()
    pipe.set_progress_bar_config(disable=True)

    num_new_images = num_class_images - cur_class_images
    print(f"Number of class images to sample: {num_new_images}")

    for idx in tqdm(range(0, num_new_images, 2), desc="Generating class images"):
        images = pipe(
            prompt=prior_preservation_class_prompt,
            num_images_per_prompt=2,
        ).images

        for i, image in enumerate(images):
            image_path = os.path.join(prior_preservation_class_dir, f"{cur_class_images + idx + i}.jpg")
            image.save(image_path)
            print(f"The generated image is saved in the following path: {image_path}")

    pipe = None
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

generate_class_images()

Stable Diffusion で学習されたコンポーネントをそれぞれ読み込みます。

from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel

text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")

学習データを、上記で定義した DreamBoothDataset を元に作成します。

train_dataset = DreamBoothDataset(
    instance_data_root=sample_image_dir_path,
    instance_prompt=instance_prompt,
    class_data_root=prior_preservation_class_dir if is_prior_preservation else None,
    class_prompt=prior_preservation_class_prompt,
    tokenizer=tokenizer,
    image_size=vae.config.sample_size,
    is_center_crop=True,
)

学習時に付与するノイズを制御するスケジューラを定義します。

from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")

学習の準備を行います。ハイパーパラメータを以下のように定義します。モデルを学習して生成結果を確認したときに結果が悪かった場合は、learning_ratemax_train_steps の調整を検討してみてください。

from dataclasses import dataclass

@dataclass
class Hyperparameter(object):
    learning_rate: float = 1e-6
    max_train_steps: int = 300
    train_text_encoder: bool = False
    train_batch_size: int = 1 # 事前知識の保存を行う場合はバッチサイズを 1 にしてください
    gradient_accumulation_steps: int = 1
    gradient_checkpointing: bool = True
    max_grad_norm: float = 1.0
    mixed_precision: str = "fp16"
    use_8bit_adam: bool = True # `bitsandbytes` による 8-bit 最適化を利用
    seed: int = 19950815
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 100
    output_dir_path: str = os.path.join(lecture_dir_path, "sd-dreambooth-output")

hparams = Hyperparameter()
print(hparams)
Hyperparameter(learning_rate=1e-06, max_train_steps=300, train_text_encoder=False, train_batch_size=1, gradient_accumulation_steps=1, gradient_checkpointing=True, max_grad_norm=1.0, mixed_precision='fp16', use_8bit_adam=True, seed=19950815, lr_scheduler='constant', lr_warmup_steps=100, output_dir_path='/content/drive/MyDrive/colab-notebooks/coloso/practice/lecture-21/sd-dreambooth-output')
import itertools
import math

import bitsandbytes as bnb

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler
import torch.nn.functional as F

logger = get_logger(__name__)

class BatchDict(TypedDict):
    input_ids: torch.Tensor
    pixel_values: torch.Tensor


def collate_fn(examples: List[Example]) -> BatchDict:
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    # concat class and instance examples for prior preservation
    if is_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = tokenizer.pad(
        {"input_ids": input_ids},
        padding="max_length",
        return_tensors="pt",
        max_length=tokenizer.model_max_length
    ).input_ids

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }
    return batch


def training_function(text_encoder, vae, unet):

    # 学習の再現性を確保するために乱数の seed を固定
    set_seed(hparams.seed)

    accelerator = Accelerator(
        gradient_accumulation_steps=hparams.gradient_accumulation_steps,
        mixed_precision=hparams.mixed_precision,
    )

    # 2023/07 現在 `accelerate.accumulate` で 2 つのモデルを学習する際に
    # gradient accmuleration を使用することができません。これは `accelerate` でまもなく使用できるようになるようです
    # 対象の機能が導入されたら、将来的に以下のチェックを無効にすることを検討してください:
    if hparams.train_text_encoder and hparams.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
        raise ValueError(
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
        )

    # VAE のパラメータを固定
    vae.requires_grad_(False)
    # Text Encoder のパラメータを固定するか判定
    if not hparams.train_text_encoder:
        text_encoder.requires_grad_(False)

    if hparams.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if hparams.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

    # Colab の T5 GPU のような、16 GB 以下の GPU RAM の場合は
    # fine-tuning 時のメモリ使用量を減らすために 8-bit の Adam optimizer を使用
    if hparams.use_8bit_adam:
        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    params_to_optimize = (
        itertools.chain(unet.parameters(), text_encoder.parameters())
        if hparams.train_text_encoder
        else unet.parameters()
    )

    optimizer = optimizer_class(
        params_to_optimize,
        lr=hparams.learning_rate,
    )

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=hparams.train_batch_size, shuffle=True, collate_fn=collate_fn
    )

    lr_scheduler = get_scheduler(
        hparams.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=hparams.lr_warmup_steps * hparams.gradient_accumulation_steps,
        num_training_steps=hparams.max_train_steps * hparams.gradient_accumulation_steps,
    )

    if hparams.train_text_encoder:
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            unet, text_encoder, optimizer, train_dataloader, lr_scheduler
        )
    else:
        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            unet, optimizer, train_dataloader, lr_scheduler
        )

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # VAE と Text Encoder を GPU に移動
    # Mixed Precision Training (混合精度学習) のために、`vae` と `text_encoder` の重みを
    # 半精度 (float16) にキャストします。これらのモデルは推論にのみ使用されるため
    # 単精度 (float32) の重みである必要はありません
    vae.to(accelerator.device, dtype=weight_dtype)
    vae.decoder.to("cpu")
    if not hparams.train_text_encoder:
        text_encoder.to(accelerator.device, dtype=weight_dtype)

    # 学習用データローダーのサイズが gradient accumulation の数によって変わる可能性があるため
    # ここで再度学習ステップ数を計算し直す
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / hparams.gradient_accumulation_steps)
    num_train_epochs = math.ceil(hparams.max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = hparams.train_batch_size * accelerator.num_processes * hparams.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Instantaneous batch size per device = {hparams.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {hparams.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {hparams.max_train_steps}")

    progress_bar = tqdm(range(hparams.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")
    global_step = 0

    for epoch in range(num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # 画像を潜在データへ変換
                latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                latents = latents * 0.18215

                # 潜在データへ追加するノイズを取得
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # 各画像に対してランダムなタイムステップ数を取得
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # 各タイムステップにおけるノイズの大きさに従って
                # 潜在データにノイズを追加 (拡散過程)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # 条件付のためのプロンプトからテキストベクトルを取得
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]

                # ノイズを予測
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                # 予測タイプに応じた損失を計算
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # 概念の事前知識を忘却しないための学習
                if is_prior_preservation:
                    # 追加したノイズ `noise` と 予測したノイズ`noise_pred` を
                    # 2 つの部分に分けて、それぞれの部分で別々の損失を計算
                    noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
                    target, target_prior = torch.chunk(target, 2, dim=0)

                    # 追加概念に対する損失を計算
                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

                    # 事前知識に対する損失を計算
                    prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")

                    # 上記の損失を合算
                    loss = loss + prior_loss_weight * prior_loss
                else:
                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(unet.parameters(), text_encoder.parameters())
                        if hparams.train_text_encoder
                        else unet.parameters()
                    )
                    accelerator.clip_grad_norm_(unet.parameters(), hparams.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()

            # accelerator がバックグラウンドで最適化工程を実行したかを確認
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

            logs = {"loss": loss.detach().item()}
            progress_bar.set_postfix(**logs)

            if global_step >= hparams.max_train_steps:
                break

        accelerator.wait_for_everyone()

    # 学習したモデルを元に、pipeline を構築して保存
    if accelerator.is_main_process:
        pipeline = StableDiffusionPipeline.from_pretrained(
            model_id,
            unet=accelerator.unwrap_model(unet),
            text_encoder=accelerator.unwrap_model(text_encoder),
        )
        pipeline.save_pretrained(hparams.output_dir_path)

accelerate を用いて Colab notebook 上で効率的な学習を開始します。

import accelerate

accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
    if param.grad is not None:
        del param.grad # Colab では RAM の制約があるため勾配に関する情報を削除
    torch.cuda.empty_cache()
Launching training on one GPU.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.

学習した概念を元に画像を生成する#

上記で学習した結果を StableDiffusionPipeline で読み込んで、instance_prompt を含んだプロンプトで新たな画像を生成させてみましょう。

from diffusers import DPMSolverMultistepScheduler

scheduler = DPMSolverMultistepScheduler.from_pretrained(hparams.output_dir_path, subfolder="scheduler")

pipe = StableDiffusionPipeline.from_pretrained(
    hparams.output_dir_path,
    scheduler=scheduler,
    torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

instance_prompt として設定した <cat-toy> を含んだプロンプトを指定して画像を生成させてみます。

prompt = "A <cat-toy> backpack"

num_samples = 2
num_rows = 1

generator = torch.Generator().manual_seed(19950815)

all_images = []
for _ in range(num_rows):
    images = pipe(
        prompt=prompt,
        num_images_per_prompt=num_samples,
        generator=generator,
        num_inference_steps=25,
        guidance_scale=9
    ).images
    all_images.extend(images)

image_grid(all_images, num_rows, num_samples)
../_images/51d1c3f7e6cc484b909e6b70e5d907d8a39040cb918fe9be368041faab03061a.png