# https://www.ncdc.noaa.gov/ibtracs/index.php?name=ib-v4-access
import sys
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from scipy import stats
import csv
from datetime import date
from elevations import Elevations

BASINTITLE = "ATLANTIC"
FILENAME = "hurdat.csv"
BASIN1 = "AL"
BASIN2 = "AL"
#BASINTITLE = "PACIFIC"
#FILENAME = "hurdat2-nepac.csv"
#BASIN1 = "EP"
#BASIN2 = "CP"

speedMin = [35, 55, 70]
speedMax = [54, 69, 84]
colors = ['green', 'aqua', 'blue']
labels = ['35-54', '55-69', '70-84']

#speedMin = [85, 105, 125]
#speedMax = [104, 124, 999]
#colors = ['purple', 'red', 'orange']
#labels = ['85-104', '105-124', '>=125']

YEARS = [1970, 2019]

def convertLatitude(latString):
	if latString[-1] == 'N':
		return float(latString[0:-1])
	else:
		return float(latString[0:-1]) * -1.0

def convertLongitude(lonString):
	if lonString[-1] == 'W':
		return float(lonString[0:-1]) * -1.0
	else:
		return float(lonString[0:-1])

def read_data(filename):
	elevations = Elevations();
	elevations.loadFiles();
	pLand = {}
	for speed in speedMin:
		pLand[speed] = []
	columns1 = {"desig": 0, "name": 1, "rows": 2}
	columns2 = {"ymd": 0, "time": 1, "landfall": 2, "stormType": 3, "lat": 4, "lon": 5, "wind": 6, "pressure": 7}
	f = open(filename, 'r')
	csv_reader = csv.reader(f, delimiter=',')
	current_year = YEARS[0]
	storms = {}
	stormsNL = {}
	for speed in speedMin:
		storms[speed] = 0
		stormsNL[speed] = 0
	for line in csv_reader:
		firstToken = line[0]
		if firstToken.startswith(BASIN1) or firstToken.startswith(BASIN2):
			rows = int(line[columns1["rows"]])
			row = 0
			maxWind = 0
			legitStorm = False
			landfall = False
			year = int(firstToken[4:9])
			if year > YEARS[1]:
				break
			if year < current_year:
				continue
			if year > current_year:
				print(current_year, storms, stormsNL)
				for speed in speedMin:
					if storms[speed] > 0:
						pLand[speed].append(stormsNL[speed] / storms[speed])
					else:
						pLand[speed].append(99.99)
					storms[speed] = 0
					stormsNL[speed] = 0
				current_year = year
		elif year >= YEARS[0]:
			lat = convertLatitude(line[columns2["lat"]])
			lon = convertLongitude(line[columns2["lon"]])
			wind = int(line[columns2["wind"]])
			stormType = line[columns2["stormType"]].strip()
			# ignore storms that don't become TS or HU
			ele = None
			if stormType == 'TS' or stormType == 'HU':
				legitStorm = True
				ele = elevations.getElevation(lon, lat)
#				if ele == None:
#					print(lat, lon, ele)
			if wind > maxWind and stormType != 'EX':
				maxWind = wind
			if ele != None and ele > -500:
				landfall = True
			row = row + 1
			if row == rows:
				# we are done with this storm
				if legitStorm:
					for sMin, sMax in zip(speedMin, speedMax):
						if maxWind >= sMin and maxWind < sMax:
							storms[sMin] = storms[sMin] + 1
							if landfall:
								stormsNL[sMin] = stormsNL[sMin] + 1
	print(current_year, storms, stormsNL)
	for speed in speedMin:
		if storms[speed] > 0:
			pLand[speed].append(stormsNL[speed] / storms[speed])
		else:
			pLand[speed].append(99.99)
	return pLand
		
def plot():
	pLand = read_data(FILENAME)

#	print(pLand)

	fig = plt.figure(figsize=(8,4))

	for speed, color, label in zip(speedMin, colors, labels):
		ps = pLand[speed]
		years = list(range(YEARS[0], YEARS[1] + 1))
		n = 0
		validYears = []
		validPs = []
		for year in years:
			p = ps[n]
			n = n + 1
			if p != 99.99:
				validYears.append(year)
				validPs.append(p)

		axes = plt.gca()
		axes.set_ylim([0,1.2])
		plt.scatter(validYears, validPs, color=color, label=label, alpha=0.6)

		if len(validYears) > 0:
			slope, intercept, r_value, p_value, std_err = stats.linregress(validYears, validPs)
			points = [slope * y + intercept for y in validYears]
			plt.plot(validYears, points, color=color, alpha=0.5)

#	plt.tight_layout()
	plt.legend(ncol = 4, loc='upper right')
	plt.title("Percent of " + BASINTITLE + " storms near land " + str(YEARS[0]) + "-" + str(YEARS[1]))
	pngFile = "output.png"
	plt.savefig(pngFile)
	plt.show()

if __name__ == '__main__':	
	plot()
