(旧) ヒノマルクのデータ分析ブログ

どんな企業、地域、国でも働けるようなスキルを身につけていくブログ

PythonやBIツールを使って分析と可視化をします

[Python] その2の2 ボストンの住宅価格を重回帰分析ver2で予測してみた

重回帰分析を以前やりましたが、 説明変数はマニュアル方式で 集計結果のグラフを見ながら選んでいました。

今回はバージョンアップして全ての説明変数の組み合わせを 試して一番精度がよいモデルを選択するプログラムを組みました。

精度の評価方法は色々ありますが、本記事ではMean Absolute Error(MAE)の値が 一番小さいモデルをベストモデルとして採用するようにしています。

※ MAE = 日本語では平均絶対誤差と呼ぶそうです。分析のコンペなどでたまに 評価方法で採用されています。

Checking Boston housing prices corrected dataset 2_2

パッケージのインポート
import packages

import seaborn as sns
import pandas as pd
import numpy as np

データの読み込み
reading data

df = pd.read_csv("http://lib.stat.cmu.edu/datasets/boston_corrected.txt",skiprows=9,sep="\t")

※ 22年2月1日更新: UnicodeDecodeErrorが発生する方はencoding='Windows-1252'を追加してください。
詳細は下記新ブログをご参照ください。


pd.DataFrame(df.columns)
Out[0]
0
0 OBS.
1 TOWN
2 TOWN#
3 TRACT
4 LON
5 LAT
6 MEDV
7 CMEDV
8 CRIM
9 ZN
10 INDUS
11 CHAS
12 NOX
13 RM
14 AGE
15 DIS
16 RAD
17 TAX
18 PTRATIO
19 B
20 LSTAT
df.head(3)
Out[0]
OBS. TOWN TOWN# TRACT LON LAT MEDV CMEDV CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
0 1 Nahant 0 2011 -70.955 42.2550 24.0 24.0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
1 2 Swampscott 1 2021 -70.950 42.2875 21.6 21.6 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
2 3 Swampscott 1 2022 -70.936 42.2830 34.7 34.7 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03

モデリング

重回帰分析でやるが、変数は組み合わせをすべて試して一番精度のよいモデルを選択する

サンプリング (訓練データ、テストデータ分割)

# 訓練データとテストデータに分割する。
from sklearn.model_selection import train_test_split
train, test = train_test_split(df, test_size=0.20,random_state=100)
# 訓練データの件数確認
train.count()["OBS."]
Out[0]
 404 
# テストデータの件数確認
test.count()["OBS."]
Out[0]
102 

訓練データ、テストデータ作成

# 分析に利用する変数に限定
# 本当だったら事前に変数選択で利用するカラムを限定しておく

anacols=[
  "CRIM"  # 1人当たりの犯罪数
, "ZN" #町別の25,000平方フィート(7600m2)以上の住居区画の割合
, "INDUS" #町別の非小売業が占める土地面積の割合
, "CHAS" #チャールズ川沿いかどうか
, "NOX" #町別の窒素酸化物の濃度
, "RM" #住居の平均部屋数
, "AGE" #持ち家住宅
, "DIS" #5つのボストン雇用センターへの重み付き距離
, "RAD" #町別の環状高速道路へのアクセスのしやすさ
, "TAX" #町別の$10,000ドルあたりの固定資産税率
, "PTRATIO" #町別の生徒と先生の比率
, "B" #1000*(黒人人口割合 - 0.63)2
, "LSTAT" #貧困人口割合
]

# 訓練データ
X_train = train[anacols]  # 説明変数
Y_train=train["CMEDV"] # 目的変数

# テストデータ
X_test = test[anacols] # 説明変数
Y_test=test["CMEDV"] # 目的変数

# 欠損処理
# nullがあれば0埋めする。平均値や最頻値でもいい
X_train = X_train.fillna(0)
Y_train = Y_train.fillna(0)
X_test = X_test.fillna(0)
Y_test = Y_test.fillna(0)
X_train
Out[0]
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
379 17.86670 0.0 18.10 0 0.6710 6.223 100.0 1.3861 24 666 20.2 393.74 21.78
311 0.79041 0.0 9.90 0 0.5440 6.122 52.8 2.6403 4 304 18.4 396.90 5.98
157 1.22358 0.0 19.58 0 0.6050 6.943 97.4 1.8773 5 403 14.7 363.43 4.59
244 0.20608 22.0 5.86 0 0.4310 5.593 76.5 7.9549 7 330 19.1 372.49 12.50
56 0.02055 85.0 0.74 0 0.4100 6.383 35.7 9.1876 2 313 17.3 396.90 5.77
... ... ... ... ... ... ... ... ... ... ... ... ... ...
343 0.02543 55.0 3.78 0 0.4840 6.696 56.4 5.7321 5 370 17.6 396.90 7.18
359 4.26131 0.0 18.10 0 0.7700 6.112 81.3 2.5091 24 666 20.2 390.74 12.67
323 0.28392 0.0 7.38 0 0.4930 5.708 74.3 4.7211 5 287 19.6 391.13 11.74
280 0.03578 20.0 3.33 0 0.4429 7.820 64.5 4.6947 5 216 14.9 387.31 3.76
8 0.21124 12.5 7.87 0 0.5240 5.631 100.0 6.0821 5 311 15.2 386.63 29.93

404 rows × 13 columns

Y_train
Out[0]
    379    10.2
    311    22.1
    157    41.3
    244    17.6
    56     24.7
           ... 
    343    23.9
    359    22.6
    323    18.5
    280    45.4
    8      16.5
    Name: CMEDV, Length: 404, dtype: float64

モデル作成 (精度が一番よいモデルを探索する)

import sys
import itertools
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error

comblist=[]
best_model=None
best_features=None
best_mae=sys.maxsize

# 変数の選択数 (1 ~ 最大選択可能数)
for i in range(1,len(anacols) + 1):

  # 変数の選択数に合わせた組み合わせを作成
  comblist = list(itertools.combinations(anacols,i))
  for featurecomb in comblist:
    # 重回帰モデル作成
    multi_regression = LinearRegression()
    multi_regression.fit(X_train[list(featurecomb)],Y_train)

    # テストデータに当てはめる
    yhat_test = multi_regression.predict(X_test[list(featurecomb)])

    # 精度(MAE) 他にも様々な評価方法がある
    mae = mean_absolute_error(Y_test, yhat_test)
    
    #一番よい精度のモデルを探索
    if  mae < best_mae:
      best_mae = mae
      best_features = featurecomb
      best_model = multi_regression

print(str(best_mae))
print(best_features)
Out[0]
    3.1841915998253896
    ('CRIM', 'ZN', 'NOX', 'RM', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT')

作成したモデルの確認

# 係数逆転現象の確認
pd.DataFrame({"name":X_train[list(best_features)].columns,"coefficients":best_model.coef_})
Out[0]
name coefficients
0 CRIM -0.087462
1 ZN 0.050444
2 NOX -15.698495
3 RM 3.656426
4 DIS -1.530007
5 RAD 0.312503
6 TAX -0.013988
7 PTRATIO -0.936410
8 B 0.009886
9 LSTAT -0.504771

ボストン雇用センターへの重み付き距離 (DIS) と 固定資産税 (TAX) の係数

がマイナスなのは少し気になるところ。

各説明変数と住宅価格の関係を見た限りの考察としては、

DISは距離が3マイル?以内までは正の強い相関がありそうで、

3マイル以上は無相関のような分布をしている。

TAXは一部の固定資産税が高い地域の住宅価格が低めなのが

負の相関となるような影響を与えていると思われる。

精度上は一番よいのだが、私だったらこのモデルは選ばず

2番目に精度がよい他のモデルを選択します

精度確認

# 自由度調整済みr2を算出
def adjusted_r2(X,Y,model):
   from sklearn.metrics import r2_score
   import numpy as np
   r_squared = r2_score(Y, model.predict(X))
   adjusted_r2 = 1 - (1-r_squared)*(len(Y)-1)/(len(Y)-X.shape[1]-1)
   return adjusted_r2

# 予測モデルの精度確認の各種指標を算出
def get_model_evaluations(X_train,Y_train,X_test,Y_test,model):
   from sklearn.metrics import explained_variance_score
   from sklearn.metrics import mean_absolute_error
   from sklearn.metrics import mean_squared_error
   from sklearn.metrics import mean_squared_log_error
   from sklearn.metrics import median_absolute_error

   # 評価指標確認
   # 参考: https://funatsu-lab.github.io/open-course-ware/basic-theory/accuracy-index/
   yhat_test = model.predict(X_test)
   return  "adjusted_r2(train)     :" + str(adjusted_r2(X_train,Y_train,model)) \
         , "adjusted_r2(test)      :" + str(adjusted_r2(X_test,Y_test,model)) \
         , "平均誤差率(test)       :" + str(np.mean(abs(Y_test / yhat_test - 1))) \
         , "MAE(test)              :" + str(mean_absolute_error(Y_test, yhat_test)) \
         , "MedianAE(test)         :" + str(median_absolute_error(Y_test, yhat_test)) \
         , "RMSE(test)             :" + str(np.sqrt(mean_squared_error(Y_test, yhat_test))) \
         , "RMSE(test) / MAE(test) :" + str(np.sqrt(mean_squared_error(Y_test, yhat_test)) / mean_absolute_error(Y_test, yhat_test)) #better if result = 1.253
get_model_evaluations(X_train[list(best_features)],Y_train,X_test[list(best_features)],Y_test,best_model)
Out[0]
    ('adjusted_r2(train)     :0.7224056607466813',
     'adjusted_r2(test)      :0.7386023143239528',
     '平均誤差率(test)       :0.15322170833625723',
     'MAE(test)              :3.1841915998253896',
     'MedianAE(test)         :2.501378017150664',
     'RMSE(test)             :4.7670541453730815',
     'RMSE(test) / MAE(test) :1.4971002830465638')

trainとtestの調整済みR2が同じくらいなのでオーバーフィットはしてなさそう。

# 描画設定
from matplotlib import rcParams
rcParams['xtick.labelsize'] = 12       # x軸のラベルのフォントサイズ
rcParams['ytick.labelsize'] = 12       # y軸のラベルのフォントサイズ
rcParams['figure.figsize'] = 18,8      # 画像サイズの変更(inch)

import matplotlib.pyplot as plt
from matplotlib import ticker
sns.set_style("whitegrid")             # seabornのスタイルセットの一つ
sns.set_color_codes()                  # デフォルトカラー設定 (deepになってる)

plt.figure()
ax = sns.regplot(x=Y_test, y=best_model.predict(X_test[list(best_features)]), fit_reg=False,color='#4F81BD')
ax.set_xlabel(u"CMEDV")
ax.set_ylabel(u"(Predicted) CMEDV")
ax.get_xaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda y, p: format(int(y), ',')))
ax.plot([0,10,20,30,40,50],[0,10,20,30,40,50], linewidth=2, color="#C0504D",ls="--")

利用バージョンの確認

!python -V
Out[0]
    Python 3.6.9
!pip3 -V
Out[0]
    pip 19.3.1 from /usr/local/lib/python3.6/dist-packages/pip (python 3.6)
!pip3 freeze
Out[0]
    absl-py==0.10.0
    alabaster==0.7.12
    albumentations==0.1.12
    altair==4.1.0
    argon2-cffi==20.1.0
    asgiref==3.3.1
    astor==0.8.1
    astropy==4.1
    astunparse==1.6.3
    async-generator==1.10
    atari-py==0.2.6
    atomicwrites==1.4.0
    attrs==20.3.0
    audioread==2.1.9
    autograd==1.3
    Babel==2.9.0
    backcall==0.2.0
    beautifulsoup4==4.6.3
    bleach==3.2.1
    blis==0.4.1
    bokeh==2.1.1
    Bottleneck==1.3.2
    branca==0.4.1
    bs4==0.0.1
    CacheControl==0.12.6
    cachetools==4.2.0
    catalogue==1.0.0
    certifi==2020.12.5
    cffi==1.14.4
    chainer==7.4.0
    chardet==3.0.4
    click==7.1.2
    cloudpickle==1.3.0
    cmake==3.12.0
    cmdstanpy==0.9.5
    colorlover==0.3.0
    community==1.0.0b1
    contextlib2==0.5.5
    convertdate==2.2.0
    coverage==3.7.1
    coveralls==0.5
    crcmod==1.7
    cufflinks==0.17.3
    cvxopt==1.2.5
    cvxpy==1.0.31
    cycler==0.10.0
    cymem==2.0.5
    Cython==0.29.21
    daft==0.0.4
    dask==2.12.0
    dataclasses==0.8
    datascience==0.10.6
    debugpy==1.0.0
    decorator==4.4.2
    defusedxml==0.6.0
    descartes==1.1.0
    dill==0.3.3
    distributed==1.25.3
    Django==3.1.4
    dlib==19.18.0
    dm-tree==0.1.5
    docopt==0.6.2
    docutils==0.16
    dopamine-rl==1.0.5
    earthengine-api==0.1.238
    easydict==1.9
    ecos==2.0.7.post1
    editdistance==0.5.3
    en-core-web-sm==2.2.5
    entrypoints==0.3
    ephem==3.7.7.1
    et-xmlfile==1.0.1
    fa2==0.3.5
    fancyimpute==0.4.3
    fastai==1.0.61
    fastdtw==0.3.4
    fastprogress==1.0.0
    fastrlock==0.5
    fbprophet==0.7.1
    feather-format==0.4.1
    filelock==3.0.12
    firebase-admin==4.4.0
    fix-yahoo-finance==0.0.22
    Flask==1.1.2
    flatbuffers==1.12
    folium==0.8.3
    future==0.16.0
    gast==0.3.3
    GDAL==2.2.2
    gdown==3.6.4
    gensim==3.6.0
    geographiclib==1.50
    geopy==1.17.0
    gin-config==0.4.0
    glob2==0.7
    google==2.0.3
    google-api-core==1.16.0
    google-api-python-client==1.7.12
    google-auth==1.17.2
    google-auth-httplib2==0.0.4
    google-auth-oauthlib==0.4.2
    google-cloud-bigquery==1.21.0
    google-cloud-bigquery-storage==1.1.0
    google-cloud-core==1.0.3
    google-cloud-datastore==1.8.0
    google-cloud-firestore==1.7.0
    google-cloud-language==1.2.0
    google-cloud-storage==1.18.1
    google-cloud-translate==1.5.0
    google-colab==1.0.0
    google-pasta==0.2.0
    google-resumable-media==0.4.1
    googleapis-common-protos==1.52.0
    googledrivedownloader==0.4
    graphviz==0.10.1
    grpcio==1.32.0
    gspread==3.0.1
    gspread-dataframe==3.0.8
    gym==0.17.3
    h5py==2.10.0
    HeapDict==1.0.1
    holidays==0.10.4
    holoviews==1.13.5
    html5lib==1.0.1
    httpimport==0.5.18
    httplib2==0.17.4
    httplib2shim==0.0.3
    humanize==0.5.1
    hyperopt==0.1.2
    ideep4py==2.0.0.post3
    idna==2.10
    image==1.5.33
    imageio==2.4.1
    imagesize==1.2.0
    imbalanced-learn==0.4.3
    imblearn==0.0
    imgaug==0.2.9
    importlib-metadata==3.3.0
    importlib-resources==3.3.0
    imutils==0.5.3
    inflect==2.1.0
    iniconfig==1.1.1
    intel-openmp==2021.1.1
    intervaltree==2.1.0
    ipykernel==4.10.1
    ipython==5.5.0
    ipython-genutils==0.2.0
    ipython-sql==0.3.9
    ipywidgets==7.5.1
    itsdangerous==1.1.0
    jax==0.2.7
    jaxlib==0.1.57+cuda101
    jdcal==1.4.1
    jedi==0.17.2
    jieba==0.42.1
    Jinja2==2.11.2
    joblib==1.0.0
    jpeg4py==0.1.4
    jsonschema==2.6.0
    jupyter==1.0.0
    jupyter-client==5.3.5
    jupyter-console==5.2.0
    jupyter-core==4.7.0
    jupyterlab-pygments==0.1.2
    kaggle==1.5.10
    kapre==0.1.3.1
    Keras==2.4.3
    Keras-Preprocessing==1.1.2
    keras-vis==0.4.1
    kiwisolver==1.3.1
    knnimpute==0.1.0
    korean-lunar-calendar==0.2.1
    librosa==0.6.3
    lightgbm==2.2.3
    llvmlite==0.31.0
    lmdb==0.99
    lucid==0.3.8
    LunarCalendar==0.0.9
    lxml==4.2.6
    Markdown==3.3.3
    MarkupSafe==1.1.1
    matplotlib==3.2.2
    matplotlib-venn==0.11.6
    missingno==0.4.2
    mistune==0.8.4
    mizani==0.6.0
    mkl==2019.0
    mlxtend==0.14.0
    more-itertools==8.6.0
    moviepy==0.2.3.5
    mpmath==1.1.0
    msgpack==1.0.1
    multiprocess==0.70.11.1
    multitasking==0.0.9
    murmurhash==1.0.5
    music21==5.5.0
    natsort==5.5.0
    nbclient==0.5.1
    nbconvert==5.6.1
    nbformat==5.0.8
    nest-asyncio==1.4.3
    networkx==2.5
    nibabel==3.0.2
    nltk==3.2.5
    notebook==5.3.1
    np-utils==0.5.12.1
    numba==0.48.0
    numexpr==2.7.1
    numpy==1.19.4
    nvidia-ml-py3==7.352.0
    oauth2client==4.1.3
    oauthlib==3.1.0
    okgrade==0.4.3
    opencv-contrib-python==4.1.2.30
    opencv-python==4.1.2.30
    openpyxl==2.5.9
    opt-einsum==3.3.0
    osqp==0.6.1
    packaging==20.8
    palettable==3.3.0
    pandas==1.1.5
    pandas-datareader==0.9.0
    pandas-gbq==0.13.3
    pandas-profiling==1.4.1
    pandocfilters==1.4.3
    panel==0.9.7
    param==1.10.0
    parso==0.7.1
    pathlib==1.0.1
    patsy==0.5.1
    pexpect==4.8.0
    pickleshare==0.7.5
    Pillow==7.0.0
    pip-tools==4.5.1
    plac==1.1.3
    plotly==4.4.1
    plotnine==0.6.0
    pluggy==0.7.1
    portpicker==1.3.1
    prefetch-generator==1.0.1
    preshed==3.0.5
    prettytable==2.0.0
    progressbar2==3.38.0
    prometheus-client==0.9.0
    promise==2.3
    prompt-toolkit==1.0.18
    protobuf==3.12.4
    psutil==5.4.8
    psycopg2==2.7.6.1
    ptyprocess==0.6.0
    py==1.10.0
    pyarrow==0.14.1
    pyasn1==0.4.8
    pyasn1-modules==0.2.8
    pycocotools==2.0.2
    pycparser==2.20
    pyct==0.4.8
    pydata-google-auth==1.1.0
    pydot==1.3.0
    pydot-ng==2.0.0
    pydotplus==2.0.2
    PyDrive==1.3.1
    pyemd==0.5.1
    pyglet==1.5.0
    Pygments==2.6.1
    pygobject==3.26.1
    pymc3==3.7
    PyMeeus==0.3.7
    pymongo==3.11.2
    pymystem3==0.2.0
    PyOpenGL==3.1.5
    pyparsing==2.4.7
    pyrsistent==0.17.3
    pysndfile==1.3.8
    PySocks==1.7.1
    pystan==2.19.1.1
    pytest==3.6.4
    python-apt==1.6.5+ubuntu0.4
    python-chess==0.23.11
    python-dateutil==2.8.1
    python-louvain==0.14
    python-slugify==4.0.1
    python-utils==2.4.0
    pytz==2018.9
    pyviz-comms==0.7.6
    PyWavelets==1.1.1
    PyYAML==3.13
    pyzmq==20.0.0
    qtconsole==5.0.1
    QtPy==1.9.0
    regex==2019.12.20
    requests==2.23.0
    requests-oauthlib==1.3.0
    resampy==0.2.2
    retrying==1.3.3
    rpy2==3.2.7
    rsa==4.6
    scikit-image==0.16.2
    scikit-learn==0.22.2.post1
    scipy==1.4.1
    screen-resolution-extra==0.0.0
    scs==2.1.2
    seaborn==0.11.0
    Send2Trash==1.5.0
    setuptools-git==1.2
    Shapely==1.7.1
    simplegeneric==0.8.1
    six==1.15.0
    sklearn==0.0
    sklearn-pandas==1.8.0
    smart-open==4.0.1
    snowballstemmer==2.0.0
    sortedcontainers==2.3.0
    spacy==2.2.4
    Sphinx==1.8.5
    sphinxcontrib-serializinghtml==1.1.4
    sphinxcontrib-websupport==1.2.4
    SQLAlchemy==1.3.20
    sqlparse==0.4.1
    srsly==1.0.5
    statsmodels==0.10.2
    sympy==1.1.1
    tables==3.4.4
    tabulate==0.8.7
    tblib==1.7.0
    tensorboard==2.4.0
    tensorboard-plugin-wit==1.7.0
    tensorboardcolab==0.0.22
    tensorflow==2.4.0
    tensorflow-addons==0.8.3
    tensorflow-datasets==4.0.1
    tensorflow-estimator==2.4.0
    tensorflow-gcs-config==2.4.0
    tensorflow-hub==0.10.0
    tensorflow-metadata==0.26.0
    tensorflow-privacy==0.2.2
    tensorflow-probability==0.11.0
    termcolor==1.1.0
    terminado==0.9.1
    testpath==0.4.4
    text-unidecode==1.3
    textblob==0.15.3
    textgenrnn==1.4.1
    Theano==1.0.5
    thinc==7.4.0
    tifffile==2020.9.3
    toml==0.10.2
    toolz==0.11.1
    torch==1.7.0+cu101
    torchsummary==1.5.1
    torchtext==0.3.1
    torchvision==0.8.1+cu101
    tornado==5.1.1
    tqdm==4.41.1
    traitlets==4.3.3
    tweepy==3.6.0
    typeguard==2.7.1
    typing-extensions==3.7.4.3
    tzlocal==1.5.1
    umap-learn==0.4.6
    uritemplate==3.0.1
    urllib3==1.24.3
    vega-datasets==0.9.0
    wasabi==0.8.0
    wcwidth==0.2.5
    webencodings==0.5.1
    Werkzeug==1.0.1
    widgetsnbextension==3.5.1
    wordcloud==1.5.0
    wrapt==1.12.1
    xarray==0.15.1
    xgboost==0.90
    xkit==0.0.0
    xlrd==1.1.0
    xlwt==1.3.0
    yellowbrick==0.9.1
    zict==2.0.0
    zipp==3.4.0