import numpy as np
import math
from matplotlib import rc
rc('font', size=18)
#rc('text', usetex=True)
#rc('font', family='serif')

import pylab as pl
from h5py import File
import os
import string
from matplotlib import pyplot 
from matplotlib import pyplot as plt
from matplotlib import pylab
from matplotlib import colors
import numpy as np
from matplotlib.ticker import NullFormatter

massmin = 1.0e6
bins =31  # 300
binwidth = 0.1 #math.log10(massmax/massmin)/bins
massmax = 10**(math.log10(massmin)+bins*binwidth)
numbins = 101

f=open('../Redshift17.91_DD11/PopIII/PopIII_haloes_nooverlap_allPopIII.txt','r')

halolist = []
PopIIInumber = []

line = f.readline()
while line != '':
   halomass = float(line.split()[0])
   halolist.append(halomass)
   PopIIInumber.append(int(line.split()[1]))
   line = f.readline()


totalPopIII = np.zeros([bins+1,numbins+1])

for halomass,PopIII in zip(halolist,PopIIInumber):
   index1 = int((math.log10(halomass/massmin)+0.0*binwidth)/binwidth)
   totalPopIII[index1][PopIII] = totalPopIII[index1][PopIII] + 1
   #index2  = int(math.log10((PopIII+1.0)/1.0)/ybinwidth)
   #if index2 < 0: index2=0
   #totalPopIII[index1][index2] = totalPopIII[index1][index2] + 1

f.close()

f=open('../Redshift15/PopIII/PopIII_haloes_nooverlap_allPopIII.txt','r')

halolist = []
PopIIInumber = []

line = f.readline()
while line != '':
   halomass = float(line.split()[0])
   halolist.append(halomass)
   PopIIInumber.append(int(line.split()[1]))
   line = f.readline()


totalPopIII1 = np.zeros([bins+1,numbins+1])

for halomass,PopIII in zip(halolist,PopIIInumber):
   index1 = int((math.log10(halomass/massmin)+0.0*binwidth)/binwidth)
   totalPopIII1[index1][PopIII] = totalPopIII1[index1][PopIII] + 1
   #index2  = int(math.log10((PopIII+1.0)/1.0)/ybinwidth)
   #if index2 < 0: index2=0
   #totalPopIII[index1][index2] = totalPopIII[index1][index2] + 1

f.close()


data = totalPopIII
xvals = 10**(math.log10(massmin)+np.array(range(bins+1))*binwidth)
yvals = np.array(range(numbins+1)) 
yvals = 10**(np.log10(np.array(range(numbins+1))-0.0))
yvals[0] = 0.8

ylabel = r'Pop III Number'
xlabel = r'Halo Mass [M$_{\odot}$]'

dy = yvals[1]-yvals[0]
dlogx = np.log10(xvals[1])-np.log10(xvals[0])

#data /= dlogx*dlogy

intmin = data.min()
intmax = data.max()
intmin = 0
intmax =np.log10(data.max())

print data
print intmin, intmax

fig = pl.figure(figsize=(8,12))
ax = fig.add_subplot(211)

mynorm = colors.LogNorm([intmin, intmax])
im = ax.contourf(xvals, yvals, data.T, np.logspace(intmin,intmax,3), cmap=pl.cm.jet, norm=mynorm)

fontsize=18

X, Y = np.meshgrid(xvals,yvals)
im = ax.pcolor(xvals,yvals,data.T,cmap=pl.cm.jet, norm=mynorm)
ax.set_xscale('log')
ax.set_yscale('log')
cbar = fig.colorbar(im, fraction=0.05, ticks=(1,10,100,1000), pad=0.01, shrink=0.9)
cbar.set_label('Halo Number')

pl.xlim(xvals.min(), xvals.max())
#pl.ylim(yvals.min(), 10)

pl.xlabel('%s' % xlabel)
pl.ylabel('%s'% ylabel)

data = totalPopIII1
xvals = 10**(math.log10(massmin)+np.array(range(bins+1))*binwidth)
yvals = np.array(range(numbins+1)) 
yvals = 10**(np.log10(np.array(range(numbins+1))-0.0))
yvals[0] = 0.8

ylabel = r'Pop III Number'
xlabel = r'Halo Mass [M$_{\odot}$]'

dy = yvals[1]-yvals[0]
dlogx = np.log10(xvals[1])-np.log10(xvals[0])

#data /= dlogx*dlogy

intmin = data.min()
intmax = data.max()
intmin = 0
intmax =np.log10(data.max())

print data
print intmin, intmax

#fig = pl.figure(figuresize=(8,12))
ax = fig.add_subplot(212)

mynorm = colors.LogNorm([intmin, intmax])
im = ax.contourf(xvals, yvals, data.T, np.logspace(intmin,intmax,3), cmap=pl.cm.jet, norm=mynorm)

fontsize=18

X, Y = np.meshgrid(xvals,yvals)
im = ax.pcolor(xvals,yvals,data.T,cmap=pl.cm.jet, norm=mynorm)
ax.set_xscale('log')
ax.set_yscale('log')
cbar = fig.colorbar(im, fraction=0.05, ticks=(1,10,100,1000), pad=0.01, shrink=0.9)
cbar.set_label('Halo Number')

pl.xlim(xvals.min(), xvals.max())
#pl.ylim(yvals.min(), 10)

pl.xlabel('%s' % xlabel)
pl.ylabel('%s'% ylabel)



pl.savefig('PopIII_halo_phase_nooverlapi_log.eps',format='eps')
fig.clf()

