ka9e
1/4/2015 - 9:05 AM

curve fitting

curve fitting

%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://yuki-koyama.hatenablog.com/entry/2014/05/04/132552
%%% (3) http://taku-k.hatenablog.com/entry/2013/11/16/203644

% clear all;

N = 100; % number of sampling
RES = 200; % resolution of output curves
L = 1; % 

%rbf = 'gauss';

x = linspace(-5, 5, N);
y = sin(x) + 0.25 * randn(size(x));
X = linspace(-5, 5, RES);

%if strcmp(rbf, 'gauss')
  h = zeros(N, N);

  for i = 1 : N
    for j = 1 : N
        r = abs(x(i) - x(j));
        h(i, j) = normpdf(r);
    end
  end
  
  w = pinv(h) * y';
  w2 = (L * eye(N) + h' * h) \ (y * h)';
  
  dist = abs(ones(RES, 1) * x - X' * ones(1, N));
  %%% equivalent to: 
  %for i = 1 : NUM_PLOT
  %  dist = abs(X(i) * ones(1, N) - x);
  %  % Y(i) = dot(ws, normpdf(x, X(i), 1));
  %  Y(i) = dot(w, normpdf(dist));
  %end

  ND = normpdf(dist);
  Y = ND * w;
  Y2 = ND * w2;
%end

figure(1)
subplot(2, 1, 1);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Gaussian')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');

%if strcmp(rbf, 'poly')
  M = 12; % degree of polynomial
  h = zeros(N, M);

  for j = 1 : M
    h(:, j) = x .^ (j-1);
  end
  
  w = pinv(h) * y';
  w2 = (L * eye(M) + h' * h) \ (y * h)';
  
  X2 = zeros(size(X));
  for i = 1 : M
    X2(i, :) = X.^(i-1);
  end
  
  Y = w' * X2;
  Y2 = w2' * X2;
%end

subplot(2, 1, 2);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Polynomial')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');

%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://aidiary.hatenablog.com/entries/2014/01/22

clear all;

N = 100;
EPS = 0.01;
ETA = 0.1;
LOOP = 500;

D = 2; % in
M = 4; % hidden
K = 1; % out

X = linspace(-5, 5, N);
bias = ones(size(X));
% T = sin(X);
T = sin(X) + 0.25 * randn(size(X));

w1 = randn(M, D);
w2 = randn(K, M);

function error = sum_sq_error(x, t, w1, w2)  
  error = 0.0;
  z = tanh(w1 * x);
  y = w2 * z;
  error += sum((y - t).^2) / 2;
end

xs = vertcat(ones(size(X)), X);
errs = zeros(1, LOOP);

c = 0;
%err1 = 0;
%err2 = Inf;

for _ = 1:LOOP
%while abs(err2 - err1) > EPS
  for n = 1:N
    x = [1; X(n)];
    % x = vertcat(ones(size(X)), X);
    
    z = tanh(w1 * x);
    y = w2 * z;

    d2 = y - T(n);
    d1 = (1 - z.^2) .* w2' * d2;
    % d1(j) = (1 - z(j)^2) * w2(j) * d2;
    
    w1 -= ETA * d1 * x';
    w2 -= ETA * d2 * z';
    
  end
  errs(_) = sum_sq_error(xs, T, w1, w2);
  %err1 = err2;
  %err2 = sum_sq_error(xs, T, w1, w2);
  %c += 1;
end

Z = tanh(w1 * vertcat(ones(size(X)), X));
Y = w2 * Z;

figure(1)
plot(X, T, 'o', X, Y)
% plot(errs(50:LOOP))