// =====================================================================================
// 
//       Filename:  ReadInfoPE.hpp
// 
//    Description:  The implementation of class ReadInfoPE
// 
//        Version:  1.0
//        Created:  02/05/2010 04:08:26 PM
//       Revision:  none
//       Compiler:  g++
// 
//         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
//        Company:  THU
// 
// =====================================================================================

#ifndef ReadInfoPE_H
#define ReadInfoPE_H

#include "ReadInfoBase.hpp"
//#define DEBUG0

/*
 * =====================================================================================
 *        Class:  ReadInfoPE
 *  Description:  This class defines information of paired end reads
 * =====================================================================================
 */
class ReadInfoPE: public ReadInfoBase
{
	public:
		// =========--------------------------------==========
		// |<------------------ mReadLen ------------------->|
		//          |<--------- Gap -------------->|
		// |<--R-->|          : R : mEndLen;
		
		int mEndLen;

	public:
		virtual void Read(istream* p_in)
		{
			(*p_in) >> mEndLen;
			ReadInfoBase::Read(p_in);
		}
		virtual void Write(ostream* p_out)
		{
			(*p_out) << "1" << endl;
			(*p_out) << mEndLen << " ";
			ReadInfoBase::Write(p_out);
		}
		virtual bool IsSingleEnd() {return false;}
		virtual double JuncLen() 
		{
			// Because a paired-end read has two ends which double
			// the probability of a junction being covered.
			return 2 * (mEndLen - mCrossStrength);
		}

		// Given an isoform, the length of each segment and the index of
		// a segment, get all the informative patterns containing the 
		// specified segment and some segments before the specified 
		// segment and the virtual length of each pattern.
		//
		// WARNING : The following way would not be proper to handle longer
		// ends
		//
		virtual void InformativePatterns(const vector<bool>& isoform, const vector<int>& seg_lens, 
				vector<Pattern>& patterns, vector<double>& vir_lens, int curr_seg)
		{
			if (!isoform[curr_seg]) return;

			int range_low = mpReadLenDist->RangeLow();
			int range_high = mpReadLenDist->RangeHigh();

			// pair check
			int len_sum = 0;
			for (int i = curr_seg; i >= 0; --i)
			{
				if (!isoform[i]) continue;
				len_sum += seg_lens[i];
				if (len_sum < range_low) continue;
				if (len_sum - seg_lens[curr_seg] - seg_lens[i] >= range_high - 2 * mCrossStrength) break;

				vector<int> segs;
				if (i != curr_seg) 
					segs.push_back(i);
				segs.push_back(curr_seg);

				Pattern patt;
				patt.mMappedSegs.push_back(segs);

				double vir_len = PatternVirtualLength(isoform, seg_lens, patt);
				if (IsInformative(vir_len))
				{
					patterns.push_back(patt);
					vir_lens.push_back(vir_len);
				}
			}

			// triple check
			// === === ................ ====
			// pair_start last_one ....... curr_seg
			len_sum = 0;
			int pair_start = curr_seg;
			while (true)
			{
				int last_one = pair_start;
				pair_start--;
				while (pair_start >= 0 && !isoform[pair_start])
					pair_start--;
				if (pair_start < 0) break;
				if (curr_seg == last_one) continue;
				len_sum += seg_lens[last_one];

				if (len_sum + seg_lens[pair_start] + seg_lens[curr_seg] < range_low) continue;
				if (len_sum >= range_high) break;

				vector<int> segs;
				segs.push_back(pair_start);
				segs.push_back(last_one);
				segs.push_back(curr_seg);
				Pattern patt;
				patt.mMappedSegs.push_back(segs);

				double vir_len = PatternVirtualLength(isoform, seg_lens, patt);
				if (IsInformative(vir_len))
				{
					patterns.push_back(patt);
					vir_lens.push_back(vir_len);
				}
			}
			// === .................. === ====
			//  first_seg ..... precessor curr_seg
			len_sum = 0;
			int first_seg = curr_seg;
			int precessor = curr_seg;
			while (true)
			{
				first_seg--;
				while (first_seg >= 0 && !isoform[first_seg])
					first_seg--;
				if (first_seg < 0) break;
				if (precessor == curr_seg)
				{ 
					precessor = first_seg;
					len_sum += seg_lens[precessor];
					continue;
				}

				if (len_sum + seg_lens[first_seg] + seg_lens[curr_seg] < range_low) continue;
				if (len_sum >= range_high) break;
				len_sum += seg_lens[first_seg];

				vector<int> segs;
				segs.push_back(first_seg);
				segs.push_back(precessor);
				segs.push_back(curr_seg);
				Pattern patt;
				patt.mMappedSegs.push_back(segs);

				double vir_len = PatternVirtualLength(isoform, seg_lens, patt);
				if (IsInformative(vir_len))
				{
					patterns.push_back(patt);
					vir_lens.push_back(vir_len);
				}
			}
		}

		virtual double PatternVirtualLength(const vector<bool>& isoform,
				const vector<int>& seg_lens, const Pattern& pattern)
		{
			assert(pattern.mMappedSegs.size() == 1);

			const vector<int>& segs_in_pattern = pattern.mMappedSegs[0];

			vector<bool> b_seg_in_pattern;
			b_seg_in_pattern.assign(isoform.size(), false);
			for (unsigned i = 0; i < segs_in_pattern.size(); ++i)
				b_seg_in_pattern[segs_in_pattern[i]] = true;

			// Each group corresponds to a set of maximal consecutive intervals 
			// in the isoform.
			vector<int> group;
			group.resize(segs_in_pattern.size());
			int group_cnt = 0;
			int patt_idx = 0;
			bool b_first = false;
			for (unsigned i = 0; i < isoform.size(); ++i)
			{
				if (isoform[i])
				{
					if (b_seg_in_pattern[i])
					{
						group[patt_idx] = group_cnt;
						++patt_idx;
						b_first = true;
					}
					else if (b_first)
					{
						++group_cnt;
						b_first = false;
					}
				}
				else if (b_seg_in_pattern[i])
				{
					return 0;
				}
			}
			if (b_first) ++group_cnt;

			// This part involves more than three consecutive segments in
			// the isoform
			if (group_cnt > 2) return 0;
			assert(group_cnt > 0);

			// segs_in_pattern length
			int max_pe_read_len = 0;
			for (unsigned i = segs_in_pattern[0]; i <= segs_in_pattern[segs_in_pattern.size()-1]; ++i)
				if (isoform[i]) max_pe_read_len += seg_lens[i];

			double prob_sum = 0;

#ifdef DEBUG0
			cout << "seg len : ";
			for (unsigned i = 0; i < seg_lens.size(); ++i)
				cout << seg_lens[i] << "\t";
			cout << endl;
			cout << "Isoform : ";
			for (unsigned i = 0; i < isoform.size(); ++i)
				cout << isoform[i] << "\t";
			cout << endl;
			cout << "Pattern : ";
			for (unsigned i = 0; i < segs_in_pattern.size(); ++i)
				cout << segs_in_pattern[i] << "\t";
			cout << endl;
			cout << "Group   : ";
			for (unsigned i = 0; i < group.size(); ++i)
				cout << group[i] << "\t";
			cout << endl;
			cout << "max_pe_read_len = " << max_pe_read_len << endl;
#endif
			if (1 == group_cnt)
			{
				// |======|=============|==|====|
				//       =====---------======
				for (int i = 0; i <= seg_lens[segs_in_pattern[0]] - mCrossStrength; ++i)
				{
					// Find the index of the segs_in_pattern where the end position of 
					// the first end of the paired-end read resides
					int start_idx = 0;
					int sum = 0;
					while (start_idx < segs_in_pattern.size() && sum < i + mEndLen)
						sum += seg_lens[segs_in_pattern[start_idx++]];
					if (segs_in_pattern.size() == start_idx && sum < i + mEndLen) break;
					if (start_idx < segs_in_pattern.size())
						sum += seg_lens[segs_in_pattern[start_idx]];

					// Find the proper range of the end position of the whole 
					// PE read. The conditions are:
					// 1) i + mEndLen <= end_pos - mEndLen <= sum - mCrossStrength
					// 2) max_pe_read_len - length of the last segment in the segs_in_pattern + mCrossStrength - 1 
					//     <= end_pos <= length of the whole segs_in_pattern
					int low = max_pe_read_len - seg_lens[segs_in_pattern[segs_in_pattern.size()-1]] + mCrossStrength - 1;
					if (low < i + 2 * mEndLen) low = i + 2 * mEndLen;

					int high = sum - mCrossStrength + mEndLen;
					if (high > max_pe_read_len) high = max_pe_read_len;
					
					//cout << i << "\t" << low << "\t" << high << endl;
					//assert (low != 2064 || high != 1899);
					if (low > high) break;
					prob_sum += mpReadLenDist->ProbInRange(low - i, high - i + 1);
				}
			}
			else // if (2 == group_cnt)
			{
				// 00000000000???????111111111111111111111        :   Group
				// |======|==|=======|===========|==|====|
				//     =====--------------------======
				int last_idx_group0 = 0;
				int len_group0 = 0;
				for (unsigned i = 0; i < segs_in_pattern.size(); ++i)
					if (0 == group[i]) 
					{
						last_idx_group0 = i;
						len_group0 += seg_lens[segs_in_pattern[i]];
					}

				// Find the proper range of the start position of the whole PE read
				// The conditions are:
				// 1) len_group0 - seg_lens[segs_in_pattern[last_idx_group0]] + mCrossStrength - 1
				//    <= start_pos + mEndLen <= len_group0
				// 2) 0 <= start_pos <= seg_lens[segs_in_pattern[0]] - mCrossStrength
				int start_pos_low = len_group0 - seg_lens[segs_in_pattern[last_idx_group0]] + mCrossStrength - 1 - mEndLen;
				if (start_pos_low < 0) start_pos_low = 0;

				int start_pos_high = seg_lens[segs_in_pattern[0]] - mCrossStrength;
				if (start_pos_high > len_group0 - mEndLen) start_pos_high = len_group0 - mEndLen;


				int len_group1 = 0;
				for (unsigned i = last_idx_group0 + 1; i < segs_in_pattern.size(); ++i)
					len_group1 += seg_lens[segs_in_pattern[i]];

				// Find the proper range of the end position of the whole PE read
				// The conditions are:
				// 1) max_pe_read_len - seg_lens[segs_in_pattern[segs_in_pattern.size()-1]] + mCrossStrength - 1
				//    <= end_pos <= max_pe_read_len 
				// 2) max_pe_read_len - len_group1  <= end_pos - mEndLen <=  
				//    max_pe_read_len - len_group1 + seg_lens[segs_in_pattern[last_idx_group0 + 1]] - mCrossStrength
				int end_pos_low = max_pe_read_len - seg_lens[segs_in_pattern[segs_in_pattern.size()-1]] + mCrossStrength - 1;
				if (end_pos_low < max_pe_read_len - len_group1 + mEndLen) end_pos_low = max_pe_read_len - len_group1 + mEndLen;

				int end_pos_high = max_pe_read_len - len_group1 + seg_lens[segs_in_pattern[last_idx_group0+1]] - mCrossStrength + mEndLen;
				if (end_pos_high > max_pe_read_len) end_pos_high = max_pe_read_len;

				if (start_pos_low <= start_pos_high && end_pos_low <= end_pos_high)
					for (int i = start_pos_low; i <= start_pos_high; ++i)
						prob_sum += mpReadLenDist->ProbInRange(end_pos_low - i, end_pos_high - i + 1);
			}
#ifdef DEBUG0
			cout << "vir len = " << prob_sum << endl;
#endif
			return prob_sum;
		}

};

#endif
