import numpy as np
import math
from matplotlib import rc
rc('font', size=18)
#rc('text', usetex=True)
#rc('font', family='serif')

import pylab as pl
from h5py import File
import os
import string
from matplotlib import pyplot 
from matplotlib import pyplot as plt
from matplotlib import pylab
from matplotlib import colors
import numpy as np
from matplotlib.ticker import NullFormatter

currenttime = 9.9284345716569

massmin = 1.0e6
bins =31  # 300
binwidth = 0.1 #math.log10(massmax/massmin)/bins
massmax = 10**(math.log10(massmin)+bins*binwidth)
tbins = 100
tmin = 0.04
tmax = 8.0
twidth = (np.log10(tmax)-np.log10(tmin))/tbins
tunit = 6.88027e+14/3.15569e7/1e6

f=open('../Redshift17.91_DD11/PopIII/PopIII_haloes_0011_type5.txt','r')
f1=open('../Redshift17.91_DD11/PopIII/PopIII_haloes_0011_type1.txt','r')

f2=open('../Redshift17.91_DD11/PopIII/Overlaped_haloes_list.txt','r')
overlaped=[]
line=f2.readline()
while line !='':
   overlaped.append(int(line))
   line=f2.readline()

halolist = []
PopIIItime = []

line = f.readline()
line1 = f1.readline()
j=0
while line !='':
   mass = float(line.split()[1])
   line = f.readline()
   line1 = f1.readline()
   pn = line.split()[0] 
   pn1 = line1.split()[0]
   m0 = []
   for i in xrange(int(pn)):
     line=f.readline()
     starmass = float(line.split()[1])
     if starmass > 1e20: starmass=starmass/1e20
     if starmass > 5:
        m0.append(float(line.split()[2]))
   for i in xrange(int(pn1)):
     line1=f1.readline()
     starmass = float(line1.split()[1])
     if starmass > 5:
       m0.append(float(line1.split()[2]))
   if j not in overlaped:
     if len(m0) != 0:
      halolist.append(mass) 
      PopIIItime.append(m0)
   f.readline();f.readline();f1.readline();f1.readline()
   line = f.readline()
   line1 = f1.readline()
   j=j+1


totalPopIII = np.zeros([bins+1,tbins+1])

for halomass,PopIII in zip(halolist,PopIIItime):
   index1 = int((math.log10(halomass/massmin)+0.0*binwidth)/binwidth)
   for time in PopIII:
       index2 = int((np.log10(currenttime - time if currenttime - time > tmin else tmin)-np.log10(tmin))/twidth)
       totalPopIII[index1][index2] = totalPopIII[index1][index2] + 1

f.close()
f1.close()
f2.close()

currenttime = 12.756587046328

f=open('../Redshift15/PopIII/PopIII_haloes_0002_type5.txt','r')
f1=open('../Redshift15/PopIII/PopIII_haloes_0002_type1.txt','r')

f2=open('../Redshift15/PopIII/Overlaped_haloes_list.txt','r')
overlaped=[]
line=f2.readline()
while line !='':
   overlaped.append(int(line))
   line=f2.readline()

halolist1 = []
PopIIItime1 = []

line = f.readline()
line1 = f1.readline()
j=0
while line !='':
   mass = float(line.split()[1])
   line = f.readline()
   line1 = f1.readline()
   pn = line.split()[0] 
   pn1 = line1.split()[0]
   m0 = []
   for i in xrange(int(pn)):
     line=f.readline()
     starmass = float(line.split()[1])
     if starmass > 1e20: starmass=starmass/1e20
     if starmass > 5:
        m0.append(float(line.split()[2]))
   for i in xrange(int(pn1)):
     line1=f1.readline()
     starmass = float(line1.split()[1])
     if starmass > 5:
       m0.append(float(line1.split()[2]))
   if j not in overlaped:
     if len(m0) != 0:
      halolist1.append(mass) 
      PopIIItime1.append(m0)
   f.readline();f.readline();f1.readline();f1.readline()
   line = f.readline()
   line1 = f1.readline()
   j=j+1


totalPopIII1 = np.zeros([bins+1,tbins+1])

for halomass,PopIII in zip(halolist1,PopIIItime1):
   index1 = int((math.log10(halomass/massmin)+0.0*binwidth)/binwidth)
   for time in PopIII:
       index2 = int((np.log10(currenttime - time if currenttime - time > tmin else tmin)-np.log10(tmin))/twidth)
       totalPopIII1[index1][index2] = totalPopIII1[index1][index2] + 1

f.close()
f1.close()
f2.close()

data = totalPopIII
xvals = 10**(math.log10(massmin)+np.array(range(bins+1))*binwidth)
yvals = 10**(np.array(range(tbins+1))*twidth)*tunit*tmin+tunit*tmin

ylabel = r'Lookback Time [Myr]'
xlabel = r'Halo Mass [M$_{\odot}$]'

dy = yvals[1]-yvals[0]
yvals[0] = 1.5
dlogx = np.log10(xvals[1])-np.log10(xvals[0])

intmin = 0
intmax =np.log10(data.max())

fig = pl.figure(figsize=(8,12))
ax = fig.add_subplot(211)

mynorm = colors.LogNorm([intmin, intmax])
im = ax.contourf(xvals, yvals, data.T, np.logspace(intmin,intmax,3), cmap=pl.cm.jet, norm=mynorm)

#if xvals.size == data.T.shape[0]:
#   xvals = np.append(xvals, 2.*10.**np.log10(xvals[-1]) - 10.**np.log10(xvals[-2]))
#if yvals.size == data.T.shape[0]:
#   yvals = np.append(yvals, yvals[-1] - yvals[-2])

fontsize=18

X, Y = np.meshgrid(xvals,yvals)
im = ax.pcolor(X,Y,data.T,cmap=pl.cm.jet, norm=mynorm)
ax.set_xscale('log')
ax.set_yscale('log')
cbar = fig.colorbar(im, fraction=0.05, ticks=(1,10,100,1000), pad=0.01, shrink=0.9)
cbar.set_label('Pop III Number')

pl.xlim(xvals.min(), xvals.max())
pl.ylim(yvals.min(), yvals.max())

pl.xlabel('%s' % xlabel)
pl.ylabel('%s'% ylabel)

data = totalPopIII1
xvals = 10**(math.log10(massmin)+np.array(range(bins+1))*binwidth)
yvals = 10**(np.array(range(tbins+1))*twidth)*tunit*tmin+tunit*tmin

ylabel = r'Lookback Time [Myr]'
xlabel = r'Halo Mass [M$_{\odot}$]'

dy = yvals[1]-yvals[0]
yvals[0]=1.5
dlogx = np.log10(xvals[1])-np.log10(xvals[0])

intmin = 0
intmax =np.log10(data.max())

#fig = pl.figure()
ax = fig.add_subplot(212)

mynorm = colors.LogNorm([intmin, intmax])
im = ax.contourf(xvals, yvals, data.T, np.logspace(intmin,intmax,3), cmap=pl.cm.jet, norm=mynorm)

#if xvals.size == data.T.shape[0]:
#   xvals = np.append(xvals, 2.*10.**np.log10(xvals[-1]) - 10.**np.log10(xvals[-2]))
#if yvals.size == data.T.shape[0]:
#   yvals = np.append(yvals, yvals[-1] - yvals[-2])

fontsize=18

X, Y = np.meshgrid(xvals,yvals)
im = ax.pcolor(X,Y,data.T,cmap=pl.cm.jet, norm=mynorm)
ax.set_xscale('log')
ax.set_yscale('log')
cbar = fig.colorbar(im, fraction=0.05, ticks=(1,10,100,1000), pad=0.01, shrink=0.9)
cbar.set_label('Pop III Number')

pl.xlim(xvals.min(), xvals.max())
pl.ylim(yvals.min(), yvals.max())

pl.xlabel('%s' % xlabel)
pl.ylabel('%s'% ylabel)



pl.savefig('PopIIItime_halo_phase_nooverlap_log.eps',format='eps')
fig.clf()

