jmquintana79
5/18/2017 - 8:09 AM

Linear Regression (degree=1)

Linear Regression (degree=1) no customizable.

## linear regression for hws
def regression_linear(namex:str,x:'np array',namey:str, y:'np array',
                      path_output:str='',isfit_intercept:bool=True,isplot:bool=True,stitle:str='') -> dict:
    '''
    Calculate linear regression, return results and plot if it is required.
    namex -- name of x variable.
    x -- data of x variable.
    namey -- name of y variable.
    y -- data of y variable.
    path_output -- path of chart file (default "").
    isfit_intercept -- is fix or not the intercept value on 0 (default True --> no fix).
    isplot -- if is True, display by screen, if it is False store into a file (default True).
    stitle -- extra information to be included in the plot title.
    return -- results of linear regression.
    '''
  
    import numpy as np
    import os
    from sklearn import linear_model

    # initialize
    results = {}

    # data preparation
    X = np.reshape(x,(len(x),1))
    
    # create linear regression object
    regr = linear_model.LinearRegression(fit_intercept=isfit_intercept)
    # train the model using the training sets
    regr.fit(X,y)
    # calculate scores
    R2 = regr.score(X,y)
    correlation = np.corrcoef(x,y)[0,1]
    # calculate line
    yhat = regr.predict(X)  
     
    # Set regression info 
    results['polynomial'] = [regr.coef_[0], regr.intercept_] # Coefficients
    results['R2'] = R2 # R**2
    results['correlation (Pearson)'] = correlation  # correlation index

    
    
    
    ## PLOT
    import matplotlib.pyplot as plt

    # build object
    fig, ax = plt.subplots(figsize=(6,7))
    # plot scatter
    plt.scatter(x,y,s=30,facecolors='blue', edgecolors='black')

    # set title
    plt.title('Linear Regression: %s'%stitle,fontsize=14)
    
    # axes labels
    plt.xlabel(namex)
    plt.ylabel(namey)
    # plot grid axis
    ax.xaxis.grid(True)
    ax.yaxis.grid(True)
    # plot fitted line
    ax.plot(x,yhat, label='fit', color="red")


    # limits of axes
    vmax = np.max([np.max(x),np.max(y)])
    ax.set_xlim([0.,vmax])
    ax.set_ylim([0.,vmax])

    # include text box        
    textstr = 'Y = %.5f * X + %.5f \nR2 = %.5f \nR(Pearson) = %.5f \nN = %s'%(
                                       results['polynomial'][0],
                                       results['polynomial'][1],
                                       results['R2'],
                                       results['correlation (Pearson)'],
                                       len(x)
                                      )
    position = ax.get_position() #position.x0, position.y0,  position.width, position.height
    bottom = position.y0
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)      
    ax.text(0.22, (bottom - 0.36),textstr, transform=ax.transAxes, fontsize=14,
            verticalalignment='center', bbox=props)
    
    # adjust space of chart
    plt.subplots_adjust(bottom=0.25) 

    # display
    if isplot: plt.show()
    # save    
    else:
        # save
        if path_output!='':
            plt.savefig(path_output,transparent=False)
        else: print('WARNING: it was not included as argument a valid path output file.')
    # close
    plt.cla()
    plt.clf() 
    plt.close()
    
    # return info results
    return results

    # close
    plt.cla()
    plt.clf() 
    plt.close()
    
    # return info results
    return results