from math import log10
import numpy as na
#nullfmt = NullFormatter()
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size=comm.size

redshift = 15.0
nx = 512
Ex = 1e3   # photon energy (eV)
cfl = 0.5

erg_eV = 8.61423e-5
G = 6.673e-8
H0 = 71.0
Omega_b = 0.0449
Omega_m = 0.266
yH = 0.76
IH = 13.6
h = 6.626e-27
kb = 1.38e-16
eV = 1.602e-12
yr = 3.1557e7
mH = 1.661e-24

#rhoH = 3 * (H0/3.086e19)**2 / (8 * na.pi * G) * yH * (1.0+redshift)**3
nu = Ex*eV / h

# Read cooling table (primordial abundances -- 2nd column)
cool_table = na.loadtxt("zcool_sd93.dat")[:,0:2]

if rank == 0:
  rho = (na.fromfile('Density_512.dat'))*yH
  Tem = na.fromfile('Temperature_512.dat')
  ef = na.fromfile('HII_Fraction_512.dat') 
else:
  rho = None
  Tem = None
  ef = None
#Xflux = na.fromfile('XrayFlux.dat')

rho_local = na.zeros(nx**3/size)
Tem_local = na.zeros(nx**3/size)
ef_local  = na.zeros(nx**3/size)
#Xflux_local = na.zeros(nx**3/size)

comm.Scatter(rho,rho_local,root=0)
comm.Scatter(Tem,Tem_local,root=0)
comm.Scatter(ef,ef_local,root=0)
#comm.Scatter(Xflux,Xflux_local,root=0)
comm.Barrier()

del rho,Tem,ef

def CH(T):
    lnT = na.log(T*erg_eV)
    coeff = [-32.71396786, 13.536556, -5.73932875, 1.56315498, -0.2877056,
              3.48255977e-2, -2.63197617e-3, 1.11954395e-4, -2.03914985e-6]
    a = 0.0
    for i in range(len(coeff)):
        a += coeff[i] * lnT**i
    return na.exp(a)

def alphaB(T):
    return 2.59e-13 * (T/1e4)**(-0.7)

def sigmaH(E):
    return 5.475e-14 * (E / 0.4298 - 1)**2 * (E / 0.4298)**(-4.0185) * \
        (1 + na.sqrt(E / 14.13))**(-2.963)

def secondary_ion(x):
    return 0.3908 * (1.0 - x**0.4092)**1.7592

def secondary_heat(x):
    return 0.9971 * (1.0 - (1.0 - x**0.2663)**1.3163)



T_final_local = na.zeros(nx**3/size)
x_final_local = na.zeros(nx**3/size)

ii=0 
for rhoH,T,x in zip(rho_local,Tem_local,ef_local):
  F =  1e-6 #XFlux[ii]
  tfinal = 1e7 * yr
  dt0 = 1e6 * yr
  t = 0.0
  Eth = kb * T
  #results = {"time": [], "x": [], "T": [], "dx_dt": [], "dE_dt": []}
  while t < tfinal:
        ne = rhoH*x/mH
        T = Eth / kb
        logT = log10(T)
        kph = secondary_ion(x) * (F / (Ex*eV)) * sigmaH(Ex) * (Ex/IH)
        heat = secondary_heat(x) * (F / (Ex*eV)) * sigmaH(Ex) * (Ex*eV)
        if T < 1e4:
            cool = 0.0
        else:
            cool_idx = na.searchsorted(cool_table[:,0], logT)
            interp_factor = (logT - cool_table[cool_idx,0]) / \
                (cool_table[cool_idx+1,0] - cool_table[cool_idx,0])
            cool = cool_table[cool_idx,1] + \
                interp_factor * (cool_table[cool_idx+1,1] - cool_table[cool_idx,1])
            cool = 10.0**cool
        dx_dt = (1.0 - x) * (kph - ne*CH(T)) - x*ne*alphaB(T)
        dt = min(dt0, 0.01 * abs(x/(dx_dt)),
                 0.01 * abs(Eth / (heat - cool)))
        t += dt
        x += dx_dt * dt
        Eth += (heat - cool) * dt
  T_final_local[ii] = Eth/kb
  x_final_local[ii] = x
  ii = ii + 1

if rank == 0:
        T_final = na.zeros(nx**3)
        x_final = na.zeros(nx**3)
else:
        T_final = None
        x_final = None

comm.Gather(T_final_local,T_final,root=0)
comm.Gather(x_final_local,x_final,root=0)

if rank == 0:
  T_final.tofile('Final_Temperature.dat')
  x_final.tofile('Final_Electron_fration.dat')
