Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zoe authored and zoe committed Feb 13, 2025
1 parent aa642d5 commit daaf950
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 40 deletions.
Empty file.
82 changes: 42 additions & 40 deletions mlptrain/descriptors/_base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from abc import ABC, abstractmethod
import numpy as np
import logging
import mlptrain as mlp
from typing import Union
import mlptrain.log as logger
import mlptrain as mlptrain

# Setup logging
logging.basicConfig(
level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class DescriptorBase(ABC):
class Descriptor(ABC):
"""Abstract base class for molecular feature descriptors."""

def __init__(self, name: str):
Expand All @@ -24,7 +19,12 @@ def __init__(self, name: str):
logger.info(f'Initialized {self.name} descriptor.')

@abstractmethod
def compute(self, configurations: mlp.ConfigurationSet) -> np.ndarray:
def compute_representation(
self,
configurations: Union[
mlptrain.Configuration, mlptrain.ConfigurationSet
],
) -> np.ndarray:
"""
Compute descriptor representation for a given molecular configuration.
Expand All @@ -33,44 +33,46 @@ def compute(self, configurations: mlp.ConfigurationSet) -> np.ndarray:
Returns:
np.ndarray: The computed descriptor representation as a vector/matrix.
"""
pass


@abstractmethod
def kernel_vector(
self, configuration, configurations, zeta: int = 4
) -> np.ndarray:
"""
Calculate the kernel matrix between a set of configurations where the
kernel is:
.. math::
@abstractmethod
def kernel_vector(
self, configuration, configurations, zeta: int = 4
) -> np.ndarray:
"""Calculate the kernel matrix between a set of configurations where the kernel is: .. math::
K(p_a, p_b) = (p_a . p_b / (p_a.p_a x p_b.p.b)^1/2 )^ζ
---------------------------------------------------------------------------
Arguments:
configuration:
---------------------------------------------------------------------------
Arguments:
configuration:
configurations:
configurations:
zeta: Power to raise the kernel matrix to
zeta: Power to raise the kernel matrix to
Returns:
(np.ndarray): Vector, shape = len(configurations)"""

Returns:
(np.ndarray): Vector, shape = len(configurations)
"""
pass
def normalize(self, vector: np.ndarray) -> np.ndarray:
"""
Normalize a feature vector to unit norm.
Arguments:
vector (np.ndarray): Input vector.
def normalize(self, vector: np.ndarray) -> np.ndarray:
"""
Normalize a feature vector to unit norm.
Returns:
np.ndarray: Normalized vector.
"""
norm = np.linalg.norm(vector)
return vector if norm == 0 else vector / norm

Arguments:
vector (np.ndarray): Input vector.
def average(self, average_method: str = 'no_average'):
"""
Compute the average of the descriptor representation to accommodate systems of different sizes.
Returns:
np.ndarray: Normalized vector.
"""
norm = np.linalg.norm(vector)
return vector if norm == 0 else vector / norm
Arguments:
average_method (str): Specifies the averaging method:
- "inner" (default), "outer", or "no_average" for soap_descriptor
- "average" or "no_average" (default) for ace_descriptor
- No parameter needed for mace_descriptor
"""

0 comments on commit daaf950

Please sign in to comment.