#include <cmath>
#include <cassert>
#include <iostream>
using std::cout;
using std::endl;

void err_chk(cudaError err) {
  if (err != cudaSuccess) {
    cout << cudaGetErrorString(err) << endl;
    assert(false);
  }
}


float wass_l1(float* m, int n, int d) {
  float emd = 0.0f;
  for (int ii = 0; ii < n*n*d; ii++) 
    emd += abs(m[ii]);
  return emd;
}


float wass_l2(float* m, int n, int d) {
  float emd = 0.0f;
  float temp;
  for (int ii = 0; ii < n*n; ii++) {
    temp = 0.0f;
    for (int v = 0; v < d; v++)
      temp += m[ii + n*n*v] * m[ii + n*n*v];
    emd += sqrt(temp);
  }
  return emd;
}







__global__ void mUpdate_l1(const float* __restrict__ d_Phi, float* d_m, float* d_m_temp, int n, float dx, float mu) {
  int N = n*n;
  int ii = blockDim.x*blockIdx.x + threadIdx.x;
  int i = ii % n;
  int j = ii / n;

  float local_d_m_x = d_m[ii + N * 0];
  float local_d_m_y = d_m[ii + N * 1];
  float local_d_m_temp_x;
  float local_d_m_temp_y;

  //m_temp = m;
  local_d_m_temp_x = local_d_m_x;
  local_d_m_temp_y = local_d_m_y;


  //m_i = m_i + \mu \nabla \Phi
  if (i < n - 1)
    local_d_m_x += mu * (d_Phi[(i + 1) + n*j] - d_Phi[ii]) / dx;
  // else
  // local_d_m_x = 0;

  if (j < n - 1)
    local_d_m_y += mu * (d_Phi[i + n*(j + 1)] - d_Phi[ii]) / dx;


  
  //Shrink1
  float shrink_factor;
  if (abs(local_d_m_x) <= mu)
    shrink_factor = 0.0f;
  else
    shrink_factor = (1 - mu / abs(local_d_m_x));
  local_d_m_x *= shrink_factor;
  if (abs(local_d_m_y) <= mu)
    shrink_factor = 0.0f;
  else
    shrink_factor = (1 - mu / abs(local_d_m_y));
  local_d_m_y *= shrink_factor;

  
  d_m[ii + N * 0] = local_d_m_x;
  d_m[ii + N * 1] = local_d_m_y;

  //m_temp = 2m - m_temp
  d_m_temp[ii + N * 0] = 2.0f * local_d_m_x - local_d_m_temp_x;
  d_m_temp[ii + N * 1] = 2.0f * local_d_m_y - local_d_m_temp_y;
}


__global__ void mUpdate_l2(const float* __restrict__ d_Phi, float* d_m, float* d_m_temp, int n, float dx, float mu) {
  int N = n*n;
  int ii = blockDim.x*blockIdx.x + threadIdx.x;
  int i = ii % n;
  int j = ii / n;

  float local_d_m_x = d_m[ii + N * 0];
  float local_d_m_y = d_m[ii + N * 1];
  float local_d_m_temp_x;
  float local_d_m_temp_y;

  //m_temp = m;
  local_d_m_temp_x = local_d_m_x;
  local_d_m_temp_y = local_d_m_y;


  //m_i = m_i + \mu \nabla \Phi
  if (i < n - 1)
    local_d_m_x += mu * (d_Phi[(i + 1) + n*j] - d_Phi[ii]) / dx;
  // else
  // local_d_m_x = 0;

  if (j < n - 1)
    local_d_m_y += mu * (d_Phi[i + n*(j + 1)] - d_Phi[ii]) / dx;




  //Shrink2
  float shrink_factor, norm;
  norm = sqrt(local_d_m_x*local_d_m_x + local_d_m_y*local_d_m_y);
  if (norm <= mu)
    shrink_factor = 0.0f;
  else
    shrink_factor = 1 - mu / norm;
  local_d_m_x *= shrink_factor;
  local_d_m_y *= shrink_factor;

  
  

  d_m[ii + N * 0] = local_d_m_x;
  d_m[ii + N * 1] = local_d_m_y;

  //m_temp = 2m - m_temp
  d_m_temp[ii + N * 0] = 2.0f * local_d_m_x - local_d_m_temp_x;
  d_m_temp[ii + N * 1] = 2.0f * local_d_m_y - local_d_m_temp_y;
}



__global__ void PhiUpdate(float* d_Phi, const float* __restrict__ d_m_temp, const float* __restrict__ d_lambda0, const float* __restrict__ d_lambda1, int n, float dx, float tau) {
  int N = n*n;
  int ii = blockDim.x*blockIdx.x + threadIdx.x;
  int i = ii % n;
  int j = (ii / n);

  //divm = divergence * m_temp
  float m_minus;
  float divm = 0.0f;

  //x-gradient on m_x
  if (i >= 1)
    m_minus = d_m_temp[(i - 1) + n*j + N * 0];
  else
    m_minus = 0.0f;
  divm += (d_m_temp[ii + N * 0] - m_minus) / dx;

  //y-gradient on m_y
  if (j >= 1)
    m_minus = d_m_temp[i + n*(j - 1) + N * 1];
  else
    m_minus = 0.0f;
  divm += (d_m_temp[ii + N * 1] - m_minus) / dx;


  //Phi = Phi + tau ( divergence * (2m - m_temp) + lambda1 - lambda0 )
  d_Phi[ii] += tau * (divm + d_lambda1[ii] - d_lambda0[ii]);
}


