mecobalamin’s diary

人間万事塞翁が馬

https://help.hatenablog.com/entry/developer-option

pythonでirisのデータセットを使う

pythonでもirisのデータセットを使える
mecobalamin.hatenablog.com

こちらのサイトで紹介されている
Sklearnを使ってみる1 - ぴろの狂人日記
iris以外のデータセットも使える
scikit-learnのサンプルデータセットの一覧と使い方 | note.nkmk.me


今回はirisのデータセットの内容を確認し
統計処理を行ってグラフを作成する


まずデータセットを確認してみた
irisのデータセットはscikit-learnに含まれている
そこでscikit-learnのdatasetsをインポートする

from sklearn import datasets

irisのデータセットを読み込む

iris = datasets.load_iris()

データセットの型はsklearn.utils.Bunchで
辞書の書式でデータが記録されている

print(type(iris))
print(iris.keys())
print(iris.values())

結果がこちら

<class 'sklearn.utils.Bunch'>
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename'])
dict_values([array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
~~~省略~~~
'petal length (cm)', 'petal width (cm)'], 
'D:\\Python\\Python37-32\\lib\\site-packages\\sklearn\\datasets\\data\\iris.csv'])

データセットについての説明もある

print(iris.DESCR)

データの統計値もある

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================


irisのデータセットをpandasのデータフレームに変換する
その時dataのカラム名をfeature_namesを当てる
またtargetをirisの種名に変えてdataのカラムに追加する

df = pd.DataFrame(iris.data, columns = iris.feature_names)
df['target'] = iris.target_names[iris.target]
print(df.head())

出力結果はこんな感じ
(見辛いのではてな記法で表組みした)

sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa

~~~以下略~~~

統計値を確認

print(df.mean())

当然だけどiris.DESCRで表示される結果と同じになる
繰り上がりがちょっと変な気がする

sepal length (cm)    5.843333
sepal width (cm)     3.057333
petal length (cm)    3.758000
petal width (cm)     1.199333
dtype: float64

統計値をまとめて出すこともできる

print({}.format(df.describe().T))

比較しやすいように行と列を入れ替えてある

                   count      mean       std  min  25%   50%  75%  max
sepal length (cm)  150.0  5.843333  0.828066  4.3  5.1  5.80  6.4  7.9
sepal width (cm)   150.0  3.057333  0.435866  2.0  2.8  3.00  3.3  4.4
petal length (cm)  150.0  3.758000  1.765298  1.0  1.6  4.35  5.1  6.9
petal width (cm)   150.0  1.199333  0.762238  0.1  0.3  1.30  1.8  2.5

pandasのgoupbyで種ごとの統計値を計算する

df_grouped = df.groupby(['target'])
print(df_grouped.mean())
            sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
target                                                               
setosa                  5.006             3.428              1.462             0.246
versicolor              5.936             2.770              4.260             1.326
virginica               6.588             2.974              5.552             2.026


結果を棒グラフとヒートマップにして
画像として保存する

current_dpi = mpl.rcParams['figure.dpi']
print(current_dpi)

path = os.getcwd()
path = os.chdir(os.path.dirname(os.path.abspath(__file__)))
path = os.getcwd()

plt.figure()
df_grouped.mean().T.plot(kind = 'bar', yerr = df_grouped.std().T, rot = 0)
plt.savefig(path + '\\' + 'bar_graph_mean.png', dpi = current_dpi * 1.5)
plt.close()

plt.figure()
plt.figure(figsize=(8, 6))
sns.heatmap(df_grouped.mean().T, square = True, cmap = 'plasma')
plt.savefig(path + '\\' + 'heatmap_mean.png', dpi = current_dpi * 1.5)
plt.close()

グラフは以下の通り

ヒートマップのカラーマップは以下のサイトを参考にした
https://matplotlib.org/tutorials/colors/colormaps.html#grayscale-conversion


実際のコードはこんな感じ

import os

from pandas import Series, DataFrame
import pandas as pd

from sklearn import datasets

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', 10)

iris = datasets.load_iris()
print(type(iris))
print(iris.keys())
print(iris.DESCR)

df = pd.DataFrame(iris.data, columns = iris.feature_names)
df['target'] = iris.target_names[iris.target]

print(df.head())
print(df.columns)

df_grouped = df.groupby(['target'])
df_mean = df_grouped.mean()
print(df_mean)

current_dpi = mpl.rcParams['figure.dpi']
print(current_dpi)

path = os.getcwd()
path = os.chdir(os.path.dirname(os.path.abspath(__file__)))
path = os.getcwd()

plt.figure()
df_grouped.mean().T.plot(kind = 'bar', yerr = df_grouped.std().T, rot = 0)
plt.savefig(path + '\\' + 'bar_graph_mean.png', dpi = current_dpi * 1.5)
plt.close()

plt.figure(figsize=(8, 6))
sns.heatmap(df_grouped.mean().T, square = True, cmap = 'plasma')
plt.savefig(path + '\\' + 'heatmap_mean.png', dpi = current_dpi * 1.5)
plt.close()