import os
import errno
import sys

import datetime as dt  # Python standard library datetime  module
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
from matplotlib.ticker import MultipleLocator
from matplotlib import ticker
import time
import argparse


def ncdump(nc_fid, verb=True):
    """
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_fid : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    """
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_fid.variables[key].dtype)
            for ncattr in nc_fid.variables[key].ncattrs():
                print "\t\t%s:" % ncattr,\
                      repr(nc_fid.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_fid.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print "\t%s:" % nc_attr, repr(nc_fid.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_fid.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_fid.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_fid.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print "\tName:", var
                print "\t\tdimensions:", nc_fid.variables[var].dimensions
                print "\t\tsize:", nc_fid.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars



def createFig(prefix, outPath, nc_fid, data, lons, lats, figType, time_idx, gridName):
    """
    createFig plot the data using lat/lon to the Figure and save it to the outPath.

    Parameters
    ----------
    prefix : string
        Main data input file name
    outPath : string
        Figure save to the location
    nc_fid : netCDF4.Dataset 
        A netCDF4 dateset object
    data : np.ma.MaskedArray
        Value reads from Dataset
    lons : numpy.ndarray
        Value reads from Dataset
    lats : numpy.ndarray
        Value reads from Dataset
    figType : string
        The data type reads from Dataset
    time_idx : string
        date of the data

    Returns
    -------
    
    """
    tStart_f = time.time()
    
    path = outPath + "/" + prefix + "/" + figType + "/"

    fig = plt.figure(figsize=(20, 20))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, \
                    wspace=None, hspace=None)
    axes = fig.add_subplot(1, 1, 1)

    tEnd_f = time.time()
    print "creatFig Cal1 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()

    if gridName == "grd1":
      ll_lat = 0
      ll_lon = 95
      ur_lat = 60
      ur_lon = 155
    elif gridName == "grd2": 
      ll_lat = 15
      ll_lon = 110
      ur_lat = 35
      ur_lon = 140 
    elif gridName == "grd3":
      ll_lat = 21
      ll_lon = 119
      ur_lat = 26
      ur_lon = 123
    elif gridName == "grd4":
      ll_lat = 24.4
      ll_lon = 121.65
      ur_lat = 24.9
      ur_lon = 122.15

    m = Basemap(projection="merc", llcrnrlat=ll_lat,urcrnrlat=ur_lat,\
            llcrnrlon=ll_lon,urcrnrlon=ur_lon, resolution="c")

    tEnd_f = time.time()
    print "creatFig Cal2 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()

    #m.drawcoastlines()
    #m.drawmapboundary(fill_color='aqua')
    #m.drawmapboundary()

    tEnd_f = time.time()
    print "creatFig Cal3 cost %f sec" % (tEnd_f - tStart_f)
    tStart_f = time.time()

    # Create 2D lat/lon arrays for Basemap
    lon2d, lat2d = np.meshgrid(lons, lats)
    # Transforms lat/lon into plotting coordinates for projection 
    x, y = m(lon2d, lat2d)
    
    color_num = 101
    
 #   norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
 #   cmap=plt.cm.jet
    if figType == "dpt":
        norm = mpl.colors.Normalize(vmin=0, vmax=7200)
        cmap = plt.cm.jet
 #       cmap=plt.cm.seismic
    elif figType == "hs":
        norm = mpl.colors.Normalize(vmin=0, vmax=10)
        cmap = plt.cm.jet
 #       cmap=plt.cm.seismic
    elif figType == "wnd":
        norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
        cmap = plt.cm.jet
    elif figType == "lm":
        norm = mpl.colors.Normalize(vmin=0, vmax=320)
        cmap = plt.cm.jet
    elif figType == "fp":
        norm = mpl.colors.Normalize(vmin=0, vmax=0.25)
        cmap = plt.cm.jet
    elif figType == "dp":
        norm = mpl.colors.Normalize(vmin=0, vmax=360)
        cmap = plt.cm.hsv
    elif figType == "dm":
        norm = mpl.colors.Normalize(vmin=0, vmax=360)
        cmap = plt.cm.hsv
    elif figType == "t0m1":
        norm = mpl.colors.Normalize(vmin=0, vmax=15)
        cmap = plt.cm.jet

    cs = m.contourf(x, y, data, color_num, cmap=cmap, norm=norm)

    fig.tight_layout()

    fig.savefig(path + figType + "_" + time_idx + ".png", transparent=True, format="png", pad_inches=0, bbox_inches="tight")

    fig = 0
    if figType == "P":
        fig = plt.figure(figsize=(2, 4.0), frameon=False, facecolor=None)
    else :
        fig = plt.figure(figsize=(1.5, 4.0), frameon=False, facecolor=None)
    ax = fig.add_axes([0.2, 0.1, 0.2, 0.8], axisbg='g')

    cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, norm=norm)
    cb.set_label("", color="k")
    if figType == "dpt":
        plt.xlabel("%s (%s)" % ("", "m"))
    elif figType == "hs":
        plt.xlabel("%s (%s)" % ("", "m"))
    elif figType == "lm":
        plt.xlabel("%s (%s)" % ("", "m"))
    elif figType == "dp":
        plt.xlabel("%s (%s)" % ("", "degree"))
    elif figType == "dm":
        plt.xlabel("%s (%s)" % ("", "degree"))
    elif figType == "t0m1":
        plt.xlabel("%s (%s)" % ("", "sec"))  

    tEnd_f = time.time()
    print "creatFig Cal4 cost %f sec" % (tEnd_f - tStart_f)

    tStart_f = time.time()
    fig.savefig(path + figType + "_" + time_idx + "_cbar.png", transparent=False, format="png", facecolor="white")
    tEnd_f = time.time()
    print "createFig Save step cost %f sec" % (tEnd_f - tStart_f)

    plt.close("all")

def smooth(INs, data, depth):
    """
    Smooth the data when create the vector figure.

    Parameters
    ----------
    INs : np.ma.MaskedArray
        Mask read from file IN.nc
    data : np.ma.MaskedArray
        Value reads from Dataset
    depth : string
        Depth of INs
    Returns
    -------
    data : np.ma.MaskedArray
        Smoothed data
    """
    end_lon = 406
    end_lat = 488
    IN = INs[int(depth)]
    
    AVG = ( \
    IN[1:end_lat-1,1:end_lon-1] + \
    IN[2:end_lat,1:end_lon-1] + \
    IN[1:end_lat-1,2:end_lon] + \
    IN[2:end_lat,2:end_lon]
    )

    for i in range(1, 2, 1):
        data = \
        0.5 * data[1:end_lat-1,1:end_lon-1] * IN[1:end_lat-1,1:end_lon-1] + \
        0.5/AVG * (data[:end_lat-2,1:end_lon-1] * IN[:end_lat-2,1:end_lon-1] + \
        data[2:end_lat,1:end_lon-1] * IN[2:end_lat,1:end_lon-1] + \
        data[1:end_lat-1,:end_lon-2] * IN[1:end_lat-1,:end_lon-2] + \
        data[1:end_lat-1,2:end_lon] * IN[1:end_lat-1,2:end_lon])

    return data



def standardizeSeq(u):
    """
    Standarize the data into same scale to prevent the vector is too long or too short

    Parameters
    ----------
    u : np.ma.MaskedArray
        Value reads from Dataset

    Returns
    -------
    u : np.ma.MaskedArray
        Standardize value
    """
    vmax_u = np.max(abs(u))
 
    level = 25
    level_diff = vmax_u/level
    end_lon = 406
    end_lat = 488
    u = np.array([np.ceil(x/level_diff) for x in u])
    return u



def drawVector(prefix, outPath, INs, U, V, lons, lats, figType, depth, zoom, feq, scale, arrow_width, gridName):
    """
    createFig plot the data using lat/lon to the Figure and save it to the outPath.

    Parameters
    ----------
    prefix : string
        Main data input file name
    outPath : string
        Figure save to the location
    INs : np.ma.MaskedArray
        Mask read from file IN.nc
    U : np.ma.MaskedArray
        Value reads from Dataset
    V : np.ma.MaskedArray
        Value reads from Dataset
    lons : numpy.ndarray
        Value reads from Dataset
    lats : numpy.ndarray
        Value reads from Dataset
    figType : string
        The data type reads from Dataset
    time_idx : string
        date of the U/V
    zoom : int
        Zoom level in the map
    feq : int
        Frequency of the point to draw vector
    scale : 
        Data units per arrow length unit
    arrow_width:
        Width of the arrow

    Returns
    -------

    """
    tStart_f = time.time()

#    U = smooth(INs, U, depth)
#    V = smooth(INs, V, depth)
#    U = standardizeSeq(U)
#    V = standardizeSeq(V)

    path = outPath + "/" + prefix + "/" + figType + "/"

    if gridName == "grd1":
      ll_lat = 0
      ll_lon = 95
      ur_lat = 60
      ur_lon = 155
    elif gridName == "grd2":
      ll_lat = 15
      ll_lon = 110
      ur_lat = 35
      ur_lon = 140
    elif gridName == "grd3":
      ll_lat = 21
      ll_lon = 119
      ur_lat = 26
      ur_lon = 123
    elif gridName == "grd4":
      ll_lat = 24.4
      ll_lon = 121.65
      ur_lat = 24.9
      ur_lon = 122.15

    fig = plt.figure(figsize=(20, 20))
    fig.subplots_adjust(left=None, bottom=None, right=None, top=None, \
                        wspace=None, hspace=None)
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(ll_lon, ur_lon)
    ax.set_ylim(ll_lat, ur_lat)
    ax.set_xbound(ll_lon, ur_lon)
    ax.set_ybound(ll_lat, ur_lat)

    lon2d, lat2d = np.meshgrid(lons, lats)
    m = Basemap(projection="merc", llcrnrlat=ll_lat,urcrnrlat=ur_lat,\
                llcrnrlon=ll_lon,urcrnrlon=ur_lon, resolution="i")
    #m.drawcoastlines()
    x, y = m(lon2d, lat2d)
    m.drawmapboundary(fill_color="aqua")
    m.quiver(x[::feq, ::feq], y[::feq, ::feq], U[::feq, ::feq], V[::feq, ::feq], scale=scale, headlength=5, headwidth=6, width=arrow_width, minshaft=1, minlength=1, scale_units="inches", units="inches", color="w")
    
    plt.axis("off")
    fig.tight_layout()

    tEnd_f = time.time()
    print "drawVector Cal step cost %f sec" % (tEnd_f - tStart_f)

    tStart_f = time.time()
    fig.savefig(path + figType + "_" + depth + "_" + str(zoom) + "_v.png", transparent=True, format="png", pad_inches=0, bbox_inches="tight")
    tEnd_f = time.time()
    print "drawVector Save step cost %f sec" % (tEnd_f - tStart_f)
    plt.close("all")

def checkDir(prefix, outPath):
    dir = ["dpt", "hs", "wnd", "fp", "dp", "lm", "dm", "t0m1"]
    path = outPath + prefix
    for d in dir :
        try:
            os.stat(path)
        except:
            os.mkdir(path) 
        try:
            os.stat(path + "/" +d)
        except:
            os.mkdir(path + "/" + d) 


def initReader():
    """
    Initial the reader

    Parameters
    ----------

    Returns
    -------

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-ip", "--filePath", help="input path")
    parser.add_argument("-op", "--outPath",  help="output path")
    parser.add_argument("-gd", "--gridName", help="grid Name")

    args = parser.parse_args()
    outPath = ""
    inputFile = ""
    prefix = ""
    gridName = ""
    if args.filePath:
        inputFile = args.filePath
    if args.outPath:
        outPath = args.outPath + "/"
    if args.gridName:
        gridName = args.gridName
    
    try:
        os.stat(inputFile)
    except:
        print 'file ' + inputFile + ' is not exist'
        sys.exit()
        
    prefix = os.path.basename(inputFile).replace('.nc', '')
    
    return prefix, inputFile, outPath, gridName

def readMainNc(fileName):
    """
    Read Main nc file

    Parameters
    ----------
    fileName: string
        main nc file name

    Returns
    -------

    """
    try:
        os.stat(fileName)
    except:
        print 'file ' + fileName + ' is not exist'
        sys.exit()
    nc_fid = Dataset(fileName, 'r')  # Dataset is the class behavior to open the file
                                     # and create an instance of the ncCDF4 class
    nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    # Extract data from NetCDF file
    lats_ = nc_fid.variables['latitude'][:][:]
    lons_ = nc_fid.variables['longitude'][:][:]
    uwnd  = nc_fid.variables['uwnd'][:][:]  
    vwnd  = nc_fid.variables['vwnd'][:][:]
    dpt   = nc_fid.variables['dpt'][0][:]
    hs    = nc_fid.variables['hs'][:][:]
    fp    = nc_fid.variables['fp'][:][:]
    dp    = nc_fid.variables['dp'][:][:]
    dm    = nc_fid.variables['dir'][:][:]
    lm    = nc_fid.variables['lm'][:][:]
    t0m1  = nc_fid.variables['t0m1'][:][:]
    mask  = nc_fid.variables['MAPSTA'][0][:] 
    return nc_fid, lats_, lons_, uwnd, vwnd, dpt, hs, fp, dp, dm, lm, t0m1, mask

def main(argv) :
    """
    The start of the program

    Parameters
    ----------

    Returns
    -------

    """
     
    tStart = time.time()
    prefix, inputFile, outPath, gridName = initReader() 
    checkDir(prefix, outPath)
    tEnd = time.time()
    print "Init step cost %f sec" % (tEnd - tStart)

    tStart = time.time()
    nc_fid_main, lats_, lons_, uwnd, vwnd, dpt, hs, fp, dp, dm, lm, t0m1, mask = readMainNc(inputFile)
    tEnd = time.time()
    print "read file step cost %f sec" % (tEnd - tStart)

    tStart = time.time()
    t_size = 8

    feq = 15
    scale = 50
    
    createFig(prefix, outPath, nc_fid_main, dpt, lons_, lats_, 'dpt', str(0), gridName)
    for i in range(0, t_size, 1):
        hsi = hs[i]
        lmi = lm[i]
        dmi = dm[i]
        fpi = fp[i]
        dpi = dp[i]
        t0m1i = t0m1[i]

        Ui = uwnd[i]
        Vi = vwnd[i]
        wnd = np.sqrt(Ui*Ui+Vi*Vi)
        createFig(prefix, outPath, nc_fid_main, hsi, lons_, lats_, 'hs', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, lmi, lons_, lats_, 'lm', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, dmi, lons_, lats_, 'dm', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, fpi, lons_, lats_, 'fp', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, dpi, lons_, lats_, 'dp', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, t0m1i, lons_, lats_, 't0m1', str(i), gridName)
        createFig(prefix, outPath, nc_fid_main, wnd, lons_, lats_, 'wnd', str(i), gridName)
 
 #       drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'currents', str(i), 4, 8, 25, 0.01, gridName)
 #       drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'currents', str(i), 5, 8, 25, 0.01, gridName)
 #       drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'currents', str(i), 6, 6, 25, 0.01, gridName)
 #       drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'currents', str(i), 6, 6, 50, 0.01, gridName)
 #       drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'currents', str(i), 7, 4, 50, 0.0075, gridName)
        drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'wnd', str(i), 4, 18, 10, 0.05, gridName)
        drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'wnd', str(i), 5, 15, 10, 0.05, gridName)
        drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'wnd', str(i), 6, 12, 10, 0.05, gridName)
        drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'wnd', str(i), 7, 10, 10, 0.05, gridName)
        drawVector(prefix, outPath, mask, Ui, Vi, lons_, lats_, 'wnd', str(i), 8,  8, 10, 0.05, gridName)

    tEnd = time.time()
    print "createFig total cost %f sec" % (tEnd - tStart)

    nc_fid_main.close()
    # Close original NetCDF file.

if __name__ == "__main__":
    main(sys.argv)