%%
mexcuda -g vomt_mex.cu

%%
n=256;
xx = linspace(-2,2,n);
yy = linspace(-2,2,n);
[XX,YY] = meshgrid(linspace(-2,2,n),linspace(-2,2,n));

%rho0 generator
rho0 = zeros(n,n,3);
for ii=1:n
    for jj=1:n
        if (sqrt((xx(jj)).^2+(yy(ii)).^2)<0.3)
            rho0(ii,jj,1) = 1;
            rho0(ii,jj,2) = 1;
            rho0(ii,jj,3) = 1;
        end
    end
end
rho0 = rho0 / sum(sum(sum(rho0)));


%rho1 generator
rho1 = zeros(n,n,3);
for ii=1:n
    for jj=1:n
        if (sqrt((xx(jj)-1).^2+(yy(ii)-1).^2)<0.2)
            rho1(ii,jj,1) = 1;
        end
        if (sqrt((xx(jj)+1).^2+(yy(ii)-1).^2)<0.2)
            rho1(ii,jj,2) = 1;
        end
        if (sqrt((xx(jj)-1).^2+(yy(ii)+1).^2)<0.2)
            rho1(ii,jj,3) = 1;
        end
    end
end
rho1 = rho1/sum(sum(sum(rho1)));
%%
n=256;

c = double(255-rgb2gray(imread('./ink_splat.png')));
ink = c(26:70,186:230);
%rho0 generator
rho0G = zeros(n,n);
rho0G(151:195,76:120)=ink;
rho0B = zeros(n,n);
rho0B(56:100,86:130)=ink;
rho0R = zeros(n,n);
rho0R(106:150,166:210)=ink;
rho0 = zeros(n,n,3);
for ii=1:n
    for jj=1:n
            rho0(ii,jj,1) = rho0R(ii,jj);
            rho0(ii,jj,2) = rho0G(ii,jj);
            rho0(ii,jj,3) = rho0B(ii,jj);
    end
end
rho0 = rho0 / sum(sum(sum(rho0)));


%rho1 generator
rho1 = double(imread('./ColorRing.png'));
rho1 = rho1/sum(sum(sum(rho1)));

%%
options = VOMT_options();
options.max_iter=150000;
options.tau = 5;
options.u_norm_type = 'l12 norm';
options.w_norm_type = 'l2 norm';
[d_val, u, w, phi]=VOMT(rho0, rho1,1, 1,1, 0.1, options);
 
%%

%rho1-rho0 plotter
%n=256;
[XX,YY] = meshgrid(linspace(-2,2,n),linspace(-2,2,n));

I=zeros(n,n,3);
I = I + rho0/max(max(max(rho0)));
I = I + rho1/max(max(max(rho1)));
imagesc([-2,2],[-2,2],I);
hold on;





n=256;
mmx = zeros(32,32);
mmy = zeros(32,32);
for ii=1:32
    for jj=1:32
        for kk=1:8
            for ll=1:8
                mmx(ii,jj) = mmx(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,1,1);
                mmy(ii,jj) = mmy(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,1,2);
            end
        end
    end
end
[XX,YY] = meshgrid(linspace(-2,2,32),linspace(-2,2,32));
%quiver(XX,YY,m(1:256,:),m((n+1):(2*n),:), 1)
quiver(XX,YY,mmy,mmx, 1,'color',[1, 0, 0])
hold on;



mmx = zeros(32,32);
mmy = zeros(32,32);
for ii=1:32
    for jj=1:32
        for kk=1:8
            for ll=1:8
                mmx(ii,jj) = mmx(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,2,1);
                mmy(ii,jj) = mmy(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,2,2);
            end
        end
    end
end
[XX,YY] = meshgrid(linspace(-2,2,32),linspace(-2,2,32));
%quiver(XX,YY,m(1:256,:),m((n+1):(2*n),:), 1)
quiver(XX,YY,mmy,mmx, 1,'color',[0, 1, 0])
hold on;



mmx = zeros(32,32);
mmy = zeros(32,32);
for ii=1:32
    for jj=1:32
        for kk=1:8
            for ll=1:8
                mmx(ii,jj) = mmx(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,3,1);
                mmy(ii,jj) = mmy(ii,jj) + u(8*(ii-1)+kk,8*(jj-1)+ll,3,2);
            end
        end
    end
end
[XX,YY] = meshgrid(linspace(-2,2,32),linspace(-2,2,32));
quiver(XX,YY,mmy,mmx, 1,'color',[0, 0, 1])
hold on;



daspect([max(daspect)*[1 1] 1])
pbaspect([1 1 1])

h=gca;
set(h,'XTick',[])
set(h,'YTick',[])

ax = gca;
ax.OuterPosition(3)=ax.OuterPosition(4);
outerpos = ax.OuterPosition;
ti = ax.TightInset; 
left = outerpos(1) + ti(1);
bottom = outerpos(2) + ti(2);
ax_width = outerpos(3) - ti(1) - ti(3);
ax_height = outerpos(4) - ti(2) - ti(4);
ax.Position = [left bottom ax_width ax_height];