|

楼主 |
发表于 2023-6-25 16:12:15
|
显示全部楼层
最终代码
- import numpy as np
- import matplotlib.pyplot as plt
- n_grid = 200
- x = np.linspace(-5,5,n_grid)
- h = x[1]-x[0]
- D = -np.eye(n_grid)+np.diagflat(np.ones(n_grid-1),1)
- D = D / h
- D2=D.dot(-D.T)
- D2[-1,-1]=D2[0,0]
- X=np.diagflat(x*x)
- eig_harm, psi_harm = np.linalg.eigh(-D2/2+X)
- # integral
- def integral(x,y,axis=0):
- dx=x[1]-x[0]
- return np.sum(y*dx, axis=axis)
- num_electron=17
- def get_nx(num_electron, psi, x):
- # normalization
- I=integral(x,psi**2,axis=0)
- normed_psi=psi/np.sqrt(I)[None, :]
-
- # occupation num
- fn=[2 for _ in range(num_electron//2)]
- if num_electron % 2:
- fn.append(1)
- # density
- res=np.zeros_like(normed_psi[:,0])
- for ne, psi in zip(fn,normed_psi.T):
- res += ne*(psi**2)
- return res
- def get_exchange(nx,x):
- energy=-3./4.*(3./np.pi)**(1./3.)*integral(x,nx**(4./3.))
- potential=-(3./np.pi)**(1./3.)*nx**(1./3.)
- return energy, potential
- def get_hatree(nx,x, eps=1e-1):
- h=x[1]-x[0]
- energy=np.sum(nx[None,:]*nx[:,None]*h**2/np.sqrt((x[None,:]-x[:,None])**2+eps)/2)
- potential=np.sum(nx[None,:]*h/np.sqrt((x[None,:]-x[:,None])**2+eps),axis=-1)
- return energy, potential
- def print_log(i,log):
- print(f"step: {i:<5} energy: {log['energy'][-1]:<10.4f} energy_diff: {log['energy_diff'][-1]:.10f}")
- max_iter=1000
- energy_tolerance=1e-5
- log={"energy":[float("inf")], "energy_diff":[float("inf")]}
- nx=np.zeros(n_grid)
- for i in range(max_iter):
- ex_energy, ex_potential=get_exchange(nx,x)
- ha_energy, ha_potential=get_hatree(nx,x)
-
- # Hamiltonian
- H=-D2/2+np.diagflat(ex_potential+ha_potential+x*x)
-
- energy, psi= np.linalg.eigh(H)
-
- # log
- log["energy"].append(energy[0])
- energy_diff=energy[0]-log["energy"][-2]
- log["energy_diff"].append(energy_diff)
- print_log(i,log)
-
- # convergence
- if abs(energy_diff) < energy_tolerance:
- print("converged!")
- break
-
- # update density
- nx=get_nx(num_electron,psi,x)
- else:
- print("not converged")
- plt.plot(nx)
- plt.plot(get_nx(num_electron,psi_harm,x), label="no-interaction")
- plt.legend()
- plt.pause(0.01)
复制代码 |
|