#!/usr/bin/env python


from pNbody import *
from pNbody import libgrid
from pNbody import myNumeric

import time

from optparse import OptionParser



########################################  
#					  
# parser				  
#					  
######################################## 



def parse_options():

  usage = "usage: %prog [options] file"
  parser = OptionParser(usage=usage)
  
  parser.add_option("-t",
		   action="store", 
		   dest="ftype",
		   type="string",
		   default = 'gadget',		   
		   help="type of the file",	 
		   metavar=" TYPE")    

  parser.add_option("-i",
		   action="store", 
		   dest="input_file",
		   type="string",
		   default = None,		   
		   help="input file",	 
		   metavar=" FILE")    

  parser.add_option("--eps",
		   action="store", 
		   dest="eps",
		   type="float",
		   default = 0.25,		   
		   help="softening length",	 
		   metavar=" FLOAT")    

  parser.add_option("--dTime",
		   action="store", 
		   dest="dTime",
		   type="float",
		   default = 0.1,		   
		   help="time step",	 
		   metavar=" FLOAT")    	   
		   
  parser.add_option("--TimeEnd",
		   action="store", 
		   dest="TimeEnd",
		   type="float",
		   default = 10,		   
		   help="final time",	 
		   metavar=" FLOAT")   	


  parser.add_option("--dOutputTime",
		   action="store", 
		   dest="dOutputTime",
		   type="float",
		   default = 0,		   
		   help="time between output",	 
		   metavar=" FLOAT")    
		   
  parser.add_option("--dStatTime",
		   action="store", 
		   dest="dStatTime",
		   type="float",
		   default = 0,		   
		   help="time between system statistics",	 
		   metavar=" FLOAT")   
		   
  parser.add_option("--theta",
		   action="store", 
		   dest="theta",
		   type="float",
		   default = 0.7,		   
		   help="time between output",	 
		   metavar=" FLOAT") 	
		   	   		   		   		   	    		    		    
  (options, args) = parser.parse_args()
     
  return options


  
  
  
  
##############################################
# write integrals
##############################################


def WriteIntegrals(fe,t,T,U,C,P,L,I):
  
  E  = T + U
  Cx = C[0]
  Cy = C[1]
  Cz = C[2]
  Px = P[0]
  Py = P[1]
  Pz = P[2]
  Lx = L[0]
  Ly = L[1]
  Lz = L[2]
  Ix = I[0]
  Iy = I[1]
  Iz = I[2]  
  
  line = "%20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e\n"%(t,E,T,U,Cx,Cy,Cz,Px,Py,Pz,Lx,Ly,Lz,Ix,Iy,Iz)
  fe.write(line)
  fe.flush()
  
  
##############################################
# compute energy kin
##############################################

def ComputeEnergyKin(nb):
  return sum(0.5 * nb.mass *  (nb.vel[:,0]**2 + nb.vel[:,1]**2 + nb.vel[:,2]**2))  


##############################################
# compute energy pot
##############################################

def ComputeEnergyPot(nb,eps):
  return sum(nb.Epot(eps))
  
    
##############################################
# compute mass center
##############################################

def ComputeMassCenter(nb):

  mass_tot = sum(nb.mass)  
  cmx = sum(nb.pos[:,0]*nb.mass) / mass_tot
  cmy = sum(nb.pos[:,1]*nb.mass) / mass_tot
  cmz = sum(nb.pos[:,2]*nb.mass) / mass_tot
  return array([cmx,cmy,cmz])


##############################################
# compute momentum
##############################################

def ComputeMomentum(nb):

  px = sum(nb.vel[:,0]*nb.mass)
  py = sum(nb.vel[:,1]*nb.mass)
  pz = sum(nb.vel[:,2]*nb.mass)
  return array([px,py,pz])

##############################################
# compute angular momentum
##############################################

def ComputeAngularMomentum(nb):
  return nb.Ltot()
    

##############################################
# intertial momentum
##############################################

def ComputeInertialMomentum(nb):
  return  nb.minert()


      
################################################################################
#
#                                    MAIN
#
################################################################################

options = parse_options()

ftype       	   = options.ftype
input_file  	   = options.input_file
eps         	   = options.eps
dOutputTime 	   = options.dOutputTime
dStatTime 	   = options.dStatTime
TimeEnd     	   = options.TimeEnd
dTime       	   = options.dTime
ErrTolTheta        = options.theta




# open file
print "open initial conditions"
nb = Nbody(input_file,ftype=ftype)

print "build the tree"
nb.getTree(ErrTolTheta=ErrTolTheta,force_computation=True)

print "compute acceleration"
nb.acc = nb.TreeAccel(nb.pos,eps)


# some init

Time  = nb.atime

Step            = 0
CPUTimeRef      = time.time()
CPUTime         = 0.0

OutputTime      = 0
OutputNumber    = 0

StatTime        = 0

# open output
fi = open("integrals.dat",'w')
fi.write("# t E T U Cx Cy Cz Px Py Pz Lx Ly Lz Ix Iy Iz\n")


#####################
# main loop
#####################

while (Time<TimeEnd):


  # write output
  if (Time >= OutputTime):
  
    outputname = 'snap_%04d'%(OutputNumber)
    print "Step %06d  writing %s"%(Step,outputname)
    nb.rename(outputname)
    nb.atime = Time
    nb.write()
  
    OutputNumber += 1
    OutputTime = Time + dOutputTime


  print "Step %06d  Time = %8.3f CPUTime=%8.1f"%(Step,Time,CPUTime)
  
  # leap-frog, first stage
  nb.vel = nb.vel + nb.acc*dTime/2.		# vel to step n+1/2
  nb.pos = nb.pos + nb.vel*dTime		# pos to step n+1
  
  
  # make the tree
  nb.getTree(ErrTolTheta=ErrTolTheta,force_computation=True)
  
  # compute acceleration  
  nb.acc = nb.TreeAccel(nb.pos,eps)     
  
  # leap-frog, second stage
  nb.vel = nb.vel + nb.acc*dTime/2.		# vel to step n+1
 

  # increment time
  Time = Time + dTime
  Step = Step + 1
  CPUTime = time.time()-CPUTimeRef


  # write stats
  if (Time >= StatTime):
  
    print "Compute System Statistic"

  
    # compute integrals
    # for the potential computation, we do not need to recompute the tree
  
    T	 = ComputeEnergyKin(nb)
    U	 = ComputeEnergyPot(nb,eps)		  
    C	 = ComputeMassCenter(nb)
    P	 = ComputeMomentum(nb)
    L	 = ComputeAngularMomentum(nb)
    I	 = ComputeInertialMomentum(nb)
  
    WriteIntegrals(fi,Time,T,U,C,P,L,I)    
    
    StatTime = Time + dStatTime
    

# close files

fi.close()  



