import os
import re
import time
import calendar
import operator

import numpy as np
import netCDF4 as nc



NUMBER = '[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?'


class CTLReader(object):
    def __init__(self, filename):
        self.dimensions = {}
        self.variables = {}

        self.filename = filename

        with open(filename, 'rb') as fp:
            self.ctl = fp.read().decode("utf-8")

        self._read_data()
        self._read_dimensions()
        self._read_vars()

    def _read_data(self):
        self.undef = eval(re.search("UNDEF *(%s)" % NUMBER, self.ctl).group(1))
        self.yrev = bool(re.search("OPTIONS.*YREV", self.ctl))
        big_endian = bool(re.search("OPTIONS.*BIG_ENDIAN", self.ctl))
        dset = re.search("DSET *(.*)", self.ctl).group(1)
        if dset.startswith('^'):
            dset = os.path.join(os.path.dirname(self.filename), dset[1:])
        data = np.fromfile(dset.strip(), 'f')
        if big_endian:
            data = data.byteswap()
        self.data = np.ma.masked_values(data, self.undef)
    
    def _read_dimensions(self):
        if 'XDEF' in self.ctl:
            self.variables['longitude'] = Variable('longitude', self._parse_dimension('XDEF'))
            self.dimensions['longitude'] = len(self.variables['longitude'])
            self.variables['longitude'].dimensions = ('longitude')
            self.variables['longitude'].units = 'degrees_east'
        if 'YDEF' in self.ctl:
            self.variables['latitude'] = Variable('latitude', self._parse_dimension('YDEF'))
            self.dimensions['latitude'] = len(self.variables['latitude'])
            self.variables['latitude'].dimensions = ('latitude')
            self.variables['latitude'].units = 'degrees_north'
        if 'ZDEF' in self.ctl:
            self.variables['levels'] = Variable('levels', self._parse_dimension('ZDEF'))
            self.dimensions['levels'] = len(self.variables['levels'])
            self.variables['levels'].dimensions = ('levels')
            self.variables['levels'].units = 'm'
        if 'TDEF' in self.ctl:
            self.variables['time'] = Variable('time', self._parse_dimension('TDEF'))
            self.dimensions['time'] = len(self.variables['time'])
            self.variables['time'].dimensions = ('time')

    def _read_vars(self):
        read = False
        for line in self.ctl.split('\n'):
            if line.startswith('ENDVARS'):
                read = False
            if read:
                p = re.compile('(\w+)\s+(\d+)\s+(\d+)\s+(.*).*')
                m = p.match(line)
                name = m.group(1)
                var = self.variables[name] = Variable(name)
                levels = list(map(int, m.group(2).split(',')))
                SPACE = self.dimensions['latitude'] * self.dimensions['longitude']
                if levels[0] > 0:
                    var.dimensions = ('time', 'levels', 'latitude', 'longitude')
                    
                    if self.dimensions['time'] > 1:
                        size = self.dimensions['time'] * self.dimensions['levels'] * (SPACE+2)  # account for header bytes
                    else:
                        size = self.dimensions['time'] * self.dimensions['levels'] * SPACE                    
                    
                else:
                    var.dimensions = ('time', 'latitude', 'longitude')

                    if self.dimensions['time'] > 1:
                        size = self.dimensions['time'] * (SPACE+2)  # account for header bytes
                    else:
                        size = self.dimensions['time'] * SPACE


                var.shape = tuple(self.dimensions[dim] for dim in var.dimensions)
                var.data = self.data[i:i+size].reshape(-1, SPACE)[:,:].reshape(var.shape)  # remove header bytes
                if self.yrev:
                    var.data = var.data[...,::-1,:]
                i += size

                #units = int(m.group(3))
                #if units != 99:
                #    raise NotImplementedError('Only unit 99 implemented!')

                #var.attributes = {
                #    'long_name': m.group(4).strip(),
                #    'units': m.group(5).strip(),
                #}
            if line.startswith('VAR'):
                i = 0
                read = True

    def _parse_dimension(self, dim):
        p = re.compile("%s\s+(\d+)\s+LINEAR\s+(%s)\s+(%s)" % (dim, NUMBER, NUMBER))
        m = p.search(self.ctl)
        if m:
            length = int(m.group(1))
            start = float(m.group(2))
            increment = float(m.group(3))
            return np.arange(start, start+length*increment, increment)

        p = re.compile("%s\s+\d+\s+LEVELS((\s+%s)+)" % (dim, NUMBER))
        m = p.search(self.ctl)
        if m:
            return np.fromstring(m.group(1), sep=' ')

        p = re.compile("%s\s+(\d+)\s+LINEAR\s+([:\w]+)\s+(\d{1,2})(\w{2})" % dim)
        m = p.search(self.ctl)
        if m:
            length = int(m.group(1))
            start = parse_date(m.group(2))
            value = m.group(3)
            unit = m.group(4).lower()
            
            units = ['mn', 'hr', 'dy', 'mo']

            if unit in ['mn', 'hr', 'dy']:
              increment = parse_delta(value, unit)

              for i in range(length):
                print(start+i*increment)
              return np.array([ start+i*increment for i in range(length)])
            
            elif unit == 'mo':
              datetimes = []
              datetimes = datetimes + [start]
              
              nmonth_int = int(value)
              month_frac = float(value) - nmonth_int
              
              increment_months_int = np.timedelta64(value, 'M')
              time_str = get_time_str(start)
              datetime = start

              for num in range(1,length):
                year, month, day = get_date(datetime)

                year_month = get_yearmonth64(year, month)

                datetime1 = year_month + increment_months_int
                year1, month1, day1 = get_date(datetime1)

                days_left = (day-1) + int(month_frac*get_month_ndays(year1, month1))
                datetime2 = datetime1 + np.timedelta64(days_left, 'D')
                
                year2, month2, day2 = get_date(datetime2)
                if (month2 - month1)%12 == 2:
                  month2 = month1%12 + 1
                
                datetime3 = get_date64(year2, month2, day2)
                
                datetime = np.datetime64(get_datetime64_str(datetime3) + 'T' + time_str)
                datetimes = datetimes + [datetime]
              return np.array(datetimes)
            else:
              str_err = "Unsupported time units '"+ unit + "'. Supported units are: " +  str(units)
              
              if unit == 'yr':
                raise NotImplementedError(str_err + '. Need to implement year time step.')
              else:
                raise Exception(str_err)


class Variable(object):
    def __init__(self, name, data=None):
        self.name = name
        self.data = data
        self.units = None

    def __getitem__(self, index):
        return self.data[index] 

    def __getattr__(self, key):
        return self.attributes[key]

    def __len__(self):
        return len(self.data)


def get_year(datetime64):
  return int(np.datetime_as_string(start).split('-')[0])

def get_month(datetime64):
  return int(np.datetime_as_string(start).split('-')[1])

def get_day(datetime64):
  return int(np.datetime_as_string(start).split('-')[2])

def get_date(datetime64):
  split_datetime = np.datetime_as_string(datetime64).split('T')
  split_date = split_datetime[0].split('-')
  year = int(split_date[0])
  month = int(split_date[1])
  if len(split_date)>=3:
    day = int(split_date[2])
  else:
    day = 1
  return year, month, day

def get_time_str(datetime64):
  return np.datetime_as_string(datetime64).split('T')[1]

def get_date_str(year, month, day):
  date = (str(year).zfill(4)
          +  '-'+str(month).zfill(2)
          +  '-'+str(day).zfill(2))
  return date


def get_datetime64(year, month=1, day=1, hour=0, minute=0):
  return np.datetime64(
            str(year).zfill(4)
            +  '-'+str(month).zfill(2)
            +  '-'+str(day).zfill(2)
            +  'T'+str(hour).zfill(2)
            +':'+str(minute).zfill(2))

def get_date64(year, month, day):
  return np.datetime64(
            str(year).zfill(4)
            +  '-'+str(month).zfill(2)
            +  '-'+str(day).zfill(2))

def get_yearmonth64(year, month):
  return np.datetime64(
            str(year).zfill(4)
            +  '-'+str(month).zfill(2))

def get_datetime64_str(datetime64):
  return np.datetime_as_string(datetime64)
  
def get_month_ndays(year, month):
  (dummy_int, ndays) = calendar.monthrange(year, month)
  return ndays



def parse_date(s):
    DATE = re.compile("""
        (?:(?P<hour>\d\d))?     # hh, default 00
        (?::(?P<minute>\d\d))?  # mm, default 00
        Z?
        (?P<day>\d\d)?          # dd, default 1
        (?P<month>\w\w\w)       # 3 char month
        (?P<year>\d\d(?:\d\d)?) # yyyy or 1950 < yy < 2049
    """, re.VERBOSE)
    d = DATE.match(s).groupdict()
    if d['hour'] is None:
        hour = 0
    else:
        hour = int(d['hour'])
    if d['minute'] is None:
        minute = 0
    else:
        minute = int(d['minute'])
    if d['day'] is None:
        day = 1
    else:
        day = int(d['day'])
    month = time.strptime(d['month'], '%b')[1]
    if len(d['year']) == 4:
        year = int(d['year'])
    else:
        year = 1950 + int(d['year'])

    if day <= 0:
      day = 1
    return get_datetime64(year, month, day, hour, minute)


def parse_delta(value, unit):
    value = int(value)
    if unit.lower() == 'mn':
        return(np.timedelta64(value, "m"))
    if unit.lower() == 'hr':
        return(np.timedelta64(value, "h"))
    if unit.lower() == 'dy':
        return(np.timedelta64(value, "D"))


def no_shift(varname):
  return 0.0

def ctl_to_netcdf(ctl_file, out_file, shift=no_shift, diskless=False):
    f = CTLReader(ctl_file)
    
    ds = nc.Dataset(out_file, 'w', format='NETCDF4', diskless=diskless)
    
    if shift is None:
      varnames = f.variables.keys()
      shift = { shift_val : 0.0 for shift_val in varnames }
    
    for dim in f.dimensions:
      nc_dim = ds.createDimension(dim, f.dimensions[dim])
      nc_dim = f.variables[dim][:]

    for var in f.variables:
      nc_var = ds.createVariable(var, 'f', f.variables[var].dimensions)

      if var == 'time':
        nc_var.units = "days since " + str(f.variables[var][0])
        nc_var[:] = (f.variables[var][:] - f.variables[var][0])/np.timedelta64(1,'D')
      else:      
        if f.variables[var].units is not None:
          nc_var.units = f.variables[var].units

        nc_var[:] = f.variables[var][:] + shift(var)

    if not diskless:
      ds.close()
    return ds

def ctl_read(ctl_file, shift=no_shift):
  return ctl_to_netcdf(ctl_file, out_file="1.nc", shift=shift, diskless=True)


if __name__ == '__main__':
    import sys

    f = ctl_to_netcdf(sys.argv[1], sys.argv[2])