import netCDF4 as nc
import numpy as np
import argparse


def summary(vars, dims):
    for var in vars:
        dims_str = ' x '.join('%d(%s)' % (dims[dim], dim) for dim in reversed(var.dimensions))
        print('  name =', var.name, ', dims =', dims_str, ', type =', var.dtype)
        attrs = {attr: var.getncattr(attr) for attr in var.ncattrs()}
        print('    long_name =', attrs.get("long_name"), ', standard_name =', attrs.get("standard_name"))
    

def main():
    parser = argparse.ArgumentParser(description='Extract STD fields from netCDF4 file')
    parser.add_argument('-i', '--input', help='input netCDF4 file', required=True)
    parser.add_argument('-v', '--var', action='append', help='Field variable to extract, may be specified multiple times')
    parser.add_argument('-o', '--out-format', default='{v}.std', help='Output filename pattern. Patter %s will be replaced with variable name')
    args = parser.parse_args()
    
    root = nc.Dataset(args.input, "r")
    dims = {dim.name: dim.size for dim in root.dimensions.values()}
    vars = root.variables

    if not args.var:
        print('No variables to extract, printing variable summary')
        summary(vars.values(), dims)
    else:
        for var_name in args.var:
            if var_name not in vars:
                print(f"Variable {var_name} is not present in the file")
                return

        if len(args.var) > 1 and ('{v}' not in args.out_format):
            print("Several vars are requested but output file is the same for all vars. Please use {v} to produce different files")
            return
            
            
        for var_name in args.var:
            var = vars[var_name]
            outfn = args.out_format.format(v=var_name)
            print(f'Extracting {var_name} with type {var.dtype} and size {tuple(reversed(var.shape))} into {outfn}')
            np.array(var).tofile(outfn)

if __name__ == "__main__":
    main()