#include <boost/multi_array.hpp>
#include <complex>
#include <fstream>
#include "fourier.h"
#include "linalg.h"
#include "image_op.h"
#include "cameraProjectorCalibration.hh"

using namespace boost;
using namespace mimas;
using namespace std;

CameraProjectorCalibration::CameraProjectorCalibration
  ( const image< unsigned char > &_searchPattern ):
     searchPattern( _searchPattern ),
     pointPairsModified(false)
{
}

void CameraProjectorCalibration::addCalibFrame
   ( const mimas::image< unsigned char > &_frame,
     const Vector &_pos )
{
  patternFrame.push_back( _frame );
  pointPairs.push_back( make_pair( _pos, findPatternImg( _frame ) ) );
  pointPairsModified = true;
}

CameraProjectorCalibration::Matrix
   CameraProjectorCalibration::getHomography(void)
{
  if ( pointPairsModified ) {
    homography = genHomography( pointPairs );
    pointPairsModified = false;
  };
  return homography;
}

multi_array< double, 2 >::array_view< 2 >::type CameraProjectorCalibration::view
  ( multi_array< double, 2 > &in, int x, int y, int w, int h )
{
  typedef multi_array< double, 2 >::index_range range;
  typedef multi_array< double, 2 >::array_view< 2 >::type array_view;
  multi_array< double, 2 >::index_gen indices;
  array_view out = 
    in[ indices[ range( y, y + h ) ][ range( x, x + w ) ] ];
  return out;
}

CameraProjectorCalibration::Matrix CameraProjectorCalibration::genHomography
   ( const std::vector< std::pair< Vector, Vector > > &pointPairs )
{
  Matrix system( 2 * pointPairs.size(), 9 );
  Matrix Vt;

  // Fill in the matrix to solve the system
  for(int i=0; i<(signed)pointPairs.size(); i++ ) {

    double
      x1  = pointPairs[i].first[0],
      x2  = pointPairs[i].first[1],
      xs1 = pointPairs[i].second[0],
      xs2 = pointPairs[i].second[1];

    system(2*i,0) = x1;
    system(2*i,1) = x2;
    system(2*i,2) = 1;
    system(2*i,3) = 0;
    system(2*i,4) = 0;
    system(2*i,5) = 0;
    system(2*i,6) = -x1 * xs1;
    system(2*i,7) = -x2 * xs1;
    system(2*i,8) = -xs1;

    system(2*i+1,0) = 0;
    system(2*i+1,1) = 0;
    system(2*i+1,2) = 0;
    system(2*i+1,3) = -x1;
    system(2*i+1,4) = -x2;
    system(2*i+1,5) = -1;
    system(2*i+1,6) = x1 * xs2;
    system(2*i+1,7) = x2 * xs2;
    system(2*i+1,8) = xs2;
  }

  // calculate the SVD of the matrix A = U.sigma.Vt
  // we don't need to compute U or sigma.
  gesvd< double >( system, NULL, &Vt );

  // the result is the last column of the matrix V, which is the
  // last line of Vt.
  Matrix retVal( 3, 3 );
  // we put the result into the homography matrix
  retVal(0,0) = Vt(8,0);
  retVal(0,1) = Vt(8,1);
  retVal(0,2) = Vt(8,2);
  retVal(1,0) = Vt(8,3);
  retVal(1,1) = Vt(8,4);
  retVal(1,2) = Vt(8,5);
  retVal(2,0) = Vt(8,6);
  retVal(2,1) = Vt(8,7);
  retVal(2,2) = Vt(8,8);
  
  return retVal;
}

CameraProjectorCalibration::Vector CameraProjectorCalibration::findPatternDiffImg
   ( const image< double > &diffImg,
     const image< double > &tpl )
{
  multi_array< double, 2 > imgField
    ( extents[ diffImg.getHeight() + tpl.getWidth() ]
             [ diffImg.getWidth() + tpl.getHeight() ] );
  view( imgField, 0, 0, diffImg.getWidth(), diffImg.getHeight() ) =
    const_multi_array_ref< double, 2 >( diffImg.rawData(),
                                        extents[ diffImg.getHeight() ]
                                               [ diffImg.getWidth() ] );
  multi_array< double, 2 > tplField
    ( extents[ diffImg.getHeight() + tpl.getWidth() ]
             [ diffImg.getWidth() + tpl.getHeight() ] );
  view( tplField, 0, 0, tpl.getWidth(), tpl.getHeight() ) =
    const_multi_array_ref< double, 2 >( tpl.rawData(),
                                        extents[ tpl.getHeight() ]
                                               [ tpl.getWidth() ] );

  multi_array< std::complex< double >, 2 > ffff( rfft( tplField ) );
  multi_array< double, 2 > crossCorrelation
    ( invrfft( rfft( imgField ) *
               conj( ffff ) ) );
  int maxIndex =
    max_element( crossCorrelation.data(),
                 crossCorrelation.data() + crossCorrelation.num_elements() ) -
    crossCorrelation.data();

  Vector retVal( 3 );
  retVal[0] = maxIndex % crossCorrelation.shape()[1] + tpl.getWidth() / 2;
  retVal[1] = maxIndex / crossCorrelation.shape()[1] + tpl.getHeight() / 2;
  retVal[2] = 1.0;
  return retVal;
}

CameraProjectorCalibration::Vector CameraProjectorCalibration::findPatternImg
   ( const image< unsigned char > &img )
{
  assert( searchPattern.initialised() );
  return findPatternDiffImg( image< double >( img ) -
                             image< double >( backgroundFrame ),
                             image< double >( searchPattern ) );
}