[Python] その2の2 ボストンの住宅価格を重回帰分析ver2で予測してみた
重回帰分析を以前やりましたが、 説明変数はマニュアル方式で 集計結果のグラフを見ながら選んでいました。
今回はバージョンアップして全ての説明変数の組み合わせを 試して一番精度がよいモデルを選択するプログラムを組みました。
精度の評価方法は色々ありますが、本記事ではMean Absolute Error(MAE)の値が 一番小さいモデルをベストモデルとして採用するようにしています。
※ MAE = 日本語では平均絶対誤差と呼ぶそうです。分析のコンペなどでたまに 評価方法で採用されています。
Checking Boston housing prices corrected dataset 2_2
パッケージのインポート
import packages
import seaborn as sns import pandas as pd import numpy as np
データの読み込み
reading data
df = pd.read_csv("http://lib.stat.cmu.edu/datasets/boston_corrected.txt",skiprows=9,sep="\t")
※ 22年2月1日更新: UnicodeDecodeErrorが発生する方はencoding='Windows-1252'を追加してください。 詳細は下記新ブログをご参照ください。
pd.DataFrame(df.columns)
0 | |
---|---|
0 | OBS. |
1 | TOWN |
2 | TOWN# |
3 | TRACT |
4 | LON |
5 | LAT |
6 | MEDV |
7 | CMEDV |
8 | CRIM |
9 | ZN |
10 | INDUS |
11 | CHAS |
12 | NOX |
13 | RM |
14 | AGE |
15 | DIS |
16 | RAD |
17 | TAX |
18 | PTRATIO |
19 | B |
20 | LSTAT |
df.head(3)
OBS. | TOWN | TOWN# | TRACT | LON | LAT | MEDV | CMEDV | CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | Nahant | 0 | 2011 | -70.955 | 42.2550 | 24.0 | 24.0 | 0.00632 | 18.0 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 |
1 | 2 | Swampscott | 1 | 2021 | -70.950 | 42.2875 | 21.6 | 21.6 | 0.02731 | 0.0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 |
2 | 3 | Swampscott | 1 | 2022 | -70.936 | 42.2830 | 34.7 | 34.7 | 0.02729 | 0.0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 |
モデリング
重回帰分析でやるが、変数は組み合わせをすべて試して一番精度のよいモデルを選択する
サンプリング (訓練データ、テストデータ分割)
# 訓練データとテストデータに分割する。 from sklearn.model_selection import train_test_split train, test = train_test_split(df, test_size=0.20,random_state=100)
# 訓練データの件数確認 train.count()["OBS."]
404
# テストデータの件数確認 test.count()["OBS."]
102
訓練データ、テストデータ作成
# 分析に利用する変数に限定 # 本当だったら事前に変数選択で利用するカラムを限定しておく anacols=[ "CRIM" # 1人当たりの犯罪数 , "ZN" #町別の25,000平方フィート(7600m2)以上の住居区画の割合 , "INDUS" #町別の非小売業が占める土地面積の割合 , "CHAS" #チャールズ川沿いかどうか , "NOX" #町別の窒素酸化物の濃度 , "RM" #住居の平均部屋数 , "AGE" #持ち家住宅 , "DIS" #5つのボストン雇用センターへの重み付き距離 , "RAD" #町別の環状高速道路へのアクセスのしやすさ , "TAX" #町別の$10,000ドルあたりの固定資産税率 , "PTRATIO" #町別の生徒と先生の比率 , "B" #1000*(黒人人口割合 - 0.63)2 , "LSTAT" #貧困人口割合 ] # 訓練データ X_train = train[anacols] # 説明変数 Y_train=train["CMEDV"] # 目的変数 # テストデータ X_test = test[anacols] # 説明変数 Y_test=test["CMEDV"] # 目的変数 # 欠損処理 # nullがあれば0埋めする。平均値や最頻値でもいい X_train = X_train.fillna(0) Y_train = Y_train.fillna(0) X_test = X_test.fillna(0) Y_test = Y_test.fillna(0)
X_train
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
379 | 17.86670 | 0.0 | 18.10 | 0 | 0.6710 | 6.223 | 100.0 | 1.3861 | 24 | 666 | 20.2 | 393.74 | 21.78 |
311 | 0.79041 | 0.0 | 9.90 | 0 | 0.5440 | 6.122 | 52.8 | 2.6403 | 4 | 304 | 18.4 | 396.90 | 5.98 |
157 | 1.22358 | 0.0 | 19.58 | 0 | 0.6050 | 6.943 | 97.4 | 1.8773 | 5 | 403 | 14.7 | 363.43 | 4.59 |
244 | 0.20608 | 22.0 | 5.86 | 0 | 0.4310 | 5.593 | 76.5 | 7.9549 | 7 | 330 | 19.1 | 372.49 | 12.50 |
56 | 0.02055 | 85.0 | 0.74 | 0 | 0.4100 | 6.383 | 35.7 | 9.1876 | 2 | 313 | 17.3 | 396.90 | 5.77 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
343 | 0.02543 | 55.0 | 3.78 | 0 | 0.4840 | 6.696 | 56.4 | 5.7321 | 5 | 370 | 17.6 | 396.90 | 7.18 |
359 | 4.26131 | 0.0 | 18.10 | 0 | 0.7700 | 6.112 | 81.3 | 2.5091 | 24 | 666 | 20.2 | 390.74 | 12.67 |
323 | 0.28392 | 0.0 | 7.38 | 0 | 0.4930 | 5.708 | 74.3 | 4.7211 | 5 | 287 | 19.6 | 391.13 | 11.74 |
280 | 0.03578 | 20.0 | 3.33 | 0 | 0.4429 | 7.820 | 64.5 | 4.6947 | 5 | 216 | 14.9 | 387.31 | 3.76 |
8 | 0.21124 | 12.5 | 7.87 | 0 | 0.5240 | 5.631 | 100.0 | 6.0821 | 5 | 311 | 15.2 | 386.63 | 29.93 |
404 rows × 13 columns
Y_train
379 10.2 311 22.1 157 41.3 244 17.6 56 24.7 ... 343 23.9 359 22.6 323 18.5 280 45.4 8 16.5 Name: CMEDV, Length: 404, dtype: float64
モデル作成 (精度が一番よいモデルを探索する)
import sys import itertools from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_absolute_error comblist=[] best_model=None best_features=None best_mae=sys.maxsize # 変数の選択数 (1 ~ 最大選択可能数) for i in range(1,len(anacols) + 1): # 変数の選択数に合わせた組み合わせを作成 comblist = list(itertools.combinations(anacols,i)) for featurecomb in comblist: # 重回帰モデル作成 multi_regression = LinearRegression() multi_regression.fit(X_train[list(featurecomb)],Y_train) # テストデータに当てはめる yhat_test = multi_regression.predict(X_test[list(featurecomb)]) # 精度(MAE) 他にも様々な評価方法がある mae = mean_absolute_error(Y_test, yhat_test) #一番よい精度のモデルを探索 if mae < best_mae: best_mae = mae best_features = featurecomb best_model = multi_regression print(str(best_mae)) print(best_features)
3.1841915998253896 ('CRIM', 'ZN', 'NOX', 'RM', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT')
作成したモデルの確認
# 係数逆転現象の確認 pd.DataFrame({"name":X_train[list(best_features)].columns,"coefficients":best_model.coef_})
name | coefficients | |
---|---|---|
0 | CRIM | -0.087462 |
1 | ZN | 0.050444 |
2 | NOX | -15.698495 |
3 | RM | 3.656426 |
4 | DIS | -1.530007 |
5 | RAD | 0.312503 |
6 | TAX | -0.013988 |
7 | PTRATIO | -0.936410 |
8 | B | 0.009886 |
9 | LSTAT | -0.504771 |
ボストン雇用センターへの重み付き距離 (DIS) と 固定資産税 (TAX) の係数
がマイナスなのは少し気になるところ。
各説明変数と住宅価格の関係を見た限りの考察としては、
DISは距離が3マイル?以内までは正の強い相関がありそうで、
3マイル以上は無相関のような分布をしている。
TAXは一部の固定資産税が高い地域の住宅価格が低めなのが
負の相関となるような影響を与えていると思われる。
精度上は一番よいのだが、私だったらこのモデルは選ばず
2番目に精度がよい他のモデルを選択します
精度確認
# 自由度調整済みr2を算出 def adjusted_r2(X,Y,model): from sklearn.metrics import r2_score import numpy as np r_squared = r2_score(Y, model.predict(X)) adjusted_r2 = 1 - (1-r_squared)*(len(Y)-1)/(len(Y)-X.shape[1]-1) return adjusted_r2 # 予測モデルの精度確認の各種指標を算出 def get_model_evaluations(X_train,Y_train,X_test,Y_test,model): from sklearn.metrics import explained_variance_score from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_log_error from sklearn.metrics import median_absolute_error # 評価指標確認 # 参考: https://funatsu-lab.github.io/open-course-ware/basic-theory/accuracy-index/ yhat_test = model.predict(X_test) return "adjusted_r2(train) :" + str(adjusted_r2(X_train,Y_train,model)) \ , "adjusted_r2(test) :" + str(adjusted_r2(X_test,Y_test,model)) \ , "平均誤差率(test) :" + str(np.mean(abs(Y_test / yhat_test - 1))) \ , "MAE(test) :" + str(mean_absolute_error(Y_test, yhat_test)) \ , "MedianAE(test) :" + str(median_absolute_error(Y_test, yhat_test)) \ , "RMSE(test) :" + str(np.sqrt(mean_squared_error(Y_test, yhat_test))) \ , "RMSE(test) / MAE(test) :" + str(np.sqrt(mean_squared_error(Y_test, yhat_test)) / mean_absolute_error(Y_test, yhat_test)) #better if result = 1.253
get_model_evaluations(X_train[list(best_features)],Y_train,X_test[list(best_features)],Y_test,best_model)
('adjusted_r2(train) :0.7224056607466813', 'adjusted_r2(test) :0.7386023143239528', '平均誤差率(test) :0.15322170833625723', 'MAE(test) :3.1841915998253896', 'MedianAE(test) :2.501378017150664', 'RMSE(test) :4.7670541453730815', 'RMSE(test) / MAE(test) :1.4971002830465638')
trainとtestの調整済みR2が同じくらいなのでオーバーフィットはしてなさそう。
# 描画設定 from matplotlib import rcParams rcParams['xtick.labelsize'] = 12 # x軸のラベルのフォントサイズ rcParams['ytick.labelsize'] = 12 # y軸のラベルのフォントサイズ rcParams['figure.figsize'] = 18,8 # 画像サイズの変更(inch) import matplotlib.pyplot as plt from matplotlib import ticker sns.set_style("whitegrid") # seabornのスタイルセットの一つ sns.set_color_codes() # デフォルトカラー設定 (deepになってる) plt.figure() ax = sns.regplot(x=Y_test, y=best_model.predict(X_test[list(best_features)]), fit_reg=False,color='#4F81BD') ax.set_xlabel(u"CMEDV") ax.set_ylabel(u"(Predicted) CMEDV") ax.get_xaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ','))) ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda y, p: format(int(y), ','))) ax.plot([0,10,20,30,40,50],[0,10,20,30,40,50], linewidth=2, color="#C0504D",ls="--")
利用バージョンの確認
!python -V
Python 3.6.9
!pip3 -V
pip 19.3.1 from /usr/local/lib/python3.6/dist-packages/pip (python 3.6)
!pip3 freeze
absl-py==0.10.0 alabaster==0.7.12 albumentations==0.1.12 altair==4.1.0 argon2-cffi==20.1.0 asgiref==3.3.1 astor==0.8.1 astropy==4.1 astunparse==1.6.3 async-generator==1.10 atari-py==0.2.6 atomicwrites==1.4.0 attrs==20.3.0 audioread==2.1.9 autograd==1.3 Babel==2.9.0 backcall==0.2.0 beautifulsoup4==4.6.3 bleach==3.2.1 blis==0.4.1 bokeh==2.1.1 Bottleneck==1.3.2 branca==0.4.1 bs4==0.0.1 CacheControl==0.12.6 cachetools==4.2.0 catalogue==1.0.0 certifi==2020.12.5 cffi==1.14.4 chainer==7.4.0 chardet==3.0.4 click==7.1.2 cloudpickle==1.3.0 cmake==3.12.0 cmdstanpy==0.9.5 colorlover==0.3.0 community==1.0.0b1 contextlib2==0.5.5 convertdate==2.2.0 coverage==3.7.1 coveralls==0.5 crcmod==1.7 cufflinks==0.17.3 cvxopt==1.2.5 cvxpy==1.0.31 cycler==0.10.0 cymem==2.0.5 Cython==0.29.21 daft==0.0.4 dask==2.12.0 dataclasses==0.8 datascience==0.10.6 debugpy==1.0.0 decorator==4.4.2 defusedxml==0.6.0 descartes==1.1.0 dill==0.3.3 distributed==1.25.3 Django==3.1.4 dlib==19.18.0 dm-tree==0.1.5 docopt==0.6.2 docutils==0.16 dopamine-rl==1.0.5 earthengine-api==0.1.238 easydict==1.9 ecos==2.0.7.post1 editdistance==0.5.3 en-core-web-sm==2.2.5 entrypoints==0.3 ephem==3.7.7.1 et-xmlfile==1.0.1 fa2==0.3.5 fancyimpute==0.4.3 fastai==1.0.61 fastdtw==0.3.4 fastprogress==1.0.0 fastrlock==0.5 fbprophet==0.7.1 feather-format==0.4.1 filelock==3.0.12 firebase-admin==4.4.0 fix-yahoo-finance==0.0.22 Flask==1.1.2 flatbuffers==1.12 folium==0.8.3 future==0.16.0 gast==0.3.3 GDAL==2.2.2 gdown==3.6.4 gensim==3.6.0 geographiclib==1.50 geopy==1.17.0 gin-config==0.4.0 glob2==0.7 google==2.0.3 google-api-core==1.16.0 google-api-python-client==1.7.12 google-auth==1.17.2 google-auth-httplib2==0.0.4 google-auth-oauthlib==0.4.2 google-cloud-bigquery==1.21.0 google-cloud-bigquery-storage==1.1.0 google-cloud-core==1.0.3 google-cloud-datastore==1.8.0 google-cloud-firestore==1.7.0 google-cloud-language==1.2.0 google-cloud-storage==1.18.1 google-cloud-translate==1.5.0 google-colab==1.0.0 google-pasta==0.2.0 google-resumable-media==0.4.1 googleapis-common-protos==1.52.0 googledrivedownloader==0.4 graphviz==0.10.1 grpcio==1.32.0 gspread==3.0.1 gspread-dataframe==3.0.8 gym==0.17.3 h5py==2.10.0 HeapDict==1.0.1 holidays==0.10.4 holoviews==1.13.5 html5lib==1.0.1 httpimport==0.5.18 httplib2==0.17.4 httplib2shim==0.0.3 humanize==0.5.1 hyperopt==0.1.2 ideep4py==2.0.0.post3 idna==2.10 image==1.5.33 imageio==2.4.1 imagesize==1.2.0 imbalanced-learn==0.4.3 imblearn==0.0 imgaug==0.2.9 importlib-metadata==3.3.0 importlib-resources==3.3.0 imutils==0.5.3 inflect==2.1.0 iniconfig==1.1.1 intel-openmp==2021.1.1 intervaltree==2.1.0 ipykernel==4.10.1 ipython==5.5.0 ipython-genutils==0.2.0 ipython-sql==0.3.9 ipywidgets==7.5.1 itsdangerous==1.1.0 jax==0.2.7 jaxlib==0.1.57+cuda101 jdcal==1.4.1 jedi==0.17.2 jieba==0.42.1 Jinja2==2.11.2 joblib==1.0.0 jpeg4py==0.1.4 jsonschema==2.6.0 jupyter==1.0.0 jupyter-client==5.3.5 jupyter-console==5.2.0 jupyter-core==4.7.0 jupyterlab-pygments==0.1.2 kaggle==1.5.10 kapre==0.1.3.1 Keras==2.4.3 Keras-Preprocessing==1.1.2 keras-vis==0.4.1 kiwisolver==1.3.1 knnimpute==0.1.0 korean-lunar-calendar==0.2.1 librosa==0.6.3 lightgbm==2.2.3 llvmlite==0.31.0 lmdb==0.99 lucid==0.3.8 LunarCalendar==0.0.9 lxml==4.2.6 Markdown==3.3.3 MarkupSafe==1.1.1 matplotlib==3.2.2 matplotlib-venn==0.11.6 missingno==0.4.2 mistune==0.8.4 mizani==0.6.0 mkl==2019.0 mlxtend==0.14.0 more-itertools==8.6.0 moviepy==0.2.3.5 mpmath==1.1.0 msgpack==1.0.1 multiprocess==0.70.11.1 multitasking==0.0.9 murmurhash==1.0.5 music21==5.5.0 natsort==5.5.0 nbclient==0.5.1 nbconvert==5.6.1 nbformat==5.0.8 nest-asyncio==1.4.3 networkx==2.5 nibabel==3.0.2 nltk==3.2.5 notebook==5.3.1 np-utils==0.5.12.1 numba==0.48.0 numexpr==2.7.1 numpy==1.19.4 nvidia-ml-py3==7.352.0 oauth2client==4.1.3 oauthlib==3.1.0 okgrade==0.4.3 opencv-contrib-python==4.1.2.30 opencv-python==4.1.2.30 openpyxl==2.5.9 opt-einsum==3.3.0 osqp==0.6.1 packaging==20.8 palettable==3.3.0 pandas==1.1.5 pandas-datareader==0.9.0 pandas-gbq==0.13.3 pandas-profiling==1.4.1 pandocfilters==1.4.3 panel==0.9.7 param==1.10.0 parso==0.7.1 pathlib==1.0.1 patsy==0.5.1 pexpect==4.8.0 pickleshare==0.7.5 Pillow==7.0.0 pip-tools==4.5.1 plac==1.1.3 plotly==4.4.1 plotnine==0.6.0 pluggy==0.7.1 portpicker==1.3.1 prefetch-generator==1.0.1 preshed==3.0.5 prettytable==2.0.0 progressbar2==3.38.0 prometheus-client==0.9.0 promise==2.3 prompt-toolkit==1.0.18 protobuf==3.12.4 psutil==5.4.8 psycopg2==2.7.6.1 ptyprocess==0.6.0 py==1.10.0 pyarrow==0.14.1 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycocotools==2.0.2 pycparser==2.20 pyct==0.4.8 pydata-google-auth==1.1.0 pydot==1.3.0 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 pyemd==0.5.1 pyglet==1.5.0 Pygments==2.6.1 pygobject==3.26.1 pymc3==3.7 PyMeeus==0.3.7 pymongo==3.11.2 pymystem3==0.2.0 PyOpenGL==3.1.5 pyparsing==2.4.7 pyrsistent==0.17.3 pysndfile==1.3.8 PySocks==1.7.1 pystan==2.19.1.1 pytest==3.6.4 python-apt==1.6.5+ubuntu0.4 python-chess==0.23.11 python-dateutil==2.8.1 python-louvain==0.14 python-slugify==4.0.1 python-utils==2.4.0 pytz==2018.9 pyviz-comms==0.7.6 PyWavelets==1.1.1 PyYAML==3.13 pyzmq==20.0.0 qtconsole==5.0.1 QtPy==1.9.0 regex==2019.12.20 requests==2.23.0 requests-oauthlib==1.3.0 resampy==0.2.2 retrying==1.3.3 rpy2==3.2.7 rsa==4.6 scikit-image==0.16.2 scikit-learn==0.22.2.post1 scipy==1.4.1 screen-resolution-extra==0.0.0 scs==2.1.2 seaborn==0.11.0 Send2Trash==1.5.0 setuptools-git==1.2 Shapely==1.7.1 simplegeneric==0.8.1 six==1.15.0 sklearn==0.0 sklearn-pandas==1.8.0 smart-open==4.0.1 snowballstemmer==2.0.0 sortedcontainers==2.3.0 spacy==2.2.4 Sphinx==1.8.5 sphinxcontrib-serializinghtml==1.1.4 sphinxcontrib-websupport==1.2.4 SQLAlchemy==1.3.20 sqlparse==0.4.1 srsly==1.0.5 statsmodels==0.10.2 sympy==1.1.1 tables==3.4.4 tabulate==0.8.7 tblib==1.7.0 tensorboard==2.4.0 tensorboard-plugin-wit==1.7.0 tensorboardcolab==0.0.22 tensorflow==2.4.0 tensorflow-addons==0.8.3 tensorflow-datasets==4.0.1 tensorflow-estimator==2.4.0 tensorflow-gcs-config==2.4.0 tensorflow-hub==0.10.0 tensorflow-metadata==0.26.0 tensorflow-privacy==0.2.2 tensorflow-probability==0.11.0 termcolor==1.1.0 terminado==0.9.1 testpath==0.4.4 text-unidecode==1.3 textblob==0.15.3 textgenrnn==1.4.1 Theano==1.0.5 thinc==7.4.0 tifffile==2020.9.3 toml==0.10.2 toolz==0.11.1 torch==1.7.0+cu101 torchsummary==1.5.1 torchtext==0.3.1 torchvision==0.8.1+cu101 tornado==5.1.1 tqdm==4.41.1 traitlets==4.3.3 tweepy==3.6.0 typeguard==2.7.1 typing-extensions==3.7.4.3 tzlocal==1.5.1 umap-learn==0.4.6 uritemplate==3.0.1 urllib3==1.24.3 vega-datasets==0.9.0 wasabi==0.8.0 wcwidth==0.2.5 webencodings==0.5.1 Werkzeug==1.0.1 widgetsnbextension==3.5.1 wordcloud==1.5.0 wrapt==1.12.1 xarray==0.15.1 xgboost==0.90 xkit==0.0.0 xlrd==1.1.0 xlwt==1.3.0 yellowbrick==0.9.1 zict==2.0.0 zipp==3.4.0