from yt.mods import *

from matplotlib import rc
rc('text', usetex=True)
rc('font', family='serif')
rc('font', size=18)

import pylab as pl
#from scipy import *
#import scipy.signal
from h5py import File
#import pyfits
import os
import string
from matplotlib import pyplot 
from matplotlib import pyplot as plt
from matplotlib import pylab
from matplotlib import colors
#from pyreadcol import *
import numpy as np
from matplotlib.ticker import NullFormatter
nullfmt = NullFormatter()

n=28

f=open('PopII_haloes_SFR.txt','r')
lines=f.readlines()
f.close()

i=1
Mass = [];sfr = [];stellarmass=[];sfe=[];gasfraction=[]
for i in xrange(0,len(lines)-1):
    if float(lines[i].split()[2]) <= 0.0: continue
    Mass.append(float(lines[i].split()[1]))
    sfr.append(float(lines[i].split()[4]))
    stellarmass.append(float(lines[i].split()[2]))
    sfe.append(float(lines[i].split()[7]))
    gasfraction.append(float(lines[i].split()[6]))

Mass=np.log10(Mass);sfr=np.log10(sfr);stellarmass=np.log10(stellarmass)
sfe=np.log10(sfe);gasfraction=np.array(gasfraction)
x_max = 10 # int(Mass.max()+1)
x_min = 6.5
x_bins =35
data=[];xvals=[];yvals=[]

y_max = 9
y_min = 3 
y_bins = 60
datat,xvalst,yvalst = pylab.histogram2d(Mass,stellarmass,bins=[x_bins,y_bins],range=[[x_min,x_max],[y_min,y_max]])
data.append(datat);xvals.append(xvalst);yvals.append(yvalst)

y_max = 1
y_min = -4
y_bins = 50  #200
datat,xvalst,yvalst = pylab.histogram2d(Mass,sfr,bins=[x_bins,y_bins],range=[[x_min,x_max],[y_min,y_max]])
data.append(datat);xvals.append(xvalst);yvals.append(yvalst)

y_max = 0
y_min = -3.5 
y_bins = 35
datat,xvalst,yvalst = pylab.histogram2d(Mass,sfe,bins=[x_bins,y_bins],range=[[x_min,x_max],[y_min,y_max]])
data.append(datat);xvals.append(xvalst);yvals.append(yvalst)

y_max = 0.2
y_min = 0
y_bins = 40  #200
datat,xvalst,yvalst = pylab.histogram2d(Mass,gasfraction,bins=[x_bins,y_bins],range=[[x_min,x_max],[y_min,y_max]])
data.append(datat);xvals.append(xvalst);yvals.append(yvalst)

filename = 'virialmass_4plots'
#fields = ["HaloNumber"]
fields_label = ["","","",""]  #['Number','Number','Number','Number']
fields_min = [0,0,0,0]
ylabels = [r'$log_{10}(M_{\star}/M_{\odot})$', r'$log_{10}(SFR/(M_{\odot} yr^{-1}))$',r'$log_{10}(f_{\star})$',r'$f_{gas}$']

xlabels = ['','',r'$log_{10}(M_{vir}/M_{\odot})$',r'$log_{10}(M_{vir}/M_{\odot})$']

fig = plt.figure(figsize=[16,8])
plt.subplots_adjust(left=0.07, right=0.98, bottom=0.07, top=0.98, hspace=1e-3)

i=1
for cbarlabel,field_min,ylabel,xlabel,xvalst,yvalst,datat in zip(fields_label,fields_min,ylabels,xlabels,xvals,yvals,data):
    
    dy = yvalst[1]-yvalst[0]
    dx = xvalst[1]-xvalst[0]
    xvals1 = xvalst[0:-1]+dx/2.0
    yvals1 = yvalst[0:-1]+dy/2.0

    intmin = 0
    intmax = datat[1:,1:].max()

    ax = fig.add_subplot(2,2,i)

    mynorm = colors.Normalize([intmin, intmax])
    im = ax.contourf(xvals1, yvals1, datat.T, np.linspace(intmin,intmax,3), cmap=pl.cm.binary, norm=mynorm)


    #if xvals.size == data.T.shape[0]:
    #    xvals = np.append(xvals, 2.*10.**np.log10(xvals[-1]) - 10.**np.log10(xvals[-2]))
    #if yvals.size == data.T.shape[0]:
    #    yvals = np.append(yvals, 2.*10.**np.log10(yvals[-1]) - 10.**np.log10(yvals[-2]))

    fontsize=24

    X, Y = np.meshgrid(xvalst,yvalst)
    im = ax.pcolormesh(X,Y,datat.T,cmap=pl.cm.binary, norm=mynorm)
    #ax.set_xscale('log')
    #ax.set_yscale('log')
    cbar = fig.colorbar(im, fraction=0.05, format = '%i', ticks=np.linspace(intmin,intmax, (intmax-intmin)+1), pad=0.01, shrink=0.9)
    cbar.set_label('%s' % cbarlabel)

    pl.xlim(xvalst.min()-0.1, xvalst.max())
    pl.ylim(yvalst.min(), yvalst.max())
    if i == 3:pl.ylim(yvalst.min(), yvalst.max()-0.01)
    if i == 4:pl.ylim(yvalst.min(), yvalst.max()-0.01)


    pl.xlabel('%s' % xlabel)
    pl.ylabel('%s'% ylabel)
    #if i == 1:
    ax.yaxis.set_label_coords(-0.1,0.5)

    if i <= 2:
        ax.xaxis.set_major_formatter(nullfmt)
    i=i+1    

pl.savefig('RD%04i_%s.eps'% (n,filename),format='eps')
fig.clf()






