#include <cassert>
#include <string>
#include <ctime>
#include <iostream>
#include "mex.h"

using std::cout;
using std::endl;



class mystream : public std::streambuf
{
protected:
virtual std::streamsize xsputn(const char *s, std::streamsize n) { mexPrintf("%.*s", n, s); return n; }
virtual int overflow(int c=EOF) { if (c != EOF) { mexPrintf("%.1s", &c); } return 1; }
};
class scoped_redirect_cout
{
public:
	scoped_redirect_cout() { old_buf = std::cout.rdbuf(); std::cout.rdbuf(&mout); }
	~scoped_redirect_cout() { std::cout.rdbuf(old_buf); }
private:
	mystream mout;
	std::streambuf *old_buf;
};
static scoped_redirect_cout mycout_redirect;




//CUDA kernels
__global__ void uUpdate_l1(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);
__global__ void uUpdate_l2(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);
__global__ void uUpdate_l12(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu);

__global__ void wUpdate_l1(const float* __restrict__  d_phi, float* d_w, float* d_w_temp, int n, float dx, float nu, float alpha);
__global__ void wUpdate_l2(const float* __restrict__  d_phi, float* d_w, float* d_w_temp, int n, float dx, float nu, float alpha);

__global__ void phiUpdate(float* d_phi, const float* __restrict__ d_u_temp, const float* __restrict__ d_w_temp, const float* __restrict__  d_lambda0, const float* __restrict__  d_lambda1, int n, int d, float dx, float tau);


float u_wass1(float* u, int n, int d);
float u_wass2(float* u, int n, int d);
float u_wass12(float* u, int n, int d);
float w_wass1(float* w, int n, int d);
float w_wass2(float* w, int n, int d);

void err_chk(cudaError err) {
  if (err != cudaSuccess) {
    cout << cudaGetErrorString(err) << endl;
    mxAssert(false,"error");
  }
}


# define K 3
# define ell 3

__constant__ int c_n;
__constant__ int c_d;
__constant__ float c_c12;
__constant__ float c_c23;
__constant__ float c_c13;

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray const *prhs[]) {
  const int d = 2;
  float mu;
  float nu;
  float tau;
  int max_iter;
  int u_norm_type;
  int w_norm_type;
  
  float c12, c23, c13;
  float alpha;
  
  c12 = float(*mxGetPr(prhs[2]));
  c23 = float(*mxGetPr(prhs[3]));
  c13 = float(*mxGetPr(prhs[4]));
  
  alpha = float(*mxGetPr(prhs[5]));
  
  
  mu = float(*mxGetPr(prhs[6]));
  nu = float(*mxGetPr(prhs[7]));
  tau = float(*mxGetPr(prhs[8]));
  max_iter = int(*mxGetPr(prhs[9]));
  u_norm_type = int(*mxGetPr(prhs[10]));
  w_norm_type = int(*mxGetPr(prhs[11]));

  
  if ( nrhs!=12 ){
    mexErrMsgIdAndTxt("error:error",
                      "wrong number of inputs.");
  }
  
  if ( !(nlhs==1 || nlhs==4) ) {
    mexErrMsgIdAndTxt("error:error",
                      "output should be dist or [dist U W phi]");
  }

  const int n = mxGetDimensions(prhs[0])[0];
  const float dx = 1. / float(n - 1);

   
  
  const double* lambda0_double = mxGetPr(prhs[0]);
  const double* lambda1_double = mxGetPr(prhs[1]);
  
  float* lambda0 = new float[n*n*K];
  float* lambda1 = new float[n*n*K];
  
  //convert input to float
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < n; j++) {
      for (int k = 0; k < K; k++) {
        lambda0[i + n*j + n*n*k] = float(lambda0_double[i + n*j + n*n*k]);
        lambda1[i + n*j + n*n*k] = float(lambda1_double[i + n*j + n*n*k]);
      }
    }
  }
  
   

  
  if ( !( n==1 || n==2 || n==4 || n==8 || n==16 || n==32 || n==64 || n==128 || n==256) ) {
    mexErrMsgIdAndTxt("error:error",
                      "n has to be a power of 2");
  }
  
  
  
  int threads_per_block;
  if (n==1) {
    threads_per_block = 1;
  } else if (n==2) {
    threads_per_block = 4;
  } else if (n==4) {
    threads_per_block = 16;
  } else if (n==8) {
    threads_per_block = 64;
  } else if (n==16) {
    threads_per_block = 128;
  } else if (n==32) {
    threads_per_block = 128;
  } else if (n==64) {
    threads_per_block = 256;
  } else if (n==128) {
    threads_per_block = 256;
  } else if (n==256) {
    threads_per_block = 256;
  } else {
    mexErrMsgIdAndTxt("error:error",
                      "unknown error");
  }
  
  



  //Problem data initialization over
  
  //initialize iteration variables
  float* u = new float[n*n*K*d];
  memset(u, 0, sizeof(float)*n*n*K*d);
  float* w = new float[n*n*ell];
  memset(w, 0, sizeof(float)*n*n*ell);
  float* phi = new float[n*n*K];
  memset(phi, 0, sizeof(float)*n*n*K);


  //---------------------------------------------------------------------
  //start error checking
  //thread_per_block must divide number of points
  mxAssert((n*n) % threads_per_block == 0,"error");
  if (n*n / threads_per_block > 65535) {
    cout << "maximum block number exceeded" << endl;
    cout << "threads_per_block should be bigger" << endl;
    mxAssert(false,"error");
  }
  //end error checking
  //---------------------------------------------------------------------

  //create CUDA pointers
  float* d_lambda0;
  float* d_lambda1;
  float* d_u;
  float* d_w;
  float* d_u_temp;
  float* d_w_temp;
  float* d_phi;


  err_chk(cudaMalloc((void**)&d_lambda0, sizeof(float)*n*n*K));
  err_chk(cudaMalloc((void**)&d_lambda1, sizeof(float)*n*n*K));
  err_chk(cudaMalloc((void**)&d_u, sizeof(float)*n*n*K*d));
  err_chk(cudaMalloc((void**)&d_w, sizeof(float)*n*n*ell));
  err_chk(cudaMalloc((void**)&d_u_temp, sizeof(float)*n*n*K*d));
  err_chk(cudaMalloc((void**)&d_w_temp, sizeof(float)*n*n*ell));;
  err_chk(cudaMalloc((void**)&d_phi, sizeof(float)*n*n*K));

  err_chk(cudaMemcpy(d_lambda0, lambda0, sizeof(float)*n*n*K, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_lambda1, lambda1, sizeof(float)*n*n*K, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_u, u, sizeof(float)*n*n*K*d, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_w, w, sizeof(float)*n*n*ell, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_phi, phi, sizeof(float)*n*n*K, cudaMemcpyHostToDevice));
  
  err_chk(cudaMemcpyToSymbol(c_c12, &c12, sizeof(float)));
  err_chk(cudaMemcpyToSymbol(c_c23, &c23, sizeof(float)));
  err_chk(cudaMemcpyToSymbol(c_c13, &c13, sizeof(float)));



  cout << "starting iteration" << endl;
  clock_t begin = clock();
  for (int k = 0; k < max_iter; k++) {
  
    if (u_norm_type == 1) {  
      uUpdate_l1 <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_u, d_u_temp, n, d, dx, mu);
    } else if (u_norm_type == 2) {
      uUpdate_l2 <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_u, d_u_temp, n, d, dx, mu);
    } else if (u_norm_type == 3) {
      uUpdate_l12 <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_u, d_u_temp, n, d, dx, mu);
    } else {
      assert(false);
    }
    
    
    if (w_norm_type == 1) {
      wUpdate_l1 <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_w, d_w_temp, n, dx, nu, alpha);
    } else if (w_norm_type == 2) {
      wUpdate_l2 <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_w, d_w_temp,  n, dx, nu, alpha);
    } else {
      assert(false);
    }
    phiUpdate <<< n*n / threads_per_block, threads_per_block >>> (d_phi, d_u_temp, d_w_temp, d_lambda0, d_lambda1, n, d, dx, tau);
    //err_chk(cudaGetLastError());
  }
  err_chk(cudaDeviceSynchronize());
  float runtime = float(clock() - begin) / CLOCKS_PER_SEC;
  cout << "Total runtime is " << runtime << "s." << endl;
  cout << "This is " << runtime / float(max_iter) * 1000 << "ms per iteration" << endl;


  
  
  err_chk(cudaMemcpy(u, d_u, sizeof(float)*n*n*K*d, cudaMemcpyDeviceToHost));
  err_chk(cudaMemcpy(w, d_w, sizeof(float)*n*n*ell, cudaMemcpyDeviceToHost));
  err_chk(cudaMemcpy(phi, d_phi, sizeof(float)*n*n*K, cudaMemcpyDeviceToHost));

  float dist_val = 0.0f;
  
  if (u_norm_type == 1) {
    dist_val += u_wass1(u, n, d);
  } else if (u_norm_type == 2) {
    dist_val += u_wass2(u, n, d);
  } else if (u_norm_type == 3) {
    dist_val += u_wass12(u, n, d);
  } else {
    assert(false);
  }
  
  if (w_norm_type == 1) {
    dist_val += alpha*w_wass1(w, n, d);
  } else if (w_norm_type == 2) {
    dist_val += alpha*w_wass2(w, n, d);
  } else {
    assert(false);
  }
    
  plhs[0] = mxCreateDoubleScalar(dist_val);
  
  if (nlhs==4) {
    size_t u_dim[4] = {n,n,K,d};
    size_t w_dim[3] = {n,n,ell};
    size_t phi_dim[3] = {n,n,K};
    plhs[1] = mxCreateNumericArray(4, u_dim, mxDOUBLE_CLASS, mxREAL);
    plhs[2] = mxCreateNumericArray(3, w_dim, mxDOUBLE_CLASS, mxREAL);
    plhs[3] = mxCreateNumericArray(3, phi_dim, mxDOUBLE_CLASS, mxREAL);
    
    
    double* u_out = mxGetPr(plhs[1]);
    double* w_out = mxGetPr(plhs[2]);
    double* phi_out = mxGetPr(plhs[3]);
    
    for (int ii=0; ii<n*n*K*d; ii++)
      u_out[ii] = double(u[ii]);
    
    for (int ii=0; ii<n*n*ell; ii++)
      w_out[ii] = double(w[ii]);
    for (int ii=0; ii<n*n*K; ii++)
      phi_out[ii] = double(phi[ii]);
    
  }
  
  
  
  cudaFree(d_lambda0);
  cudaFree(d_lambda1);
  cudaFree(d_u);
  cudaFree(d_w);
  cudaFree(d_u_temp);
  cudaFree(d_w_temp);
  cudaFree(d_phi);

  delete[] lambda0;
  delete[] lambda1;
  delete[] phi;
  delete[] u;
  delete[] w;
}






//phi is read only. restrict?
__global__ void uUpdate_l1(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < K*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }
  

  
  //u_ij = u_ij + \mu (\nabla \phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[i_n_j + 1 + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[(i+1) + n*j + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i_n_j + n + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i + n*(j+1) + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;



  //shirnk1
  for (int ii=0; ii<K*d; ii++) {
    int ind = i_n_j + n*n*ii;
    float factor = 0.0f;
    float val = d_u[ind];
    if (abs(val) <= mu)
      factor = 0.0f;
    else
      factor = (1-mu/abs(val));
    d_u[ind] = factor*val;
  }
  
  //d_u_temp(i,j) = 2*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<K*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2*d_u[ind] - d_u_temp[ind];
  }
  
}

__global__ void uUpdate_l2(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < K*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }

  
  //u_ij = u_ij + \mu (\nabla \phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[i_n_j + 1 + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[(i+1) + n*j + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i_n_j + n + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i + n*(j+1) + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;



  
  //shirnk2
  float val = 0.0f;
  for (int ii=0; ii<K*d; ii++) {
    int ind = i_n_j + n*n*ii;
    val += d_u[ind]*d_u[ind];
  }
  val = sqrt(val);

  float factor;
  if (abs(val) <= mu)
    factor = 0.0f;
  else
    factor = (1-mu/abs(val));

  for (int ii=0; ii<K*d; ii++) {
    int ind = i_n_j + n*n*ii;
    d_u[ind] *= factor;
  }
    
  
  
  
  //d_u_temp(i,j) = 2*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<K*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2*d_u[ind] - d_u_temp[ind];
  }
  
}




__global__ void uUpdate_l12(const float* __restrict__  d_phi, float* d_u, float* d_u_temp, int n, int d, float dx, float mu) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //d_u_temp(i,j) = d_u(i,j)
  for (int ii = 0; ii < K*d; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = d_u[ind];
  }

  
  //u_ij = u_ij + \mu (\nabla \phi)_ij
  //if (i < n-1)
  if ( (i_n_j%n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[i_n_j + 1 + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*0] += mu * (d_phi[(i+1) + n*j + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      
  //if (j < n-1)
  if ( (i_n_j/n) < n-1)
    for (int ii = 0; ii < K; ii++)
      d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i_n_j + n + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;
      //d_u[i_n_j + n*n*ii + n*n*K*1] += mu * (d_phi[i + n*(j+1) + n*n*ii] - d_phi[i_n_j + n*n*ii]) / dx;



  
  //shirnk12
  float val, factor;
  for(int ii = 0; ii<K; ii++) {
    int ind = i_n_j + n*n*ii;
    val = sqrt(d_u[ind + n*n*K*0]*d_u[ind + n*n*K*0] + d_u[ind + n*n*K*1]*d_u[ind + n*n*K*1]);
    if (abs(val) <= mu)
      factor = 0.0f;
    else
      factor = (1-mu/abs(val));
    d_u[ind + n*n*K*0] *= factor;
    d_u[ind + n*n*K*1] *= factor;
  }
    
  
  
  
  //d_u_temp(i,j) = 2*d_u(i,j)-d_u_prev(i,j)
  
  for (int ii = 0; ii<K*d; ii++){
    int ind = i_n_j + n*n*ii;
    d_u_temp[ind] = 2*d_u[ind] - d_u_temp[ind];
  }
  
}




__global__ void wUpdate_l1(const float* __restrict__  d_phi, float* d_w, float* d_w_temp, int n, float dx, float nu, float alpha) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  
  //d_w_temp(i,j) = d_w(i,j)
  for (int ii = 0; ii < ell; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = d_w[ind];
  }

  //w_i = w_i + \nu \nabla_G phi
  d_w[i_n_j + n*n*0] += nu/ c_c12*(d_phi[i_n_j + n*n*0]-d_phi[i_n_j + n*n*1]);
  d_w[i_n_j + n*n*1] += nu/ c_c23*(d_phi[i_n_j + n*n*1]-d_phi[i_n_j + n*n*2]);
  d_w[i_n_j + n*n*2] += nu/ c_c13*(d_phi[i_n_j + n*n*0]-d_phi[i_n_j + n*n*2]);
  
  
  //shirnk1
  for (int ii=0; ii<ell; ii++) {
    int ind = i_n_j + n*n*ii;
    float factor = 0.0f;
    float val = d_w[ind];
    if (abs(val) <= alpha*nu)
      factor = 0.0f;
    else
      factor = (1-alpha*nu/abs(val));
    d_w[ind] = factor*val;
  }
  
  
  //d_w_temp(i,j) = 2*d_w(i,j)-d_w_prev(i,j)
  for (int ii = 0; ii < ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = 2*d_w[ind] - d_w_temp[ind];
  }
}






__global__ void wUpdate_l2(const float* __restrict__  d_phi, float* d_w, float* d_w_temp, int n, float dx, float nu, float alpha) {
  //int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  
  //d_w_temp(i,j) = d_w(i,j)
  for (int ii = 0; ii < ell; ii++) {
    const int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = d_w[ind];
  }

  //w_i = w_i + \nu \nabla_G phi
  d_w[i_n_j + n*n*0] += nu/ c_c12*(d_phi[i_n_j + n*n*0]-d_phi[i_n_j + n*n*1]);
  d_w[i_n_j + n*n*1] += nu/ c_c23*(d_phi[i_n_j + n*n*1]-d_phi[i_n_j + n*n*2]);
  d_w[i_n_j + n*n*2] += nu/ c_c13*(d_phi[i_n_j + n*n*0]-d_phi[i_n_j + n*n*2]);
  

  //shirnk2
  float val = 0.0f;
  for (int ii=0; ii<ell; ii++) {
    int ind = i_n_j + n*n*ii;
    val += d_w[ind]*d_w[ind];
  }
  val = sqrt(val);

  float factor;
  if (abs(val) <= alpha*nu)
    factor = 0.0f;
  else
    factor = (1-alpha*nu/abs(val));

  for (int ii=0; ii<ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w[ind] *= factor;
  }
    
  
  //d_w_temp(i,j) = 2*d_w(i,j)-d_w_prev(i,j)
  for (int ii = 0; ii < ell; ii++) {
    int ind = i_n_j + n*n*ii;
    d_w_temp[ind] = 2*d_w[ind] - d_w_temp[ind];
  }
}





__global__ void phiUpdate(float* d_phi, const float* __restrict__ d_u_temp, const float* __restrict__ d_w_temp, const float* __restrict__  d_lambda0, const float* __restrict__  d_lambda1, int n, int d, float dx, float tau) {
  //const int i = (blockDim.x*blockIdx.x + threadIdx.x) % n;
  //const int j = (blockDim.x*blockIdx.x + threadIdx.x) / n;
  //const int i_n_j = i + n*j;
  const int i_n_j = blockDim.x*blockIdx.x + threadIdx.x;
  
  //div_L_G = div_G (w_temp)
  float div_G_w[K];
  
  div_G_w[0] = - (1/c_c12)*d_w_temp[i_n_j + n*n*0] -  (1/c_c13)*d_w_temp[i_n_j + n*n*2];
  div_G_w[1] =  (1/c_c12)*d_w_temp[i_n_j + n*n*0] -  (1/c_c23)*d_w_temp[i_n_j + n*n*1];
  div_G_w[2] =  (1/c_c23)*d_w_temp[i_n_j + n*n*1] +  (1/c_c13)*d_w_temp[i_n_j + n*n*2];

  
  //div_u = div(u_temp)
  float div_u[K];
  for (int ii = 0; ii < K; ii++)	
      div_u[ii] = 0.0f;

  for (int r = 0; r < d; r++) 
    for (int ii = 0; ii < K; ii++) 
      div_u[ii] += d_u_temp[i_n_j + n*n*ii + n*n*K*r];

  if ((i_n_j%n)>0)
    for (int ii = 0; ii < K; ii++)
      div_u[ii] -= d_u_temp[i_n_j - 1 + n*n*ii + n*n*K*0];
      //div_u[ii] -= d_u_temp[(i-1) + n*j + n*n*ii + n*n*K*0];
  
  if ((i_n_j/n)>0)
    for (int ii = 0; ii < K; ii++)
      div_u[ii] -= d_u_temp[i_n_j - n + n*n*ii + n*n*K*1];
      //div_u[ii] -= d_u_temp[i + n*(j-1) + n*n*ii + n*n*K*1];
  
  for (int ii = 0; ii < K; ii++) 
      div_u[ii] /= dx;
  
  
  //phi = phi + tau ( div * (2u - u_temp) - div_L * (2w - w_temp) + lambda1 - lambda0 )
  for (int ii = 0; ii < K; ii++) {
    int ind = i_n_j + n*n*ii;
    d_phi[ind] += tau * (div_u[ii] + div_G_w[ii] + d_lambda1[ind] - d_lambda0[ind]);
  }
}


float u_wass1(float* u, int n, int d) {
  float emd = 0.0f;
  for (int ii=0; ii < n*n*d*K; ii++)
    emd += abs(u[ii]);
  return emd;
}


float u_wass12(float* u, int n, int d) {
  float emd = 0.0f;
  for (int ii=0; ii<n*n*K; ii++)
    emd += sqrt(u[ii + n*n*K*0]*u[ii + n*n*K*0]+u[ii + n*n*K*1]*u[ii + n*n*K*1]);
  return emd;
}


float u_wass2(float* u, int n, int d) {
  float emd = 0.0f;
  float temp;
  for (int ii=0; ii<n*n; ii++) {
    temp = 0.0f;
    for (int jj=0; jj<d*K; jj++)
      temp += u[ii + n*n*jj]*u[ii + n*n*jj];
    emd += sqrt(temp);
  }
  return emd;
}


float w_wass1(float* w, int n, int d) {
  float emd = 0.0f;
  for (int ii=0; ii < n*n*ell; ii++)
    emd += abs(w[ii]);
  return emd;
}

float w_wass2(float* w, int n, int d) {
  float emd = 0.0f;
  float temp;
  for (int ii=0; ii<n*n; ii++) {
    temp = 0.0f;
    for (int jj=0; jj<ell; jj++)
      temp += w[ii + n*n*jj]*w[ii + n*n*jj];
    emd += sqrt(temp);
  }
  return emd;
}
