/*
 * =====================================================================================
 *
 *       Filename:  RandomExpReadAssignerIM.h
 *
 *    Description:  This is the header file for class RandomExpReadAssignerIM
 *
 *        Version:  1.0
 *        Created:  04/22/2009 11:26:44 AM
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
 *        Company:  THU
 *
 * =====================================================================================
 */

#ifndef RandomExpReadAssignerIM_H 
#define RandomExpReadAssignerIM_H

#include <string>
#include <vector>
#include <iostream>
#include "InstanceHandler.h"
#include "Utility2.h"

using namespace std;

/*
 * =====================================================================================
 *        Class:  RandomExpReadAssignerIM
 *  Description:  This is an improvement over RandomExpReadAssigner. This class will
 *                randomly generate expression level and reads over all the genes simultaneously.
 * =====================================================================================
 */
class RandomExpReadAssignerIM : public InstanceHandler
{
	public:
		/* ====================  LIFECYCLE     ======================================= */
		RandomExpReadAssignerIM (ostream* p_output = NULL) : InstanceHandler(p_output)
		{
			mSEReadCnt = 0;
			mRng = 0;
		};                             /* constructor */

		virtual ~RandomExpReadAssignerIM (){};                             /* constructor */

		// Set se_read_cnt to be 0 if no random single read and expression level is needed
		void 
		SetSEReadCnt(int se_read_cnt)
		{
			mSEReadCnt = se_read_cnt;
		}

		void
		SetRandExpType(int rand_exp_type)
		{
			mRandExpType = rand_exp_type;
		}

		// Set read_len to be -1 such that the exact partial combination extracted from 
		// the known isoforms are used. Otherwise, if read_len > 0 and mSEReadCnt == 0, 
		// random PE reads will be generated.
		void
		AddPEInfo(const PEInfo& pe_info)
		{
			mPEInfo.push_back(pe_info);
		}

		virtual
		void
		Initialize()
		{
			InstanceHandler::Initialize();
		};

		virtual
		void
		CleanUp()
		{
			if (mSEReadCnt == 0 && mPEInfo.size() == 0) return;

			if (mRng)
				gsl_rng_free(mRng);

			// prepare the random generator
			gsl_rng_env_setup();

			const gsl_rng_type* T = gsl_rng_default;
			mRng = gsl_rng_alloc (T);

			// To check whether the used seed are identical during different run 
			// of the program. Because the only way to set a seed for the generator
			// is to set the environment variable GSL_RNG_SEED
			cerr << "First random number is : " << gsl_rng_uniform(mRng) << endl;

			if (mSEReadCnt > 0)
			{
				GenerateExpRead();
				for (unsigned i = 0; i < mAllInstances.size(); i++)
				{
					TranslateToAbs(mAllInstances[i]);
					InstanceHandler::OnInstance(mAllInstances[i]);
				}
			}
			else if (mPEInfo.size() > 0)
			{
				PEInfo& pe_info = mPEInfo[0];
				GeneratePERead(pe_info.mSpanMean, pe_info.mSpanStd, pe_info.mReadCnt, pe_info.mReadLen);
			}

			if (mRng)
				gsl_rng_free(mRng);
			InstanceHandler::CleanUp();
		};

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  Isoform
		 *       Method:  TranslateToAbs
		 *  Description:  Translate the expression levels of the isoforms to absolute expression
		 *                level (#reads / base)
		 *        Param: 
		 *       Return:
		 *--------------------------------------------------------------------------------------
		 */
		void
		TranslateToAbs(Instance& an_instance)
		{
			double read_sum = 0;
			for (unsigned i = 0; i < an_instance.mSetSizes.size(); i++)
				read_sum += an_instance.mSampleCnt[i];
			for (unsigned i = 0; i < an_instance.mSetSizes.size(); i++)
				for (unsigned j = 0; j < an_instance.mSetSizes.size(); j++)
					read_sum += an_instance.mSpliceReadCnt[i][j];

			double exp_len_sum = 0;
			for (unsigned i = 0; i < an_instance.mIsoforms.size(); i++)
			{
				int len = 0;
				for (unsigned j = 0; j < an_instance.mSetSizes.size(); j++)
					if (an_instance.mIsoforms[i][j]) len += an_instance.mSetSizes[j];
				len -= an_instance.mReadLen - 1;
				exp_len_sum += an_instance.mIsoExp[i] * len;
			}

			if (0 == exp_len_sum) exp_len_sum = 1;
			for (unsigned i = 0; i < an_instance.mIsoExp.size(); i++)
				an_instance.mIsoExp[i] = read_sum * an_instance.mIsoExp[i] / exp_len_sum;
		}/* -----  end of method OnInstance  ----- */

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  Isoform
		 *       Method:  RandomExpReads
		 *  Description:  Given the length of exons and isoforms, randomly generate the expression
		 *                level and reads
		 *        Param: 
		 *       Return:
		 *--------------------------------------------------------------------------------------
		 */
		virtual
		void
		OnInstance(Instance& an_instance)
		{
			mAllInstances.push_back(an_instance);
		}/* -----  end of method OnInstance  ----- */

	protected:

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  Isoform
		 *       Method:  GenerateExpRead
		 *  Description:  Generate random expression levels and single end reads. 
		 *        Param:  There are three types of random expression levels.
		 *                1. pow(10, r), where r is a random variable following standard normal
		 *                    distribution.
		 *                2. pow(2, r), where r ....
		 *                3. r \in [0, 1], uniformly distributed  
		 *       Return:
		 *--------------------------------------------------------------------------------------
		 */
		void
		GenerateExpRead()
		{
			// Generate expression levels of each isoform
			for (unsigned ins = 0; ins < mAllInstances.size(); ins++)
			{
				Instance& an_instance = mAllInstances[ins];
				an_instance.mIsoExp.resize(an_instance.mIsoforms.size());
				for (unsigned i = 0; i < an_instance.mIsoExp.size(); i++)
				{
					double r;
					if (3 == mRandExpType)
					{
						// Generate a random varaible obeying uniform distribution
						r = gsl_rng_uniform(mRng);
					}
					else
					{
						// Generate a random varaible obeying standard normal distribution
						// The expression levels follow a lognormal distribution.
						// Note that even though the expression levels are relative expression levels.
						// pow(10, r) and exp(r) are not equivalent. pow(10, r) will enlarge the
						// difference between expression levels more than exp(r) does.
						r = gsl_ran_gaussian_ziggurat(mRng, 1);
						//r = gsl_ran_ugaussian(mRng);
						if (2 == mRandExpType)
							r = pow(2, r);
						else 
							r = pow(10, r);
					}

					an_instance.mIsoExp[i] = r;
				}
			}

			// Generate reads

			// Prepare partial weight sum
			vector<double> gene_wlength_part_sum;
			gene_wlength_part_sum.assign(mAllInstances.size(), 0);

			vector<vector<double> > iso_wlength_part_sum;
			iso_wlength_part_sum.resize(mAllInstances.size());

			vector<vector<int> > iso_len;
			iso_len.resize(mAllInstances.size());

			// Calculate the (weighted) length of each isoform
			for (unsigned ins = 0; ins < mAllInstances.size(); ins++)
			{
				Instance& an_instance = mAllInstances[ins];
				an_instance.mReadCnt = mSEReadCnt;
				iso_len[ins].resize(an_instance.mIsoforms.size(), 0);
				vector<double>& wlength = iso_wlength_part_sum[ins];
				wlength.assign(an_instance.mIsoforms.size(), 0);
				for (unsigned i = 0; i < an_instance.mIsoforms.size(); i++)
				{
					int length = 0;
					for (unsigned j = 0; j < an_instance.mSetSizes.size(); j++)
						length += an_instance.mIsoforms[i][j] * an_instance.mSetSizes[j];
					length -= an_instance.mReadLen - 1;
					iso_len[ins][i] = length;
					wlength[i] = (double)(length * an_instance.mIsoExp[i]);
				}
				for (unsigned i = 1; i < an_instance.mIsoforms.size(); i++)
					wlength[i] += wlength[i-1];

				gene_wlength_part_sum[ins] = wlength[wlength.size()-1];
				if (ins > 0)
					gene_wlength_part_sum[ins] += gene_wlength_part_sum[ins-1];

				// Initialize sample cnt and splice cnt
				int set_size = an_instance.mSetSizes.size();
				an_instance.mSpliceReadCnt.resize(set_size);
				for (unsigned i = 0; i < an_instance.mSpliceReadCnt.size(); i++)
					an_instance.mSpliceReadCnt[i].assign(set_size, 0);
				an_instance.mSampleCnt.assign(set_size, 0);

				for (unsigned i = 0; i < an_instance.mExons.size(); i++)
				{
					Exon& exon = an_instance.mExons[i];
					exon.mStartCnt = exon.mEndCnt = exon.mBothCnt = 0;
				}
			}

			double tot_gene_wlen = gene_wlength_part_sum[gene_wlength_part_sum.size()-1];
			for (int cnt = 0; cnt < mSEReadCnt; cnt++)
			{
				// Randomly select a gene
				double r = gsl_rng_uniform(mRng) * tot_gene_wlen;
				int gene_idx = Utility2Temp<double>::BinarySearch(gene_wlength_part_sum, r);

				// Randomly select an isoform on this gene
				vector<double>& iso_wlen = iso_wlength_part_sum[gene_idx];
				r = gsl_rng_uniform(mRng) * iso_wlen[iso_wlen.size()-1];
				int iso_idx = Utility2Temp<double>::BinarySearch(iso_wlen, r);

				int read_len = mAllInstances[gene_idx].mReadLen;

				// Randomly select a start position of the read
				int start_pos = (int)(gsl_rng_uniform(mRng) * (iso_len[gene_idx][iso_idx]));

				vector<vector<bool> >& isoforms = mAllInstances[gene_idx].mIsoforms;
				vector<int>& set_sizes = mAllInstances[gene_idx].mSetSizes;
				vector<vector<double> >& splice_read_cnt =  mAllInstances[gene_idx].mSpliceReadCnt;
				vector<double>& sample_cnt =  mAllInstances[gene_idx].mSampleCnt;
				vector<Exon>& curr_exons = mAllInstances[gene_idx].mExons;

				int set_cnt = set_sizes.size();
				int len = 0;
				int start_exon_idx = 0;
				while (start_exon_idx < set_cnt)
				{
					if (isoforms[iso_idx][start_exon_idx])
						len += set_sizes[start_exon_idx];
					if (start_pos < len) break;
					start_exon_idx++;
				}
				int end_pos = start_pos + read_len - 1;
				int end_exon_idx = start_exon_idx;
				int len2 = len;
				while (end_exon_idx < set_cnt)
				{
					if (end_pos < len2) break;
					end_exon_idx++;
					if (isoforms[iso_idx][end_exon_idx])
						len2 += set_sizes[end_exon_idx];
				}
				ASSERT(end_exon_idx < set_cnt, "Internal ERROR : Randomly sampling a read wrong!");

				curr_exons[start_exon_idx].mStartCnt++;
				curr_exons[end_exon_idx].mEndCnt++;
				if (start_exon_idx == end_exon_idx)
				{
					curr_exons[start_exon_idx].mBothCnt++;
					sample_cnt[start_exon_idx]++;
				}
				else
				{
					assert(start_pos + read_len >= len);

					start_pos += read_len;
					int next_exon_idx = start_exon_idx+ 1;

					vector<int> involved;
					involved.push_back(start_exon_idx);
					while (next_exon_idx < set_cnt && len < start_pos)
					{
						if (isoforms[iso_idx][next_exon_idx])
						{
							involved.push_back(next_exon_idx);
							len += set_sizes[next_exon_idx];
							start_exon_idx = next_exon_idx;
						}
						next_exon_idx++;
					}

					for (unsigned i = 0; i < involved.size() - 1; i++)
						splice_read_cnt[involved[i]][involved[i+1]] += 1;
				}

				if (cnt % 100000 == 0) 
					cout << cnt << " reads have been sampled" << endl;
			}
			cout << mSEReadCnt << " reads have been sampled" << endl;

		}

		/*
		 *--------------------------------------------------------------------------------------
		 *        Class:  Isoform
		 *       Method:  GeneratePERead
		 *  Description:  Generate paired end reads. The expression levels in mAllInstances 
		 *                should be initialized before calling this function
		 *        Param: 
		 *       Return:
		 *--------------------------------------------------------------------------------------
		 */
		void
		GeneratePERead(double span_mean, double span_std, int read_cnt, int read_len)
		{
			if (mAllInstances.size() == 0) return;

			// According to the expression levels generated by GenerateExpRead,
			// Prepare partial weight sum
			vector<double> gene_wlength_part_sum;
			gene_wlength_part_sum.assign(mAllInstances.size(), 0);

			vector<vector<double> > iso_wlength_part_sum;
			iso_wlength_part_sum.resize(mAllInstances.size());

			vector<vector<int> > iso_len;
			iso_len.resize(mAllInstances.size());

			// Calculate the (weighted) length of each isoform
			for (unsigned ins = 0; ins < mAllInstances.size(); ins++)
			{
				Instance& an_instance = mAllInstances[ins];
				iso_len[ins].resize(an_instance.mIsoforms.size(), 0);
				vector<double>& wlength = iso_wlength_part_sum[ins];
				wlength.assign(an_instance.mIsoforms.size(), 0);
				for (unsigned i = 0; i < an_instance.mIsoforms.size(); i++)
				{
					int length = 0;
					for (unsigned j = 0; j < an_instance.mSetSizes.size(); j++)
						length += an_instance.mIsoforms[i][j] * an_instance.mSetSizes[j];
					// :WARNING:05/28/2009 01:49:24 PM:feeldead: 
					// This method is only an approximation of the real sampling process.
					// It is OK if enough reads are sampled and the paired-end reads
					// are not used to calculate the expression level.
					if (length < 0) length = 0;
					iso_len[ins][i] = length;
					wlength[i] = (double)(length * an_instance.mIsoExp[i]);
					if (wlength[i] < 0) wlength[i] = 0;
				}
				for (unsigned i = 1; i < an_instance.mIsoforms.size(); i++)
					wlength[i] += wlength[i-1];

				gene_wlength_part_sum[ins] = wlength[wlength.size()-1];
				if (ins > 0)
					gene_wlength_part_sum[ins] += gene_wlength_part_sum[ins-1];

				// Initialize sample cnt and splice cnt
				int set_size = an_instance.mSetSizes.size();
				an_instance.mSpliceReadCnt.resize(set_size);
				for (unsigned i = 0; i < an_instance.mSpliceReadCnt.size(); i++)
					an_instance.mSpliceReadCnt[i].assign(set_size, 0);
				an_instance.mSampleCnt.assign(set_size, 0);
			}

			vector<vector<double> > all_end_at_cnt;
			all_end_at_cnt.resize(mAllInstances.size());
			for (unsigned i = 0; i < all_end_at_cnt.size(); i++)
				all_end_at_cnt[i].assign(mAllInstances[i].mSampleCnt.size(), 0);

			double tot_gene_wlen = gene_wlength_part_sum[gene_wlength_part_sum.size()-1];

			int cnt = 0;
			while (cnt < read_cnt)
			{
				// Randomly select a gene
				double r = gsl_rng_uniform(mRng) * tot_gene_wlen;
				unsigned gene_idx = Utility2Temp<double>::BinarySearch(gene_wlength_part_sum, r);
				// Find the last gene in case some gene has zero contribution to wlength
				while (gene_idx + 1 < gene_wlength_part_sum.size() && 
					gene_wlength_part_sum[gene_idx] == gene_wlength_part_sum[gene_idx+1])
					gene_idx++;

				// Randomly select an isoform on this gene
				vector<double>& iso_wlen = iso_wlength_part_sum[gene_idx];
				r = gsl_rng_uniform(mRng) * iso_wlen[iso_wlen.size()-1];
				int iso_idx = Utility2Temp<double>::BinarySearch(iso_wlen, r);
				// Find the last iso in case some isoform has zero contribution to wlength
				while (iso_idx + 1 < iso_wlen.size() && iso_wlen[iso_idx] == iso_wlen[iso_idx+1])
					iso_idx++;

				// Randomly select a span length following N(span_mean, span_std)
				int span_len = (int)(gsl_ran_gaussian(mRng, 1) * span_std + span_mean);

				if (span_len < 2 * read_len) continue;

				// If this isoform is too short, skip this read
				if (iso_len[gene_idx][iso_idx] < span_len) continue;

				// Randomly select a start position of the read
				int start_pos = (int)(gsl_rng_uniform(mRng) * (iso_len[gene_idx][iso_idx] - span_len));

				vector<vector<bool> >& isoforms = mAllInstances[gene_idx].mIsoforms;
				vector<int>& set_sizes = mAllInstances[gene_idx].mSetSizes;

				(*mpOutput) << mAllInstances[gene_idx].mExons[0].mChr << "\t" << (mAllInstances[gene_idx].mExons[0].mStrand ? '+' : '-') << "\t";

				// Find the start and end positions of the first part of the PE read
				int set_cnt = set_sizes.size();
				int len = 0;
				int start_exon_idx = 0;
				while (start_exon_idx < set_cnt)
				{
					if (isoforms[iso_idx][start_exon_idx])
						len += set_sizes[start_exon_idx];
					if (start_pos < len) break;
					start_exon_idx++;
				}
				int end_pos = start_pos + read_len - 1;
				int end_exon_idx = start_exon_idx;
				int len2 = len;
				while (end_exon_idx < set_cnt)
				{
					if (end_pos < len2) break;
					end_exon_idx++;
					if (isoforms[iso_idx][end_exon_idx])
						len2 += set_sizes[end_exon_idx];
				}
				ASSERT(end_exon_idx < set_cnt, "Internal ERROR : Randomly sampling a read wrong!");

				if (start_exon_idx == end_exon_idx)
					(*mpOutput) << mAllInstances[gene_idx].mExons[start_exon_idx].mStart + start_pos - len + set_sizes[start_exon_idx] << "\t" 
						        << read_len << "\t" << -1 << "\t" << 0 << "\t";
				else
					(*mpOutput) << mAllInstances[gene_idx].mExons[start_exon_idx].mStart + start_pos - len + set_sizes[start_exon_idx] << "\t" 
								<< len - start_pos << "\t" 
								<< mAllInstances[gene_idx].mExons[end_exon_idx].mStart << "\t" << end_pos - len + 1 << "\t";

				// Find the start and end positions of the second part of the PE read
				start_pos += span_len - read_len;
				len = 0;
				start_exon_idx = 0;
				while (start_exon_idx < set_cnt)
				{
					if (isoforms[iso_idx][start_exon_idx])
						len += set_sizes[start_exon_idx];
					if (start_pos < len) break;
					start_exon_idx++;
				}
				end_pos = start_pos + read_len - 1;
				end_exon_idx = start_exon_idx;
				len2 = len;
				while (end_exon_idx < set_cnt)
				{
					if (end_pos < len2) break;
					end_exon_idx++;
					if (isoforms[iso_idx][end_exon_idx])
						len2 += set_sizes[end_exon_idx];
				}
				ASSERT(end_exon_idx < set_cnt, "Internal ERROR : Randomly sampling a read wrong!");

				if (start_exon_idx == end_exon_idx)
					(*mpOutput) << mAllInstances[gene_idx].mExons[start_exon_idx].mStart + start_pos - len + set_sizes[start_exon_idx] << "\t" 
						        << read_len << "\t" << -1 << "\t" << 0 << "\t";
				else
					(*mpOutput) << mAllInstances[gene_idx].mExons[start_exon_idx].mStart + start_pos - len + set_sizes[start_exon_idx] << "\t" 
								<< len - start_pos << "\t" 
								<< mAllInstances[gene_idx].mExons[end_exon_idx].mStart << "\t" << end_pos - len + 1 << "\t";
				(*mpOutput) << endl;

				cnt++;
				if (cnt % 100000 == 0) 
					cout << cnt << " reads have been sampled" << endl;
			}
		}



	private:
		const gsl_rng_type * mRngType;
		gsl_rng * mRng;

		vector<Instance> mAllInstances;
		// The depth of the sampling
		int mSEReadCnt;

		// In the following, the PE information could be classified into several
		// groups with each group corresponding to a span range
		vector<PEInfo> mPEInfo;

		int mRandExpType;
}; /* -----  end of class RandomExpReadAssignerIM  ----- */

#endif
