/*
 * =====================================================================================
 *
 *       Filename:  GraphAlgorithm_Matching.hpp
 *
 *    Description:  Find the max matching of a bipartite graph.
 *
 *        Version:  1.0
 *        Created:  12/11/2008 08:06:46 PM
 *       Revision:  none
 *       Compiler:  g++
 *
 *         Author:  Jianxing Feng (feeldead), feeldead@gmail.com
 *        Company:  THU
 *
 * =====================================================================================
 */


#ifndef GraphAlgorithm_Matching_H
#define GraphAlgorithm_Matching_H

#include "GraphEx.hpp"
#include "GraphAlgorithm_BusackerGowen.hpp"


/*
 * =====================================================================================
 *        Class:  GraphAlgorithm_Matching
 *  Description:  Find a max matching of a (weighted) bipartite graph
 * =====================================================================================
 */
	template < typename WEIGHT_TYPE >
class GraphAlgorithm_Matching
{
	public:

		static
		WEIGHT_TYPE	
		MaxWeightedBiMatching(const vector<int>& from_nodes, const vector<int>& to_nodes, 
									const vector<WEIGHT_TYPE>& weights, vector<bool>& matching);
		
	protected:

	private:


}; /* ----------  end of template class GraphAlgorithm_Matching  ---------- */


/*
 *--------------------------------------------------------------------------------------
 *       Class:  GraphAlgorithm_Matching
 *      Method:  MaxWeightedBiMatching
 * Description:  Find the maximum weighted matching in a bipartite graph
 *       Param:  from_nodes    :    The from node of each edge in the bipartite graph
 *               to_nodes      :    The to node of each edge in the bipartite graph
 *               weights       :    The weights of each edge
 *               matching      :    matching[i] is true if and only if edge 
 *                                  (from_nodes[i], to_nodes[i]) is in the matching.
 *      Return:  None
 *        Note:  Make sure that the sizes of the three array are the same
 *--------------------------------------------------------------------------------------
 */
	template < typename WEIGHT_TYPE >
/*static*/ WEIGHT_TYPE 
GraphAlgorithm_Matching<WEIGHT_TYPE>::MaxWeightedBiMatching( const vector<int>& from_nodes, const vector<int>& to_nodes, 
												const vector<WEIGHT_TYPE>& weights, vector<bool>& matching)
{
	// Find the maximum weight
	WEIGHT_TYPE max_weight = numeric_limits<WEIGHT_TYPE>::min();
	for (int i = 0; i < weights.size(); i++)
		if (max_weight < weights[i])
			max_weight = weights[i];

	// Build a flow network for this bipartite graph.
	GraphEx<int, int> bi_graph;
	int source, sink;
	bi_graph.SetDirected(true);
	bi_graph.BeginAddNodeOrEdge();
	source = 0;
	for (int i = 0; i < from_nodes.size(); i++)
	{
		bi_graph.AddEdgeEx(from_nodes[i], to_nodes[i], i);
		if (source < from_nodes[i]) source = from_nodes[i];
		if (source < to_nodes[i]) source = to_nodes[i];
	}

	source++;
	sink = source + 1;

	int edgecnt = from_nodes.size();
	for (int i = 0; i < from_nodes.size(); i++)
	{
		bi_graph.AddEdgeEx(source, from_nodes[i], edgecnt++);
		bi_graph.AddEdgeEx(to_nodes[i], sink, edgecnt++);
	}
	bi_graph.EndAddNodeOrEdge();

	int* capacity = new int[bi_graph.EdgeCnt()];
	int* flow = new int[bi_graph.EdgeCnt()];
	WEIGHT_TYPE* cost = new WEIGHT_TYPE[bi_graph.EdgeCnt()];

	const int* edges = bi_graph.Edges();
	for (int i = 0; i < bi_graph.EdgeCnt(); i++)
	{
		int edge = edges[i];
		capacity[edge] = 1;

		int ext_edge_id = bi_graph.GetEdgeExID(edge);
		if (ext_edge_id >= from_nodes.size())
			cost[edge] = 0;
		else
			cost[edge] = max_weight - weights[ext_edge_id];
	}

	matching.resize(from_nodes.size());
	WEIGHT_TYPE best_sum_weight = numeric_limits<WEIGHT_TYPE>::min();
	GraphAlgorithm_BusackerGowen<WEIGHT_TYPE> bg_algo;

	bg_algo.PrepareBuf(&bi_graph, source, sink, capacity, cost);

	// Note, small flow may lead to large sum of weight
	for (int flow_value = 1; flow_value < from_nodes.size(); flow_value++)
	{
		// If no flow has value 'flow_value', break;
		if (!bg_algo.OptimalFlow(flow_value))
			break;

		bg_algo.GetFlow(flow);

		/*
		cout << "DEBUG : The flow is " << endl;
		for (int i = 0; i < bi_graph.EdgeCnt(); i++)
		{
			cout << "(" << bi_graph.FromNode(i) << "," << bi_graph.ToNode(i) << ")  " << flow[i] << endl;
		}
		*/

		WEIGHT_TYPE sum_of_weight = 0;
		// Find the matching and calculate the sum of weight, and then remember the best one
		for (int ext_edge_id = 0; ext_edge_id < from_nodes.size(); ext_edge_id++)
		{
			int in_edge = bi_graph.GetEdgeInID(ext_edge_id);
			if (flow[in_edge] == 1)
				sum_of_weight += weights[ext_edge_id];
		}

		if (best_sum_weight < sum_of_weight)
		{
			int match_cnt = 0;
			best_sum_weight = sum_of_weight;
			for (int ext_edge_id = 0; ext_edge_id < from_nodes.size(); ext_edge_id++)
			{
				int in_edge = bi_graph.GetEdgeInID(ext_edge_id);
				if (flow[in_edge] == 1)
				{
					matching[ext_edge_id] = true;
					match_cnt++;
				}
				else
					matching[ext_edge_id] = false;
			}

			if (match_cnt != flow_value)
			{
				cerr << "ERROR : the number of matching (" << match_cnt
				     << ") is different from the flow value (" << flow_value << ")" << endl;
			}
		}
#ifdef DEBUG0
		cout << "DEBUG0 : flow_value = " << flow_value << endl;
		cout << "DEBUG0 : best_sum_weight = " << best_sum_weight << endl;
#endif

	}

	delete[] capacity;
	delete[] flow;
	delete[] cost;

	return best_sum_weight;
}		/* -----  end of method GraphAlgorithm_Matching<WEIGHT_TYPE>::MaxWeightedBiMatching  ----- */

#endif
