#include <cassert>
#include <string>
#include <ctime>
#include <iostream>
#include <cmath>
#include "somt.h"
#include "mex.h"




class mystream : public std::streambuf
{
protected:
virtual std::streamsize xsputn(const char *s, std::streamsize n) { mexPrintf("%.*s", n, s); return n; }
virtual int overflow(int c=EOF) { if (c != EOF) { mexPrintf("%.1s", &c); } return 1; }
};
class scoped_redirect_cout
{
public:
	scoped_redirect_cout() { old_buf = std::cout.rdbuf(); std::cout.rdbuf(&mout); }
	~scoped_redirect_cout() { std::cout.rdbuf(old_buf); }
private:
	mystream mout;
	std::streambuf *old_buf;
};
static scoped_redirect_cout mycout_redirect;







void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray const *prhs[]) {
  const int d = 2;
  float mu;
  float tau;
  int max_iter;
  int norm_type;
  
  
  mu = float(*mxGetPr(prhs[2]));
  tau = float(*mxGetPr(prhs[3]));
  max_iter = int(*mxGetPr(prhs[4]));
  norm_type = int(*mxGetPr(prhs[5]));
  

  
  if ( nrhs!=6 ){
    mexErrMsgIdAndTxt("error:error",
                      "wrong number of inputs.");
  }
  
  if ( !(nlhs==1 || nlhs==3) ) {
    mexErrMsgIdAndTxt("error:error",
                      "output should be dist or [dist U Phi]");
  }

  const int n = mxGetDimensions(prhs[0])[0];
  const float dx = 1. / float(n - 1);
  
  const double* lambda0_float = mxGetPr(prhs[0]);
  const double* lambda1_float = mxGetPr(prhs[1]);
  
  float* lambda0 = new float[n*n];
  float* lambda1 = new float[n*n];
  
  //convert input to float
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < n; j++) {
      lambda0[i + n*j] = float(lambda0_float[i + n*j]);
      lambda1[i + n*j] = float(lambda1_float[i + n*j]);
    }
  }

  
  if ( !( n==1 || n==2 || n==4 || n==8 || n==16 || n==32 || n==64 || n==128 || n==256) ) {
    mexErrMsgIdAndTxt("error:error",
                      "n has to be a power of 2");
  }
  
  
  
  int threads_per_block;
  if (n==1) {
    threads_per_block = 1;
  } else if (n==2) {
    threads_per_block = 4;
  } else if (n==4) {
    threads_per_block = 16;
  } else if (n==8) {
    threads_per_block = 64;
  } else if (n==16) {
    threads_per_block = 128;
  } else if (n==32) {
    threads_per_block = 128;
  } else if (n==64) {
    threads_per_block = 256;
  } else if (n==128) {
    threads_per_block = 256;
  } else if (n==256) {
    threads_per_block = 256;
  } else {
    mexErrMsgIdAndTxt("error:error",
                      "unknown error");
  }
  
  



  //Problem data initialization over
  
  //initialize m and Phi to be 0
  float* m = new float[n*n*d];
  memset(m, 0, sizeof(float)*n*n*d);
  float* Phi = new float[n*n];
  memset(Phi, 0, sizeof(float)*n*n);
  float* m_prev = new float[n*n*d];
  float* Phi_prev = new float[n*n];

  //---------------------------------------------------------------------
  //start error checking
  //thread_per_block must divide number of points
  assert((n*n) % threads_per_block == 0);
  if (n*n / threads_per_block > 65535) {
    cout << "maximum block number exceeded" << endl;
    cout << "threads_per_block should be bigger" << endl;
    assert(false);
  }
  //end error checking
  //---------------------------------------------------------------------

  //create CUDA pointers
  float* d_lambda0;
  float* d_lambda1;
  float* d_m;
  float* d_m_temp;
  float* d_Phi;

  cout << "Allocating GPU memory" << endl;
  err_chk(cudaMalloc((void**)&d_lambda0, sizeof(float)*n*n));
  err_chk(cudaMalloc((void**)&d_lambda1, sizeof(float)*n*n));
  err_chk(cudaMalloc((void**)&d_m, sizeof(float)*n*n*d));
  err_chk(cudaMalloc((void**)&d_m_temp, sizeof(float)*n*n*d));
  err_chk(cudaMalloc((void**)&d_Phi, sizeof(float)*n*n));


  cout << "copying memory from host to GPU" << endl;
  err_chk(cudaMemcpy(d_m, m, sizeof(float)*n*n*d, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_Phi, Phi, sizeof(float)*n*n, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_lambda0, lambda0, sizeof(float)*n*n, cudaMemcpyHostToDevice));
  err_chk(cudaMemcpy(d_lambda1, lambda1, sizeof(float)*n*n, cudaMemcpyHostToDevice));






  cout << "starting iteration" << endl;
  clock_t begin = clock();
  for (int k = 0; k < max_iter; k++) {
  
    if (norm_type == 1) {
      mUpdate_l1 << < n*n / threads_per_block, threads_per_block >> > (d_Phi, d_m, d_m_temp, n, dx, mu);
    } else if (norm_type == 2) {
      mUpdate_l2 << < n*n / threads_per_block, threads_per_block >> > (d_Phi, d_m, d_m_temp, n, dx, mu);
    } else {
      assert(false);
    }

    //Phi = Phi + tau*(div m_temp-lambda0_lambda1)
    PhiUpdate << < n*n / threads_per_block, threads_per_block >> > (d_Phi, d_m_temp, d_lambda0, d_lambda1, n, dx, tau);

    //err_chk(cudaGetLastError());
  }
  float runtime = float(clock() - begin) / CLOCKS_PER_SEC;
  cout << "Total runtime is " << runtime << "s." << endl;
  cout << "This is " << runtime / float(max_iter) * 1000 << "ms per iteration" << endl;

  err_chk(cudaMemcpy(m, d_m, sizeof(float)*n*n*d, cudaMemcpyDeviceToHost));
  err_chk(cudaMemcpy(Phi, d_Phi, sizeof(float)*n*n, cudaMemcpyDeviceToHost));

  
  
  if (norm_type == 1) {
    plhs[0] = mxCreateDoubleScalar(wass_l1(m, n, d));
  } else if (norm_type == 2) {
    plhs[0] = mxCreateDoubleScalar(wass_l2(m, n, d));
  } else {
    assert(false);
  }
  if (nlhs==3) {
    size_t m_dim[3] = {n,n,d};
    size_t phi_dim[2] = {n,n};
    plhs[1] = mxCreateNumericArray(3, m_dim, mxDOUBLE_CLASS, mxREAL);
    plhs[2] = mxCreateNumericArray(2, phi_dim, mxDOUBLE_CLASS, mxREAL);
    
    
    double* m_out = mxGetPr(plhs[1]);
    double* phi_out = mxGetPr(plhs[2]);
    
    for (int ii=0; ii<n*n*d; ii++)
      m_out[ii] = double(m[ii]);
    
    for (int ii=0; ii<n*n; ii++)
      phi_out[ii] = double(Phi[ii]);
    
  }
  
  
  


  cout << "freeing CUDA resources" << endl;
  cudaFree(d_lambda0);
  cudaFree(d_lambda1);
  cudaFree(d_m);
  cudaFree(d_m_temp);
  cudaFree(d_Phi);

  cout << "freeing host resources" << endl;
  delete[] lambda0;
  delete[] lambda1;
  delete[] m;
  delete[] Phi;
  delete[] m_prev;
  delete[] Phi_prev;
}






