/*
 * =====================================================================================
 *
 *       Filename:  LPsolver.h
 *
 *    Description:  The header file for LPsolver
 *
 *        Version:  1.0
 *        Created:  04/09/2009 12:46:31 PM
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
 *        Company:  THU
 *
 * =====================================================================================
 */


#ifndef LPsolver_H
#define LPsolver_H

#include <vector>
#include <string>
#include <numeric>
#include <limits>
// For sqrt()
#include <cmath>

// For linear programming
#include <glpk.h>
#include "Utility2.h"
#include <gsl/gsl_vector.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_multifit.h>
#include "QuadProg++.hh"

using namespace QuadProgPP;


/*
 * =====================================================================================
 *        Class:  LPsolver
 *  Description:  This class solve the LP problem related to a gene
 * =====================================================================================
 */
class LPsolver
{
	public:
		/* ====================  LIFECYCLE     ======================================= */
		LPsolver ()
		{
		};                             /* constructor      */

		void
		SetParameters(double noise_level, int read_len, int cross_strength)
		{
			mNoiseLevel = noise_level;
			mReadLen = read_len;
			mCrossStrength = cross_strength;
		}

		virtual ~LPsolver (){};                            /* destructor       */

		void GetExpLevel(vector<double>& exp_level){exp_level = mExpLevel;}
		void GetFitValue(vector<double>& fit_value){fit_value = mFitValue;}
		void GetError(vector<double>& error){error = mError;}

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  LPsolver
		 *       Method:  ResultPvalue
		 *  Description:  Calculate pvalue of the result on lastest instance
		 *  			  solved by this LP solver
		 *
		 *        Param:  mFitValue
		 *       Return:  
		 *--------------------------------------------------------------------------------------
		 */
		double
		ResultPvalue()
		{
			return mPvalue;
		}		/* -----  end of method LPsolver::ResultPvalue  ----- */

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  LPsolver
		 *       Method:  SolveLeastSquareCon
		 *  Description:  This function calculate least square error sum with constraints
		 *        Param:  
		 *       Return:  The objective function value
		 *--------------------------------------------------------------------------------------
		 */
		double
		SolveLeastSquareCon(const vector<vector<int> >& measure_in_isoform,
							const vector<vector<double> >& measure_virtual_length,
							const vector<double>& measure_read)
		{
			int iso_cnt = measure_in_isoform.size();
			int measure_cnt = measure_read.size();

			if (0 == measure_cnt)
			{
				mPvalue = 1;
				return 0;
			}
			if (0 == iso_cnt) 
			{
				mObjValue = measure_cnt;
				mPvalue = gsl_cdf_chisq_Q (mObjValue, measure_cnt);
				return mPvalue;
			}

//			if (iso_cnt > measure_cnt)
//			{
//				mPvalue = 1;
//				return 0;
//			}

			mFitValue = measure_read;

			Matrix<double> A(0.0, measure_cnt, iso_cnt);
			for (unsigned i = 0; i < measure_virtual_length.size(); i++)
				for (unsigned j = 0; j < measure_virtual_length[i].size(); j++)
					A[measure_in_isoform[i][j]][i] = measure_virtual_length[i][j];

			int iter_cnt = 10;
			for (int iter = 0; iter < iter_cnt; iter++)
			{
				Matrix<double> C(0.0, measure_cnt, measure_cnt);
				for (int i = 0; i < measure_cnt; i++)
					//C[i][i] = 1.0;
					if (mFitValue[i] < 1) C[i][i] = 1;  // large enough
					else C[i][i] = 1.0 / mFitValue[i];

				Matrix<double> T(measure_cnt, iso_cnt);
				for (int i = 0; i < measure_cnt; i++)
					for (int j = 0; j < iso_cnt; j++)
					{
						double sum = 0;
						for (int k = 0; k < measure_cnt; k++)
							sum += C[i][k] * A[k][j];
						T[i][j] = sum;
					}

				Matrix<double> obj_quad_coeff(0.0, iso_cnt, iso_cnt);

				for (int i = 0; i < iso_cnt; i++)
					for (int j = 0; j < iso_cnt; j++)
					{
						double sum = 0;
						for (int k = 0; k < measure_cnt; k++)
							sum += A[k][i] * T[k][j];
						obj_quad_coeff[i][j] = sum;
						if (i == j) obj_quad_coeff[i][j] += 0.0001;
					}

				//print_matrix("C", C);
				//print_matrix("obj_quad_coeff", obj_quad_coeff);

				Vector<double> obj_linear_coeff(0.0, iso_cnt);
				for (int i = 0; i < iso_cnt; i++)
				{
					double sum = 0;
					for (int j = 0; j < measure_cnt; j++)
					 // :TRICKY:01/15/2010 11:25:36 PM:feeldead: 
					 // Here, sum += measure_read[j] * T[j][i], instead of
					 // 	  sum += mFitValue[j] * T[j][i]
					 // according to the formula. Because mFitValue is used as
					 // the square of the standard deviation but measure_read is
					 // the observed value. Here, the observed value is needed.
					 // The improvement using measure_read instead of mFitValue 
					 // is obvious especially in the accuracy of calculating 
					 // expression values given isoforms.
						sum += measure_read[j] * T[j][i];
					obj_linear_coeff[i] = -sum;
				}

				Matrix<double> con_quad_coeff_eq_zero(iso_cnt, 0);
				Vector<double> con_linear_coeff_eq_zero;
				Matrix<double> con_quad_coeff_ge_zero(iso_cnt, iso_cnt);
				for (int i = 0; i < iso_cnt; i++)
					for (int j = 0; j < iso_cnt; j++)
						con_quad_coeff_ge_zero[j][i] = 0;
				for (int i = 0; i < iso_cnt; i++)
					con_quad_coeff_ge_zero[i][i] = 1;

				Vector<double> con_linear_coeff_ge_zero(iso_cnt);
				for (int i = 0; i < iso_cnt; i++)
					con_linear_coeff_ge_zero[i] = 0;

				Vector<double> exp_level;
				exp_level.resize(iso_cnt);
				mObjValue = solve_quadprog (obj_quad_coeff, obj_linear_coeff,
											con_quad_coeff_eq_zero, con_linear_coeff_eq_zero,
											con_quad_coeff_ge_zero, con_linear_coeff_ge_zero,
											exp_level);
				for (int i = 0; i < measure_cnt; i++)
				{
					double sum = 0;
					for (int j = 0; j < iso_cnt; j++)
						sum += exp_level[j] * A[i][j];
					mFitValue[i] = sum;
				}

				mExpLevel.resize(iso_cnt);
				bool b_identical = true;
				for (int i = 0; i < iso_cnt; i++)
				{
					if (mExpLevel[i] != exp_level[i])
						b_identical = false;
					mExpLevel[i] = exp_level[i];
				}

				/* 
				cout << __func__ << " iter = " << iter << endl;
				for (int i = 0; i < measure_cnt; i++)
					cout << "     " << old_fit[i] << "  -->  " << mFitValue[i] << endl;
				cout << endl;
				*/

				if (b_identical) break;
			} // for iter

			mError.resize(measure_cnt);
			for (int i = 0; i < measure_cnt; i++)
				mError[i] = measure_read[i] - mFitValue[i];

			// Take the top ranked measures as the indicator
			vector<double> sorted_measure_cnt = mFitValue;
			vector<double> sorted_error = mError;
			vector<int> sortedIdx;
			UtilityTemp<double, greater<double> >::Sort(sorted_measure_cnt, sortedIdx);
			UtilityTemp<double>::SortByIndex(sorted_error, sortedIdx);
			int top_cnt;
			for (top_cnt = 0; top_cnt < sorted_measure_cnt.size() && top_cnt < 10; top_cnt++)
				if (sorted_measure_cnt[top_cnt] < 1) break;

			mObjValue = 0;
			for (int i = 0; i < top_cnt; i++)
				mObjValue += sorted_error[i] * sorted_error[i] / sorted_measure_cnt[i];

			if (0 == mObjValue)
				mPvalue = 1;
			else
				mPvalue = gsl_cdf_chisq_Q (mObjValue, top_cnt);

			return mObjValue;
		}		/* -----  end of method LPsolver::SolveLeastSquareCon  ----- */

		void print_matrix(const char* name, const Matrix<double>& A, int n = -1, int m = -1)
		{
		  std::string t;
		  if (n == -1)
			n = A.nrows();
		  if (m == -1)
			m = A.ncols();
			
		  cout << name << ": " << std::endl;
		  for (int i = 0; i < n; i++)
		  {
			cout << " ";
			for (int j = 0; j < m; j++)
			  cout << A[i][j] << ", ";
			cout << std::endl;
		  }
		  std::cout << t << std::endl;
		}


	protected:
		double mNonZero;
		double mReads;
		int mGeneLen;
		double mNoiseLevel;
		int mReadLen;
		int mCrossStrength;

		vector<double> mFitValue;
		vector<double> mError;
		vector<double> mExpLevel;
		double mObjValue;
		double mPvalue;
		bool mbAllExplained;
		int mSetCnt;

	private:
		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  LPsolver
		 *       Method:  IsAllExplained
		 *  Description:  Check whether all the junctions and exons with non-zero reads has been
		 *                covered by current isoform
		 *
		 *        Param:  kept_exons
		 *                set_cnt
		 *       Return:  
		 *--------------------------------------------------------------------------------------
		 */
		bool
		AreAllCovered(const vector<int>& set_sizes, const vector<vector<bool> >& isoforms, 
		              const vector<double>& sample_cnt, const vector<vector<double> >& junc_cnt)
		{
			vector<bool> b_exons_covered;
			b_exons_covered.assign(set_sizes.size(), false);
			for (unsigned i = 0; i < isoforms.size(); i++)
				for (unsigned j = 0; j < set_sizes.size(); j++)
					if (isoforms[i][j]) b_exons_covered[j] = true;

			vector<vector<bool> > b_juncs_covered;
			b_juncs_covered.resize(set_sizes.size());
			for (unsigned i = 0; i < set_sizes.size(); i++)
				b_juncs_covered[i].assign(set_sizes.size(), false);

			// set coefficient matrix for junction reads
			for (unsigned i = 0; i < isoforms.size(); i++)
			{
				int start = -1;
				for (unsigned j = 0; j < set_sizes.size(); j++)
				{
					if (isoforms[i][j])
					{
						if (-1 == start)
							start = j;
						else
						{
							b_juncs_covered[start][j] = true;
							start = j;
						}
					}
				}
			}

			bool b_succ = true;
			for (unsigned i = 0; i < set_sizes.size(); i++)
			{
				if (set_sizes[i] > 0 && !b_exons_covered[i])
				{
					b_succ = false;
					break;
				}
			}
			for (unsigned i = 0; i < set_sizes.size(); i++)
			{
				for (unsigned j = 0; j < set_sizes.size(); j++)
				{
					if (junc_cnt[i][j] > 0 && !b_juncs_covered[i][j])
					{
						b_succ = false;
						break;
					}
				}
				if (!b_succ) break;
			}

			return b_succ;
		}

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  LPsolver
		 *       Method:  GetIndex
		 *  Description:  Given a sub set of exons, and the total number of
		 *  			  exons, this methods calculate all the corresponding
		 *  			  indexes of mFitValue and error.
		 *
		 *        Param:  kept_exons
		 *                set_cnt
		 *       Return:  
		 *--------------------------------------------------------------------------------------
		 */
		void
		GetIndex(int set_cnt, vector<int>& kept_exons, vector<int>& index)
		{
			vector<bool> remove_exon;
			remove_exon.assign(set_cnt, false);
			for (int i = 0; i < set_cnt; i++)
				remove_exon[kept_exons[i]] = true;

			index.clear();
			// Be careful about the following index, make sure that
			// it is consistent with the junction index in method 'Solve'
			for (int i = 0; i < set_cnt; i++)
				if (!remove_exon[i]) index.push_back(i);

			for (int i = 0; i < set_cnt; i++)
			{
				for (int j = i+1; j < set_cnt; j++)
				{
					if (remove_exon[i] || remove_exon[j]) continue;
					int idx = (2*set_cnt - i - 1) * i / 2 + j - i - 1;
					index.push_back(set_cnt + idx);
				}
			}
		}		/* -----  end of method LPsolver::GetIndex  ----- */


}; /* -----  end of class LPsolver  ----- */

#endif
