%%
mexcuda -g somt_mex.cu

%%
n=256;

%rho0 generator
rho0 = double(255-rgb2gray(imread('cat1.png')));
rho0 = rho0 / sum(sum(rho0));


%rho1 generator
rho1 = double(255-rgb2gray(imread('cat2.png')));
rho1 = rho1/sum(sum(rho1));



%%
n=256;
[XX,YY] = meshgrid(linspace(-2,2,n),linspace(-2,2,n));

%rho0 generator
rho0 = zeros(n,n);
II = (sqrt((XX).^2+(YY).^2)<0.3);
rho0(II) = 1;
rho0 = rho0 / sum(sum(rho0));


%rho1 generator
rho1 = zeros(n,n);
II = (sqrt((XX-1).^2+(YY-1).^2)<0.2);
rho1(II) = 1;
II = (sqrt((XX+1).^2+(YY-1).^2)<0.2);
rho1(II) = 1;
II = (sqrt((XX-1).^2+(YY+1).^2)<0.2);
rho1(II) = 1;
II = (sqrt((XX+1).^2+(YY+1).^2)<0.2);
rho1(II) = 1;
rho1 = rho1/sum(sum(rho1));


%%
options = SOMT_options();
options.nu =1;
options.norm_type = 'l1 norm';
[dist, m, phi] = SOMT(rho0, rho1, options);
dist

%%
%m plotter
n=256;
mmx = zeros(32,32);
mmy = zeros(32,32);
for ii=1:32
    for jj=1:32
        for kk=1:8
            for ll=1:8
                mmx(ii,jj) = mmx(ii,jj) + m(8*(ii-1)+kk,8*(jj-1)+ll,1);
                mmy(ii,jj) = mmy(ii,jj) + m(8*(ii-1)+kk,8*(jj-1)+ll,2);
            end
        end
    end
end


I=zeros(n,n,3);
I(:,:,1) = I(:,:,1) + rho0/max(max(rho0))*50/255;
I(:,:,2) = I(:,:,2) + rho0/max(max(rho0))*132/255;
I(:,:,3) = I(:,:,3) + rho0/max(max(rho0))*191/255;
I(:,:,1) = I(:,:,1) + rho1/max(max(rho1))*255/255;
I(:,:,2) = I(:,:,2) + rho1/max(max(rho1))*232/255;
I(:,:,3) = I(:,:,3) + rho1/max(max(rho1))*0/255;
imagesc([-2,2],[-2,2],I);

hold on;
[XX,YY] = meshgrid(linspace(-2,2,32),linspace(-2,2,32));
%quiver(XX,YY,m(1:256,:),m((n+1):(2*n),:), 1)
quiver(XX,YY,mmy,mmx, 1,'color',[1, .3, .3])

daspect([max(daspect)*[1 1] 1])
pbaspect([1 1 1])

h=gca;
set(h,'XTick',[])
set(h,'YTick',[])

ax = gca;
ax.OuterPosition(3)=ax.OuterPosition(4);
outerpos = ax.OuterPosition;
ti = ax.TightInset; 
left = outerpos(1) + ti(1);
bottom = outerpos(2) + ti(2);
ax_width = outerpos(3) - ti(1) - ti(3);
ax_height = outerpos(4) - ti(2) - ti(4);
ax.Position = [left bottom ax_width ax_height];



%%
%rho1-rho0 plotter
%n=256;
[XX,YY] = meshgrid(linspace(-2,2,n),linspace(-2,2,n));
figure();
surf(XX,YY,rho1-rho0)
shading interp
xlabel('x'); ylabel('y')
title('rho1-rho0 plot')

%%
%Phi plotter 
n=256;
[XX,YY] = meshgrid(linspace(-2,2,n),linspace(-2,2,n));
figure();
surf(XX,YY,phi)
shading interp
light		
lighting gouraud
xlabel('x'); ylabel('y')
title('\Phi plot')