"""
.. module:: Charge_charge
    :synopsis: Generate the charge_charge interaction matrix.
.. moduleauthor:: D. Wang <dwang5@zoho.com>
"""
import numpy as np
from netCDF4 import Dataset
import time
from math import atan, sqrt, log
from scipy import special
from supercell import Supercell
from ewald_cc import cc_set_parameters, cc_sum_over_k
[docs]class Charge_charge(Supercell):
    """
    Charge_charge inherits the *Supercell* class, it first initializes the supercell it works on.
    :param n1: Number of unit cells along the first Bravais vector of 'lattice'.
    :param n2: Number of unit cells along the second Bravais vector of 'lattice'.
    :param nz: Number of unit cells along the third Bravais vector of 'lattice'.
    :param lattice: The lattice of the **unit cell**, not the supercell.
    """
    def __init__(self, n1, n2, nz, lattice):
        Supercell.__init__(self, n1, n2, nz, lattice)
        self.ccij = np.zeros(self.nsites)
        self.charge_matrix_calculated = False
[docs]    def write_charge_matrix(self, fn):
        """
        Write the calculated interaction matrix to a netcdf file for later us.
        :param fn: The name of the netcdf file.
        :return: Nothing.
        """
        if self.charge_matrix_calculated == False:
            print("Calculate the matrix first ...")
            self.generate_charge_matrix()
            self.charge_matrix_calculated = True
        ccm = Dataset(fn, "w", format="NETCDF4")
        ia = ccm.createDimension("ia", None)
        # ias = ccm.createVariable("ia",np.int32,("ia"))
        # The actual 2-d varable.
        matrix = ccm.createVariable('matrix', np.float64, 'ia')
        ccm.description = 'Charge matrix: interaction matrix'
        ccm.history = 'Created at ' + time.ctime(time.time())
        matrix[:] = self.ccij[:]
        ccm.close() 
[docs]    def generate_charge_matrix(self):
        """
        This is the core of the interaction matrix calculation.
        'cc_set_parameters' and 'cc_sum_over_k' are defined in 'ewald_cc.cpp' using PyBind11.
        'cc_sum_over_k' is to sum over the k space, the long range part.
        :return: nothing.
        """
        pi = 4.0 * atan(1.0)
        pi2 = pi * 2.0
        #NN sets how many unit cells will be used for the real-space sum.
        #As explained in our paper, real-space sum can be ignored.
        NN = 10
        tol = 1.0e-12
        eta = sqrt(-log(tol))
        gcut = 2.0 * eta ** 2
        gcut2 = gcut ** 2
        eta4 = 1.0 / (4 * eta ** 2)
        am = np.zeros(3)
        for i in range(3):
            for k in range(3):
                am[i] += self.a[i, k] ** 2
            am[i] = sqrt(am[i])
        mg1 = int(gcut * am[0] / pi2) + 1
        mg2 = int(gcut * am[1] / pi2) + 1
        mg3 = int(gcut * am[2] / pi2) + 1
        print('Gcut: ', gcut, ' mg1, mg2, mg3: ', mg1, mg2, mg3)
        # Set parameters to be used in the PYBIND11 C++ computation.
        cc_set_parameters(self.b, mg1, mg2, mg3, gcut2, eta4)
        c = 4.0 * pi / self.celvol
        residue = 2.0 * eta / sqrt(pi)
        pos0 = np.zeros(3)
        pos0 = self.ixa[0] * self.lattice[0, :] \
               
+ self.iya[0] * self.lattice[1, :] \
               
+ self.iza[0] * self.lattice[2, :]
        for ia in range(self.nsites):
            # Note how the three dimensional (dx,dy,dz) is mapped
            # into a single array, which is important for correct
            # later use of the generated matrix.
            print('site: ', ia)
            pos = np.zeros(3)
            pos = self.ixa[ia] * self.lattice[0, :] \
                  
+ self.iya[ia] * self.lattice[1, :] \
                  
+ self.iza[ia] * self.lattice[2, :]
            rx = pos[0]
            ry = pos[1]
            rz = pos[2]
            # print('Summing over k space')
            krslt = cc_sum_over_k(rx - pos0[0], ry - pos0[1], rz - pos0[2])
            self.ccij[ia] = krslt * c
            for ir1 in range(-NN, NN + 1):
                for ir2 in range(-NN, NN + 1):
                    for ir3 in range(-NN, NN + 1):
                        if (ir1 == 0 and ir2 == 0 and ir3 == 0): continue
                        Rpos = np.zeros(3)
                        Rpos = ir1 * self.a[0, :] + ir2 * self.a[1, :] + ir3 * self.a[2, :]
                        Rix = Rpos[0]
                        Riy = Rpos[1]
                        Riz = Rpos[2]
                        x = (rx - Rix)
                        y = (ry - Riy)
                        z = (rz - Riz)
                        r = sqrt(x ** 2 + y ** 2 + z ** 2)
                        # self.ccij[ia] += 1.0/r * special.erfc(r*eta)
            if (ia == 0):
                self.ccij[ia] -= residue
            else:
                dum0 = sqrt(rx * rx + ry * ry + rz * rz)
                self.ccij[ia] += 1.0 / dum0 * special.erfc(dum0 * eta)
        self.ccij[:] = 0.5 * self.ccij[:]