import numpy as np
import math
from matplotlib import rc
rc('text', usetex=True)
rc('font', family='serif')

import pylab as pl
#from scipy import *
#import scipy.signal
from h5py import File
#import pyfits
import os
import string
from matplotlib import pyplot 
from matplotlib import pyplot as plt
from matplotlib import pylab
from matplotlib import colors
#from pyreadcol import *
import numpy as np
from matplotlib.ticker import NullFormatter

currenttime = 12.756587046328

massmin = 1.0e6
bins =31  # 300
binwidth = 0.1 #math.log10(massmax/massmin)/bins
massmax = 10**(math.log10(massmin)+bins*binwidth)
tbins = 100
tmin = 0.0
tmax = 8.0
twidth = (tmax-tmin)/tbins
tunit = 6.88027e+14/3.15569e7/1e6

f=open('PopIII_haloes_0002_type5.txt','r')
f1=open('PopIII_haloes_0002_type1.txt','r')

halolist = []
PopIIItime = []

line = f.readline()
line1 = f1.readline()
while line !='':
   mass = float(line.split()[1])
   line = f.readline()
   line1 = f1.readline()
   pn = line.split()[0] 
   pn1 = line1.split()[0]
   m0 = []
   for i in xrange(int(pn)):
     line=f.readline()
     starmass = float(line.split()[1])
     if starmass > 1e20: starmass=starmass/1e20
     if starmass > 5:
        m0.append(float(line.split()[2]))
   for i in xrange(int(pn1)):
     line1=f1.readline()
     starmass = float(line1.split()[1])
     if starmass > 5:
       m0.append(float(line1.split()[2]))
   if len(m0) != 0:
      halolist.append(mass) 
      PopIIItime.append(m0)
   f.readline();f.readline();f1.readline();f1.readline()
   line = f.readline()
   line1 = f1.readline()


totalPopIII = np.zeros([bins+1,tbins+1])

for halomass,PopIII in zip(halolist,PopIIItime):
   index1 = int((math.log10(halomass/massmin)+0.0*binwidth)/binwidth)
   for time in PopIII:
       index2 = int((currenttime-time)/twidth)
       totalPopIII[index1][index2] = totalPopIII[index1][index2] + 1

f.close()

data = totalPopIII
xvals = 10**(math.log10(massmin)+np.array(range(bins+1))*binwidth)
yvals = np.array(range(tbins+1))*twidth*tunit

ylabel = r'$Look Back Time [Myr]$'
xlabel = '$Halo Mass [M_{/odot}]$'

#dlogy = np.log10(yvals[1])-np.log10(yvals[0])
dy = yvals[1]-yvals[0]
dlogx = np.log10(xvals[1])-np.log10(xvals[0])

#data /= dlogx*dlogy

#intmin = int(np.log10(data[1:,1:].min())) + 1
#intmax = int(np.log10(data[1:,1:].max())) + 1
intmin = data.min()
intmax = data.max()
intmin = 0
intmax =np.log10(data.max())

print data
print intmin, intmax

fig = pl.figure()
ax = fig.add_subplot(111)

mynorm = colors.LogNorm([intmin, intmax])
im = ax.contourf(xvals, yvals, data.T, np.logspace(intmin,intmax,3), cmap=pl.cm.jet, norm=mynorm)

if xvals.size == data.T.shape[0]:
   xvals = np.append(xvals, 2.*10.**np.log10(xvals[-1]) - 10.**np.log10(xvals[-2]))
if yvals.size == data.T.shape[0]:
   yvals = np.append(yvals, yvals[-1] - yvals[-2])

fontsize=24

X, Y = np.meshgrid(xvals,yvals)
im = ax.pcolor(X,Y,data.T,cmap=pl.cm.jet, norm=mynorm)
ax.set_xscale('log')
#ax.set_yscale('log')
cbar = fig.colorbar(im, fraction=0.05, ticks=(1,10,100,1000), pad=0.01, shrink=0.9)
cbar.set_label('Halo Number')

pl.xlim(xvals.min(), xvals.max())
pl.ylim(yvals.min(), yvals.max())
#pl.ylim(yvals.min(), 10)

pl.xlabel('%s' % xlabel)
pl.ylabel('%s'% ylabel)
pl.savefig('PopIIItime_halo_phase.png',format='png')
fig.clf()

