#include <stdlib.h>
#include "parlib.h"

int P_Transpose_init ( ndims, dim_source, lblks_source, dim_dest, 
		lblks_dest, stride, blklen, overlap, datatype, comm, period,
		transp )
	int ndims, dim_source, *lblks_source, dim_dest, *lblks_dest, *stride;
   int *blklen, *overlap;
	MPI_Datatype datatype;
   MPI_Comm comm;
	int period;
   Transposition *transp;
{
	int idim, nproc, iproc, ip, strd, count;
	int wblka, wblkb, begb;
	int ifsta, ifstb, idir, suma, sumb;
	MPI_Aint fsize;
	MPI_Datatype oldtype, *stype, *rtype;
	int *sbeg, *rbeg;
/*
 * Check input parameters
 */
	if ( ndims < 2) { return 1; }
	if ( dim_source < 1 || dim_source > ndims ) { return 2; }
	if ( dim_dest < 1 || dim_dest > ndims ) { return 3; }
   if ( dim_source == dim_dest ) { return 4; }
   for ( idim = 0; idim < ndims; idim++ ) {
		if ( stride[idim] <= 0) { return 5; }
	}
   for (idir = 0; idir < 2; idir++ ) {
		if ( overlap[idir] < 0 ) { return 6; }
	}
/*
 * Define the number of processors in the group and the rank
 */
	MPI_Comm_size ( comm, &nproc );
	if ( nproc == 0 ) { return 0; } 
	MPI_Comm_rank ( comm, &iproc );
	if ( iproc == MPI_UNDEFINED ) { return 0; }

	suma = sumb = 0;
	for ( ip = 0; ip < nproc; ip++ ) {
		suma += lblks_source[ip];
		sumb += lblks_dest[ip];
		if ( lblks_source[ip] <= 0 ) { return 14; }
		if ( lblks_dest[ip] <= 0 ) { return 15; }
	}

	if ( lblks_source[iproc] > blklen[dim_source-1] ) { return 8; }
   if ( lblks_dest[iproc] > blklen[dim_dest-1] ) { return 9; }
	if ( suma > stride[dim_source-1] ) { return 10; }
   if ( sumb > stride[dim_dest-1] ) { return 11; }
	for ( idim = 0; idim < ndims; idim++ ) {
		if ( idim != dim_source-1 && idim != dim_dest-1 ) {
			if ( blklen[idim] > stride[idim] ) { return 7; }
		}
	}
	if ( overlap[0] > lblks_dest[0] ) { return 12; }
	if ( overlap[1] > lblks_dest[nproc-1] ) { return 13; }

	MPI_Type_extent ( datatype, &fsize );
/*
 * Allocate memory
 */
	stype = transp->stype =
		(MPI_Datatype *) malloc ( sizeof(MPI_Datatype)*nproc ); 
	rtype = transp->rtype =
		(MPI_Datatype *) malloc ( sizeof(MPI_Datatype)*nproc ); 
	sbeg = transp->sbeg = (int *) malloc ( sizeof(int)*nproc ); 
	rbeg = transp->rbeg = (int *) malloc ( sizeof(int)*nproc ); 
/*
 * Define data types for the blocks and the beginings of the blocks
 */
   ifsta = ifstb = 1;
   for ( ip = 0; ip < nproc; ip++ ) {
      wblka = lblks_source[iproc];
      wblkb = lblks_dest[ip];
      if ( ip > 0 || period ) wblkb += overlap[0];
      if ( ip < nproc-1 || period ) wblkb += overlap[1];

      oldtype = datatype;
      strd = 1;
      for ( idim = 0; idim < ndims; idim++ ) {
         if ( idim == dim_source-1 ) {
            count = wblka;
         } else if ( idim == dim_dest-1 ) {
            count = wblkb;
         } else {
            count = blklen[idim];
         }
         MPI_Type_hvector ( count, 1, strd*fsize, oldtype, stype+ip );
			if ( idim > 0 ) { MPI_Type_free ( &oldtype ); }
         oldtype = stype[ip];
         if ( idim == dim_source-1 ) {
            strd *= blklen[idim];
         } else {
            strd *= stride[idim];
         }
      }
      MPI_Type_commit ( stype+ip );

      wblka = lblks_source[ip];
      wblkb = lblks_dest[iproc];
      if ( iproc > 0 || period ) wblkb += overlap[0];
      if ( iproc < nproc-1 || period ) wblkb += overlap[1];
      oldtype = datatype;
      strd = 1;
      for ( idim = 0; idim < ndims; idim++ ) {
         if ( idim == dim_source-1 ) {
            count = wblka;
         } else if ( idim == dim_dest-1 ) {
            count = wblkb;
         } else {
            count = blklen[idim];
         }
         MPI_Type_hvector ( count, 1, strd*fsize, oldtype, rtype+ip );
			if ( idim > 0 ) { MPI_Type_free ( &oldtype ); }
         oldtype = rtype[ip];
         if ( idim == dim_dest-1 ) {
            strd *= blklen[idim];
         } else {
            strd *= stride[idim];
         }
      }
      MPI_Type_commit ( rtype+ip );

      begb = ifstb;
      if ( ip > 0 || period ) begb -= overlap[0];
      strd = 1;
      for ( idim = 0; idim < dim_dest-1; idim++ ) {
         if ( idim == dim_source-1 ) {
            strd *= blklen[idim];
         } else {
            strd *= stride[idim];
         }
      }
      sbeg[ip] = strd*(begb-1);

      rbeg[ip] = 0;
      if ( iproc > 0 || period ) rbeg[ip] -= overlap[0]*strd;

      strd = 1;
      for ( idim = 0; idim < dim_source-1; idim++ ) {
         if ( idim == dim_dest-1 ) {
            strd *= blklen[idim];
         } else {
            strd *= stride[idim];
         }
      }
      rbeg[ip] += strd*(ifsta-1);
      ifsta += lblks_source[ip];
      ifstb += lblks_dest[ip];
	}
	transp->nproc = nproc;
	transp->iproc = iproc;
	transp->comm = comm;
	transp->fsize = fsize;
	return 0;
}

int P_Transpose_start ( arr_source, arr_dest, transp )
	void *arr_source, *arr_dest;
	Transposition *transp;
{
	char *arr_source_ch = (char *) arr_source;
	char *arr_dest_ch = (char *) arr_dest;
	int nproc = transp->nproc;
	int iproc = transp->iproc;
	MPI_Aint fsize = transp->fsize;
	int *sbeg = transp->sbeg;
	int *rbeg = transp->rbeg;
	MPI_Datatype *stype = transp->stype;
	MPI_Datatype *rtype = transp->rtype;
	MPI_Comm comm = transp->comm;
	MPI_Request *sreq, *rreq;
	int ip;
	char *src, *dest;

   if ( nproc == 0 ) { return 0; }
   if ( iproc == MPI_UNDEFINED ) { return 0; }
/*
 * Allocate memory
 */
	sreq = transp->sreq = 
		(MPI_Request *) malloc ( sizeof(MPI_Request)*nproc );
	rreq = transp->rreq = 
		(MPI_Request *) malloc ( sizeof(MPI_Request)*nproc );
/*
 * Start the communication
 */
   for ( ip = 0; ip < nproc; ip++ ) {
	   dest = arr_dest_ch+rbeg[ip]*fsize;
		src = arr_source_ch+sbeg[ip]*fsize;
		if ( dest == src && iproc == ip ) {
			rreq[ip] = sreq[ip] = MPI_REQUEST_NULL;
      } else {
         MPI_Irecv ( dest, 1, rtype[ip], ip, 0, comm, rreq+ip );
         MPI_Isend ( src, 1, stype[ip], ip, 0, comm, sreq+ip );
		}
	}
	return 0;
}

int P_Transpose_end ( transp )
	Transposition *transp;
{
	int nproc = transp->nproc;
	int iproc = transp->iproc;
	MPI_Request *sreq = transp->sreq; 
	MPI_Request *rreq = transp->rreq; 
	int ip;
	MPI_Status status;

   if ( nproc == 0 ) { return 0; }
   if ( iproc == MPI_UNDEFINED ) { return 0; }

   for ( ip = 0; ip < nproc; ip++ ) {
      MPI_Wait ( rreq+ip, &status );
      MPI_Wait ( sreq+ip, &status );
   }
	return 0;
}

int P_Transpose_free ( transp )
	Transposition *transp;
{
	int nproc = transp->nproc;
	int iproc = transp->iproc;
	int *sbeg = transp->sbeg;
	int *rbeg = transp->rbeg;
	MPI_Datatype *stype = transp->stype;
	MPI_Datatype *rtype = transp->rtype;
	MPI_Request *sreq = transp->sreq; 
	MPI_Request *rreq = transp->rreq; 
	int ip;

   if ( nproc == 0 ) { return 0; }
   if ( iproc == MPI_UNDEFINED ) { return 0; }

   for ( ip = 0; ip < nproc; ip++ ) {
      MPI_Type_free ( rtype+ip );
      MPI_Type_free ( stype+ip );
   }
	free ( stype );
	free ( rtype );
	free ( sbeg );
   free ( rbeg );
	free ( sreq );
   free ( rreq );
	return 0;
}

int P_Transpose ( ndims, arr_source, dim_source, lblks_source, arr_dest, 
		dim_dest, lblks_dest, stride, blklen, overlap, datatype, comm, 
		period )
	void *arr_source, *arr_dest;
	int ndims, dim_source, *lblks_source, dim_dest, *lblks_dest, *stride; 
	int *blklen, *overlap;
	MPI_Datatype datatype;
   MPI_Comm comm;
	int period;
{
   Transposition transp;
	int ierr;
	if ( ierr = P_Transpose_init ( ndims, dim_source, lblks_source, 
		dim_dest, lblks_dest, stride, blklen, overlap, datatype, comm, 
		period, &transp ) != 0 ) 
	{
		return ierr;
	}
	P_Transpose_start ( arr_source, arr_dest, &transp );
	P_Transpose_end ( &transp );
	P_Transpose_free ( &transp );
	return 0;
}