Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix band stats k-point labels #253

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sumo/cli/bandplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def bandplot(
if code == "vasp":
for vr_file in filenames:
vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
bs = vr.get_band_structure(line_mode=True)
bs = vr.get_band_structure(line_mode=True, efermi="smart")
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures)
elif code == "castep":
Expand Down
37 changes: 28 additions & 9 deletions sumo/cli/bandstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,26 @@ def bandstats(
bandstructures = []
for vr_file in filenames:
vr = BSVasprun(vr_file, parse_projected_eigen=False)
bs = vr.get_band_structure(line_mode=True)
bs = vr.get_band_structure(line_mode=True, efermi="smart")
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures, force_kpath_branches=False)
bs, kpt_mapping = get_reconstructed_band_structure(
bandstructures, force_kpath_branches=True, return_forced_branch_kpt_map=True
)

if bs.is_metal():
logging.error("ERROR: System is metallic!")
sys.exit()

_log_band_gap_information(bs)
_log_band_gap_information(bs, kpt_mapping=kpt_mapping)

vbm_data = bs.get_vbm()
cbm_data = bs.get_cbm()

logging.info("\nValence band maximum:")
_log_band_edge_information(bs, vbm_data)
_log_band_edge_information(bs, vbm_data, kpt_mapping=kpt_mapping)

logging.info("\nConduction band minimum:")
_log_band_edge_information(bs, cbm_data)
_log_band_edge_information(bs, cbm_data, kpt_mapping=kpt_mapping)

if parabolic:
logging.info("\nUsing parabolic fitting of the band edges")
Expand Down Expand Up @@ -179,11 +181,14 @@ def bandstats(
return {"hole_data": hole_data, "electron_data": elec_data}


def _log_band_gap_information(bs):
def _log_band_gap_information(bs, kpt_mapping=None):
"""Log data about the direct and indirect band gaps.

Args:
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
band structure with forced branches to the original band structure.

"""
bg_data = bs.get_band_gap()
if not bg_data["direct"]:
Expand All @@ -199,6 +204,7 @@ def _log_band_gap_information(bs):
direct_kpoint = bs.kpoints[direct_kindex].frac_coords
direct_kpoint = kpt_str.format(k=direct_kpoint)
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
k_indices = ", ".join(map(str, eq_kpoints))

# add 1 to band indices to be consistent with VASP band numbers.
Expand All @@ -215,7 +221,9 @@ def _log_band_gap_information(bs):

direct_kindex = direct_data[Spin.up]["kpoint_index"]
direct_kpoint = kpt_str.format(k=bs.kpoints[direct_kindex].frac_coords)
k_indices = ", ".join(map(str, bs.get_equivalent_kpoints(direct_kindex)))
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
k_indices = ", ".join(map(str, eq_kpoints))
b_indices = ", ".join(
[str(i + 1) for i in direct_data[Spin.up]["band_indices"]]
)
Expand All @@ -225,14 +233,16 @@ def _log_band_gap_information(bs):
logging.info(f" Band indices: {b_indices}")


def _log_band_edge_information(bs, edge_data):
def _log_band_edge_information(bs, edge_data, kpt_mapping=None):
"""Log data about the valence band maximum or conduction band minimum.

Args:
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
The band structure.
edge_data (dict): The :obj:`dict` from ``bs.get_vbm()`` or
``bs.get_cbm()``
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
band structure with forced branches to the original band structure.
"""
if bs.is_spin_polarized:
spins = edge_data["band_index"].keys()
Expand All @@ -247,7 +257,9 @@ def _log_band_edge_information(bs, edge_data):

kpoint = edge_data["kpoint"]
kpoint_str = kpt_str.format(k=kpoint.frac_coords)
k_indices = ", ".join(map(str, edge_data["kpoint_index"]))
k_indices = ", ".join(
map(str, _map_kpoints(edge_data["kpoint_index"], kpt_mapping))
)
k_degen = bs.get_kpoint_degeneracy(kpoint=kpoint.frac_coords)

if kpoint.label:
Expand Down Expand Up @@ -311,6 +323,13 @@ def _log_effective_mass_data(data, is_spin_polarized, mass_type="m_e"):
logging.info(f" {mass_type}: {eff_mass:.3f} | {band_str} | {kpoint_str}")


def _map_kpoints(kpt_idxs, kpt_mapping):
"""Map k-point indices to the original band structure."""
if not kpt_mapping:
return kpt_idxs
return sorted(set([kpt_mapping.get(k, k) for k in kpt_idxs]))


def _get_parser():
parser = argparse.ArgumentParser(
description="""
Expand Down
37 changes: 29 additions & 8 deletions sumo/electronic_structure/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def get_projections(bs, selection, normalise=None):
return spec_proj


def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=True):
def get_reconstructed_band_structure(
list_bs, efermi=None, force_kpath_branches=True, return_forced_branch_kpt_map=False
):
"""Combine a list of band structures into a single band structure.

This is typically very useful when you split non self consistent
Expand All @@ -210,12 +212,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
across all band structures is used.
force_kpath_branches (bool): Force a linemode band structure to contain
branches by adding repeated high-symmetry k-points in the path.
return_forced_branch_kpt_map (bool): If True, return a mapping of the
the new k-points to the original k-points.

Returns:
:obj:`pymatgen.electronic_structure.bandstructure.BandStructure` or \
:obj:`pymatgen.electronic_structure.bandstructureBandStructureSymmLine`:
A band structure object. The type depends on the type of the band
structures in ``list_bs``.
If return_forced_branch_kpt_map is True, then a tuple is returned
containing the band structure and the mapping from the new k-points
to the original k-points.
"""
if efermi is None:
efermi = sum(b.efermi for b in list_bs) / len(list_bs)
Expand Down Expand Up @@ -244,13 +251,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
structure=list_bs[0].structure,
projections=projections,
)
if force_kpath_branches:
return force_branches(bs)
else:
return bs
branch_bs, mapping = force_branches(bs, return_mapping=True)
if force_kpath_branches and return_forced_branch_kpt_map:
return branch_bs, mapping
elif force_kpath_branches:
return branch_bs
elif return_forced_branch_kpt_map:
return bs, mapping
return bs


def force_branches(bandstructure):
def force_branches(bandstructure, return_mapping=False):
"""Force a linemode band structure to contain branches.

Branches give a specific portion of the path from one high-symmetry point
Expand All @@ -262,9 +273,14 @@ def force_branches(bandstructure):

Args:
bandstructure: A band structure object.
return_mapping: If True, return a mapping of the new k-points (with branches)
to the original k-points.

Returns:
A band structure with brnaches.
A band structure with branches.
If return_forced_branch_kpt_map is True, then a tuple is returned
containing the band structure and the mapping from the new k-points
to the original k-points.
"""
kpoints = np.array([k.frac_coords for k in bandstructure.kpoints])
labels_dict = {k: v.frac_coords for k, v in bandstructure.labels_dict.items()}
Expand All @@ -275,6 +291,7 @@ def force_branches(bandstructure):
# already.
dup_ids = []
high_sym_kpoints = tuple(map(tuple, labels_dict.values()))
mapping = {}
for i, k in enumerate(kpoints):
dup_ids.append(i)
if (
Expand All @@ -287,6 +304,7 @@ def force_branches(bandstructure):
)
):
dup_ids.append(i)
mapping[len(dup_ids) - 1] = i

kpoints = kpoints[dup_ids]

Expand All @@ -297,7 +315,7 @@ def force_branches(bandstructure):
if len(bandstructure.projections) != 0:
projections[spin] = bandstructure.projections[spin][:, dup_ids]

return type(bandstructure)(
bs = type(bandstructure)(
kpoints,
eigenvals,
bandstructure.lattice_rec,
Expand All @@ -306,6 +324,9 @@ def force_branches(bandstructure):
structure=bandstructure.structure,
projections=projections,
)
if return_mapping:
return bs, mapping
return bs


def string_to_spin(spin_string):
Expand Down
2 changes: 1 addition & 1 deletion sumo/electronic_structure/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def load_dos(
else:
vr = vasprun

band = vr.get_band_structure()
band = vr.get_band_structure(efermi="smart")
dos = vr.complete_dos

dos, band = _scissor_dos(dos, band, scissor)
Expand Down
Loading