#pragma once

// [bin-stamp.h]: binStamp data structure for file I/O
//
// -------------------------------------------------------------------------------------------- //

#include "grid-id.h"
#include <stdio.h>

namespace nse
{
	template< typename T >
	class binStamp {

	public:
		binStamp();
		~binStamp();


		bool get(const int idx, T* out) const;
		bool update(const int idx, T* out) const;

		void push(const T in);

		int fwrite(FILE *ptr) const;

		int fread(FILE* ptr);
#ifdef _USE_DEPRECATED_WST_FORMAT
		int fread(FILE* ptr, const int _size);
#endif

		void mpi_broadcast(const int host, const MPI_Comm comm);

		// read-only: //
		int mem_size, size;
		T *value;

	private:
		static const int c_alloc_step = 16;

		void allocate(const int msize);
		void resize(const int msize);
	};
}
// -------------------------------------------------------------------------------------------- //

// Implementation
// -------------------------------------------------------------------------------------------- //
template< typename T >
nse::binStamp< T >::binStamp() : mem_size(0), size(0) {}
template< typename T >
nse::binStamp< T >::~binStamp()
{
	if (mem_size > 0) delete[] value;

	mem_size = 0;
	size = 0;
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
bool nse::binStamp< T >::get(const int idx, T* out) const
{
	if ((idx < 0) || (idx >= size)) return false;

	(*out) = value[idx];
	return true;
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
bool nse::binStamp< T >::update(const int idx, T* out) const
{
	if ((idx < 0) || (idx >= size)) return false;

	(*out) += value[idx];
	return true;
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
void nse::binStamp< T >::push(const T in)
{
	resize(size + 1);

	value[size] = in;
	size++;
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
void nse::binStamp< T >::allocate(const int req_size)
{
	if (req_size > mem_size)
	{
		int alloc_size = (req_size > mem_size + c_alloc_step) ?
			req_size : mem_size + c_alloc_step;

		if (mem_size > 0) delete[] value;
		value = new T[alloc_size];
		mem_size = alloc_size;
	}
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
void nse::binStamp< T >::resize(const int req_size)
{
	if (req_size > mem_size)
	{
		int alloc_size = (req_size > mem_size + c_alloc_step) ?
			req_size : mem_size + c_alloc_step;

		T *cpval = new T[alloc_size];
		if (size > 0) memcpy(cpval, value, size * sizeof(T));
		if (mem_size > 0) delete[] value;

		value = cpval;
		mem_size = alloc_size;
	}
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
int nse::binStamp< T >::fwrite(FILE *ptr) const
{
	int status = 0;

	status += ::fwrite(&size, sizeof(int), 1, ptr);
	if (size > 0)
		status += ::fwrite(value, sizeof(T), size, ptr);

	return status;
}
// -------------------------------------------------------------------------------------------- //

template< typename T >
int nse::binStamp< T >::fread(FILE* ptr)
{
	int status = 0;

	status += ::fread(&size, sizeof(int), 1, ptr);
	allocate(size);

	if (size > 0)
		status += ::fread(value, sizeof(T), size, ptr);

	return status;
}
// -------------------------------------------------------------------------------------------- //

#ifdef _USE_DEPRECATED_WST_FORMAT
template< typename T >
int nse::binStamp< T >::fread(FILE* ptr, const int _size)
{
	int status = 0;

	size = _size;
	allocate(size);

	if (size > 0)
		status += ::fread(value, sizeof(T), size, ptr);

	return status;
}
// -------------------------------------------------------------------------------------------- //
#endif

template< typename T >
void nse::binStamp< T >::mpi_broadcast(
	const int host, const MPI_Comm comm)
{
	int mpi_rank;
	MPI_Comm_rank(comm, &mpi_rank);

	MPI_Bcast(&size, 1, MPI_INT, host, comm);
	if (mpi_rank != host) allocate(size);

	if (size > 0)
		MPI_Bcast(value, size, mpi_type< T >(), host, comm);
}
// -------------------------------------------------------------------------------------------- //