########################################################################
#           Mathematical and Numerical Optimization (2020-2)           #
#                Computational Tomography (CT) with PDHG               #
########################################################################
from copy import copy
import numpy as np
import matplotlib.pyplot as plt

from skimage.data import shepp_logan_phantom
from skimage.transform import radon, iradon, rescale


# execute %matplotlib qt in Spyder

########## Data Generation ##########
image = rescale(shepp_logan_phantom(), scale=.5, mode='reflect', multichannel=False)  # default 400 * 400
theta = np.linspace(0., 180., 70, endpoint=False)
sino = radon(image=image, theta=theta, circle=False)  #sinogram
sino += 1e-5 * np.random.normal(size=sino.shape)


########## Plot ##########
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
fig.suptitle(f"CT using PDHG : Iteration {0}", fontsize=20)
ax1.imshow(image, cmap='gray')
ax1.set_title("original image")
update_im = ax2.imshow(np.ones_like(image), cmap='gray')
ax2.set_title("reconstructed image")
plt.pause(1)









########## PDHG parameters and variable initialization ##########
N = image.shape[0]
max_iter = 1000
alpha, beta = 1e-2, 1e-4
lmbda = 1e-1

x = 1e-1 * np.random.normal(size=image.shape) # start with random image
u = np.zeros_like(sino)
vx = np.zeros_like(image)
vy = np.zeros_like(image)


########## Image Reconstruction with PDHG ##########
for i in range(max_iter) :
    x_prime = x

    Dpv = vx + vy 
    Dpv[:,1:] = Dpv[:,1:] - vx[:,:(N-1)]
    Dpv[1:,:] = Dpv[1:,:] - vy[:(N-1),:]
    
    # Update variable x
    x -= (1/alpha) * (alpha * iradon(u, theta=theta, circle=False) + beta * Dpv)
    x_prime = 2*x - x_prime
    
    # Update intermediate variable u
    u = 1/(1+alpha) * (u + alpha * (radon(x_prime, theta=theta, circle=False) - sino))
    
    vx[:,:(N-1)] = np.maximum(np.minimum(vx[:,:(N-1)] + beta * (x_prime[:,:(N-1)]-x_prime[:,1:]),lmbda * alpha / beta),-lmbda * alpha / beta)
    vy[:(N-1),:] = np.maximum(np.minimum(vy[:(N-1),:] + beta * (x_prime[:(N-1),:]-x_prime[1:,:]),lmbda * alpha / beta),-lmbda * alpha / beta)
    
    print(f"iteration {i} obj {1/2*np.linalg.norm(radon(x,theta,circle=False)-sino)**2 + lmbda*np.linalg.norm((x[:,:(N-1)]-x[:,1:]).flatten(), 1) + lmbda*np.linalg.norm((x[:(N-1),:]-x[1:,:]).flatten(), 1)}")
    update_im.set_data(x)
    update_im.autoscale()
    fig.suptitle(f"CT using PDHG : Iteration {i}", fontsize=20)
    fig.canvas.draw_idle()
    plt.pause(0.001)
    
plt.show()
