Linear Regression (degree=1) no customizable.
## linear regression for hws
def regression_linear(namex:str,x:'np array',namey:str, y:'np array',
path_output:str='',isfit_intercept:bool=True,isplot:bool=True,stitle:str='') -> dict:
'''
Calculate linear regression, return results and plot if it is required.
namex -- name of x variable.
x -- data of x variable.
namey -- name of y variable.
y -- data of y variable.
path_output -- path of chart file (default "").
isfit_intercept -- is fix or not the intercept value on 0 (default True --> no fix).
isplot -- if is True, display by screen, if it is False store into a file (default True).
stitle -- extra information to be included in the plot title.
return -- results of linear regression.
'''
import numpy as np
import os
from sklearn import linear_model
# initialize
results = {}
# data preparation
X = np.reshape(x,(len(x),1))
# create linear regression object
regr = linear_model.LinearRegression(fit_intercept=isfit_intercept)
# train the model using the training sets
regr.fit(X,y)
# calculate scores
R2 = regr.score(X,y)
correlation = np.corrcoef(x,y)[0,1]
# calculate line
yhat = regr.predict(X)
# Set regression info
results['polynomial'] = [regr.coef_[0], regr.intercept_] # Coefficients
results['R2'] = R2 # R**2
results['correlation (Pearson)'] = correlation # correlation index
## PLOT
import matplotlib.pyplot as plt
# build object
fig, ax = plt.subplots(figsize=(6,7))
# plot scatter
plt.scatter(x,y,s=30,facecolors='blue', edgecolors='black')
# set title
plt.title('Linear Regression: %s'%stitle,fontsize=14)
# axes labels
plt.xlabel(namex)
plt.ylabel(namey)
# plot grid axis
ax.xaxis.grid(True)
ax.yaxis.grid(True)
# plot fitted line
ax.plot(x,yhat, label='fit', color="red")
# limits of axes
vmax = np.max([np.max(x),np.max(y)])
ax.set_xlim([0.,vmax])
ax.set_ylim([0.,vmax])
# include text box
textstr = 'Y = %.5f * X + %.5f \nR2 = %.5f \nR(Pearson) = %.5f \nN = %s'%(
results['polynomial'][0],
results['polynomial'][1],
results['R2'],
results['correlation (Pearson)'],
len(x)
)
position = ax.get_position() #position.x0, position.y0, position.width, position.height
bottom = position.y0
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.22, (bottom - 0.36),textstr, transform=ax.transAxes, fontsize=14,
verticalalignment='center', bbox=props)
# adjust space of chart
plt.subplots_adjust(bottom=0.25)
# display
if isplot: plt.show()
# save
else:
# save
if path_output!='':
plt.savefig(path_output,transparent=False)
else: print('WARNING: it was not included as argument a valid path output file.')
# close
plt.cla()
plt.clf()
plt.close()
# return info results
return results
# close
plt.cla()
plt.clf()
plt.close()
# return info results
return results