GANsを実装してみる
初めまして。こんにちは。
今回は、最近流行っているGANsを自分でも実装してみたので、その記録をこちらに残しておこうと思います。
はじめに
Generative adversarial networls (GANs)は画像生成手法の一つで、GeneratorとDiscriminatorと呼ばれる二つのモデルを競わせながら学習させます。
Generatorは偽物画像を生成し、Discriminatorは偽物画像と本物画像を識別します。GeneratorはDiscriminatorを騙せるように、Discriminatorは正しく識別できるように学習するため、学習が上手くいけば本物画像に近い偽物画像を生成できるようになります。
今回は、手書き数字画像のデータセットであるMNISTと、自分で撮り溜めておいたご飯の画像のデータセットの二つを使用して、画像生成を行なってみました。
モデル
アーキテクチャとしては、ネットワークの中に転置畳み込み層/畳み込み層を組み込んだDeep convolutional generative adversarial networks (DCGANs)を使用しました。
Generatorは5層の転置畳み込み層、Discriminatorは5層の畳み込み層からなります。Generatorの活性化関数にはReLU及びTanhを使用し、DiscriminatorにはLeakyReLUを使用しました。また中間層にはBatch normalizationを挿入しました。
GANsの欠点の一つとして学習の不安定性があります。今回私も、初めの方はDiscriminatorが圧勝してしまい、かなり苦しめられました。というのも、学習初期段階では、Generatorの生成する画像の精度が低く、所謂ノイズ画像となります。Discriminatorからすれば、ノイズ画像と本物画像の識別は容易にできてしまうわけで、このまま学習が進み収束してしまうと、ノイズ画像しか生成できなくなってしまうのです。
そこで、Discriminatorにハンデを課すべく、何点か工夫を施しました。
- GeneratorとDiscriminatorの重み更新頻度の調整
- Discriminatorの重み1回の更新に対してGeneratorの重みを複数回更新するようにしました。
- LSGANsの使用
- ロス関数はLeast squares generative adversarial networks (LSGANs)で使われているロスを使用しました。
- 正解/不正解ラベルにノイズ混入
- 不正解は0、正解は1という01のラベリングではなく、不正解は0~0.3、正解は0.7~1.0とある程度幅を持たせたラベルを使用しました。
- Discriminatorの最終畳み込み層の前にDropoutを挿入
結果
データセットとして、まずはMNISTを使用して学習させました。実際の生成結果を以下に示します。
MNISTは「黒背景に白文字」と画像間のばらつきが少なく、学習データも60000枚と十分にあるため、生成も上手くいっているように思えます。
次に、自分で撮り溜めていた、過去3年分のご飯画像を学習させてみました。
画像の枚数は600枚程度と、かなり少なかったので、Data augumentationを施しています。
なんとなく、ご飯っぽい画像は生成できていますが、一部歪んだり、色が混ざったりしていて、よくよく見ると出来はイマイチに感じます。学習データの枚数が約600枚と少なかったこと、モデル構造がシンプルであったことが原因として考えられます。ただ、個人的には、逆にこれだけ少ないデータ、かつシンプルなモデル構造でも、ある程度の精度の画像が作れることがわかり、少し驚きました。
おわりに
初めてGANsを実装してみましたが、画像数600枚程度でも"っぽい"画像が生成できることがわかりました。
一方で、より解像度の高い画像を生成しようとすると、よりたくさんの学習データや、最新のアーキテクチャの導入が必要になってきそうです。
(GANsとはまた系譜の違うモデルですが、)テキストからかなり解像度の高い画像を生成するStable Diffusion modelなんかは話題にもなっているので、引き続き画像生成分野の勉強をしていこうと思います。
参考文献
- Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
- Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).
- Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).
- Mao, Xudong, et al. "Least squares generative adversarial networks." Proceedings of the IEEE international conference on computer vision. 2017.
- cedro-blog, “PyTorchでConditional GANをやってみる”, http://cedro3.com/ai/pytorch-conditional-gan/, 2019
- Kaggle, “GAN in Pytorch with FID”, https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid/notebook
- Github, “FID score for PyTorch”, https://github.com/mseitzer/pytorch-fid
- Githubm "stylegan2-pytorch", https://github.com/rosinality/stylegan2-pytorch/blob/master/inception.py
SORACOMのGPSマルチユニットを使ってデータをカスタマイズしてみる
初めまして。こんにちは。私は今、工学部に所属する大学生です。今年から一人暮らしをはじめ、今までやったことのないことに挑戦しようとして、なかなか上手くいかず、もがいている、そんな生活を送っています。
(ブログも初めて書くのですが、緊張しています。)
今回SORACOMのIoTデバイスを使って、自転車の管理システムみたいなものを作ってみたので、紹介したいと思います。
前提知識
まずは、製作するにあたって、私自身が持っていた前提知識(のようなもの)を紹介しようと思います。「あ、これくらいの知識だけでもなんとかなるんだ」と思っていただけたら嬉しいです。
お分かりの通り、「工学部」とは名ばかりで、まだまだ"そっち系"においてはてんで初心者なのですが、一夏の挑戦ということで、今回SORACOMのGPSマルチユニットを使って、今話題(?)のIoTシステムの一端に触れてみることにしました。
作成しようとしたもの
一口にIoTシステムと言っても色々なものがあると思います。例えば、ドアにIoT技術を搭載して開閉状態を記録したり、冷蔵庫の温度をモニターしたりなど。。
今回私が使用するGPSマルチユニットは位置情報(GPS)、温度、湿度、加速度の4つを計測してデータをクラウド上に送信してくれます。
これらのデータを使って、何かIoTシステムを作れないか考えました。結構考えました。行き詰まってしまったので、気分転換に自転車に跨り、サイクリングしながら考えました。そんな折にふと思いついたのが、今回のアイデアです。
自転車にGPSマルチユニットを取り付けることで、位置情報や速度情報なんかを記録できないか。また、温度センサーなんかも使用して、熱中症の注意喚起やスピード出し過ぎの注意喚起などを行えないか。
そういうわけで、今回私は、
- パソコンから位置情報や速度情報を見れるようにする
- 状況に応じて適宜LINE等から注意喚起のメッセージを送る
というシステムを作ってみようと思いました。
よく卒論で尋ねられると噂の「この研究なんの役に立つの?」じゃないけれど、有用性や将来性を有り体に言えば、
- 自転車の盗難防止 (盗まれても現在地を割り出せる)
- サイクリングの記録
- 事故防止
なんかに繋がるのではないでしょうか。
作成
SORACOM lagoon
SORACOMのデバイスを使ってみて、一番驚いたのがSORACOM lagoonというサービスです。このサービスを使えば、デバイスから送信されたデータを手軽にグラフ化できたり、データを用いてアラート(通知)を送ることができるのです。その他色々カスタマイズができて本当に素晴らしい!感動です。
温度や加速度の時間変化のグラフ化、地図と組み合わせた位置情報のトレースなんかはlagoonを使えば容易に実現できるので、こちらを使用することに。その他、速度や移動距離など、一度こちらでデータ処理しなければいけないものについては、コードを書いて、処理と表示を実装することにしました。
大雑把な仕組み
当初の予定では、(lagoonに対抗して、じゃないけれど) 速度や距離表示についても、webアプリっぽいものを作ってスマホからでも自由に記録を見られるようにしたかったのですが、なんせそういった経験がないもので。。pythonのdjangoを触ってみたりもしたのですが、結局よくわからないのと時間がそんなになかったということで今回は断念。
とりあえず、lagoonで表示できない速度や移動距離に関しては、
プログラムを起動 → サイクリングへgo → 自動でデータ取得・更新 → 可視化
みたいなことをやってくれるコードを書いてみることにしました。なお、言語としてはpythonを採用しました。
データの取得
まずはデータの取得。SORACOMのデバイスに取り付けられた各種センサーが取得したデータは、SORACOM Harvestなるサービスへと送られるとのこと。そこで、APIを使用して、このHarvestからデータを取得しました。
Harvestから取得したデータはjson形式となっており、一部デコードの必要があったりと、データ形式の変換に苦労しましたが、なんとか欲しいデータを得ることができました。コードは以下の通りです。
import requests import time import json import base64 import ast import datetime time_list = [] #datetime.datetim型 temp_list = [] #温度 lat = [] #緯度 lon = [] #経度 headers = { 'X-Soracom-API-Key': '(自分のAPIキー)', 'X-Soracom-Token': '(自分のAPI トークン)', } response = requests.get('https://api.soracom.io/v1/subscribers/(SIMの番号)/data', headers=headers).json() # 初め全部で10個のデータを取得 先頭から最新順に並んでいる for i in range(len(response)): # データを取得した時刻 time_list.append(datetime.datetime.fromtimestamp(response[i]['time']*0.001)) # 取得したデータのうちkey:'content'に属するものをstr型からdict型に変換 text = ast.literal_eval(response[i]['content']) # byteデータをdecode (バイト型に変換) tmp = base64.b64decode(text["payload"]) # 取得したデータ(バイト型)を、str型、dict型の順に変換 tmp2 = ast.literal_eval(tmp.decode()) # 各配列に値を追加 lat.append(tmp2['lat']) lon.append(tmp2['lon']) temp_list.append(tmp2['temp'])
データ処理
次は取得したデータの処理。今回は速度と移動距離を求めたい。。(と言っても、センサーの精度も考えると"概算"程度の精度となってしまいますが。)
加速度を積分することで、速度、距離を求めるという方法も考えましたが、加速度センサーの向きや、重力加速度など、考慮しなければならないことが多々あったので、今回は、位置情報(経度と緯度)を用いて、概算することにしました。
以下の公式を使用すると、経度と緯度の差から2点間距離を求められるとのことだったので、これにより距離を求め、時間で割って速度を算出しました。
地点A(経度, 緯度)、地点B(経度, 緯度の距離d (rは地球の半径とする)
データの取得とデータ処理を行ってくれる関数を以下のように定義しました。
# データ取得 def get_data(): global count global time_list, temp_list, lat, lon, speed_list, distance, distance_list headers = { # 8/11 14:33 'X-Soracom-API-Key': '(自分のAPIキー)', 'X-Soracom-Token': '(自分のAPI トークン)', } response = requests.get('https://api.soracom.io/v1/subscribers/(SIMの番号)/data', headers=headers).json() # 初期処理 if count==0: for i in range(len(response)): # 時刻を記録 (UNIX時間からdatetime型に変換) time_list.append(datetime.datetime.fromtimestamp(response[i]['time']*0.001)) # 取得したデータのうちkey:'content'に属するものをstr型からdict型に変換 text = ast.literal_eval(response[i]['content']) # byteデータをdecode (バイト型に変換) tmp = base64.b64decode(text["payload"]) # 取得したデータ(バイト型)を、str型、dict型の順に変換 tmp2 = ast.literal_eval(tmp.decode()) # 各配列に値を追加 lat.append(tmp2['lat']) lon.append(tmp2['lon']) temp_list.append(tmp2['temp']) # 最初に取得した10個のデータは新しい順なので逆順にする(以降append()で追加するので古い順に) time_list.reverse() lat.reverse() lon.reverse() temp_list.reverse() # 距離計算 (速度計算で使用するので各記録ごとの変位はdistance_list[km]で記憶) for j in range(len(lat)-1): distance_list.append(R * math.acos(math.sin(math.radians(lat[j]))*math.sin(math.radians(lat[j+1])) + math.cos(math.radians(lat[j]))*math.cos(math.radians(lat[j+1]))*math.cos(math.radians(lon[j]-lon[j+1])))) for k in range(len(distance_list)): distance = distance + distance_list[k] # 速度計算 (km/h) for l in range(len(distance_list)): dt = time_list[l+1]-time_list[l] speed_list.append(distance_list[l] / (dt.total_seconds()/3600)) count += 1 # 2回目以降の処理 (データの追加および更新) else : time_tmp = datetime.datetime.fromtimestamp(response[0]['time']*0.001) # データが更新されているかチェック if(time_list[-1]!=time_tmp): print('\nchange') # 更新されたデータ text2 = ast.literal_eval(response[0]['content']) tmp3 = base64.b64decode(text2["payload"]) latest_data = ast.literal_eval(tmp3.decode()) # データを追加 time_list.append(time_tmp) lat.append(latest_data['lat']) lon.append(latest_data['lon']) temp_list.append(latest_data['temp']) # 距離計算 distance_list.append(R * math.acos(math.sin(math.radians(lat[-2]))*math.sin(math.radians(lat[-1])) + math.cos(math.radians(lat[-2]))*math.cos(math.radians(lat[-1]))*math.cos(math.radians(lon[-2]-lon[-1])))) distance = distance + distance_list[-1] # 速度計算 dt = time_list[-1]-time_list[-2] speed_list.append(distance_list[-1] / (dt.total_seconds()/3600)) else: print('\nno change') count += 1
グラフ化・自動化・その他機能
処理したデータに対して、matplotlibを使用してグラフを作成し、tkinterを用いてGUIとしてwindow上に表示させました。
データやそれに伴うグラフの自動更新に関してはtkinterに用意されているafter()という関数を使用して実行しました。本来ならリアルタイムでの更新を行いたかったのですが、そもそもSORACOMのGPSマルチユニットのデータ送信間隔が最短でも1分ということだったので、1分ごとに更新することにしました。
処理の流れとしては、
1分ごとにAPIでデータにアクセス → データが更新されていれば新たに取得(リストに追加) → 自動でグラフを書き換える
みたいな感じですかね。。
コードはこんな感じになりました。
# グラフ表示のためのインスタンス fig = plt.Figure() # グラフ表示 def make_graph(): global distance fig.clf() # データ取得 get_data() # 速度のグラフを描写 ax1 = fig.add_subplot(211) ax1.plot(time_list[1:], speed_list, marker="o") ax1.set_title("speed") ax1.set_ylabel("speed [km/h]") xaxis_1 = ax1.xaxis xaxis_1.set_major_formatter(DateFormatter('%H:%M')) #x軸表示を時刻のみに調整 # 距離のグラフを描写 ax2 = fig.add_subplot(212) ax2.plot(time_list[1:], sum_up(distance_list), marker="o") ax2.set_title("distance") ax2.set_ylabel("distance [km]") ax2.set_xlabel("time [s]") xaxis_2 = ax2.xaxis xaxis_2.set_major_formatter(DateFormatter('%H:%M')) #x軸表示を時刻のみに調整 # 距離表示 distance_text = str(distance) total_distance_text = "total distance: " + distance_text[0:7] + " (km)" adtext_distance = tkinter.Label(root, text=total_distance_text, fg="white", bg="black", font=font1) adtext_distance.place(x=190, y=730) # 更新日時表示 update_date_text = "update date: " + str(time_list[-1].date()) + " " + str(time_list[-1].hour) + ":" + str(time_list[-1].minute) adtext_update_date = tkinter.Label(root, text=update_date_text) adtext_update_date.place(x=380, y=70) fig.canvas.draw() root.after(timespan*1000,make_graph)
その他機能として、LINEへの通知も実装しました。LINE notifyを使用して、取得した速度や温度がある閾値を超えたら通知が行くようにしたのですが、LINE notifyの手軽さに感動しました。すごいですね、これ。
LINEで通知してくれる関数は以下のように定義しました。
# LINEで通知 def alert(): global speed_list, temp_list, alert_count # LINE通知の準備 #alert_count = 0 line_notify_api = 'https://notify-api.line.me/api/notify' line_notify_token = '(取得したLINEのAPIトークン)' message_1 = 'スピード注意' message_2 = '熱中症注意: 少し休憩しましょう' line_payload_1 = {'message': message_1} line_payload_2 = {'message': message_2} line_headers = {'Authorization': 'Bearer ' + line_notify_token} # 速度オーバーの時LINEで通知 if(speed_list[-1]>15): requests.post(line_notify_api, data=line_payload_1, headers=line_headers) else: print("safe speed") # 熱中症換気 if(alert_count>30 and temp_list[-1]>=30): requests.post(line_notify_api, data=line_payload_2, headers=line_headers) alert_count = 0 else: print("safe temperature") alert_count += 1
成果
さて、これでとりあえずプロトタイプ(のプロトタイプのプロトタイプ)、のようなものは完成したので実際に使ってみました。
まずはプログラムを起動させます。
$ python3 final.py
あとは、PCのスリープモード設定をoffにして(じゃないと途中で処理が停止してしまう)、GPSマルチユニットを持ってサイクリングへgo。今回は思い切って皇居の周辺まで行ってみました。
帰宅してPCの画面を確認すると、しっかり記録がとれてました!
右側がSORACOM lagoonを使用して書き出したグラフ、左側が、私がデータを処理して書き出したグラフです。速度解析に関してはもう少し工夫が必要な気もしますが、とりあえずそれっぽいのが書けました!
また、今回は、30分以上の走行と気温30度以上で熱中症注意の通知、時速15km以上で速度注意の通知をラインから送る設定にしました。加えて、SORACOM lagoonの方でも、加速度が一定以上になるとLINEで通知が送られて来る設定にしておいたので、こんな感じで結構通知が来ました。