// =====================================================================================
// 
//       Filename:  ResultSummary.h
// 
//    Description:  Summary all the prediction results
// 
//        Version:  1.0
//        Created:  09/28/2009 10:22:09 AM
//       Revision:  none
//       Compiler:  g++
// 
//         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
//        Company:  THU
// 
// =====================================================================================

#ifndef ResultSummary_H 
#define ResultSummary_H

#include <string>
#include <iostream>
#include "InstanceHandler.h"
#include "Utility2.h"
#include "MishMash.h"

using namespace std;

/*
 * =====================================================================================
 *        Class:  ResultSummary
 *  Description:  Calculate various statistics of the predictions.
 * =====================================================================================
 */
class ResultSummary : public InstanceHandler
{
	public:
		/* ====================  LIFECYCLE     ======================================= */
		ResultSummary (IsoInferPE* p_infer, ostream* p_output = NULL) : InstanceHandler(p_output)
		{
			mpInfer = p_infer;
			mOutputFormat = 0;
			mIsoCntScale = 9;
		};                             /* constructor */

		virtual ~ResultSummary (){};                             /* constructor */

		void
		SetOutputFormat(int format){mOutputFormat = format;}

		void
		SetMinIsoCnt(int min_iso_cnt){mMinIsoCnt = min_iso_cnt;}

		void
		SetIsoCntScale(int iso_cnt_scale){mIsoCntScale = iso_cnt_scale;}

		virtual
		void
		Initialize()
		{
			mPredictCnt = 0;	
		};


		virtual
		void
		OnInstance(Instance& an_instance)
		{
			if (an_instance.mIsoforms.size() < mMinIsoCnt) return;
			cout << "Inst : " << an_instance.mInstanceCnt << endl;

			mpInstance = &an_instance;
			mpInfer->OnInstance(an_instance);

			const vector<vector<bool> >& valid_isoforms = mpInfer->GetValidIsoforms();
			const vector<int>& solution = mpInfer->GetSolution();

			if (0 == mOutputFormat || 2 == mOutputFormat)
			{
				vector<Gene> valid_genes;
				valid_genes.resize(valid_isoforms.size());
				for (unsigned i = 0; i < valid_isoforms.size(); i++)
				{
					Gene& gene = valid_genes[i];
					for (unsigned j = 0; j <  valid_isoforms[i].size(); j++)
						if (valid_isoforms[i][j])
							gene.mExons.push_back(an_instance.mExons[j]);
					gene.CalculateRange();
				}

				ResultStatistics(an_instance.mExons, an_instance.mGenes, valid_genes, an_instance.mIsoExp,
						         an_instance.mSpliceReadCnt, solution);
			}
			if (1 == mOutputFormat || 2 == mOutputFormat)
				OutputPrediction(an_instance.mExons, valid_isoforms, solution);

		}/* -----  end of method OnInstance  ----- */

		virtual
		void
		CleanUp(){Output();}

		void
		Output()
		{
		};


	protected:
		//--------------------------------------------------------------------------------------
		//       Class:  ResultSummary
		//      Method:  ResultStatistics
		// Description:  Group the result in different formats
		//  Parameters:  solution  :  The indexes of the final recovered isoforms in valid_isoforms
		//--------------------------------------------------------------------------------------
			void
		ResultStatistics(const vector<Exon>& exons,
						 const vector<Gene>& known_isoforms,
						 const vector<Gene>& valid_isoforms,
						 const vector<double>& known_isoforms_exp, 
						 const vector<vector<double> >& junc_cnt,
						 const vector<int>& solution)
		{
			int scale_low = -15;
			int scale_high = 10;

			if (mKnown_by_Exp_2.size() == 0)
			{
				mExpScale = scale_high - scale_low + 1;

				mKnown_by_Exp_2.assign(mExpScale, 0);
				mKnown_by_IsoCnt.assign(mIsoCntScale, 0);
				mHighKnown_by_IsoCnt.assign(mIsoCntScale, 0);

				mValid_by_IsoCnt.assign(mIsoCntScale, 0);
				mKnownInValid_by_Exp_2.assign(mExpScale, 0);
				mKnownInValid_by_IsoCnt.assign(mIsoCntScale, 0);
				mValidInKnown_by_IsoCnt.assign(mIsoCntScale, 0);

				mSolution_by_IsoCnt.assign(mIsoCntScale, 0);
				mKnownInSolution_by_Exp_2.assign(mExpScale, 0);
				mKnownInSolution_by_IsoCnt.assign(mIsoCntScale, 0);
				mSolutionInKnown_by_IsoCnt.assign(mIsoCntScale, 0);

				mIsoCntScale--;
			}

			vector<bool> is_solution;
			is_solution.assign(valid_isoforms.size(), false);
			for (unsigned i = 0; i < solution.size(); i++)
				is_solution[solution[i]] = true;

			int iso_cnt = known_isoforms.size();
			if (iso_cnt > mIsoCntScale) iso_cnt = mIsoCntScale;

			mKnown_by_IsoCnt[iso_cnt] += known_isoforms.size();
			mValid_by_IsoCnt[iso_cnt] += valid_isoforms.size();
			mSolution_by_IsoCnt[iso_cnt] += solution.size();

			set<int> known_in_valid_idx;
			set<int> known_in_solution_idx;
			set<int> valid_in_known_idx;
			set<int> solution_in_known_idx;
			for (unsigned i = 0; i < valid_isoforms.size(); i++)
			{	
				vector<int> in_idx;
				MatchedIndex(valid_isoforms[i], known_isoforms, in_idx);

				if (in_idx.size() > 0)
				{
					valid_in_known_idx.insert(i);
					if (is_solution[i])
						solution_in_known_idx.insert(i);	
				}

				for(unsigned j = 0; j < in_idx.size(); j++)
					known_in_valid_idx.insert(in_idx[j]);
				if (is_solution[i])
				{
					for(unsigned j = 0; j < in_idx.size(); j++)
						known_in_solution_idx.insert(in_idx[j]);
				}
			}

			int known_curr = known_isoforms.size();
			int valid_curr = valid_isoforms.size();
			int solution_curr = solution.size();

			int known_in_valid_curr = known_in_valid_idx.size();
			int known_in_solution_curr = known_in_solution_idx.size();
			int valid_in_known_curr = valid_in_known_idx.size();
			int solution_in_known_curr = solution_in_known_idx.size();

			mKnownInValid_by_IsoCnt[iso_cnt] += known_in_valid_curr;
			mKnownInSolution_by_IsoCnt[iso_cnt] += known_in_solution_curr;
			mValidInKnown_by_IsoCnt[iso_cnt] += valid_in_known_curr;
			mSolutionInKnown_by_IsoCnt[iso_cnt] += solution_in_known_curr;

			for (unsigned i = 0; i < known_isoforms.size(); i++)
			{
				double f_scale = log(known_isoforms_exp[i]) / log(2);
				if (f_scale < 0) f_scale -= 0.9999999;
				int exp_scale_2 = (int)(f_scale);
				if (exp_scale_2 < scale_low) exp_scale_2 = scale_low;
				if (exp_scale_2 > scale_high) exp_scale_2 = scale_high;
				exp_scale_2 -= scale_low;

				mKnown_by_Exp_2[exp_scale_2]++;

				if (known_in_valid_idx.find(i) != known_in_valid_idx.end())
					mKnownInValid_by_Exp_2[exp_scale_2]++;
				if (known_in_solution_idx.find(i) != known_in_solution_idx.end())
					mKnownInSolution_by_Exp_2[exp_scale_2]++;
			}

			int high_curr = 0;
			for (unsigned i = 0; i < known_isoforms.size(); i++)
				if (IsSupportedByJunc(exons, known_isoforms[i], junc_cnt))
					high_curr++;

			mHighKnown_by_IsoCnt[iso_cnt] += high_curr;

			// Over all result
			int known_all = 0;
			for (unsigned i = 0; i < mKnown_by_IsoCnt.size(); i++)
				known_all += mKnown_by_IsoCnt[i];

			int valid_all = 0;
			for (unsigned i = 0; i < mValid_by_IsoCnt.size(); i++)
				valid_all += mValid_by_IsoCnt[i];

			int known_in_valid_all = 0;
			for (unsigned i = 0; i < mKnownInValid_by_IsoCnt.size(); i++)
				known_in_valid_all += mKnownInValid_by_IsoCnt[i];

			int valid_in_known_all = 0;
			for (unsigned i = 0; i < mValidInKnown_by_IsoCnt.size(); i++)
				valid_in_known_all += mValidInKnown_by_IsoCnt[i];

			int solution_all = 0;
			for (unsigned i = 0; i < mSolution_by_IsoCnt.size(); i++)
				solution_all += mSolution_by_IsoCnt[i];

			int known_in_solution_all = 0;
			for (unsigned i = 0; i < mKnownInSolution_by_IsoCnt.size(); i++)
				known_in_solution_all += mKnownInSolution_by_IsoCnt[i];

			int solution_in_known_all = 0;
			for (unsigned i = 0; i < mSolutionInKnown_by_IsoCnt.size(); i++)
				solution_in_known_all += mSolutionInKnown_by_IsoCnt[i];

			int high_all = 0;
			for (unsigned i = 0; i < mHighKnown_by_IsoCnt.size(); i++)
				high_all += mHighKnown_by_IsoCnt[i];

			// Output result
			(*mpOutput) << "Inst : " << mpInstance->mInstanceCnt << endl;
			(*mpOutput) << "Results current: " << endl;
			(*mpOutput) << "    known_curr             : " << setw(6) << known_curr << endl;
			(*mpOutput) << "    high_curr              : " << setw(6) << high_curr << endl;
			(*mpOutput) << "    valid_curr             : " << setw(6) << valid_curr << endl;
			(*mpOutput) << "    solution_curr          : " << setw(6) << solution_curr << endl;
			(*mpOutput) << "    known_in_valid_curr    : " << setw(6) << known_in_valid_curr;
			(*mpOutput) << "    sen = : " << (double) known_in_valid_curr / known_curr << endl;
			(*mpOutput) << "    valid_in_known_curr    : " << setw(6) << valid_in_known_curr;
			(*mpOutput) << "    spe = : " << (double) valid_in_known_curr / valid_curr << endl;
			(*mpOutput) << "    known_in_solution_curr : " << setw(6) << known_in_solution_curr;
			(*mpOutput) << "    sen = : " << (double) known_in_solution_curr / known_curr << endl;
			(*mpOutput) << "    solution_in_known_curr : " << setw(6) << solution_in_known_curr;
			(*mpOutput) << "    spe = : " << (double) solution_in_known_curr / solution_curr << endl;
			(*mpOutput) << "    high ratio             : " << (double) high_curr / known_curr << endl;

			(*mpOutput) << "Results overall: " << endl;
			(*mpOutput) << "     known_all             : " << setw(6) << known_all << endl;
			(*mpOutput) << "     high_all              : " << setw(6) << high_all << endl;
			(*mpOutput) << "     valid_all             : " << setw(6) << valid_all << endl;
			(*mpOutput) << "     solution_all          : " << setw(6) << solution_all << endl;
			(*mpOutput) << "     known_in_valid_all    : " << setw(6) << known_in_valid_all;
			(*mpOutput) << "     sen = : " << (double) known_in_valid_all / known_all << endl;
			(*mpOutput) << "     valid_in_known_all    : " << setw(6) << valid_in_known_all;
			(*mpOutput) << "     spe = : " << (double) valid_in_known_all / valid_all << endl;
			(*mpOutput) << "     known_in_solution_all : " << setw(6) << known_in_solution_all;
			(*mpOutput) << "     sen = : " << (double) known_in_solution_all / known_all << endl;
			(*mpOutput) << "     solution_in_known_all : " << setw(6) << solution_in_known_all;
			(*mpOutput) << "     spe = : " << (double) solution_in_known_all / solution_all << endl;
			(*mpOutput) << "     high ratio            : " << (double) high_all / known_all << endl;

			(*mpOutput) << "==================================================================" << endl;
			(*mpOutput) << "Results by Exp 2 : scale                =  ";
			for (int i = scale_low; i <= scale_high; i++)
				(*mpOutput) << i << "\t";
			(*mpOutput) <<                                    endl;

			(*mpOutput) << "Results by Exp 2 : known                =  ";
			for (unsigned i = 0; i < mKnown_by_Exp_2.size(); i++)
				(*mpOutput) << mKnown_by_Exp_2[i] << "\t";
			(*mpOutput) <<                                    endl;

			(*mpOutput) << "Results by Exp 2 : Known in valid       =  ";
			for (unsigned i = 0; i < mKnownInValid_by_Exp_2.size(); i++)
				(*mpOutput) << mKnownInValid_by_Exp_2[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by Exp 2 : Sensitivity          =  ";
			for (unsigned i = 0; i < mKnownInValid_by_Exp_2.size(); i++)
				(*mpOutput) << (double)mKnownInValid_by_Exp_2[i] / mKnown_by_Exp_2[i] << "\t";
			(*mpOutput) <<                                    endl;

			(*mpOutput) << "Results by Exp 2 : Known in solution    =  ";
			for (unsigned i = 0; i < mKnownInSolution_by_Exp_2.size(); i++)
				(*mpOutput) << mKnownInSolution_by_Exp_2[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by Exp 2 : Sensitivity          =  ";
			for (unsigned i = 0; i < mKnownInSolution_by_Exp_2.size(); i++)
				(*mpOutput) << (double)mKnownInSolution_by_Exp_2[i] / mKnown_by_Exp_2[i] << "\t";
			(*mpOutput) <<                                    endl;

			
			(*mpOutput) << "==================================================================" << endl;
			(*mpOutput) << "Results by IsoCnt : known                  =  ";
			for (unsigned i = 0; i < mKnown_by_IsoCnt.size(); i++)
				(*mpOutput) << mKnown_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : High                   =  ";
			for (unsigned i = 0; i < mHighKnown_by_IsoCnt.size(); i++)
				(*mpOutput) << mHighKnown_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : High  ratio            =  ";
			for (unsigned i = 0; i < mHighKnown_by_IsoCnt.size(); i++)
				(*mpOutput) << (double)mHighKnown_by_IsoCnt[i] / mKnown_by_IsoCnt[i]<< "\t";
			(*mpOutput) <<                                    endl;


			(*mpOutput) << "Results by IsoCnt : Valid                  =  ";
			for (unsigned i = 0; i < mKnownInValid_by_IsoCnt.size(); i++)
				(*mpOutput) << mValid_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Known in valid         =  ";
			for (unsigned i = 0; i < mKnownInValid_by_IsoCnt.size(); i++)
				(*mpOutput) << mKnownInValid_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Sensitivity            =  ";
			for (unsigned i = 0; i < mKnownInValid_by_IsoCnt.size(); i++)
				(*mpOutput) << (double)mKnownInValid_by_IsoCnt[i] / mKnown_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Specificity            =  ";
			for (unsigned i = 0; i < mValid_by_IsoCnt.size(); i++)
				(*mpOutput) << (double)mValidInKnown_by_IsoCnt[i] / mValid_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;

			(*mpOutput) << "Results by IsoCnt : Solution               =  ";
			for (unsigned i = 0; i < mSolution_by_IsoCnt.size(); i++)
				(*mpOutput) << mSolution_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Solution in known      =  ";
			for (unsigned i = 0; i < mSolutionInKnown_by_IsoCnt.size(); i++)
				(*mpOutput) << mSolutionInKnown_by_IsoCnt[i]<< "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Sensitivity (Solution) =  ";
			for (unsigned i = 0; i < mSolutionInKnown_by_IsoCnt.size(); i++)
				(*mpOutput) << (double)mKnownInSolution_by_IsoCnt[i] / mKnown_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "Results by IsoCnt : Specificity (Solution) =  ";
			for (unsigned i = 0; i < mSolution_by_IsoCnt.size(); i++)
				(*mpOutput) << (double)mSolutionInKnown_by_IsoCnt[i] / mSolution_by_IsoCnt[i] << "\t";
			(*mpOutput) <<                                    endl;
			(*mpOutput) << "==================================================================" << endl;
			(*mpOutput) <<                                    endl;
		}

		//--------------------------------------------------------------------------------------
		//       Class:  ResultSummary
		//      Method:  IsSupportedByJunc
		// Description:  Whether a given gene is supported by junction reads
		//  Parameters:  
		//        Note:  exons should be sorted.
		//--------------------------------------------------------------------------------------
		bool
		IsSupportedByJunc(const vector<Exon>& exons, const Gene& isoform, const vector<vector<double> >& junc_cnt)
		{
			const vector<Exon>& gene_exons = isoform.mExons;

			vector<int> start_position;
			start_position.resize(exons.size());
			for (unsigned i = 0; i < start_position.size(); i++)
				start_position[i] = exons[i].mStart;

			vector<int> end_position;
			end_position.resize(exons.size());
			for (unsigned i = 0; i < end_position.size(); i++)
				end_position[i] = exons[i].mEnd;

			for (unsigned i = 0; i < gene_exons.size() - 1; i++)
			{
				int64 end_pos = gene_exons[i].mEnd;
				int64 start_pos = gene_exons[i+1].mStart;

				int start_idx = Utility2Temp<int>::BinarySearch(start_position, start_pos);
				if (start_position[start_idx] != start_pos) return false;

				int end_idx = Utility2Temp<int>::BinarySearch(end_position, end_pos);
				if (end_position[end_idx] != end_pos) return false;

				//cout << end_idx << "," << start_idx << "," << junc_cnt[end_idx][start_idx] << endl;
				if (junc_cnt[end_idx][start_idx] <= 0) return false;
			}

			return true;
		}

		//--------------------------------------------------------------------------------------
		//       Class:  ResultSummary
		//      Method:  MatchedIndex 
		// Description:  Whether a given gene is matching to another set of genes
		//  Parameters:  
		//--------------------------------------------------------------------------------------
		void
		MatchedIndex(const Gene& gene, const vector<Gene>& genes, vector<int>& matched_idx)
		{
			matched_idx.clear();

			for (unsigned i = 0; i < genes.size(); i++)
			{
				if (gene.mStart >= genes[i].mEnd || gene.mEnd <= genes[i].mStart) continue;

				double match_score = MishMash::CompareTwoIsoforms(gene, genes[i]);

				if (1 == match_score)
					matched_idx.push_back(i);
			}
		}

		//--------------------------------------------------------------------------------------
		//       Class:  ResultSummary
		//      Method:  OutputPrediction
		// Description:  Group the result in different formats
		//  Parameters:  solution  :  The indexes of the final recovered isoforms in valid_isoforms
		//--------------------------------------------------------------------------------------
			void
		OutputPrediction(const vector<Exon>& exons, 
						 const vector<vector<bool> >& valid_isoforms, 
						 const vector<int>& solution)
		{
			for (unsigned i = 0; i < solution.size(); i++)
			{
				(*mpOutput) << "Pred" << mPredictCnt++ << "\t" << exons[0].mChr << "\t" << (exons[0].mStrand ? '+' : '-') << "\t";
				
				const vector<bool>& an_iso = valid_isoforms[solution[i]];

				int first_idx = 0;
				for (first_idx = 0; first_idx < an_iso.size(); first_idx++)
					if (an_iso[first_idx])
						break;
				int last_idx = 0;
				for (last_idx = an_iso.size() - 1; last_idx >= 0; last_idx--)
					if (an_iso[last_idx])
						break;

				(*mpOutput) << exons[first_idx].mStart << "\t" << exons[last_idx].mEnd << "\t";

				bool b_first = true;
				for (unsigned j = 0; j < an_iso.size(); j++)
				{
					if (an_iso[j])
					{
						if (b_first) (*mpOutput) << exons[j].mStart;
						else (*mpOutput) << "," << exons[j].mStart;
						b_first = false;
					}
				}
				(*mpOutput) << "\t";

				b_first = true;
				for (unsigned j = 0; j < an_iso.size(); j++)
				{
					if (an_iso[j])
					{
						if (b_first) (*mpOutput) << exons[j].mEnd;
						else (*mpOutput) << "," << exons[j].mEnd;
						b_first = false;
					}
				}
				(*mpOutput) << endl;
			}
		}

	private:
		IsoInferPE* mpInfer;

		int mMinIsoCnt;          // Only genes with at least mMinIsoCnt isoforms will be considered

		// Group the result by
		// 1. Expression levels of hiden isoforms
		// 2. #isoforms/gene
		// 3. Overall 
		int mExpScale;
		int mIsoCntScale;

		vector<int> mKnown_by_Exp_2;
		vector<int> mKnown_by_IsoCnt;

		vector<int> mHighKnown_by_IsoCnt;

		// The result of step 1
		// #Valid isoforms
		vector<int> mValid_by_IsoCnt;
		// #Intersection of valid and known isoforms.
		vector<int> mKnownInValid_by_Exp_2;
		vector<int> mKnownInValid_by_IsoCnt;
		vector<int> mValidInKnown_by_IsoCnt;

		// The result of step 2
		// #Solutions
		vector<int> mSolution_by_IsoCnt;
		// #Intersection of solution and known isoforms.
		vector<int> mKnownInSolution_by_Exp_2;
		vector<int> mKnownInSolution_by_IsoCnt;
		vector<int> mSolutionInKnown_by_IsoCnt;

		int mOutputFormat;
		int mPredictCnt; 
}; /* -----  end of class ResultSummary  ----- */

#endif
