razhangwei
2/11/2015 - 5:39 PM

sklearn: tune parameter using cross validation

sklearn: tune parameter using cross validation

"""
This file uses cross validation to tune the parameters.
"""
import multiprocessing
import sys

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, ParameterGrid
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm

from model.hmm import ConstrainedMixHMM
from utils.utils import load_data, send_email

sns.set_style('whitegrid')
np.set_printoptions(precision=3)
pd.set_option('precision', 3)

task_name = "HMMCrossValidation"
input_folder = '../data/output'
output_folder = '../data/output/%s' % task_name
figure_folder = '../figure/%s' % task_name

data_config = dict(
    path='%s/RNZS_CC_F_include_missing.csv' % (input_folder),
    aMCI_only=True,
    split_aMCI_naMCI=False,
    split_early_clinic_MCI=True,
    include_missing_x=True,
    include_missing_y=True,
    impute_missing=False,
    reviewed_label_only=False
)

run_config = dict(
    run_times=10,
    verbose=False,
    plot=False
)

model_config = dict(
    monotonic_state=True,
    covariance_type='diag',
)

param_grid = {
    'transmat_type': ['upper-bidiagonal'],
    'n_components': range(3, 15) + [17, 20, 24]
}

def _fit(X, lengths, random_state, config):
    """fit the model"""

    return ConstrainedMixHMM(random_state=random_state, **config).fit(X, lengths)


def fit_model(parallel, X, lengths, config):
    """ fit model with multiple times using different random starts

    config : dict
        model config
    """

    if run_config['verbose']:
        print "Fitting the model..."

    models = parallel(
        delayed(_fit)(X, lengths, i, config)
        for i in range(run_config['run_times'])
    )

    scores = [m.monitor_.history[-1] for m in models]

    if run_config['plot']:
        plt.figure()
        sns.kdeplot(np.array(scores))
        plt.plot(np.max(scores), 0, 'ro')
        plt.title('kernel density esitmation of log-likelihood')

    return models[np.argmax(scores)]


def run_cross_validation(n_splits, notification=False):
    """ tune the number of hidden states using cross_validation
    Parameters
    ----------
    n_splits : int
        number of folds for cross validation

    notification : boolean
        Whether to notify the progress through wechat
    """

    X = load_data(**data_config)

    path = "%s/cv_splits.pkl" % output_folder
    try:
        k_splits = joblib.load(path)
    except IOError:
        k_splits = list(KFold(n_splits, shuffle=True).split(X.index.levels[0]))
        joblib.dump(k_splits, path)
        print k_splits

    n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
    with Parallel(n_jobs=n_jobs) as parallel:

        for params in tqdm(ParameterGrid(param_grid)):

            print "Fitting %r using cross-validation..." % params

            for i, (train_index, _) in tqdm(enumerate(k_splits)):

                DBID_train = X.index.levels[0][train_index].tolist()
                X_train = X.loc[DBID_train]
                lengths_train = X_train.groupby(level=0).size().values

                # load or traint the model
                path = "%s/cv=%d_transmat=%s_n=%d.pkl" % (
                    output_folder, i, params['transmat_type'], params['n_components'])

                try:
                    model = ConstrainedMixHMM.load(path)

                except IOError:
                    if notification:
                        send_email("Cross Validation",
                                   "Start fitting %d-fold with %r." % (i, params))

                    model_config['n_components'] = params['n_components']
                    model_config['transmat_type'] = params['transmat_type']

                    model = fit_model(parallel, X_train,
                                      lengths_train, model_config)

                    model.save(path)

                    if notification:
                        send_email("Cross Validation",
                                   "Finished fitting %d-fold with %r." % (i, params))


def fit_model_on_whole_dataset():

    X = load_data(**data_config)
    lengths = X.groupby(level=0).size().values

    n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
    with Parallel(n_jobs=n_jobs) as parallel:

        for params in tqdm(ParameterGrid(param_grid)):

            print "fitting %r ..." % params

            # load or traint the model
            filename = "%s/model_type=%s_n=%d.pkl" % (output_folder,
                params['transmat_type'], params['n_components'])

            try:
                model = ConstrainedMixHMM.load(filename)

            except IOError:

                model_config['n_components'] = params['n_components']
                model_config['transmat_type'] = params['transmat_type']

                model = fit_model(parallel, X, lengths, model_config)

                model.save(filename)


if __name__ == "__main__":
    assert len(sys.argv) == 2

    if "CV" in sys.argv[1]:
        run_cross_validation(n_splits=5, notification=True)

    if "WHOLE" in sys.argv[1]:
        fit_model_on_whole_dataset()