Matplotlibで大量の画像を出力すると落ちる
前回までにメモリの使用量や処理時間など、プログラムによるパソコンへの負荷の計測方法を色々と試してきました。
ということでそろそろ本題。
元々発生していた問題としては、「matplotlibで大量の画像を出力するとメモリ不足で落ちる」という問題です。
とある事情で、何万枚もの画像を出力することになった時、800枚程度のところでどうしてもプログラムが落ちてしまい、継続して処理できなかったのです。
その時に色々と調べて出てきた情報がこちら。
こちらの記事によると、「savefig」をすると、メモリリークしてしまい、どんどん使用しているメモリ量が増加してしまうという現象があるとのこと。
またその対処法としては、画像エリアの初期化は一度だけ行い(つまり fig = plt.figure()を一度だけ)、画像が重ならないように「plt.cla()」で画像エリアを消去するという方法でした。
ということで自分の環境でも同じようにメモリリークが解消されるのか試してみました。
メモリリークすることの確認
まずはメモリリークしていることを確認しましょう。
確認するために作成したプログラムがこちら。
import matplotlib.pyplot as plt
import random
import os
import datetime
import psutil
import csv
num_graph = 1000
data_num = 1000
data_range = [-10, 10]
default_dirpath = os.getcwd()
outputgraph_dirname = 'graph'
outputgraph_dirpath = os.path.join(default_dirpath, outputgraph_dirname)
start_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
output_filepath = os.path.join(default_dirpath, f'{start_time}.csv')
def randomData(min_val, max_val, data_num):
x_list = []; y_list = []
for i in range(data_num):
x_list.append(i)
y_list.append(random.randint(min_val, max_val))
return x_list, y_list
def graphMake(x_list, y_list, graph_no, output_filepath):
fig = plt.figure()
plt.clf()
plt.plot(x_list, y_list)
outputfig_filepath = os.path.join(outputgraph_dirpath, f'{graph_no}.png')
plt.savefig(outputfig_filepath)
memorySave(output_filepath)
def memorySave(output_filepath):
timenow = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S')
mem = psutil.virtual_memory()
used = mem.used
percent = mem.percent
row = [timenow, used, percent]
if not os.path.exists(output_filepath):
with open(output_filepath, 'w') as f_in:
writer = csv.writer(f_in)
header = ['Time', 'Used', 'PercentUsed']
writer.writerow(header)
writer.writerow(row)
elif os.path.exists(output_filepath):
with open(output_filepath, 'a') as f_in:
writer = csv.writer(f_in)
writer.writerow(row)
def main():
for i in range(num_graph):
x_list, y_list = randomData(data_range[0], data_range[1], data_num)
graphMake(x_list, y_list, i, output_filepath)
if __name__ == '__main__':
main()
プログラムの解説
プログラムの流れ
このプログラムでは
- プロット用の値をランダムに取得する
- グラフ化し、グラフ画像を保存する
- メモリ使用量、使用率を取得し、CSVファイルとして保存する
という流れを指定した数だけ繰り返します。
ライブラリのインポートと設定部分
import matplotlib.pyplot as plt
import random
import os
import datetime
import psutil
import csv
num_graph = 1000
data_num = 1000
data_range = [-10, 10]
default_dirpath = os.getcwd()
outputgraph_dirname = 'graph'
outputgraph_dirpath = os.path.join(default_dirpath, outputgraph_dirname)
start_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
output_filepath = os.path.join(default_dirpath, f'{start_time}.csv')
今回使用するライブラリは「matplotlib」、「random」、「os」、「datetime」、「psutil」、「csv」です。
「num_graph」がグラフを出力する枚数、「data_num」は一つのグラフに出力する値の数、「data_range」は値の最大値、最小値でこの間の数字がランダムに取得されます。
そして「outputgraph_dirpath」でグラフを出力するパスの設定、「start_time」で開始時間の取得、「output_filepath」でメモリ使用量、使用率を出力するファイルパスの設定を行なっています。
randomData関数
こちらは指定した数だけ、ランダムにX値、Y値を作成し、そのリストを返す関数です。
def randomData(min_val, max_val, data_num):
x_list = []; y_list = []
for i in range(data_num):
x_list.append(i)
y_list.append(random.randint(min_val, max_val))
return x_list, y_list
graphMake関数
こちらは作成したX値のリスト、Y値のリストからグラフを作成し、保存する関数です。
def graphMake(x_list, y_list, graph_no, output_filepath):
fig = plt.figure()
plt.clf()
plt.plot(x_list, y_list)
outputfig_filepath = os.path.join(outputgraph_dirpath, f'{graph_no}.png')
plt.savefig(outputfig_filepath)
memorySave(output_filepath)
最後に「memorySave関数」を呼び出すことで、メモリの使用量、使用率を保存しています。
memorySave関数
こちらはメモリの使用率、使用量をCSVファイルとして保存するための関数です。
def memorySave(output_filepath):
timenow = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S')
mem = psutil.virtual_memory()
used = mem.used
percent = mem.percent
row = [timenow, used, percent]
if not os.path.exists(output_filepath):
with open(output_filepath, 'w') as f_in:
writer = csv.writer(f_in)
header = ['Time', 'Used', 'PercentUsed']
writer.writerow(header)
writer.writerow(row)
elif os.path.exists(output_filepath):
with open(output_filepath, 'a') as f_in:
writer = csv.writer(f_in)
writer.writerow(row)
この関数に関しては、こちらの記事で解説していますので、良かったらどうぞ。
main関数
最後にmain関数ですが、指定したグラフの枚数分だけ、上記のrandomData関数とgraphMake関数を実行しています。
def main():
for i in range(num_graph):
x_list, y_list = randomData(data_range[0], data_range[1], data_num)
graphMake(x_list, y_list, i, output_filepath)
if __name__ == '__main__':
main()
実行してみた結果
上記のプログラムを実行して、出力されたメモリの使用量、使用率のデータをこちらの記事で作成したプログラムでグラフ化してみました。
メモリ使用量も使用率も70秒あたりまで直線的に伸びていき、その後メモリ使用量は減少しているのに、メモリ使用率は一旦下がって増加するを繰り返しています。
メモリ使用率も使用量も同じように変動するはずですが、もしかしたらMacOSとしてメモリ使用量が多くなりすぎた場合は調整が入って、見かけ上使用率が落ちるなんてことがあるのかもしれません。
細かい内容は分かりませんが、重要なのはメモリ使用率としては、40%未満から80%近くまで増加しているということです。
今回のプログラムの枚数では落ちませんでしたが、このままグラフ出力が続くとそのうちにメモリがパンクして落ちるということが予想されます。
メモリリークしないだろうプログラム
次に最初に紹介した記事にあった「画像エリアの初期化は一度だけ行い(つまり fig = plt.figure()を一度だけ)、画像が重ならないように「plt.cla()」で画像エリアを消去するという方法」を試してみましょう。
プログラムの変更点
変更するのはgraphMake関数とmain関数の部分です。
変更後のgraphMake関数
def graphMake(x_list, y_list, graph_no, output_filepath):
plt.cla()
plt.plot(x_list, y_list)
outputfig_filepath = os.path.join(outputgraph_dirpath, f'{graph_no}.png')
plt.savefig(outputfig_filepath)
memorySave(output_filepath)
「fig = plt.figure()」を削除し、グラフエリア全体をクリアする「plt.clf()」の代わりに、グラフエリアの特定の軸データをクリアする「plt.cla()」を追加します。
変更後のmain関数
def main():
fig = plt.figure()
for i in range(num_graph):
x_list, y_list = randomData(data_range[0], data_range[1], data_num)
graphMake(x_list, y_list, i, output_filepath)
if __name__ == '__main__':
main()
main関数には、graphMake関数で削除した「fig = plt.figure()」をfor文より前に追加します。
これでグラフエリアは最初に一回だけ初期化され、グラフエリアをクリアして再利用されます。
実行してみた結果
実行してみた結果がこちらです。
こちらの場合はメモリ使用量とメモリ使用率が連動している形になりました。
そしてメモリ使用率としては、開始時が31%程度、終了時が33.5%程度とほとんど増加していないのが分かります。
ということでこれでメモリリークによるプログラムが落ちる現象は防げそうです。
plt.cla()の代わりにplt.clf()は使えないのか?
ここでふと疑問に思ったのは、グラフエリアの特定の軸のデータをクリアするplt.cla()ではなく、グラフエリア全体をクリアするplt.clf()は使えないのかということです。
plt.clf()でもグラフエリア自体は確保されたままで、プロットが消されるだけなので、多分使えるのではないでしょうか。
ということで試してみました。
プログラムの変更点
変更後のgraphMake関数
def graphMake(x_list, y_list, graph_no, output_filepath):
plt.clf()
plt.plot(x_list, y_list)
outputfig_filepath = os.path.join(outputgraph_dirpath, f'{graph_no}.png')
plt.savefig(outputfig_filepath)
memorySave(output_filepath)
先ほど変更したplt.cla()を元に戻して、plt.clf()にしました。
実行してみた結果
実行してみた結果がこちらです。
メモリ使用量自体は細かく変動していますが、変動幅が小さいため、メモリ使用率は40.7%と40.8%を行ったりきたりしています。
つまりplt.clf()でも問題なく使えるということで、重要なのは「fig = plt.figure()」を繰り返さない、つまりグラフエリアの作成を繰り返さないということのようです。
これで大量のグラフを出力できるようになりました。
ただ大量のグラフを出力するには、時間が掛かってしまいます。
そこで次回は処理時間を短縮するため、並列処理のやり方を勉強してみましょう。
ではでは今回はこんな感じで。
コメント