Pythonで階層的クラスタリング
研究で階層的クラスタリングを扱う機会があり、いろんな記事を読みながら実装し、自分がよく使う内容を実装してまとめてみた。
(Qiitaに投稿するか迷ったけど、Tips的内容でもないし、この手のものは個人のほうがいいのかなと)
クラスタリングに関しては以下の論文がわかりやすい
実装には主にScipyを使用した。Spicyは科学技術計算系のモジュールで、研究でもフーリエ変換とかヒルベルト変換とかいろいろなところでお世話になってます
階層的クラスタリングをPythonで実装するのに見やすい記事を3つほど
特に一番下のSciPy Hierarchical Clustering and Dendrogram Tutorialは非常に実装するのに分かりやすい記事なので是非とも参考にして欲しい(特に、fancy_dendrogramって関数がデンドログラムを超キレイに描いてくれるから私も使用している)
私がしたかった分析は、階層的クラスタリングの、どの枝にどのラベルのデータがどのくらいの割合入っているのかを可視化させたかった
パッとわかる方法はないかなぁと考えて、結局ヒストグラム形式で可視化できるようにした。
def view_breakdown(label: np.ndarray, pred: np.ndarray, kind: str = "bar", name: str = False) -> pd.DataFrame: """ クラスタリングやクラス分類結果の内訳を表示 :param label: もとのラベル :param pred: 結果 :param kind: bar or barh :return: 内訳のdf """ df_index = sorted(list(set(pred)), key=int) df_columns = sorted(list(set(label))) report = pd.DataFrame(index=df_index, columns=df_columns).fillna(0) for i in range(0, len(pred)): report.ix[pred[i], label[i]] += 1 print(report) ratio = (report / report.sum()) ratio.plot(kind=kind) if kind == 'bar': plt.ylim(0, 1) plt.ylabel("Proportion") plt.xlabel("Cluster ID") else: plt.xlim(0, 1) plt.xlabel("Proportion") plt.ylabel("cluster ID") if name is not False: plt.savefig(name) plt.show() return report
クラスタリング対象のデータラベルと、クラスタリング結果のクラスタ番号を比較して、クラスタ○○にはラベル○○が何パーセントあるよっていうのをヒストグラムで可視化させている
試しに行った結果がこれ。データセットは有名なirisデータ。pythonならscikit-learnのモジュールでよびだせる。 特徴量は4次元で、各50個ずつ3種類のデータセットである
説明をすると、一番左の枝(クラスタ番号1)は、ラベル0のデータ100%で構成されていて、一番右の枝(クラスタ番号3)は、ラベル1のデータが95%くらいとラベル2のデータが30%くらいで構成されてるよと見れる。
コード全体はGithubに。