# Wavefunction database and computation (WfBase)
# April 26th, 2024
__version__='0.0.2'
# Copyright 2024 by Sinisa Coh
#
# This file is part of WfBase. WfBase is free software: you can
# redistribute it and/or modify it under the terms of the GNU General
# Public License as published by the Free Software Foundation, either
# version 3 of the License, or (at your option) any later version.
#
# WfBase is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
# License for more details.
#
# A copy of the GNU General Public License should be available
# alongside this source in a file named gpl-3.0.txt. If not,
# see <http://www.gnu.org/licenses/>.
#
# WfBase is available at https://coh.ucr.edu/wfbase
import wannierberri as wberri
import numpy as np
from numba import njit
from opt_einsum import contract as opteinsum
from copy import deepcopy
import textwrap
import traceback
import fnmatch
from fractions import Fraction
from PIL import Image
import re
import sys
import os
import io
import time
import imgcat
import matplotlib
import matplotlib.pyplot as plt
import requests
from zipfile import ZipFile
import pyparsing as pp
sys.setrecursionlimit(10000)
pp.ParserElement.enablePackrat()
hbar_SI = 1.0545718176461565e-34
electron_charge_SI = 1.602176487e-19
epsilon_zero_SI = 8.854187817620389e-12
speed_of_light_SI = 299792458.0
electron_mass_SI = 9.1093837015e-31
angstrom_SI = 1.0e-10
hartree_SI = 4.35974394E-18
rydberg_SI = hartree_SI / 2.0
bohr_radius_SI = 0.52917720859E-10
QE_MAGN_SIGN = -1.0
[docs]def download_data_if_needed(silent = False):
r"""
Downloads the latest version of the WfBase database (file *data.zip*)
and unpacks it in the *data/* folder. Will not erase previous *data/*
folder if it already exists.
You can manually download the same database by following
:ref:`these instructions <database>`.
The database can be in any folder, it doesn't need to be in a folder
called *data* (which is the default used in this documentation). You
might also want to place this folder at a fixed place on your machine
and then simply load it by providing an absolute path to the file
when you use the *load* function. This way you don't need to have multiple
copies of the database on your machine.
:param silent: If set to *True* will not print an error message if
unable to download, or unzip, the database, or the database was
already downloaded. The default is *False*.
Example usage::
import wfbase as wf
# download the database
wf.download_data_if_needed()
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# you can also open the database file by providing
# absolute path.
db = wf.load("~/work/calculations/data/fe_bcc.wf")
"""
folder = "data"
fname = "data.zip"
if os.path.exists(folder) == True:
if silent == False:
_print_without_stopping(""" You called function *download_data_if_needed* in your script.
This function is supposed to download the WfBase database from a website, and unpack
it in folder named """ + folder + """. However, this folder seems to exist already
in the current path. Therefore, this script will not download anything. You probably should
remove the call to function .download_data_if_needed() from your script. If you insist
on forcing a new download of the database, you could rename, or move, folder """ + folder + """
and then run the script again.""")
return
elif os.path.exists(fname) == True:
if silent == False:
_print_without_stopping(""" You called function *download_data_if_needed* in your script.
This function is supposed to download the WfBase database from a website as a
single zip file called """ + fname + """. However, this zip file seems to exist already
in the current path. Therefore, this script will not download anything. You probably should
remove the call to function .download_data_if_needed() from your script. If you insist
on forcing a new download of the database, you could rename, or move, zip file """ + fname + """
and then run the script again.""")
return
url = "https://coh.engr.ucr.edu/wfbase/latest/" + fname
if silent == False:
print("Trying to connect to " + url)
r = requests.get(url)
open(fname, 'wb').write(r.content)
if silent == False:
print("Download successful!")
with ZipFile(fname, "r") as f:
f.extractall()
if silent == False:
print("Unzipped file " + fname)
[docs]def load(*args, **kwargs):
r"""
This is the function used to open a database file containing information
about one of the DFT calculations.
:param data_path: Path to the database .wf file containing information about
a calculation. The user should download database .wf file by following
:ref:`these instructions <database>`.
:param perform_consistency_check: A flag specifying whether the code
should check that the computation of various physical quantities is
done correctly. The default is True. Set to False only if you are really
sure what is going on. It is strongly advised to keep this parameter as is.
:returns:
* **db** -- database object of type :class:`DatabaseWf <wfbase.DatabaseWf>`.
This object can be used next to create computators of various physical
quantities.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
"""
ret = DatabaseWf()
ret._load_from_wfbase_database_file(*args, **kwargs)
return ret
[docs]def load_from_wannierberri(*args, **kwargs):
r"""
Loads calculation directly using Wannier Berri package without using
a .wf database file from WfBase.
See documentation of `Wannier Berri <https://wannier-berri.org>`_ for more
details on creation of object *System_w90*.
:param system: This is *System_w90* object from Wannier Berri.
:param global_fermi_level_ev: This is the Fermi level in eV. You
can get this number at the end of your self-consistent DFT calculation.
WfBase will later shift all band energies by this number, so that
new Fermi level is zero.
:returns:
* **db** -- database object of type :class:`DatabaseWf <wfbase.DatabaseWf>`.
This object can next be used to create computators of various physical
quantities.
Example usage::
import wfbase as wf
import wannierberri as wberri
def main():
system = wberri.System_w90("run_dft_output/x", berry = True, spin = False)
db = wf.load_from_wannierberri(system, global_fermi_level_ev = 18.3776)
if __name__ == "__main__":
main()
"""
ret = DatabaseWf()
ret._load_from_wannierberri_system(*args, **kwargs)
return ret
[docs]class DatabaseWf():
r"""
Object of this class contains information about the DFT calculations. Use
function :func:`load <wfbase.load>` create object of this class.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
"""
def __init__(self):
self._only_essentials = False
self._rng = np.random.default_rng(8318)
self._loaded_from_wannierberri = False
def _load_from_wfbase_database_file(self, data_path, perform_consistency_check = True):
self.__data_path = data_path
self._system, self._add_info = _read_interface_to_wberri_from_file(self.__data_path)
if perform_consistency_check == True:
hashes_current = self._compute_own_hashes()
for k in hashes_current.keys():
hash_stored = self._add_info["hash_" + k]
if k in ["nonhermA", "nonhermS"]:
comp = _are_hashes_similar_absolute(hashes_current[k], hash_stored)
else:
comp = _are_hashes_similar_relative(hashes_current[k], hash_stored)
if comp == False:
_raise_value_error("""Something is wrong with computation of
quantity """ + k + """. Value obtained from your combination
of database file/installed software/used hardware is different
from what is expected.
You likely should update your Wannier Berri installation,
or you should download newer version of the database.
If you know what you are doing and you want to make this
message go away, set perform_consistency_check to False,
but this is strongly discouraged.""")
def _load_from_wannierberri_system(self, system, global_fermi_level_ev):
self._system = system
self._loaded_from_wannierberri = True
self._add_info = {}
self._add_info["fermi_scf_ev"] = global_fermi_level_ev
self._add_info["num_wann"] = system.num_wann
self._add_info["cell"] = np.array(system.real_lattice, dtype = float)
[docs] def info(self, print_to_screen = True, full = False):
r"""
Returns information about the computation stored in the database .wf file. This
information here could be used by the user to redo the DFT calculation
from scratch, as shown in
:ref:`this example <sphx_glr_all_examples_example_standalone_prepare.py>`.
Note that there is a function with the same name that provides information
about the computator, not about the database .wf file. See here
for more information on how to use this other function :func:`info <wfbase._ComputatorWf.info>`.
:param print_to_screen: Whether the code should print the information
to the screen. The default is True. If set to False, nothing is printed
but instead this function returns a string.
:param full: Whether output should be cut to 50 lines per entry or the
entire information should be shown.
:returns:
* **txt** -- string of the text with the information. This is returned
only if *print_to_screen* is set to False.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# print information about the database
db.info()
"""
if full == False:
max_line = 50
else:
max_line = None
out = ""
if self._loaded_from_wannierberri == False:
out += "\n"
out += _make_rst_title("Database file *" + self.__data_path.strip() + "*")
out += "\n"
out += _make_rst_field("The version of this database file")
out += "\n"
out += _format_one_block_simple_indent(str(self._add_info["data_version"]), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Created using WfBase version")
out += "\n"
out += _format_one_block_simple_indent(str(self._add_info["wfbase_version"]), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Currently loaded WfBase version by this script")
out += "\n"
out += _format_one_block_simple_indent(str(__version__), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
if self._loaded_from_wannierberri == False:
out += _make_rst_field("Created using Wannier Berri version")
out += "\n"
out += _format_one_block_simple_indent(str(self._add_info["wberri_version"]), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Currently loaded Wannier Berri version")
out += "\n"
out += _format_one_block_simple_indent(str(wberri.__version__), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
if self._loaded_from_wannierberri == False:
out += _make_rst_field("Input file for the SCF computation using pw.x from Quantum ESPRESSO")
out += "\n"
out += _format_one_block_simple_indent(_adjust_input_file("scf", str(self._add_info["input_scf"])), indent = 4,
start_and_end = True, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Input file for the NSCF computation using pw.x from Quantum ESPRESSO")
out += "\n"
out += _format_one_block_simple_indent(_adjust_input_file("nscf", str(self._add_info["input_nscf"])), indent = 4,
start_and_end = True, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Pseudopotentials used in the calculation")
out += "\n"
out += _format_one_block_simple_indent(" , ".join(self._get_psp()), indent = 4,
start_and_end = False, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Input file for pw2wannier90.x from Quantum ESPRESSO")
out += "\n"
out += _format_one_block_simple_indent(_adjust_input_file("pw2wan", str(self._add_info["input_pw2wan"])), indent = 4,
start_and_end = True, max_line = max_line)
out += "\n\n"
out += _make_rst_field("Input file for Wannier90")
out += "\n"
out += _format_one_block_simple_indent(str(self._add_info["input_w90"]), indent = 4,
start_and_end = True, max_line = max_line)
out += "\n\n"
if print_to_screen:
print(out)
else:
return out
def _get_psp(self):
if self._loaded_from_wannierberri == True:
_stop_because_loaded_from_wannierberri()
pref = "https://coh.ucr.edu/wfbase/" + str(self._add_info["data_version"]) + "/psp/"
ret = list(map(lambda s: pref + str(s).strip(), self._add_info["pseudopotentials"]))
return ret
[docs] def do_mesh(self, k_mesh = None, shift_k = [0.0, 0.0, 0.0],
to_compute = ["psi", "A", "S", "dEdk"],
formatted_output_latex = True, doublet_indices = False, reorder_orbitals = False):
r"""
Compute various quantities on a regular k-mesh in the Brillouin zone.
By default, it will compute the following quantities,
.. _common_quantites:
+----------+--------------------------------------------+--------------------------------------------------+
| Quantity | Short description | How to get more information? |
+==========+============================================+==================================================+
| E | Electron band energy. | Call function *comp.info("E")*. |
| | | Example :ref:`output <quantities>`. |
+----------+--------------------------------------------+--------------------------------------------------+
| psi | Electron wavefunction. | Call function *comp.info("psi")*. |
| | | Example :ref:`output <quantities>`. |
+----------+--------------------------------------------+--------------------------------------------------+
| A | Berry connection. | Call function *comp.info("A")*. |
| | | Example :ref:`output <quantities>`. |
+----------+--------------------------------------------+--------------------------------------------------+
| S | Electron spin magnetic moment. | Call function *comp.info("S")*. |
| | | Example :ref:`output <quantities>`. |
+----------+--------------------------------------------+--------------------------------------------------+
| dEdk | Electron Fermi velocity (times hbar). | Call function *comp.info("dEdk")*. |
| | | Example :ref:`output <quantities>`. |
+----------+--------------------------------------------+--------------------------------------------------+
| ... | | |
+----------+--------------------------------------------+--------------------------------------------------+
Code will also compute some other quantities, not listed here. For example,
it will construct the default range of energies *hbaromega*, choose a
default smearing parameter *eta*, etc. These can be changed by the user.
See the examples below.
A complete list of all computed quantities can be obtained by calling the
:func:`info <wfbase._ComputatorWf.info>` function on the object returned by this function.
Example output of the function :func:`info <wfbase._ComputatorWf.info>` can be found
:ref:`here <quantities>`.
.. note::
For more details on the computator object returned by this function, see
the description of the computator class :class:`_ComputatorWf <wfbase._ComputatorWf>`.
:param k_mesh: Size of the uniform k-mesh on which you want to compute
these quantities. This should be a vector with three components, one
for the number of k-points in each direction. The default is the coarse mesh used to
construct the Wannier functions, but you may want to use a denser
mesh. Note that the code will precompute all the quantities here on
the mesh you specify, and this might take up a lot of your RAM. If
you wish to save RAM, you could instead compute several smaller grids, one at
a time. For example, this can be achieved by randomly shifting the
k-grid. See :ref:`this example <sphx_glr_all_examples_example_conv.py>`
for more details.
:param shift_k: Shift of the uniform k-mesh. The coordinates for the shift
are given as dimensionless, reduced, coordinates. Therefore, this
parameter expects a set of three numbers between 0 and 1. If you
specify numbers outside of this range, then the code will automatically
reduce them to the range from 0 to 1 by removing the integer part.
If you set *shift_k* to a string "random" then the code will shift the k-mesh
by a random amount in all three directions. This might be useful for
sampling the k-points.
:param to_compute: Quantity "E" is always computed. Here you can list additional
quantities that you want to compute. These are any combination of
"psi", "A", "S", "dEdk". The default is to use all of them: ["psi", "A", "S", "dEdk"].
See here for :ref:`more information <common_quantites>`.
If you don't need some of these quantities, don't list them here, so they will not
be computed, and the code will use less resources.
:param formatted_output_latex: Boolean value of True or False. The default
is True. If set to False then any quantity evaluated from this computator
will have a less formatted latex output. For example, the latex output
will not use bra and ket notation, etc.
:param doublet_indices: The default is False. If set to True then the code will check
if your band-structure is (at least) two-fold degenerate at each k-point.
(This happens when a product of inversion and time-reversion symmetry is present,
such as in inversion symmetric non-magnetic systems, for example.) If the
band structure is (at least) two-fold degenerate at each k-point, and this
parameter is set to True, then the band indices of all quantities will be
changed, as follows. If you initially had, for example, 18 bands, then
band index "n" would normally go over those 18 bands. But if this parameter
is set to True then this same system will now have two indices, call them "a"
and "A". Index "a" now goes over 9 values while index "A" goes over 2 values.
Therefore, "a" corresponds to the index of a doublet, and "A" indexes states
in the doublet. The choice of the two states in the doublet is randomized
by the diagonalizer in the Wannier Berri, and there is no special meaning to it.
Also, if your band structure at some point is 4-fold degenerate, this routine
will still use the doublet notation, and there is again no special meaning in the
choice of those two doublets out of the 4 degenerate states. If you want to avoid
these high-symmetry points then set *shift_k* to "random", as that will guarantee
that you can't have more than doubly degenerate bands at every uniform mesh.
:param reorder_orbitals: The default is False. If set to True it will reorder orbitals
in quantities *wfc* and *orbitallabels* so that the spin index of the orbital
is the slow index. All other quantities are left unchanged.
:returns:
* **comp** -- computator object of type :class:`_ComputatorWf <wfbase._ComputatorWf>`.
This object can next be used to evaluate various physical quantities.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a uniform mesh
comp = db.do_mesh()
# create a different computator object, from the same database db
# but now use a different k-mesh. There is no need to load
# the database object again using load function!
comp_alter = db.do_mesh([4, 4, 4])
# print information about all quantities in the computator
comp.info()
# print information about one of the quantities in the computator
comp.info("A")
# change the energy range
comp.compute_photon_energy("hbaromega", 0.0, 5.0, 51)
# change the value of the smearing parameter
comp.replace("eta", "0.2 eV")
# different way to change the value of the smearing parameter
comp["eta"] = 0.25
# you can also do more complex operations on quantities
comp["S"] = comp["S"][:, :, :, 0]
db = wf.load("data/au_bcc.wf")
comp = db.do_mesh()
print("Shape of matrix A without doubling: ", comp.get_shape("A"))
comp_new = db.do_mesh(doublet_indices = True)
print("Shape of matrix A with doubling : ", comp_new.get_shape("A"))
"""
start_counter_do = time.perf_counter()
if k_mesh is None:
k_mesh = self._system.mp_grid
if isinstance(shift_k, str):
if shift_k.lower().strip() == "random":
shift_k = self._rng.random((3))
else:
_raise_value_error("Unrecognized string provided to shift_k: " + shift_k + ".")
props = {}
text_trap = io.StringIO()
sys.stdout = text_trap
__grid = wberri.Grid(system = self._system, NKFFT = k_mesh, NKdiv = [1, 1, 1])
sys.stdout = sys.__stdout__
self._data_K = wberri.data_K.Data_K_R(system = self._system, dK = shift_k, grid = __grid)
if False:
# The core below throws an error in current version of wannierberri.
props["kredvec"] = {"value": self._data_K.kpoints_all(),
"units": Units(eV = 0, Ang = 0, muB = 0),
"origin_story": "...",
"indices_info": {
"definition": None,
"canonical_names": "...",
"explanation": ["...",
"..."],
"bands": [],
},
}
del __grid
computed = {}
computed["E"] = np.copy(self._data_K.E_K)
for thing in to_compute:
if thing == "psi":
computed[thing] = _potentially_reorder_orbitals(
np.copy(np.transpose(self._data_K.UU_K, (0, 2, 1))),
2, reorder_orbitals)
elif thing == "A":
computed[thing] = np.copy(self._data_K.A_H)
elif thing == "S":
computed[thing] = np.copy(self._data_K.Xbar("SS"))
elif thing == "dEdk":
computed[thing] = np.copy(
np.diagonal(self._data_K.Xbar('Ham', 1), axis1 = 1, axis2 = 2).transpose(0, 2, 1))
elif thing == "E":
continue
else:
_raise_value_error("Unknown quantity " + thing)
del self._data_K
props["numk"] = {
"value": int(np.prod(k_mesh)),
"origin_story": "Total number of k-points in the mesh.",
"units": Units(eV = 0, Ang = 0, muB = 0),
"format" : r"N_{\rm k}",
"format_conjugate": r"N_{\rm k}",
}
props = self.__comp_common_essential(computed, props)
if self._only_essentials == False:
props = self.__comp_common(computed, props, reorder_orbitals)
comp = _ComputatorWf(props, formatted_output_latex,
doublet_indices, self._loaded_from_wannierberri)
comp._computated_using = "do_mesh"
comp.compute_photon_energy("hbaromega", 0.01, 3.0, 31)
comp.compute_occupation("f", "E", "ef")
comp.new("eta", {"value": 0.1, "units": Units(eV = 1)})
time_do = time.perf_counter() - start_counter_do
comp._total_seconds_initialize = time_do
return comp
[docs] def do_list(self, k_list, to_compute = [],
formatted_output_latex = True,
doublet_indices = False,
reorder_orbitals = False):
r"""
Similar to :func:`do_mesh <wfbase.DatabaseWf.do_mesh>` with the difference that
now the returned computator contains information on an arbitrary list
of k-vectors.
Use :func:`do_path <wfbase.DatabaseWf.do_path>` for a simple way to generate a path
between the special k-points in the Brillouin zone.
:param k_list: List of k-vectors on which you want to do a computation. These
vectors are specified as dimensionless reduced coordinates of the reciprocal vectors.
:param to_compute: List of additional quantities that you want to compute.
You can take any of these: "psi", "A", "S", "dEdk". The default is to compute none
of them. See here for :ref:`more information <common_quantites>` about these quantities.
:param formatted_output_latex: Boolean value of True or False. The default
is True. If set to False then any quantity evaluated from this computator
will have a less formatted latex output. For example, the latex output
will not use bra and ket notation, etc.
:param doublet_indices: Same meaning as in :func:`do_mesh <wfbase.DatabaseWf.do_mesh>`.
:param reorder_orbitals: The default is False. If set to True it will reorder orbitals
in quantities *wfc* and *orbitallabels* so that the spin index of the orbital
is the slow index. All other quantities are left unchanged.
:returns:
* **comp** -- computator object of type :class:`_ComputatorWf <wfbase._ComputatorWf>`. This object
can next be used to evaluate various physical quantities.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a list of two k-points
comp = db.do_list([[0.2, 0.4, 0.2], [0.1, 0.9, 0.4]])
"""
start_counter_do = time.perf_counter()
props = {}
computed = {}
computed["E"] = []
for thing in to_compute:
if thing == "E":
continue
computed[thing] = []
for ii in range(len(k_list)):
k_one = np.array(k_list[ii])
text_trap = io.StringIO()
sys.stdout = text_trap
__grid = wberri.Grid(system = self._system, NKFFT = [1, 1, 1], NKdiv = [1, 1, 1])
sys.stdout = sys.__stdout__
self._data_K = wberri.data_K.Data_K_R(system = self._system, dK = k_one, grid = __grid)
del __grid
computed["E"].append(np.copy(self._data_K.E_K[0]))
for thing in to_compute:
if thing == "psi":
computed[thing].append(_potentially_reorder_orbitals(
np.copy(np.transpose(self._data_K.UU_K, (0, 2, 1))[0]),
2, reorder_orbitals))
elif thing == "A":
computed[thing].append(np.copy(self._data_K.A_H[0]))
elif thing == "S":
computed[thing].append(np.copy(self._data_K.Xbar("SS")[0]))
elif thing == "dEdk":
computed[thing].append(
np.copy(np.diagonal(self._data_K.Xbar('Ham', 1), axis1 = 1, axis2 = 2).transpose(0, 2, 1)[0]))
elif thing == "E":
continue
else:
_raise_value_error("Unknown quantity " + thing)
del self._data_K
for k in computed.keys():
computed[k] = np.array(computed[k])
props["numk"] = {
"value": len(k_list),
"origin_story": "Total number of k-points in the list",
"units": Units(eV = 0, Ang = 0, muB = 0),
"format" : r"N_{\rm k}",
"format_conjugate": r"N_{\rm k}",
}
props["kredvec"] = {
"value": np.array(k_list),
"units": Units(eV = 0, Ang = 0, muB = 0),
"origin_story": "These are reduced coordinates of all k-points in the list.",
"indices_info": {
"definition": None,
"canonical_names": "kr",
"explanation": ["k-points in the list",
"reduced coordinate"],
"bands": [],
},
}
props = self.__comp_common_essential(computed, props)
if self._only_essentials == False:
props = self.__comp_common(computed, props, reorder_orbitals)
comp = _ComputatorWf(props, formatted_output_latex, doublet_indices, self._loaded_from_wannierberri)
comp._computated_using = "do_list"
comp.compute_photon_energy("hbaromega", 0.01, 3.0, 31)
comp.compute_occupation("f", "E", "ef")
comp.new("eta", {"value": 0.1, "units": Units(eV = 1)})
time_do = time.perf_counter() - start_counter_do
comp._total_seconds_initialize = time_do
return comp
def __comp_common_essential(self, computed, props):
props["ef"] = \
{
"value": 0.0,
"origin_story": """This is the Fermi level. It is set to zero as we subtract the DFT-computed Fermi
level on a coarse mesh from the Hamiltonian in the Wannier basis. Note: if you compute band structure
on a very fine mesh, and then recompute the Fermi level from that band-structure, you will likely
get a Fermi level that is not exactly zero. This is to be expected, as the Fermi level will be somewhat
sensitive on the k-mesh you choose and the smearing you use for the occupations. However, one can expect
that 0.0 is a good approximation of the Fermi level for any k-mesh you use, but you may need to check
this on your own for a very fine mesh.""",
"units": Units(eV = 1, Ang = 0, muB = 0),
"format": r"{\rm E}_{\rm F}",
"format_conjugate": r"{\rm E}_{\rm F}",
}
props["eV"] = \
{
"value": 1.0,
"origin_story": "Constant equal to 1 eV.",
"units": Units(eV = 1, Ang = 0, muB = 0),
"format": r"{\rm \, eV}",
"format_conjugate": r"{\rm \, eV}",
}
props["Ang"] = \
{
"value": 1.0,
"origin_story": "Constant equal to 1 angstrom.",
"units": Units(eV = 0, Ang = 1, muB = 0),
"format": r"{\rm \, \AA}",
"format_conjugate": r"{\rm \, \AA}",
}
props["muB"] = \
{
"value": 1.0,
"origin_story": "Constant equal to 1 bohr magneton.",
"units": Units(eV = 0, Ang = 0, muB = 1),
"format": r"\, \mu_{\rm B}",
"format_conjugate": r"\, \mu_{\rm B}",
}
for thing in computed.keys():
if thing == "E":
props[thing] = {
"value": computed[thing] - float(self._add_info["fermi_scf_ev"]),
"origin_story": """Band energies E computed from Wannier interpolation.
The Fermi level was computed on a DFT coarse mesh and then bands were shifted
so that the Fermi level was set
to zero. If you use a very fine k-mesh then the Fermi level will have
to be recomputed for your k-mesh, but it will likely still be close to zero.
Note: due to the nature of the Wannier interpolation, some of the bands
far above the Fermi level do not correspond to the actual bands computed
in the DFT. Typically this is not a problem as one usually cares about states
close to the Fermi level. All electron states around the Fermi level, as well as
in the valence bands should
be well reproduced. Quantities "reliableminenergy" and "reliablemaxenergy"
give you an energy window in which band energies are reliable.
Similarly, due to the nature of pseudopotential calculations, these band
structures do not contain deep core states, but again, these are often not needed.
""",
"indices_info": {
"definition": r"H_*0 u_*0*1 = E_*0*1 u_*0*1",
"canonical_names": "kn",
"explanation": ["index of a k-point",
"electron band index"],
"bands": [1],
},
"units": Units(eV = 1, Ang = 0, muB = 0),
"format": r"E_{*0*1}",
"format_conjugate": r"E_{*0*1}",
}
elif thing == "psi":
props[thing] = {
"value": computed[thing],
"origin_story":"""This is the electron wavefunction written in terms
of the localized Wannier state. The approximate atomic-like orbital
characters of these localized Wannier states are given by quantity
*orbitallabels* (available only if your database was loaded from
the WfBase's database).
""",
"indices_info": {
"definition": r"psi_*0*1*2 = < W_*2 | psi_*0*1 >",
"canonical_names": "knp",
"explanation": ["index of a k-point",
"electron band index of the state",
"localized orbital index",
],
"bands": [1],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
"format": r"\psi_{*0*1*2}",
"format_conjugate": r"\overline{\psi_{*0*1*2}}",
}
elif thing == "A":
props[thing] = {
"value": computed[thing],
"origin_story":"""This is the Berry connection that can be used to
compute the optical matrix elements and other related quantities.
""",
"indices_info": {
"definition": r"< u_*0*1 | i delk_*3 | u_*0*2 >",
"canonical_names": "knma",
"explanation": ["index of a k-point",
"electron band index of the bra state",
"electron band index of the ket state",
"k-derivative in Cartesian axes (0 for x, 1 for y, 2 for z)",
],
"bands": [1, 2],
},
"units": Units(eV = 0, Ang = 1, muB = 0),
"format": r"\langle u_{*0*1} \lvert "+ _process_latex_imag_j("j") + r" \partial_{k_{*3}} \rvert u_{*0*2} \rangle",
"format_conjugate": r"\langle u_{*0*2} \lvert "+ _process_latex_imag_j("j") + r" \partial_{k_{*3}} \rvert u_{*0*1} \rangle",
}
elif thing == "S":
props[thing] = {
"value": QE_MAGN_SIGN * computed[thing],
"origin_story": """
Matrix elements of the spin magnetic moment operator.
""",
"indices_info": {
"definition": r"< u_*0*1 | Mspin_*3 | u_*0*2 >",
"canonical_names": "knma",
"explanation": ["index of a k-point",
"electron band index of the bra state",
"electron band index of the ket state",
"direction of spin in Cartesian axes (0 for x, 1 for y, 2 for z)",
],
"bands": [1, 2],
},
"units": Units(eV = 0, Ang = 0, muB = 1),
"format": r"\langle \psi_{*0*1} \lvert M^{\rm spin}_{*3} \rvert \psi_{*0*2} \rangle",
"format_conjugate": r"\langle \psi_{*0*2} \lvert M^{\rm spin}_{*3} \rvert \psi_{*0*1} \rangle",
}
elif thing == "dEdk":
props[thing] = {
"value": computed[thing],
"origin_story":"""This is hbar times the Fermi velocity.
""",
"indices_info": {
"definition": r"d E_*0*1 / d k_*2 (= hbar Vfermi_*0*1*2)",
"canonical_names": "kna",
"explanation": ["index of a k-point",
"electron band index",
"k-derivative in Cartesian axes (0 for x, 1 for y, 2 for z)",
],
"bands": [1],
},
"units": Units(eV = 1, Ang = 1, muB = 0),
"format": r"\frac{\partial E_{*0*1}}{\partial k_{*2}}",
"format_conjugate": r"\frac{\partial E_{*0*1}}{\partial k_{*2}}",
}
else:
_raise_value_error("Unknown quantity: " + thing + " !")
return props
def __comp_common(self, computed, props, reorder_orbitals):
_cell = self._add_info["cell"]
props["numwann"] = \
{
"value": int(self._add_info["num_wann"]),
"origin_story": "The number of Wannier bands. " +
" Also, the number of electron states that are computed at each k-point.",
"units": Units(eV = 0, Ang = 0, muB = 0),
}
props["coarsekmesh"] = \
{
"value": np.array(self._system.mp_grid),
"origin_story": "The size of the coarse k-mesh used to create Wannier functions",
"indices_info": {
"canonical_names": "r",
"explanation": ["reduced reciprocal axis",
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
}
props["cell"] = \
{
"value": np.array(_cell, dtype = float),
"origin_story": """Computational unit cell vectors.
""",
"indices_info": {
"canonical_names": "ia",
"explanation": ["index of the cell vector",
"Cartesian axis (0 for x, 1 for y, 2 for z)",
],
"bands": [],
},
"units": Units(eV = 0, Ang = 1, muB = 0),
}
props["recip"] = \
{
"value": np.array(2.0*np.pi*_real_to_recip_no2pi([_cell[0], _cell[1], _cell[2]]), dtype = float),
"origin_story": """Reciprocal unit cell vectors.
""",
"indices_info": {
"canonical_names": "ia",
"explanation": ["index of the reciprocal cell vector",
"Cartesian axis (0 for x, 1 for y, 2 for z)",
],
"bands": [],
},
"units": Units(eV = 0, Ang =-1, muB = 0),
}
props["volume"] = \
{
"value": np.linalg.det(_cell),
"origin_story": "The volume of the computational unit cell.",
"units": Units(eV = 0, Ang = 3, muB = 0),
"format" : r"V_{\rm c}",
"format_conjugate": r"V_{\rm c}",
}
if self._loaded_from_wannierberri == False:
props["orbitallabels"] = \
{
"value": _potentially_reorder_orbitals(np.array(self._add_info["orbital_labels"]), 0, reorder_orbitals),
"origin_story": "Array of names of localized atomic-like orbitals" +
" used in the decomposition of the wavefunction *psi*. Uparrow and " +
"downarrow refer to spin angular momentum (opposite to the spin magnetic moment)." +
" " + str(self._add_info["orbital_labels_description"]),
"indices_info": {
"canonical_names": "p",
"explanation": ["localized orbital index",
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
}
props["atomname"] = \
{
"value": np.array(self._add_info["atom_name"]),
"origin_story": "Array of names of atoms in the computational unit cell.",
"indices_info": {
"canonical_names": "j",
"explanation": ["index of the atom in the unit cell",
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
}
props["atomred"] = \
{
"value": np.array(self._add_info["atom_reduced"], dtype = float),
"origin_story": """Reduced coordinates of atom positions.
""",
"indices_info": {
"canonical_names": "jr",
"explanation": ["index of the atom in the unit cell",
"reduced coordinates axis",
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
}
props["reliablemaxenergy"] = \
{
"value": np.array(self._add_info["frozen_max"], dtype = float) - float(self._add_info["fermi_scf_ev"]),
"origin_story": """Due to the nature of the Wannier interpolation, the electronic
properties are well reproduced for the valence band, up to some energy around the Fermi level.
This number gives you a maximal energy up to which you should trust the electronic properties.
For a lot of calculations, one does not need to worry about the states that are too far from the Fermi level.
The energy of this window includes the fact that the Fermi level is set to zero.
""",
"units": Units(eV = 1, Ang = 0, muB = 0),
}
props["reliableminenergy"] = \
{
"value": np.array(self._add_info["frozen_min_adjusted"], dtype = float) - float(self._add_info["fermi_scf_ev"]),
"origin_story": """Similarly to reliablemaxenergy, parameter reliableminenergy gives you the minimal energy
for the range with reliable band properties. States below reliableminenergy, such as core states, are
given by Wannier interpolation. The interaction of valence with core and semi-core states is included, of course,
but the energies of core states themselves will not show up here.
""",
"units": Units(eV = 1, Ang = 0, muB = 0),
}
for thing in computed.keys():
if thing == "E":
if props[thing]["value"].shape != (props["numk"]["value"], props["numwann"]["value"]):
_raise_value_error("Object " + thing + " returned from Wannier Berri has a different shape than expected.")
elif thing == "psi":
if props[thing]["value"].shape != (props["numk"]["value"], props["numwann"]["value"], props["numwann"]["value"]):
_raise_value_error("Object " + thing + " returned from Wannier Berri has a different shape than expected.")
elif thing == "A":
if props[thing]["value"].shape != (props["numk"]["value"], props["numwann"]["value"], props["numwann"]["value"], 3):
_raise_value_error("Object " + thing + " returned from Wannier Berri has a different shape than expected.")
elif thing == "S":
if props[thing]["value"].shape != (props["numk"]["value"], props["numwann"]["value"], props["numwann"]["value"], 3):
_raise_value_error("Object " + thing + " returned from Wannier Berri has a different shape than expected.")
elif thing == "dEdk":
if props[thing]["value"].shape != (props["numk"]["value"], props["numwann"]["value"], 3):
_raise_value_error("Object " + thing + " returned from Wannier Berri has a different shape than expected.")
else:
_raise_value_error("Unknown quantity: " + thing + " !")
if props["volume"]["value"] < 0.0:
_raise_value_error("Unit cell vectors are not right-handed, as volume is negative.")
return props
[docs] def do_path(self, k_str,
to_compute = [],
num_steps_first_segment = 30,
latex_tick_labels = True,
formatted_output_latex = True,
doublet_indices = False,
reorder_orbitals = False):
r"""
Similar to :func:`do_mesh <wfbase.DatabaseWf.do_mesh>` with the difference that
now the returned computator contains information on a list
of k-vectors between special k-points.
:param k_str: String describing the path between the special k-points for the given
symmetry of the system. Conventions here follow that from the Bilbao Crystallographic
server.
:param to_compute: List of additional quantities that you want to compute.
You can take any of these: "psi", "A", "S", "dEdk". The default is to compute
none of them. See here for :ref:`more information <common_quantites>` about these quantities.
:param num_steps_first_segment: The number of points between two first special k-points
in the list. The number of points between other special k-points is computed so that
the density of k-points in the Brillouin zone is nearly constant. The default is 30.
:param latex_tick_labels: True or False. The default is True. If False, then labels
of special points will not use LaTeX. For example, Gamma point will simply be rendered
as "GM" instead of using "$\\Gamma$".
:param formatted_output_latex: Boolean value of True or False. The default
is True. If set to False then any quantity evaluated from this computator
will have a less formatted latex output. For example, the latex output
will not use bra and ket notation, etc.
:param doublet_indices: Same meaning as in :func:`do_mesh <wfbase.DatabaseWf.do_mesh>`.
:param reorder_orbitals: The default is False. If set to True it will reorder orbitals
in quantities *wfc* and *orbitallabels* so that the spin index of the orbital
is the slow index. All other quantities are left unchanged.
:returns:
* **comp** -- computator object of type :class:`_ComputatorWf <wfbase._ComputatorWf>`. This object
can next be used to evaluate various physical quantities.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a path between these special points
comp = db.do_path("GM--H--N")
# plot the band structure
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
comp.plot_bs(ax)
ax.set_title("Band structure of Fe bcc")
fig.tight_layout()
fig.savefig("a.pdf")
"""
if self._loaded_from_wannierberri == True:
_stop_because_loaded_from_wannierberri()
cell = self._add_info["cell"]
recip = _real_to_recip_no2pi([cell[0], cell[1], cell[2]])
# get information for the labels of k-points
kpoint_info = _get_kpoint_label_info(cell,
self._add_info["atom_name"],
self._add_info["atom_reduced"])
# parse the k-path
fmt = pp.one_of(list("-_,. "))
par = list(fmt.split(k_str, include_separators = False))
pnts = []
for p in par:
if p.strip() != "":
pnts.append(p.strip())
if len(pnts) == 0:
_raise_value_error("Need to specify at least one special k-point on the path.")
if len(pnts) > 1 and num_steps_first_segment < 2:
_raise_value_error("num_steps_first_segment must be 2 or larger.")
# get coordinates of endpoints in reduced
kred_end = []
for p in pnts:
if p.upper() not in kpoint_info.keys():
_raise_value_error("User specified incorrect label of a special k-point \"" + p +
"\". Allowed values for this system are " + str(list(kpoint_info.keys())))
kred_end.append(kpoint_info[p.upper()][0])
kred_end = np.array(kred_end)
pnts_new = []
for p in pnts:
if len(kpoint_info[p.upper()]) > 1 and latex_tick_labels == True:
pnts_new.append(kpoint_info[p.upper()][1])
else:
pnts_new.append(p.upper())
pnts = pnts_new
if kred_end.shape[0] > 1:
def distance_kpts_red(k0_red, k1_red):
_c0 = _red_to_cart(recip[0], recip[1], recip[2], k0_red)
_c1 = _red_to_cart(recip[0], recip[1], recip[2], k1_red)
return 2.0*np.pi*np.sqrt(np.dot(_c0 - _c1, _c0 - _c1))
num_steps = []
num_steps.append(int(round(num_steps_first_segment)))
kdist_zero = distance_kpts_red(kred_end[0], kred_end[1])
for i in range(2, kred_end.shape[0]):
tmp = np.round(num_steps_first_segment * distance_kpts_red(kred_end[i - 1],
kred_end[i]) / kdist_zero)
if tmp < 2:
tmp = 2
num_steps.append(int(tmp))
k_all_red = []
k_spec_index = []
k_spec_index.append(0)
for i in range(1, kred_end.shape[0]):
for j in range(num_steps[i - 1]):
if i > 1 and j == 0:
continue
kone = kred_end[i - 1] + (float(j)/float(num_steps[i - 1] - 1))*(kred_end[i] - kred_end[i - 1])
k_all_red.append(kone)
if j == num_steps[i - 1] - 1:
k_spec_index.append(len(k_all_red) - 1)
k_all_red = np.array(k_all_red)
else:
k_all_red = np.array([kred_end[0]])
k_spec_index = [0]
k_dist = []
k_dist.append(0.0)
for i in range(1, k_all_red.shape[0]):
k_dist.append(2.0*np.pi*distance_kpts_red(k_all_red[i - 1], k_all_red[i]))
k_dist = np.cumsum(k_dist)
k_spec_dist = k_dist[k_spec_index]
# now compute things in wannierberri
comp = self.do_list(k_all_red, to_compute, formatted_output_latex, doublet_indices, reorder_orbitals)
comp._computated_using = "do_path"
comp.new("kdist", {"value": k_dist,
"units": Units(eV = 0, Ang = -1, muB = 0),
"origin_story": """Distance between k-points along the specified k-path.
You might want to use this for the x-axis of your band structure plot.""",
"indices_info": {
"canonical_names": "l",
"explanation": ["index of all k-points along the path",
],
"bands": [],
},
})
comp.new("kspecdist", {"value": k_spec_dist,
"units": Units(eV = 0, Ang = -1, muB = 0),
"origin_story": """Distance of special points along the specified k-path.
For example, you can use this for x-location of special
points on your band structure plot.""",
"indices_info": {
"canonical_names": "s",
"explanation": ["index of the special k-points along the path",
],
"bands": [],
},
})
comp.new("kspeclabels", {"value": np.array(pnts),
"units": Units(eV = 0, Ang = 0, muB = 0),
"origin_story": "Labels of special k-points along the specified k-path.",
"indices_info": {
"canonical_names": "s",
"explanation": ["index of the special k-points along the path",
],
"bands": [],
},
})
return comp
def _compute_own_hashes(self):
ret = {}
comp = self.do_mesh([4, 2, 3], [0.123, 0.298, 0.784], ["psi", "A", "S", "dEdk"])
m_e = comp["E"]
m_psi = comp["psi"]
m_a = comp["A"]
m_s = comp["S"]
m_dEdk = comp["dEdk"]
# use energies as one of the hashes
ret["E"] = m_e
# use non-hermiticity of A as a check. This quantity should
# ideally be very small.
ret["nonhermA"] = np.sum(np.abs(m_a - m_a.transpose((0, 2, 1, 3)).conjugate()), axis = (1, 2))
ret["nonhermS"] = np.sum(np.abs(m_s - m_s.transpose((0, 2, 1, 3)).conjugate()), axis = (1, 2))
# force A and S to be hermitean, so that later there are no
# issues in computing eigenvalues
m_a = 0.5*(m_a + m_a.transpose((0, 2, 1, 3)).conjugate())
m_s = 0.5*(m_s + m_s.transpose((0, 2, 1, 3)).conjugate())
# signs in front of Bloch states will depend on machine
# you use. So I just make sure that these signs are randomized
# each time you run this code you get a different sign.
# This ensures that signs don't match accidentally.
rndS = np.random.RandomState(seed = None)
rnd = (rndS.random(m_a.shape[1]) > 0.5)
for i,r in enumerate(rnd):
if r == 1:
m_a[:,:,i,:] *= -1.0
m_a[:,i,:,:] *= -1.0
m_s[:,:,i,:] *= -1.0
m_s[:,i,:,:] *= -1.0
m_de = 1.0j*np.sin(m_e[:,:,None] - m_e[:,None,:])
# this hash will keep track of A and S. Need to do product to get rid of dependence
# on the trivial gauge (multiplying a single u-state with a minus sign)
ret["combined_hash"] = \
opteinsum("AijB,AjiB, Aij -> AB", m_s, m_s, m_de) + \
opteinsum("AijB,AjiB, Aij -> AB", m_s, m_a, m_de) + \
opteinsum("AijB,AjkB,AkiB,Aij -> AB", m_s, m_s, m_a, m_de) + \
opteinsum("AijB,AjkB,AkiB,Aij -> AB", m_s, m_a, m_a, m_de) + \
opteinsum("AijB,AjkB,AkiB,Aij -> AB", m_a, m_a, m_a, m_de)
# this hash will keep track of dEdk
ret["dEdk_hash"] = np.sum(m_dEdk) + 1.0j*np.sin(m_dEdk)
# this hash will keep track of psi
m_psi_abs2 = np.abs(m_psi)**2
ret["psi_hash_0"] = np.sum(np.sin(m_psi_abs2[:, 0::2, :] + m_psi_abs2[:, 1::2, :]))
ret["psi_hash_1"] = opteinsum("knp, kmp -> ", np.conjugate(m_psi), m_psi)
ret["psi_hash_2"] = opteinsum("knp, knr -> ", np.conjugate(m_psi), m_psi)
return ret
[docs]class _ComputatorWf():
r"""
This is a class for an object that stores various quantities and then parses
mathematical expressions to compute various other physical quantities. It has
a funny archaic name so that it is not confused with "calculators" from Wannier Berri.
In most cases, you will not need to create computator object on your own using
a constructor. Instead, after loading a database you should use
:func:`do_mesh <wfbase.DatabaseWf.do_mesh>`, :func:`do_path <wfbase.DatabaseWf.do_path>`,
or :func:`do_list <wfbase.DatabaseWf.do_list>` to create the object from this class.
See the example below.
Each quantity has a "value", physical "unit", and sometimes other data, such as
information about how the quantity was constructed, etc.
The "value" of the quantity is simply a numpy array and it can be modified
by the user with any numpy operation, such as transpose or reshape or by slicing
using the [] operator.
Quantities stored in the computator are constructed in one of these three ways.
* First, the quantity could be precomputed by the code. For example, if you
construct this object using :func:`do_mesh <wfbase.DatabaseWf.do_mesh>` the
code will precompute various quantities such as electron band energies,
Berry connection, and so on.
* Second, quantities could be added to the computator by the user by calling
the :func:`new <wfbase._ComputatorWf.new>` function. These added quantites are added
as ordinary numpy arrays, along with some additional information sent
to the :func:`new <wfbase._ComputatorWf.new>` function.
* Third, quantities can be evaluated using the :func:`evaluate <wfbase._ComputatorWf.evaluate>`
function. This function parses a mathematical expression using currently available
quantities and it then stores the resulting quantity in the computator. For example,
one can use precomputed band energy and Berry connection to compute the optical
conductivity, and so on.
.. _structure_computator:
Here is an example of the structure of the computator object *comp*. Two quantities
(*E* and *A*) in this computator were created using function :func:`do_mesh <wfbase.DatabaseWf.do_mesh>`.
The third quantity (*sigma*) was computed using :func:`evaluate <wfbase._ComputatorWf.evaluate>`.
See :ref:`examples page <examples>` for various examples that use computators with these quantites.
For example, you could take a look at :ref:`this example <sphx_glr_all_examples_example_ahc.py>`.
Each quantity below contains several keys, such as *value* or *units*. Third column below shows
you how to access these these keys.
+----------+----------+---------------------------------+
| Quantity | Key | How to access? |
+==========+==========+=================================+
| E | value | | *comp["E"]* |
| | | | *comp.get("E", "value")* |
+ +----------+---------------------------------+
| | units | | *comp.get_units("E")* |
| | | | *comp.get("E", "units")* |
+ +----------+---------------------------------+
| | ... | |
+----------+----------+---------------------------------+
| A | value | | *comp["A"]* |
| | | | *comp.get("A", "value")* |
+ +----------+---------------------------------+
| | units | | *comp.get_units("A")* |
| | | | *comp.get("A", "units")* |
+ +----------+---------------------------------+
| | ... | |
+----------+----------+---------------------------------+
| sigma | value | | *comp["sigma"]* |
| | | | *comp.get("sigma", "value")* |
+ +----------+---------------------------------+
| | units | | *comp.get_units("sigma")* |
| | | | *comp.get("sigma", "units")* |
+ +----------+---------------------------------+
| | latex | | *comp.get_latex("sigma")* |
| | | | *comp.get("sigma", "latex")* |
+ +----------+---------------------------------+
| | ... | |
+----------+----------+---------------------------------+
| ... | | |
+----------+----------+---------------------------------+
To get a list of all quantites in the computator, use function
:func:`all_quantities <wfbase._ComputatorWf.all_quantities>`. To get all keys stored for a
specific quantity, use function :func:`all_quantity_keys <wfbase._ComputatorWf.all_quantity_keys>`.
You can get more information about the quantites in the computator using function
:func:`info <wfbase._ComputatorWf.info>`. :ref:`Here <quantities>` you can find an example
output of function :func:`info <wfbase._ComputatorWf.info>`.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create a computator object called "comp"
# the type of this object is _ComputatorWf
comp = db.do_mesh(formatted_output_latex = False)
# add new quantity "sigma" to the object comp,
# evaluate this quantity from the expression below
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# add new quantity called "omicron" to the object comp
comp.new("omicron", "0.7 * eV")
# change value of one of the previously stored quantity
comp.replace("eta", "0.05 * eV")
# print information about all quantities in this computator
comp.info()
# print information about quantity "sigma"
comp.info("sigma")
# access some of the stored quantities
print(comp["omicron"])
print(comp["sigma"])
# change the value stored for quantity "sigma"
# Now sigma is no longer a 3x3 matrix, but a
# vector. Also, we multiplied it with 10.
comp["sigma"] = 10.0 * comp["sigma"][0, :]
# print information about quantity "sigma"
# Now a warning will be printed that quantity sigma
# has been modified.
comp.info("sigma")
# since [] returns a copy of the stored object,
# the following will NOT change the value stored for quantity "sigma"
# but it will only change the copy of hte array.
cpy = comp["sigma"]
cpy = 20.0 * cpy
print(comp["sigma"]) # --> unchanged by the previous line
# the following also does not change the value of sigma
comp["sigma"][:] = 1.0
# the following, on the other hand, does change the value of sigma
comp["sigma"] = 1.0
# list all quantities stored in this computator
print(comp.all_quantities())
"""
def __init__(self,
quantities={},
formatted_output_latex = True,
doublet_indices = False,
loaded_from_wannierberri = False):
self.__quantities = quantities
self._computated_using = ""
self._total_seconds_initialize = None
self._db_loaded_from_wannierberri = loaded_from_wannierberri
for k in quantities.keys():
self.__check_core(k)
self.__check_quantity_has_required(quantities[k])
self.__set_to_numpy_array(k)
self.__allow_early_changes = True
if doublet_indices == True:
self.__try_to_convert_all_band_indices_from_singlets_to_doublets()
for k in quantities.keys():
self.__check_core(k)
self.__check_quantity_has_required(quantities[k])
self.__set_to_numpy_array(k)
self.__doublet_indices = doublet_indices
self.__allow_early_changes = False
self.__did_user_mess_with_values = {}
for core in self.all_quantities():
self.__did_user_mess_with_values[core] = False
self.__added_later_by_user = []
self.__formatted_output_latex = formatted_output_latex
self._order_parsed = 0
self._reorg_parser = ParserReorg()
# create parsing object from pyparsing
# This code is heavily based on eval_arith.py and simpleArith.py from pyparsing github repository
#
# basic building blocks are either integers or symbols that have letters with additional special characters
# order of these three things below matters
operand = _get_operand()
# This class will be called whenever you wish to evaluate one operand
operand.set_parse_action(EvalConstVar)
self._parser = pp.infix_notation(
operand,
[
(pp.oneOf("Real Imag") , 1, pp.opAssoc.RIGHT, EvalFuncOp ), # various function calls
( "#" , 1, pp.opAssoc.RIGHT, EvalConjugOp ), # complex conjugation
( "&" , 1, pp.opAssoc.RIGHT, EvalDOneOp ), # operation of dividing 1 by the object
( "^" , 2, pp.opAssoc.LEFT , EvalPowerOp ), # power raising (strictly speaking this should be a RIGHT not LEFT to follow conventions. But we don't allow user to do A^B^C so it doesn"t matter.)
(pp.oneOf("+ -") , 1, pp.opAssoc.RIGHT, EvalSignOp ), # sign in front of an object
(pp.oneOf("* /") , 2, pp.opAssoc.LEFT , EvalMultDivOp ), # multiplication and division
(pp.oneOf("+ -") , 2, pp.opAssoc.LEFT , EvalAddSubOp ), # addition and subtraction
(pp.oneOf("<= <+= <<="), 2, pp.opAssoc.LEFT , EvalArrowOp ), # perform assignment
],
)
# this is parser for brute force sums
operand_bfs = _get_operand()
operand_bfs.set_parse_action(BfsConstVar)
self._parser_brute_force_sums = pp.infix_notation(
operand_bfs,
[
(pp.oneOf("Real Imag") , 1, pp.opAssoc.RIGHT, BfsFuncOp ),
( "#" , 1, pp.opAssoc.RIGHT, BfsConjugOp ),
( "&" , 1, pp.opAssoc.RIGHT, BfsDOneOp ),
( "^" , 2, pp.opAssoc.LEFT , BfsPowerOp ),
(pp.oneOf("+ -") , 1, pp.opAssoc.RIGHT, BfsSignOp ),
(pp.oneOf("* /") , 2, pp.opAssoc.LEFT , BfsMultDivOp ),
(pp.oneOf("+ -") , 2, pp.opAssoc.LEFT , BfsAddSubOp ),
(pp.oneOf("<= <+= <<="), 2, pp.opAssoc.LEFT , BfsArrowOp ),
],
)
self._verbose_evaluate = False
def __try_to_convert_all_band_indices_from_singlets_to_doublets(self):
if self["numwann"]%2 != 0:
_raise_value_error("It is not possible to use doublet-index notation if you have odd number of states.")
ene = self["E"]
de = ene[:, 1::2] - ene[:, :-1:2]
if np.max(np.abs(de)) > 1.0E-11:
_raise_value_error("It is not possible to use doublet-index notation as your bands are not at least twice degenerate. Your system either has broken P*T symmetry, or something went wrong with symmetrization.")
for core in self.all_quantities():
if "indices_info" in self.all_quantity_keys(core):
old_quant = self.__get_entire_quantity(core)
new_quant = self.__do_doublets_one_quant(old_quant, core, value_already_doubled = False)
self.__change_quantity(core, new_quant)
def __do_doublets_one_quant(self, old_quant, core, value_already_doubled = False):
quant = deepcopy(old_quant)
indices_info = quant["indices_info"]
if value_already_doubled == False:
for b in indices_info["bands"]:
if quant["value"].shape[b] != self["numwann"]:
_raise_value_error("What is marked as band index in quantity " + core + " does not have the right shape.")
new_shape = []
tmp0 = 0
orig_inds = []
tmp1 = 0
new_inds = []
for i,ii in enumerate(quant["value"].shape):
if i in indices_info["bands"]:
new_shape.append(ii//2)
new_shape.append(2)
new_inds.append(str(tmp1) + str(tmp1 + 1))
tmp1 += 2
else:
new_shape.append(ii)
new_inds.append(str(tmp1))
tmp1 += 1
orig_inds.append(str(tmp0))
tmp0 += 1
if value_already_doubled == False:
quant["value"] = quant["value"].reshape(list(new_shape))
if "format" in quant.keys():
new_format = quant["format"]
for i in range(len(new_inds) - 1, -1, -1):
new_format = new_format.replace("*" + orig_inds[i], "*" + "*".join(new_inds[i]))
quant["format"] = new_format
if "format_conjugate" in quant.keys():
new_format_conjugate = quant["format_conjugate"]
for i in range(len(new_inds) - 1, -1, -1):
new_format_conjugate = new_format_conjugate.replace("*" + orig_inds[i], "*" + "*".join(new_inds[i]))
quant["format_conjugate"] = new_format_conjugate
new_indices_info = {}
new_canonical_names = ""
for j, jj in enumerate(indices_info["canonical_names"]):
if j in indices_info["bands"]:
new_canonical_names += jj.lower() + jj.upper()
else:
new_canonical_names += jj
if len(new_canonical_names) != len(list(set(new_canonical_names))):
_raise_value_error("Indices not unique: " + new_canonical_names)
new_indices_info["canonical_names"] = new_canonical_names
if "definition" in indices_info.keys():
new_definition = indices_info["definition"]
if new_definition is not None:
for i in range(len(new_inds) - 1, -1, -1):
new_definition = new_definition.replace("*" + orig_inds[i], "*" + "*".join(new_inds[i]))
new_indices_info["definition"] = new_definition
new_bands = []
for b in indices_info["bands"]:
new_bands.append(list(map(int, list(new_inds[b]))))
new_indices_info["bands"] = new_bands
new_explanation = []
for j, exp in enumerate(indices_info["explanation"]):
if j in indices_info["bands"]:
new_explanation.append("doublet index (" + exp + ")")
new_explanation.append("index within the doublet (0 or 1)")
else:
new_explanation.append(exp)
new_indices_info["explanation"] = new_explanation
quant["indices_info"] = new_indices_info
return quant
[docs] def get_initialization_time(self):
r"""
Returns time, in seconds, it took to initialize this computator.
Most of this time is spent in calls to Wannier Berri to get
all required quantities, such as band energy, Berry connection,
etc.
"""
return self._total_seconds_initialize
def __getitem__(self, core):
self.__does_core_exist(core)
return deepcopy(self.__quantities[core]["value"])
def __setitem__(self, core, value):
if core not in self.all_quantities():
_raise_value_error("Quantity "+ core + " is not defined. You must add it using .new(...) method function.")
self.__quantities[core]["value"] = value
self.__did_user_mess_with_values[core] = traceback.extract_stack()
for k in ["origin_story", "latex", "exec"]:
if k in self.all_quantity_keys(core):
del self.__quantities[core][k]
self.__quantities[core]["origin_story"] = ""
self.__set_to_numpy_array(core)
def __change_quantity(self, core, quantity):
self.__does_core_exist(core)
if self.__allow_early_changes == True:
self.__quantities[core] = quantity
[docs] def all_quantities(self):
r"""
Returns a list of all quantities stored in the computator object.
:returns:
* **lst** -- list containing all quantities stored in the object.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a mesh
comp = db.do_mesh()
# get a list of all quantities stored in comp
lst = comp.all_quantities()
# access one of the quantities in comp
print(comp["E"])
"""
return list(self.__quantities.keys())
[docs] def all_quantity_keys(self, core):
r"""
Returns a list of keys for all data stored about the single quantity *core*.
:returns:
* **lst** -- list containing keys for quantity *core*
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a mesh
comp = db.do_mesh()
# get a list of all quantities stored in comp
keys = comp.all_quantity_keys("E")
# access one of the keys (units) stored about quantity "E"
print(str(comp.get("E", "units")))
"""
self.__does_core_exist(core)
return sorted(list(self.__quantities[core].keys()))
[docs] def get(self, core, key = "value"):
r"""
Returns value of the specified quantity *core*. This returned
value (typically an array of numbers) is a copy of the value
stored in the computator object. Therefore, if you change the copy
of this value, the one stored in the computator object will not
be changed. If you want to actually change the value of the
quantity stored in the computator use the [] operator, as shown
in the example below, or use the :func:`replace <wfbase._ComputatorWf.replace>`
function.
:param core: name of the quantity you wish to get
:param key: Which part of the quantity you want to get. The default is "value"
which returns the value of the quantity (typically an array of numbers, for
example). This could also be "units" to get the physical unit of the
quantity.
:returns:
* **val** -- returned value of the specified quantity.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# now "energy" is a shallow copy of quantity "E"
energy = comp.get("E")
units = comp.get("E", "units")
print("Energy of the first band at the first kpoint is", energy[0, 0], "in units of", units)
# this will change array "energy" but not quantity "E" stored in the comp!
energy = energy + 10.0
# this will change the quantity "E" stored in the comp
comp["energy"] = energy * 3.4
"""
self.__does_core_exist(core)
return deepcopy(self.__quantities[core][key])
def __get_entire_quantity(self, core):
self.__does_core_exist(core)
return deepcopy(self.__quantities[core])
[docs] def get_shape(self, core):
r"""
Returns the shape of the specified quantity *core*.
:param core: name of the quantity whose shape you want to get.
:returns:
* **shp** -- shape of the quantity
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
print("Band energy is stored in array of this shape: ", comp.get_shape("E"))
"""
self.__does_core_exist(core)
return self.__quantities[core]["value"].shape
[docs] def get_ndim(self, core):
r"""
Returns the dimensionality of the specified quantity *core*.
:param core: name of the quantity whose dimensionality you want to get.
:returns:
* **ndim** -- dimensionality. 0 for single number. 1 for vector, etc.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
print("Band energy is stored in array of this dimensionality: ", comp.get_ndim("E"))
"""
self.__does_core_exist(core)
return self.__quantities[core]["value"].ndim
[docs] def get_latex(self, core):
r"""
This returns an object that stores information about LaTeX'ed
definition of *core*. Here *core* is a quantity that was computed
using the :func:`evaluate <wfbase._ComputatorWf.evaluate>` function.
:param core: Name of the quantity.
:returns:
* **lat** -- object that contains information about LaTeX'ed
definition of *core*.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# now get LaTeX'ed data about this object
lat = comp.get_latex("sigma")
wf.render_latex(lat, "test.png")
wf.display_in_separate_window("test.png")
wf.display_in_terminal("test.png")
"""
self.__does_core_exist(core)
if self.__did_user_mess_with_values[core] == False:
if self._is_parsed(core):
if "ind" in self.all_quantity_keys(core):
use_ind = self.get(core, "ind")
else:
use_ind = ""
if "latex" in self.all_quantity_keys(core):
use_latex = self.get(core, "latex")
else:
use_latex = ""
return _LatexExpression(core, use_ind, use_latex)
return None
[docs] def get_units(self, core):
r"""
Returns units of quantity *core*. This is a product of arbitrary power
of Angstroms, electron-volts, and Bohr's magneton.
:param core: Name of the quantity.
:returns:
* **unit** -- object that contains information about units.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
print("Units of sigma are: ", comp.get_units("sigma"))
"""
self.__does_core_exist(core)
return self.get(core, "units")
[docs] def compute_in_SI(self, core, prefactor = None):
r"""
Internally in WfBase all quantities are specified in units of eV, Angstrom, and
Bohr magneton. This function will return the numerical value of the
physical quantity *core* in SI units. If *prefactor* is not None (default),
then the returned numerical value will be multiplied by *prefactor*.
Here *prefactor* is a string that consists of various constants
of nature (hbar, electron charge, etc, as listed below). If prefactors are specified
then this function will return an additional object that contains LaTeX expression
for the product of the *prefactor* and *core*.
This function does not change any property of the quantity *core* itself. Everything
in the computator, and all cores, are always specified in eV, Angstrom, and Bohr
magneton. The only numbers in SI units are those returned by this function.
:param core: name of the quantity you wish to get.
:param prefactor: optional parameter. If specified then the returned quantity will
be multiplied by this prefactor. The prefactor provided to this function is a string
of the form "e^2 / (hbar * epszero)" or similar. The allowed constants are
"e" for electron charge, "epszero" for vacuum permittivity, "c" for speed of light
"me" for electron mass, and "hbar" for reduced Planck's constant.
:returns:
* **val** -- returned value of the specified quantity.
* **lat** -- object that contains LaTeX expression for the product of *core* and *prefactor*.
Returned only if *prefactor* is not *None* (default).
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# get units of this object
print("Units of sigma are: ", comp.get_units("sigma"))
# convert units to SI and multiply with e^2/hbar
result, result_latex = comp.compute_in_SI("sigma", "e^2 / hbar")
print("result = ", result)
wf.render_latex(result_latex, "latex.png")
wf.display_in_separate_window("latex.png")
wf.display_in_terminal("latex.png")
"""
self.__does_core_exist(core)
ret_value = self.__quantities[core]["units"]._to_SI(self.get(core, "value"))
if prefactor is not None:
ret_latex_obj= self.get_latex(core)
data = _parse_prefactor_SI_units_fundamental_constants(prefactor)
ret_value = data._numerical_value() * ret_value
ret_latex_obj = _LatexExpression(core = "",
ind = "",
rhs = ret_latex_obj._rhs,
prefactor = data._to_latex())
return (ret_value, ret_latex_obj)
else:
return ret_value
[docs] def get_as_dictionary(self, want_cores_in, key = "value"):
r"""
Returns values of multiple quantities at once in form of a dictionary.
:param want_cores_in: A string or a list of strings. Each string is either
a name of the quantity you want to return, or it contains an asterisk or question mark
to match possibly multiple quantities at once. See the example below.
:param key: Which part of the quantity you want to get. The default is "value"
which returns the values of the quantities (typically an array of numbers, for
example).
:returns:
* **dic** -- dictionary containing values of all quantities that match.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# this will return all quantities that match
dic = comp.get_as_dictionary(["atom*", "A", "num?"])
print(dic["numk"])
"""
if isinstance(want_cores_in, str):
want_cores = [want_cores_in]
else:
want_cores = want_cores_in
ret = {}
for k in self.all_quantities():
found_match = False
for c in want_cores:
if fnmatch.fnmatch(k, c):
found_match = True
if found_match == True:
ret[k] = self.get(k, key)
return ret
[docs] def new(self, core, data, units_as = None):
r"""
Adds a new quantity to the computator.
:param core: This is the name of the new quantity.
:param data: This is the data associated with the new quantity.
This can be one of three things.
* First option -- is that *data* can be a dictionary that contains key
"value" that is a number or a numpy array. The
dictionary can also contain key "units" (defaults to dimensionless).
The key "units" should be of type :class:`Units <wfbase.Units>`
as shown in the example below.
User can also specify key "format" and "format_conjugate" which
give a way to format this quantity in LaTeX. See the example
below how to use "format". In short, one needs to specify \*0 at the place
in the LaTeX expression where the first index of the quantity goes,
\*1 for the second, etc. Another entry in the dictionary
could be "origin_story" which is a string describing the quantity *core*.
* Second option -- is that *data* is a string, such
as "3.0 eV * muB^2 / Ang" or similar. Allowed units are eV, Ang, and muB,
for electronvolt, angstrom, and bohr radius. You must use multiplication
signs between units, such as "eV * muB". (You are not allowed to
use "eV muB" or "3.0 eV".) You can use parentheses, division, and power
operator (^).
* Third option -- is to simply make *data* a number or a numpy array. The
units will be set to dimensionless by default (unless you specified those
with parameter *units_as*).
:param units_as: Ignored if set to *None* (default). Otherwise, units of the
new quantity will be equal to the units of quantity *units_as*. Stops
if units were specified through parameter *data* (either first or second option
above).
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# create new quantity with dimensionless value of 3.0
comp.new("alpha", 3.0)
# create new quantity with value of 3.0 * eV^2 * muB^2 / Ang
# (notice that below there must be a multiplication sign after "3.0")
comp.new("beta", "3.0 * (eV * muB)^2 / Ang")
# different way to achieve the same thing
comp.new("gamma", {"value": 3.0, "units": wf.Units(eV = 1)})
# adding extra information about the quantity
comp.new("delta", {"value": 3.0,
"units": wf.Units(eV = 1),
"origin_story": "Give here some information about delta."})
# information about how to display this quantity
comp.new("epsilon", {"value": np.array([[ 10.0 , 9.0 + 3.0j, 1.0 + 2.0j],
[ 9.0 - 3.0j, 12.0 , 22.0 - 8.0j],
[ 1.0 - 2.0j, 22.0 + 8.0j, 45.0 ]]),
"units": wf.Units(eV = 1),
"format": r"\langle v_{*0} \lvert X \rvert v_{*1} \rangle",
"format_conjugate": r"\langle v_{*1} \lvert X \rvert v_{*0} \rangle"})
# one can also define quantities that are vectors, or tensors in general
comp.new("zeta", [3.0, 4.0, 5.0, 6.0])
# units of new quantity eta will be copy-pasted from gamma
comp.new("eta", 3.0, units_as = "gamma")
comp.info("alpha")
comp.info("beta")
comp.info("gamma")
comp.info("delta")
comp.info("epsilon")
# evaluate new quantity using quantity defined earlier
comp.evaluate("kappa_ik <= epsilon_ij * #epsilon_kj")
# expression for kappa should be formatted using the bra-ket notation for epsilon
comp.info("kappa", display = True)
"""
if core in self.all_quantities():
_raise_value_error("Quantity \"" + str(core) + """\" already exists. You can't change its value using function .new.
Btw, you can change the value of the variable with comp[\"""" + core + """\"] = ... if you want. You can also
use .replace instead of .new to rewrite previous quantity with new one.""")
self.__check_similar_tilde(core)
if isinstance(data, str):
data = _parse_value_and_units(data)
if isinstance(data, int) or \
isinstance(data, float) or \
isinstance(data, complex) or \
isinstance(data, np.ndarray) or \
isinstance(data, list):
data = {"value": np.array(data)}
if "units" not in data.keys():
if units_as is not None:
data["units"] = self.get_units(units_as)
else:
data["units"] = Units()
else:
if units_as is not None:
_raise_value_error("You specified units twice. Once using the parameter *data*, second time using parameter *units_as*. Do one or the other, but not both.")
self.__check_core(core)
self.__check_quantity_has_required(data)
self.__quantities[core] = data
self.__did_user_mess_with_values[core] = False
self.__set_to_numpy_array(core)
self.__added_later_by_user.append(core)
[docs] def replace(self, core, data, units_as = None):
r"""
Removes previously existing quantity *core* and replaces it with new
quantity *core* with data provided in *data*.
:param core: This is the name of the quantity that you want to replace.
:param data: This is the data associated with the new quantity.
Same as *data* parameter used in :func:`new <wfbase._ComputatorWf.new>`.
:param units_as: Same as *units_as* parameter used in :func:`new <wfbase._ComputatorWf.new>`.
Example usage::
import wfbase as wf
db = wf.load("data/fe_bcc.wf")
comp = db.do_mesh()
comp.new("beta", "3.0 * (eV * muB)^2 / Ang")
comp.replace("beta", "4.0 * (eV * muB)^2 / Ang")
"""
if core not in self.all_quantities():
_raise_value_error("Quantity \"" + str(core) + "\" does not exist already, so you can't replace it. Use the function .new() instead.")
self.remove(core)
self.new(core, data, units_as)
[docs] def remove(self, core):
r"""
Removes previously existing quantity *core*.
:param core: This is the name of the quantity that you want to remove.
Example usage::
import wfbase as wf
db = wf.load("data/fe_bcc.wf")
comp = db.do_mesh()
comp.new("beta", "3.0 * (eV * muB)^2 / Ang")
comp.remove("beta")
"""
if core not in self.all_quantities():
_raise_value_error("Quantity \"" + str(core) + "\" doesn't exist. Therefore, it can't be removed.")
del self.__quantities[core]
del self.__did_user_mess_with_values[core]
if core in self.__added_later_by_user:
self.__added_later_by_user.remove(core)
def __check_core(self, core):
if core.count("~") > 1:
_raise_value_error("Variable name: " + core + " is invalid. It must contain at most only one ~.")
if core.startswith("~") == True:
_raise_value_error("Variable name: " + core + " is invalid. It can't start with ~.")
if core.endswith("~") == True:
_raise_value_error("Variable name: " + core + " is invalid. It can't end with ~.")
if not core.replace("~","").isalpha() or core == "":
_raise_value_error("Variable name: " + core + " is invalid. It must contain only letters and at most one ~.")
if core == "j":
_raise_value_error("Quantity name j is not allowed as it might be confusing " +
"as we use the same symbol to represent square root of negative one.")
def __does_core_exist(self, core):
if core not in self.all_quantities():
_raise_value_error("Specified core \"" + core + "\" does not exist.")
def __check_similar_tilde(self, core):
core_use = core.replace("~", "")
keys = self.all_quantities()
for k in keys:
if k.replace("~", "") == core_use:
_raise_value_error("Quantity \"" + str(core) + """\" does not already exists, but a similarly named quantity
(ignoring the tilde symbol, ~) does exist! This is not allowed as these are too similar. Pick
a more unique name.""")
def __check_quantity_has_required(self, data):
if "value" not in data.keys():
_raise_value_error("Did not specify value.")
if "units" not in data.keys():
_raise_value_error("Did not specify units.")
def __set_to_numpy_array(self, core):
self.__quantities[core]["value"] = np.array(self.__quantities[core]["value"])
def _return_in_latex(self, core, ind, do_latex_conjugate = False):
if "format" in self.all_quantity_keys(core) and \
self.__formatted_output_latex == True and \
self.__did_user_mess_with_values[core] == False:
performed_latex_conjugate = False
if do_latex_conjugate == False:
use_str = self.get(core, "format")
else:
if "format_conjugate" in self.all_quantity_keys(core):
use_str = self.get(core, "format_conjugate")
performed_latex_conjugate = True
else:
use_str = self.get(core, "format")
ret = _replace_star_with_indices(use_str, ind)
if do_latex_conjugate == True and performed_latex_conjugate == False:
ret = r" \overline{ " + ret.strip() + r" } "
else:
if ind is None:
ret = _nicefy_core(core)
else:
ret = _nicefy_core(core) + r"_{" + _nicefy_subscript(ind) + r"}"
if do_latex_conjugate == True:
ret = r" \overline{ " + ret.strip() + r" } "
return ret
[docs] def compute_occupation(self, out_core = "f", energy = "E", fermi = "ef", kbtemp = None):
r"""
Computes a quantity that has entries close to 1 at all places
where *energy* is less than *fermi* and 0 otherwise. Uses a Fermi-Dirac
distribution if temperature is specified. Otherwise, temperature is zero.
.. note::
This quantity can be used to enforce occupations of states while
evaluating physical quantities. The same effect can be achieved
using the *conditions* tag while calling :func:`evaluate <wfbase._ComputatorWf.evaluate>`
function. The benefit of using *conditions* tag is that it reduces
the number of operations needed to do the computation. See examples below
:param out_core: Name of the occupation factor quantity. Defaults to "f".
(This function will remove previously existing quantity with the same name.)
:param energy: Name of the energy quantity. Defaults to "E".
:param fermi: Name of the fermi level quantity (or a float). Defaults to "ef".
:param kbtemp: The default is None, which means zero temperature. You can also specify
a quantity, or give a floating point number in units of eV.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
comp.compute_occupation("f", "E", "ef")
# computes band energy of occupied states
comp.evaluate("sumA <= E_nk * f_nk")
# same but without using "f"
comp.evaluate("sumB <= E_nk", "E_nk < ef")
# same computation by directly using numpy operations
import numpy as np
sumC = np.sum(comp["E"][ comp["E"] < comp["ef"]])
print(comp["sumA"], comp["sumB"], sumC)
"""
if out_core in self.all_quantities():
self.remove(out_core)
use_energy = np.real(self.get(energy, "value"))
if isinstance(fermi, str):
if self.get(fermi, "units")._check_units_the_same(self.get(energy, "units")) == False:
_raise_value_error("Units of " + energy + " and " + fermi + " are not the same!")
if self.get_ndim(fermi) != 0:
_raise_value_error("Fermi level must be a single number.")
use_fermi = np.real(self.get(fermi, "value"))
else:
use_fermi = np.real(float(fermi))
if kbtemp is None:
value = np.array(use_energy < use_fermi, dtype = float)
data = {
"value": value,
"origin_story": "Has an entry equal to 1 for all entries where " +
str(energy) + " < " + str(fermi) + " otherwise it is 0.",
"units": Units(eV = 0, Ang = 0, muB = 0),
}
else:
if isinstance(kbtemp, str):
if self.get(kbtemp, "units")._check_units_the_same(self.get(energy, "units")) == False:
_raise_value_error("Units of " + energy + " and " + kbtemp + " are not the same!")
if self.get_ndim(kbtemp) != 0:
_raise_value_error("Temperature must be a single number.")
use_kbtemp = np.real(self.get(kbtemp, "value"))
else:
use_kbtemp = np.real(float(kbtemp))
value = _fermi_dirac(use_energy, use_fermi, use_kbtemp)
data = {
"value": value,
"origin_story": "Fermi-Dirac occupation factor between 1 and 0. " +
"It is close to 1 for all entries where " +
str(energy) + " < " + str(fermi) + " otherwise it is close to 0.",
"units": Units(eV = 0, Ang = 0, muB = 0),
}
if "indices_info" in self.all_quantity_keys(energy):
energy_data = self.get(energy, "indices_info")
data["indices_info"] = {"canonical_names": energy_data["canonical_names"],
"explanation": energy_data["explanation"],
"bands": energy_data["bands"],
}
else:
data["origin_story"] += " The indices of this quantity are the same as those of *" + str(energy) + "*."
self.new(out_core, data)
[docs] def compute_occupation_derivative(self, out_core = "dfdE", energy = "E", fermi = "ef", kbtemp = 0.05):
r"""
Computes a derivative of the Fermi-Dirac distribution.
:param out_core: Name of the occupation factor quantity. Defaults to "dfdE".
(This function will remove previously existing quantity with the same name.)
:param energy: Name of the energy quantity. Defaults to "E".
:param fermi: Name of the fermi level quantity (or a float). Defaults to "ef".
:param kbtemp: Quantity (or a floating point number) giving kb*temperature in eV.
Defaults to 0.05 eV.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
comp.compute_occupation_derivative("dfdE", "E", "ef", 0.01)
"""
if out_core in self.all_quantities():
self.remove(out_core)
use_energy = np.real(self.get(energy, "value"))
if isinstance(fermi, str):
if self.get(fermi, "units")._check_units_the_same(self.get(energy, "units")) == False:
_raise_value_error("Units of " + energy + " and " + fermi + " are not the same!")
if self.get_ndim(fermi) != 0:
_raise_value_error("Fermi level must be a single number.")
use_fermi = np.real(self.get(fermi, "value"))
else:
use_fermi = np.real(float(fermi))
if isinstance(kbtemp, str):
if self.get(kbtemp, "units")._check_units_the_same(self.get(energy, "units")) == False:
_raise_value_error("Units of " + energy + " and " + kbtemp + " are not the same!")
if self.get_ndim(kbtemp) != 0:
_raise_value_error("Temperature must be a single number.")
use_kbtemp = np.real(self.get(kbtemp, "value"))
else:
use_kbtemp = np.real(float(kbtemp))
value = _fermi_dirac_deriv(use_energy, use_fermi, use_kbtemp)
data = {
"value": value,
"origin_story": "Derivative of the Fermi-Dirac occupation factor.",
"units": Units(eV = -1, Ang = 0, muB = 0),
}
if "indices_info" in self.all_quantity_keys(energy):
energy_data = self.get(energy, "indices_info")
data["indices_info"] = {"canonical_names": energy_data["canonical_names"],
"explanation": energy_data["explanation"],
"bands": energy_data["bands"],
}
else:
data["origin_story"] += " The indices of this quantity are the same as those of *" + str(energy) + "*."
allinds = ""
for j in range(self.get_ndim(energy)):
allinds += "*" + str(j)
data["format"] = r"\frac{\partial f}{\partial " + energy.strip() + r"_{" + allinds + r"} }"
data["format_conjugate"] = data["format"]
self.new(out_core, data)
[docs] def compute_kronecker(self, out_core, core, ind, core2 = None, ind2 = None):
r"""
Computes a Kronecker delta symbol that has one on diagonal and zero otherwise.
:param out_core: Name of the Kronecker delta quantity. No default.
(This function will remove previously existing quantity with the same name.)
:param core: The Kronecker delta will have the first index the same shape as
*ind*-th index of the quantity *core*.
:param ind: First index to use in the construction of the Kronecker delta.
:param core2: The Kronecker delta will have the second index the same shape as
*ind2*-th index of the quantity *core2*. The default is to use the same as *ind*.
:param ind2: Second index to use in the construction of the Kronecker delta. The default
is to have the same as *ind*.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# this will create kronecker with shape numwann * numwann
comp.compute_kronecker("d", "E", 1)
# this will include only diagonal parts
comp.evaluate("B_nm <= d_nm/(E_km - E_kn + j*eta)")
# this will exclude diagonal parts, where n == m
comp.evaluate("C_nm <= (1.0 - d_nm)/(E_km - E_kn + j*eta)")
"""
if out_core in self.all_quantities():
self.remove(out_core)
if core2 is None:
core2 = core
if ind2 is None:
ind2 = ind
self.new(out_core, {"value": np.eye(self.get_shape(core)[ind], self.get_shape(core2)[ind2]),
"origin_story": """Kronecker delta. Equals one when two indices
are the same (on the diagonal) and equals zero otherwise.""",
"indices_info": {
"canonical_names": "xy",
"explanation": ["same as index #" + str(ind) + " of quantity *" + core,
"same as index #" + str(ind2) + " of quantity *" + core2,
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0)}
)
[docs] def compute_identity(self, out_core, size):
r"""
Gives an identity matrix of shape (size, size). One on diagonal, zero off-diagonal.
:param out_core: Name of this quantity.
:param size: Size of the matrix.
:returns:
* **mat** -- The identity matrix.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# this will create 3x3 identity matrix
comp.compute_identity("one", 3)
"""
if out_core in self.all_quantities():
self.remove(out_core)
self.new(out_core, {"value": np.eye(size, dtype = complex),
"origin_story": "Identity matrix.",
"indices_info": {
"canonical_names": "xy",
"explanation": ["generic index.",
"generic index."
],
"bands": [],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
"format": r"I^{" + str(size) + "}",
"format_conjugate": r"I^{" + str(size) + "}"}
)
[docs] def compute_photon_energy(self, out_core = "hbaromega", emin = 0.5, emax = 3.0, steps = 31):
r"""
Computes an array of photon energy in eV.
:param out_core: Name of the quantity for photon energy. Defaults to "hbaromega".
(This function will remove previously existing quantity with the same name.)
:param emin: Minimal energy in the range. Defaults to 0.5. Units are eV.
:param emax: Maximal energy in the range. Defaults to 3.0. Units are eV.
:param steps: The number of equidistant steps. Defaults to 31.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# this will create photon energies from 0.5 to 5 eV in 51 steps
comp.compute_photon_energy("omicron", 0.5, 5.0, 51)
# this achieves the same thing
import numpy as np
comp.new("delta", {"value": np.linspace(0.01, 5.0, 51, endpoint = True),
"units": wf.Units(eV = 1)})
"""
if out_core in self.all_quantities():
self.remove(out_core)
self.new(out_core, {"value": np.linspace(emin, emax, steps, endpoint = True),
"origin_story": "Photon energies hbar*omega on an equidistant mesh of values.",
"indices_info": {
"canonical_names": "o",
"explanation": ["index of the photon energy",
],
"bands": [],
},
"units": Units(eV = 1, Ang = 0, muB = 0),
"format": r"\hbar \omega_{*0}",
"format_conjugate": r"\hbar \omega_{*0}"})
def __confirm_consistent_definition_of_variables(self, cores, txt):
for c in cores:
if c not in self.all_quantities():
_raise_value_error("Can't compute " + txt.strip() + " without having quantity " + c.strip() + ".")
if self.__did_user_mess_with_values[c] == True:
_raise_value_error("Can't compute " + txt.strip() + " because user has changed quantity A, so I'm not sure what happened to it.")
if c in self.__added_later_by_user:
_raise_value_error("Can't compute " + txt.strip() + " because user has created their own quantity A.")
[docs] def compute_orbital_character(self, out_core):
r"""
Computes approximate orbital character of the wavefunction.
|| < W_p | psi_kn > ||^2
(If you earlier used doublet_indices = True then the indices
above on psi are are *knN* instead of *kn*.)
This is a dimensionless number. The sum of this number over index *p* is 1.0.
:param out_core: Name of the quantity for the orbital character.
(This function will remove previously existing quantity with the same name.)
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# computes orbital character
comp.compute_orbital_character("O")
"""
self.__confirm_consistent_definition_of_variables(["psi"], "orbital character")
if out_core in self.all_quantities():
self.remove(out_core)
if self.__doublet_indices == False:
tmp = self.evaluate("_knp <= Real(psi_knp * #psi_knp)")
else:
tmp = self.evaluate("_knNp <= Real(psi_knNp * #psi_knNp)")
quant = \
{"value": tmp,
"origin_story": """Approximate orbital character of wavefunction
in terms of localized atomic-like function. The sum over orbitals
adds up to 1. The description of localized orbitals are given by
quantity *orbitallabels* (available only if your database was loaded
from the WfBase's database).
""",
"indices_info": {
"definition": r"|| < W_*2 | u_*0*1 > ||^2",
"canonical_names": "knmdo",
"explanation": ["index of a k-point",
"electron band index",
"localized orbital index",
],
"bands": [1],
},
"units": Units(eV = 0, Ang = 0, muB = 0),
"format" : \
r"\lvert \langle W_{*2} \vert \psi_{*0 *1} \rangle \rvert^2",
"format_conjugate" :\
r"\lvert \langle W_{*2} \vert \psi_{*0 *1} \rangle \rvert^2"}
if self.__doublet_indices == True:
quant = self.__do_doublets_one_quant(quant, out_core, value_already_doubled = True)
self.new(out_core, quant)
[docs] def compute_optical_offdiagonal(self, out_core, hbaromega):
r"""
Computes matrix elements for the off-diagonal (interband) interaction
of electrons with electromagnetic waves.
< psi_kn | H_offdiagonal | psi_km >
(If you earlier used doublet_indices = True then the indices above are *knN* and *kmM*
instead of *kn* and *km*.)
Units of this quantity are eV. This effectively assumes that the maximum
electric field of the applied electromagnetic wave is 1 eV/Ang.
The exact computed quantity is
(1 / 2) (E_kn - E_km) * < u_kn | i del_k_d | u_km > / hbaromega_o
The diagonal elements (n = m) are set to zero. (When indices are doubled then
matrix elements within the doublet are set to zero.)
:param out_core: Name of the quantity for the matrix element.
(This function will remove previously existing quantity with the same name.)
:param hbaromega: Name of the quantity containing photon energies.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# compute interband optical matrix element for predefined
# photon energies "hbaromega"
comp.compute_optical_offdiagonal("L", "hbaromega")
# this will create photon energies from 0.01 to 5 eV in 51 steps
comp.compute_photon_energy("omicron", 0.01, 5.0, 51)
# recompute matrix elements for these new energies
comp.compute_optical_offdiagonal("L", "omicron")
"""
if self.get(hbaromega, "units")._check_units_the_same(Units(eV = 1)) == False:
_raise_value_error("Units of " + hbaromega + " are not eV!")
self.__confirm_consistent_definition_of_variables(["A", "E"], "optical matrix element")
if out_core in self.all_quantities():
self.remove(out_core)
if self.__doublet_indices == False:
tmp = self.evaluate("_knmdo <= 0.5 * (E_kn - E_km) * A_knmd / "+hbaromega+"_o")
for i in range(tmp.shape[1]):
tmp[:,i,i,:,:] = 0.0
else:
tmp = self.evaluate("_knNmMdo <= 0.5 * (E_knN - E_kmM) * A_knNmMd / "+hbaromega+"_o")
for i in range(tmp.shape[1]):
tmp[:,i,:,i,:,:,:] = 0.0
quant = \
{"value": tmp,
"origin_story": """Off-diagonal matrix element for electron-light interaction.
Diagonal elements of this matrix are set to zero by hand.
The maximal electric field strength of the incoming light is set to 1 V/Ang.
The unit of the matrix element is therefore energy (eV).
""",
"indices_info": {
"definition": "0.5 (E_*0*1 - E_*0*2) < u_*0*1 | i del_k_*3 | u_*0*2 > / " + hbaromega + "_*4",
"canonical_names": "knmdo",
"explanation": ["index of a k-point",
"electron band index of the bra state",
"electron band index of the ket state",
"direction of the E-field of light in Cartesian (0 for x, 1 for y, 2 for z)",
"index for photon energy " + hbaromega,
],
"bands": [1, 2],
},
"units": Units(eV = 1, Ang = 0, muB = 0),
"format" : \
r"\langle \psi_{*0 *1} \lvert H^{\rm inter}_{*3 *4} \rvert \psi_{*0 *2} \rangle",
"format_conjugate" :\
r"\langle \psi_{*0 *2} \lvert H^{\rm inter}_{*3 *4} \rvert \psi_{*0 *1} \rangle"}
if self.__doublet_indices == True:
quant = self.__do_doublets_one_quant(quant, out_core, value_already_doubled = True)
self.new(out_core, quant)
[docs] def compute_optical_offdiagonal_polarization(self, out_core, hbaromega, polarization):
r"""
Similar to :func:`compute_optical_offdiagonal <wfbase._ComputatorWf.compute_optical_offdiagonal>` but now
the matrix element is computed for a specified polarization only.
:param out_core: Name of the quantity for the matrix element.
(This function will remove previously existing quantity with the same name.)
:param hbaromega: Name of the quantity containing photon energies.
:param polarization: String describing the polarization direction. For linearly
polarized light use "x", or "y", or "z". For circularly polarized light,
use one of these "x + i y", "x - i y", "x + i z", "x - i z", "y + i z", "y - i z".
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# compute interband optical matrix element for predefined
# photon energies "hbaromega"
comp.compute_optical_offdiagonal_polarization("L", "hbaromega", "x + i y")
"""
if self.get(hbaromega, "units")._check_units_the_same(Units(eV = 1)) == False:
_raise_value_error("Units of " + hbaromega + " are not eV!")
self.__confirm_consistent_definition_of_variables(["A", "E"], "optical matrix element")
if out_core in self.all_quantities():
self.remove(out_core)
if self.__doublet_indices == False:
tmp = self.evaluate("_knmdo <= 0.5 * (E_kn - E_km) * A_knmd / "+hbaromega+"_o")
for i in range(tmp.shape[1]):
tmp[:,i,i,:,:] = 0.0
else:
tmp = self.evaluate("_knNmMdo <= 0.5 * (E_knN - E_kmM) * A_knNmMd / "+hbaromega+"_o")
for i in range(tmp.shape[1]):
tmp[:,i,:,i,:,:,:] = 0.0
use_pol = polarization.replace(" ", "").lower().replace("i", "j")
if use_pol == "x":
direction = np.array([1.0, 0.0, 0.0], dtype = complex)
elif use_pol == "y":
direction = np.array([0.0, 1.0, 0.0], dtype = complex)
elif use_pol == "z":
direction = np.array([0.0, 0.0, 1.0], dtype = complex)
elif use_pol in ["x+jy", "x+yj"]:
direction = np.array([1.0, 1.0j, 0.0], dtype = complex)
elif use_pol in ["x-jy", "x-yj"]:
direction = np.array([1.0,-1.0j, 0.0], dtype = complex)
elif use_pol in ["x+jz", "x+zj"]:
direction = np.array([1.0, 0.0, 1.0j], dtype = complex)
elif use_pol in ["x-jz", "x-zj"]:
direction = np.array([1.0, 0.0,-1.0j], dtype = complex)
elif use_pol in ["y+jz", "y+zj"]:
direction = np.array([0.0, 1.0, 1.0j], dtype = complex)
elif use_pol in ["y-jz", "y-zj"]:
direction = np.array([0.0, 1.0,-1.0j], dtype = complex)
else:
_raise_value_error("Unknown polarization. Must be one of these: x, y, z, x+iy, x-iy, x+iz, x-iz, y+iz, y-iz.")
if use_pol in ["x", "y", "z"]:
pol_type = "linear"
else:
pol_type = "circular"
latex_pol = use_pol.replace("j", "i").replace("+", " + ").replace("-", " - ")
latex_pol = latex_pol.replace("x", r"{\rm x}").replace("y", r"{\rm y}").replace("z", r"{\rm z}")
if "+" in latex_pol:
latex_pol_conj = latex_pol.replace("+", "-")
elif "-" in latex_pol:
latex_pol_conj = latex_pol.replace("-", "+")
else:
latex_pol_conj = latex_pol
if self.__doublet_indices == False:
tmp = opteinsum("knmdo, d -> knmo", tmp, np.conjugate(direction))
else:
tmp = opteinsum("knNmMdo, d -> knNmMo", tmp, np.conjugate(direction))
quant = {"value": tmp,
"origin_story": """Off-diagonal matrix element for electron-light interaction
for a """ + pol_type + " " + use_pol + """ polarization of light.
Diagonal elements of this matrix are set to zero by hand.
The maximal electric field strength of the incoming light is set to 1 V/Ang.
The unit of the matrix element is therefore energy (eV).
""",
"indices_info": {
"definition": "0.5 (E_*0*1 - E_*0*2) < u_*0*1 | i delk_" + use_pol.strip() + " | u_*0*2 > / " + hbaromega + "_*3",
"canonical_names": "knmo",
"explanation": ["index of a k-point",
"electron band index of the bra state",
"electron band index of the ket state",
"index for photon energy " + hbaromega,
],
"bands": [1, 2],
},
"units": Units(eV = 1, Ang = 0, muB = 0),
"format" : \
r"\langle \psi_{*0 *1} \lvert H^{\rm inter}_{ " + latex_pol + r", *3 } \rvert \psi_{*0 *2} \rangle",
"format_conjugate" :\
r"\langle \psi_{*0 *2} \lvert H^{\rm inter}_{ " + latex_pol_conj + r", *3 } \rvert \psi_{*0 *1} \rangle"}
if self.__doublet_indices == True:
quant = self.__do_doublets_one_quant(quant, out_core, value_already_doubled = True)
self.new(out_core, quant)
[docs] def compute_hbar_velocity(self, out_core):
r"""
Computes matrix elements for the diagonal (intraband) and
off-diagonal (interband) velocity operator times hbar,
< psi_kn | hbar * v_d | psi_km >
Units of this quantity are eV * Ang
The exact computed quantity for the diagonal is
(d E_kn / d k_d) delta_nm
for off-diagonal (n != m) it is
i (E_kn - E_km) * < u_kn | i del_k_d | u_km >
(If you earlier used doublet_indices = True then the indices above are *knN* and *kmM*
instead of *kn* and *km*.)
:param out_core: Name of the quantity for the matrix element.
(This function will remove previously existing quantity with the same name.)
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
comp.compute_hbar_velocity("hbarv")
"""
self.__confirm_consistent_definition_of_variables(["A", "E", "dEdk"], "velocity matrix element")
if out_core in self.all_quantities():
self.remove(out_core)
if self.__doublet_indices == False:
tmp = self.evaluate("_knmd <= 1.0j * (E_kn - E_km) * A_knmd")
diagonal = self["dEdk"]
for i in range(tmp.shape[1]):
tmp[:,i,i,:] = diagonal[:,i,:]
else:
tmp = self.evaluate("_knNmMd <= 1.0j * (E_knN - E_kmM) * A_knNmMd")
diagonal = self["dEdk"]
for i in range(tmp.shape[1]):
for I in range(tmp.shape[2]):
for J in range(tmp.shape[4]):
if I == J:
tmp[:,i,I,i,J,:] = diagonal[:,i,I,:]
else:
tmp[:,i,I,i,J,:] = 0.0
quant = {"value": tmp,
"origin_story": """Matrix element of hbar*velocity operator.
Includes both diagonal elements (computed from the Fermi velocity) and off-diagonal
elements (computed from the Berry connection).
""",
"indices_info": {
"definition": "[if *1 == *2] (d E_*0*1 / d k_*3) delta_*1*2\n[if *1 != *2] i (E_*0*1 - E_*0*2) < u_*0*1 | i del_k_*3 | u_*0*2 >",
"canonical_names": "knmd",
"explanation": ["index of a k-point",
"electron band index of the bra state",
"electron band index of the ket state",
"direction of velocity in Cartesian (0 for x, 1 for y, 2 for z)",
],
"bands": [1, 2],
},
"units": Units(eV = 1, Ang = 1, muB = 0),
"format" : \
r"\langle \psi_{*0 *1} \lvert \hbar v_{*3} \rvert \psi_{*0 *2} \rangle",
"format_conjugate" :\
r"\langle \psi_{*0 *2} \lvert \hbar v_{*3} \rvert \psi_{*0 *1} \rangle"}
if self.__doublet_indices == True:
quant = self.__do_doublets_one_quant(quant, out_core, value_already_doubled = True)
self.new(out_core, quant)
def _process_all_greater_lesser(self, conditions, code, brute_force_sums):
conditions = re.split(",|:|;|\n", conditions)
self._filters = []
counter = 0
if brute_force_sums == False:
code._add_raw("_f = []\n", to_preamble = True)
for cond in conditions:
cond = cond.strip()
if cond == "":
continue
if "<" in cond:
sp = cond.split("<")
op = "<"
elif ">" in cond:
sp = cond.split(">")
op = ">"
elif "!=" in cond:
continue
else:
_raise_value_error("Condition must contain either < or > or !=.")
if len(sp) != 2:
_raise_value_error("Condition must contain ony one of < or >.")
left = sp[0].strip()
right = sp[1].strip()
sp = left.split("_")
if len(sp) != 2:
_raise_value_error("Left of " + op + " there has to be a tensor written as A_ijk or similar.")
left_core = sp[0].strip()
left_ind = sp[1].strip()
if "_" in right:
_raise_value_error("Left of " + op + " there has to be a constant not a tensor.")
if left_core not in self.all_quantities():
_raise_value_error("Tensor " + left_core + " in the greater/lesser condition on the left, is not defined.")
if len(left_ind) == 0:
_raise_value_error("Tensor " + left_core + " in the greater/lesser condition on the left should have indices specified.")
if len(left_ind) != len(self.get_shape(left_core)):
_raise_value_error("Tensor " + left_core + " in the greater/lesser condition on the left was not specified with the correct number of indices.")
if len(left_ind) != len(set(left_ind)):
_raise_value_error("Tensor " + left_core + " in the greater/lesser condition on the left should not have repeating indices.")
parsing_right = pp.Combine( pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine( pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Word( pp.nums, min = 1).set_results_name("integer") | \
pp.Word("-" + pp.nums, min = 1).set_results_name("integer") | \
pp.Combine(pp.Word(pp.alphas, min = 1) + "~" + pp.Word(pp.alphas, min = 1)).set_results_name("constant") | \
pp.Word(pp.alphas, min = 1).set_results_name("constant")
par = _my_parse_string(parsing_right, right, parse_all = True)
if par.get_name() in ["integer", "float"]:
right = str(right)
right_bfs = right
right_latex = str(right)
right_type = "int or float"
elif par.get_name() in ["constant"]:
if right not in self.all_quantities():
_raise_value_error("Constant: " + right + " in the greater/lesser condition on the right, is not defined.")
right_latex = self._return_in_latex(right, None)
right_bfs = "__object_" + right
right = "__object[\"" + right + "\"]"
right_type = "constant"
else:
_raise_value_error("Should not happen.")
if brute_force_sums == False:
code._add_raw("_f.append(np.real(__object[\"" + left_core + "\"]) " + op + " np.real(" + right + "))")
# these are filters to be used to enforce which terms appear in the sum
txt = "_f[" + str(counter) + "]"
for i, one_ind in enumerate(left_ind):
tmp = []
for j in range(len(left_ind)):
if i == j:
tmp.append("_s[\"" + one_ind + "\"]")
else:
tmp.append(":")
txt += "[" + ",".join(tmp) + "]"
self._filters.append({
"cond_value": txt,
"cond_inds": left_ind,
"cond_latex": self._return_in_latex(left_core, left_ind) + " " + op + " " + right_latex,
"cond_for_brute_force": {"left_core": left_core,
"left_index": left_ind,
"op": op,
"right": right_bfs,
"right_type": right_type}
})
counter += 1
# now compute slices which simplify computation by removing terms that will be removed eventually anyways.
# I call this partially filtered as slicing operations always produce rectangular arrays.
# you can't slice indices one by one and get anything that is not rectangular
tt = ""
for f in self._filters:
tt += f["cond_inds"]
tt = "".join(sorted(set(tt)))
self._partially_filtered_indices = tt
if brute_force_sums == False:
code._add_raw("_s = {}\n", to_preamble = True)
for ind in self._partially_filtered_indices:
txt = ""
txt += "_s[\"" + str(ind) + "\"] = ("
# now go over all filters and find in which is ind appearing
tmp = []
for f, ff in enumerate(self._filters):
if ind in ff["cond_inds"]:
to_sum_over = list(range(len(ff["cond_inds"])))
to_sum_over.remove(ff["cond_inds"].index(ind))
tmp.append("np.sum(_f[" + str(f) + "], axis = " + str(tuple(to_sum_over)) + ")")
if len(tmp) == 0:
_raise_value_error("This should not happen.")
txt += " + ".join(tmp)
txt += ") > 0\n"
if brute_force_sums == False:
code._add_raw(txt)
#code._add_raw("print(\"Value of _s[" + str(ind) +"] is : \", _s[\"" + str(ind) + "\"])" + "\n")
to_loop_over = ", ".join(list(map(lambda x: "\"" + x + "\"", list(self._partially_filtered_indices))))
if len(to_loop_over) > 0:
txt = ""
txt += "for ind in [" + to_loop_over + "]:" + "\n"
# make sure you don't slice too much
txt += " if True not in _s[ind]:\n"
txt += " raise ValueError(\"Condition on index \" + ind + \" is so restrictive that it removes all elements.\")\n"
# no point in slicing things if there is nothing to slice
txt += " if False not in _s[ind]:\n"
txt += " _s[ind] = slice(None)\n"
if brute_force_sums == False:
code._add_raw(txt)
def _process_all_diagonals(self, conditions, code, brute_force_sums):
conditions = re.split(",|:|;|\n", conditions)
self._diagonals = []
if brute_force_sums == False:
code._add_raw("_orig_shp = {}\n", to_preamble = True)
for cond in conditions:
cond = cond.strip()
if cond == "":
continue
if "<" in cond:
continue
elif ">" in cond:
continue
elif "!=" in cond:
sp = cond.split("!=")
else:
_raise_value_error("Condition must contain either < or > or !=.")
if len(sp) != 2:
_raise_value_error("Condition on diagonals must contain ony one of !=")
left = sp[0].strip()
right = sp[1].strip()
parsing_indices = pp.Word(pp.alphas, min = 1, max = 1).set_results_name("single_index")
left = _my_parse_string(parsing_indices, left, parse_all = True)
left = left[0]
right = _my_parse_string(parsing_indices, right, parse_all = True)
right = right[0]
if left == right:
_raise_value_error("Condition on diagonals can't be between same indices!")
if [left, right] not in self._diagonals and [right, left] not in self._diagonals:
self._diagonals.append(sorted([left, right]))
tt = ""
for dia in self._diagonals:
tt += "".join(dia)
tt = "".join(sorted(set(tt)))
self._all_diagonal_indices = tt
self._all_diagonal_indices_stored_shape = ""
for dia in self._diagonals:
# if indices in the diagonal don't have the same shape that could lead to ambiguities
if brute_force_sums == False:
code._add_raw("if \"" + dia[0] +"\" in _orig_shp.keys() and \"" +
dia[1] +"\" in _orig_shp.keys():\n if _orig_shp[\"" +
dia[0] + "\"] != _orig_shp[\"" + dia[1] + "\"]:\n raise ValueError(\"Indices " +
dia[0] + " and " + dia[1] + " must be of the same length, as they are used in a condition.\")", \
to_preamble = False, to_the_top = True)
[docs] def evaluate(self,
in_eqns,
conditions = "",
brute_force_sums = False,
optimize_divisions = True,
optimize_recomputation = True,
show_latex_with_div_opts = False):
r"""
This function can be used to evaluate a wide range of mathematical expressions.
Einstein summation over repeating indices is assumed.
Quick example::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# calculate something
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
:param in_eqns: String containing mathematical expression(s) you wish to evaluate.
This string can contain the following.
* "Numbers," -- either integer, real, or complex. For example, you can
use numbers in these formats 2, 3.14, 3j, 4.3 + 2.0j, j.
* "Quantities" -- defined in this computator. For example, "E_kn" corresponds
to the band energy associated with k-point k and band n. This quantity
can be anyone that is currently defined for the computator. You can use the
function :func:`all_quantity_keys <wfbase._ComputatorWf.all_quantity_keys>` to
get a list of all quantities currently in the computator. You can
also use :func:`info <wfbase._ComputatorWf.info>` to get
extended information about all quantities in the computator.
Note that you can use any set of letters as indices for quantities.
For example, if your expression contains multiple
sums over bands, you can use "E_kn" and "E_km", or any other combination
of indices such as "E_qw". Quantities that are constants (not tensors with
indices) you can simply refer to as "Q" or "QwR" or similar. For constants
don't use the underscore symbol, as these quantities don't have indices.
In the name of the quantity, you can use ~ symbol (once at most). When doing
LaTeX rendering everything after ~ will be rendered as a superscript. For example
"Abc~intra_knm" will be rendered in LaTeX as "Abc^{intra}_{knm}".
* "Basic operators" -- such as addition (+), subtraction (-), multiplication (*),
division (/), exponentiation (^). Note that the multiplication symbol *
is required. For example, you must use "A_ij * B_jk". If you simply use
"A_ij B_jk" you will get an error message. Certain potentially ambiguous expressions
are not allowed, such as "A / B * C". You should either use "A / (B * C)"
or "(A / B) * C" depending on your intention. Similarly, you
are not allowed to write "A^B^C" as that could be ambiguous. Instead,
write "(A^B)^C" or "A^(B^C)".
* "Plus and minus symbol" -- can be used either in between two quantities,
such as "A_ij - B_ij", but you can also use it in front of a symbol, such
as "-A_ij - B_ij". This expression will change the sign of the
quantity "A" relative to the former expression, as expected.
* "Parentheses." -- Only regular parentheses "(" and ")" are allowed.
* "Assignment operators" -- are one of these three: <=, <+=, or <<=.
The assignment operator assigns to the new quantity on the
left the expression given on the right. For example, "A_ik <= B_ij * C_jk"
would do a regular matrix multiplication of B and C and assign
the result to a new quantity A. Related assignment operator "<+=" assumes
that the quantity "A" on the left already exists. This operator will compute
the value of the right hand side and add it to the value of the
preexisting quantity "A". This is similar to the behavior of "+="
operator in python. Finally, "<<=" operator will erase the previously existing
quantity "A" on the left, and it will create a new quantity with the
value of whatever is on the right of "<<=".
On the left of "<=" you can either have a tensor quantity "A_ik", or a constant
quantity "D", or you can simply have "_ik", or simply "_". If you have
something like "_ik" on the left of "<=" then this function
will simply return a numerical value of the tensor and will not create
any new quantity. Finally, if you simply write "A_ij * B_jk" without even
using "<=" operator, you will again get a numerical value of the matrix product
without creating a new quantity. The indices of the returned quantity in this case
will be simply sorted in alphabetical order ("ik") in the present case. To reduce
ambiguity, it is therefore recommended to always use one of the assignment operators
in your computations. (This option is not allowed when *brute_force_sums* is *True*.
One must use one of the arrow operators (<=, <+=, <<=) if *brute_force_sums* is *True*.)
Assignment operators can be used only once per expression.
* "Complex conjugation" -- is done using the # operator. Note that this operator
does not transpose any of the indices of the matrix, it only does complex conjugation.
The # operator must appear before the quantity whose complex conjugation you
are taking. For example, in this case, quantity B would be conjugated: "A_ij * #B_jl".
If you also wish to transpose B then you need to do so explicitly "A_ij * #B_lj" by
swapping indices on B. (If quantity B_jl is a matrix element of operator O,
"< j | O | l >" then #B_jl will refer simply to "conjugate(< j | O | l >)". Now, if operator O
happens to be a Hermitian operator, then this quantity is by definition equal to
<l|O|j>. LaTeX parser will display in this case <l|O|j> instead of
conjugate(<j|O|l>). In other words, #B_jl will be parsed into <l|O|j> instead
of conjugate(<j|O|l>) which doesn't look as pretty.)
* "Functions" -- Real and Imag can also be used in the expression. These take
real and imaginary parts of the complex numbers. For example, "Real(A_ij + B_ji)"
:param conditions: This string contains all restrictions you wish to perform on the
sums in your computation. The string can contain more than one condition. Conditions
must be separated from each other by a comma. For example, "E_kn < ef , E_km > ef, n != m"
would limit any sum containing indices k and n to only those for which E_kn is less than ef.
An additional limitation would occur for the sum over k and m. The third condition would ensure that no
sum over n and m includes the n==m term. Also, if n and m appear in the output indices, the diagonal
terms are set to zero. (The code will also avoid dividing by n==m term of the tensors in the denominator.
This is useful if you divide 1/(A_n - A_m) and then use the condition n!=m.)
The lesser/greater conditions must be formatted so that
on the left of < or > there is a single tensor quantity. The indices should not repeat.
If one of the indices is summed over, the other index in the condition must appear in the same sum.
Otherwise, the meaning of the condition is imprecise. (For example, this would occur if one has
the condition "E_kn < ef" but in the expression we have a sum only over n, but not over k. For example,
if the expression "B_k <= E_kn * E_kn" we have a sum over repeated index n but not over k.)
On the right of < or > you must use either a constant (quantity without indices) or a number.
No other formatting is allowed for conditions with lesser/greater.
Only real parts of the left and right hand side are used in determining conditions.
The formatting of conditions involving != is much simpler. On the left and right of != you
can only have a single index.
:param brute_force_sums: Boolean. The default is False. The code will use different algorithms to evaluate
the quantity, depending on the value of this parameter. If False (default) then the code will
evaluate quantity using numpy vectorization. If True, then the code will use brute-force for loops
compiled via Numba. The two approaches should give numerically the same values. Depending
on your machine, it is probably more optimal to use default (False) when you have moderately
dense k-grids. If you need to sample denser k-grids, you should probably use random sampling
and perform several calculations with a moderately dense k-grid unless you reach convergence.
See :ref:`this example <sphx_glr_all_examples_example_conv.py>` for more details on how to do this. If you really
have a need to do a single-shot calculation with very many k-points, you may want to set this
flag to *True* and test if that makes your calculation faster. (Whether brute-force sums are faster or not
will depend not only on the number of k-points but also the type of calculation you are evaluating.)
(There is a slight difference in the two approaches (numpy vectorization vs Numba)
since in certain edge cases numpy vectorization approach will stop, while Numba approach will still do the
calculation. For example, this will happen in certain edge cases with *conditions*, if
some condition is too restrictive on the sum, numpy vectorization will stop, but with this parameter set to *True* it
will do the calculation. Another difference is that when *brute_force_sums* is *True* you must
use one of the arrow operators (<=, <+=, or <<=) in all of your evaluation expressions.)
:param optimize_divisions: Boolean. The default is True. If True then the code will internally
preoptimize the expression by replacing expressions such as A*B*C/(D*E) with something like A*B*C*(1/D)*(1/E).
In many cases, this is faster to evaluate (less overhead) as this is now a single product
of five quantities (A, B, C, 1/D, and 1/E). For debugging purposes, the user could set this flag
to False, but True should give the same result faster in most cases. If in doubt, inspect the output of
:func:`info <wfbase._ComputatorWf.info>` (with the *show_code* flag set to True) to see what exact operation the code
uses to evaluate your expression. (This parameter is ignored if *brute_force_sums* is *True*.)
:param optimize_recomputation: Boolean. The default is True. If True then the code will optimize certain
computations. For example, if the expression (E_nk - E_mk) appears twice in the single call
to this function (it could be either in the same line or in two different lines, as long
as it is in the same call to this function), then in the second appearance of this expression
code will reuse previously computed value. For debugging purposes, the user could set this
flag to False, but True should give the same result faster. If in doubt, inspect the output of
:func:`info <wfbase._ComputatorWf.info>` (with *show_code* flag set to True) to see what exact operation the code
uses to evaluate your expression. (This parameter is ignored if *brute_force_sums* is *True*.)
:param show_latex_with_div_opts: Boolean. The default is False. If True then the code will return LaTeX'd
code that includes optimizations that were potentially used if the *optimize_divisions* flag is True.
Example usage::
import wfbase as wf
import numpy as np
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
comp.verbose_evaluate()
# This creates new quantity, called "test".
# This quantity depends on one index (k).
# The value of this quantity is the sum of squares
# of all band energies with that k. The sum over *n* is implied as
# index *n* does not appear on the left of <=
comp.evaluate("test_k <= E_kn * E_kn")
print(comp["test"])
# direct way to compute the same thing using numpy
direct = np.sum(comp["E"]*comp["E"], axis = 1)
print(direct)
# this removes quantity "test"
comp.remove("test")
# same calculation as above, but using brute-forced sums and Numba
# instead of numpy vectorization
comp.evaluate("test_k <= E_kn * E_kn", brute_force_sums = True)
comp.remove("test")
# since now index *n* appears next to the tensor "test"
# on the left of <= we don't sum over n
comp.evaluate("test_kn <= E_kn * E_kn")
comp.remove("test")
# now we sum over both *k* and *n* as neither index
# appears next to the tensor "test" on the left of <=
comp.evaluate("test <= E_kn * E_kn")
comp.remove("test")
# this basically makes "test" be a copy of "E", as nothing
# is summed over
comp.evaluate("test_kn <= E_kn")
comp.remove("test")
# this makes "test" be a transpose of "E"
comp.evaluate("test_nk <= E_kn")
comp.remove("test")
# now the sum is performed only over occupied states
comp.evaluate("test <= E_kn * E_kn", "E_kn < ef")
comp.remove("test")
# now the sum is performed only over occupied states with energy above -2 eV.
comp.evaluate("test <= E_kn * E_kn", "E_kn < ef, E_kn > -2.0")
comp.remove("test")
# now the sum is performed only over occupied states with energy above -2 eV,
# the diagonal terms (n == m) are not included in the sum
comp.evaluate("test <= E_kn * E_kn", "E_kn < ef, E_kn > -2.0, n != m")
comp.remove("test")
# this will create a tensor with four indices. Tensor "E" will be
# repeated over the missing indices.
comp.evaluate("test_knmi <= E_kn * A_kmni")
comp.remove("test")
# same thing, but now we perform sum over three of the indices
# so we are left with one index (k)
comp.evaluate("test_k <= E_kn * A_kmni")
comp.remove("test")
# this will create a tensor with three indices. E_kn does
# not depend on m, so it will simply copy E_kn along m'th index.
comp.evaluate("test_knm <= E_kn - E_km")
comp.remove("test")
# this can be made more complicated as well...
comp.evaluate("test_ki <= (E_kn - E_km) * A_knmi * (E_kn + E_km)")
comp.remove("test")
# One is allowed to use numbers as well (but there must be a
# multiplication sign between any two quantities if you want them
# multiplied. So, you must use ... 2 * E_kn ... not simply ... 2 E_kn ...
comp.evaluate("test_ki <= (E_kn - E_km) * A_knmi * (E_kn + 2 * E_km)")
comp.remove("test")
# One can use "eV", "Ang", or "muB" to introduce quantities in the expression with units
comp.evaluate("test_ki <= (E_kn - E_km) * A_knmi * (E_kn + 2 * E_km + 0.1 * eV)")
comp.remove("test")
# ... division is also allowed, as well as complex numbers
comp.evaluate("test_ki <= (E_kn - E_km) * A_knmi / (E_kn + 2j * E_km)")
comp.remove("test")
# ... and multiple levels of parentheses
comp.evaluate("test_ki <= (E_kn - E_km) * A_knmi / ((E_kn + 2j * E_km) * A_knni)")
comp.remove("test")
# ... you can also raise to a power
comp.evaluate("test_ki <= (E_kn - E_km)^2 * A_knmi")
comp.remove("test")
# ... and take complex conjugate, or real, or imaginary part
# (note that complex conjugation does not transpose any indices! It only
# complex conjugates each element of the tensor)
comp.evaluate("test_ki <= (E_kn - E_km) * #A_knmi * Real(1 / (E_kn - 2j* E_km)) * Imag(A_knni)")
# you can redefine quantity "test" that already exists, without the need to call "remove"
# notice how we use here <<= instead of <=
comp.evaluate("test_knm <<= E_kn - E_km")
# you can also add to the previously existing quantity using <+=
comp.evaluate("test_knm <+= E_kn + 3.0 * E_km")
comp.remove("test")
# notice how the regular use of <= creates quantity "test"
comp.evaluate("test_knm <= E_kn - E_km")
# if you simply want to get the numerical array, and you
# don't want to create a new quantity, you can simply do:
value = comp.evaluate("_knm <= E_kn - E_km")
# if you don't use <= at all, then there will be no sums performed
# and the order of indices is alphabetical (kab in the case below)
value = comp.evaluate("E_kb - E_ka")
"""
split_eqns = []
for ln in in_eqns.split("\n"):
l = ln.strip()
if l != "":
split_eqns.append(l)
if len(split_eqns) < 1:
_raise_value_error("Need to specify at least one valid equation when calling evaluate function.")
# put code for each line in input equations in here
code = _StorePythonCode(optimize_recomputation = optimize_recomputation)
code_only_for_latex = _StorePythonCode(optimize_recomputation = optimize_recomputation)
self._process_all_diagonals(conditions, code, brute_force_sums)
self._process_all_greater_lesser(conditions, code, brute_force_sums)
if brute_force_sums == False:
code_dic = {"opteinsum": opteinsum, "np": np}
else:
code_dic = {"njit": njit, "np": np}
num_eqns = len(split_eqns)
# go over each equation
for eqn_i in range(num_eqns):
raw_eqn = split_eqns[eqn_i]
# this removes all divisions and returns inverse operator & instead.
# also, this removes some parentheses that are not needed.
# That should generate a more optimized code.
use_eqn = raw_eqn
if brute_force_sums == False:
if optimize_divisions == True:
use_eqn = self._reorg_parser.reorganize(raw_eqn)
ret, store_to, mode = self._work_on_one_equation_using_vectorizations(use_eqn, code, code_dic, eqn_i, num_eqns)
else:
ret, store_to, mode = self._work_on_one_equation_using_brute_force_sums(use_eqn, code, code_dic, eqn_i, num_eqns)
# return numpy array immediately if you are doing things like _ik <= A_ij * B_jk or simply A_ij * B_jk
if store_to is None:
return ret["value"]
if store_to == "no use of left arrow":
_check_that_return_indices_compatible_with_filter(ret["ind"], self._filters)
return ret["value"]
# get latex by parsing everything again, but now with raw eqn
if mode in ["<=", "<<="]:
if show_latex_with_div_opts == True:
use_for_latex_eqn = use_eqn
else:
use_for_latex_eqn = raw_eqn
par_only_for_latex = _my_parse_string(self._parser, use_for_latex_eqn, parse_all = True)
par_only_for_latex = par_only_for_latex[0]
if isinstance(par_only_for_latex, EvalArrowOp):
ret_only_for_latex, _ignore, _ignore = par_only_for_latex.eval(self,
code_only_for_latex, allow_storing_data = True, call_from_main_evaluate = True)
else:
ret_only_for_latex = par_only_for_latex.eval(self, code_only_for_latex)
ret["latex"] = ret_only_for_latex["latex"]
# otherwise, store everything that is needed
ret["parsed"] = True
if mode == "<=": # definition
ret["parsed_string"] = raw_eqn
ret["parsed_condition"] = conditions
ret["order_parsed"] = self._order_parsed
self._order_parsed += 1
ret["brute_force_sums"] = brute_force_sums
if store_to in self.all_quantities():
_raise_value_error("Quantity " + store_to + \
" already defined! You can't use <= if quantity on the left of it was defined already. Use <<= to overwrite previously defined quantity.")
self.new(store_to, ret)
elif mode == "<<=": # replacement
ret["parsed_string"] = raw_eqn
ret["parsed_condition"] = conditions
ret["brute_force_sums"] = brute_force_sums
if store_to not in self.all_quantities():
_raise_value_error("Quantity " + store_to + \
" was not defined earlier! Operator <<= is used to replace the tensor on the left of it with new value.")
self.remove(store_to)
self.new(store_to, ret)
if brute_force_sums == False:
code.changed_value_of_this_core("__object[\"" + store_to + "\"]")
elif mode == "<+=": # addition
if store_to not in self.all_quantities():
_raise_value_error("Quantity " + store_to + " not defined! Operator <+= can only be used when quantity on the left of it was defined already")
if self.get_units(store_to)._check_units_the_same(ret["units"]) == False:
_raise_value_error("Units don't match when using the <+= operator. On the left of <+= units are: " + str(self.get_units(store_to)) + \
" while on the right of <+= units are: " + str(ret["units"]) + ".")
self[store_to] += ret["value"]
if brute_force_sums == False:
code.changed_value_of_this_core("__object[\"" + store_to + "\"]")
else:
_raise_value_error("Unknown mode " + mode + ".")
if mode in ["<=", "<<="]:
if self._verbose_evaluate == True:
self.info(store_to)
# how many times did code do recomputation optimization
self._from_last_evaluation_num_used_stored = code.how_many_times_used_stored()
def _work_on_one_equation_using_vectorizations(self, use_eqn, code, code_dic, eqn_i, num_eqns):
code_txt_import = ""
code_txt_import += "import numpy as np\n"
code_txt_import += "from opt_einsum import contract as opteinsum\n"
# This will parse the string. Code is not yet generated. That happens later when you call par.eval(...)
par = _my_parse_string(self._parser, use_eqn, parse_all = True)
if len(par) != 1:
_raise_value_error("PyParsing returned something not expected?!")
par = par[0]
# check whether this topmost thing in the parsing tree is using one of "<=" operators
if isinstance(par, EvalArrowOp):
# this generates python code equivalent for one equation
# if user specified _ij <= A_ij * ... then store_to will be None
# otherwise store_to will be X if you did X_ij <= A_ij * ...
ret, store_to, mode = par.eval(self, code, allow_storing_data = True, call_from_main_evaluate = True)
else:
ret = par.eval(self, code)
store_to = "no use of left arrow"
mode = None
if store_to is None and mode != "<=":
_raise_value_error("If equation is of the form _ij <= ... or _ <= ... then you must use <= operator, not <+= or <<=.")
if store_to is None or store_to == "no use of left arrow":
if num_eqns != 1:
_raise_value_error("If equation is of the form _ij <= ... or you don't use <= at the top of the parsing tree, then you can't have more than one equation per evaluation.")
code._add_raw("__value = " + ret["value"])
if eqn_i < num_eqns - 1:
code._add_raw("#")
code.start_new_chunk()
code_txt = code.get_code_chunk(eqn_i, prefix = "")
# run this part of the code
code_dic["__object"] = self
time_exec = _nice_exec(code_txt, code_dic)
# store what came out of exec
ret["value"] = code_dic["__value"]
# store code used to compute this quantity
ret["exec"] = {"code": _decorate_code_into_a_function(code_txt_import + "\n" + code_txt),
"eqn_order_from_1": eqn_i + 1,
"num_eqns": num_eqns}
# store time spent in exec
ret["total_seconds_exec"] = time_exec
# release memory for variables inside exec's
if eqn_i == num_eqns - 1:
def_vars = code.get_all_defined_variables()
def_vars.sort()
del_txt = ""
for dd in def_vars:
del_txt += "del " + dd + "\n"
del_txt +="del _f\n"
del_txt +="del _s\n"
exec(del_txt, code_dic)
return ret, store_to, mode
def _work_on_one_equation_using_brute_force_sums(self, use_eqn, code, code_dic, eqn_i, num_eqns):
# There are certain checks on the syntax that we do with regular parsing (one
# without brute forcing sums). Since we don't do these checks on the brute
# force sums, I here have a dummy parsing, just to make sure that the
# same checks are applied in both cases. In the future one would probably
# want to find a better way to deal with this. Either one can separate out
# checks from the (non-bfs) parser, or duplicate code. Both options
# don't look clean to me. Third option is to simply remove the dummy parsing below.
# Another problem is that self._parser_brute_force_sums does not do units,
# so I need to steal units from the dummy parser below. Finally, there are
# some checks that appear in the opteinsum which are skipped if we don't use
# it.
if True:
__par = _my_parse_string(self._parser, use_eqn, parse_all = True)
__par = __par[0]
if isinstance(__par, EvalArrowOp):
__ret, __store_to, __mode = __par.eval(self, deepcopy(code), allow_storing_data = True, call_from_main_evaluate = True)
else: __ret = __par.eval(self, deepcopy(code))
code_txt_import = ""
code_txt_import += "from numba import njit\n"
code_txt_import += "import numpy as np\n"
data_about_used_tensors = []
par = _my_parse_string(self._parser_brute_force_sums, use_eqn, parse_all = True)
if len(par) != 1:
_raise_value_error("PyParsing returned something not expected?!")
par = par[0]
if isinstance(par, BfsArrowOp):
ret, store_to, mode = par.eval(data_about_used_tensors, self, allow_storing_data = True, call_from_main_evaluate = True)
ret["units"] = __ret["units"]
else:
# this condition is here as otherwise I would need to figure
# out which indices were duplicated (therefore summed over) and which were not
# but this choice would have to be consistent with what is done in case with brute_force_sums = False
_raise_value_error("If parameter brute_force_sums is set to True, then you must use one of the arrow operators " +
"in your expressions: <=, <+= or <<=.")
if store_to is None and mode != "<=":
_raise_value_error("If equation is of the form _ij <= ... or _ <= ... then you must use <= operator, not <+= or <<=.")
if store_to is None or store_to == "no use of left arrow":
if num_eqns != 1:
_raise_value_error("If equation is of the form _ij <= ... or you don't use <= at the top of the parsing tree, then you can't have more than one equation per evaluation.")
filters_bfs = []
for ii in range(len(self._filters)):
filters_bfs.append(self._filters[ii]["cond_for_brute_force"])
# now create code that involves all brute force sums
initialization_code = _create_brute_force_sums(self, code, code_dic, eqn_i, expression = ret["value"], indices_want = ret["ind"], data_about_used_tensors = data_about_used_tensors,
filters_bfs = filters_bfs, diagonals = self._diagonals)
if eqn_i < num_eqns - 1:
code._add_raw("#")
code.start_new_chunk()
code_txt = code.get_code_chunk(eqn_i, prefix = "")
# run this part of the code
time_exec = _nice_exec(code_txt, code_dic)
# store what came out of exec
ret["value"] = code_dic["__value"]
# store code used to compute this quantity
ret["exec"] = {"code": _decorate_code_into_a_function(code_txt_import + "\n" + initialization_code + "\n" + code_txt),
"eqn_order_from_1": eqn_i + 1,
"num_eqns": num_eqns}
# store time spent in exec
ret["total_seconds_exec"] = time_exec
if eqn_i == num_eqns - 1:
del code_dic
return ret, store_to, mode
def _is_parsed(self, core):
ret = False
if "parsed" in self.all_quantity_keys(core):
if self.get(core, "parsed") == True:
ret = True
return ret
def _what_to_say_if_user_messed(self, core):
if self.__did_user_mess_with_values[core] == False:
_raise_value_error("User did not mess with "+core)
out = ""
out += _format_one_block("""The quantity """ + core + """ has been accessed at some point and the user
might have changed its numerical value.""")
out += "\n\n"
out += _format_one_block( """**Therefore it is possible that the stored description of this
object is no longer valid. If you want more information about this object, try printing documentation about it before you modify it.** """)
out += "\n"
out += _format_one_block("Here is the location in the code where this quantity was accessed by the user and it might have been modified by the code.")
out += "\n"
for ln in self.__did_user_mess_with_values[core]:
out += _format_one_block(str(ln), indent = 8, width = 200)
out += "\n"
out += "\n"
return out
[docs] def info(self, core = None, print_to_screen = True, display = False, full = False, show_code = False, allow_multiple_expressions = False):
r"""
This function provides information about various quantities in the computator.
Note that there is a function with the same name that provides information
about the database .wf file, not the computator. See here
for more information on how to use this other function :func:`info <wfbase.DatabaseWf.info>`.
:param core: If set to None then the function will return information
about all quantities in the computator. Otherwise, it will give
information only about the quantity *core*.
:param print_to_screen: Boolean. If set to True (default) then the information
about the quantity will be printed on the screen by this function. If set to False
nothing is printed on the screen. Instead, in this case function returns a string
with the same information.
:param display: Boolean. If set to True it will display inside the terminal LaTeX-ed
equations if they are present in the description of the quantity. This feature works only
in terminals that support imgcat. One such terminal is iTerm2 on OS X. You can find information about
`installing imgcat in iTerm2 here <https://iterm2.com/documentation-images.html>`_.
If set to False (default) nothing is displayed in the terminal.
:param full: Boolean. If set to False (default) only some of the numerical values of the quantity are
shown. If set to True, all values are shown.
:param show_code: Boolean. If set to True the function will show the python code
used to compute the quantity. The default is False.
:param allow_multiple_expressions: Boolean. If set to True, and if *show_code* set to True, will
display python code used to compute the quantity, even if multiple expressions were processed
in a single call to :func:`evaluate <wfbase._ComputatorWf.evaluate>` and you are trying to access information
for a quantity that is not the first one that was evaluated. The default is False.
:returns:
* **txt** -- String with information about the quantities. This string is returned only
if *print_to_screen* is set to False.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# shows information about quantity "E"
comp.info("E")
# compute some quantity
comp.evaluate("sigma_oij <= (j / (numk * volume)) * (f_km - f_kn) * Real((E_km - E_kn) / (E_km - E_kn - hbaromega_o - j*eta)) * A_knmi * A_kmnj")
# shows information about newly computed quantity
comp.info("sigma")
# ... also displays LaTeX-ed equation defining this quantity
# This works only in terminals that support imgcat, such as iTerm2 on OS X.
# You can find information about installing imgcat in iTerm2 here:
# https://iterm2.com/documentation-images.html>
comp.info("sigma", display = True)
"""
if core is not None:
ret = self.__info_one_core(core, print_to_screen = print_to_screen, display = display,
full = full, show_code = show_code, allow_multiple_expressions = allow_multiple_expressions)
if print_to_screen == False:
return ret
else:
keys = self.all_quantities()
keys.sort()
ret = []
for k in keys:
ret.append(self.__info_one_core(k, print_to_screen = print_to_screen, display = display,
full = full, show_code = show_code, allow_multiple_expressions = allow_multiple_expressions))
if print_to_screen == False:
return "\n".join(ret)
def __info_one_core(self, core, print_to_screen, display, full, show_code, allow_multiple_expressions):
self.__does_core_exist(core)
out = ""
out += "\n"
out += _make_rst_title("Quantity *" + core.strip() + "*")
out += "\n"
out += _make_rst_field("Shape")
out += _format_one_block(str(self.get_shape(core)))
out += "\n\n"
out += _make_rst_field("Value")
if full == False:
with np.printoptions(precision = 3, linewidth = 80):
out += _format_one_block_simple_indent(str(np.array(self[core])), indent = 4, start_and_end = False, max_line = 8)
else:
with np.printoptions(precision = 6, linewidth = 80, threshold = sys.maxsize):
out += _format_one_block_simple_indent(str(np.array(self[core])), indent = 4, start_and_end = False, max_line = None)
out += "\n"
out += _make_rst_field("Units")
out += _format_one_block(str(self.get(core, "units")))
out += "\n\n"
show_latex = False
if "origin_story" in self.all_quantity_keys(core) or self._is_parsed(core):
if self.__did_user_mess_with_values[core] != False:
out += _make_rst_field("Origin story")
out += self._what_to_say_if_user_messed(core)
else:
if self._is_parsed(core):
out += _make_rst_field("Origin story")
out += _format_one_block("This quantity was computed by parsing the following string")
out += "\n\n"
out += _format_one_block(self.get(core, "parsed_string"), indent = 8)
out += "\n\n"
if self.get(core, "parsed_condition") != "":
out += _format_one_block("... under the following conditions")
out += "\n\n"
out += _format_one_block(self.get(core, "parsed_condition"), indent = 8)
out += "\n\n"
#
if show_code == True:
out += _make_rst_field("Python code")
out += "\n"
#
data_exec = self.get(core, "exec")
code = data_exec["code"]
which_eq = data_exec["eqn_order_from_1"]
number_eq = data_exec["num_eqns"]
#
if number_eq > 1 and which_eq > 1:
if allow_multiple_expressions == True:
out += _format_one_block(\
"*Note!* You evaluated " + str(number_eq) + " equation(s) in the same string at the same time. " +
"Your quantity " + str(core) + " was evaluated as the equation number " + str(which_eq) + \
" . Therefore, there were " + str(which_eq - 1) + " expression(s) computed before this one. " +
"The output below includes computations of only one quantity. However, some of the terms here " +
"might have been computed in earlier quantities. Therefore, the code below potentially can't be executed " +
"on its own, in isolation from the earlier " + str(which_eq - 1) + " expression(s). Additionally, " +
"it is possible that some of the quantities used in this expression were modified by the previous expression(s).")
out += "\n\n"
out += _format_one_block_simple_indent(code, indent = 0, start_and_end = True, dont_indent_first = False)
else:
out += _format_one_block(\
"Will not display code because this quantity was evaluated as equation number " + str(which_eq) + \
" out of total " + str(number_eq) + " equations. This can lead to ambiguities. If you still " +
"insist on getting partial code for this quantity, " +
"please set parameter allow_multiple_expressions to True when you call .info() function.")
out += "\n\n"
else:
out += _format_one_block_simple_indent(code, indent = 0, start_and_end = True, dont_indent_first = False)
out += "\n"
#
out += _make_rst_field("LaTeX")
latex_obj = self.get_latex(core)
latex_source = "$" + latex_obj.get_string() + "$"
out += _format_one_block(latex_source, indent = 4)
out += "\n\n"
show_latex = True
else:
out += _make_rst_field("Origin story")
out += _format_one_block(self.get(core, "origin_story"), indent = 4)
out += "\n\n"
if "indices_info" in self.all_quantity_keys(core):
txt_ind, txt_def = _process_index_information(self.get(core, "indices_info"))
out += _make_rst_field("Indices")
out += _format_one_block_simple_indent(txt_ind, indent = 4, start_and_end = False)
out += "\n"
out += _format_one_block_simple_indent(txt_def, indent = 4, start_and_end = False)
out += "\n\n"
if "latex" in self.all_quantity_keys(core):
out += _make_rst_field("LaTeX")
out += _format_one_block(self.get(core, "latex"), indent = 4)
out += "\n\n"
else:
if "latex" in self.all_quantity_keys(core):
out += _make_rst_field("LaTeX")
out += _format_one_block(self.get(core, "latex"), indent = 4)
out += "\n\n"
if print_to_screen:
print(out)
if show_latex == True:
if display == True:
latex_obj = self.get_latex(core)
render_latex(latex_obj, "__tmp.png")
display_in_terminal("__tmp.png")
print("\n\n")
return out
[docs] def verbose_evaluate(self, verbose = True):
r"""
If *verbose* is set to True then the code will print information about
each expression it evaluates.
:param verbose: whether to print information or not. If not specified,
it defaults to True.
Example usage::
import wfbase as wf
import numpy as np
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
comp.verbose_evaluate()
"""
self._verbose_evaluate = verbose
[docs] def report(self):
r"""
This function returns a LaTeX equation for all quantities that were parsed
by the user, in the order in which they were parsed. (The function skips
quantities that were later modified by the user.)
:returns:
* **txt** -- String with LaTeX equation.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# compute some quantity
comp.evaluate("sigma_oij <= (j / (numk * volume)) * (f_km - f_kn) * Real((E_km - E_kn) / (E_km - E_kn - hbaromega_o - j*eta)) * A_knmi * A_kmnj")
# you can print the equation to the screen
print(comp.report())
# ... or, you can also render the equation
wf.render_latex(comp.report(), "latex.pdf")
"""
out = r"\begin{align}" + "\n"
keys = self.all_quantities()
use_cores = []
order_cores = []
for core in keys:
if self._is_parsed(core):
if self.__did_user_mess_with_values[core] == False:
use_cores.append(core)
order_cores.append(self.get(core, "order_parsed"))
srt = np.argsort(order_cores)
use_cores = np.array(use_cores)[srt].tolist()
for core in use_cores:
latex_obj = self.get_latex(core)
latex_source = latex_obj.get_string(inside_align = True)
out += latex_source + "\n" + r"\notag" + "\n" + r"\\" + "\n"
if out.endswith(r"\\" + "\n"):
out = out[:-3]
out += r"\end{align}"
return out
[docs] def plot_bs(self, ax, plot_bands = True, plot_spec = True, plot_fermi = True, plot_xticks = True):
"""
Plots the band structure.
:param ax: Matplotlib's axes onto which you wish to plot the band structure.
:param plot_bands: If *True*, will plot the electron bands. (Default).
:param plot_spec: If *True*, will plot special k-points on the path. (Default).
:param plot_fermi: If *True*, will plot the Fermi level. (Default).
:param plot_xticks: If *True*, will plot the x-ticks of special points. (Default).
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# compute quantities on a path between these special points
comp = db.do_path("GM--H--N")
# plot the band structure
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
comp.plot_bs(ax)
ax.set_title("Band structure of Fe bcc")
fig.tight_layout()
fig.savefig("a.pdf")
"""
if self._computated_using not in ["do_path", "do_list"]:
_raise_value_error("Function .plot_bs you can use only on computators that were constructed using .do_path or .do_list")
if self._computated_using == "do_path":
if plot_bands == True:
for i in range(self["E"].shape[1]):
ax.plot(self["kdist"], self["E"][:, i], "k-", lw = 1.0, zorder = -200)
if plot_spec == True:
for i in range(1, self["kspecdist"].shape[0] - 1):
ax.axvline(self["kspecdist"][i], c = "b", ls = "--", lw = 0.7, zorder = -300)
if self._computated_using == "do_list":
if plot_bands == True:
for i in range(self["E"].shape[1]):
ax.plot(range(self["E"].shape[0]), self["E"][:, i], "k-")
if plot_fermi == True:
ax.axhline(self["ef"], c = "r", zorder = -100, lw = 0.7)
if self._computated_using == "do_path":
if plot_xticks == True:
ax.set_xticks(self["kspecdist"])
ax.set_xticklabels(self["kspeclabels"])
ax.set_xlim(self["kdist"][0], self["kdist"][-1])
if self._db_loaded_from_wannierberri == False:
ax.set_ylim(self["reliableminenergy"], self["reliablemaxenergy"])
def _decorate_code_into_a_function(code):
use = ""
use += "def evaluate_directly(__object):\n"
use += code
use += "return __value\n"
use = _format_one_block_simple_indent(use, indent = 4, start_and_end = False, dont_indent_first = True)
return use
def _replace_star_with_indices(use_str, ind):
fmt = pp.White() |\
pp.Combine("*" + pp.Word(pp.nums, min = 1)).set_results_name("index") |\
pp.Word(pp.printables, exclude_chars = "* ")
par = _my_parse_string(fmt[...], use_str, parse_all = True)
ret = ""
for p in par:
if p[0] == "*":
num_index = int(p[1:])
if num_index >= len(ind):
_raise_value_error("Index #" + str(num_index) + " does not exist in the formatting or definition of the quantity. " + use_str)
ret += ind[num_index]
else:
ret += p
return ret
def _process_index_information(indices_info):
ret_ind = ""
ret_ind += "This quantity has these indices\n[" + ", ".join(indices_info["canonical_names"]) + "]"
ret_ind += "\n"
if "definition" in indices_info.keys():
ret_ind += "\nDefinition of the quantity in terms of the same indices as above\n" + \
_replace_star_with_indices(indices_info["definition"], list(indices_info["canonical_names"]))
ret_ind += "\n"
ret_def = ""
for i, exp in enumerate(indices_info["explanation"]):
ret_def += "#" + str(i) + " index \"" + indices_info["canonical_names"][i] + "\" corresponds to the " + exp + "\n"
return ret_ind, ret_def
def _create_brute_force_sums(comp, code, code_dic, eqn_i, expression, indices_want, data_about_used_tensors, filters_bfs, diagonals):
# go over all indices that appear on the right of the arrow
all_indices_on_right = ""
for d in data_about_used_tensors:
all_indices_on_right += d["indices"]
all_indices_on_right = sorted(list(set(all_indices_on_right)))
for j in indices_want:
if j not in all_indices_on_right:
_raise_value_error("Index " + j + " appears on the left of arrow operator, but not on the right.")
# flush all old shared data from code_dic
keys = list(code_dic.keys())
for k in keys:
if str(k).startswith("__object_") or str(k).startswith("__size_"):
del code_dic[k]
# also need to include quantities that appear in filters_bfs
data_filters_bfs = []
for f in filters_bfs:
data_filters_bfs.append({"key": f["left_core"].replace("__object_", ""), "indices": f["left_index"]})
if f["right_type"] == "constant":
data_filters_bfs.append({"key": f["right"].replace("__object_", ""), "indices": ""})
initialization_code = ""
# add all matrices needed in code_dic
for d in data_about_used_tensors + data_filters_bfs:
key = d["key"]
key_reduced = key.replace("~", "")
# this will make a deepcopy
if "__object_" + key_reduced not in code_dic.keys():
code_dic["__object_" + key_reduced] = comp[key]
initialization_code += "__object_" + key_reduced + " = __object[\"" + key + "\"]" + "\n"
shape = code_dic["__object_" + key_reduced].shape
indices = d["indices"]
for i, j in enumerate(indices):
if "__size_" + j not in code_dic.keys():
code_dic["__size_" + j] = shape[i]
initialization_code += "__size_" + j + " = __object.get_shape(\"" + key + "\")[" + str(i) + "]" + "\n"
else:
if code_dic["__size_" + j] != shape[i]:
_raise_value_error("Index " + j + " of quantity " + key + " has shape " + shape[i] + " but earlier in the expression this index appeared with shape " + code_dic["__size_" + j] + " which is different!")
func_name = "_tmp_func__" + str(eqn_i).zfill(3)
code._add_raw(r"@njit" + "\n", indent4 = 0)
code._add_raw(r"def " + func_name + "(__value):" + "\n", indent4 = 0)
num_indents = 0
# order indices in the order in which Numba will be the fastest (this involves some guessing...)
optimal_order = _find_optimal_order_indices(all_indices_on_right, data_about_used_tensors + data_filters_bfs, indices_want)
for j in optimal_order:
code._add_raw("for " + j + " in range(__size_" + j + "):" + "\n", indent4 = num_indents + 1)
num_indents += 1
things_in_if_statement = []
#
for f in filters_bfs:
num_appear_on_right = 0
for j in f["left_index"]:
if j in all_indices_on_right:
num_appear_on_right += 1
num_appear_on_left = 0
for j in f["left_index"]:
if j in indices_want:
num_appear_on_left += 1
if num_appear_on_left > 0:
_raise_value_error("Some, of the indices \"" + f["left_index"] + "\" used in the condition appear on the left hand side of the arrow operator.")
if num_appear_on_right == len(f["left_index"]):
if f["right_type"] == "constant":
use_right = f["right"].replace("~", "").strip() + ".real"
else:
use_right = f["right"].strip()
things_in_if_statement.append("__object_" + f["left_core"].replace("~", "") + "[" + ",".join(f["left_index"]) + "].real " + f["op"] + " " + use_right)
elif num_appear_on_right != 0:
_raise_value_error("Some, but not all indices \"" + f["left_index"] + "\" used in the condition appear on the right hand side of the arrow operator.")
#
for d in diagonals:
num_appear_on_right = 0
for j in d:
if j in all_indices_on_right:
num_appear_on_right += 1
num_appear_on_left = 0
for j in d:
if j in indices_want:
num_appear_on_left += 1
if num_appear_on_right == 2 and num_appear_on_left in [0, 2]:
things_in_if_statement.append(d[0] + " != " + d[1])
if code_dic["__size_" + d[0]] != code_dic["__size_" + d[1]]:
_raise_value_error("Condition " + d[0] + " != " + d[1] + " is applied onto indices that are not of the same shape.")
elif num_appear_on_right == 2 and num_appear_on_left == 1:
_raise_value_error("In condition " + d[0] + " != " + d[1] + " you have both indices appear on the right of arrow and only one appear on the left of it. This is ambiguous.")
if len(things_in_if_statement) > 0:
code._add_raw("if " + " and ".join(things_in_if_statement) + ":" + "\n", indent4 = num_indents + 1)
num_indents += 1
if indices_want == "":
value_indices = ""
else:
value_indices = "[" + ",".join(indices_want) + "]"
_expr_multi = _pack_code_multiline("__value" + value_indices + " += ", expression)
for em in _expr_multi:
code._add_raw(em + "\n", indent4 = num_indents + 1)
if indices_want == "":
code._add_raw("__value = np.array([0.0], dtype = complex)" + "\n")
else:
size_indices = []
for j in indices_want:
size_indices.append("__size_" + j)
code._add_raw("__value = np.zeros((" + ",".join(size_indices) + "), dtype = complex)" + "\n")
code._add_raw(func_name + "(__value)" + "\n")
if indices_want == "":
code._add_raw("__value = __value[0]" + "\n")
return initialization_code
def _pack_code_multiline(first, second):
second_parts = textwrap.wrap(second, width = 80, break_long_words = False)
ret = []
for i in range(len(second_parts)):
if i == 0:
tmp = first + second_parts[i]
else:
tmp = " " * len(first) + second_parts[i]
if i != len(second_parts) - 1:
tmp = tmp + " \\"
ret.append(tmp)
return ret
def _find_optimal_order_indices(all_indices_on_right, data_tensors, indices_want):
# take indices that appear on the right of the arrow, but not on the left
# as we want to have the ones on the left at the end anyways
only_right = sorted(list(set(all_indices_on_right).difference(indices_want)))
# keep score for each index
score_for_only_right = np.zeros(len(only_right), dtype = int)
longest_tensor = 0
for d in data_tensors:
ind = deepcopy(d["indices"])
if len(ind) > longest_tensor:
longest_tensor = len(ind)
for d in data_tensors:
ind = deepcopy(d["indices"])
if ind != "":
for i, l in enumerate(only_right):
if l in ind:
where = ind.index(l)
score_for_only_right[i] += longest_tensor - where
srt = np.argsort(score_for_only_right)[::-1]
ret = "".join(np.array(list(only_right))[srt]) + indices_want
return list(ret)
[docs]class Units():
"""
Class that stores information about a physical quantity's
units of electron volts (eV), angstroms (Ang), and Bohr magnetons (muB).
:param eV: Power associated with electron-volts. For example, if this
parameter has value 2 then that represents units eV^2. The default is zero.
:param A: Same as *eV* but for angstrom.
:param muB: Same as *eV* but for Bohr magneton.
Example usage::
import wfbase as wf
import numpy as np
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# defines a new quantity, gamma, with value 3.0 in units of eV
comp.new("gamma", {"value": 3.0, "units": wf.Units(eV = 1)})
# defines a new quantity, delta, with value 2.5 in units of eV * A^2 / muB
comp.new("delta", {"value": 2.5, "units": wf.Units(eV = 1, Ang = 2, muB = -1)})
"""
def __init__(self, eV = 0.0, Ang = 0.0, muB = 0.0):
self._eV = float(eV)
self._Ang = float(Ang)
self._muB = float(muB)
self.__tolerance = 1.0E-8
def _to_SI(self, value):
factor = 1.0
if np.abs(self._eV) > self.__tolerance:
factor *= electron_charge_SI**self._eV
if np.abs(self._Ang) > self.__tolerance:
factor *= angstrom_SI**self._Ang
if np.abs(self._muB) > self.__tolerance:
factor *= (electron_charge_SI * hbar_SI / (2.0 * electron_mass_SI))**self._muB
return value * factor
def _check_units_the_same(self, another):
unit_diffs = np.abs(np.array([
self._eV - another._eV,
self._Ang - another._Ang,
self._muB - another._muB]))
return all(unit_diffs < self.__tolerance)
def _multiply(self, another):
return Units(eV = self._eV + float(another._eV ),
Ang = self._Ang + float(another._Ang),
muB = self._muB + float(another._muB))
def _inverse(self):
return Units(eV = (-1.0)*self._eV ,
Ang = (-1.0)*self._Ang,
muB = (-1.0)*self._muB)
def _divide(self, another):
return Units(eV = self._eV - float(another._eV ),
Ang = self._Ang - float(another._Ang),
muB = self._muB - float(another._muB))
def _exponent(self, ex):
if self._is_trivial():
if ex["units"]._is_trivial() == False:
_raise_value_error("You can't have exponent with units!")
return Units(eV = 0.0, Ang = 0.0 , muB = 0.0)
else:
# we need to get the numerical value of the exponent before it is evaluated
# this is a chicken and an egg problem. Therefore, I will allow here
# only simple exponents that consist only of numerical quantities,
# and parentheses or operations, as these might appear in the parsing.
_simple = ex["value"]
for s in _simple:
if s not in "()*+-/^j.0123456789 ":
_raise_value_error("""If you use (...)^(,,,) and if "..." has units
then ",,," can only be a numerical value.
You are not allowed to use constants or tensors, such as A_ij^(B + 3),
as long as A has units. You are allowed to do things
like A_ij^(-3.0) or similar. """ + _simple)
try:
numerical_value = eval(_simple)
except Exception:
traceback.print_stack()
print()
_raise_value_error("""Currently exponents can only be numerical values.
You are not allowed to use anything too complex in the exponent.
You are allowed to do things like A_ij^(-3.0) or similar. """ + _simple)
if abs(np.imag(numerical_value)/np.abs(numerical_value)) > self.__tolerance:
_raise_value_error("You can't raise something with units to exponent that is not real!")
return Units(eV = self._eV * float(np.real(numerical_value)),
Ang = self._Ang * float(np.real(numerical_value)),
muB = self._muB * float(np.real(numerical_value)))
def _exponent_float(self, vv):
return Units(eV = self._eV * vv,
Ang = self._Ang * vv,
muB = self._muB * vv)
def _is_trivial(self):
return (np.abs(self._eV ) < self.__tolerance) and \
(np.abs(self._Ang) < self.__tolerance) and \
(np.abs(self._muB) < self.__tolerance)
def __str__(self):
out = ""
if np.abs(self._eV - 1.0) < self.__tolerance:
out += r" eV "
elif self._eV > self.__tolerance:
out += r" eV^"+str(self._eV).strip() + " "
elif self._eV < (-1.0)*self.__tolerance:
out += r" eV^("+str(self._eV).strip() + ") "
if np.abs(self._Ang - 1.0) < self.__tolerance:
out += r" Ang "
elif self._Ang > self.__tolerance:
out += r" Ang^"+str(self._Ang).strip() + " "
elif self._Ang < (-1.0)*self.__tolerance:
out += r" Ang^("+str(self._Ang).strip() + ") "
if np.abs(self._muB - 1.0) < self.__tolerance:
out += r" muB "
elif self._muB > self.__tolerance:
out += r" muB^"+str(self._muB).strip() + " "
elif self._muB < (-1.0)*self.__tolerance:
out += r" muB^("+str(self._muB).strip() + ") "
if out == "":
out = "1"
return out.strip()
[docs]def render_latex(latex_str, fname, dpi = 300):
"""
Renders mathematical LaTeX expression *latex_str* and saves it into file *fname*.
:param latex_str: mathematical LaTeX expression to be rendered.
:param fname: name of the output file. Either png or pdf format.
:param dpi: dots per inch for png file (ignored for pdf).
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# now get LaTeX'ed data about this object
lat = comp.get_latex("sigma")
wf.render_latex(lat, "test.png")
"""
if len(fname.strip()) < 5:
_raise_value_error("Must provide a valid filename of a png or a pdf file.")
if fname.strip()[-4:].lower() not in [".png", ".pdf"]:
_raise_value_error("Must provide a valid filename of a png or a pdf file.")
else:
file_format = fname.strip()[-3:].lower()
expression = str(latex_str).strip()
if expression.startswith(r"\begin{align}") and expression.endswith(r"\end{align}"):
expression = expression
elif expression.startswith(r"$") and expression.endswith(r"$"):
expression = expression
else:
expression = "$" + expression + "$"
expression = expression.replace("\n", " ")
matplotlib.rcParams["text.usetex"] = True
matplotlib.rcParams["mathtext.fontset"] = "cm"
matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath} \usepackage{amssymb} \usepackage{amsfonts}"
fig = plt.figure()
text = fig.text(0.0, 0.0, expression, ha = "center", va = "center")
if file_format == "png":
fig.savefig(fname, dpi = dpi)
bbox = text.get_window_extent()
width, height = (bbox.size / float(dpi)) + 0.1
fig.set_size_inches((width, height))
text.set_position((0.5, 0.5))
fig.savefig(fname.strip(), dpi = dpi)
elif file_format == "pdf":
fig.savefig(fname)
bbox = text.get_window_extent()
width, height = bbox.size
fig.set_size_inches(((width /72.0) + 0.05,
(height/72.0) + 0.05))
text.set_position((0.5, 0.5))
fig.savefig(fname.strip())
else:
print("Should not happen.")
exit()
[docs]def display_in_terminal(fname):
"""
Displays image file *fname* in the terminal. Only supported on terminals
that allow imgcat, such as iTerm2 on OS X. You can find information about
`installing imgcat in iTerm2 here <https://iterm2.com/documentation-images.html>`_.
:param fname: name of the file.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# now get LaTeX'ed data about this object
lat = comp.get_latex("sigma")
wf.render_latex(lat, "test.png")
wf.display_in_terminal("test.png")
"""
f = open(fname)
imgcat.imgcat(f)
f.close()
[docs]def display_in_separate_window(fname):
"""
Displays image file *fname* in a separate window. The result will
depend on your operating system and how it interacts with python's
PIL module.
:param fname: name of the file.
Example usage::
import wfbase as wf
# open a database file on bcc phase of iron
db = wf.load("data/fe_bcc.wf")
# create now a computator from the database
comp = db.do_mesh()
# evaluate some object
comp.evaluate("sigma_ij <= (j / (numk * volume)) * (f_km - f_kn) * A_knmi * A_kmnj")
# now get LaTeX'ed data about this object
lat = comp.get_latex("sigma")
wf.render_latex(lat, "test.png")
wf.display_in_separate_window("test.png")
"""
im = Image.open(fname)
im.show()
####################################################################
# Below this point we have private classes and functions that user #
# is not meant to interact with directly. #
####################################################################
class _StorePythonCode():
# It is essential here that this code is simply defining, and modifying, variables
# one by one. So that if you work on variable A first then you later work on B, you
# will then never go back to working on A.
def __init__(self, optimize_recomputation = True):
self.__code_chunks = [""]
self.__code_preamble = [""]
self.__variable_names = []
# used to retrieve previously computed values
self.__broadcast_info = {}
self.__einsum_info = {}
# remember which variables might have been updated with <+= or <<= operators. You can't then use stored values involving those variables.
self.__which_cores_got_updated = []
self.__optimize_recomputation = optimize_recomputation
self.__counter_used_stored = {
"einsum_simple": 0,
"einsum_jumbled_indices":0,
"broadcast_simple": 0,
"broadcast_jumbled_indices":0
}
self._added_soft_div = False
def get_num_chunks(self):
return len(self.__code_chunks)
def get_code_chunk(self, ii, prefix = " "*4):
ret = ""
for l in self.__code_preamble[ii].split("\n"):
if l == "":
continue
ret += prefix + l.rstrip() + "\n"
for l in self.__code_chunks[ii].split("\n"):
if l == "":
continue
ret += prefix + l.rstrip() + "\n"
return ret
def _add_raw(self, txt, to_preamble = False, to_the_top = False, indent4 = 0):
inde = " " * indent4
if to_preamble == False:
if to_the_top == False:
self.__code_chunks[-1] += inde + txt.rstrip() + "\n"
else:
self.__code_chunks[-1] = inde + txt.rstrip() + "\n" + self.__code_chunks[-1]
else:
if to_the_top == False:
self.__code_preamble[-1] += inde + txt.rstrip() + "\n"
else:
self.__code_preamble[-1] = inde + txt.rstrip() + "\n" + self.__code_preamble[-1]
def give_me_unique_variable_name(self, base):
i = 0
while(True):
ret = base + str(i).zfill(2)
if ret not in self.__variable_names:
return ret
i = i + 1
def start_new_chunk(self):
self.__code_chunks.append("")
self.__code_preamble.append("")
def get_all_defined_variables(self):
return deepcopy(self.__variable_names)
def _add_definition(self, variable, right_code, adjust_newlines = False, to_preamble = False):
if variable in self.__variable_names:
_raise_value_error("Duplicate variable name!")
self.__variable_names.append(variable.strip())
if adjust_newlines == True:
if len(right_code) > 50:
right_code = right_code.replace("\\\n", "\\\n" + " "*(len(variable.strip()) + 13))
else:
right_code = right_code.replace("\\\n", " ")
if to_preamble == False:
self.__code_chunks[-1] += variable.strip() + " = " + right_code + "\n"
else:
self.__code_preamble[-1] += variable.strip() + " = " + right_code + "\n"
return variable
def _add_definition_from_einsum(self, base, einsum_1, einsum_2_with_newlines, do_copy):
einsum_2 = einsum_2_with_newlines.replace("\\\n", " ")
sp = einsum_1.split("->")
left_sp = sp[0].replace(",", "").strip()
right_sp = sp[1].strip()
for i in right_sp:
if i not in left_sp:
_raise_value_error("Index \"" + i + "\" appears on the left of assignment operator (<=, <<=, <+=) but this index does not appear on the right." )
if self.__optimize_recomputation:
# check if we ever did the same einsum before
for var in self.__einsum_info.keys():
# check if there is something with the same tensors
if einsum_2 == self.__einsum_info[var]["einsum_2"]:
# check if any of the variables got changed
values_got_changed = False
for core in self.__which_cores_got_updated:
if core in einsum_2:
values_got_changed = True
if values_got_changed == False:
if einsum_1 == self.__einsum_info[var]["einsum_1"]:
self.__counter_used_stored["einsum_simple"] += 1
return var
else:
# this checks if we earlier computed something with equivalent indices, but not exactly the same
left = self.__einsum_info[var]["einsum_1"]
right = einsum_1
mm = _find_1_to_1_map_from_left_to_right(left, right)
if mm is not None:
self.__counter_used_stored["einsum_jumbled_indices"] += 1
return var
# compute it, as we didn't find this before
variable = self.give_me_unique_variable_name(base)
if do_copy == False:
ret = self._add_definition(variable, "opteinsum(\"" + einsum_1 + "\",\\\n" + einsum_2_with_newlines + ")" , adjust_newlines = True)
else:
# probably could go without np.copy here, and at other places,
# but better be safe than sorry, as opteinum in many cases
# returns a shallow copy (even if indices are contracted)
ret = self._add_definition(variable, "np.copy(opteinsum(\"" + einsum_1 + "\",\\\n" + einsum_2_with_newlines + "))", adjust_newlines = True)
if self.__optimize_recomputation:
# store in case we need it later
if variable in self.__einsum_info.keys():
_raise_value_error("Duplicate variable name, einsum!")
self.__einsum_info[variable] = {"einsum_1": einsum_1, "einsum_2": einsum_2}
return ret
def changed_value_of_this_core(self, core):
self.__which_cores_got_updated.append(core)
def store_broadcast_info_for_lookup(self, variable, info):
if self.__optimize_recomputation:
if variable in self.__broadcast_info.keys():
_raise_value_error("Duplicate variable name, broadcast!")
self.__broadcast_info[variable] = info
def check_if_did_this_broadcast_before(self, vals, inds, operations):
if self.__optimize_recomputation:
for var in self.__broadcast_info.keys():
# check if there is an equivalent computation that was done already
if operations == self.__broadcast_info[var]["input_operations"]:
if vals == self.__broadcast_info[var]["input_vals"]:
# check if any of the variables got changed in the meantime
values_got_changed = False
for core in self.__which_cores_got_updated:
for v in vals:
if core in v:
values_got_changed = True
if values_got_changed == False:
# this means that earlier we computed the exact same thing, with the same set of indices
if inds == self.__broadcast_info[var]["input_inds"]:
self.__counter_used_stored["broadcast_simple"] += 1
return {"value": var,
"ind": self.__broadcast_info[var]["ret_ind"],
"units": self.__broadcast_info[var]["ret_units"]}
else:
# this checks if we earlier computed something with equivalent indices, but jumbled around
# This applies only to the case when the jumbled indices don't appear in conditions.
# If they appear in conditions (even if equivalent) the code will ignore that and
# will recompute everything from scratch.
left = ",".join(self.__broadcast_info[var]["input_inds"])
right = ",".join(inds)
mm = _find_1_to_1_map_from_left_to_right(left, right)
if mm is not None:
self.__counter_used_stored["broadcast_jumbled_indices"] += 1
return {"value": var,
"ind": _convert_left_to_right(self.__broadcast_info[var]["ret_ind"], mm),
"units": self.__broadcast_info[var]["ret_units"]}
return None
def how_many_times_used_stored(self):
return deepcopy(self.__counter_used_stored)
def _check_const(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
def _check_func(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
if (len(tokens[0]) != 2) or (str(tokens[0][0]) not in ["Real", "Imag"]):
_raise_value_error("Incorrect input. Problem with function parsing.")
def _check_conjug(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
if (len(tokens[0]) != 2) or (str(tokens[0][0]) != "#"):
_raise_value_error("Incorrect input. Problem with complex conjugation.")
def _check_d_one(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
if (len(tokens[0]) != 2) or (str(tokens[0][0]) != "&"):
_raise_value_error("Incorrect input. Problem with & operator.")
def _check_sign(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
if len(tokens[0]) != 2 or str(tokens[0][0]) not in ["+", "-"]:
_raise_value_error("Incorrect input. Problem with sign in front of a symbol.")
def _check_power(tokens):
if len(tokens) != 1 or tokens[0][1] != "^":
_raise_value_error("Unexpected behavior of pyparsing.")
if len(tokens[0]) != 3:
_raise_value_error("Don't allow things like A^B^C. Use parentheses instead. For example (A^B)^C.")
def _check_mult_div(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
for j in range(1, len(tokens[0]), 2):
op = tokens[0][j]
if op != "*" and op != "/":
_raise_value_error("This should never happen. Expected * or /")
count_divisions = 0
for j in range(1, len(tokens[0]), 2):
op = tokens[0][j]
if op == "/":
count_divisions +=1
if count_divisions > 1:
_raise_value_error("Not allowing terms like A / B / C or similar, as it might be ambiguous. Use parantheses to clarify what you mean.")
if count_divisions == 1:
if tokens[0][-2] != "/":
_raise_value_error("Not allowing terms like A / B * C or similar, as it might be ambiguous. Use parantheses to clarify what you mean.")
def _check_add_sub(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
for j in range(1, len(tokens[0]), 2):
op = tokens[0][j]
if op != "+" and op != "-":
_raise_value_error("This should never happen. Expected + or -")
def _check_ein(tokens):
if len(tokens) != 1:
_raise_value_error("Unexpected behavior of pyparsing.")
if len(tokens[0]) != 3 or str(tokens[0][1]) not in ["<=", "<+=", "<<="]:
_raise_value_error("Arrow must appear as ... <= ... If you want something fancier, use parantheses to clear up what you mean.")
class ReorgConstVar():
def __init__(self, s, loc, tokens):
_check_const(tokens)
self.value = tokens[0]
def eval(self, parent):
return self.value.strip()
class ReorgFuncOp():
def __init__(self, s, loc, tokens):
_check_func(tokens)
self.func = tokens[0][0]
self.value = tokens[0][1]
def eval(self, parent):
return self.func + "(" + self.value.eval(parent = self) + ")"
class ReorgConjugOp():
def __init__(self, s, loc, tokens):
_check_conjug(tokens)
self.value = tokens[0][1]
def eval(self, parent):
return "#" + "(" + self.value.eval(parent = self) + ")"
class ReorgSignOp():
def __init__(self, s, loc, tokens):
_check_sign(tokens)
self.sign, self.value = tokens[0]
def eval(self, parent):
return self.sign + "(" + self.value.eval(parent = self) + ")"
class ReorgPowerOp():
def __init__(self, s, loc, tokens):
_check_power(tokens)
self.value = tokens[0]
def eval(self, parent):
return "(" + self.value[0].eval(parent = self) + ")^(" + self.value[2].eval(parent = self) + ")"
def _take_inverse_of(obj, parent):
if isinstance(obj, ReorgConstVar):
return "&" + obj.eval(parent = parent)
elif isinstance(obj, ReorgFuncOp):
return "&(" + obj.eval(parent = parent) + ")"
elif isinstance(obj, ReorgConjugOp):
return "&(" + obj.eval(parent = parent) + ")"
elif isinstance(obj, ReorgSignOp):
return "&(" + obj.eval(parent = parent) + ")"
elif isinstance(obj, ReorgPowerOp):
return "&(" + obj.eval(parent = parent) + ")"
elif isinstance(obj, ReorgMultDivOp):
return obj.eval(parent = parent, add_inverse = True)
elif isinstance(obj, ReorgAddSubOp):
return "&" + obj.eval(parent = parent)
elif isinstance(obj, ReorgArrowOp):
return "&" + obj.eval(parent = parent)
else:
_raise_value_error("Unknown object in taking inverse.")
class ReorgMultDivOp():
def __init__(self, s, loc, tokens):
_check_mult_div(tokens)
self.value = tokens[0]
def eval(self, parent, add_inverse = False):
if add_inverse == False:
ret = self.value[0].eval(parent = self)
for j in range(1, len(self.value), 2):
op = self.value[j]
if op == "*":
nxt = self.value[j + 1].eval(parent = self)
elif op == "/":
nxt = _take_inverse_of(self.value[j + 1], self)
else:
_raise_value_error("Unexpected operator!")
ret = ret + " * " + nxt
else:
ret = _take_inverse_of(self.value[0], self)
for j in range(1, len(self.value), 2):
op = self.value[j]
if op == "*":
nxt = _take_inverse_of(self.value[j + 1], self)
elif op == "/":
nxt = self.value[j + 1].eval(parent = self)
else:
_raise_value_error("Unexpected operator!")
ret = ret + " * " + nxt
# it is more optimal NOT to have extra parentheses here as then
# something like A*B*C/(D*E) will be one product A*B*C*(1/D)*(1/E)
# instead of A*B*C*((1/D)*(1/E)) which has more overhead
# if len(self.value) > 1:
# ret = "( " + ret + " )"
return ret
class ReorgAddSubOp():
def __init__(self, s, loc, tokens):
_check_add_sub(tokens)
self.value = tokens[0]
def eval(self, parent):
ret = self.value[0].eval(parent = self)
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j+1].eval(parent = self)
ret = ret + " " + op + " " + nxt
if len(self.value) > 1:
ret = "(" + ret + ")"
return ret
class ReorgArrowOp():
def __init__(self, s, loc, tokens):
_check_ein(tokens)
self.value = tokens[0]
def eval(self, parent):
if parent is None:
return self.value[0].eval(parent = self) + " " + self.value[1] + " " + self.value[2].eval(parent = self)
else:
return "(" + self.value[0].eval(parent = self) + " " + self.value[1] + " " + self.value[2].eval(parent = self) + ")"
class ParserReorg():
def __init__(self):
operand = _get_operand()
operand.set_parse_action(ReorgConstVar)
self._rparser = pp.infix_notation(
operand,
[
(pp.oneOf("Real Imag") , 1, pp.opAssoc.RIGHT, ReorgFuncOp ), # various function calls
( "#" , 1, pp.opAssoc.RIGHT, ReorgConjugOp ), # complex conjugation
( "^" , 2, pp.opAssoc.LEFT , ReorgPowerOp ), # power raising (strictly speaking this should be a RIGHT not LEFT to follow conventions. But we don't allow user to do A^B^C so it doesn't matter.)
(pp.oneOf("+ -") , 1, pp.opAssoc.RIGHT, ReorgSignOp ), # sign in front of an object
(pp.oneOf("* /") , 2, pp.opAssoc.LEFT , ReorgMultDivOp ), # multiplication or division
(pp.oneOf("+ -") , 2, pp.opAssoc.LEFT , ReorgAddSubOp ), # addition and subtraction
(pp.oneOf("<= <+= <<="), 2, pp.opAssoc.LEFT , ReorgArrowOp ), # perform assignment
]
)
def reorganize(self, txt):
return _my_parse_string(self._rparser, txt, parse_all = True)[0].eval(parent = None)
def _perform_partial_filtering(obj, ind, partially_filtered_indices):
ret = deepcopy(obj)
for i, ii in enumerate(ind):
if ii in partially_filtered_indices:
tmp = []
for j in range(len(ind)):
if j == i:
tmp.append("_s[\"" + ii + "\"]")
else:
tmp.append(":")
ret += "[" + ",".join(tmp) + "]"
return ret
def _process_latex_imag_j(s):
return s.replace("j", r" i")
def _save_info_about_original_shapes(orig_val, ind, comp, code):
for i, ii in enumerate(ind):
if ii in comp._all_diagonal_indices:
if ii not in comp._all_diagonal_indices_stored_shape:
code._add_definition("_orig_shp[\"" + ii + "\"]", orig_val + ".shape[" + str(i) + "]", to_preamble = True)
comp._all_diagonal_indices_stored_shape += ii
class EvalConstVar():
def __init__(self, s, loc, tokens):
_check_const(tokens)
self.text = tokens[0]
self.name = tokens.get_name()
def eval(self, comp, code, calling_from_left_arrow = False, try_latex_conjugate = False):
did_latex_conjugate = False
if self.name == "ind_for_left_arrow":
_raise_value_error("You can't evaluate thing that starts with underscore. This thing needs to be to the left of <= operator or similar.")
elif self.name in ["float", "integer"]:
ret = {"value": str(self.text),
"ind": "",
"units": Units(eV = 0, Ang = 0 , muB = 0),
"latex": self.text.strip()}
elif self.name in ["imaginary float", "imaginary integer"]:
ret = {"value": str(complex(self.text)),
"ind": "",
"units": Units(eV = 0, Ang = 0 , muB = 0),
"latex": _process_latex_imag_j(self.text.strip())}
elif self.name in ["imaginary one"]:
ret = {"value": "complex(1.0j)",
"ind": "",
"units": Units(eV = 0, Ang = 0 , muB = 0),
"latex": _process_latex_imag_j("j")}
elif self.name in ["tensor"]:
core, ind = self.text.split("_")
if comp[core].shape == ():
_raise_value_error("Object " + core + " does not have a shape of a tensor.")
orig_val = "__object[\"" + core + "\"]"
ret = {"value": _perform_partial_filtering(orig_val, ind, comp._partially_filtered_indices),
"ind": ind,
"units": comp.get(core, "units"),
"latex": comp._return_in_latex(core.strip(), ind.strip(), try_latex_conjugate)}
did_latex_conjugate = True
_save_info_about_original_shapes(orig_val, ret["ind"], comp, code)
elif self.name in ["constant"]:
if comp[self.text].shape != ():
_raise_value_error("Object " + self.text + " is not a constant.")
ret = {"value": "__object[\""+ self.text +"\"]",
"ind": "",
"units": comp.get(self.text, "units"),
"latex": comp._return_in_latex(self.text.strip(), None, try_latex_conjugate)}
did_latex_conjugate = True
else:
_raise_value_error("Unknown name!")
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
if try_latex_conjugate == True:
return ret, did_latex_conjugate
else:
return ret
class EvalFuncOp():
def __init__(self, s, loc, tokens):
_check_func(tokens)
self.func = tokens[0][0]
self.value = tokens[0][1]
def eval(self, comp, code, calling_from_left_arrow = False):
ret = self.value.eval(comp = comp, code = code)
if self.func == "Real":
ret["value"] = "np.real(" + ret["value"] + ")"
ret["latex"] = r" {\rm Re} \left( " + ret["latex"].strip() + r" \right) "
elif self.func == "Imag":
ret["value"] = "np.imag(" + ret["value"] + ")"
ret["latex"] = r" {\rm Im} \left( " + ret["latex"].strip() + r" \right) "
else:
_raise_value_error("Unknown function.")
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
class EvalConjugOp():
def __init__(self, s, loc, tokens):
_check_conjug(tokens)
self.value = tokens[0][1]
def eval(self, comp, code, calling_from_left_arrow = False):
if isinstance(self.value, EvalConstVar):
ret, did_latex_conjugate = self.value.eval(comp = comp, code = code, try_latex_conjugate = True)
else:
ret = self.value.eval(comp = comp, code = code)
did_latex_conjugate = False
ret["value"] = "np.conjugate(" + ret["value"] + ")"
if did_latex_conjugate == False: # revert to basic conjugation.
ret["latex"] = r" \overline{ " + ret["latex"].strip() + r" } "
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
def _prepare_soft_divide(numer, denom_value, denom_ind, code, comp, parantheses_around):
all_whch = []
for dia in comp._diagonals:
howmany = 0
for i in dia:
if i in denom_ind:
howmany += 1
if howmany == 0 or howmany == 1:
continue
whch = {"ind": [],
"osh": None,
"sli": []}
for j, jj in enumerate(denom_ind):
if jj is None:
continue
if jj in dia:
whch["ind"].append(j)
if whch["osh"] is None:
whch["osh"] = "_orig_shp[[[\"" + jj + "\"]]]"
if jj in comp._partially_filtered_indices:
whch["sli"].append("_s[[[\"" + jj + "\"]]]")
else:
whch["sli"].append(None)
all_whch.append(whch)
if len(all_whch) == 0:
return None
if code._added_soft_div == False:
txt = """
def _soft_divide(_x, _y, _whc_all):
if len(_whc_all) == 0: raise ValueError("Should not occur.")
_yc = np.copy(_y)
_all_dd = False
for _whc in _whc_all:
_dd = np.zeros(tuple([_whc["osh"]] * len(_whc["ind"])), dtype = bool)
np.fill_diagonal(_dd, True)
for _i in range(len(_whc["ind"])):
if _whc["sli"][_i] is not None:
_dd = _dd[tuple([slice(None) if _i != _j else _whc["sli"][_i] for _j in range(len(_whc["ind"]))])]
_slic = tuple([slice(None) if _i in _whc["ind"] else None for _i in range(_y.ndim)])
_dd = _dd[_slic] + np.zeros(_y.shape, dtype = bool)
_all_dd = np.logical_or(_dd, _all_dd)
_yc[_all_dd] = np.inf
_ret = np.nan_to_num(_x / _yc)
return _ret
""".strip() + "\n"
code._add_raw(txt, to_preamble = True, to_the_top = True)
code._added_soft_div = True
str_all_whch = str(all_whch)
str_all_whch = str_all_whch.replace(r"'_s[[[", "_s[").replace("]]]'", "]").replace(r"'_orig_shp[[[", "_orig_shp[")
str_all_whch = str_all_whch.replace("{", "\n {")
ret = "_soft_divide(" + numer + ", " + denom_value + ", " + str_all_whch + ")"
if parantheses_around == True:
ret = "(" + ret + ")"
return ret
class EvalDOneOp():
def __init__(self, s, loc, tokens):
_check_d_one(tokens)
self.value = tokens[0][1]
def eval(self, comp, code, calling_from_left_arrow = False):
ret = self.value.eval(comp = comp, code = code)
tsd = _prepare_soft_divide("1.0", ret["value"], list(ret["ind"]), code, comp, parantheses_around = True)
if tsd is None:
ret["value"] = "(1.0/(" + ret["value"] + "))"
else:
ret["value"] = tsd
ret["latex"] = r" \left( " + ret["latex"].strip() + r" \right)^{-1} "
ret["units"] = ret["units"]._inverse()
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
class EvalSignOp():
def __init__(self, s, loc, tokens):
_check_sign(tokens)
self.sign, self.value = tokens[0]
def eval(self, comp, code, calling_from_left_arrow = False):
mult = {"+": 1.0, "-": -1.0}[self.sign]
ret = self.value.eval(comp = comp, code = code)
ret["value"] = "(" + str(mult) + ") * (" + ret["value"] + ")"
ret["latex"] = r"\left( " + self.sign.strip() + " " + ret["latex"] + r" \right)"
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
class EvalPowerOp():
def __init__(self, s, loc, tokens):
_check_power(tokens)
self.value = tokens[0]
def eval(self, comp, code, calling_from_left_arrow = False):
trmL = self.value[0].eval(comp = comp, code = code)
trmR = self.value[2].eval(comp = comp, code = code)
if trmR["ind"] != "":
_raise_value_error("Exponents must be simply numbers.")
if trmR["units"]._is_trivial() == False:
_raise_value_error("Exponents can't have units.")
ret = {"value": "np.power(" + trmL["value"] + ", " + trmR["value"] + ")",
"ind": trmL["ind"],
"units": trmL["units"]._exponent(trmR),
"latex": r"\left(" + trmL["latex"].strip() + r"\right)"+ r"^{ " + trmR["latex"].strip() + r" } "} # maybe parantheses not needed if trmL is a single object?
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
def _check_that_return_indices_compatible_with_filter(indices_to, filters):
for filt in filters:
found = ""
for i in filt["cond_inds"]:
if i in indices_to:
found += i
if found != "":
_raise_value_error("You specified greater/lesser condition involving indices " + filt["cond_inds"] + " but now you are not summing over index " + found + ".")
def _add_filters_to_einsum(indices_from, indices_to, filters):
filter_vals = []
filter_from = []
filter_latex = []
_check_that_return_indices_compatible_with_filter(indices_to, filters)
for filt in filters:
howmany = 0
for i in filt["cond_inds"]:
if i in indices_from:
howmany += 1
if howmany == 0:
continue
elif howmany != len(filt["cond_inds"]):
_raise_value_error("You specified greater/lesser condition involving indices " + filt["cond_inds"] + ". " +
"Therefore, for any sum in your expression, if any of these indices appear " +
"in the sum, all other should appear as well. Otherwise, greater/lesser condition you " +
"specified makes no sense.")
filter_from.append(filt["cond_inds"])
filter_vals.append(filt["cond_value"])
filter_latex.append(filt["cond_latex"])
return filter_vals, filter_from, filter_latex
def _add_diagonals_to_einsum(indices_from, indices_to, comp):
diagonal_vals = []
diagonal_from = []
diagonal_latex = []
for idia, dia in enumerate(comp._diagonals):
howmany_from = 0
for i in dia:
if i in indices_from:
howmany_from += 1
howmany_to = 0
for i in dia:
if i in indices_to:
howmany_to += 1
if howmany_from == 2 and howmany_to == 1:
_raise_value_error("You have specified this condition on indices ( " + dia[0] + " != " + dia[1] + " ) but now you have one index " +
"on the left of arrow operator and both on the right of it. It is ambiguous what you want to do with this condition.")
if howmany_from == 0 or howmany_from == 1:
continue
if howmany_to == 1 or howmany_to == 2:
continue
diagonal_from.append("".join(dia))
val = "(1.0 - np.eye(_orig_shp[\"" + dia[0] + "\"], _orig_shp[\"" + dia[1] + "\"]))"
val = _perform_partial_filtering(val, dia, comp._partially_filtered_indices)
diagonal_vals.append(val)
diagonal_latex.append(dia[0] + r" \neq " + dia[1])
return diagonal_vals, diagonal_from, diagonal_latex
def _remove_outer_para(s):
#Removes parentheses such as ( 4 + 5 * (3 - 1) ). But leaves alone stuff like (3 + 43) (2 - 3)
if len(s) < 13:
return s
if s[:6] != r"\left(" or s[-7:] != r"\right)":
return s
counter = 0
for i in range(6, len(s) - 7):
if s[i:].startswith(r"\left("):
counter += 1
elif s[i:].startswith(r"\right)"):
counter -= 1
if counter < 0:
return s
if counter != 0:
_raise_value_error("Missing parentheses!")
return s[6:-7]
class EvalMultDivOp():
def __init__(self, s, loc, tokens):
_check_mult_div(tokens)
self.value = tokens[0]
def eval(self, comp, code, calling_from_left_arrow = False):
vals = []
inds = []
operations = []
latex = ""
tmpL = self.value[0].eval(comp = comp, code = code)
vals.append(tmpL["value"])
inds.append(tmpL["ind"])
units = tmpL["units"]
latex += tmpL["latex"].strip()
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j + 1]
if op == "/" and isinstance(nxt, EvalAddSubOp):
tmpR = nxt.eval(comp = comp, code = code, no_outer_para = True)
else:
tmpR = nxt.eval(comp = comp, code = code)
vals.append(tmpR["value"])
inds.append(tmpR["ind"])
operations.append(op)
if op == "*":
units = units._multiply(tmpR["units"])
latex += r" \, " + tmpR["latex"].strip()
elif op == "/":
units = units._divide(tmpR["units"])
latex = r" \frac{ " + _remove_outer_para(latex.strip()) + r" }{ " + tmpR["latex"].strip() + r" } "
else:
_raise_value_error("Wrong operator!?")
# if operations[-1] != "/":
# latex = r" \left( " + latex + r" \right) "
if calling_from_left_arrow == False:
ret = _broadcast_indices(vals, inds, operations, units, code, comp)
ret["latex"] = latex
else:
tmp_indices_from = ",".join(inds)
indices_to = calling_from_left_arrow
filter_vals, filter_from, filter_latex = _add_filters_to_einsum(tmp_indices_from, indices_to, comp._filters)
diagonal_vals, diagonal_from, diagonal_latex = _add_diagonals_to_einsum(tmp_indices_from, indices_to, comp)
indices_from = ",".join(inds + filter_from + diagonal_from)
indices = indices_from + "->" + indices_to
summed_over_indices = sorted(set("".join(inds)).difference(indices_to))
if operations[-1] == "/":
tsd = _prepare_soft_divide("1.0", vals[-1], list(inds[-1]), code, comp, parantheses_around = True)
if tsd is None:
vals[-1] = "(1.0/(" + vals[-1] + "))"
else:
vals[-1] = tsd
einsum_1 = indices
einsum_2 = ",\\\n".join(vals + filter_vals + diagonal_vals)
ret = {"value": code._add_definition_from_einsum("__mult", einsum_1, einsum_2, do_copy = False),
"ind": calling_from_left_arrow,
"units": units}
summed_over_index = "".join(sorted(summed_over_indices)).strip()
if len(summed_over_index) > 0:
ret["latex"] = r" \displaystyle\sum_{ " + _nicefy_subscript(summed_over_index) + r" }"
else:
ret["latex"] = " "
if len(filter_latex) > 0 or len(diagonal_latex) > 0:
if len(summed_over_index) == 0:
_raise_value_error("This should not happen. If you have nothing summing over, then there should be no conditions.")
ret["latex"] += r"^{\substack{"
ret["latex"] += r" \\ ".join(filter_latex + diagonal_latex)
ret["latex"] += r"}}"
ret["latex"] += " " + latex.strip() + r" "
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
class EvalAddSubOp():
def __init__(self, s, loc, tokens):
_check_add_sub(tokens)
self.value = tokens[0]
def eval(self, comp, code, calling_from_left_arrow = False, no_outer_para = False):
vals = []
inds = []
latex = ""
operations = []
tmpL = self.value[0].eval(comp = comp, code = code)
vals.append(tmpL["value"])
inds.append(tmpL["ind"])
units = tmpL["units"]
latex += tmpL["latex"].strip()
prev_latex = tmpL["latex"].strip()
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j+1]
tmpR = nxt.eval(comp = comp, code = code)
vals.append(tmpR["value"])
inds.append(tmpR["ind"])
operations.append(op)
latex += " " + str(op).strip() + " " + tmpR["latex"].strip()
if units._check_units_the_same(tmpR["units"]) == False:
_raise_value_error("Units in your expression do not match! You are trying to " +
{"+": "add", "-": "subtract"}[str(op).strip()] + " term with units of [" +
str(units) + "] to term with units of [" + str(tmpR["units"]) + "]." +
"BREAK The first term is: " + prev_latex.strip() + "\n" +
"BREAK The second term is: " + tmpR["latex"].strip() + "\n")
prev_latex = tmpR["latex"].strip()
# THIS MIGHT BE DONE BETTER. MAYBE NO NEED TO BROADCAST
# HERE TO INDICES THAT YOU WILL LATER COLLAPSE IN THE EINSUM?
ret = _broadcast_indices(vals, inds, operations, units, code, comp)
latex = latex.strip()
if no_outer_para == False:
latex = r" \left( " + latex + r" \right) "
ret["latex"] = latex
ret = _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp)
if calling_from_left_arrow != False:
ret = _apply_diagonals_on_output_if_indices_remain(code, ret, comp)
return ret
class EvalArrowOp():
def __init__(self, s, loc, tokens):
_check_ein(tokens)
self.value = tokens[0]
def eval(self, comp, code, allow_storing_data = False, call_from_main_evaluate = False):
if self.value[0].name not in ["ind_for_left_arrow", "tensor", "constant"]:
_raise_value_error("Incorrect usage of <= operator. On the left of <= there should be something of the form X_Y where X can be a new variable, or X can be empty string. Y should be set of indices or empty.")
if call_from_main_evaluate == False:
# It is not clear how to deal with conditions. For example, if one says m!=n but then
# mn appear both in nested arrow and main arrow, these m and n indices could refer to different
# things. For example, they might correspond to matrices with different shapes, and
# our _orig_shp thing saves a global shape for a fixed index.
_raise_value_error("Not allowing nested assignment operators, as that might lead to ambiguity.")
if allow_storing_data == False:
if self.value[0].name != "ind_for_left_arrow":
_raise_value_error("Incorrect usage of <= operator. On the left of <= there should be _ or _ijk or similar.")
if "_" in self.value[0].text:
store_to = self.value[0].text.split("_")[0]
if store_to == "":
store_to = None
indices_want = self.value[0].text.split("_")[1]
else:
store_to = self.value[0].text
indices_want = ""
if len(indices_want) != len(list(set(indices_want))):
_raise_value_error("You have duplicate indices on the left of the assignment operator <=, <<=, or <+= .")
ret = self.value[2].eval(comp = comp, code = code, calling_from_left_arrow = indices_want)
if allow_storing_data == False:
return ret
else:
return ret, store_to, self.value[1]
def _resolve_indices_if_called_from_left_arrow(calling_from_left_arrow, ret, code, comp):
if calling_from_left_arrow == False:
return ret
elif ret["ind"] == calling_from_left_arrow:
_check_that_return_indices_compatible_with_filter(calling_from_left_arrow, comp._filters)
return ret
else:
summed_over_indices = sorted(set("".join(ret["ind"])).difference(calling_from_left_arrow))
how_many_indices_summing_over = len(summed_over_indices)
latex_start = ret["latex"]
tmp_indices_from = ret["ind"]
indices_to = calling_from_left_arrow
filter_vals, filter_from, filter_latex = _add_filters_to_einsum(tmp_indices_from, indices_to, comp._filters)
diagonal_vals, diagonal_from, diagonal_latex = _add_diagonals_to_einsum(tmp_indices_from, indices_to, comp)
if how_many_indices_summing_over == 0 and (len(filter_vals) > 0 or len(diagonal_vals) > 0):
_raise_value_error("Should not happen. Can't apply conditions if there is no summation.")
if how_many_indices_summing_over > 0:
summed_over_index = "".join(sorted(summed_over_indices)).strip()
if len(summed_over_index) > 0:
latex = r" \displaystyle\sum_{ " + _nicefy_subscript(summed_over_index) + r" }"
else:
latex = r" "
if len(filter_latex)>0 or len(diagonal_vals) > 0:
if len(summed_over_index) == 0:
_raise_value_error("This should not happen. If you have nothing summing over, then there should be no conditions.")
latex += r"^{\substack{"
latex += r" \\ ".join(filter_latex + diagonal_latex)
latex += r"}}"
latex += r" \left( " + latex_start + r" \right) "
else:
latex = latex_start
indices_from = ",".join([ret["ind"]] + filter_from + diagonal_from)
indices = indices_from + "->" + indices_to
einsum_1 = indices
einsum_2 = ",\\\n".join([ret["value"]] + filter_vals + diagonal_vals)
if len([ret["value"]] + filter_vals + diagonal_vals) > 1:
do_copy = False
else:
do_copy = True
ret = {"value": code._add_definition_from_einsum("__reso", einsum_1 , einsum_2, do_copy = do_copy),
"ind": calling_from_left_arrow,
"units": ret["units"],
"latex": latex}
return ret
def _apply_diagonals_on_output_if_indices_remain(code, ret, comp):
all_val = []
for idia, dia in enumerate(comp._diagonals):
howmany = 0
for i in dia:
if i in ret["ind"]:
howmany += 1
if howmany == 0 or howmany == 1:
continue
where_dia_0 = ret["ind"].index(dia[0])
where_dia_1 = ret["ind"].index(dia[1])
if where_dia_0 < where_dia_1:
val = "(1.0 - np.eye(_orig_shp[\"" + dia[0] + "\"], _orig_shp[\"" + dia[1] + "\"]))"
else:
val = "(1.0 - np.eye(_orig_shp[\"" + dia[1] + "\"], _orig_shp[\"" + dia[0] + "\"]))"
# no need here for partial filtering as there is no < or > condition applied to the output indices!
# val = _perform_partial_filtering(val, dia, comp.partially_filtered_indices)
# val = "(" + val + ")"
tmp = []
for j in ret["ind"]:
if j in dia:
tmp.append(":")
else:
tmp.append("None")
val = val + "[" + ", ".join(tmp) + "]"
all_val.append(val)
if len(all_val) == 0:
return ret
all_val = " * ".join(all_val)
variable = code.give_me_unique_variable_name("__removediag")
code._add_definition(variable, "(" + ret["value"] + ") * (" + all_val + ")")
ret["value"] = variable
return ret
def _broadcast_indices(vals, inds, operations, units, code, comp):
if (len(vals) != len(inds)) or (len(vals) != len(operations) + 1):
_raise_value_error("Inconsistent input!")
# check if we ever called broadcast with these same parameters
bef = code.check_if_did_this_broadcast_before(vals, inds, operations)
if bef is not None:
return bef
# broadcast function will create new variable in the python code for exec
# this will be the name of that variable
result_val_name = code.give_me_unique_variable_name("__brod")
l_val, l_ind = vals[0], inds[0]
#obtain a sorted set representing the indices that the end-result should have
ret_ind = "".join(sorted(set("".join(inds))))
# construct variable where we will store result
#
# place None at missing indices
result_axes = ["None" for i in ret_ind]
# leave other indices as they are
for i in l_ind:
result_axes[ret_ind.find(i)] = ":"
# now make sure that indices are in the right order
if l_ind != "":
jss_l_ind = "".join(sorted(set(l_ind)))
if l_ind != jss_l_ind:
einsum_1 = l_ind + "->" + jss_l_ind
einsum_2 = l_val
code._add_definition(result_val_name, "np.copy(opteinsum(\"" + einsum_1 + "\", " + einsum_2 + "))")
else:
code._add_definition(result_val_name, "np.copy(" + l_val + ")")
else:
code._add_definition(result_val_name, "np.array(" + l_val + ")")
# if len(result_axes) > 0:
# code._add_raw(result_val_name + " = " + result_val_name + "[" + ",".join(result_axes) + "]")
for j in range(1, len(vals)):
r_val, r_ind = vals[j], inds[j]
if r_ind == "":
if operations[j - 1] == "+":
code._add_raw(result_val_name + " = " + result_val_name + " + (" + r_val + ")")
elif operations[j - 1] == "-":
code._add_raw(result_val_name + " = " + result_val_name + " - (" + r_val + ")")
elif operations[j - 1] == "*":
code._add_raw(result_val_name + " = " + result_val_name + " * (" + r_val + ")")
elif operations[j - 1] == "/":
code._add_raw(result_val_name + " = " + result_val_name + " / (" + r_val + ")")
else:
_raise_value_error("Unknown operation!")
else:
r_axes = ["None" for i in ret_ind]
for i in r_ind:
r_axes[ret_ind.find(i)] = ":"
if len(r_axes) > 0:
r_axes_str = "[" + ",".join(r_axes) + "]"
else:
r_axes_str = ""
r_ind_ssj = "".join(sorted(set(r_ind)))
if r_ind == r_ind_ssj:
rot_r_val_code = "(" + r_val + ")"
else:
rot_r_val_code = "np.copy(opteinsum(\"" + r_ind + "->" + r_ind_ssj + "\", " + r_val + "))"
if operations[j - 1] in ["+", "-", "*", "/"]:
if len(result_axes) > 0:
result_axes_str = "[" + ",".join(result_axes) + "]"
else:
result_axes_str = ""
if operations[j - 1] != "/":
code._add_raw(result_val_name + " = " +
result_val_name + result_axes_str + " " +
operations[j - 1] + " " +
rot_r_val_code + r_axes_str)
else:
r_ind_ssj_expanded = []
pos = 0
for ra in r_axes:
if ra == ":":
r_ind_ssj_expanded.append(r_ind_ssj[pos])
pos += 1
elif ra == "None":
r_ind_ssj_expanded.append(None)
else:
_raise_value_error()
if pos != len(r_ind_ssj):
_raise_value_error("Missing something. Should not happen.")
tsd = _prepare_soft_divide(result_val_name + result_axes_str, rot_r_val_code + r_axes_str, list(r_ind_ssj_expanded), \
code, comp, parantheses_around = False)
if tsd is None:
code._add_raw(result_val_name + " = " +
result_val_name + result_axes_str + " " +
operations[j - 1] + " " +
rot_r_val_code + r_axes_str)
else:
code._add_raw(result_val_name + " = " + tsd)
# here the problem is that it might happen that result_axes_str and r_axes_str have None at the same place
# in that case you should squeeze out those indices
to_squeeze = []
for k in range(len(result_axes)):
if result_axes[k] == "None" and r_axes[k] == "None":
to_squeeze.append(k)
if len(to_squeeze) > 0:
code._add_raw(result_val_name + " = " + result_val_name + ".squeeze(axis = " + str(tuple(to_squeeze)) + ")")
else:
_raise_value_error("Unknown operation!")
for i in r_ind:
result_axes[ret_ind.find(i)] = ":"
if "None" in result_axes:
_raise_value_error("Hm, this shouldn't happen. All indices must be sliced eventually...")
ret = {"value": result_val_name, "ind": ret_ind, "units": units}
# store parameters you sent to this function
# in case we ever call this function again with the same parameters
code.store_broadcast_info_for_lookup(result_val_name, {"input_vals": vals, "input_inds": inds, "input_operations": operations, "ret_ind": ret_ind, "ret_units": units})
return ret
def _nicefy_core(txt):
ll = txt.strip()
if ll.count("~") > 1:
_raise_value_error("Wrong format of tensor/constant: " + ll)
if ll.startswith("~") == True:
_raise_value_error("Name: " + ll + " is invalid. It can't start with ~.")
if ll.endswith("~") == True:
_raise_value_error("Name: " + ll + " is invalid. It can't end with ~.")
if ll.count("~") == 0:
core = ll
superscript = ""
else:
sp = ll.split("~")
core = sp[0].strip()
superscript = sp[1].strip()
if core in "alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi pi rho sigma tau upsilon phi chi psi omega Gamma Delta Theta Lambda Xi Pi Sigma Upsilon Phi Psi Omega".split(" "):
core = "\\" + core
else:
if len(core) > 1:
core = r"\mathrm{" + core + r"}"
if superscript == "":
return core
else:
return core + r"^{\mathrm{ " + superscript +" }}"
def _nicefy_subscript(txt):
ll = txt.strip()
return "".join(ll)
class BfsConstVar():
def __init__(self, s, loc, tokens):
_check_const(tokens)
self.value = tokens[0]
self.name = tokens.get_name()
def eval(self, data, parent):
if self.name in ["tensor"]:
core, ind = self.value.split("_")
core = core.strip()
ind = ind.strip()
data.append({"key": core, "indices": ind})
return "__object_" + core.replace("~", "") + "[" + ",".join(ind) + "]"
elif self.name in ["constant"]:
core = self.value.strip()
data.append({"key": core, "indices": ""})
return "__object_" + core.replace("~", "")
elif self.name in ["imaginary one"]:
return "(1.0j)"
return self.value.strip()
class BfsFuncOp():
def __init__(self, s, loc, tokens):
_check_func(tokens)
self.func = tokens[0][0]
self.value = tokens[0][1]
def eval(self, data, parent):
if self.func == "Real":
return "(" + self.value.eval(data, parent = self) + ").real"
elif self.func == "Imag":
return "(" + self.value.eval(data, parent = self) + ").imag"
else:
_raise_value_error("Unknown function.")
class BfsConjugOp():
def __init__(self, s, loc, tokens):
_check_conjug(tokens)
self.value = tokens[0][1]
def eval(self, data, parent):
return "(" + self.value.eval(data, parent = self) + ").conjugate()"
class BfsSignOp():
def __init__(self, s, loc, tokens):
_check_sign(tokens)
self.sign, self.value = tokens[0]
def eval(self, data, parent):
return self.sign + "(" + self.value.eval(data, parent = self) + ")"
class BfsDOneOp():
def __init__(self, s, loc, tokens):
_check_d_one(tokens)
self.sign, self.value = tokens[0]
def eval(self, data, parent):
return "(1.0/(" + self.value.eval(data, parent = self) + "))"
class BfsPowerOp():
def __init__(self, s, loc, tokens):
_check_power(tokens)
self.value = tokens[0]
def eval(self, data, parent):
term = self.value[0].eval(data, parent = self)
exponent = self.value[2].eval(data, parent = self)
if exponent in ["1", "1.0"]:
return "(" + term + ")"
elif exponent in ["2", "2.0"]:
return "((" + term + ")*(" + term + "))"
elif exponent in ["-1", "-1.0"]:
return "(1.0/(" + term + "))"
else:
return "(" + term + ")**(" + exponent + ")"
class BfsMultDivOp():
def __init__(self, s, loc, tokens):
_check_mult_div(tokens)
self.value = tokens[0]
def eval(self, data, parent):
ret = self.value[0].eval(data, parent = self)
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j + 1].eval(data, parent = self)
ret = ret + " " + op + " " + nxt
if len(self.value) > 1:
ret = "(" + ret + ")"
return ret
class BfsAddSubOp():
def __init__(self, s, loc, tokens):
_check_add_sub(tokens)
self.value = tokens[0]
def eval(self, data, parent):
ret = self.value[0].eval(data, parent = self)
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j+1].eval(data, parent = self)
ret = ret + " " + op + " " + nxt
if len(self.value) > 1:
ret = "(" + ret + ")"
return ret
class BfsArrowOp():
def __init__(self, s, loc, tokens):
_check_ein(tokens)
self.value = tokens[0]
def eval(self, data, parent, allow_storing_data = False, call_from_main_evaluate = False):
if self.value[0].name not in ["ind_for_left_arrow", "tensor", "constant"]:
_raise_value_error("Incorrect usage of <= operator. On the left of <= there should be something of the form X_Y where X can be a new variable, or X can be empty string. Y should be set of indices or empty.")
if call_from_main_evaluate == False:
_raise_value_error("Not allowing nested assignment operators, as that might lead to ambiguity.")
if allow_storing_data == False:
if self.value[0].name != "ind_for_left_arrow":
_raise_value_error("Incorrect usage of <= operator. On the left of <= there should be _ or _ijk or similar.")
if "_" in self.value[0].value:
store_to = self.value[0].value.split("_")[0]
if store_to == "":
store_to = None
indices_want = self.value[0].value.split("_")[1]
else:
store_to = self.value[0].value
indices_want = ""
if len(indices_want) != len(list(set(indices_want))):
_raise_value_error("You have duplicate indices on the left of the assignment operator <=, <<=, or <+= .")
ret = {}
ret["value"] = self.value[2].eval(data, parent = self)
ret["ind"] = indices_want
return ret, store_to, self.value[1]
class UnitConstVar():
def __init__(self, s, loc, tokens):
_check_const(tokens)
self.text = tokens[0]
self.name = tokens.get_name()
def eval(self):
if self.name in ["float", "integer"]:
ret = float(self.text)
elif self.name == "basic_unit":
if self.text.lower() == "ev":
ret = Units(eV = 1)
elif self.text.lower() == "ang":
ret = Units(Ang = 1)
elif self.text.lower() == "mub":
ret = Units(muB = 1)
else:
_raise_value_error("Unknown unit: " + self.text + ".")
else:
_raise_value_error("Unknown name!")
return ret
class UnitPowerOp():
def __init__(self, s, loc, tokens):
_check_power(tokens)
self.value = tokens[0]
def eval(self):
trmL = self.value[0].eval()
trmR = self.value[2].eval()
if isinstance(trmL, Units) == False:
_raise_value_error("Must raise one of basic units to a power.")
if isinstance(trmR, float) == False:
_raise_value_error("Exponents must be simply a number.")
return trmL._exponent_float(trmR)
class UnitMultDivOp():
def __init__(self, s, loc, tokens):
_check_mult_div(tokens)
self.value = tokens[0]
def eval(self):
ret = self.value[0].eval()
if isinstance(ret, Units) == False:
_raise_value_error("Needs to be one of basic units.")
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j + 1]
tmpR = nxt.eval()
if isinstance(tmpR, Units) == False:
_raise_value_error("Needs to be one of basic units.")
if op == "*":
ret = ret._multiply(tmpR)
elif op == "/":
ret = ret._divide(tmpR)
else:
_raise_value_error("Wrong operator!?")
return ret
def _parse_value_and_units(raw_string):
ret = {}
string = raw_string.strip()
if " " not in string:
ret["value"] = float(string)
ret["units"] = Units(eV = 0, Ang = 0, muB = 0)
return ret
if string.count("*") == 0:
_raise_value_error("""If you are specifying this value as string, then the string must
be of the form \"3.0 * eV/Ang\" or similar. In other words, you _must_ include
multiplication sign between number and units.""")
sp = string.index("*")
ret["value"] = float(string[:sp])
units_str = string[sp + 1:]
# now parse units out of the string
operand = \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Word("-" + pp.nums, min = 1).set_results_name("integer") | \
pp.Word(pp.nums, min = 1).set_results_name("integer") | \
pp.Word(pp.alphas, min = 1).set_results_name("basic_unit")
operand.set_parse_action(UnitConstVar)
parser = pp.infix_notation(
operand,
[
( "^" , 2, pp.opAssoc.LEFT , UnitPowerOp ), # power raising
(pp.oneOf("* /") , 2, pp.opAssoc.LEFT , UnitMultDivOp ), # multiplication and division
],
)
par = _my_parse_string(parser, units_str, parse_all = True)
if len(par) != 1:
_raise_value_error("PyParsing returned something not expected?!")
par = par[0]
ret["units"] = par.eval()
return ret
class FundConst():
def __init__(self, cnst):
self._data = {"hbar": 0.0, "e": 0.0, "epszero": 0.0, "c": 0.0, "me": 0.0}
self._data[cnst] = 1.0
def _exponent_float(self, ex):
for k in self._data.keys():
self._data[k] *= ex
return self
def _multiply_with(self, obj):
for k in self._data.keys():
self._data[k] = self._data[k] + obj._data[k]
return self
def _divide_by(self, obj):
for k in self._data.keys():
self._data[k] = self._data[k] - obj._data[k]
return self
def _to_latex(self):
keys = list(self._data.keys())
keys.sort()
numerator = []
denominator = []
for k in keys:
if np.abs(self._data[k]) < 1.0E-8:
continue
tmp = _fund_const_in_latex(k)
if self._data[k] > 0.0:
sign = 1
else:
sign = -1
expon = str(Fraction(sign * self._data[k]).limit_denominator(1000))
if expon == "1":
expon = ""
else:
expon = r"^{" + expon.strip() + r"}"
if sign == 1:
numerator.append(tmp + expon)
else:
denominator.append(tmp + expon)
if len(numerator) == 0 and len(denominator) == 0:
return ""
if len(numerator) != 0 and len(denominator) == 0:
return " ".join(numerator)
if len(numerator) == 0 and len(denominator) != 0:
return r"\frac{1}{" + " ".join(denominator) + "}"
if len(numerator) != 0 and len(denominator) != 0:
return r"\frac{" + " ".join(numerator) + "}{" + " ".join(denominator) + "}"
def _numerical_value(self):
ret = 1.0
for k in self._data.keys():
ret *= np.power(_fund_const_to_numerical(k), self._data[k])
return ret
def _fund_const_in_latex(key):
if key == "hbar" : return r"\hbar"
if key == "e" : return r"e"
if key == "epszero": return r"\epsilon_0"
if key == "c" : return r"c"
if key == "me" : return r"m_{\rm e}"
_raise_value_error("Unknown fundamental constant.")
def _fund_const_to_numerical(key):
if key == "hbar" : return hbar_SI
if key == "e" : return electron_charge_SI
if key == "epszero": return epsilon_zero_SI
if key == "c" : return speed_of_light_SI
if key == "me" : return electron_mass_SI
_raise_value_error("Unknown fundamental constant.")
class PrefactorSIConstVar():
def __init__(self, s, loc, tokens):
_check_const(tokens)
self.text = tokens[0]
self.name = tokens.get_name()
def eval(self):
if self.name in ["float", "integer"]:
ret = float(self.text)
elif self.name == "fundamental_constant":
if self.text.lower() in ["hbar", "e", "epszero", "c", "me"]:
ret = FundConst(self.text.lower())
else:
_raise_value_error("Unknown fundamental constant: " + self.text + ".")
else:
_raise_value_error("Unknown name!")
return ret
class PrefactorSIPowerOp():
def __init__(self, s, loc, tokens):
_check_power(tokens)
self.value = tokens[0]
def eval(self):
trmL = self.value[0].eval()
trmR = self.value[2].eval()
if isinstance(trmL, FundConst) == False:
_raise_value_error("Only allowed to raise one of the fundamental constants to a power. " + str(trmL))
if isinstance(trmR, float) == False:
_raise_value_error("Exponents must be simply a number. " + str(trmR))
return trmL._exponent_float(trmR)
class PrefactorSIMultDivOp():
def __init__(self, s, loc, tokens):
_check_mult_div(tokens)
self.value = tokens[0]
def eval(self):
ret = self.value[0].eval()
if isinstance(ret, FundConst) == False:
_raise_value_error("Only allowed to multiply/divide fundamental constants. " + str(ret))
for j in range(1, len(self.value), 2):
op = self.value[j]
nxt = self.value[j + 1]
tmpR = nxt.eval()
if isinstance(tmpR, FundConst) == False:
_raise_value_error("Only allowed to multiply/divide fundamental constants. " + str(tmpR))
if op == "*":
ret = ret._multiply_with(tmpR)
elif op == "/":
ret = ret._divide_by(tmpR)
else:
_raise_value_error("Wrong operator!?")
return ret
def _parse_prefactor_SI_units_fundamental_constants(raw_string):
string = raw_string.strip()
operand = \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine("-" + pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Word("-" + pp.nums, min = 1).set_results_name("integer") | \
pp.Word(pp.nums, min = 1).set_results_name("integer") | \
pp.Word(pp.alphas, min = 1).set_results_name("fundamental_constant")
operand.set_parse_action(PrefactorSIConstVar)
parser = pp.infix_notation(
operand,
[
( "^" , 2, pp.opAssoc.LEFT , PrefactorSIPowerOp ), # power raising
(pp.oneOf("* /") , 2, pp.opAssoc.LEFT , PrefactorSIMultDivOp ), # multiplication and division
],
)
par = _my_parse_string(parser, string, parse_all = True)
if len(par) != 1:
_raise_value_error("PyParsing returned something not expected?!")
par = par[0]
return par.eval()
def _are_hashes_similar_relative(hash_0, hash_1, tol = 1.0E-5):
if np.max(np.abs(hash_0 - hash_1))/np.max(np.abs(hash_1)) < tol:
return True
else:
return False
def _are_hashes_similar_absolute(hash_0, hash_1, tol = 1.0E-5):
if np.max(np.abs(hash_0 - hash_1)) < tol:
return True
else:
return False
def _get_kpoint_label_info(cell, name, red):
if red.shape[0] != 1:
_raise_value_error("Crystal structure not supported")
p = cell[0, 2]
if np.abs(p) < 1.0E-9:
_raise_value_error("Crystal structure not supported")
tmp_bcc = np.array([[ p, p, p], [-p, p, p], [-p,-p, p]])
tmp_fcc = np.array([[-p, 0, p], [ 0, p, p], [-p, p, 0]])
if np.max(np.abs(cell - tmp_bcc)) < 1.0E-9:
struc_kind = "bcc"
elif np.max(np.abs(cell - tmp_fcc)) < 1.0E-9:
struc_kind = "fcc"
else:
_raise_value_error("Crystal structure not supported")
# the nomenclature is from bilbao crystallography server
if struc_kind == "bcc":
ret = {
"GM": [[0. , 0. , 0. ], r"$\Gamma$"],
"H" : [[0. , 1./2., 0. ]],
"N" : [[1./4., 1./4., 0. ]],
"P" : [[1./4., 1./4., 1./4.]],
}
recip_alt = cell[0, 2] * _real_to_recip_no2pi([cell[0], cell[1], cell[2]])
elif struc_kind == "fcc":
ret = {
"GM": [[0. , 0. , 0. ], r"$\Gamma$"],
"X" : [[0. , 1./2., 0. ]],
"M" : [[1./2., 1./2., 0. ]],
"U" : [[1./8., 1./2., 1./8.]],
"K" : [[3./8., 3./8., 0. ]],
"L" : [[1./4., 1./4., 1./4.]],
"W" : [[1./4., 1./2., 0. ]],
}
recip_alt = cell[0, 2] * _real_to_recip_no2pi([cell[0], cell[1], cell[2]])
else:
_raise_value_error("Crystal structure not supported")
# convert to conventions used in QE
for k in ret.keys():
ret[k][0] = _cart_to_red(recip_alt[0], recip_alt[1], recip_alt[2], ret[k][0])
return ret
def _cart_to_red(a1, a2, a3, cart):
cnv = np.array([a1, a2, a3])
cnv = cnv.T
cnv = np.linalg.inv(cnv)
return np.dot(cnv, cart)
def _red_to_cart(a1, a2, a3, red):
return np.array(red[0]*a1 + red[1]*a2 + red[2]*a3)
def _real_to_recip_no2pi(real):
ret = []
ret.append(np.cross(real[1], real[2]))
ret.append(np.cross(real[2], real[0]))
ret.append(np.cross(real[0], real[1]))
ret = np.array(ret) / np.dot(real[0], np.cross(real[1], real[2]))
return ret
def _get_operand():
return \
pp.Combine(pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1) + "j").set_results_name("imaginary float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + "." + "j").set_results_name("imaginary float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + "j").set_results_name("imaginary integer") | \
pp.Literal("j").set_results_name("imaginary one") | \
pp.Combine(pp.Word(pp.nums, min = 1) + "." + pp.Word(pp.nums, min = 1)).set_results_name("float") | \
pp.Combine(pp.Word(pp.nums, min = 1) + ".").set_results_name("float") | \
pp.Word(pp.nums, min = 1).set_results_name("integer") | \
pp.Combine("_" + pp.Word(pp.alphas, min = 1)).set_results_name("ind_for_left_arrow") | \
pp.Literal("_").set_results_name("ind_for_left_arrow") | \
pp.Combine(pp.Word(pp.alphas, min = 1) + "~" + pp.Word(pp.alphas, min = 1) + "_" + pp.Word(pp.alphas, min = 1)).set_results_name("tensor") | \
pp.Combine(pp.Word(pp.alphas, min = 1) + "~" + pp.Word(pp.alphas, min = 1)).set_results_name("constant") | \
pp.Combine(pp.Word(pp.alphas, min = 1) + "_" + pp.Word(pp.alphas, min = 1)).set_results_name("tensor") | \
pp.Word(pp.alphas, min = 1).set_results_name("constant")
def _my_parse_string(obj, *args, **kwargs):
try:
return obj.parse_string(*args, **kwargs)
except pp.ParseException as pe:
traceback.print_stack()
print()
print("\n" +
_format_one_block_simple_indent("Problem with parsing of this string (note, problem might occur after symbol ^ below):\n\n" + pe.explain(depth = 0),
indent = 0,
string_indent = "+++ ",
start_and_end = True,
special_string = "+") + \
"\n")
exit()
def _raise_value_error(s):
raise ValueError("\n\n" + "*"*81 + "\n" +
_format_one_block(s,
indent = None,
width = 80,
initial_indent = "*** ",
subsequent_indent = "*** ") + \
"\n" + "*"*81 + "\n")
def _print_without_stopping(s):
print("\n\n" + "&"*81 + "\n" +
_format_one_block(s,
indent = None,
width = 80,
initial_indent = "&&& ",
subsequent_indent = "&&& ") + \
"\n" + "&"*81 + "\n")
def _nice_exec(code_txt, code_dic):
try:
exec_time = time.perf_counter()
exec(code_txt, code_dic)
exec_time = time.perf_counter() - exec_time
return exec_time
except Exception as err:
traceback.print_stack()
print()
print(repr(err) + " exception occured while trying to execute this code:")
#
if err.__class__.__name__ == "SyntaxError":
line_number = err.lineno
else:
cl, exc, tb = sys.exc_info()
line_number = traceback.extract_tb(tb)[-1][1]
#
tmp = code_txt.split("\n")
code_use_txt = ""
for i in range(len(tmp)):
if i != line_number - 1:
code_use_txt += " " + tmp[i] + "\n"
else:
code_use_txt += "==PROBLEM==> " + tmp[i] + "\n"
out = _format_one_block_simple_indent(code_use_txt, indent = 0, start_and_end = True, special_string = "$")
print(out)
raise err
class _LatexExpression():
def __init__(self, core, ind, rhs, prefactor = ""):
self._core = core.strip()
self._ind = ind.strip()
self._prefactor = prefactor.strip()
self._rhs = rhs.strip()
def get_string(self, inside_align = False):
latex_source = ""
if self._core != "":
latex_source += _nicefy_core(self._core)
if self._ind != "":
latex_source += r"_{" + _nicefy_subscript(self._ind) + r"}"
if self._core != "" and self._rhs != "":
if inside_align == False:
latex_source += r" \Leftarrow "
else:
latex_source += r" & \Leftarrow "
if self._rhs != "":
if self._prefactor != "":
latex_source += r"\left[ " + self._prefactor + r" \right] \times \left[ " + self._rhs + r" \right]"
else:
latex_source += self._rhs
return latex_source
def __str__(self):
return self.get_string()
def _format_one_block(raw_text, indent = 4, width = 80, initial_indent = None, subsequent_indent = None):
text = raw_text.strip()
if "BREAK" in text:
sp = text.split("BREAK")
ret = []
for s in sp:
ret.append(_format_one_block(s, indent, width, initial_indent, subsequent_indent))
return ("\n" + subsequent_indent + "\n").join(ret)
text = text.replace("\n", "")
while " " in text:
text = text.replace(" ", " ")
# while " " in text:
# text = text.replace(" ", " ")
# while "\n " in text:
# text = text.replace("\n ", " ")
# while " \n" in text:
# text = text.replace(" \n", " ")
if initial_indent is None:
initial_indent = " "*indent
if subsequent_indent is None:
subsequent_indent = " "*indent
wrapper = textwrap.TextWrapper(width = width, # this includes indents
expand_tabs = True,
tabsize = 4,
replace_whitespace = True,
drop_whitespace = True,
initial_indent = initial_indent,
subsequent_indent = subsequent_indent,
break_long_words = True,
break_on_hyphens = True)
return wrapper.fill(text = text)
def _format_one_block_simple_indent(raw_text, indent = 4, string_indent = "", start_and_end = True, max_line = None, dont_indent_first = False, special_string = "#"):
max_len = 1
ret = ""
sp = raw_text.split("\n")
for i, s in enumerate(sp):
if max_line is not None:
if i >= max_line:
one = " " * indent + string_indent + "... cutting long output ... set parameter \"full\" to True to get complete output." + "\n"
if len(one) > max_len:
max_len = len(one)
ret += one
break
if dont_indent_first == False:
one = " "*indent + string_indent + s + "\n"
else:
if i == 0:
one = string_indent + s + "\n"
else:
one = " "*indent + string_indent + s + "\n"
if len(one) > max_len:
max_len = len(one)
ret += one
if start_and_end == True:
tmp = special_string*(max_len + 2) + "\n"
else:
tmp = ""
ret = tmp + ret.rstrip() + "\n" + tmp
return ret
def _make_rst_title(title):
out = ""
out += "="*(len(title.strip()) + 2)
out += "\n"
out += " " + title.strip()
out += "\n"
out += "="*(len(title.strip()) + 2)
out += "\n"
return out
def _make_rst_subtitle(subtitle):
out = ""
out += " " + subtitle.strip()
out += "\n"
out += "-"*(len(subtitle.strip()) + 2)
out += "\n"
return out
def _make_rst_field(title):
out = ""
out += ":"
out += title.strip()
out += ":"
out += "\n"
return out
def _find_1_to_1_map_from_left_to_right(left, right):
if len(left) != len(right):
return None
given_left_return_right = {}
for i in range(len(left)):
l = left[i]
r = right[i]
if l in given_left_return_right.keys():
if given_left_return_right[l] != r:
return None
else:
given_left_return_right[l] = r
return given_left_return_right
def _convert_left_to_right(left, given_left_return_right):
ret = ""
for l in left:
ret += given_left_return_right[l]
return ret
class _InterfaceToWberri(wberri.System_w90):
# This is a class that derives from Wannier Berri's System_w90 class.
# This class is only used as an interface to Wannier Berri, and this
# class is therefore not supposed to be used directly by the user.
def __init__(self, syst_raw = None):
if syst_raw is not None:
self.mp_grid = deepcopy(syst_raw.mp_grid)
self.num_wann = deepcopy(syst_raw.num_wann)
self.periodic = deepcopy(syst_raw.periodic)
self.real_lattice = deepcopy(syst_raw.real_lattice)
self.recip_lattice = deepcopy(syst_raw.recip_lattice)
self.use_wcc_phase = deepcopy(syst_raw.use_wcc_phase)
self.iRvec = deepcopy(syst_raw.iRvec)
self._XX_R = deepcopy(syst_raw._XX_R)
self.__add_empty_symmetry()
def fill_in_from_dictionary(self, ff):
self.mp_grid = ff["mp_grid"]
self.num_wann = ff["num_wann"]
self.periodic = ff["periodic"]
self.real_lattice = ff["real_lattice"]
self.recip_lattice = ff["recip_lattice"]
self.use_wcc_phase = ff["use_wcc_phase"]
self.iRvec = ff["iRvec"]
self.__add_empty_symmetry()
do_regular = False
if "use_reduced_sym" in ff.keys():
if ff["use_reduced_sym"] == True:
self._XX_R = _from_reduced_dic_to_XX(ff)
else:
do_regular = True
else:
do_regular = True
if do_regular == True:
self._XX_R = {}
for k in ff.keys():
if len(k) > 7:
if k[:7] == "_XX_R__":
self._XX_R[k[7:]] = _few_to_many_bits(ff[k])
def __add_empty_symmetry(self):
# Symmetry is not stored in a numpy array
# Therefore I need to reconstruct symmetry from scratch
# as I don't want to store this object.
self.symgroup = wberri.symmetry.Group([],
recip_lattice = self.recip_lattice,
real_lattice = self.real_lattice)
def _replace_with_reduced_data(fn_in, fn_out, reduced_XX, reduced_XX_common):
g = open(fn_in, "rb")
gg = np.load(g, allow_pickle = False)
f = open(fn_out, "wb")
# get stuff that is already stored, except for full matrices
data = {}
for k in gg.keys():
if len(k) > 7:
if k[:7] == "_XX_R__":
continue
data[k] = gg[k]
# add stuff for symmetry reduction
data["use_reduced_sym"] = True
for X in reduced_XX.keys():
for k in reduced_XX[X].keys():
data["reduced_sym__" + X.strip() + "__" + k.strip()] = _many_to_few_bits(reduced_XX[X][k])
for k in reduced_XX_common.keys():
data["reduced_sym__common__" + k.strip()] = _many_to_few_bits(reduced_XX_common[k])
np.savez_compressed(f, **data)
f.close()
g.close()
def _from_reduced_dic_to_XX(ff):
all_X = []
for k in ff.keys():
if len(k) > 13:
if k[:13] == "reduced_sym__":
sp = k.split("__")
if len(sp) != 3:
print("problem")
exit()
if sp[1] == "common":
continue
all_X.append(sp[1])
all_X = list(sorted(set(all_X)))
s_rotc = ff["reduced_sym__common__s_rotc"]
s_inv = ff["reduced_sym__common__s_inv"]
s_tr = ff["reduced_sym__common__s_tr"]
r_irr = ff["reduced_sym__common__r_irr"]
s_max_denom = ff["reduced_sym__common__s_max_denom"]
s_orb = _squarerootform_to_complex(ff["reduced_sym__common__s_orb"], s_max_denom)
r_rel = _oned_to_llist(ff["reduced_sym__common__r_rel"])
r_rel_oper = _oned_to_llist(ff["reduced_sym__common__r_rel_oper"])
r_star = _oned_to_llist(ff["reduced_sym__common__r_star"])
_XX_R = {}
for X in all_X:
mat_in_eigen_space = _oned_to_llist(ff["reduced_sym__" + X + "__mat_in_eigen_space"], nump = True)
#
e_max_denom = ff["reduced_sym__" + X + "__e_max_denom"]
e_data = _squarerootform_to_complex(ff["reduced_sym__" + X + "__e_data"], e_max_denom)
#
cache_eig_matrices = _oned_to_complicated(ff["reduced_sym__" + X + "__e_info"],
e_data,
ff["reduced_sym__" + X + "__e_keys"])
reconstructed = _from_eigen_to_general(X,
mat_in_eigen_space,
cache_eig_matrices,
s_rotc,
s_inv,
s_tr,
s_orb,
r_irr,
r_rel,
r_rel_oper,
r_star,
ff["reduced_sym__" + X + "__parity_I"],
ff["reduced_sym__" + X + "__parity_TR"])
_XX_R[X] = _few_to_many_bits(reconstructed)
return _XX_R
def _many_to_few_bits(x):
if not isinstance(x, np.ndarray):
return x
if x.dtype == np.complex128:
return np.array(x, dtype = np.complex64)
elif x.dtype == np.float64:
return np.array(x, dtype = np.float32)
else:
return x
def _few_to_many_bits(x):
if not isinstance(x, np.ndarray):
return x
if x.dtype == np.complex64:
return np.array(x, dtype = np.complex128)
elif x.dtype == np.float32:
return np.array(x, dtype = np.float64)
else:
return x
def _write_interface_to_wberri_to_file(fname, obj, add_info = {}):
f = open(fname, "wb")
data={
"mp_grid": obj.mp_grid,
"num_wann": obj.num_wann,
"periodic": obj.periodic,
"real_lattice": obj.real_lattice,
"recip_lattice": obj.recip_lattice,
"use_wcc_phase": obj.use_wcc_phase,
"iRvec": obj.iRvec,
}
for k in obj._XX_R.keys():
data["_XX_R" + "__" + k] = _many_to_few_bits(obj._XX_R[k])
for keys in add_info.keys():
data["add_info__" + keys.strip()] = add_info[keys]
np.savez_compressed(f, **data)
f.close()
def _read_interface_to_wberri_from_file(fname):
f = open(fname, "rb")
ff = np.load(f, allow_pickle = False)
ret = _InterfaceToWberri(None)
ret.fill_in_from_dictionary(ff)
add_info = {}
for k in ff.keys():
if "add_info__" in k:
if ff[k].dtype.type is np.string_:
add_info[k.replace("add_info__", "")] = str(ff[k])
else:
add_info[k.replace("add_info__", "")] = ff[k]
f.close()
return ret, add_info
def _append_to_npz(fin, fout, extra_add_info, overwrite_key = False):
if os.path.exists(fout) == True:
_raise_value_error("File " + fout + " already exists. Stopping.")
f = open(fin, "rb")
ff = np.load(f, allow_pickle = False)
data = dict(ff)
for k in extra_add_info.keys():
usek = "add_info__" + k.strip()
if overwrite_key == False and usek in data.keys():
_raise_value_error("Duplicate key " + usek + ". Stopping.")
data[usek] = extra_add_info[k]
g = open(fout, "wb")
np.savez_compressed(g, **data)
g.close()
f.close()
def _llist_to_oned(ll):
num_terms = len(ll)
len_each_term = []
for l in ll:
len_each_term.append(len(l))
ret = [num_terms] + len_each_term
for l in ll:
ret = ret + l
ret = np.array(ret)
if ret.ndim != 1:
print("Not a list of lists?!")
exit()
return ret
def _oned_to_llist(oned, nump = False):
num_terms = int(np.round(np.real(oned[0])))
len_each_term = np.array(np.round(np.real(oned[1 : num_terms + 1])), dtype = int).tolist()
use = oned[num_terms + 1:]
ret = []
for i in range(num_terms):
if nump == False:
ret.append(use[:len_each_term[i]].tolist())
else:
ret.append(np.array(use[:len_each_term[i]]))
use = use[len_each_term[i]:]
if len(use) != 0:
print("Missing!")
exit()
return ret
def _np_to_ls(lmat):
ret = []
for l in lmat:
ret.append(l.tolist())
return ret
def _complicated_to_oned(dicmat):
keys = list(dicmat.keys())
keys.sort()
num_terms = len(keys)
first_shape_each_term = []
for k in keys:
first_shape_each_term.append(dicmat[k].shape[0])
remaining_shape = list(dicmat[keys[0]].shape[1:])
for k in keys:
if remaining_shape != list(dicmat[k].shape[1:]):
print(remaining_shape, list(dicmat[k].shape[1:]))
print("Inconsistent shapes!")
exit()
ret_info = [num_terms] + first_shape_each_term + [len(remaining_shape)] + remaining_shape
ret_data = np.copy(dicmat[keys[0]])
for k in keys[1:]:
ret_data = np.vstack((ret_data, dicmat[k]))
keys = ",".join(keys)
return np.array(ret_info), ret_data, np.array(keys)
def _oned_to_complicated(e_info, e_data, e_keys):
num_terms = e_info[0]
first_shape_each_term = e_info[1 : num_terms + 1]
#remaining_shape_len = e_info[num_terms + 1]
#remaining_shape = e_info[num_terms + 2 : num_terms + 2 + remaining_shape_len]
keys = str(e_keys).split(",")
dicmat = {}
start = 0
for i, k in enumerate(keys):
dicmat[k] = e_data[start : start + first_shape_each_term[i]]
start += first_shape_each_term[i]
if start != e_data.shape[0]:
print("Missing terms!")
exit()
return dicmat
def _complex_to_squarerootform(mat, max_denom):
tmp_re = float(max_denom)*(np.real(mat)**2)*np.sign(np.real(mat))
tmp_im = float(max_denom)*(np.imag(mat)**2)*np.sign(np.imag(mat))
if np.max(np.abs(tmp_re - np.round(tmp_re))) > 1.0E-8:
print(np.max(np.abs(tmp_re - np.round(tmp_re))))
print(tmp_re - np.round(tmp_re))
print("Problem!")
exit()
if np.max(np.abs(tmp_im - np.round(tmp_im))) > 1.0E-8:
print(np.max(np.abs(tmp_im - np.round(tmp_im))))
print("Problem! im")
exit()
if np.max(np.abs(tmp_re)) > 128.0*256.0 - 1.0:
print("Too large!")
exit()
if np.max(np.abs(tmp_im)) > 128.0*256.0 - 1.0:
print("Too large!")
exit()
ret = np.array([np.array(np.round(tmp_re), dtype = np.int16),
np.array(np.round(tmp_im), dtype = np.int16)])
return ret
def _squarerootform_to_complex(mat, max_denom):
tmp_re = np.array(mat[0], dtype = float)
tmp_im = np.array(mat[1], dtype = float)
tmp_re = np.sqrt(np.abs(tmp_re) / float(max_denom)) * np.sign(tmp_re)
tmp_im = np.sqrt(np.abs(tmp_im) / float(max_denom)) * np.sign(tmp_im)
ret = tmp_re + 1.0j*tmp_im
ret[np.abs(ret) < 1.0E-8] = 0.0
return ret
def _from_eigen_to_general(X, mat_in_eigen_space, cache_eig_matrices, s_rotc, s_inv, s_tr, s_orb, r_irr, r_rel, r_rel_oper, r_star, parity_I, parity_TR):
maxR = np.max(r_irr)
for ll in r_rel:
maxR = np.max([maxR, np.max(ll)])
shp = list(cache_eig_matrices[list(cache_eig_matrices.keys())[0]].shape)
shp = shp[1:3] + [maxR + 1] + shp[3:]
ret = np.zeros(shp, dtype = complex)
populated = []
for ii in range(len(r_irr)):
pick_one_ind = r_irr[ii]
subgr = " ".join(list(map(str, r_star[ii])))
mat_reconstructed = opteinsum("b, b... -> ...", mat_in_eigen_space[ii], cache_eig_matrices[subgr])
ret[:,:,pick_one_ind] = mat_reconstructed
populated.append(pick_one_ind)
for i in range(len(r_rel[ii])):
tmp_trns = _transform_one(s_rotc, s_inv, s_tr, s_orb, r_rel_oper[ii][i], mat_reconstructed, parity_I, parity_TR)
if r_rel[ii][i] not in populated:
ret[:,:,r_rel[ii][i]] = tmp_trns
else:
if np.max(np.abs(tmp_trns - ret[:,:,r_rel[ii][i]])) > 1.0E-10:
print("Hm, this should not happen. One can arrive at this R vector by applying to different operations to different irreducible R vectors.")
print("But, depending on which combination you use, you get a different matrix?!")
exit()
populated.append(r_rel[ii][i])
return ret
def _rotate_matrix(X, L, R):
if X.ndim == 2:
return L.dot(X).dot(R)
elif X.ndim == 3:
X_shift = X.transpose(2, 0, 1)
tmpX = L.dot(X_shift).dot(R)
return tmpX.transpose(0, 2, 1).reshape(X.shape)
else:
raise ValueError()
def _transform_one(s_rotc, s_inv, s_tr, s_orb, i, mat_in, parity_I, parity_TR):
mat = np.copy(mat_in)
# on AA and SS you need to deal with vectors
# here I daggered everything relative to wberri as I'm rotating from irr to rotated
if mat.ndim == 3:
mat = np.tensordot(mat, s_rotc[i].T.conj(), axes=1).reshape(mat.shape)
elif mat.ndim > 3:
raise ValueError("transformation of tensors is not implemented")
if s_inv[i] == True:
mat *= parity_I
if s_tr[i] == True:
mat = (mat/parity_TR).conj()
return _rotate_matrix(mat, s_orb[i], s_orb[i].T.conj())
def _nice_exp(x):
arg = np.maximum(np.minimum(x, 100.0), -100.0)
return np.exp(arg)
def _fermi_dirac(en, mu, kbt):
return 1.0 / (_nice_exp((en - mu)/kbt) + 1.0)
def _fermi_dirac_deriv(en, mu, kbt):
return (-1.0 / (_nice_exp((en - mu)/kbt) + 1.0)**2) * _nice_exp((en - mu)/kbt) * (1.0 / kbt)
def _potentially_reorder_orbitals(mat, ind, reorder_orbitals):
if reorder_orbitals == True:
if ind == 0:
d = mat.shape[0]
return np.reshape(np.transpose(np.reshape(mat, (d//2, 2)), (1, 0)), (d))
elif ind == 2:
d = mat.shape[2]
return np.reshape(np.transpose(np.reshape(mat, (mat.shape[0], mat.shape[1], d//2, 2)), (0, 1, 3, 2)), (mat.shape[0], mat.shape[1], d))
else:
_raise_value_error("This should not happen.")
else:
return mat
def _adjust_input_file(kind, text):
if kind == "scf" or kind == "nscf":
ln = text.split("\n")
ret = ""
for l in ln:
if "pseudo_dir = " in l:
ret += l.split("=")[0] + r"= '.'" + "\n"
else:
ret += l + "\n"
elif kind == "pw2wan":
ln = text.split("\n")
ret = ""
for l in ln:
if "outdir = " in l:
ret += l.split("=")[0] + r"= '_work'" + "\n"
else:
ret += l + "\n"
else:
_raise_value_error("This should not happen.")
return ret
def _stop_because_loaded_from_wannierberri():
traceback.print_stack()
print("\n\n\n")
_raise_value_error("""You loaded your calculation directly from Wannier Berri using
function load_from_wannierberri, and not through the WfBase database. Therefore
some information you are trying to access is not available""")