// =====================================================================================
// 
//       Filename:  ShortReadGroup.hpp
// 
//    Description:  This file defines class ShortReadGroup
// 
//        Version:  1.0
//        Created:  02/05/2010 05:02:32 PM
//       Revision:  none
//       Compiler:  g++
// 
//         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
//        Company:  THU
// 
// =====================================================================================

#ifndef ShortReadGroup_H 
#define ShortReadGroup_H

#include "ShortRead.hpp"

//#define DEBUG0

/*
 * =====================================================================================
 *        Class:  ShortReadGroup
 *  Description:  This class groups a set of short reads with different properties
 * =====================================================================================
 */

class ShortReadGroup
{
	public:
		vector<ShortRead>      mShortReads;

		void Print() const
		{
			for (unsigned i = 0; i < mShortReads.size(); ++i)
			{
				const vector<Pattern>& patterns = mShortReads[i].mPatterns;
				const vector<double>& pattern_dup = mShortReads[i].mPatternDup;
				cout << "Reads\t" << patterns.size() << endl;
				for (unsigned j = 0; j < patterns.size(); ++j)
				{
					const vector<vector<int> >& segs = patterns[j].mMappedSegs;

					for (unsigned k = 0; k < segs.size(); ++k)
					{
						for (unsigned l = 0; l < segs[k].size(); ++l)
							cout << segs[k][l] << "\t";
						cout << ";";
					}
					cout << pattern_dup[j] << endl;
				}
			}
		}

	private:
		void Flatten()
		{
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				mShortReads[i].Flatten();
		}

		// Return the number of reads of a specified pattern
		double EquPatternCnt(const Pattern& pattern) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				sum += mShortReads[i].EquPatternCnt(pattern);
			return sum;
		}
		
		// Return the number of reads covering a specified pattern
		// If b_jump = false (true), only patterns that contain the specified
		// pattern as a substring will be considered
		double SupPatternCnt(const Pattern& pattern, bool b_jump = false) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				sum += mShortReads[i].SupPatternCnt(pattern, b_jump);
			return sum;
		}

		// Return the number of reads covered by a specified pattern
		// If b_jump = false (true), only patterns that contain the specified
		// pattern as a substring (subsequence) will be considered
		double SubPatternCnt(const Pattern& pattern, bool b_jump = false) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				sum += mShortReads[i].SubPatternCnt(pattern, b_jump);
			return sum;
		}

	public:
		// Return the total number of reads in the whole experiment
		double TotalReadCnt() const
		{
			double sum = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				sum += mShortReads[i].mpReadInfo->mTotalReadCnt;
			return sum;
		}
		
		// Return the total number of short reads of all types
		// or the number of short reads that cover a specified segment
		double ReadCnt(int seg = -1) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				sum += mShortReads[i].ReadCnt(seg);
			return sum;
		}

		// Construct a matrix junc_cnt, such that junc_cnt[i][j] storing
		// the number of reads falling into junction (i,j)
		void JuncCnt(int seg_cnt, vector<vector<double> >& junc_cnt) const
		{
			Pattern patt;
			patt.mMappedSegs.resize(1);
			vector<int>& segs = patt.mMappedSegs[0];

			segs.resize(2);
			junc_cnt.resize(seg_cnt);
			for (unsigned i = 0; i < junc_cnt.size(); ++i)
			{
				segs[0] = i;
				junc_cnt[i].assign(seg_cnt, 0);
				for (unsigned j = i+1; j < junc_cnt.size(); ++j)
				{
					segs[1] = j;
					junc_cnt[i][j] = SupPatternCnt(patt);
				}
			}
		}

		// Calculate the expression level in RPKM of a junction
		double JuncExp(int first_seg, int second_seg) const
		{
			double ave = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				ave += mShortReads[i].JuncExp(first_seg, second_seg);
			return ave / mShortReads.size();
		}

		// Calculate the expression level in RPKM of a segment
		// Note that the expression level estimated in this function
		// is not accurate (over-estimating) especially when the 
		// read length is close to the length of an isoform
		double SegExp(int seg_len, int seg) const
		{
			double ave = 0;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
			{
				vector<int> segs;
				segs.push_back(seg);

				double exp = mShortReads[i].ReadCnt(seg) * 1000000.0 / mShortReads[i].mpReadInfo->mTotalReadCnt* 
							 1000.0 / seg_len;
				ave += exp;

			}
			return ave / mShortReads.size();
		}

		// Given an isoform, the length of each segment and the index of
		// a segment, decide whether the specified segment is valid/consistent
		// with all the segment before it. This method will be used by
		// IsoInfer::EnumerateValidDFS
		bool IsCurrentSegValid(const vector<bool>& isoform, const vector<int>& seg_lens, int curr_seg) const
		{
			vector<Pattern> patterns;
			vector<double> vir_lens;
			for (unsigned i = 0; i < mShortReads.size(); ++i)
				mShortReads[i].mpReadInfo->InformativePatterns(isoform, seg_lens, patterns, vir_lens, curr_seg);
			
			/*
			cout << "Informative Pattern cnt = " << patterns.size() << endl;
			cout << "Reads\t" << patterns.size() << endl;
			for (unsigned j = 0; j < patterns.size(); ++j)
			{
				const vector<vector<int> >& segs = patterns[j].mMappedSegs;

				for (unsigned k = 0; k < segs.size(); ++k)
				{
					for (unsigned l = 0; l < segs[k].size(); ++l)
						cout << segs[k][l] << "\t";
					cout << ";";
				}
				cout << endl;
			}

			Print();
			*/

			ShortReadGroup flat_group = *this;
			flat_group.Flatten();

			bool b_valid = true;
			for (unsigned i = 0; i < patterns.size(); ++i)
			{
				// Because the validity of a junction pair has been gauranteed by
				// the splice-graph, we don't check it here. Acturally, we should
				// not check it because the splice-graph may add some adjacent
				// junction reads, such junctions would not pass the following 
				// validity checking.
				//
				// Similarly, the validity of single segment is also gauranteed 
				// during the intron-removing step. Again, it is possible that
				// some single segment would not pass the following validity checking.
				
				vector<int>& segs = patterns[i].mMappedSegs[0];
				if (segs.size() == 1) continue;
				if (segs.size() == 2) 
				{
					bool b_junc = true;
					for (unsigned j = segs[0]+1; j < segs[1]; ++j)
						if (isoform[j]) b_junc = false;
					if (b_junc) continue;
				}

				b_valid = false;
				for (unsigned j = 0; j < flat_group.mShortReads.size(); ++j)
					if (flat_group.mShortReads[j].EquPatternCnt(patterns[i]) > 0)
					{
						b_valid = true;
						break;
					}
				if (!b_valid) break;
			}
			if (!b_valid)
				cout << "Invalid" << endl;
			return b_valid;
		}

		void Shrink(const vector<int>& mapping)
		{
			/*
			cout << "Mapping = ";
			for (unsigned i = 0; i < mapping.size(); ++i)
				cout << mapping[i] << "\t";
			cout << endl;

			cout << "Reads before = " << endl;
			Print();
		    */

			for (unsigned i = 0; i < mShortReads.size(); ++i)
				mShortReads[i].Shrink(mapping);
			/*
			cout << "Reads after = " << endl;
			Print();
			*/
		}

		// Given all the isoforms and the length of each segment, construct all the measures
		// virtual_len_matrix[i][j] will the the virtual length of isoform j on measure (pattern) i
		// virtual length * expression level in RPKM = #reads. All the patterns that are not 
		// supported by the given isoforms will not appear in the constructed measures. In other
		// words, the set of given isoforms is always a feasible solution.
		void ContructMeasures(const vector<vector<bool> >& isoforms, const vector<int>& seg_lens,
				vector<double>& observed_reads, vector<vector<double> >& virtual_len_matrix) const
		{
			ShortReadGroup flat_group = *this;
			flat_group.Flatten();

			vector<Pattern> patterns;
			for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
			{
				// Make sure that all the junctions are informative
				flat_group.mShortReads[k].mpReadInfo->mMinExpLevel = 
					flat_group.mShortReads[k].mpReadInfo->MinExpForJunctionToBeInformative() + 1;
				for (unsigned i = 0; i < isoforms.size(); ++i)
				{
					vector<double> vir_lens;
					flat_group.mShortReads[k].mpReadInfo->InformativePatterns(isoforms[i], seg_lens, patterns, vir_lens);
				}

				for (unsigned i = 0; i < flat_group.mShortReads[k].mPatterns.size(); ++i)
					patterns.push_back(flat_group.mShortReads[k].mPatterns[i]);
			}

			set<Pattern> uniq_patterns;
			for (unsigned i = 0; i < patterns.size(); ++i)
				uniq_patterns.insert(patterns[i]);
			patterns.assign(uniq_patterns.begin(), uniq_patterns.end());

			virtual_len_matrix.resize(patterns.size());
			observed_reads.resize(patterns.size());

			vector<bool> b_covered;
			b_covered.assign(patterns.size(), false);

			for (unsigned i = 0; i < patterns.size(); ++i)
			{
				virtual_len_matrix[i].resize(isoforms.size());
				for (unsigned j = 0; j < isoforms.size(); ++j)
				{
					double sum = 0;
					for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
					{
						ReadInfoBase& read_info = (*flat_group.mShortReads[k].mpReadInfo);
						sum += read_info.PatternVirtualLength(isoforms[j], seg_lens, patterns[i])
								 / 1000 * read_info.mTotalReadCnt / 1000000;
						// The adjusted virtual length * exp in RPKM = #expected reads
					}
					virtual_len_matrix[i][j] = sum;

					if (sum > 0) b_covered[i] = true;
				}
				
				observed_reads[i] = flat_group.EquPatternCnt(patterns[i]);
			}

			// Remove uncovered measures
			int new_size = 0;
			for (unsigned i = 0; i < b_covered.size(); ++i)
			{
				if (b_covered[i])
				{
					patterns[new_size] = patterns[i];
					observed_reads[new_size] = observed_reads[i];
					virtual_len_matrix[new_size++] = virtual_len_matrix[i];
				}
			}

			observed_reads.resize(new_size);
			virtual_len_matrix.resize(new_size);	



#ifdef DEBUG0
			//Print();
			cout << endl;
			cout << __func__ << endl;
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				cout << seg_lens[i] << "\t";
			cout << endl;
			cout << "The isoforms are :" << endl;
			for (unsigned i = 0; i < isoforms.size(); ++i)
			{
				for (unsigned j = 0; j < isoforms[i].size(); ++j)
					cout << isoforms[i][j] << "\t";
				cout << endl;
			}
			cout << "The virtual_len_matrix is :" << endl;
			for (unsigned i = 0; i < virtual_len_matrix.size(); ++i)
			{
				cout << i << " | ";
				for (unsigned j = 0; j < virtual_len_matrix[i].size(); ++j)
					cout << virtual_len_matrix[i][j] << "\t";
				cout << " | ";
				for (unsigned j = 0; j < patterns[i].mMappedSegs.size(); ++j)
				{
					for (unsigned k = 0; k < patterns[i].mMappedSegs[j].size(); ++k)
						cout << patterns[i].mMappedSegs[j][k] << "\t";
					cout << ";";
				}
				cout << " | " << observed_reads[i];
				cout << endl;
			}
#endif
		}
		
		// The function of this method is the same as ConstructMeasures. However, the way of counting
		// reads falling onto a segment is different. Here, a read belongs to a segment iff the start
		// and end point of this read is in the segment. A read belongs to a junction iff the two
		// segments that form the junction have been covered by this read. The measures only contain 
		// single segments and junctions
		void ContructMeasuresSimple(const vector<vector<bool> >& isoforms, const vector<int>& seg_lens,
				vector<double>& observed_reads, vector<vector<double> >& virtual_len_matrix) const
		{
			set<int> uniq_segs;
			set<pair<int, int> > uniq_junctions;
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				uniq_segs.insert(i);
			for (unsigned i = 0; i < isoforms.size(); ++i)
			{
				int start = -1;
				for (unsigned j = 0; j < isoforms[i].size(); ++j)
				{
					if (!isoforms[i][j]) continue;
					if (-1 != start)
						uniq_junctions.insert(pair<int, int>(start, j));
					start = j;
				}
			}

			int measure_cnt = uniq_junctions.size() + uniq_segs.size();

			vector<bool> b_covered;
			b_covered.assign(measure_cnt, false);

			virtual_len_matrix.resize(measure_cnt);
			observed_reads.resize(measure_cnt);

			ShortReadGroup flat_group = *this;
			flat_group.Flatten();
			measure_cnt = 0;
			for (set<int>::iterator iter = uniq_segs.begin(); iter != uniq_segs.end(); ++iter)
			{
				virtual_len_matrix[measure_cnt].assign(isoforms.size(), 0);
				for (unsigned j = 0; j < isoforms.size(); ++j)
				{
					double len_sum = 0;
					if (isoforms[j][*iter])
					{	
						for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
						{
							ReadInfoBase& read_info = (*flat_group.mShortReads[k].mpReadInfo);
							if (seg_lens[*iter] - read_info.ReadLen() + 1 > 0)
								len_sum += (seg_lens[*iter] - read_info.ReadLen() + 1) / 1000 * read_info.mTotalReadCnt / 1000000;
						}
					}
					if (len_sum > 0) b_covered[measure_cnt] = true;
					virtual_len_matrix[measure_cnt][j] = len_sum;
				}

				Pattern pattern;
				vector<int> segs;
				segs.push_back(*iter);
				pattern.mMappedSegs.push_back(segs);

				double read_sum = 0;
				for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
				{
					read_sum += flat_group.EquPatternCnt(pattern);

					/*
					for (unsigned j = 0; j < flat_group.mShortReads[k].mPatterns.size(); ++j)
						if (*iter == flat_group.mShortReads[k].mPatterns[j].mMappedSegs[0][0])
							read_sum += flat_group.mShortReads[k].mPatternDup[j];
					*/
				}
				observed_reads[measure_cnt] = read_sum;
				
				++measure_cnt;
				// cout << *iter << "\t" << seg_lens[*iter] << "\t" << read_sum << endl;
			}

			for (set<pair<int, int> >::iterator iter = uniq_junctions.begin(); iter != uniq_junctions.end(); ++iter)
			{
				Pattern pattern;
				vector<int> segs;
				segs.push_back(iter->first);
				segs.push_back(iter->second);
				pattern.mMappedSegs.push_back(segs);

				virtual_len_matrix[measure_cnt].resize(isoforms.size());
				for (unsigned j = 0; j < isoforms.size(); ++j)
				{
					bool b_in_isoform = true;
					if (!isoforms[j][segs[0]]) b_in_isoform = false;
					if (!isoforms[j][segs[1]]) b_in_isoform = false;
					for (unsigned k = segs[0] + 1; k < segs[1]; ++k)
						if (isoforms[j][k]) b_in_isoform = false;

					double len_sum = 0;
					if (b_in_isoform)
					{	
						for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
						{
							ReadInfoBase& read_info = (*flat_group.mShortReads[k].mpReadInfo);
							len_sum += read_info.JuncLen() / 1000 * read_info.mTotalReadCnt / 1000000;
						}
					}
					if (len_sum > 0) b_covered[measure_cnt] = true;
					virtual_len_matrix[measure_cnt][j] = len_sum;
				}
				observed_reads[measure_cnt] = flat_group.SupPatternCnt(pattern);
				// cout << segs[0] << "\t" << segs[1] << "\t" << flat_group.mShortReads[0].mpReadInfo->JuncLen() << "\t" << observed_reads[measure_cnt] << endl;
				++measure_cnt;
			}
			
			// Remove uncovered measures
			int new_size = 0;
			for (unsigned i = 0; i < b_covered.size(); ++i)
			{
				if (b_covered[i])
				{
					observed_reads[new_size] = observed_reads[i];
					virtual_len_matrix[new_size++] = virtual_len_matrix[i];
				}
			}

			observed_reads.resize(new_size);
			virtual_len_matrix.resize(new_size);	

#ifdef DEBUG0
			cout << endl;
			cout << __func__ << endl;
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				cout << seg_lens[i] << "\t";
			cout << endl;
			cout << "The isoforms are :" << endl;
			for (unsigned i = 0; i < isoforms.size(); ++i)
			{
				for (unsigned j = 0; j < isoforms[i].size(); ++j)
					cout << isoforms[i][j] << "\t";
				cout << endl;
			}
			cout << "The virtual_len_matrix is :" << endl;
			for (unsigned i = 0; i < virtual_len_matrix.size(); ++i)
			{
				cout << i << " | ";
				for (unsigned j = 0; j < virtual_len_matrix[i].size(); ++j)
					cout << virtual_len_matrix[i][j] << "\t";
				cout << " | " << observed_reads[i];
				cout << endl;
			}
#endif
		}

		// The function of this method is the same as ConstructMeasures. However, the way of counting
		// reads falling onto a segment is different. Here, a read belongs to a segment iff the start
		// point of this read is in the segment. A read belongs to a junction iff the two
		// segments that form the junction have been covered by this read. The measures only contain 
		// single segments and junctions
		void ContructMeasuresStart(const vector<vector<bool> >& isoforms, const vector<int>& seg_lens,
				vector<double>& observed_reads, vector<vector<double> >& virtual_len_matrix) const
		{
			set<int> uniq_segs;
			set<pair<int, int> > uniq_junctions;
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				uniq_segs.insert(i);
			for (unsigned i = 0; i < isoforms.size(); ++i)
			{
				int start = -1;
				for (unsigned j = 0; j < isoforms[i].size(); ++j)
				{
					if (!isoforms[i][j]) continue;
					if (-1 != start)
						uniq_junctions.insert(pair<int, int>(start, j));
					start = j;
				}
			}

			int measure_cnt = uniq_junctions.size() + uniq_segs.size();

			vector<bool> b_covered;
			b_covered.assign(measure_cnt, false);

			virtual_len_matrix.resize(measure_cnt);
			observed_reads.resize(measure_cnt);

			ShortReadGroup flat_group = *this;
			flat_group.Flatten();
			measure_cnt = 0;
			for (set<int>::iterator iter = uniq_segs.begin(); iter != uniq_segs.end(); ++iter)
			{
				virtual_len_matrix[measure_cnt].assign(isoforms.size(), 0);
				for (unsigned j = 0; j < isoforms.size(); ++j)
				{
					double len_sum = 0;
					if (isoforms[j][*iter])
					{	
						for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
						{
							ReadInfoBase& read_info = (*flat_group.mShortReads[k].mpReadInfo);
							len_sum += (seg_lens[*iter]) / 1000 * read_info.mTotalReadCnt / 1000000;
						}
					}
					if (len_sum > 0) b_covered[measure_cnt] = true;
					virtual_len_matrix[measure_cnt][j] = len_sum;
				}

				Pattern pattern;
				vector<int> segs;
				segs.push_back(*iter);
				pattern.mMappedSegs.push_back(segs);

				double read_sum = 0;
				for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
				{
					for (unsigned j = 0; j < flat_group.mShortReads[k].mPatterns.size(); ++j)
						if (*iter == flat_group.mShortReads[k].mPatterns[j].mMappedSegs[0][0])
							read_sum += flat_group.mShortReads[k].mPatternDup[j];
				}
				observed_reads[measure_cnt] = read_sum;
				
				++measure_cnt;
				// cout << *iter << "\t" << seg_lens[*iter] << "\t" << read_sum << endl;
			}

			for (set<pair<int, int> >::iterator iter = uniq_junctions.begin(); iter != uniq_junctions.end(); ++iter)
			{
				Pattern pattern;
				vector<int> segs;
				segs.push_back(iter->first);
				segs.push_back(iter->second);
				pattern.mMappedSegs.push_back(segs);

				virtual_len_matrix[measure_cnt].resize(isoforms.size());
				for (unsigned j = 0; j < isoforms.size(); ++j)
				{
					bool b_in_isoform = true;
					if (!isoforms[j][segs[0]]) b_in_isoform = false;
					if (!isoforms[j][segs[1]]) b_in_isoform = false;
					for (unsigned k = segs[0] + 1; k < segs[1]; ++k)
						if (isoforms[j][k]) b_in_isoform = false;

					double len_sum = 0;
					if (b_in_isoform)
					{	
						for (unsigned k = 0; k < flat_group.mShortReads.size(); ++k)
						{
							ReadInfoBase& read_info = (*flat_group.mShortReads[k].mpReadInfo);
							len_sum += read_info.JuncLen() / 1000 * read_info.mTotalReadCnt / 1000000;
						}
					}
					if (len_sum > 0) b_covered[measure_cnt] = true;
					virtual_len_matrix[measure_cnt][j] = len_sum;
				}
				observed_reads[measure_cnt] = flat_group.SupPatternCnt(pattern);
				// cout << segs[0] << "\t" << segs[1] << "\t" << flat_group.mShortReads[0].mpReadInfo->JuncLen() << "\t" << observed_reads[measure_cnt] << endl;
				++measure_cnt;
			}
			
			// Remove uncovered measures
			int new_size = 0;
			for (unsigned i = 0; i < b_covered.size(); ++i)
			{
				if (b_covered[i])
				{
					observed_reads[new_size] = observed_reads[i];
					virtual_len_matrix[new_size++] = virtual_len_matrix[i];
				}
			}

			observed_reads.resize(new_size);
			virtual_len_matrix.resize(new_size);	

#ifdef DEBUG0
			cout << endl;
			cout << __func__ << endl;
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				cout << seg_lens[i] << "\t";
			cout << endl;
			cout << "The isoforms are :" << endl;
			for (unsigned i = 0; i < isoforms.size(); ++i)
			{
				for (unsigned j = 0; j < isoforms[i].size(); ++j)
					cout << isoforms[i][j] << "\t";
				cout << endl;
			}
			cout << "The virtual_len_matrix is :" << endl;
			for (unsigned i = 0; i < virtual_len_matrix.size(); ++i)
			{
				cout << i << " | ";
				for (unsigned j = 0; j < virtual_len_matrix[i].size(); ++j)
					cout << virtual_len_matrix[i][j] << "\t";
				cout << " | " << observed_reads[i];
				cout << endl;
			}
#endif
		}
};

#endif

