Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of duplicated atom names in vasp output #250

Merged
merged 8 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dpdata.vasp.outcar
import numpy as np
from dpdata.format import Format

from dpdata.utils import sort_atom_names, uniq_atom_name

@Format.register("poscar")
@Format.register("contcar")
Expand All @@ -14,7 +14,9 @@ class VASPPoscarFormat(Format):
def from_system(self, file_name, **kwargs):
with open(file_name) as fp:
lines = [line.rstrip('\n') for line in fp]
return dpdata.vasp.poscar.to_system_data(lines)
data = dpdata.vasp.poscar.to_system_data(lines)
data = uniq_atom_name(data)
return data

def to_system(self, data, file_name, frame_idx=0, **kwargs):
"""
Expand Down Expand Up @@ -71,6 +73,7 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
for ii in range(data['cells'].shape[0]):
vol = np.linalg.det(np.reshape(data['cells'][ii], [3, 3]))
data['virials'][ii] *= v_pref * vol
data = uniq_atom_name(data)
return data


Expand Down Expand Up @@ -102,4 +105,5 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
for ii in range(data['cells'].shape[0]):
vol = np.linalg.det(np.reshape(data['cells'][ii], [3, 3]))
data['virials'][ii] *= v_pref * vol
data = uniq_atom_name(data)
return data
55 changes: 9 additions & 46 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from dpdata.plugin import Plugin
from dpdata.format import Format

from dpdata.utils import (
elements_index_map,
remove_pbc,
sort_atom_names,
add_atom_names,
)

def load_format(fmt):
fmt = fmt.lower()
formats = Format.get_formats()
Expand Down Expand Up @@ -341,27 +348,7 @@ def sort_atom_names(self, type_map=None):
type_map : list
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert (set(self.data['atom_names']).issubset(set(type_map)))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(self.data["atom_names"])
new_atoms = [e for e in type_map if e not in self.data["atom_names"]]
if new_atoms:
self.add_atom_names(new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(self.data['atom_names'])[np.argsort(np.argsort(type_map))]
else:
# index that will sort an array by alphabetical order
idx = np.argsort(self.data['atom_names'])
# sort atom_names, atom_numbs, atom_types by idx
self.data['atom_names'] = list(np.array(self.data['atom_names'])[idx])
self.data['atom_numbs'] = list(np.array(self.data['atom_numbs'])[idx])
self.data['atom_types'] = np.argsort(idx)[self.data['atom_types']]
self.data = sort_atom_names(self.data, type_map=type_map)

def check_type_map(self, type_map):
"""
Expand Down Expand Up @@ -489,8 +476,7 @@ def add_atom_names(self, atom_names):
"""
Add atom_names that do not exist.
"""
self.data['atom_names'].extend(atom_names)
self.data['atom_numbs'].extend([0 for _ in atom_names])
self.data = add_atom_names(self.data, atom_names)

def replicate(self, ncopy):
"""
Expand Down Expand Up @@ -1298,26 +1284,3 @@ def check_LabeledSystem(data):
assert( len(data['cells']) == len(data['coords']) == len(data['energies']) )


def elements_index_map(elements,standard=False,inverse=False):
if standard:
elements.sort(key=lambda x: Element(x).Z)
if inverse:
return dict(zip(range(len(elements)),elements))
else:
return dict(zip(elements,range(len(elements))))
# %%

def remove_pbc(system, protect_layer = 9):
nframes = len(system["coords"])
natoms = len(system['coords'][0])
for ff in range(nframes):
tmpcoord = system['coords'][ff]
cog = np.average(tmpcoord, axis = 0)
dist = tmpcoord - np.tile(cog, [natoms, 1])
max_dist = np.max(np.linalg.norm(dist, axis = 1))
h_cell_size = max_dist + protect_layer
cell_size = h_cell_size * 2
shift = np.array([1,1,1]) * h_cell_size - cog
system['coords'][ff] = system['coords'][ff] + np.tile(shift, [natoms, 1])
system['cells'][ff] = cell_size * np.eye(3)
return system
91 changes: 91 additions & 0 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
from dpdata.periodic_table import Element

def elements_index_map(elements,standard=False,inverse=False):
if standard:
elements.sort(key=lambda x: Element(x).Z)
if inverse:
return dict(zip(range(len(elements)),elements))
else:
return dict(zip(elements,range(len(elements))))
# %%

def remove_pbc(system, protect_layer = 9):
nframes = len(system["coords"])
natoms = len(system['coords'][0])
for ff in range(nframes):
tmpcoord = system['coords'][ff]
cog = np.average(tmpcoord, axis = 0)
dist = tmpcoord - np.tile(cog, [natoms, 1])
max_dist = np.max(np.linalg.norm(dist, axis = 1))
h_cell_size = max_dist + protect_layer
cell_size = h_cell_size * 2
shift = np.array([1,1,1]) * h_cell_size - cog
system['coords'][ff] = system['coords'][ff] + np.tile(shift, [natoms, 1])
system['cells'][ff] = cell_size * np.eye(3)
return system

def add_atom_names(data, atom_names):
"""
Add atom_names that do not exist.
"""
data['atom_names'].extend(atom_names)
data['atom_numbs'].extend([0 for _ in atom_names])
return data

def sort_atom_names(data, type_map=None):
"""
Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.

Parameters
----------
type_map : list
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert (set(data['atom_names']).issubset(set(type_map)))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(data["atom_names"])
new_atoms = [e for e in type_map if e not in data["atom_names"]]
if new_atoms:
data = add_atom_names(data, new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(data['atom_names'])[np.argsort(np.argsort(type_map))]
else:
# index that will sort an array by alphabetical order
idx = np.argsort(data['atom_names'])
# sort atom_names, atom_numbs, atom_types by idx
data['atom_names'] = list(np.array(data['atom_names'])[idx])
data['atom_numbs'] = list(np.array(data['atom_numbs'])[idx])
data['atom_types'] = np.argsort(idx)[data['atom_types']]
return data

def uniq_atom_name(data):
"""
Make the atom names uniq. For example
['O', 'H', 'O', 'H', 'O'] -> ['O', 'H']

Parameters
----------
data : dict
data dict of `System`, `LabeledSystem`

"""
unames = []
uidxmap = []
for idx,ii in enumerate(data['atom_names']):
if ii not in unames:
unames.append(ii)
uidxmap.append(unames.index(ii))
data['atom_names'] = unames
tmp_type = list(data['atom_types']).copy()
data['atom_types'] = np.array([uidxmap[jj] for jj in tmp_type], dtype=int)
data['atom_numbs'] = [sum( ii == data['atom_types'] ) for ii in range(len(data['atom_names'])) ]
return data
Loading