#include <stdio.h>
#include "util.h"
#include "gmap.h"
#include "graph.h"
#include "spanning_tree.h"
#include "lists.h"

typedef struct index_map {
  int internal_index;
  int mapped_index;
  int range;
  int * entries;
} * IM;

IM IM_alloc(internal_index, range, length) int internal_index, range, length;
{
  int * entries = (int *) malloc ((1 + length) * sizeof(int));
  IM x = (IM) malloc (sizeof(*x));
  x -> internal_index = internal_index;
  x -> mapped_index = 0;
  x -> range = range;
  x -> entries = entries;
  bzero(x -> entries, (1 + length) * sizeof(int));
  return x;
}

skip_to_eol(fd) FILE * fd;
{
  while(fgetc(fd) != '\n') ;
}

skip_prologue(fd) FILE * fd;
{
  while(fgetc(fd) != ' ') skip_to_eol(fd);
}

readn(fd) FILE * fd;
{
  int i,x;
  x = fscanf(fd,"%d",&i);
  if (x != 1) {
    fprintf(stderr,"READN failed to read an integer\n"); abort(1);
  }
  return i;
}

int global_count;
IM * global_maps;
GMAP colmap;
GRAPH colgraph;
GRAPH_INDEX basic_index;
GRAPH_INDEX temp_index;
GRAPH_INDEX sp;

column_cost(k1,k2,count,maps)
     register int k1,k2,count;
     register IM * maps;
{
  register int t = 0, i;
  for (i = 0; i < count; i++)
    if (maps[i] -> entries[k1] != maps[i] -> entries[k2]) t++;
  return t;
}

node_cost(n1,n2) GRAPH_NODE n1,n2;
{
  return column_cost(node_data(n1,basic_index),
		     node_data(n2,basic_index),
		     global_count,
		     global_maps);
}

compare_columns(j1,j2) int j1,j2;
{
  return column_cost(j1,j2,global_count, global_maps) == 0;
}

hash_column(j) int j;
{
  int i;
  int x = 1;
  int base = 0;
  for (i = 0; i < global_count; i++)
    {
    x = x * (base + global_maps[i] -> entries[j]);
    base = base + global_maps[i] -> range + 1;
    }
  return x;
}

copy_column(j) int j; {return j; }

countup(e,s) GRAPH_EDGE e; int s; {return s + 1;}

int sanity_check = 0;
     
has_odd_degree(n) GRAPH_NODE n;
{
  int state = 0;
  state = foreach_edge_from(colgraph,sp,n,countup,state);
  state = foreach_edge_to(colgraph,sp,n,countup,state);
  sanity_check += state;
  return (state & 1) == 1;
}

matching_cost(odd_map, pairs_of, odd_count)
     GRAPH_NODE * odd_map; int * pairs_of; int odd_count;
{
  int i,j;
  int temp = 0;
  
  for (i = 0; i < odd_count; i ++)
    {
      temp = temp + node_cost(odd_map[i],odd_map[pairs_of[i]]);
    }
  return temp/2;
}

improve_matching(odd_map, pairs_of, odd_count)
     GRAPH_NODE * odd_map; int * pairs_of; int odd_count;
{
  int i,j,pi,pj,min_j,best_choice;
  int temp, improvement;
  
  for (i = 1; i < odd_count; i ++)
    {
      pi = pairs_of[i];
      improvement = 0;
      
      for (j = 0; j < i; j ++)
	{
	  if (j != pi && j != i)
	    {
	      pj = pairs_of[j];
	      
	      temp = 
		(node_cost(odd_map[i],odd_map[pi]) + 
		 node_cost(odd_map[j],odd_map[pj])) -
		   (node_cost(odd_map[i],odd_map[j]) + 
		    node_cost(odd_map[pi],odd_map[pj]));
	      if (temp > improvement)
		{
		  improvement = temp;
		  min_j = j;
		}
	    }
	}
      if (improvement > 0)
	{
	  int old_pi = pi;
	  int old_pj = pairs_of[min_j];
	  pairs_of[i] = min_j;
	  pairs_of[min_j] = i;
	  pairs_of[old_pi] = old_pj;
	  pairs_of[old_pj] = old_pi;
	}
    }
}


read_i_maps(fd) FILE * fd;
{
  int count, size, internal_index, range;

  /* count is the number of "rows"
     size is the number of "columns"
     range is the (per-row) maximum element
     */


  int i,j,k,l,m,n;
  int size_without_duplicates;
  IM * maps;

  int treecost; /* cost of our spanning tree */

  GRAPH_NODE * odd_map;
  int * pairs_of;
  int odd_count;
      
  int mcost;
  int omcost = 0;

  skip_prologue(fd);
  count = readn(fd);
  size = readn(fd);

  maps = (IM *) malloc (sizeof(IM) * count);
  
  global_maps = maps;
  global_count = count;

  /* Read in the bulk of the maps. */
  
  for (i = 0; i < count; i++)
    {
      internal_index = readn(fd);
      range = readn(fd);
      maps[i] = IM_alloc(internal_index, range, size);
      for (j = 1; j <= size; j++)
	{
	  maps[i] -> entries[j] = readn(fd);
	}
    }
  printf("Read in matrix of %d rows, %d columns\n", count, size);

  /* compute the raw transitions cost */
  
  {
    int cost = 0;
    for (j = 2; j <= size; j++)
      {
	cost = cost + column_cost(j - 1, j, count,maps);
      }
    printf("Raw cost is %d (total internal transition count)\n", cost);
  }

  /* Remove duplicates */
  
  colmap = int_g_new_map(hash_column,
			 copy_column,
			 copy_column,
			 compare_columns,
			 size + 1, size + 1);
  
  for (j = 1; j <= size; j++)
    g_map(colmap,j);
  
  size_without_duplicates = g_map_size(colmap);
  printf("%d columns before removing duplicates, %d after\n",
	 size, size_without_duplicates);
  
  /* Compute cost after removing duplicates */

  {
    int cost = 0;
    for (j = 2; j <= size_without_duplicates; j++)
      {
	cost = cost + column_cost(g_unmap(colmap,j - 1),
				  g_unmap(colmap,j),
				  count,
				  maps);
      }
    printf("Reduced cost is %d (total internal transition count)\n", cost);
  }

  /* Compute a spanning tree of the columns (treated as nodes,
     edge cost equal to column-to-column cost) */

  colgraph = new_graph();
  basic_index = next_graph_index(colgraph);
  temp_index = next_graph_index(colgraph);

  /* the data (basic_index) for each node contains the index of a column
     the other field for each non-duplicated node contains the address of a node. */
  
  printf("(forming nodes)\n");
  for (j = 1; j <= size_without_duplicates; j++)
    {
      GRAPH_NODE n = new_node(colgraph);
      node_data(n,basic_index) = g_unmap(colmap,j);
      g_other(colmap,j) = (CELL) n;
    }

  /* Each edge will be tagged with a cost.  Put the edges in temp_index,
     because we'll be disposing of them. */
  
  printf("(forming edges)\n");
  for (i = 1; i <= size_without_duplicates; i++)
    for (j = i + 1; j <= size_without_duplicates; j++)
      {
	int c = column_cost(g_unmap(colmap,i), g_unmap(colmap,j), count, maps);
	GRAPH_EDGE e = new_edge_after(colgraph,
				      temp_index,
				      c,
				      g_other(colmap,i),
				      g_other(colmap,j));
      }

  /* Now ask for a spanning tree */
  printf("(forming spanning tree)\n");
  sp = graph_spanning_tree(colgraph, temp_index, count, 0, &treecost);
  
  printf("A spanning tree has cost %d\n", treecost);
  
  delete_graph_index(colgraph,temp_index);

  /* Now find all nodes in the tree such their degree is odd */
  printf("(finding nodes with odd degree)\n");
  {
    odd_count = 0;
    sanity_check = 0;
    for (j = 1; j <= size_without_duplicates; j++)
      {
	GRAPH_NODE n = (GRAPH_NODE) g_other(colmap,j);
	int is_odd = has_odd_degree(n);
	node_data(n,sp) = is_odd;
	odd_count += is_odd;
      }
    printf("Sum of degrees = %d, should be %d\n",
	   sanity_check, 2 * size_without_duplicates - 2);
    printf("%d nodes out of %d have odd degree\n",
	   odd_count, size_without_duplicates);
    
    /* Now we go after a matching */
    odd_map = (GRAPH_NODE *) malloc (odd_count * sizeof(GRAPH_NODE));
    pairs_of = (int *) malloc (odd_count * sizeof(int));

    l = 0;
    for (j = 1; j <= size_without_duplicates; j++)
      {
	GRAPH_NODE n = (GRAPH_NODE) g_other(colmap,j);
	int is_odd = has_odd_degree(n);
	if (is_odd)
	  {
	    odd_map[l] = n;
	    l++;
	  }
      }
    /* Initial, probably bogus, matching */
    for (j = 0; j < odd_count; j += 2)
      {
	pairs_of[j] = j + 1;
	pairs_of[j + 1] = j;
      }
    
    mcost = matching_cost(odd_map, pairs_of, odd_count);
    omcost = 0;
    
    printf("Initial (bogus) matching cost is %d\n", mcost);
    
      while (omcost != mcost)
	{
	  omcost = mcost;
	  improve_matching(odd_map, pairs_of, odd_count);
	  mcost = matching_cost(odd_map, pairs_of, odd_count);
	  printf("Improved matching cost is %d\n", mcost);
	}
    /* at this point we have a good matching */
    printf("Estimated total cost is %d\n", mcost + treecost);
  }
}

stream(in,out) FILE * in, *out;
{
  int c = fgetc(in);
  while (c != EOF)
    {
      fputc(c,out);
      c = fgetc(in);
    }
}

main()
{
  read_i_maps(stdin);

#ifdef CHECK_INPUT
  printf("Remainder of input is:\n----\n");
  stream(stdin,stdout);
  printf("----\n");
#endif
}
