from ase.io import read
import numpy as np

#rxyz dimension => nat*3
def hessian_lennard_jones(rxyz):
    nat = rxyz.shape[0]
    hess = np.zeros((3*nat, 3*nat))
    
    for iat in range(nat-1):
        for jat in range(iat+1, nat):
            if iat != jat:
                dx = rxyz[iat, 0] - rxyz[jat, 0]
                dy = rxyz[iat, 1] - rxyz[jat, 1]
                dz = rxyz[iat, 2] - rxyz[jat, 2]
                
                dd = dx**2 + dy**2 + dz**2
                dd2 = 1.0 / dd
                dd8 = dd2**4
                dd10 = dd8 * dd2
                dd14 = dd10 * dd2**2
                dd16 = dd14 * dd2
                
                dxdx = dx*dx
                dydy = dy*dy
                dzdz = dz*dz
                
                dxdy = dx*dy
                dydz = dy*dz
                dzdx = dz*dx
                
                i = 3*iat
                j = 3*jat
                
                # Precompute common terms to reduce the number of calculations
                common_term_1 = 672.0*dd16 - 192.0*dd10
                common_term_2 = 48.0*dd14 - 24.0*dd8

                # dx*dx term
                temp = common_term_1*dxdx - common_term_2
                hess[i, i] += temp
                hess[j, j] += temp
                hess[i, j] = -temp
                
                # dx*dy and dy*dx term
                temp = common_term_1*dxdy
                hess[i, i+1] += temp
                hess[j, j+1] += temp
                hess[i, j+1] = -temp
                hess[i+1, j] = -temp
                
                # dx*dz and dz*dx term
                temp = common_term_1*dzdx
                hess[i, i+2] += temp
                hess[j, j+2] += temp
                hess[i, j+2] = -temp
                hess[i+2, j] = -temp

                # dy*dy term
                i += 1
                j += 1
                temp = common_term_1*dydy - common_term_2
                hess[i, i] += temp
                hess[j, j] += temp
                hess[i, j] = -temp

                # dy*dz and dz*dy term
                temp = common_term_1*dydz
                hess[i, i+1] += temp
                hess[j, j+1] += temp
                hess[i, j+1] = -temp
                hess[i+1, j] = -temp

                # dz*dz term
                i += 1
                j += 1
                temp = common_term_1*dzdz - common_term_2
                hess[i, i] += temp
                hess[j, j] += temp
                hess[i, j] = -temp

    # Fill the lower triangle of the hessian matrix
    for i in range(3*nat):
        for j in range(i):
            hess[i, j] = hess[j, i]

    return hess


def save_hessian_to_file(hess, filename='hessian_output.txt'):
    with open(filename, 'w') as f:
        for row in hess:
            f.write(' '.join(['{:.10E}'.format(x) for x in row]))
            f.write('\n')

from lj_python import lenjon
if __name__ == "__main__":
    # Read the atomic configuration from the "bcc.xyz" file
    atoms = read('./LJ38.xyz')
    
    # Extract the number of atoms and their positions
    nat = len(atoms)
    rxyz = atoms.get_positions()  # Transposing to match the shape expected by the lenjon function
    
    # Calculate the total energy and forces using the lenjon function
    hess = hesslj(nat, rxyz)
    
    x = rxyz.reshape(3 * nat)

    dx = 1e-6

    # construct finite difference hessian
    hess2 = np.zeros((3*nat, 3*nat))
    for i in range(3*nat):
        xl = x.copy()
        xr = x.copy()
        xl[i] = x[i] - dx
        xr[i] = x[i] + dx
        xl = xl.reshape((nat, 3))
        xr = xr.reshape((nat, 3))
        el, fl = lenjon_vectorized(nat, xl)
        er, fr = lenjon_vectorized(nat, xr)
        fl = -fl.reshape(3*nat)
        fr = -fr.reshape(3*nat)
        hess2[:, i] = (fr - fl) / (2 * dx)

    # print(hess[:10, 0])
    # print(hess2[:10, 0])
    print(np.max(np.abs(hess - hess2)))



    # Step 3: Read the Fortran output into a numpy array
    # fortran_output = np.loadtxt('fortran_hessian_output.dat')
    
    # is_close = np.allclose(fortran_output, hess, rtol=1e-05, atol=1e-08)
    
    # print(f"Are the outputs close?: {'Yes' if is_close else 'No'}")
    
    # Save the hessian matrix to a file
    # save_hessian_to_file(hess, filename='python_hessian_output.txt')
