ka9e
1/4/2015 - 9:05 AM

## curve fitting

curve fitting

%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://yuki-koyama.hatenablog.com/entry/2014/05/04/132552
%%% (3) http://taku-k.hatenablog.com/entry/2013/11/16/203644

% clear all;

N = 100; % number of sampling
RES = 200; % resolution of output curves
L = 1; %

%rbf = 'gauss';

x = linspace(-5, 5, N);
y = sin(x) + 0.25 * randn(size(x));
X = linspace(-5, 5, RES);

%if strcmp(rbf, 'gauss')
h = zeros(N, N);

for i = 1 : N
for j = 1 : N
r = abs(x(i) - x(j));
h(i, j) = normpdf(r);
end
end

w = pinv(h) * y';
w2 = (L * eye(N) + h' * h) \ (y * h)';

dist = abs(ones(RES, 1) * x - X' * ones(1, N));
%%% equivalent to:
%for i = 1 : NUM_PLOT
%  dist = abs(X(i) * ones(1, N) - x);
%  % Y(i) = dot(ws, normpdf(x, X(i), 1));
%  Y(i) = dot(w, normpdf(dist));
%end

ND = normpdf(dist);
Y = ND * w;
Y2 = ND * w2;
%end

figure(1)
subplot(2, 1, 1);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Gaussian')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');

%if strcmp(rbf, 'poly')
M = 12; % degree of polynomial
h = zeros(N, M);

for j = 1 : M
h(:, j) = x .^ (j-1);
end

w = pinv(h) * y';
w2 = (L * eye(M) + h' * h) \ (y * h)';

X2 = zeros(size(X));
for i = 1 : M
X2(i, :) = X.^(i-1);
end

Y = w' * X2;
Y2 = w2' * X2;
%end

subplot(2, 1, 2);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Polynomial')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');


%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://aidiary.hatenablog.com/entries/2014/01/22

clear all;

N = 100;
EPS = 0.01;
ETA = 0.1;
LOOP = 500;

D = 2; % in
M = 4; % hidden
K = 1; % out

X = linspace(-5, 5, N);
bias = ones(size(X));
% T = sin(X);
T = sin(X) + 0.25 * randn(size(X));

w1 = randn(M, D);
w2 = randn(K, M);

function error = sum_sq_error(x, t, w1, w2)
error = 0.0;
z = tanh(w1 * x);
y = w2 * z;
error += sum((y - t).^2) / 2;
end

xs = vertcat(ones(size(X)), X);
errs = zeros(1, LOOP);

c = 0;
%err1 = 0;
%err2 = Inf;

for _ = 1:LOOP
%while abs(err2 - err1) > EPS
for n = 1:N
x = [1; X(n)];
% x = vertcat(ones(size(X)), X);

z = tanh(w1 * x);
y = w2 * z;

d2 = y - T(n);
d1 = (1 - z.^2) .* w2' * d2;
% d1(j) = (1 - z(j)^2) * w2(j) * d2;

w1 -= ETA * d1 * x';
w2 -= ETA * d2 * z';

end
errs(_) = sum_sq_error(xs, T, w1, w2);
%err1 = err2;
%err2 = sum_sq_error(xs, T, w1, w2);
%c += 1;
end

Z = tanh(w1 * vertcat(ones(size(X)), X));
Y = w2 * Z;

figure(1)
plot(X, T, 'o', X, Y)
% plot(errs(50:LOOP))