import os
import errno
import sys

import datetime as dt  # Python standard library datetime  module
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
from matplotlib.ticker import MultipleLocator
from matplotlib import ticker
import time
import argparse


def ncdump(nc_fid, verb=True):
    """
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_fid : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    """
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_fid.variables[key].dtype)
            for ncattr in nc_fid.variables[key].ncattrs():
                print "\t\t%s:" % ncattr,\
                      repr(nc_fid.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_fid.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print "\t%s:" % nc_attr, repr(nc_fid.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_fid.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_fid.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_fid.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print "\tName:", var
                print "\t\tdimensions:", nc_fid.variables[var].dimensions
                print "\t\tsize:", nc_fid.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars



def createFig(prefix, outPath, nc_fid, data, lons, lats, figType, depth_idx):
    """
    createFig plot the data using lat/lon to the Figure and save it to the outPath.

    Parameters
    ----------
    prefix : string
        Main data input file name
    outPath : string
        Figure save to the location
    nc_fid : netCDF4.Dataset 
        A netCDF4 dateset object
    data : np.ma.MaskedArray
        Value reads from Dataset
    lons : numpy.ndarray
        Value reads from Dataset
    lats : numpy.ndarray
        Value reads from Dataset
    figType : string
        The data type reads from Dataset
    depth_idx : string
        Depth of the data

    Returns
    -------
    
    """
    tStart_f = time.time()
    
    path = outPath + "/" + prefix + "/" + figType + "/"

    if figType == "P" or figType == "P_anomaly":
        data = data/98000 

    fig = plt.figure(figsize=(20, 20))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, \
                    wspace=None, hspace=None)
    axes = fig.add_subplot(1, 1, 1)

    tEnd_f = time.time()
    print "creatFig Cal1 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()


    ll_lat = 0
    ll_lon = 100
    ur_lat = 50
    ur_lon = 150 
    m = Basemap(projection="merc", llcrnrlat=ll_lat,urcrnrlat=ur_lat,\
            llcrnrlon=ll_lon,urcrnrlon=ur_lon, resolution="c")

    tEnd_f = time.time()
    print "creatFig Cal2 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()

    #m.drawcoastlines()
    #m.drawmapboundary(fill_color='aqua')
    #m.drawmapboundary()

    tEnd_f = time.time()
    print "creatFig Cal3 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()

    # Create 2D lat/lon arrays for Basemap
    lon2d, lat2d = np.meshgrid(lons, lats)
    # Transforms lat/lon into plotting coordinates for projection 
    x, y = m(lon2d, lat2d)
    
    color_num = 101
    
    norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
    cmap=plt.cm.jet
    if figType == "P_anomaly":
        norm = mpl.colors.Normalize(vmin=-0.5, vmax=0.5)
        cmap=plt.cm.seismic
    elif figType == "T_anomaly":
        norm = mpl.colors.Normalize(vmin=-4, vmax=4)
        cmap=plt.cm.seismic
    elif figType == "S_anomaly":
        norm = mpl.colors.Normalize(vmin=-1, vmax=1)
        cmap = plt.cm.seismic
    elif figType == "currents":
        norm = mpl.colors.Normalize(vmin=data.max()-0.8*(data.max()-data.min()), vmax=0.8*(data.max()-data.min())+data.min())
        cmap = plt.cm.jet
    elif figType == "taux" or figType == "tauy":
        maxv = np.abs(data).max()
        norm = mpl.colors.Normalize(vmin=-maxv, vmax=maxv)
        cmap = plt.cm.jet

    cs = m.contourf(x, y, data, color_num, cmap=cmap, norm=norm)

    fig.tight_layout()

    fig.savefig(path + figType + "_" + depth_idx + ".png", transparent=True, format="png", pad_inches=0, bbox_inches="tight")

    fig = 0
    if figType == "P":
        fig = plt.figure(figsize=(2, 4.0), frameon=False, facecolor=None)
    else :
        fig = plt.figure(figsize=(1.5, 4.0), frameon=False, facecolor=None)
    ax = fig.add_axes([0.2, 0.1, 0.2, 0.8], axisbg='g')

    cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, norm=norm)
    cb.set_label("", color="k")
    if figType == "P" or figType == "P_anomaly":
        plt.xlabel("%s (%s)" % ("", "m"))
    elif figType == "currents":
        plt.xlabel("%s (%s)" % ("", nc_fid.variables["U"].units), x=2)    
    elif figType == "heat_up":
        plt.xlabel("%s (%s)" % ("", "W/m^2"))
    elif figType == "swr_down":
        plt.xlabel("%s (%s)" % ("", "W/m^2"))
    elif figType == "taux":
        plt.xlabel("%s (%s)" % ("", "dyn/cm^2"))
    elif figType == "tauy":
        plt.xlabel("%s (%s)" % ("", "dyn/cm^2"))
    elif figType == "e_p":
        plt.xlabel("%s (%s)" % ("", "g/cm2/s"))
    elif figType == "T_anomaly":
        plt.xlabel("%s (%s)" % ("", "degree C"))
    elif figType == "S_anomaly":
        plt.xlabel("%s (%s)" % ("", "ppt"))
    elif "units" in nc_fid.variables[figType].ncattrs() :
        plt.xlabel("%s (%s)" % ("", nc_fid.variables[figType].units))

    tEnd_f = time.time()
    print "creatFig Cal4 cost %f sec" % (tEnd_f - tStart_f)

    tStart_f = time.time()
    fig.savefig(path + figType + "_" + depth_idx + "_cbar.png", transparent=False, format="png", facecolor="white")
    tEnd_f = time.time()
    print "createFig Save step cost %f sec" % (tEnd_f - tStart_f)

    plt.close("all")

def smooth(INs, data, depth):
    """
    Smooth the data when create the vector figure.

    Parameters
    ----------
    INs : np.ma.MaskedArray
        Mask read from file IN.nc
    data : np.ma.MaskedArray
        Value reads from Dataset
    depth : string
        Depth of INs
    Returns
    -------
    data : np.ma.MaskedArray
        Smoothed data
    """
    end_lon = 406
    end_lat = 488
    IN = INs[int(depth)]
    
    AVG = ( \
    IN[1:end_lat-1,1:end_lon-1] + \
    IN[2:end_lat,1:end_lon-1] + \
    IN[1:end_lat-1,2:end_lon] + \
    IN[2:end_lat,2:end_lon]
    )

    for i in range(1, 2, 1):
        data = \
        0.5 * data[1:end_lat-1,1:end_lon-1] * IN[1:end_lat-1,1:end_lon-1] + \
        0.5/AVG * (data[:end_lat-2,1:end_lon-1] * IN[:end_lat-2,1:end_lon-1] + \
        data[2:end_lat,1:end_lon-1] * IN[2:end_lat,1:end_lon-1] + \
        data[1:end_lat-1,:end_lon-2] * IN[1:end_lat-1,:end_lon-2] + \
        data[1:end_lat-1,2:end_lon] * IN[1:end_lat-1,2:end_lon])

    return data



def standardizeSeq(u):
    """
    Standarize the data into same scale to prevent the vector is too long or too short

    Parameters
    ----------
    u : np.ma.MaskedArray
        Value reads from Dataset

    Returns
    -------
    u : np.ma.MaskedArray
        Standardize value
    """
    vmax_u = np.max(abs(u))
 
    level = 25
    level_diff = vmax_u/level
    end_lon = 406
    end_lat = 488
    u = np.array([np.ceil(x/level_diff) for x in u])
    return u



def drawVector(prefix, outPath, INs, U, V, lons, lats, figType, depth, zoom, feq, scale, arrow_width):
    """
    createFig plot the data using lat/lon to the Figure and save it to the outPath.

    Parameters
    ----------
    prefix : string
        Main data input file name
    outPath : string
        Figure save to the location
    INs : np.ma.MaskedArray
        Mask read from file IN.nc
    U : np.ma.MaskedArray
        Value reads from Dataset
    V : np.ma.MaskedArray
        Value reads from Dataset
    lons : numpy.ndarray
        Value reads from Dataset
    lats : numpy.ndarray
        Value reads from Dataset
    figType : string
        The data type reads from Dataset
    depth_idx : string
        Depth of the U/V
    zoom : int
        Zoom level in the map
    feq : int
        Frequency of the point to draw vector
    scale : 
        Data units per arrow length unit
    arrow_width:
        Width of the arrow

    Returns
    -------

    """
    tStart_f = time.time()

    U = smooth(INs, U, depth)
    V = smooth(INs, V, depth)
    U = standardizeSeq(U)
    V = standardizeSeq(V)

    path = outPath + "/" + prefix + "/" + figType + "/"

    ll_lat = 0
    ll_lon = 100
    ur_lat = 50
    ur_lon = 150

    fig = plt.figure(figsize=(20, 20))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, \
                        wspace=None, hspace=None)
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(ll_lon, ur_lon)
    ax.set_ylim(ll_lat, ur_lat)
    ax.set_xbound(ll_lon, ur_lon)
    ax.set_ybound(ll_lat, ur_lat)

    lon2d, lat2d = np.meshgrid(lons, lats)
    m = Basemap(projection="merc", llcrnrlat=ll_lat,urcrnrlat=ur_lat,\
                llcrnrlon=ll_lon,urcrnrlon=ur_lon, resolution="i")
    #m.drawcoastlines()
    x, y = m(lon2d, lat2d)
    m.drawmapboundary(fill_color="aqua")
    m.quiver(x[::feq, ::feq], y[::feq, ::feq], U[::feq, ::feq], V[::feq, ::feq], scale=scale, headlength=5, headwidth=6, width=arrow_width, minshaft=1, minlength=1, scale_units="inches", units="inches", color="w")
    
    plt.axis("off")
    fig.tight_layout()

    tEnd_f = time.time()
    print "drawVector Cal step cost %f sec" % (tEnd_f - tStart_f)

    tStart_f = time.time()
    fig.savefig(path + figType + "_" + depth + "_" + str(zoom) + "_v.png", transparent=True, format="png", pad_inches=0, bbox_inches="tight", dpi=250)
    tEnd_f = time.time()
    print "drawVector Save step cost %f sec" % (tEnd_f - tStart_f)
    plt.close("all")

def checkDir(prefix, outPath):
    dir = ["P", "S", "T", "currents", "e_p", "heat_up", "swr_down", "taux", "tauy", "P_anomaly", "T_anomaly", "S_anomaly"]
    path = outPath + prefix
    for d in dir :
        try:
            os.stat(path)
        except:
            os.mkdir(path) 
        try:
            os.stat(path + "/" +d)
        except:
            os.mkdir(path + "/" + d) 


def initReader():
    """
    Initial the reader

    Parameters
    ----------

    Returns
    -------

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-ip", "--filePath", help="input path")
    parser.add_argument("-ifp", "--iFluxesPath", help="input fluxes path")
    parser.add_argument("-iap", "--iAnomalyPath", help="input anomaly path")
    parser.add_argument("-iip", "--iInPath", help="input in.nc path")
    parser.add_argument("-op", "--outPath", help="output path")

    args = parser.parse_args()
    outPath = ""
    inputFile = ""
    inputFileFluxes = ""
    inputFileAnomaly = ""
    inFile = "IN.nc"
    prefix = ""
    if args.filePath:
        inputFile = args.filePath
    if args.iFluxesPath:
        inputFileFluxes = args.iFluxesPath
    if args.iAnomalyPath:
        inputFileAnomaly = args.iAnomalyPath
    if args.iInPath:
        inFile = args.iInPath
    if args.outPath:
        outPath = args.outPath + "/"
    
    try:
        os.stat(inputFile)
    except:
        print 'file ' + inputFile + ' is not exist'
        sys.exit()
    try:
        os.stat(inputFileFluxes)
    except:
        print 'file ' + inputFileFluxes + ' is not exist'
        sys.exit()
    try:
        os.stat(inputFileAnomaly)
    except:
        print 'file ' + inputFileAnomaly + ' is not exist'
        sys.exit()
        
    prefix = os.path.basename(inputFile).replace('.nc', '')
    
    return prefix, inputFile, inputFileFluxes, inputFileAnomaly, inFile, outPath

def readMainNc(fileName):
    """
    Read Main nc file

    Parameters
    ----------
    fileName: string
        main nc file name

    Returns
    -------

    """
    try:
        os.stat(fileName)
    except:
        print 'file ' + fileName + ' is not exist'
        sys.exit()
    nc_fid = Dataset(fileName, 'r')  # Dataset is the class behavior to open the file
                                     # and create an instance of the ncCDF4 class
    nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    # Extract data from NetCDF file
    lats_ = nc_fid.variables['Latitude'][:][1:487]
    lons_ = nc_fid.variables['Longitude'][:][1:405]
    U = nc_fid.variables['U'][0][:]  
    V = nc_fid.variables['V'][0][:]  
    P = nc_fid.variables['P'][0][:]  
    T = nc_fid.variables['T'][0][:]  
    S = nc_fid.variables['S'][0][:]  
    mask = np.ma.getmask(P[0][1:487, 1:405])
    return nc_fid, lats_, lons_, U, V, P, T, S, mask

def readFlux(fileName, mask):
    """
    Read Flux nc file

    Parameters
    ----------
    fileName: string
        flux nc file name

    Returns
    -------

    """
    try:
        os.stat(fileName)
    except:
        print 'file ' + fileName + ' is not exist'
        sys.exit()
    
    nc_fid = Dataset(fileName, 'r')  # Dataset is the class behavior to open the file
                                     # and create an instance of the ncCDF4 class
    nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    # Extract data from NetCDF file
    heat_up = np.ma.MaskedArray(nc_fid.variables['heat_up'], mask=mask)
    swr_down = np.ma.MaskedArray(nc_fid.variables['swr_down'], mask=mask)
    taux = np.ma.MaskedArray(nc_fid.variables['taux'], mask=mask)
    tauy = np.ma.MaskedArray(nc_fid.variables['tauy'], mask=mask)
    e_p = np.ma.MaskedArray(nc_fid.variables['e_p'], mask=mask)
    return nc_fid, heat_up, swr_down, taux, tauy, e_p 



def readAnomaly(fileName, mask):
    """
    Read Anomaly nc file

    Parameters
    ----------
    fileName: string
        anomaly nc file name

    Returns
    -------

    """
    try:
        os.stat(fileName) 
    except:
        print 'file ' + fileName + ' is not exist'
        sys.exit()

    nc_fid = Dataset(fileName, 'r')
    nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    # Extract data from NetCDF file
    P = np.ma.MaskedArray(nc_fid.variables['P'][0], mask=mask)
    T = np.ma.MaskedArray(nc_fid.variables['T'][0], mask=mask)
    S = np.ma.MaskedArray(nc_fid.variables['S'][0], mask=mask)
    
    return nc_fid, P, T, S 
       
def readIn(fileName):
    """
    Read IN nc file

    Parameters
    ----------
    fileName: string
        IN nc file name

    Returns
    -------

    """

    nc = Dataset('IN.nc', 'r')  # Dataset is the class behavior to open the file
                                # and create an instance of the ncCDF4 class
    nc_attrs, nc_dims, nc_vars = ncdump(nc)
    INs = nc.variables['IN'][:]
    nc.close()
    return INs


def main(argv) :
    """
    The start of the program

    Parameters
    ----------

    Returns
    -------

    """
     
    tStart = time.time()
    prefix, inputFile, inputFileFluxes, inputFileAnomaly, inFile, outPath = initReader() 
    checkDir(prefix, outPath)
    tEnd = time.time()
    print "Init step cost %f sec" % (tEnd - tStart)

    INs = readIn(inFile)
    
    tStart = time.time()
    
    nc_fid_main, lats_, lons_, U, V, P, T, S, mask = readMainNc(inputFile)
    tEnd = time.time()
    print "read file step cost %f sec" % (tEnd - tStart)

    tStart = time.time()
    d_size = 25

    feq = 10
    scale = 25
    
    Pi = P[0]
    createFig(prefix, outPath, nc_fid_main, Pi[1:487, 1:405], lons_, lats_, 'P', str(0))
    for i in range(0, d_size, 1):
        Ti = T[i]
        Si = S[i]
        Ui = U[i]
        Vi = V[i]
        currents = np.sqrt(Ui*Ui+Vi*Vi)
        createFig(prefix, outPath, nc_fid_main, Ti[1:487, 1:405], lons_, lats_, 'T', str(i))
        createFig(prefix, outPath, nc_fid_main, Si[1:487, 1:405], lons_, lats_, 'S', str(i))
        createFig(prefix, outPath, nc_fid_main, currents[1:487, 1:405], lons_, lats_, 'currents', str(i))

        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 4, 8, 25, 0.01)
        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 5, 8, 25, 0.01)
        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 6, 6, 25, 0.01)
        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 6, 6, 50, 0.01)
        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 7, 4, 50, 0.0075)
        drawVector(prefix, outPath, INs, Ui, Vi, lons_, lats_, 'currents', str(i), 8, 2, 100, 0.005)

    tEnd = time.time()
    print "createFig total cost %f sec" % (tEnd - tStart)

    tStart = time.time()
    nc_fid_flux, heat_up, swr_down, taux, tauy, e_p = readFlux(inputFileFluxes, mask)
    tEnd = time.time()
    print "read file Flux total cost %f sec" % (tEnd - tStart)


    tStart = time.time()
    createFig(prefix, outPath, nc_fid_flux, heat_up, lons_, lats_, 'heat_up', '0')
    createFig(prefix, outPath, nc_fid_flux, swr_down, lons_, lats_, 'swr_down', '0')
    createFig(prefix, outPath, nc_fid_flux, taux, lons_, lats_, 'taux', '0')
    createFig(prefix, outPath, nc_fid_flux, tauy, lons_, lats_, 'tauy', '0')
    createFig(prefix, outPath, nc_fid_flux, e_p, lons_, lats_, 'e_p', '0')

    tEnd = time.time()
    print "createFig Flux total cost %f sec" % (tEnd - tStart)

    
    nc_fid_anomaly, P_anomaly, T_anomaly, S_anomaly = readAnomaly(inputFileAnomaly, np.ma.getmask(P))
   
    feq = 10
    scale = 25

    for i in range(0, d_size, 1):
        createFig(prefix, outPath, nc_fid_anomaly, P_anomaly[i][1:487, 1:405], lons_, lats_, 'P_anomaly', str(i))
        createFig(prefix, outPath, nc_fid_anomaly, T_anomaly[i][1:487, 1:405], lons_, lats_, 'T_anomaly', str(i))
        createFig(prefix, outPath, nc_fid_anomaly, S_anomaly[i][1:487, 1:405], lons_, lats_, 'S_anomaly', str(i))

    nc_fid_main.close()
    nc_fid_flux.close()
    nc_fid_anomaly.close()
    # Close original NetCDF file.

if __name__ == "__main__":
    main(sys.argv)