珈琲とバイクとときどきプログラミング

狐が好きな編入生的な人間のブログです

Deep Count: Fruit Counting Based on Deep Simulated Learning の プログラムとか

Deep Count: Fruit Counting Based on Deep Simulated Learning

この前の記事で読んだとかいってた論文を見ながら実際にやってみて、精度でるんか? みたいなものを一旦出したいとのことなので手を動かしてみた(pythonクソ初心者です)

www.mdpi.com

環境

python 3.6.3

やったこと

花実の数を数えるにあたり、データセット(画像+写ってる花実の数)を用意するのがとりあえずめんどいから、synthetic imagesで代用したよ〜との記述があったのですが、

そのプログラムとか一ミリもなかったので、とりあえずpython練習がてらプログラムを書きました

f:id:yosshi0774:20180429021643p:plain

jupyter notebookとかで実行するのを考えてるのでガバガバです。

numpyとかのおかげでかなり簡単にできるみたいですね。

別のプログラムで使うことを考えてcsvに結果(ファイル名、花実数)を吐き出すようにしました

import numpy as np
import matplotlib.pyplot as plt
import cv2
import csv
from numpy.random import *

# 画像の幅
width = 130
# 最小の花実の直径
min_r = 5
# 最大の花実の直径
max_r = 13
# 最大の花実の個数
max_count = 20
# 最低の花実の個数
min_count = 0
# ベースとなる緑(RGB)(乱数により揺れるので、適当で大丈夫)
green_base = (72, 107, 12)
# ベースとなる茶色
brown_base = (120, 101, 86)
# 花の色
flower_color =  (237, 223, 29)
# 作る画像の枚数
create_count = 1000
# データの保存先
data_path = "./data/dst/"
# 画像の名前
image_name = "train_image"
# 拡張子
extension = ".jpg"
# 背景画像
back_image = np.full((height, width, 3), 0, dtype=np.uint8)
# 背景色のベースカラー(緑、茶色)
color_pattern = np.array([green_base, brown_base])
# 10飛びのarrayを作成
list = np.array(range(0,width + 1,10))


with open('train_images.csv', 'w') as f:
    writer = csv.writer(f)
    # headerの書き込み
    writer.writerow(["file_name", "count"])
    for i in range(create_count):
        # 半径10の円で 埋めていく
        for x in list:
            for y in list:
                cv2.circle(back_image, (x, y), 10, (color_pattern[randint(0, 2)] + randint(0, 20,(1,3))[0]).tolist(), thickness=-1)

        # ガウシアンフィルターを書けて平滑化
        blured_back_image = cv2.GaussianBlur(back_image, (7, 7), 0)

        # 花実の数を生成
        flower_count = randint(min_count , max_count)

        for j in range(flower_count):
            #花実の大きさをランダムで生成
            flower_radius = randint(min_r, max_r)
            position = (randint(0, width + 1), randint(0, width + 1))
            cv2.circle(blured_back_image, position, flower_radius, flower_color, thickness= -1)

        # bgr になっているので rgb に変換する
        im_rgb = cv2.cvtColor(blured_back_image, cv2.COLOR_BGR2RGB)
        
        save_name = image_name + str(i) + extension
        cv2.imwrite(data_path + save_name, im_rgb)
        writer.writerow([save_name, flower_count])

色とかも適当にベースカラーきめてrandで適当なvarianceを付けるみたなことをしてて

ほんとにこれで大丈夫か?みたいなところはあります。

実行結果

画像達は一応できてそうです

f:id:yosshi0774:20180429022707p:plain

csvもfile名とそれの花実の数を吐き出されていますね。

$ head -n 10 train_images.csv
file_name,count
train_image0.jpg,9
train_image1.jpg,16
train_image2.jpg,11
train_image3.jpg,11
train_image4.jpg,2
train_image5.jpg,11
train_image6.jpg,17
train_image7.jpg,15
train_image8.jpg,18

研究が環境系なのでAWSとかへの知見がなく、研究費がおりないとか言われたので、一旦Google Colab とかでtrainは走らせたいと思います

自分でマシン組めとか言われましたが、流石にしんどいので(時間あればやりたいですが...)

https://colab.research.google.com/

google drive とかに↑で作った画像とcsv置いて、API叩けばgoogle colab上にfetchしてこれるみたいなので、それでいいかなーと思ってます