/*
 * =====================================================================================
 *
 *       Filename:  PvalueCalculator.h
 *
 *    Description:  The header file for calss PvalueCalculator
 *
 *        Version:  1.0
 *        Created:  04/09/2009 12:33:45 PM
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
 *        Company:  THU
 *
 * =====================================================================================
 */

#ifndef PvalueCalculator_H
#define PvalueCalculator_H

#include "InstanceHandler.h"
#include "ExpEstimator.h"
#include "IsoInfer.h"

/*
 * =====================================================================================
 *        Class:  PvalueCalculator
 *  Description:  This class calculates proper pvalue for different situations, for 
 *                example, different number of isoforms
 * =====================================================================================
 */
class PvalueCalculator : public InstanceHandler
{
	public:

		/* ====================  LIFECYCLE     ======================================= */
		PvalueCalculator (ExpEstimator* p_solver, ostream* p_output = NULL) : InstanceHandler(p_output)
		{
			SetSolver(p_solver);
			mMaxIsoCnt = 10;
			mMaxExonCnt = 40;
		}
		/* constructor      */

		~PvalueCalculator (){};                            /* destructor       */

		void
		SetHideCnt(int hide_cnt)
		{
			mHideCnt = hide_cnt;
		}

		void
		SetConfidencelevel(double confi_level)
		{
			mConfidenceLevel = confi_level;
		}

		virtual
		void
		Initialize()
		{
			mBeforePvalue.resize(mMaxIsoCnt+1);
			mAfterPvalue.resize(mMaxIsoCnt+1);
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				mBeforePvalue[i].clear();
				mAfterPvalue[i].clear();
			}

			mBeforePvalueExon.resize(mMaxExonCnt+1);
			mAfterPvalueExon.resize(mMaxExonCnt+1);
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				mBeforePvalueExon[i].clear();
				mAfterPvalueExon[i].clear();
			}

			mFN.resize(mMaxExonCnt+1);
			mFP.resize(mMaxExonCnt+1);
			mInst.resize(mMaxExonCnt+1);
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				mFN[i].assign(mMaxIsoCnt+1, 0);
				mFP[i].assign(mMaxIsoCnt+1, 0);
				mInst[i].assign(mMaxIsoCnt+1, 0);
			}

			mAllInst = mAllFP = mAllFN = 0;
		};

		virtual
		void
		OnInstance(Instance& an_instance)
		{
			int 					&instance_cnt =     an_instance.mInstanceCnt;
			vector<int> 			&set_sizes =        an_instance.mSegLen;
			vector<vector<bool> > 	&isoforms =         an_instance.mIsoforms;
			vector<double> 			&iso_exp =          an_instance.mIsoExp;
			vector<Exon> 			&exons =            an_instance.mExons;
			vector<Gene> 			&genes =   		    an_instance.mGenes;
			vector<vector<int> >    &start_exons =      an_instance.mStartExons;
			vector<vector<int> >    &end_exons =        an_instance.mEndExons;

			vector<double> pred_exp_level;
			vector<double> fit_value;

			vector<vector<int> > measure_in_isoform;
			vector<vector<double> > measure_virtual_length;
			vector<double> measure_read;
			vector<int> exon_type;

			IsoInfer isoinfer(mpSolver);
			isoinfer.SetInstance(&an_instance);
			isoinfer.CalculateExonType(start_exons, end_exons, set_sizes.size(), exon_type);
			isoinfer.ConstructOptModel(set_sizes, isoforms, an_instance.mShortReadGroup, 
									   measure_in_isoform, measure_virtual_length, measure_read);
			//double obj_val = mpSolver->SolvePE(measure_in_isoform, measure_virtual_length, measure_read);
			double obj_val = mpSolver->SolveLeastSquareCon(measure_in_isoform, measure_virtual_length, measure_read);
			double pvalue = mpSolver->ResultPvalue();

			vector<double> before_errors;
			mpSolver->GetError(before_errors);
			vector<double> before_fitvalue;
			mpSolver->GetFitValue(before_fitvalue);
			for (int i = 0; i < before_fitvalue.size(); i++)
				before_errors[i] = before_errors[i] / sqrt(before_fitvalue[i]);

			int iso_cnt = isoforms.size();
			if (iso_cnt > mMaxIsoCnt) iso_cnt = mMaxIsoCnt;

			int set_cnt = set_sizes.size();
			if (set_cnt > mMaxExonCnt) set_cnt = mMaxExonCnt;

			// Randomly hide a certain number of known isoforms.
			vector<vector<bool> > kept_known_isos;
			vector<bool> b_hided;
			b_hided.assign(isoforms.size(), false);
			for (int i = 0; i < mHideCnt && i < isoforms.size(); i++)
			{
				int r = rand() % isoforms.size();
				while (b_hided[r]) r = (r+1) % isoforms.size();
				b_hided[r] = true;
			}
			for (int i = 0; i < isoforms.size(); i++)
			{
				if (!b_hided[i])
					kept_known_isos.push_back(isoforms[i]);
			}

			int old_size = measure_read.size();
			isoinfer.CalculateExonType(start_exons, end_exons, set_sizes.size(), exon_type);
			isoinfer.ConstructOptModel(set_sizes, isoforms, an_instance.mShortReadGroup, 
										 measure_in_isoform, measure_virtual_length, measure_read);
			double obj_val2 = 100000;
			double pvalue2 = 0;
			if (measure_read.size() == old_size)
			{
				//obj_val2 = mpSolver->SolvePE(measure_in_isoform, measure_virtual_length, measure_read);
				obj_val = mpSolver->SolveLeastSquareCon(measure_in_isoform, measure_virtual_length, measure_read);
				pvalue2 = mpSolver->ResultPvalue();
			}

			vector<double> after_errors;
			mpSolver->GetError(after_errors);
			vector<double> after_fitvalue;
			mpSolver->GetFitValue(after_fitvalue);
			for (int i = 0; i < after_fitvalue.size(); i++)
				after_errors[i] = after_errors[i] / sqrt(after_fitvalue[i]);

			if (pvalue < mConfidenceLevel)
			{
				mFP[set_cnt][iso_cnt]++;
				mAllFP++;
			}
			if (pvalue2 > mConfidenceLevel)
			{
				mFN[set_cnt][iso_cnt]++;
				mAllFN++;
			}
			mInst[set_cnt][iso_cnt]++;
			mAllInst++;

			cout << "All : FN : FP = " << mAllInst << " : " << mAllFN << " : " << mAllFP << endl;

			mBeforePvalue[iso_cnt].push_back(pvalue);
			mAfterPvalue[iso_cnt].push_back(pvalue2);
			mBeforePvalueExon[set_cnt].push_back(pvalue);
			mAfterPvalueExon[set_cnt].push_back(pvalue2);
		}

		virtual
		void
		CleanUp(){Output();};

		void
		Output()
		{
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				//(*mpOutput) << i << "\tBefore\t";
				for (int j = 0; j < mBeforePvalue[i].size(); j++)
					(*mpOutput) << "," << mBeforePvalue[i][j];
				(*mpOutput) << endl;
				//(*mpOutput) << i << "\tAfter\t";
				for (int j = 0; j < mAfterPvalue[i].size(); j++)
					(*mpOutput) << "," << mAfterPvalue[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Exon" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				//(*mpOutput) << i << "\tBefore\t";
				for (int j = 0; j < mBeforePvalueExon[i].size(); j++)
					(*mpOutput) << "," << mBeforePvalueExon[i][j];
				(*mpOutput) << endl;
				//(*mpOutput) << i << "\tAfter\t";
				for (int j = 0; j < mAfterPvalueExon[i].size(); j++)
					(*mpOutput) << "," << mAfterPvalueExon[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Instance cnt" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				for (int j = 0; j < mMaxIsoCnt; j++)
					(*mpOutput) << setw(6) << mInst[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "FN" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				for (int j = 0; j < mMaxIsoCnt; j++)
					(*mpOutput) << setw(6) << mFN[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "FP" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				for (int j = 0; j < mMaxIsoCnt; j++)
					(*mpOutput) << setw(6) << mFP[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "FN rate" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				for (int j = 0; j < mMaxIsoCnt; j++)
					(*mpOutput) << setw(6) << (double)mFN[i][j] / mInst[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "FP rate" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				for (int j = 0; j < mMaxIsoCnt; j++)
					(*mpOutput) << setw(6) << (double)mFP[i][j] / mInst[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "All : FN : FP = " << mAllInst << " : " << mAllFN << " : " << mAllFP << endl;
		};

	protected:

	private:
		int mMaxIsoCnt;
		int mMaxExonCnt;
		vector<vector<double> > mBeforePvalue;
		vector<vector<double> > mAfterPvalue;
		vector<vector<double> > mBeforePvalueExon;
		vector<vector<double> > mAfterPvalueExon;

		vector<vector<int> > mFN;
		vector<vector<int> > mFP;
		vector<vector<int> > mInst;
		int mAllFN;
		int mAllFP;
		int mAllInst;

		int mHideCnt;
		double mConfidenceLevel;
}; /* -----  end of class PvalueCalculator  ----- */

#endif
