/*
 * =====================================================================================
 *
 *       Filename:  PredictionStat.h
 *
 *    Description:  This is the header file for class PredictionStat
 *
 *        Version:  1.0
 *        Created:  04/10/2009 05:13:39 PM
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
 *        Company:  THU
 *
 * =====================================================================================
 */


#ifndef PredictionStat_H 
#define PredictionStat_H

#include <string>
#include <iostream>
#include "InstanceHandler.h"
#include "Utility2.h"

using namespace std;

/*
 * =====================================================================================
 *        Class:  PredictionStat
 *  Description:  Calculate various statistics of the predictions.
 * =====================================================================================
 */
class PredictionStat : public InstanceHandler
{
	public:
		/* ====================  LIFECYCLE     ======================================= */
		PredictionStat (IsoInfer* p_infer, ostream* p_output = NULL) : InstanceHandler(p_output)
		{
			mpInfer = p_infer;
			mScaleCnt = 20;
			mHideCnt = 0;
			mMaxIsoCnt = 10;
			mMaxExonCnt = 40;
			mMaxOrderCnt= 10;
		};                             /* constructor */

		virtual ~PredictionStat (){};                             /* constructor */

		void
		SetParameters(int hide_cnt, int scale_cnt)
		{
			mHideCnt = hide_cnt;
			mScaleCnt = scale_cnt;
		}

		virtual
		void
		Initialize()
		{
			mSmallSkippedCnt = mLageSkippedCnt = 0;

			mAllValidCnt.resize(mScaleCnt, 0);
			mAllKnownCnt.resize(mScaleCnt, 0);
			mAllPredCnt.resize(mScaleCnt, 0);
			mAllRecoveredCnt.resize(mScaleCnt, 0);
			mAllNonSkippedCnt.resize(mScaleCnt, 0);
			for (int i = 0; i < mScaleCnt; i++)
				mAllRecoveredCnt[i] = mAllValidCnt[i] = mAllKnownCnt[i] = mAllPredCnt[i] = mAllNonSkippedCnt[i] = 0;

			mValidCnt.resize(mMaxExonCnt+1);
			mKnownCnt.resize(mMaxExonCnt+1);
			mRecoveredCnt.resize(mMaxExonCnt+1);
			mPredCnt.resize(mMaxExonCnt+1);
			for (int i = 0; i <= mMaxExonCnt; i++)
			{
				mValidCnt[i].assign(mMaxIsoCnt+1, 0);
				mKnownCnt[i].assign(mMaxIsoCnt+1, 0);
				mRecoveredCnt[i].assign(mMaxIsoCnt+1, 0);
				mPredCnt[i].assign(mMaxIsoCnt+1, 0);
			}

			mOrderCnt.resize(mMaxIsoCnt+1); 
			mOrderRightCnt.resize(mMaxIsoCnt+1);
			for (int i = 0; i <= mMaxIsoCnt; i++)
			{
				mOrderCnt[i].assign(mMaxOrderCnt+1, 0);
				mOrderRightCnt[i].assign(mMaxOrderCnt+1, 0);
			}
		};

		virtual
		void
		OnInstance(Instance& an_instance)
		{
			int 					&instance_cnt =     an_instance.mInstanceCnt;
			vector<int> 			&set_sizes =        an_instance.mSegLen;
			vector<double> 			&sample_cnt =       an_instance.mSampleCnt;
			vector<vector<bool> > 	&isoforms =         an_instance.mIsoforms;
			vector<vector<double> > &splice_read_cnt =  an_instance.mSpliceReadCnt;
			vector<Gene> 			&genes =   		    an_instance.mGenes;
			double                  &noise_level =      an_instance.mNoiseLevel;
			int                     &known_cnt =        an_instance.mKnownCnt;

			if (set_sizes.size() > 40)
			{
				mLageSkippedCnt++;
				cerr << "WARNING : Too many exons (over 40) on instance " << an_instance.mInstanceCnt << " skipped." << endl;
				return;
			}
			if (set_sizes.size() <= 1)
			{
				mSmallSkippedCnt++;
				cerr << "WARNING : Too less exons (no more than 1) on instance " << an_instance.mInstanceCnt << " skipped." << endl;
				return;
			}

			int exon_cnt = set_sizes.size();
			if (exon_cnt > mMaxExonCnt)
				exon_cnt = mMaxExonCnt;

			int iso_cnt = isoforms.size();
			if (iso_cnt > mMaxIsoCnt)
				iso_cnt = mMaxIsoCnt;

			// For cross validation. Randomly hide a certain number of known isoforms.
			vector<int> hided_idx;
			vector<vector<bool> > hided_known_isos;
			vector<vector<bool> > kept_known_isos;

			vector<bool> b_hided;
			b_hided.assign(isoforms.size(), false);
			for (int i = 0; i < mHideCnt && i < isoforms.size(); i++)
			{
				int r = rand() % isoforms.size();
				while (b_hided[r]) r = (r+1) % isoforms.size();
				b_hided[r] = true;
				hided_idx.push_back(r);
			}
			for (int i = 0; i < isoforms.size(); i++)
			{
				if (b_hided[i])
					hided_known_isos.push_back(isoforms[i]);
				else
					kept_known_isos.push_back(isoforms[i]);
			}

			// Remove the known isoforms that contain a junction which can not
			// be explained by the observed data.
			vector<int> removed_idx;
			mpInfer->RemoveUnExpressedIsoform(hided_known_isos, removed_idx, splice_read_cnt);

			isoforms = kept_known_isos;
			known_cnt = kept_known_isos.size();

			mpInfer->OnInstance(an_instance);

			vector<vector<bool> > new_iso;
			new_iso.assign(an_instance.mIsoforms.begin() + an_instance.mKnownCnt, an_instance.mIsoforms.end());

			int inter_size = 0;
			for (int i = 0; i < new_iso.size(); i++)
			{
				int order = i+1;
				if (order > mMaxOrderCnt) order = mMaxOrderCnt;

				for (int j = 0; j < hided_known_isos.size(); j++)
				{
					if (hided_known_isos[j] == new_iso[i])
					{
						inter_size++;
						(*mpOutput) << "Predicted Right Isoforms: " << genes[hided_idx[j]].mName << endl;

						mOrderRightCnt[iso_cnt][order]++;
						break;
					}
				}

				mOrderCnt[iso_cnt][order]++;
			}

			double tot_reads = 0;
			int gene_len = 0;
			for (int i = 0; i < set_sizes.size(); i++)
			{
				tot_reads += sample_cnt[i];
				gene_len += set_sizes[i];
				for (int j = 0; j < set_sizes.size(); j++)
					tot_reads += splice_read_cnt[i][j];
			}

			for (int i = 0; i < mScaleCnt; i++)
			{
				if ((double)tot_reads / gene_len >= i * 0.05)
				{
					mAllNonSkippedCnt[i]++;
					mAllKnownCnt[i] += hided_known_isos.size();
					mAllValidCnt[i] += hided_known_isos.size() - removed_idx.size();
					mAllRecoveredCnt[i] += inter_size;
					mAllPredCnt[i] += new_iso.size();
				}
			}

			mValidCnt[exon_cnt][iso_cnt] += hided_known_isos.size() - removed_idx.size();
			mKnownCnt[exon_cnt][iso_cnt] += hided_known_isos.size();
			mRecoveredCnt[exon_cnt][iso_cnt] += inter_size;
			mPredCnt[exon_cnt][iso_cnt] += new_iso.size();

			for (int i = 0; i < mScaleCnt; i++)
				(*mpOutput) << "Ins : Valid : Known : Pred : Recovered " << i << " -- " << mAllNonSkippedCnt[i] << " : " << mAllValidCnt[i] << " : " <<  mAllKnownCnt[i] << " : " << mAllPredCnt[i] << " : " << mAllRecoveredCnt[i] << endl;

			(*mpOutput) << "Specificity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_pred = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_pred += mPredCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_pred;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "All Sensitivity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_known= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_known += mKnownCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;


			(*mpOutput) << "Valid Sensitivity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_known= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_known += mValidCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;


		}/* -----  end of method OnInstance  ----- */

		virtual
		void
		CleanUp(){Output();}

		void
		Output()
		{
			(*mpOutput) << " Number of instances skipped -- Large : Small = " << mLageSkippedCnt << " : " << mSmallSkippedCnt << endl;
			(*mpOutput) << " By exon * isoform : ======================================================================" << endl;

			(*mpOutput) << setw(6) << "\t"  << " ";
			for (int i = 1; i <= mMaxExonCnt; i++)
				(*mpOutput) << setw(6) << "\t"  << i;
			(*mpOutput) << endl;

			(*mpOutput) << "Known" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mKnownCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Valid" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mValidCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Pred" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mPredCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Recover" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mRecoveredCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Specificity" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << (double)mRecoveredCnt[j][i] / mPredCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "All Sensitivity" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << (double)mRecoveredCnt[j][i] / mKnownCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Valid Sensitivity" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxExonCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << (double)mRecoveredCnt[j][i] / mValidCnt[j][i];
				(*mpOutput) << endl;
			}

			(*mpOutput) << " By exon: ======================================================================" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
				(*mpOutput) << setw(6) << "\t"  << i;
			(*mpOutput) << endl;

			(*mpOutput) << "Known" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum += mKnownCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Valid" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum += mValidCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Pred" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum += mPredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;
		
			(*mpOutput) << "Recover" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Specificity" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum_pred = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_pred += mPredCnt[i][j];
				int sum_recover= 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_pred;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "All Sensitivity" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum_known = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_known += mKnownCnt[i][j];
				int sum_recover= 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Valid Sensitivity" << endl;
			for (int i = 1; i <= mMaxExonCnt; i++)
			{
				int sum_known = 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_known += mValidCnt[i][j];
				int sum_recover= 0;
				for (int j = 1; j <= mMaxIsoCnt; j++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;


			(*mpOutput) << " By isoform: ======================================================================" << endl;
			for (int i = 1; i <= mMaxIsoCnt; i++)
				(*mpOutput) << setw(6) << "\t"  << i;
			(*mpOutput) << endl;

			(*mpOutput) << "Known" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum += mKnownCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Valid " << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum += mValidCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Pred" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum += mPredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;
		
			(*mpOutput) << "Recover" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Specificity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_pred = 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_pred += mPredCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_pred;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "All Sensitivity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_known= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_known += mKnownCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;


			(*mpOutput) << "Valid Sensitivity" << endl;
			for (int j = 1; j <= mMaxIsoCnt; j++)
			{
				int sum_known= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_known += mValidCnt[i][j];
				int sum_recover= 0;
				for (int i = 1; i <= mMaxExonCnt; i++)
					sum_recover += mRecoveredCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_recover / sum_known;
			}
			(*mpOutput) << endl;

			(*mpOutput) << " Order result by isoform: ======================================================================" << endl;
			(*mpOutput) << "Order cnt" << endl;
			for (int i = 0; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxOrderCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mOrderCnt[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Order Right cnt" << endl;
			for (int i = 0; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxOrderCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << mOrderRightCnt[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << "Order Ratio" << endl;
			for (int i = 0; i <= mMaxIsoCnt; i++)
			{
				(*mpOutput) << setw(6) << "\t"  << i;
				for (int j = 1; j <= mMaxOrderCnt; j++)
					(*mpOutput) << setw(6) << "\t"  << (double)mOrderRightCnt[i][j] / mOrderCnt[i][j];
				(*mpOutput) << endl;
			}

			(*mpOutput) << " Order result overall: ======================================================================" << endl;
			(*mpOutput) << "Order cnt" << endl;
			for (int j = 1; j <= mMaxOrderCnt; j++)
			{
				int sum = 0;
				for (int i = 0; i <= mMaxIsoCnt; i++)
					sum += mOrderCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;

			(*mpOutput) << "Order Right cnt" << endl;
			for (int j = 1; j <= mMaxOrderCnt; j++)
			{
				int sum = 0;
				for (int i = 0; i <= mMaxIsoCnt; i++)
					sum += mOrderRightCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << sum;
			}
			(*mpOutput) << endl;


			(*mpOutput) << "Order Ratio" << endl;
			for (int j = 1; j <= mMaxOrderCnt; j++)
			{
				int sum_all = 0;
				for (int i = 0; i <= mMaxIsoCnt; i++)
					sum_all += mOrderCnt[i][j];
				int sum_right = 0;
				for (int i = 0; i <= mMaxIsoCnt; i++)
					sum_right += mOrderRightCnt[i][j];
				(*mpOutput) << setw(6) << "\t"  << (double)sum_right / sum_all;
			}
			(*mpOutput) << endl;

			for (int i = 0; i < mScaleCnt; i++)
				(*mpOutput) << "Ins : Valid : Known : Pred : Recovered " << i << " -- " << mAllNonSkippedCnt[i] << " : " << mAllValidCnt[i] << " : " <<  mAllKnownCnt[i] << " : " << mAllPredCnt[i] << " : " << mAllRecoveredCnt[i] << endl;
		};


	protected:

	private:
		IsoInfer* mpInfer;

		int mSmallSkippedCnt;
		int mLageSkippedCnt;

		vector<int> mAllValidCnt;
		vector<int> mAllKnownCnt;
		vector<int> mAllRecoveredCnt;
		vector<int> mAllPredCnt;
		vector<int> mAllNonSkippedCnt;
		int mHideCnt;
		int mScaleCnt;

		vector<vector<int> > mValidCnt;
		vector<vector<int> > mKnownCnt;
		vector<vector<int> > mRecoveredCnt;
		vector<vector<int> > mPredCnt;

		// Generally, the one first predicted are more likely to be true
		// mOrderRightCnt[i][j] stores the number of cases that j'th prediction
		// is right and mOrderCnt[i][j] stores the number of j'th prediction when
		// the known iso cnt is i.
		vector<vector<int> > mOrderRightCnt;
		vector<vector<int> > mOrderCnt;

		int mMaxExonCnt;
		int mMaxIsoCnt;
		int mMaxOrderCnt;
}; /* -----  end of class PredictionStat  ----- */

#endif
