Skip to content

Commit

Permalink
fix bug of duplicated atom names in vasp output (#250)
Browse files Browse the repository at this point in the history
* skip UTs when we do not have parmed, ase, pymatgen

* fix bug of ase fmt

* fix bug of duplicated atom names in vasp output

* handle duplicated atom names in poscar

* add missing test poscar file

* fixed bug in UT

* uniq_atom_name -> uniq_atom_names

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
amcadmus and Han Wang authored Feb 26, 2022
1 parent b62f67a commit 2c90f17
Show file tree
Hide file tree
Showing 7 changed files with 3,252 additions and 48 deletions.
8 changes: 6 additions & 2 deletions dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dpdata.vasp.outcar
import numpy as np
from dpdata.format import Format

from dpdata.utils import sort_atom_names, uniq_atom_names

@Format.register("poscar")
@Format.register("contcar")
Expand All @@ -14,7 +14,9 @@ class VASPPoscarFormat(Format):
def from_system(self, file_name, **kwargs):
with open(file_name) as fp:
lines = [line.rstrip('\n') for line in fp]
return dpdata.vasp.poscar.to_system_data(lines)
data = dpdata.vasp.poscar.to_system_data(lines)
data = uniq_atom_names(data)
return data

def to_system(self, data, file_name, frame_idx=0, **kwargs):
"""
Expand Down Expand Up @@ -71,6 +73,7 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
for ii in range(data['cells'].shape[0]):
vol = np.linalg.det(np.reshape(data['cells'][ii], [3, 3]))
data['virials'][ii] *= v_pref * vol
data = uniq_atom_names(data)
return data


Expand Down Expand Up @@ -102,4 +105,5 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
for ii in range(data['cells'].shape[0]):
vol = np.linalg.det(np.reshape(data['cells'][ii], [3, 3]))
data['virials'][ii] *= v_pref * vol
data = uniq_atom_names(data)
return data
55 changes: 9 additions & 46 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from dpdata.plugin import Plugin
from dpdata.format import Format

from dpdata.utils import (
elements_index_map,
remove_pbc,
sort_atom_names,
add_atom_names,
)

def load_format(fmt):
fmt = fmt.lower()
formats = Format.get_formats()
Expand Down Expand Up @@ -341,27 +348,7 @@ def sort_atom_names(self, type_map=None):
type_map : list
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert (set(self.data['atom_names']).issubset(set(type_map)))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(self.data["atom_names"])
new_atoms = [e for e in type_map if e not in self.data["atom_names"]]
if new_atoms:
self.add_atom_names(new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(self.data['atom_names'])[np.argsort(np.argsort(type_map))]
else:
# index that will sort an array by alphabetical order
idx = np.argsort(self.data['atom_names'])
# sort atom_names, atom_numbs, atom_types by idx
self.data['atom_names'] = list(np.array(self.data['atom_names'])[idx])
self.data['atom_numbs'] = list(np.array(self.data['atom_numbs'])[idx])
self.data['atom_types'] = np.argsort(idx)[self.data['atom_types']]
self.data = sort_atom_names(self.data, type_map=type_map)

def check_type_map(self, type_map):
"""
Expand Down Expand Up @@ -489,8 +476,7 @@ def add_atom_names(self, atom_names):
"""
Add atom_names that do not exist.
"""
self.data['atom_names'].extend(atom_names)
self.data['atom_numbs'].extend([0 for _ in atom_names])
self.data = add_atom_names(self.data, atom_names)

def replicate(self, ncopy):
"""
Expand Down Expand Up @@ -1298,26 +1284,3 @@ def check_LabeledSystem(data):
assert( len(data['cells']) == len(data['coords']) == len(data['energies']) )


def elements_index_map(elements,standard=False,inverse=False):
if standard:
elements.sort(key=lambda x: Element(x).Z)
if inverse:
return dict(zip(range(len(elements)),elements))
else:
return dict(zip(elements,range(len(elements))))
# %%

def remove_pbc(system, protect_layer = 9):
nframes = len(system["coords"])
natoms = len(system['coords'][0])
for ff in range(nframes):
tmpcoord = system['coords'][ff]
cog = np.average(tmpcoord, axis = 0)
dist = tmpcoord - np.tile(cog, [natoms, 1])
max_dist = np.max(np.linalg.norm(dist, axis = 1))
h_cell_size = max_dist + protect_layer
cell_size = h_cell_size * 2
shift = np.array([1,1,1]) * h_cell_size - cog
system['coords'][ff] = system['coords'][ff] + np.tile(shift, [natoms, 1])
system['cells'][ff] = cell_size * np.eye(3)
return system
91 changes: 91 additions & 0 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
from dpdata.periodic_table import Element

def elements_index_map(elements,standard=False,inverse=False):
if standard:
elements.sort(key=lambda x: Element(x).Z)
if inverse:
return dict(zip(range(len(elements)),elements))
else:
return dict(zip(elements,range(len(elements))))
# %%

def remove_pbc(system, protect_layer = 9):
nframes = len(system["coords"])
natoms = len(system['coords'][0])
for ff in range(nframes):
tmpcoord = system['coords'][ff]
cog = np.average(tmpcoord, axis = 0)
dist = tmpcoord - np.tile(cog, [natoms, 1])
max_dist = np.max(np.linalg.norm(dist, axis = 1))
h_cell_size = max_dist + protect_layer
cell_size = h_cell_size * 2
shift = np.array([1,1,1]) * h_cell_size - cog
system['coords'][ff] = system['coords'][ff] + np.tile(shift, [natoms, 1])
system['cells'][ff] = cell_size * np.eye(3)
return system

def add_atom_names(data, atom_names):
"""
Add atom_names that do not exist.
"""
data['atom_names'].extend(atom_names)
data['atom_numbs'].extend([0 for _ in atom_names])
return data

def sort_atom_names(data, type_map=None):
"""
Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.
Parameters
----------
type_map : list
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert (set(data['atom_names']).issubset(set(type_map)))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(data["atom_names"])
new_atoms = [e for e in type_map if e not in data["atom_names"]]
if new_atoms:
data = add_atom_names(data, new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(data['atom_names'])[np.argsort(np.argsort(type_map))]
else:
# index that will sort an array by alphabetical order
idx = np.argsort(data['atom_names'])
# sort atom_names, atom_numbs, atom_types by idx
data['atom_names'] = list(np.array(data['atom_names'])[idx])
data['atom_numbs'] = list(np.array(data['atom_numbs'])[idx])
data['atom_types'] = np.argsort(idx)[data['atom_types']]
return data

def uniq_atom_names(data):
"""
Make the atom names uniq. For example
['O', 'H', 'O', 'H', 'O'] -> ['O', 'H']
Parameters
----------
data : dict
data dict of `System`, `LabeledSystem`
"""
unames = []
uidxmap = []
for idx,ii in enumerate(data['atom_names']):
if ii not in unames:
unames.append(ii)
uidxmap.append(unames.index(ii))
data['atom_names'] = unames
tmp_type = list(data['atom_types']).copy()
data['atom_types'] = np.array([uidxmap[jj] for jj in tmp_type], dtype=int)
data['atom_numbs'] = [sum( ii == data['atom_types'] ) for ii in range(len(data['atom_names'])) ]
return data
Loading

0 comments on commit 2c90f17

Please sign in to comment.