#include "parlib.h"

/*  
 *  Error codes:
 *     0 - success
 *     1 - nonpositive number of dimensions
 *     2 - wrong communicated dimension
 *     3 - negative boundary width
 *     4 - nonpositive dimension
 *     5 - boundary width exceeds the array block length
 */
int P_BExchange_init ( ndims, stride, blklen, bdim, overlap, datatype, 
		comm, period, bexchange )
   int				ndims, *stride, *blklen, bdim, overlap[2], period;
   MPI_Datatype	datatype;
   MPI_Comm			comm;
	BExchange		*bexchange;
{
	int				nproc, iproc, direct, idim, sendproc[2], recvproc[2];
	int				count, strd, sbind[2], rbind[2], send[2], recv[2];
   MPI_Aint			fsize;
   MPI_Datatype	oldtype, btype[2];
   MPI_Request		sreq[2], rreq[2]; 

/*
 * Check input parameters
 */
   if ( ndims < 1 ) {return 1;}
   if ( bdim < 1 || bdim > ndims ) {return 2;}
   if ( overlap[0] == 0 && overlap[1] == 0 ) {return 0;} /* success */
   for ( idim = 0; idim < ndims; idim++ ) {
      if ( stride[idim] <= 0 ) {return 4;}
   }
   for ( direct = 0; direct < 2; direct++ ) {
      if ( overlap[direct] < 0 ) {return 3;}
      if ( overlap[direct] > blklen[bdim-1] ) {return 5;}
   }
/*
 * Define the number of processors in the group and the rank
 */
   MPI_Comm_size ( comm, &nproc );
   if ( nproc == 0 ) {return 0;} /* success */
   MPI_Comm_rank ( comm, &iproc );
   if ( iproc == MPI_UNDEFINED ) {return 0;} /* the process does not belong to the group */
   sendproc[0] = ( iproc == 0 ? nproc-1 : iproc-1 );
   recvproc[0] = ( iproc == nproc-1 ? 0 : iproc+1 );
	sendproc[1] = recvproc[0];
	recvproc[1] = sendproc[0];
   send[0] = iproc > 0 || period; 
   recv[0] = iproc < nproc-1 || period;
   send[1] = recv[0];
   recv[1] = send[0];
   MPI_Type_extent ( datatype, &fsize );
/*
 * Define data types for the boundaries
 */
    for ( direct = 0; direct < 2; direct++ ) {
      if ( overlap[direct] > 0 ) {
         oldtype = datatype;
         strd = 1;
         for ( idim = 0; idim < ndims; idim++ ) {
            if ( idim+1 == bdim ) {
               count = overlap[direct];
            } else {
               count = blklen[idim];
            }
            MPI_Type_hvector ( count, 1, strd * fsize, oldtype, 
					&btype[direct] );
            if ( idim > 0 ) {
               MPI_Type_free ( &oldtype );
            }
            oldtype = btype[direct];
            strd = strd * stride[idim];
         }
         MPI_Type_commit ( &btype[direct] );
      }
   }
/*
 * Determine the begining of boundaries
 */
   strd = 1;
   for ( idim = 0; idim < bdim - 1; idim++ ) {
      strd = strd * stride[idim];
   }
   sbind[0] = 0;
   rbind[0] = blklen[bdim-1]*strd;
   sbind[1] = (blklen[bdim-1]-overlap[1])*strd;
   rbind[1] = -overlap[1]*strd;

   for ( direct = 0; direct < 2; direct++ ) {
      bexchange->overlap[direct] = overlap[direct];
      bexchange->send[direct] = send[direct];
      bexchange->recv[direct] = recv[direct];
      bexchange->btype[direct] = btype[direct];
      bexchange->sendproc[direct] = sendproc[direct];
      bexchange->recvproc[direct] = recvproc[direct];
      bexchange->sbind[direct] = sbind[direct];
      bexchange->rbind[direct] = rbind[direct];
	}
	bexchange->comm = comm;
	bexchange->fsize = fsize;
	return 0;
}

int P_BExchange_start ( a, bexchange )
   void				*a; 
	BExchange		*bexchange;
{
	int				direct, overlap[2], send[2], recv[2], btype[2]; 
	int				sendproc[2], recvproc[2], sbind[2], rbind[2];
	MPI_Comm			comm;
   MPI_Request		sreq[2], rreq[2]; 
   MPI_Aint			fsize;
	char				*ach = (char *) a;

   for ( direct = 0; direct < 2; direct++ ) {
      overlap[direct] = bexchange->overlap[direct];
      send[direct] = bexchange->send[direct];
      recv[direct] = bexchange->recv[direct];
      btype[direct] = bexchange->btype[direct];
      sendproc[direct] = bexchange->sendproc[direct];
      recvproc[direct] = bexchange->recvproc[direct];
      sbind[direct] = bexchange->sbind[direct];
      rbind[direct] = bexchange->rbind[direct];
   }
	comm = bexchange->comm;
	fsize = bexchange->fsize;

   for ( direct = 0; direct < 2; direct++ ) {
      if ( overlap[direct] > 0 ) {
         if ( send[direct] ) {
            MPI_Isend ( ach+sbind[direct]*fsize, 1, btype[direct], 
					sendproc[direct], 0, comm, &sreq[direct] );
         }
         if ( recv[direct] ) {
            MPI_Irecv ( ach+rbind[direct]*fsize, 1, btype[direct], 
					recvproc[direct], 0, comm, &rreq[direct] );
         }
      }
   }
   for ( direct = 0; direct < 2; direct++ ) {
      bexchange->sreq[direct]=sreq[direct];
      bexchange->rreq[direct]=rreq[direct];
   }
	return 0;
}

int P_BExchange_end ( bexchange ) 
	BExchange    *bexchange;
{
   MPI_Status		status;
   int 				direct, overlap[2], send[2], recv[2];
   MPI_Request		sreq[2], rreq[2]; 

   for ( direct = 0; direct < 2; direct++ ) {
      overlap[direct] = bexchange->overlap[direct];
      send[direct] = bexchange->send[direct];
      recv[direct] = bexchange->recv[direct];
      sreq[direct] = bexchange->sreq[direct];
      rreq[direct] = bexchange->rreq[direct];
   }

   for ( direct = 0; direct < 2; direct++ ) {
      if ( overlap[direct] > 0 ) {
         if ( send[direct] ) {
            MPI_Wait ( &sreq[direct], &status );
         }
         if ( recv[direct] ) {
            MPI_Wait ( &rreq[direct], &status );
         }
      }
   }
	return 0;
}

int P_BExchange_free ( bexchange ) 
	BExchange    *bexchange;
{
   int direct, overlap[2];
   MPI_Datatype btype[2];
   for ( direct = 0; direct < 2; direct++ ) {
      overlap[direct] = bexchange->overlap[direct];
      btype[direct] = bexchange->btype[direct];
	}
   for ( direct = 0; direct < 2; direct++ ) {
      if ( overlap[direct] > 0 ) {
         MPI_Type_free ( &btype[direct] );
      }
   }
	return 0;
}

int P_BExchange ( a, ndims, stride, blklen, bdim, overlap, datatype,
		comm, period )
   void *a; 
   MPI_Datatype datatype;
   int ndims, *stride, *blklen, bdim, overlap[2]; 
   int period;
   MPI_Comm comm;
{
	BExchange bexchange;
	int ierr;
   if ( ierr = P_BExchange_init ( ndims, stride, blklen, bdim, overlap, 
		datatype, comm, period, &bexchange ) != 0 ) { return ierr; }
   P_BExchange_start ( a, &bexchange );
   P_BExchange_end ( &bexchange );
   P_BExchange_free ( &bexchange );
   return 0;
}