mecobalamin’s diary

人間万事塞翁が馬

Pythonを使って人口ピラミッドのグラフを作る

Pythonのseabornを使って以下のような人口ピラミッドのグラフを作りたい
f:id:mecobalamin:20210214160109p:plain

元のデータ

man woman
90+ 47.0 128.0
80 - 89 174.0 205.0
70 - 79 321.0 257.0
60- 69 562.0 393.0
50 - 59 604.0 442.0
40 - 49 730.0 501.0
30 - 39 670.0 478.0
20 - 29 895.0 717.0
10 - 19 309.0 258.0
0 - 9 123.0 132.0
Undisclosed 2.0 8.0
Missing number NaN NaN

ここのやり方を参考にさせてもらった
Matplotlib - Python3のmatplotlibを使ったヒストグラムの作図|teratail

上記の表はdf_ageという変数に入っているとする
型はpandasのDataFrame

print(df_age)
                  man  woman
90+              47.0  128.0
80 - 89         174.0  205.0
70 - 79         321.0  257.0
60- 69          562.0  393.0
50 - 59         604.0  442.0
40 - 49         730.0  501.0
30 - 39         670.0  478.0
20 - 29         895.0  717.0
10 - 19         309.0  258.0
0 - 9           123.0  132.0
Undisclosed       2.0    8.0
Missing number    NaN    NaN

print(type(df_age))
<class 'pandas.core.frame.DataFrame'>

男女のデータを一つのヒストグラムに描画する
今回キモになっているのは以下の3つ

  1. データの反転
  2. 一つの図に2つのグラフを書き込む
  3. 軸の書き換え

データの反転
参考にしたサイトにも説明されているが
片方のデータを負の値に反転させている

df_age["man"] *= -1

一つの図に2つのグラフを書き込む
seabornのbarplotを使っている
zip関数を使ってman/womanのデータを
うまく選択している

for name, color in zip(age_names, age_colors):
        sns.barplot(x = name, y = df_age.index, 
            data = df_age, color = color, label = name,
            orient = 'h', order = df_age.index, 
            ax = ax)

zip関数
Python, zip関数の使い方: 複数のリストの要素をまとめて取得 | note.nkmk.me

軸の書き換え
横軸が負->正となるので
グラフの中心を0として
両側に正の値になるように
ラベルを書き換える

ax.set_xticklabels(['1000', '750', '500', '250', '0', '250', '500', '750', '1000'])


コードを以下にまとめる

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style = 'whitegrid')
fig, ax = plt.subplots(figsize = (7, 5))
plt.subplots_adjust(left = 0.2, right = 0.85, bottom = 0.05, top = 1.0)

df_age["man"] *= -1

age_colors = ["#4169e1", "#ff1493"]
age_names = df_age.columns

for name, color in zip(age_names, age_colors):
    sns.barplot(x = name, y = df_age.index, data = df_age, color = color, label = name,
    orient = 'h', order = df_age.index, ax = ax)

ax.set_xlabel("")
ax.set_ylabel("age", fontsize = 12)
ax.set_xlim(-1000, 1000)
ax.set_xticklabels(['1000', '750', '500', '250', '0', '250', '500', '750', '1000'])
ax.legend(loc = 'lower left')

plt.savefig('graph.png')
plt.close()