// =====================================================================================
// 
//       Filename:  ShortRead.h
// 
//    Description:  Define the data structures and methods for different type of short reads
// 
//        Version:  1.0
//        Created:  01/31/2010 10:50:18 AM
//       Revision:  none
//       Compiler:  g++
// 
//         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
//        Company:  THU
// 
// =====================================================================================

#ifndef ShortRead_H 
#define ShortRead_H

#include <vector>
#include <string>
#include <iostream>
#include <algorithm>
#include "Utility.hpp"
#include "ReadInfoBase.hpp"

using namespace std;

class Pattern
{
	public:
		// One instance of this class corresponds to a short read
		//
		// Each element in mMappedSegs should the mapped
		// segments of an un-gaped short read. For a paired
		// end read, the size of mMappedSegs should be 2
		// and each element of mMappedSegs corresponds to
		// one end of the paired-end read.
		vector<vector<int> >       mMappedSegs;

	public:

		bool operator != (const Pattern& patt) const
		{
			return (mMappedSegs != patt.mMappedSegs);
		}

		bool operator == (const Pattern& patt) const
		{
			return (mMappedSegs == patt.mMappedSegs);
		}

		bool operator < (const Pattern& patt) const
		{
			return (mMappedSegs < patt.mMappedSegs);
		}

		bool operator > (const Pattern& patt) const
		{
			return (mMappedSegs > patt.mMappedSegs);
		}

		// Whether this pattern is empty
		bool IsEmpty() const {return (mMappedSegs.size() == 0);}

		// Flatten the mapped information by combining all the
		// mapped segment IDs and removing duplications. To make 
		// the pattern of single-end reads and paired-end reads
		// consistent, we need to flatten the pattern of paired-end
		// reads such that each read is defined only by the 
		// set of mapped segments.
		//
		// However, in some cases, patterns of paired-end reads
		// should not be flattened. For example, when the pattern
		// of a junction is compared with the pattern of a paired-end
		// read, the mapped segments of the two ends of this paired-end
		// read should not be considered at the same time.
		//
		// Becareful about this method. If not clear, search all
		// occurrances of this method for a better understanding.
		void Flatten()
		{
			if (mMappedSegs.size() == 0) return;

			set<int> uniq;
			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
				for (unsigned j = 0; j < mMappedSegs[i].size(); ++j)
					uniq.insert(mMappedSegs[i][j]);

			mMappedSegs.resize(1);
			mMappedSegs[0].assign(uniq.begin(), uniq.end());
			sort(mMappedSegs[0].begin(), mMappedSegs[0].end());
		}

		// Map the index of the patterns to a new set of indexes
		// index i will be mapped to mapping[i]. If mapping[i] == -1, 
		// i is removed. Duplications will be merged. There are two cases
		// should be considered carefully
		//
		// Case 1: mapping could be like :
		// 0 1 2 3 4 5 6 7 8 9 10 11 12
		// 0 0 0 1 1 2 2 1 1 2 2  3  2
		// In this case, all the reads that covers both 1 and 2 would
		// be removed because "2" in the 5th and 6th positions has been
		// sandwitched by 1 in the 4th and 7th position; all the reads
		// that covers 2 and 3 would be removed because "3" in the 11th
		// position has been sandwitched by 2 in the 10th and 12th position.
		// The above situation would happen if the IsoInfer::EnumerateValid
		// has omitted some junction reads.
		//
		// Case 2: mapping could be like:
		// 0 1 2 3 4 5 6 
		// 0 0 0 1 1 1 1 
		// If some part of the read covers non-consecutive segments, e.g.
		// {0,1,2,4,5}, this read is discarded, because a part of the read
		// is supposed to be a consecutive interval in some isoform, the 
		// mapping means that this read does not belong to the isoform under
		// consideration.
		void Shrink(const vector<int>& old_mapping)
		{
			vector<int> mapping = old_mapping;

			// If this some old index of this read is mapped to -1, remove 
			// this read
			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
			{
				for (unsigned j = 0; j < mMappedSegs[i].size(); ++j)
				{
					if (-1 == mapping[mMappedSegs[i][j]])
					{
						mMappedSegs.resize(0);
						return;
					}
				}
			}

			// Find the max new index
			int max_new_idx = 0;
			for (unsigned i = 0; i < mapping.size(); ++i)
				if (max_new_idx < mapping[i]) max_new_idx = mapping[i];
			++max_new_idx;

			bool b_increased = false;
			for (unsigned i = 0; i < mapping.size(); ++i)
				if (-1 == mapping[i])
				{
					b_increased = true;
					mapping[i] = max_new_idx;
				}
			if (b_increased) ++max_new_idx;


			// For each new index, find out all the old indexes that
			// have been sandwitched by this new index
			vector<vector<int> > sandwitch;
			sandwitch.resize(max_new_idx);
			for (unsigned i = 0; i < sandwitch.size(); ++i)
			{
				unsigned first_pos = mapping.size();
				unsigned last_pos = 0;
				for (unsigned j = 0; j < mapping.size(); ++j)
				{
					if (mapping[j] == i) last_pos = j;
					if (mapping[mapping.size()-j-1] == i) first_pos = mapping.size()-j-1;
				}
				// Find the indexes that have been sandwitched
				for (unsigned j = first_pos + 1; j < last_pos; ++j)
				{
					if (mapping[j] != i) 
						sandwitch[i].push_back(j);
				}
			}

			// Remove this short read if it contains indexes such that some of them has
			// been sandwitched by another one.
			vector<bool> b_mapped;
			b_mapped.assign(mapping.size(), false);
			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
				for (unsigned j = 0; j < mMappedSegs[i].size(); ++j)
					b_mapped[mMappedSegs[i][j]] = true;
			for (unsigned i = 0; i < mapping.size(); ++i)
			{
				if (!b_mapped[i]) continue;
				for (unsigned j = 0; j < sandwitch[mapping[i]].size(); ++j)
					if (b_mapped[sandwitch[mapping[i]][j]])
					{
						mMappedSegs.resize(0);
						return;
					}
			}

			// For each new index, find out the old indexes corresponding to it
			vector<vector<int> > corresponding_old_index;
			corresponding_old_index.resize(max_new_idx);
			for (unsigned i = 0; i < mapping.size(); ++i)
				corresponding_old_index[mapping[i]].push_back(i);
		
			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
			{

				b_mapped.assign(mapping.size(), false);
				for (unsigned j = 0; j < mMappedSegs[i].size(); ++j)
				{
					b_mapped[mMappedSegs[i][j]] = true;
					mMappedSegs[i][j] = mapping[mMappedSegs[i][j]];
				}

				// Remove duplications
				UtilityTemp<int>::RemoveDups(mMappedSegs[i]);

				// Find the first mapped old index of this read
				unsigned old_first = 0;
				for (unsigned j = 0; j < b_mapped.size(); ++j)
					if (b_mapped[j])
					{
						old_first = j;
						break;
					}

				// Find the last mapped old index of this read
				unsigned old_last = 0;
				for (unsigned j = b_mapped.size()-1; j >= 0; --j)
					if (b_mapped[j])
					{
						old_last = j;
						break;
					}

				if (mMappedSegs[i].size() == 1) continue;

				// For the first mapped new index of this read, all the old indexes,
				// greater than old_first, corresponding to this new index, should
				// be covered by this read
				vector<int>& old_first_indexes = corresponding_old_index[mMappedSegs[i][0]];
				for (unsigned j = 0; j < old_first_indexes.size(); ++j)
					if (old_first_indexes[j] > old_first && !b_mapped[old_first_indexes[j]])
					{
						mMappedSegs.resize(0);
						return;
					}

				// For the last mapped new index of this read, all the old indexes,
				// smaller than old_last, corresponding to this new index, should
				// be covered by this read
				vector<int>& old_last_indexes = corresponding_old_index[mMappedSegs[i][mMappedSegs[i].size()-1]];
				for (unsigned j = 0; j < old_last_indexes.size(); ++j)
					if (old_last_indexes[j] < old_last && !b_mapped[old_last_indexes[j]])
					{
						mMappedSegs.resize(0);
						return;
					}

				// For mapped new indexes not the first and the last one of this read, 
				// all the old indexes corresponding to the new indexes, should
				// be covered by this read
				vector<int>& old_mid_indexes = corresponding_old_index[mMappedSegs[i][mMappedSegs[i].size()-1]];
				for (unsigned j = 1; j < mMappedSegs[i].size() - 1; ++j)
				{
					vector<int>& old_mid_indexes = corresponding_old_index[mMappedSegs[i][j]];
					for (unsigned k = 0; k < old_mid_indexes.size(); ++k)
						if (!b_mapped[old_mid_indexes[k]])
						{
							mMappedSegs.resize(0);
							return;
						}
				}
			} // for (unsigned i = 0; i < mMappedSegs.size(); ++i)
		} // Shrink(const vector<int>& mapping)

		// Whether the current pattern is contained in the second_pattern.
		bool IsContainedIn(const Pattern& second_pattern, bool b_jump = false) const
		{
			if (IsEmpty()) return false;

			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
			{
				bool b_contained = false;
				for (unsigned j = 0; j < second_pattern.mMappedSegs.size(); ++j)
				{
					if (IsContainedIn(mMappedSegs[i], second_pattern.mMappedSegs[j], b_jump))
					{
						b_contained = true; break;
					}
				}
				if (!b_contained) return false;
			}

			return true;
		}

		// Return the maximum index in this patterns.
		int MaxSegIndex() const
		{
			int max_id = 0;
			for (unsigned i = 0; i < mMappedSegs.size(); ++i)
				for (unsigned j = 0; j < mMappedSegs[i].size(); ++j)
					if (max_id < mMappedSegs[i][j]) max_id = mMappedSegs[i][j];
			return max_id;
		}


	private:
		// Whether is the first pattern contained in the second pattern?
		// If b_jump = false (true), only patterns that contain the specified
		// pattern as a substring (sequence) will lead to the "true" return value
		bool IsContainedIn(const vector<int>& first_segs, const vector<int>& second_segs, bool b_jump) const
		{
			if (b_jump)
			{
				for (unsigned i = 0; i < first_segs.size(); ++i)
					if (find(second_segs.begin(), second_segs.end(), first_segs[i]) == second_segs.end()) 
						return false;
			}
			else
			{
				// find the position in "second_segs" of the first element in "first_segs" 
				unsigned second_idx = 0;
				for (second_idx = 0; second_idx < second_segs.size(); ++second_idx)
					if (second_segs[second_idx] == first_segs[0]) break;

				if (second_idx != second_segs.size())
				{
					int first_idx = 0;
					for (; first_idx < first_segs.size() && second_idx+first_idx < second_segs.size(); ++first_idx)
						if (second_segs[second_idx+first_idx] != first_segs[first_idx])
							return false;
					if (first_idx < first_segs.size()) 
						return false;
				}
				else return false;
			}
			return true;
		}
};

/*
 * =====================================================================================
 *        Class:  ShortRead
 *  Description:  This class defines mapped information of short reads
 * =====================================================================================
 */
class ShortRead
{
	public:
		ReadInfoBase*           mpReadInfo;
		vector<Pattern> 		mPatterns;
		vector<double> 			mPatternDup;

	public:
		// Map the index of the patterns to a new set of indexes
		// index i will be mapped to mapping[i]. If mapping[i] == -1, 
		// i is removed. Duplications will be merged
		void Shrink(const vector<int>& mapping)
		{
			// Note that shrink_map could be like :
			// 0 0 0 1 1 2 2 1 1 2 2 3 3
			for (unsigned i = 0; i < mPatterns.size(); ++i)
				mPatterns[i].Shrink(mapping);
			RemoveEmptyPatterns();
			CombineDup();
		}

		// Return the number of reads in the current group
		// or the number of short reads that cover a specified segments
		double ReadCnt(int seg = -1) const
		{
			if (-1 != seg)
			{
				Pattern pattern;
				vector<int> curr_seg;
				curr_seg.push_back(seg);
				pattern.mMappedSegs.push_back(curr_seg);
				return SupPatternCnt(pattern);
			}

			double sum = 0;
			for (unsigned i = 0; i < mPatternDup.size(); ++i)
			{
				if (-1 == seg)
					sum += mPatternDup[i];
			}
			return sum;
		}

		void Flatten()
		{
			for (unsigned i = 0; i < mPatterns.size(); ++i)
				mPatterns[i].Flatten();
			CombineDup();
		}

		// Return the number of reads of a specified pattern
		double EquPatternCnt(const Pattern& pattern) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mPatternDup.size(); ++i)
				if (pattern == mPatterns[i]) 
					sum += mPatternDup[i];
			return sum;
		}
		
		// Return the number of reads covering a specified pattern
		// If b_jump = false (true), only patterns that contain the specified
		// pattern as a substring will be considered
		double SupPatternCnt(const Pattern& pattern, bool b_jump = false) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mPatterns.size(); ++i)
				if (pattern.IsContainedIn(mPatterns[i], b_jump))
					sum += mPatternDup[i];
			return sum;
		}

		// Return the number of reads covered by a specified pattern
		double SubPatternCnt(const Pattern& pattern, bool b_jump = false) const
		{
			double sum = 0;
			for (unsigned i = 0; i < mPatterns.size(); ++i)
				if (mPatterns[i].IsContainedIn(pattern, b_jump))
					sum += mPatternDup[i];
			return sum;
		}

		// Calculate the expression level in RPKM of a junction
		double JuncExp(int first_seg, int second_seg) const
		{
			vector<int> segs;
			segs.push_back(first_seg);
			segs.push_back(second_seg);

			Pattern patt;
			patt.mMappedSegs.push_back(segs);

			double exp = SupPatternCnt(patt, false) * 1000000.0 / mpReadInfo->mTotalReadCnt* 
						 1000.0 / (mpReadInfo->JuncLen());
			return exp;
		}

	private:
		// Remove duplications in the patterns and combine the read cnt.
		void CombineDup()
		{
			if (mPatterns.size() < 2) return;
			vector<int> sortedIndex;
			UtilityTempComp<Pattern>::Sort(mPatterns, sortedIndex);
			UtilityTemp<double>::SortByIndex(mPatternDup, sortedIndex);

			int new_size = 0;
			for (unsigned i = 1; i < mPatterns.size(); ++i)
			{
				if (mPatterns[i] != mPatterns[new_size])
				{
					mPatterns[++new_size] = mPatterns[i];
					mPatternDup[new_size] = mPatternDup[i];
				}
				else
					mPatternDup[new_size] += mPatternDup[i];
			}
			mPatterns.resize(++new_size);
			mPatternDup.resize(new_size);
		}

		// Remove empty patterns and the corresponding pattern dup
		void RemoveEmptyPatterns()
		{
			int new_size = 0;
			for (unsigned i = 0; i < mPatterns.size(); ++i)
				if (!mPatterns[i].IsEmpty())
				{
					mPatternDup[new_size] = mPatternDup[i];
					mPatterns[new_size++] = mPatterns[i];
				}
			mPatternDup.resize(new_size);
			mPatterns.resize(new_size);
		}

		// Return the maximum index in all the patterns.
		int MaxSegIndex() const
		{
			int max_id = 0;
			for (unsigned i = 0; i < mPatterns.size(); ++i)
			{
				int curr_max_id = mPatterns[i].MaxSegIndex();
				if (max_id < curr_max_id) max_id = curr_max_id;
			}
			return max_id;
		}
};

#endif
