#include <cassert>
#include <string>
#include <ctime>
#include <iostream>
#include "mex.h"

using std::cout;
using std::endl;




class mystream : public std::streambuf
{
protected:
virtual std::streamsize xsputn(const char *s, std::streamsize n) { mexPrintf("%.*s", n, s); return n; }
virtual int overflow(int c=EOF) { if (c != EOF) { mexPrintf("%.1s", &c); } return 1; }
};
class scoped_redirect_cout
{
public:
	scoped_redirect_cout() { old_buf = std::cout.rdbuf(); std::cout.rdbuf(&mout); }
	~scoped_redirect_cout() { std::cout.rdbuf(old_buf); }
private:
	mystream mout;
	std::streambuf *old_buf;
};
static scoped_redirect_cout mycout_redirect;




//CUDA kernels

__global__ void uUpdate_l1(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);
__global__ void uUpdate_l2(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);
__global__ void uUpdate_nuc(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);

__global__ void wUpdate_l1(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu);
__global__ void wUpdate_l2(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu);
__global__ void wUpdate_nuc(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu);

__global__ void PhiUpdate(const float* __restrict__ d_L_arr, float* d_Phi, const float *d_u_temp, const float *d_w_temp, const float* __restrict__ d_Lambda0, const float* __restrict__ d_Lambda1, int n, int d, int ell, float dx, float tau);


float wass1(float* u, int n, int d);
float wass2(float* u, int n, int d);




void err_chk(cudaError err) {
  if (err != cudaSuccess) {
    cout << cudaGetErrorString(err) << endl;
    mxAssert(false,"error");
  }
}





void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray const *prhs[]) {
  const int d = 2;
  float mu;
  float nu;
  float tau;
  int max_iter;
  int u_norm_type;
  int w_norm_type;
  
  float alpha;

  alpha = float(*mxGetPr(prhs[3]));
  mu = float(*mxGetPr(prhs[4]));
  nu = float(*mxGetPr(prhs[5]));
  tau = float(*mxGetPr(prhs[6]));
  max_iter = int(*mxGetPr(prhs[7]));
  u_norm_type = int(*mxGetPr(prhs[8]));
  w_norm_type = int(*mxGetPr(prhs[9]));

  
  if ( nrhs!=10 ){
    mexErrMsgIdAndTxt("error:error",
                      "wrong number of inputs.");
  }
  
  if ( !(nlhs==1 || nlhs==4) ) {
    mexErrMsgIdAndTxt("error:error",
                      "output should be dist or [dist U W Phi]");
  }

  const int n = mxGetDimensions(prhs[0])[0];
  const float dx = 1. / float(n - 1);

   
  
  const double* lambda0_double = mxGetPr(prhs[0]);
  const double* lambda1_double = mxGetPr(prhs[1]);
  
  float* Lambda0 = new float[n*n*3*3];
  float* Lambda1 = new float[n*n*3*3];
  
  //convert input to float
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < n; j++) {
      for (int k = 0; k < 3; k++) {
        for (int l = 0; l < 3; l++) {
          Lambda0[i + n*j + n*n*k + n*n*3*l] = float(lambda0_double[i + n*j + n*n*k + n*n*3*l]);
          Lambda1[i + n*j + n*n*k + n*n*3*l] = float(lambda1_double[i + n*j + n*n*k + n*n*3*l]);
        }
      }
    }
  }
  if (mxGetNumberOfDimensions(prhs[2])!=3 || mxGetDimensions(prhs[2])[0]!=3 || mxGetDimensions(prhs[2])[0]!=3 ) {
    mexErrMsgIdAndTxt("error:error",
                      "L must be 3x3xell.");
  }
  const int ell = mxGetDimensions(prhs[2])[2];
  
  if (ell <= 0) {
    mexErrMsgIdAndTxt("error:error",
                      "L must be 3x3xell.");
  }
  
  
   
  
  const double* L_double = mxGetPr(prhs[2]);
  float* L_arr = new float[3*3*ell];
  
  //convert input to float
  for (int i = 0; i < 3; i++) {
    for (int j = 0; j < 3; j++) {
      for (int r = 0; r < ell; r++) {
        L_arr[i + 3*j + 3*3*r] = float( L_double[i + 3*j + 3*3*r] );
      }
    }
  }
  


  
  if ( !( n==1 || n==2 || n==4 || n==8 || n==16 || n==32 || n==64 || n==128 || n==256) ) {
    mexErrMsgIdAndTxt("error:error",
                      "n has to be a power of 2");
  }
  
  
  
  int threads_per_block;
  if (n==1) {
    threads_per_block = 1;
  } else if (n==2) {
    threads_per_block = 4;
  } else if (n==4) {
    threads_per_block = 16;
  } else if (n==8) {
    threads_per_block = 64;
  } else if (n==16) {
    threads_per_block = 128;
  } else if (n==32) {
    threads_per_block = 128;
  } else if (n==64) {
    threads_per_block = 256;
  } else if (n==128) {
    threads_per_block = 256;
  } else if (n==256) {
    threads_per_block = 256;
  } else {
    mexErrMsgIdAndTxt("error:error",
                      "unknown error");
  }
  
  



  //Problem data initialization over
  
  //initialize iteration variables
  float* u = new float[n*n*3*3*d];
  memset(u, 0, sizeof(float)*n*n*3*3*d);
  float* w = new float[n*n*3*3*ell];
  memset(w, 0, sizeof(float)*n*n*3*3*ell);
  float* Phi = new float[n*n*3*3];
  memset(Phi, 0, sizeof(float)*n*n*3*3);


  //---------------------------------------------------------------------
  //start error checking
  //thread_per_block must divide number of points
  mxAssert((n*n) % threads_per_block == 0,"error");
  if (n*n / threads_per_block > 65535) {
    cout << "maximum block number exceeded" << endl;
    cout << "threads_per_block should be bigger" << endl;
    mxAssert(false,"error");
  }
  //end error checking
  //---------------------------------------------------------------------

  //create CUDA pointers
  float* d_L_arr;
  float* d_Lambda0;
  float* d_Lambda1;
  float* d_u;
  float* d_w;
  float* d_u_temp;
  float* d_w_temp;
  float* d_Phi;


  err_chk(cudaMalloc((void**)&d_L_arr, sizeof(float)*3*3*ell));
  err_chk(cudaMalloc((void**)&d_Lambda0, sizeof(float)*n*n*3*3));
  err_chk(cudaMalloc((void**)&d_Lambda1, sizeof(float)*n*n*3*3));
  err_chk(cudaMalloc((void**)&d_u, sizeof(float)*n*n*3*3*d));
  err_chk(cudaMalloc((void**)&d_w, sizeof(float)*n*n*3*3*ell));
  err_chk(cudaMalloc((void**)&d_u_temp, sizeof(float)*n*n*3*3*d));
  err_chk(cudaMalloc((void**)&d_w_temp, sizeof(float)*n*n*3*3*ell));;
  err_chk(cudaMalloc((void**)&d_Phi, sizeof(float)*n*n*3*3));

  err_chk(cudaMemcpy(d_L_arr, L_arr, sizeof(float)*3*3*ell, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_Lambda0, Lambda0, sizeof(float)*n*n*3*3, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_Lambda1, Lambda1, sizeof(float)*n*n*3*3, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_u, u, sizeof(float)*n*n*3*3*d, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_w, w, sizeof(float)*n*n*3*3*ell, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_Phi, Phi, sizeof(float)*n*n*3*3, cudaMemcpyHostToDevice));
  

  const int num_streams = 2;
  cudaStream_t streams[num_streams];
  cudaStreamCreate(&streams[0]);
  cudaStreamCreate(&streams[1]);


  cout << "starting iteration" << endl;
  clock_t begin = clock();
  for (int k = 0; k < max_iter; k++) {
    if (u_norm_type == 1) {
      uUpdate_l1 <<< n*n / threads_per_block, threads_per_block, 0, streams[0] >>> (d_Phi, d_u, d_u_temp, n, d, dx, mu);
    } else if (u_norm_type == 2) {
      uUpdate_l2 <<< n*n / threads_per_block, threads_per_block, 0, streams[0] >>> (d_Phi, d_u, d_u_temp, n, d, dx, mu);
    } else if (u_norm_type == 3) {
      uUpdate_nuc <<< n*n / threads_per_block, threads_per_block, 0, streams[0] >>> (d_Phi, d_u, d_u_temp, n, d, dx, mu);
    } else {
      assert(false);
    }

    if (w_norm_type == 1) {
      wUpdate_l1 <<< n*n / threads_per_block, threads_per_block, 0, streams[1] >>> (d_L_arr, d_Phi, d_w, d_w_temp, alpha, n, ell, dx, nu);
    } else if (w_norm_type == 2) {
      wUpdate_l2 <<< n*n / threads_per_block, threads_per_block, 0, streams[1] >>> (d_L_arr, d_Phi, d_w, d_w_temp, alpha, n, ell, dx, nu);
    } else if (w_norm_type == 3) {
      wUpdate_nuc <<< n*n / threads_per_block, threads_per_block, 0, streams[1] >>> (d_L_arr, d_Phi, d_w, d_w_temp, alpha, n, ell, dx, nu);
    } else {
      assert(false);
    }
  
    PhiUpdate <<< n*n / threads_per_block, threads_per_block >>> (d_L_arr, d_Phi, d_u_temp, d_w_temp, d_Lambda0, d_Lambda1, n, d, ell, dx, tau);
    //err_chk(cudaGetLastError());
  }
  err_chk(cudaDeviceSynchronize());
  float runtime = float(clock() - begin) / CLOCKS_PER_SEC;
  cout << "Total runtime is " << runtime << "s." << endl;
  cout << "This is " << runtime / float(max_iter) * 1000 << "ms per iteration" << endl;


  
  
  err_chk(cudaMemcpy(u, d_u, sizeof(float)*n*n*3*3*d, cudaMemcpyDeviceToHost));
  err_chk(cudaMemcpy(w, d_w, sizeof(float)*n*n*3*3*ell, cudaMemcpyDeviceToHost));
  err_chk(cudaMemcpy(Phi, d_Phi, sizeof(float)*n*n*3*3, cudaMemcpyDeviceToHost));

  
  float dist = 0.0f;
  if (u_norm_type == 1) {
    dist += wass1(u,n,d);
  } else if (u_norm_type == 2) {
    dist += wass2(u,n,d);
  } else {
    assert(false);
  }

  if (w_norm_type == 1) {
    dist += alpha*wass1(w,n,ell);
  } else if (w_norm_type == 2) {
    dist += alpha*wass2(w,n,ell);
  } else {
    assert(false);
  }
  plhs[0] = mxCreateDoubleScalar(dist);
  
  
  if (nlhs==4) {
    size_t u_dim[5] = {n,n,3,3,d};
    size_t w_dim[5] = {n,n,3,3,ell};
    size_t phi_dim[4] = {n,n,3,3};
    plhs[1] = mxCreateNumericArray(5, u_dim, mxDOUBLE_CLASS, mxREAL);
    plhs[2] = mxCreateNumericArray(5, w_dim, mxDOUBLE_CLASS, mxREAL);
    plhs[3] = mxCreateNumericArray(4, phi_dim, mxDOUBLE_CLASS, mxREAL);
    
    
    double* u_out = mxGetPr(plhs[1]);
    double* w_out = mxGetPr(plhs[2]);
    double* phi_out = mxGetPr(plhs[3]);
    
    for (int ii=0; ii<n*n*3*3*d; ii++)
      u_out[ii] = double(u[ii]);
    
    for (int ii=0; ii<n*n*3*3*ell; ii++)
      w_out[ii] = double(w[ii]);
    for (int ii=0; ii<n*n*3*3; ii++)
      phi_out[ii] = double(Phi[ii]);
    
  }
  
  
  

  cudaFree(d_L_arr);
  cudaFree(d_Lambda0);
  cudaFree(d_Lambda1);
  cudaFree(d_u);
  cudaFree(d_w);
  cudaFree(d_u_temp);
  cudaFree(d_w_temp);
  cudaFree(d_Phi);

  delete[] Lambda0;
  delete[] Lambda1;
  delete[] Phi;
  delete[] L_arr;
  delete[] u;
  delete[] w;

}






float wass1(float* u, int n, int d) {
  float ret = 0;
  for (int ii = 0; ii < 3*3*d*n*n; ii++) {
    ret += abs(u[ii]);
  }
  return ret;
}



float wass2(float* u, int n, int d) {
  float ret = 0.0f;
  for (int ii=0; ii<n*n; ii++) {
    float val = 0.0f;
    for (int kk = 0; kk<3*3*d; kk++)
      val += u[ii + n*n*kk]*u[ii + n*n*kk];
    ret += sqrt(val);
  }
  return ret;
}


inline __device__ void shrink1(float* M, float mu, int i_n_j, int d, int n) {
  for (int ii = 0; ii < 3*3*d; ii++) {
    int ind = i_n_j + n*n*ii;
    if (abs(M[ind]) <= mu)
      M[ind] = 0.0f;
    else
      M[ind] *= (1-mu/abs(M[ind]));
  }
}



inline __device__ void shrink2(float* M, float mu, int i_n_j, int d, int n) {
  //shirnk2
  float val = 0.0f;
  for (int ii=0; ii<3*3*d; ii++) {
    int ind = i_n_j + n*n*ii;
    val += M[ind]*M[ind];
  }
  val = sqrt(val);

  float factor;
  if (abs(val) <= mu)
    factor = 0.0f;
  else
    factor = (1-mu/abs(val));

  for (int ii=0; ii<3*3*d; ii++) {
    int ind = i_n_j + n*n*ii;
    M[ind] *= factor;
  }
}


inline __device__ void shrinkNuc(float* M, float mu, int i_n_j, int d, int n) {

}



__global__ void uUpdate_l1(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < 3*3*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }
 

  //u_ij = u_ij + \mu (\nabla \Phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[i_n_j + 1 + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[(i+1) + n*j + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i_n_j + n + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i + n*(j+1) + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;

  
  shrink1(d_u, mu, i_n_j, d, n);
  
  
  //d_u_temp(i,j) = 2.0f*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<3*3*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2.0f*d_u[ind] - d_u_temp[ind];
  }
  
}

//Phi is read only. restrict?
__global__ void uUpdate_l2(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < 3*3*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }
  

  //u_ij = u_ij + \mu (\nabla \Phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[i_n_j + 1 + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[(i+1) + n*j + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i_n_j + n + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i + n*(j+1) + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;

  
  shrink2(d_u, mu, i_n_j, d, n);
  
  
  //d_u_temp(i,j) = 2.0f*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<3*3*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2.0f*d_u[ind] - d_u_temp[ind];
  }
  
}




__global__ void uUpdate_nuc(const float* __restrict__ d_Phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < 3*3*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }
  
  //u_ij = u_ij + \mu (\nabla \Phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[i_n_j + 1 + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*0] += mu * (d_Phi[(i+1) + n*j + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < 3*3; ii++)
      d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i_n_j + n + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*3*3*1] += mu * (d_Phi[i + n*(j+1) + n*n*ii] - d_Phi[i_n_j + n*n*ii]) / dx;

  
  shrinkNuc(d_u, mu, i_n_j, d, n);
  
  
  //d_u_temp(i,j) = 2.0f*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<3*3*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2.0f*d_u[ind] - d_u_temp[ind];
  }
  
}





//Phi is read only. restrict?
__global__ void wUpdate_l1(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_w_temp(i,j) = d_w(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = d_w[ind];
  }

  //room for optimization with __shared__ and __constant__
  //w_i = w_i + \nu \nabla_L Phi
  const int nn = n*n;
  for (int r = 0; r < ell; r++) {
    for (int k = 0; k < 3; k++) {
      for (int l = 0; l < 3; l++) {
        for (int s = 0; s < 3; s++) {
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] += nu * d_L_arr[k + 3*s + 3*3*r] * d_Phi[i_n_j + n*n*s + n*n*3*l];
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] -= nu * d_Phi[i_n_j + n*n*k + n*n*3*s] * d_L_arr[s + 3*l + 3*3*r];
        }
      }
    }
  }
  
  
  shrink1(d_w, alpha*nu, i_n_j, ell, n);
  
  
  //d_w_temp(i,j) = 2.0f*d_w(i,j)-d_w_prev(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = 2.0f*d_w[ind] - d_w_temp[ind];
  }
}



__global__ void wUpdate_l2(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_w_temp(i,j) = d_w(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = d_w[ind];
  }

  //room for optimization with __shared__ and __constant__
  //w_i = w_i + \nu \nabla_L Phi
  for (int r = 0; r < ell; r++) {
    for (int k = 0; k < 3; k++) {
      for (int l = 0; l < 3; l++) {
        for (int s = 0; s < 3; s++) {
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] += nu * d_L_arr[k + 3*s + 3*3*r] * d_Phi[i_n_j + n*n*s + n*n*3*l];
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] -= nu * d_Phi[i_n_j + n*n*k + n*n*3*s] * d_L_arr[s + 3*l + 3*3*r];
        }
      }
    }
  }
  
  
  shrink2(d_w, alpha*nu, i_n_j, ell, n);
  
  
  //d_w_temp(i,j) = 2.0f*d_w(i,j)-d_w_prev(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = 2.0f*d_w[ind] - d_w_temp[ind];
  }
}


__global__ void wUpdate_nuc(const float* __restrict__ d_L_arr, const float* __restrict__ d_Phi, float* d_w, float* d_w_temp, float alpha, int n, int ell, float dx, float nu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_w_temp(i,j) = d_w(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = d_w[ind];
  }
  

  //room for optimization with __shared__ and __constant__
  //w_i = w_i + \nu \nabla_L Phi
  for (int r = 0; r < ell; r++) {
    for (int k = 0; k < 3; k++) {
      for (int l = 0; l < 3; l++) {
        for (int s = 0; s < 3; s++) {
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] += nu * d_L_arr[k + 3*s + 3*3*r] * d_Phi[i_n_j + n*n*s + n*n*3*l];
          d_w[i_n_j + n*n*k + n*n*3*l + n*n*3*3*r] -= nu * d_Phi[i_n_j + n*n*k + n*n*3*s] * d_L_arr[s + 3*l + 3*3*r];
        }
      }
    }
  }
  
  
  shrinkNuc(d_w, alpha*nu, i_n_j, ell, n);
  
  
  //d_w_temp(i,j) = 2.0f*d_w(i,j)-d_w_prev(i,j)
  for (int ii = 0; ii < 3*3*ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = 2.0f*d_w[ind] - d_w_temp[ind];
  }
}







__global__ void PhiUpdate(const float* __restrict__ d_L_arr, float* d_Phi, const float *d_u_temp, const float *d_w_temp, const float* __restrict__ d_Lambda0, const float* __restrict__ d_Lambda1, int n, int d, int ell, float dx, float tau) {
  //const int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //const int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //const int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //div_L_w = div_L (w_temp)
  float div_L_w[3*3];
  
  for (int ii = 0; ii < 3*3; ii++)
      div_L_w[ii] = 0.0f;
  
  for (int r = 0; r < ell; r++) {
    for (int k = 0; k < 3; k++) {
      for (int l = 0; l < 3; l++) {
        for (int s = 0; s < 3; s++) {
          div_L_w[k + 3*l] -= d_L_arr[k + 3*s + 3*3*r] * d_w_temp[i_n_j + n*n*s + n*n*3*l + n*n*3*3*r];
          div_L_w[k + 3*l] += d_w_temp[i_n_j + n*n*k + n*n*3*s + n*n*3*3*r] * d_L_arr[s + 3*l + 3*3*r];
        }
      }
    }
  }

  
  //div_u = div(u_temp)
  float div_u[3*3];
  for (int ii = 0; ii < 3*3; ii++)	
      div_u[ii] = 0.0f;

  for (int r = 0; r < d; r++) 
    for (int ii = 0; ii < 3*3; ii++) 
      div_u[ii] += d_u_temp[i_n_j + n*n*ii + n*n*3*3*r];

  if ((i_n_j%n)>0)
    for (int ii = 0; ii < 3*3; ii++)
      div_u[ii] -= d_u_temp[i_n_j - 1 + n*n*ii + n*n*3*3*0];
      //div_u[ii] -= d_u_temp[(i-1) + n*j + n*n*ii + n*n*3*3*0];
  
  if ((i_n_j/n)>0)
    for (int ii = 0; ii < 3*3; ii++)
      div_u[ii] -= d_u_temp[i_n_j - n + n*n*ii + n*n*3*3*1];
      //div_u[ii] -= d_u_temp[i + n*(j-1) + n*n*ii + n*n*3*3*1];
  
  for (int ii = 0; ii < 3*3; ii++) 
      div_u[ii] /= dx;
  
  
  //Phi = Phi + tau ( div * (2u - u_temp) - div_L * (2w - w_temp) + Lambda1 - Lambda0 )
  for (int ii = 0; ii < 3 * 3; ii++) {
    int ind = i_n_j + n*n*ii;
    d_Phi[ind] += tau * (div_u[ii] + div_L_w[ii] + d_Lambda1[ind] - d_Lambda0[ind]);
  }
}