画像生成 AI 入門: Python による拡散モデルの理論と実践#
Section 07. Play with Diffusion Model#
Stable Diffusion を中心とした拡散モデルを用いて、最先端の画像生成技術を実際に動かして実践していきます。
Lecture 28. LoRA#
LoRA [Hu+ ICLR'22] を用いて元々の Stable Diffusion の重みはそのままに、省メモリを達成しながらさまざまなスタイルを学習可能にしていきます。以下 🤗 Low-Rank Adaptation of Large Language Models (LoRA) を参考に動作を追っていきます。
huggingface/diffusers では 2023/07/25 現在、LoRA は UNet2DConditionalModel
の attention 層のみに対応しています。また、一部 DreamBooth の text encoder を LoRA で微調整することもサポートしています。DreamBooth の text encoder を微調整すると一般的に良い結果が得られますが、計算量が増える可能性があります。
セットアップ#
GPU が使用できるか確認#
本 Colab ノートブックを実行するために GPU ランタイムを使用していることを確認します。CPU ランタイムと比べて画像生成がより早くなります。以下の nvidia-smi
コマンドが失敗する場合は再度講義資料の GPU 使用設定
のスライド説明や Google Colab の FAQ 等を参考にランタイムタイプが正しく変更されているか確認してください。
!nvidia-smi
Sat Jul 29 02:49:30 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 42C P8 10W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
huggingface/diffusers からレポジトリを clone してインストール#
今回は huggingface/diffusers の example に用意されている以下 2 つの python スクリプトを使用します:
これらのスクリプトを使用するために、huggingface/diffusers からレポジトリを clone します。その後 clone したレポジトリのディレクトリに移動し、ベースとなる依存ライブラリを pip でインストールします。
!git clone https://github.com/huggingface/diffusers.git
%cd /content/diffusers
!pip install .
Cloning into 'diffusers'...
remote: Enumerating objects: 32512, done.
remote: Total 32512 (delta 0), reused 0 (delta 0), pack-reused 32512
Receiving objects: 100% (32512/32512), 20.83 MiB | 22.96 MiB/s, done.
Resolving deltas: 100% (23967/23967), done.
/content/diffusers
Processing /content/diffusers
Installing build dependencies ... ?25l?25hdone
Getting requirements to build wheel ... ?25l?25hdone
Preparing metadata (pyproject.toml) ... ?25l?25hdone
Requirement already satisfied: importlib-metadata in /usr/lib/python3/dist-packages (from diffusers==0.20.0.dev0) (4.6.4)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers==0.20.0.dev0) (3.12.2)
Collecting huggingface-hub>=0.13.2 (from diffusers==0.20.0.dev0)
Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 4.8 MB/s eta 0:00:00
?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers==0.20.0.dev0) (1.22.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers==0.20.0.dev0) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers==0.20.0.dev0) (2.27.1)
Collecting safetensors>=0.3.1 (from diffusers==0.20.0.dev0)
Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 31.5 MB/s eta 0:00:00
?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers==0.20.0.dev0) (9.4.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.20.0.dev0) (2023.6.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.20.0.dev0) (4.65.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.20.0.dev0) (6.0.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.20.0.dev0) (4.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.13.2->diffusers==0.20.0.dev0) (23.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.20.0.dev0) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.20.0.dev0) (2023.7.22)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.20.0.dev0) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers==0.20.0.dev0) (3.4)
Building wheels for collected packages: diffusers
Building wheel for diffusers (pyproject.toml) ... ?25l?25hdone
Created wheel for diffusers: filename=diffusers-0.20.0.dev0-py3-none-any.whl size=1321056 sha256=75554b5a2ce81961996250c8c01d1f9c2db76ea1310abb74db240027c0be84e7
Stored in directory: /tmp/pip-ephem-wheel-cache-pn2f90cg/wheels/95/c5/3b/e1b4269f8a2584de57e75f949a185b48fc4144e9a91fc9965a
Successfully built diffusers
Installing collected packages: safetensors, huggingface-hub, diffusers
Successfully installed diffusers-0.20.0.dev0 huggingface-hub-0.16.4 safetensors-0.3.1
準備として画像を複数生成した場合に結果を確認しやすいように、画像をグリッド上に表示する関数を以下のように定義します。この関数は 🤗 Hugging Face Stable Diffusion のブログ記事のものを利用しています。
from typing import List
from PIL import Image
from PIL.Image import Image as PilImage
def image_grid(imgs: List[PilImage], rows: int, cols: int) -> PilImage:
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
Text-to-Image における LoRA チューニング#
LoRA でモデルの微調整は LoRA チューニングと呼ぶことが多く、従来のモデル全体の微調整はフルファインチューニングと呼ばれています。
何十億ものパラメータを持つ Stable Diffusion のような拡散モデルのフルファインチューニングは時間もコストも掛かります。LoRA チューニングを使用すると拡散モデルの微調整が遥かに簡単に、かつ高速になります。LoRA は効果的な注意機構 (attention mechanism) や 8-bit optimizer などの GPU RAM の省メモリ技術に頼ることなく、Colab GPU 上のわずか 11 GB の GPU RAM で動作します。
依存ライブラリのインストール#
本セクションでは examples/text_to_image/train_text_to_image_lora.py
を使用して text-to-image モデルの LoRA チューニングを試します。ここでは、LoRA チューニングが可能なスクリプトで使用するライブラリを、対応する examples/text_to_image/requirements.txt
からインストールします。
%cd /content/diffusers/examples/text_to_image
!pip install -r requirements.txt
/content/diffusers/examples/text_to_image
Collecting accelerate>=0.16.0 (from -r requirements.txt (line 1))
Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 244.2/244.2 kB 4.2 MB/s eta 0:00:00
?25hRequirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.15.2+cu118)
Collecting transformers>=4.25.1 (from -r requirements.txt (line 3))
Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.4/7.4 MB 18.1 MB/s eta 0:00:00
?25hCollecting datasets (from -r requirements.txt (line 4))
Downloading datasets-2.14.1-py3-none-any.whl (492 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 492.4/492.4 kB 23.7 MB/s eta 0:00:00
?25hCollecting ftfy (from -r requirements.txt (line 5))
Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.1/53.1 kB 5.8 MB/s eta 0:00:00
?25hRequirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 6)) (2.12.3)
Requirement already satisfied: Jinja2 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 7)) (3.1.2)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (23.1)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (5.9.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (6.0.1)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (2.0.1+cu118)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->-r requirements.txt (line 2)) (2.27.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->-r requirements.txt (line 2)) (9.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.1)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (16.0.6)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (0.16.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (2022.10.31)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers>=4.25.1->-r requirements.txt (line 3))
Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 42.9 MB/s eta 0:00:00
?25hRequirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (4.65.0)
Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->-r requirements.txt (line 4)) (9.0.0)
Collecting dill<0.3.8,>=0.3.0 (from datasets->-r requirements.txt (line 4))
Downloading dill-0.3.7-py3-none-any.whl (115 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 12.6 MB/s eta 0:00:00
?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->-r requirements.txt (line 4)) (1.5.3)
Collecting xxhash (from datasets->-r requirements.txt (line 4))
Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.5/212.5 kB 24.2 MB/s eta 0:00:00
?25hCollecting multiprocess (from datasets->-r requirements.txt (line 4))
Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 15.2 MB/s eta 0:00:00
?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets->-r requirements.txt (line 4)) (2023.6.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->-r requirements.txt (line 4)) (3.8.5)
Requirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy->-r requirements.txt (line 5)) (0.2.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.4.0)
Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.56.2)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (2.17.3)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (3.4.4)
Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (3.20.3)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (67.7.2)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (0.7.1)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (2.3.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (0.41.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2->-r requirements.txt (line 7)) (2.1.3)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (23.1.0)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (2.0.12)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (6.0.4)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (4.0.2)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (1.9.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (1.4.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->-r requirements.txt (line 4)) (1.3.1)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (5.3.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (0.3.0)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (1.16.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard->-r requirements.txt (line 6)) (1.3.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (2023.7.22)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (3.4)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->-r requirements.txt (line 4)) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->-r requirements.txt (line 4)) (2022.7.1)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (0.5.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard->-r requirements.txt (line 6)) (3.2.2)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (1.3.0)
Installing collected packages: tokenizers, xxhash, ftfy, dill, multiprocess, transformers, datasets, accelerate
Successfully installed accelerate-0.21.0 datasets-2.14.1 dill-0.3.7 ftfy-6.1.1 multiprocess-0.70.15 tokenizers-0.13.3 transformers-4.31.0 xxhash-3.2.0
今回は Pokémon BLIP captions を用いて stable-diffusion-v1-5
を微調整して、自分だけのポケモンを生成してみましょう。
LoRA チューニングが可能なスクリプトを実行する際に使用する環境変数を以下のように準備しておきます:
MODEL_NAME
には今回使用するstable-diffusion-v1-5
を設定DATASET_NAME
には今回学習させるlambdalabs/pokemon-blip-captions
を設定OUTPUT_DIR
には学習結果をどこに保存するかを設定
%env MODEL_NAME=runwayml/stable-diffusion-v1-5
%env DATASET_NAME=lambdalabs/pokemon-blip-captions
%env OUTPUT_DIR=/sddata/finetune/lora/pokemon
env: MODEL_NAME=runwayml/stable-diffusion-v1-5
env: DATASET_NAME=lambdalabs/pokemon-blip-captions
env: OUTPUT_DIR=/sddata/finetune/lora/pokemon
LoRA チューニングの実施#
これで LoRA チューニングを開始する準備ができました。スクリプトには以下のように複数のオプションを指定することが出来ますが、いくつか注意すべきオプションがあります:
--max_train_steps
: 学習回数を指定するオプションです。デフォルトでは 15,000 に設定されていますが、Colab で実行すると約 6 時間程度訓練に時間がかかります。今回の実習では動作確認のために 10 を設定します--learing_rate
: LoRA では比較的高めの学習率を設定可能です。デフォルトでは1e-4
が設定されていますが、通常の微調整では1e-5 ~ 1e-6
を使うことが多いでしょう
!accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path="${MODEL_NAME}" \
--dataset_name="${DATASET_NAME}" \
--dataloader_num_workers=8 \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=10 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="${OUTPUT_DIR}" \
--checkpointing_steps=500 \
--validation_prompt="A pokemon with blue eyes." \
--seed=1337
2023-07-29 02:57:29.853137: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
The following values were not passed to `accelerate launch` and had defaults used instead:
`--num_processes` was set to a value of `1`
`--num_machines` was set to a value of `1`
`--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
2023-07-29 02:57:35.909420: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
07/29/2023 02:57:38 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: fp16
Downloading (…)cheduler_config.json: 100% 308/308 [00:00<00:00, 1.71MB/s]
{'clip_sample_range', 'thresholding', 'variance_type', 'sample_max_value', 'prediction_type', 'dynamic_thresholding_ratio', 'timestep_spacing'} was not found in config. Values will be initialized to default values.
Downloading (…)tokenizer/vocab.json: 100% 1.06M/1.06M [00:00<00:00, 16.4MB/s]
Downloading (…)tokenizer/merges.txt: 100% 525k/525k [00:00<00:00, 38.6MB/s]
Downloading (…)cial_tokens_map.json: 100% 472/472 [00:00<00:00, 2.39MB/s]
Downloading (…)okenizer_config.json: 100% 806/806 [00:00<00:00, 4.48MB/s]
Downloading (…)_encoder/config.json: 100% 617/617 [00:00<00:00, 3.23MB/s]
Downloading model.safetensors: 100% 492M/492M [00:03<00:00, 132MB/s]
Downloading (…)main/vae/config.json: 100% 547/547 [00:00<00:00, 3.33MB/s]
Downloading (…)ch_model.safetensors: 100% 335M/335M [00:02<00:00, 122MB/s]
{'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
Downloading (…)ain/unet/config.json: 100% 743/743 [00:00<00:00, 4.56MB/s]
Downloading (…)ch_model.safetensors: 100% 3.44G/3.44G [00:34<00:00, 98.3MB/s]
{'mid_block_only_cross_attention', 'cross_attention_norm', 'class_embed_type', 'class_embeddings_concat', 'resnet_time_scale_shift', 'resnet_out_scale_factor', 'use_linear_projection', 'time_embedding_dim', 'timestep_post_act', 'only_cross_attention', 'projection_class_embeddings_input_dim', 'num_attention_heads', 'mid_block_type', 'encoder_hid_dim_type', 'addition_embed_type', 'time_embedding_act_fn', 'upcast_attention', 'time_embedding_type', 'transformer_layers_per_block', 'conv_in_kernel', 'dual_cross_attention', 'resnet_skip_time_act', 'conv_out_kernel', 'addition_embed_type_num_heads', 'time_cond_proj_dim', 'num_class_embeds', 'addition_time_embed_dim', 'encoder_hid_dim'} was not found in config. Values will be initialized to default values.
Downloading readme: 100% 1.80k/1.80k [00:00<00:00, 12.5MB/s]
Downloading metadata: 100% 731/731 [00:00<00:00, 4.06MB/s]
Downloading data files: 0% 0/1 [00:00<?, ?it/s]
Downloading data: 0% 0.00/99.7M [00:00<?, ?B/s]
Downloading data: 4% 4.19M/99.7M [00:00<00:07, 12.9MB/s]
Downloading data: 13% 12.6M/99.7M [00:00<00:02, 33.0MB/s]
Downloading data: 21% 21.0M/99.7M [00:00<00:01, 46.3MB/s]
Downloading data: 29% 29.4M/99.7M [00:00<00:01, 54.9MB/s]
Downloading data: 38% 37.7M/99.7M [00:00<00:01, 60.4MB/s]
Downloading data: 46% 46.1M/99.7M [00:00<00:00, 64.8MB/s]
Downloading data: 55% 54.5M/99.7M [00:01<00:00, 66.1MB/s]
Downloading data: 63% 62.9M/99.7M [00:01<00:00, 68.0MB/s]
Downloading data: 72% 71.3M/99.7M [00:01<00:00, 69.3MB/s]
Downloading data: 80% 79.7M/99.7M [00:01<00:00, 70.0MB/s]
Downloading data: 88% 88.1M/99.7M [00:01<00:00, 70.1MB/s]
Downloading data: 100% 99.7M/99.7M [00:01<00:00, 61.5MB/s]
Downloading data files: 100% 1/1 [00:01<00:00, 1.62s/it]
Extracting data files: 100% 1/1 [00:00<00:00, 1090.00it/s]
Generating train split: 100% 833/833 [00:01<00:00, 824.49 examples/s]
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
07/29/2023 02:58:59 - INFO - __main__ - ***** Running training *****
07/29/2023 02:58:59 - INFO - __main__ - Num examples = 833
07/29/2023 02:58:59 - INFO - __main__ - Num Epochs = 1
07/29/2023 02:58:59 - INFO - __main__ - Instantaneous batch size per device = 1
07/29/2023 02:58:59 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 4
07/29/2023 02:58:59 - INFO - __main__ - Gradient Accumulation steps = 4
07/29/2023 02:58:59 - INFO - __main__ - Total optimization steps = 10
Steps: 100% 10/10 [00:25<00:00, 1.65s/it, lr=0, step_loss=0.00373] 07/29/2023 02:59:25 - INFO - __main__ - Running validation...
Generating 4 images with prompt: A pokemon with blue eyes..
Downloading (…)ain/model_index.json: 100% 541/541 [00:00<00:00, 2.57MB/s]
Fetching 13 files: 0% 0/13 [00:00<?, ?it/s]
Downloading (…)rocessor_config.json: 100% 342/342 [00:00<00:00, 2.18MB/s]
Fetching 13 files: 8% 1/13 [00:00<00:04, 2.51it/s]
Downloading (…)_checker/config.json: 100% 4.72k/4.72k [00:00<00:00, 13.5MB/s]
Downloading model.safetensors: 0% 0.00/1.22G [00:00<?, ?B/s]
Downloading model.safetensors: 2% 21.0M/1.22G [00:00<00:08, 141MB/s]
Downloading model.safetensors: 4% 52.4M/1.22G [00:00<00:06, 191MB/s]
Downloading model.safetensors: 7% 83.9M/1.22G [00:00<00:05, 202MB/s]
Downloading model.safetensors: 9% 105M/1.22G [00:00<00:05, 200MB/s]
Downloading model.safetensors: 10% 126M/1.22G [00:00<00:05, 199MB/s]
Downloading model.safetensors: 13% 157M/1.22G [00:00<00:05, 203MB/s]
Downloading model.safetensors: 16% 189M/1.22G [00:00<00:04, 206MB/s]
Downloading model.safetensors: 17% 210M/1.22G [00:01<00:04, 202MB/s]
Downloading model.safetensors: 19% 231M/1.22G [00:01<00:04, 204MB/s]
Downloading model.safetensors: 21% 252M/1.22G [00:01<00:04, 204MB/s]
Downloading model.safetensors: 22% 273M/1.22G [00:01<00:04, 205MB/s]
Downloading model.safetensors: 24% 294M/1.22G [00:01<00:04, 203MB/s]
Downloading model.safetensors: 26% 315M/1.22G [00:01<00:04, 195MB/s]
Downloading model.safetensors: 28% 336M/1.22G [00:01<00:04, 193MB/s]
Downloading model.safetensors: 29% 357M/1.22G [00:01<00:04, 186MB/s]
Downloading model.safetensors: 31% 377M/1.22G [00:01<00:04, 182MB/s]
Downloading model.safetensors: 33% 398M/1.22G [00:02<00:04, 186MB/s]
Downloading model.safetensors: 34% 419M/1.22G [00:03<00:23, 34.4MB/s]
Downloading model.safetensors: 36% 440M/1.22G [00:04<00:17, 43.1MB/s]
Downloading model.safetensors: 38% 461M/1.22G [00:04<00:13, 54.6MB/s]
Downloading model.safetensors: 40% 482M/1.22G [00:04<00:11, 65.0MB/s]
Downloading model.safetensors: 41% 503M/1.22G [00:04<00:09, 77.9MB/s]
Downloading model.safetensors: 43% 524M/1.22G [00:04<00:07, 95.5MB/s]
Downloading model.safetensors: 45% 545M/1.22G [00:04<00:06, 111MB/s]
Downloading model.safetensors: 47% 566M/1.22G [00:04<00:05, 126MB/s]
Downloading model.safetensors: 48% 587M/1.22G [00:04<00:04, 140MB/s]
Downloading model.safetensors: 50% 608M/1.22G [00:05<00:04, 152MB/s]
Downloading model.safetensors: 52% 629M/1.22G [00:05<00:05, 112MB/s]
Downloading model.safetensors: 53% 650M/1.22G [00:05<00:04, 124MB/s]
Downloading model.safetensors: 55% 671M/1.22G [00:05<00:04, 135MB/s]
Downloading model.safetensors: 57% 692M/1.22G [00:05<00:03, 146MB/s]
Downloading model.safetensors: 59% 713M/1.22G [00:05<00:03, 153MB/s]
Downloading model.safetensors: 61% 744M/1.22G [00:05<00:02, 172MB/s]
Downloading model.safetensors: 63% 765M/1.22G [00:06<00:02, 179MB/s]
Downloading model.safetensors: 66% 797M/1.22G [00:06<00:02, 189MB/s]
Downloading model.safetensors: 67% 818M/1.22G [00:06<00:02, 194MB/s]
Downloading model.safetensors: 70% 849M/1.22G [00:06<00:01, 191MB/s]
Downloading model.safetensors: 72% 870M/1.22G [00:06<00:01, 188MB/s]
Downloading model.safetensors: 73% 891M/1.22G [00:06<00:01, 190MB/s]
Downloading model.safetensors: 76% 923M/1.22G [00:08<00:06, 43.7MB/s]
Downloading model.safetensors: 78% 944M/1.22G [00:10<00:11, 24.0MB/s]
Downloading model.safetensors: 79% 965M/1.22G [00:10<00:08, 30.4MB/s]
Downloading model.safetensors: 80% 975M/1.22G [00:10<00:07, 33.8MB/s]
Downloading model.safetensors: 82% 996M/1.22G [00:10<00:04, 45.0MB/s]
Downloading model.safetensors: 85% 1.03G/1.22G [00:11<00:02, 66.1MB/s]
Downloading model.safetensors: 86% 1.05G/1.22G [00:11<00:02, 78.0MB/s]
Downloading model.safetensors: 89% 1.08G/1.22G [00:11<00:01, 103MB/s]
Downloading model.safetensors: 91% 1.11G/1.22G [00:11<00:00, 126MB/s]
Downloading model.safetensors: 94% 1.14G/1.22G [00:11<00:00, 146MB/s]
Downloading model.safetensors: 97% 1.17G/1.22G [00:11<00:00, 161MB/s]
Downloading model.safetensors: 98% 1.20G/1.22G [00:11<00:00, 167MB/s]
Downloading model.safetensors: 100% 1.22G/1.22G [00:11<00:00, 101MB/s]
Fetching 13 files: 100% 13/13 [00:12<00:00, 1.04it/s]
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
Loading pipeline components...: 0% 0/7 [00:00<?, ?it/s]{'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
Loaded vae as AutoencoderKL from `vae` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 29% 2/7 [00:00<00:00, 5.70it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loaded safety_checker as StableDiffusionSafetyChecker from `safety_checker` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 43% 3/7 [00:01<00:02, 1.66it/s]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of runwayml/stable-diffusion-v1-5.
Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 71% 5/7 [00:04<00:02, 1.17s/it]Loaded feature_extractor as CLIPImageProcessor from `feature_extractor` subfolder of runwayml/stable-diffusion-v1-5.
{'prediction_type', 'timestep_spacing'} was not found in config. Values will be initialized to default values.
Loaded scheduler as PNDMScheduler from `scheduler` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 100% 7/7 [00:04<00:00, 1.43it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Model weights saved in /sddata/finetune/lora/pokemon/pytorch_lora_weights.bin
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
Loading pipeline components...: 0% 0/7 [00:00<?, ?it/s]{'mid_block_only_cross_attention', 'cross_attention_norm', 'class_embed_type', 'class_embeddings_concat', 'resnet_time_scale_shift', 'resnet_out_scale_factor', 'use_linear_projection', 'time_embedding_dim', 'timestep_post_act', 'only_cross_attention', 'projection_class_embeddings_input_dim', 'num_attention_heads', 'mid_block_type', 'encoder_hid_dim_type', 'addition_embed_type', 'time_embedding_act_fn', 'upcast_attention', 'time_embedding_type', 'transformer_layers_per_block', 'conv_in_kernel', 'dual_cross_attention', 'resnet_skip_time_act', 'conv_out_kernel', 'addition_embed_type_num_heads', 'time_cond_proj_dim', 'num_class_embeds', 'addition_time_embed_dim', 'encoder_hid_dim'} was not found in config. Values will be initialized to default values.
Loaded unet as UNet2DConditionModel from `unet` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 14% 1/7 [00:15<01:34, 15.83s/it]{'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
Loaded vae as AutoencoderKL from `vae` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 29% 2/7 [00:16<00:33, 6.70s/it]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loaded safety_checker as StableDiffusionSafetyChecker from `safety_checker` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 43% 3/7 [00:23<00:27, 6.96s/it]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 57% 4/7 [00:23<00:12, 4.27s/it]Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 71% 5/7 [00:26<00:07, 3.85s/it]Loaded feature_extractor as CLIPImageProcessor from `feature_extractor` subfolder of runwayml/stable-diffusion-v1-5.
{'prediction_type', 'timestep_spacing'} was not found in config. Values will be initialized to default values.
Loaded scheduler as PNDMScheduler from `scheduler` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 100% 7/7 [00:26<00:00, 3.81s/it]
0% 0/30 [00:00<?, ?it/s]
3% 1/30 [00:00<00:10, 2.67it/s]
7% 2/30 [00:00<00:06, 4.14it/s]
10% 3/30 [00:00<00:05, 5.02it/s]
13% 4/30 [00:00<00:04, 5.57it/s]
17% 5/30 [00:00<00:04, 5.95it/s]
20% 6/30 [00:01<00:03, 6.11it/s]
23% 7/30 [00:01<00:03, 6.24it/s]
27% 8/30 [00:01<00:03, 6.37it/s]
30% 9/30 [00:01<00:03, 6.47it/s]
33% 10/30 [00:01<00:03, 6.53it/s]
37% 11/30 [00:01<00:02, 6.58it/s]
40% 12/30 [00:02<00:02, 6.61it/s]
43% 13/30 [00:02<00:02, 6.60it/s]
47% 14/30 [00:02<00:02, 6.59it/s]
50% 15/30 [00:02<00:02, 6.61it/s]
53% 16/30 [00:02<00:02, 6.62it/s]
57% 17/30 [00:02<00:01, 6.64it/s]
60% 18/30 [00:02<00:01, 6.65it/s]
63% 19/30 [00:03<00:01, 6.64it/s]
67% 20/30 [00:03<00:01, 6.62it/s]
70% 21/30 [00:03<00:01, 6.62it/s]
73% 22/30 [00:03<00:01, 6.62it/s]
77% 23/30 [00:03<00:01, 6.62it/s]
80% 24/30 [00:03<00:00, 6.63it/s]
83% 25/30 [00:03<00:00, 6.63it/s]
87% 26/30 [00:04<00:00, 6.62it/s]
90% 27/30 [00:04<00:00, 6.59it/s]
93% 28/30 [00:04<00:00, 6.60it/s]
97% 29/30 [00:04<00:00, 6.61it/s]
100% 30/30 [00:04<00:00, 6.32it/s]
0% 0/30 [00:00<?, ?it/s]
3% 1/30 [00:00<00:08, 3.40it/s]
7% 2/30 [00:00<00:05, 4.77it/s]
10% 3/30 [00:00<00:04, 5.47it/s]
13% 4/30 [00:00<00:04, 5.88it/s]
17% 5/30 [00:00<00:04, 6.05it/s]
20% 6/30 [00:01<00:03, 6.20it/s]
23% 7/30 [00:01<00:03, 6.33it/s]
27% 8/30 [00:01<00:03, 6.41it/s]
30% 9/30 [00:01<00:03, 6.47it/s]
33% 10/30 [00:01<00:03, 6.50it/s]
37% 11/30 [00:01<00:02, 6.52it/s]
40% 12/30 [00:01<00:02, 6.51it/s]
43% 13/30 [00:02<00:02, 6.51it/s]
47% 14/30 [00:02<00:02, 6.53it/s]
50% 15/30 [00:02<00:02, 6.55it/s]
53% 16/30 [00:02<00:02, 6.56it/s]
57% 17/30 [00:02<00:01, 6.57it/s]
60% 18/30 [00:02<00:01, 6.56it/s]
63% 19/30 [00:03<00:01, 6.54it/s]
67% 20/30 [00:03<00:01, 6.55it/s]
70% 21/30 [00:03<00:01, 6.54it/s]
73% 22/30 [00:03<00:01, 6.53it/s]
77% 23/30 [00:03<00:01, 6.56it/s]
80% 24/30 [00:03<00:00, 6.55it/s]
83% 25/30 [00:03<00:00, 6.55it/s]
87% 26/30 [00:04<00:00, 6.54it/s]
90% 27/30 [00:04<00:00, 6.57it/s]
93% 28/30 [00:04<00:00, 6.55it/s]
97% 29/30 [00:04<00:00, 6.55it/s]
100% 30/30 [00:04<00:00, 6.36it/s]
0% 0/30 [00:00<?, ?it/s]
3% 1/30 [00:00<00:08, 3.37it/s]
7% 2/30 [00:00<00:05, 4.73it/s]
10% 3/30 [00:00<00:04, 5.46it/s]
13% 4/30 [00:00<00:04, 5.90it/s]
17% 5/30 [00:00<00:04, 6.12it/s]
20% 6/30 [00:01<00:03, 6.15it/s]
23% 7/30 [00:01<00:03, 6.30it/s]
27% 8/30 [00:01<00:03, 6.38it/s]
30% 9/30 [00:01<00:03, 6.43it/s]
33% 10/30 [00:01<00:03, 6.48it/s]
37% 11/30 [00:01<00:02, 6.52it/s]
40% 12/30 [00:01<00:02, 6.51it/s]
43% 13/30 [00:02<00:02, 6.51it/s]
47% 14/30 [00:02<00:02, 6.52it/s]
50% 15/30 [00:02<00:02, 6.53it/s]
53% 16/30 [00:02<00:02, 6.53it/s]
57% 17/30 [00:02<00:01, 6.51it/s]
60% 18/30 [00:02<00:01, 6.51it/s]
63% 19/30 [00:03<00:01, 6.53it/s]
67% 20/30 [00:03<00:01, 6.55it/s]
70% 21/30 [00:03<00:01, 6.55it/s]
73% 22/30 [00:03<00:01, 6.56it/s]
77% 23/30 [00:03<00:01, 6.57it/s]
80% 24/30 [00:03<00:00, 6.56it/s]
83% 25/30 [00:03<00:00, 6.56it/s]
87% 26/30 [00:04<00:00, 6.54it/s]
90% 27/30 [00:04<00:00, 6.53it/s]
93% 28/30 [00:04<00:00, 6.51it/s]
97% 29/30 [00:04<00:00, 6.52it/s]
100% 30/30 [00:04<00:00, 6.35it/s]
0% 0/30 [00:00<?, ?it/s]
3% 1/30 [00:00<00:08, 3.38it/s]
7% 2/30 [00:00<00:05, 4.75it/s]
10% 3/30 [00:00<00:04, 5.44it/s]
13% 4/30 [00:00<00:04, 5.83it/s]
17% 5/30 [00:00<00:04, 5.91it/s]
20% 6/30 [00:01<00:03, 6.09it/s]
23% 7/30 [00:01<00:03, 6.22it/s]
27% 8/30 [00:01<00:03, 6.32it/s]
30% 9/30 [00:01<00:03, 6.39it/s]
33% 10/30 [00:01<00:03, 6.43it/s]
37% 11/30 [00:01<00:02, 6.44it/s]
40% 12/30 [00:01<00:02, 6.38it/s]
43% 13/30 [00:02<00:02, 6.40it/s]
47% 14/30 [00:02<00:02, 6.46it/s]
50% 15/30 [00:02<00:02, 6.47it/s]
53% 16/30 [00:02<00:02, 6.48it/s]
57% 17/30 [00:02<00:02, 6.48it/s]
60% 18/30 [00:02<00:01, 6.48it/s]
63% 19/30 [00:03<00:01, 6.46it/s]
67% 20/30 [00:03<00:01, 6.45it/s]
70% 21/30 [00:03<00:01, 6.46it/s]
73% 22/30 [00:03<00:01, 6.50it/s]
77% 23/30 [00:03<00:01, 6.48it/s]
80% 24/30 [00:03<00:00, 6.48it/s]
83% 25/30 [00:04<00:00, 6.45it/s]
87% 26/30 [00:04<00:00, 6.45it/s]
90% 27/30 [00:04<00:00, 6.44it/s]
93% 28/30 [00:04<00:00, 6.44it/s]
97% 29/30 [00:04<00:00, 6.47it/s]
100% 30/30 [00:04<00:00, 6.28it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
Steps: 100% 10/10 [02:00<00:00, 12.03s/it, lr=0, step_loss=0.00373]
LoRA チューニング結果を元にした推論#
上記で学習した結果を元に、StableDiffusionPipeline
でベースとなるモデル (runwayml/stable-diffusion-v1-5
) を読み込み、DPMSolverMultistepScheduler
で推論できるように準備します。
import os
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
model_base = os.environ["MODEL_NAME"]
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
ベースモデルの重みの上に微調整した LoRA モジュールを unet
の load_attn_procs
関数で読み込み、パイプラインを GPU へ移動して推論を高速化します。
lora_model_path = os.environ["OUTPUT_DIR"]
pipe.unet.load_attn_procs(lora_model_path)
pipe = pipe.to("cuda")
パイプラインを用いた推論時に、以下のように cross_attention_kwargs
オプションで scale
パラメータを設定することが可能です。このパラメータにより、LoRA モジュールの影響度合いを制御することが可能です。
ここで、scale 値が 0 のときは LoRA モジュールを使用せずオリジナルのベースモデルの重みのみを使用するのと同じです。逆に scale 値が 1 のときは LoRA モジュールのみを使用することを意味します。scale 値は 0 ~ 1 の間で 2 つの重みを補完します。
時間の関係上、上記の LoRA チューニングでは十分な学習ができていないためここでは画像生成結果をお見しておりません。ただ、以下のようにしてコードを実行することで LoRA チューニングによって獲得された自分だけのポケモンの画像が生成できるようになります。
以下は、LoRA モジュールの重みの半分と、ベースモデルの重みの半分を使用する例です:
pipe(
prompt="A pokemon with blue eyes.",
num_inference_steps=25,
guidance_scale=7.5,
cross_attention_kwargs={"scale": 0.5},
).images[0]
以下は、LoRA モジュールによる重みをすべて使用する例です:
pipe(
prompt="A pokemon with blue eyes.",
num_inference_steps=25,
guidance_scale=7.5,
).images[0]
以下は上記で LoRA チューニングに使用した lambdalabs/pokemon-blip-captions を十分に学習させた sayakpaul/sd-model-finetuned-lora-t4
による画像生成例です。
LoRA モジュールのモデル ID から RepoCard.load
を通じてモデル情報を読み込み、ベースモデルの情報を取得します。その後取得したベースモデルを読み込み、load_attn_procs
で LoRA モジュールも読み込む流れになっています。あとはこれまでのパイプライン同様の生成の流れを踏みます。
from huggingface_hub.repocard import RepoCard
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe.unet.load_attn_procs(lora_model_id)
pipe = pipe.to("cuda")
generator = torch.Generator().manual_seed(42)
pipe(
prompt="A pokemon with blue eyes.",
num_inference_steps=25,
guidance_scale=7.5,
generator=generator,
).images[0]
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
DreamBooth における LoRA チューニング#
DreamBooth [Ruiz+ CVPR'23] は Stable Diffusion のような Text2Image モデルを個人の趣向に合わせてパーソナライズするためのファインチューニング手法です。この手法は被写体の画像数枚から、異なるコンテキストにおける被写体の写実的な画像を生成可能な技術になっています。
しかしながら DreamBooth はハイパーパラメータに非常に敏感で、過学習しやすい傾向にあります。考慮すべき重要なハイパーパラメータは学習時間(学習率・学習回数)、推論時間(拡散過程のステップ数・ノイズスケジューラの種類)に影響するものがあります。
依存ライブラリのインストール#
本セクションでは examples/dreambooth/train_dreambooth_lora.py
を使用して dreambooth モデルの LoRA チューニングを試します。ここでは、dreambooth で LoRA チューニングが可能なスクリプトで使用するライブラリを、対応する examples/dreambooth/requirements.txt
からインストールします。
%cd /content/diffusers/examples/dreambooth
!pip install -r requirements.txt
!pip install bitsandbytes # 8-bit Adam optimizer を使用するため
/content/diffusers/examples/dreambooth
Requirement already satisfied: accelerate>=0.16.0 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.21.0)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.15.2+cu118)
Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 3)) (4.31.0)
Requirement already satisfied: ftfy in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 4)) (6.1.1)
Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 5)) (2.12.3)
Requirement already satisfied: Jinja2 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 6)) (3.1.2)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (23.1)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (5.9.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (6.0.1)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.16.0->-r requirements.txt (line 1)) (2.0.1+cu118)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->-r requirements.txt (line 2)) (2.27.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->-r requirements.txt (line 2)) (9.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.1)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (16.0.6)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (0.16.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (2022.10.31)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->-r requirements.txt (line 3)) (4.65.0)
Requirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy->-r requirements.txt (line 4)) (0.2.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (1.4.0)
Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (1.56.2)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (2.17.3)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (3.4.4)
Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (3.20.3)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (67.7.2)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (0.7.1)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (2.3.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 5)) (0.41.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2->-r requirements.txt (line 6)) (2.1.3)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 5)) (5.3.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 5)) (0.3.0)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 5)) (1.16.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 5)) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard->-r requirements.txt (line 5)) (1.3.1)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers>=4.25.1->-r requirements.txt (line 3)) (2023.6.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (2023.7.22)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->-r requirements.txt (line 2)) (3.4)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 5)) (0.5.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard->-r requirements.txt (line 5)) (3.2.2)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate>=0.16.0->-r requirements.txt (line 1)) (1.3.0)
Collecting bitsandbytes
Downloading bitsandbytes-0.41.0-py3-none-any.whl (92.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.6/92.6 MB 12.4 MB/s eta 0:00:00
?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.0
今回は DreamBooth と LoRA を使って コーギー🐶の画像 を用いて stable-diffusion-v1-5
を微調整してみましょう。
LoRA チューニングが可能なスクリプトを実行する際に使用する環境変数を以下のように準備しておきます:
MODEL_NAME
には今回使用するstable-diffusion-v1-5
を設定INSTANCE_DIR
には今回学習させるコーギーの画像が保存されているディレクトリを設定OUTPUT_DIR
には学習結果をどこに保存するかを設定
%env MODEL_NAME=runwayml/stable-diffusion-v1-5
%env INSTANCE_DIR=/sddata/instance-dir/
%env OUTPUT_DIR=/sddata/output-dir/
env: MODEL_NAME=runwayml/stable-diffusion-v1-5
env: INSTANCE_DIR=/sddata/instance-dir/
env: OUTPUT_DIR=/sddata/output-dir/
以下を実行して INSTANCE_DIR
に設定した場所に diffusers/dog-example
から画像をダウンロードします。
from diffusers.utils import load_image
from huggingface_hub import snapshot_download
instance_dir = os.environ["INSTANCE_DIR"]
snapshot_download(
"diffusers/dog-example",
local_dir=instance_dir,
repo_type="dataset",
ignore_patterns=".gitattributes",
)
jpg_files = os.listdir(instance_dir)
dog_examples = [
load_image(os.path.join(instance_dir, jpg_file)).resize((256, 256))
for jpg_file in jpg_files
]
image_grid(dog_examples, rows=1, cols=len(dog_examples))
LoRA チューニングの実施#
これで LoRA チューニングを開始する準備ができました。スクリプトには以下のように複数のオプションを指定することが出来ますが、いくつか注意すべきオプションがあります:
--max_train_steps
: 学習回数を指定するオプションです。デフォルトでは 500 に設定されていますが、Colab で実行すると約 6 時間程度訓練に時間がかかります。今回の実習では動作確認のために 10 を設定します--checkpointing_steps
: 学習途中のパラメータを保存するタイミングを指定するオプションです。デフォルトは 100 に設定されていますが、上記の--max_train_steps
に合わせて 5 に変更しました--learing_rate
: LoRA では比較的高めの学習率を設定可能です。デフォルトでは1e-4
が設定されていますが、通常の微調整では1e-5 ~ 1e-6
を使うことが多いでしょう
!accelerate launch --mixed_precision="fp16" train_dreambooth_lora.py \
--pretrained_model_name_or_path="${MODEL_NAME}" \
--instance_data_dir="${INSTANCE_DIR}" \
--output_dir="${OUTPUT_DIR}" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=5 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=10 \
--seed="0" \
--use_8bit_adam
2023-07-29 03:16:10.102240: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
The following values were not passed to `accelerate launch` and had defaults used instead:
`--num_processes` was set to a value of `1`
`--num_machines` was set to a value of `1`
`--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
2023-07-29 03:16:18.325093: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
07/29/2023 03:16:22 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: fp16
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'clip_sample_range', 'timestep_spacing', 'prediction_type', 'thresholding', 'dynamic_thresholding_ratio', 'sample_max_value', 'variance_type'} was not found in config. Values will be initialized to default values.
{'force_upcast', 'scaling_factor'} was not found in config. Values will be initialized to default values.
{'addition_embed_type_num_heads', 'num_attention_heads', 'time_embedding_type', 'class_embeddings_concat', 'addition_time_embed_dim', 'use_linear_projection', 'resnet_time_scale_shift', 'cross_attention_norm', 'num_class_embeds', 'mid_block_only_cross_attention', 'addition_embed_type', 'dual_cross_attention', 'mid_block_type', 'time_embedding_dim', 'encoder_hid_dim', 'class_embed_type', 'encoder_hid_dim_type', 'upcast_attention', 'time_cond_proj_dim', 'resnet_skip_time_act', 'resnet_out_scale_factor', 'only_cross_attention', 'projection_class_embeddings_input_dim', 'conv_in_kernel', 'transformer_layers_per_block', 'timestep_post_act', 'conv_out_kernel', 'time_embedding_act_fn'} was not found in config. Values will be initialized to default values.
07/29/2023 03:16:46 - INFO - __main__ - ***** Running training *****
07/29/2023 03:16:46 - INFO - __main__ - Num examples = 5
07/29/2023 03:16:46 - INFO - __main__ - Num batches each epoch = 5
07/29/2023 03:16:46 - INFO - __main__ - Num Epochs = 2
07/29/2023 03:16:46 - INFO - __main__ - Instantaneous batch size per device = 1
07/29/2023 03:16:46 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 1
07/29/2023 03:16:46 - INFO - __main__ - Gradient Accumulation steps = 1
07/29/2023 03:16:46 - INFO - __main__ - Total optimization steps = 10
Steps: 50% 5/10 [00:06<00:04, 1.07it/s, loss=0.00415, lr=0.0001]07/29/2023 03:16:52 - INFO - accelerate.accelerator - Saving current state to /sddata/output-dir/checkpoint-5
Model weights saved in /sddata/output-dir/checkpoint-5/pytorch_lora_weights.bin
07/29/2023 03:16:52 - INFO - accelerate.checkpointing - Optimizer state saved in /sddata/output-dir/checkpoint-5/optimizer.bin
07/29/2023 03:16:52 - INFO - accelerate.checkpointing - Scheduler state saved in /sddata/output-dir/checkpoint-5/scheduler.bin
07/29/2023 03:16:52 - INFO - accelerate.checkpointing - Gradient scaler state saved in /sddata/output-dir/checkpoint-5/scaler.pt
07/29/2023 03:16:52 - INFO - accelerate.checkpointing - Random states saved in /sddata/output-dir/checkpoint-5/random_states_0.pkl
07/29/2023 03:16:52 - INFO - __main__ - Saved state to /sddata/output-dir/checkpoint-5
Steps: 100% 10/10 [00:10<00:00, 1.34it/s, loss=0.00344, lr=0.0001]07/29/2023 03:16:56 - INFO - accelerate.accelerator - Saving current state to /sddata/output-dir/checkpoint-10
Model weights saved in /sddata/output-dir/checkpoint-10/pytorch_lora_weights.bin
07/29/2023 03:16:56 - INFO - accelerate.checkpointing - Optimizer state saved in /sddata/output-dir/checkpoint-10/optimizer.bin
07/29/2023 03:16:56 - INFO - accelerate.checkpointing - Scheduler state saved in /sddata/output-dir/checkpoint-10/scheduler.bin
07/29/2023 03:16:56 - INFO - accelerate.checkpointing - Gradient scaler state saved in /sddata/output-dir/checkpoint-10/scaler.pt
07/29/2023 03:16:56 - INFO - accelerate.checkpointing - Random states saved in /sddata/output-dir/checkpoint-10/random_states_0.pkl
07/29/2023 03:16:56 - INFO - __main__ - Saved state to /sddata/output-dir/checkpoint-10
Steps: 100% 10/10 [00:10<00:00, 1.34it/s, loss=0.0684, lr=0.0001] Model weights saved in /sddata/output-dir/pytorch_lora_weights.bin
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
Loading pipeline components...: 0% 0/7 [00:00<?, ?it/s]{'addition_embed_type_num_heads', 'num_attention_heads', 'time_embedding_type', 'class_embeddings_concat', 'addition_time_embed_dim', 'use_linear_projection', 'resnet_time_scale_shift', 'cross_attention_norm', 'num_class_embeds', 'mid_block_only_cross_attention', 'addition_embed_type', 'dual_cross_attention', 'mid_block_type', 'time_embedding_dim', 'encoder_hid_dim', 'class_embed_type', 'encoder_hid_dim_type', 'upcast_attention', 'time_cond_proj_dim', 'resnet_skip_time_act', 'resnet_out_scale_factor', 'only_cross_attention', 'projection_class_embeddings_input_dim', 'conv_in_kernel', 'transformer_layers_per_block', 'timestep_post_act', 'conv_out_kernel', 'time_embedding_act_fn'} was not found in config. Values will be initialized to default values.
Loaded unet as UNet2DConditionModel from `unet` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 14% 1/7 [00:17<01:47, 17.87s/it]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 29% 2/7 [00:17<00:37, 7.43s/it]Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 43% 3/7 [00:21<00:22, 5.67s/it]{'timestep_spacing', 'prediction_type'} was not found in config. Values will be initialized to default values.
Loaded scheduler as PNDMScheduler from `scheduler` subfolder of runwayml/stable-diffusion-v1-5.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loaded safety_checker as StableDiffusionSafetyChecker from `safety_checker` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 71% 5/7 [00:29<00:09, 4.64s/it]Loaded feature_extractor as CLIPImageProcessor from `feature_extractor` subfolder of runwayml/stable-diffusion-v1-5.
{'force_upcast', 'scaling_factor'} was not found in config. Values will be initialized to default values.
Loaded vae as AutoencoderKL from `vae` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 100% 7/7 [00:31<00:00, 4.48s/it]
{'solver_type', 'lambda_min_clipped', 'timestep_spacing', 'variance_type', 'prediction_type', 'use_karras_sigmas', 'thresholding', 'algorithm_type', 'sample_max_value', 'dynamic_thresholding_ratio', 'lower_order_final', 'solver_order'} was not found in config. Values will be initialized to default values.
Loading unet.
Steps: 100% 10/10 [00:43<00:00, 4.33s/it, loss=0.0684, lr=0.0001]
CLIP の text encoder を LoRA でファインチューニングすることも可能です。殆どの場合、計算量を少し増やすだけで、よりよい結果をもたらします。LoRA で text encoder をファインチューニングするには、train_dream_lora.py
を実行する際に、--train_text_encoder
を指定してください。
LoRA チューニング結果を元にした推論#
Text-to-Image における LoRA チューニングでも説明しましたが、DreamBooth との組み合わせでも同様にStableDiffusionPipeline
でベースモデルを読み込み、推論パイプラインを構成します。
ベースモデルの重みの上に、ファインチューニングした DreamBooth モデルの LoRA モジュールを読み込み、パイプラインを GPU に移動させて推論を高速化します。scale
値でベースモデルと LoRA モジュールの利用度合いを調整できたことを思い出してください。
model_base = os.environ["MODEL_NAME"]
lora_model_path = os.environ["OUTPUT_DIR"]
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
pipe.unet.load_attn_procs(lora_model_path)
pipe = pipe.to("cuda")
image = pipe(
"A picture of a sks dog in a bucket.",
num_inference_steps=25,
guidance_scale=7.5,
cross_attention_kwargs={"scale": 0.5},
).images[0]
時間の関係上、上記 DreamBooth と LoRA チューニングの組み合わせでは十分な学習ができていないためここでは画像生成結果をお見しておりません。
ただ、以下のようにしてコードを実行することで DreamBooth と LoRA チューニングによって獲得されたコーギーを合成したような画像が生成可能です。
以下は上記で LoRA チューニングに使用した diffusers/dog-example を十分に LoRA DreamBooth で学習させた patrickvonplaten/lora_dreambooth_dog_example
による画像生成例です。
LoRA モジュールの読み込みには、上記で使用した load_attn_procs()
よりも、load_lora_weights()
の使用が好ましいようです(公式ドキュメントより)。これは load_lora_weights()
が以下の状況に対応できるからです:
U-Net や text encoder で別々の識別子を持たない LoRA モジュール(今回使用した
patrickvonplaten/lora_dreambooth_dog_example
含めて)の場合U-Net や text encoder で別々の識別子を持つ LoRA モジュール(
sayakpaul/dreambooth
など)
lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id)
pipe = pipe.to("cuda")
generator = torch.Generator().manual_seed(2)
pipe(
"A picture of a sks dog in a bucket.",
num_inference_steps=25,
guidance_scale=7.5,
generator=generator,
cross_attention_kwargs={"scale": 0.5},
).images[0]
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
/usr/local/lib/python3.10/dist-packages/diffusers/loaders.py:1223: UserWarning: You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`.
warnings.warn(warn_message)