GANsを実装してみる
初めまして。こんにちは。
今回は、最近流行っているGANsを自分でも実装してみたので、その記録をこちらに残しておこうと思います。
はじめに
Generative adversarial networls (GANs)は画像生成手法の一つで、GeneratorとDiscriminatorと呼ばれる二つのモデルを競わせながら学習させます。
Generatorは偽物画像を生成し、Discriminatorは偽物画像と本物画像を識別します。GeneratorはDiscriminatorを騙せるように、Discriminatorは正しく識別できるように学習するため、学習が上手くいけば本物画像に近い偽物画像を生成できるようになります。
今回は、手書き数字画像のデータセットであるMNISTと、自分で撮り溜めておいたご飯の画像のデータセットの二つを使用して、画像生成を行なってみました。
モデル
アーキテクチャとしては、ネットワークの中に転置畳み込み層/畳み込み層を組み込んだDeep convolutional generative adversarial networks (DCGANs)を使用しました。
Generatorは5層の転置畳み込み層、Discriminatorは5層の畳み込み層からなります。Generatorの活性化関数にはReLU及びTanhを使用し、DiscriminatorにはLeakyReLUを使用しました。また中間層にはBatch normalizationを挿入しました。
GANsの欠点の一つとして学習の不安定性があります。今回私も、初めの方はDiscriminatorが圧勝してしまい、かなり苦しめられました。というのも、学習初期段階では、Generatorの生成する画像の精度が低く、所謂ノイズ画像となります。Discriminatorからすれば、ノイズ画像と本物画像の識別は容易にできてしまうわけで、このまま学習が進み収束してしまうと、ノイズ画像しか生成できなくなってしまうのです。
そこで、Discriminatorにハンデを課すべく、何点か工夫を施しました。
- GeneratorとDiscriminatorの重み更新頻度の調整
- Discriminatorの重み1回の更新に対してGeneratorの重みを複数回更新するようにしました。
- LSGANsの使用
- ロス関数はLeast squares generative adversarial networks (LSGANs)で使われているロスを使用しました。
- 正解/不正解ラベルにノイズ混入
- 不正解は0、正解は1という01のラベリングではなく、不正解は0~0.3、正解は0.7~1.0とある程度幅を持たせたラベルを使用しました。
- Discriminatorの最終畳み込み層の前にDropoutを挿入
結果
データセットとして、まずはMNISTを使用して学習させました。実際の生成結果を以下に示します。
MNISTは「黒背景に白文字」と画像間のばらつきが少なく、学習データも60000枚と十分にあるため、生成も上手くいっているように思えます。
次に、自分で撮り溜めていた、過去3年分のご飯画像を学習させてみました。
画像の枚数は600枚程度と、かなり少なかったので、Data augumentationを施しています。
なんとなく、ご飯っぽい画像は生成できていますが、一部歪んだり、色が混ざったりしていて、よくよく見ると出来はイマイチに感じます。学習データの枚数が約600枚と少なかったこと、モデル構造がシンプルであったことが原因として考えられます。ただ、個人的には、逆にこれだけ少ないデータ、かつシンプルなモデル構造でも、ある程度の精度の画像が作れることがわかり、少し驚きました。
おわりに
初めてGANsを実装してみましたが、画像数600枚程度でも"っぽい"画像が生成できることがわかりました。
一方で、より解像度の高い画像を生成しようとすると、よりたくさんの学習データや、最新のアーキテクチャの導入が必要になってきそうです。
(GANsとはまた系譜の違うモデルですが、)テキストからかなり解像度の高い画像を生成するStable Diffusion modelなんかは話題にもなっているので、引き続き画像生成分野の勉強をしていこうと思います。
参考文献
- Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
- Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).
- Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).
- Mao, Xudong, et al. "Least squares generative adversarial networks." Proceedings of the IEEE international conference on computer vision. 2017.
- cedro-blog, “PyTorchでConditional GANをやってみる”, http://cedro3.com/ai/pytorch-conditional-gan/, 2019
- Kaggle, “GAN in Pytorch with FID”, https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid/notebook
- Github, “FID score for PyTorch”, https://github.com/mseitzer/pytorch-fid
- Githubm "stylegan2-pytorch", https://github.com/rosinality/stylegan2-pytorch/blob/master/inception.py