diff --git a/docs/source/jmath.rst b/docs/source/jmath.rst index 4868b87..47f70fa 100644 --- a/docs/source/jmath.rst +++ b/docs/source/jmath.rst @@ -20,6 +20,14 @@ Subpackages Submodules ---------- +jmath.autodiff module +--------------------- + +.. automodule:: jmath.autodiff + :members: + :undoc-members: + :show-inheritance: + jmath.complex module -------------------- diff --git a/docs/source/jmath.universal.rst b/docs/source/jmath.universal.rst index 486ca4b..bbba1e2 100644 --- a/docs/source/jmath.universal.rst +++ b/docs/source/jmath.universal.rst @@ -4,6 +4,14 @@ jmath.universal package Submodules ---------- +jmath.universal.hyperbolic module +--------------------------------- + +.. automodule:: jmath.universal.hyperbolic + :members: + :undoc-members: + :show-inheritance: + jmath.universal.logarithms module --------------------------------- diff --git a/jmath/autodiff.py b/jmath/autodiff.py new file mode 100644 index 0000000..fef80a6 --- /dev/null +++ b/jmath/autodiff.py @@ -0,0 +1,321 @@ +''' + Automatic Differentiation +''' + +# - Imports + +import operator as op +import inspect +import string +from functools import wraps +from types import FunctionType +from .uncertainties import Uncertainty +from typing import Any, Union, Callable, Tuple + +# - Typing + +Supported = Union[int, float, Uncertainty, 'Function', 'Variable'] +Numeric = Union[int, float, Uncertainty] + +# - Classes + +class Function: + ''' + Automatic Differentiation Function Object + + Parameters + ----------- + + func + Represented function. + derivatives + Tuple of partial derivatives of the function with respect to function variables. + Note that derivatives may be numeric values or functions with built-in operations and Functions. + ''' + def __init__(self, func: Callable, derivatives: Tuple[Callable]): + + self.inputs = None + self.func = func + self.derivatives = derivatives + + # Check if diff is not a tuple + if not isinstance(self.derivatives, tuple): + # If not then we shall make it one + self.derivatives = (self.derivatives,) + + def __str__(self): + + # Get parameters + params = tuple() + if self.inputs is not None: + params = tuple(str(input) for input in self.inputs) + + # Cases for operations + if self.func == op.mul: + return f"{params[0]}*{params[1]}" + elif self.func == op.truediv: + return f"{params[0]}/{params[1]}" + elif self.func == op.sub: + return f"{params[0]}-{params[1]}" + elif self.func == op.add: + return f"{params[0]}+{params[1]}" + + # Standard function + return f"{self.func.__name__}{str(params)[:-2]})" + + def __call__(self, **kwargs): + + if not isinstance(self.func, Callable): + return self.func + + # Input collection + inputs = [] + # Start computing inputs + for input in self.inputs: + if isinstance(input, Variable): + # Variable case + # Check if the variable has been assigned a value + if input.id in kwargs.keys(): + inputs.append(kwargs[input.id]) + else: + # There is no value for the variable??? + # Throw a value error + raise KeyError(f"Variable '{input.id}' was not assigned a value on function call!") + elif isinstance(input, (int, float, Uncertainty)): + # Const case + input.append(input) + elif isinstance(input, Function): + # Function case + inputs.append(input(**kwargs)) + + return self.func(*tuple(inputs)) + + def __add__(self, other: Supported) -> 'Function': + + if other == 0: + # Special case + return self + elif isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: x + other, 1) + f.register(self) + return f + else: + # Variable case + f = Function(op.add, (1, 1)) + f.register(self, other) + return f + + def __radd__(self, other: Supported) -> 'Function': + + return self + other + + def __sub__(self, other: Supported) -> 'Function': + + if other == 0: + # Special case + return self + elif isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: x - other, 1) + f.register(self) + return f + else: + # Variable case + f = Function(op.sub, (1, -1)) + f.register(self, other) + return f + + def __rsub__(self, other: Supported) -> 'Function': + + if isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: other - x, -1) + f.register(self) + return f + else: + # Variable case + f = Function(op.sub, (-1, 1)) + f.register(self, other) + return f + + def __mul__(self, other: Supported) -> 'Function': + + if other == 1: + # Special case + return self + elif other == 0: + return 0 + elif isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: other * x, other) + f.register(self) + return f + else: + # Variable case + f = Function(op.mul, (lambda x, y: y, lambda x, y: x)) + f.register(self, other) + return f + + def __neg__(self) -> 'Function': + + f = Function(op.neg, -1) + f.register(self) + return f + + def __pow__(self, power: Union[int, float, Uncertainty]) -> 'Function': + + if power == 1: + f = Function(lambda x: x, 1) + f.register(self) + return f + elif power == 0: + return 1 + else: + # Non-trivial case + f = Function(lambda x: x**power, lambda x: power*x**(power - 1)) + f.register(self) + return f + + def __rmul__(self, other: Supported) -> 'Function': + + return self * other + + def __truediv__(self, other: Supported) -> 'Function': + + if other == 1: + # Special case + return other + elif other == 0: + # Error case + raise ZeroDivisionError + elif isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: x/other, 1/other) + f.register(self) + return f + else: + # Variable case + f = Function(op.truediv, (lambda x, y: 1/y, lambda x, y: -x/(y**2))) + f.register(self, other) + return f + + def __rtruediv__(self, other: Supported) -> 'Function': + + if isinstance(other, (int, float, Uncertainty)): + # Numeric case + f = Function(lambda x: other/x, lambda x: -other/(x**2)) + f.register(self) + return f + else: + # Variable case + f = Function(op.rtruediv, (lambda x, y: 1/y, lambda x, y: -x/(y**2))) + f.register(self, other) + return f + + def register(self, *inputs: 'Function'): + ''' + Registers inputs to the function. + + Parameters + ---------- + + inputs + Args, the functions to register as inputs. + ''' + self.inputs = inputs + + def differentiate(self, wrt: Union['Variable', str]) -> 'Function': + ''' + Differentiates the function with respect to a variable. + + Parameters + ---------- + + wrt + The variable to differentiate with respect to. + ''' + # The differentiated function + func = 0 + # Move across inputs + for i, input in enumerate(self.inputs): + # Get respective derivative + partial = self.derivatives[i] + if isinstance(partial, FunctionType): + partial = partial(*self.inputs) + func += partial * input.differentiate(wrt) + + return func + + @wraps(differentiate) + def d(self, wrt: Union['Variable', str]) -> 'Function': + return self.differentiate(wrt) + +class Variable(Function): + ''' + Variables for function differentiation. + + Parameters + ---------- + + id + Unique identifier string. + ''' + def __init__(self, id: str = None): + + super().__init__(lambda x: x, None) + self.id = id + self.inputs = None + self.derivatives = None + + def __str__(self): + + return self.id + + def __call__(self, input: Any) -> Any: + + return input + + def differentiate(self, wrt: 'Variable') -> int: + + if ((wrt == self) or (wrt == self.id and self.id is not None)): + return 1 + else: + return 0 + +# - Functions + +def analyse(f: Callable) -> Function: + ''' + Automatically analyses the given function and produces a Function object. + + Parameters + ---------- + + f + The function to analyse + + Returns + ------- + + Function + A differentiable function object representing the given function. + ''' + # Get the list of parameters from the function + names = inspect.getargspec(f)[0] + # Convert these into variables for the function + vars = tuple(Variable(name) for name in names) + # Pass these to the function + f = f(*vars) + # Register the input variables + f.register(*vars) + # And return + return f + +# - Main + +# Define all english letters as 'Variable's +# This code is very silly +# I'm not sure it should stay here +for letter in string.ascii_letters: + globals()[letter] = Variable(letter) \ No newline at end of file diff --git a/jmath/universal/__init__.py b/jmath/universal/__init__.py index 79e781b..e6b084a 100644 --- a/jmath/universal/__init__.py +++ b/jmath/universal/__init__.py @@ -9,6 +9,6 @@ # - Defaults from .trigonometry import sin, asin, cos, acos, tan, atan -from .logarithms import log, log10, log2 +from .logarithms import log, log10, log2, ln from .natural import exp from .hyperbolic import sinh, asinh, cosh, acosh, tanh, atanh \ No newline at end of file diff --git a/jmath/universal/hyperbolic.py b/jmath/universal/hyperbolic.py index df9ab5f..afcfe2f 100644 --- a/jmath/universal/hyperbolic.py +++ b/jmath/universal/hyperbolic.py @@ -19,7 +19,7 @@ def cosh(value: Supported) -> Supported: value The number to compute the hyberbolic cosine of. """ - return generic_function(math.cosh, value) + return generic_function(math.cosh, value, derivative = sinh) def acosh(value: Supported) -> Supported: """ @@ -43,7 +43,7 @@ def sinh(value: Supported) -> Supported: value The number to compute the hyberbolic sine of. """ - return generic_function(math.sinh, value) + return generic_function(math.sinh, value, derivative = cosh) def asinh(value: Supported) -> Supported: """ diff --git a/jmath/universal/logarithms.py b/jmath/universal/logarithms.py index 4b4c6a7..7db342b 100644 --- a/jmath/universal/logarithms.py +++ b/jmath/universal/logarithms.py @@ -21,6 +21,18 @@ def log(value: Supported, base: float = math.e) -> Supported: """ return generic_function(math.log, value, base) +def ln(value: Supported) -> Supported: + """ + Calculates the natural logarithm of a number. + + Parameters + ---------- + + value + The value to compute the natural logarithm of. + """ + return generic_function(math.log, value, derivative = lambda x: 1/x) + def log10(value: Supported) -> Supported: """ Calculates the log base 10 of a number. diff --git a/jmath/universal/natural.py b/jmath/universal/natural.py index 30d61ea..098a711 100644 --- a/jmath/universal/natural.py +++ b/jmath/universal/natural.py @@ -25,4 +25,4 @@ def exp(value: Supported) -> Supported: value The value to calculate the exponential of. """ - return generic_function(math.exp, value) \ No newline at end of file + return generic_function(math.exp, value, derivative = exp) \ No newline at end of file diff --git a/jmath/universal/tools.py b/jmath/universal/tools.py index e82c7f9..38826fe 100644 --- a/jmath/universal/tools.py +++ b/jmath/universal/tools.py @@ -7,14 +7,16 @@ from typing import Union, Callable from ..uncertainties import Uncertainty from ..units import Unit +from ..autodiff import Variable, Function # - Typing -Supported = Union[float, int, Uncertainty, Unit] +Supported = Union[float, int, Uncertainty, Unit, Variable, Function] +Numeric = Union[float, int] # - Functions -def generic_function(func: Callable[[float], float], input: Supported, *args) -> Supported: +def generic_function(func: Callable[[Numeric], Numeric], input: Supported, *args, derivative: Callable = None) -> Union[Supported, Function]: """ Applies a function with generic cases for special objects. @@ -25,7 +27,10 @@ def generic_function(func: Callable[[float], float], input: Supported, *args) -> The function to apply. args Arguments to send to the function + derivatives + The partial derivatives of the function with respect to its variables """ + if isinstance(input, Unit): # Units # Return function applied to unit value @@ -33,6 +38,15 @@ def generic_function(func: Callable[[float], float], input: Supported, *args) -> elif isinstance(input, Uncertainty): # Uncertainties return input.apply(func, *args) + elif isinstance(input, (Function, Variable)): + # Auto-diff Variables/Functions + # Check that there is a derivative + if derivative is None: + return NotImplemented + # Build auto-diff function + f = Function(func, derivative) + f.register(input) + return f else: # Anything else return func(input, *args) \ No newline at end of file diff --git a/jmath/universal/trigonometry.py b/jmath/universal/trigonometry.py index 8dec133..9f086b8 100644 --- a/jmath/universal/trigonometry.py +++ b/jmath/universal/trigonometry.py @@ -21,7 +21,7 @@ def sin(value: other.radian) -> Supported: value Value (in radians) to compute the sine of. ''' - return generic_function(math.sin, value) + return generic_function(math.sin, value, derivative = cos) @annotate def asin(value: Supported) -> other.radian: @@ -53,7 +53,7 @@ def cos(value: other.radian) -> Supported: value Value (in radians) to compute the cosine of ''' - return generic_function(math.cos, value) + return generic_function(math.cos, value, derivative = lambda x: -sin(x)) @annotate def acos(value: Supported) -> other.radian: diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py new file mode 100644 index 0000000..f41803a --- /dev/null +++ b/tests/test_autodiff.py @@ -0,0 +1,57 @@ +''' + Tests the auto-differentiation module +''' + +# - Imports + +from ..jmath.autodiff import x, y, z, a, b, c +from ..jmath import sin, ln, exp +from .tools import random_integer, repeat + +# - Tests + +@repeat +def test_linear_derivative(): + '''Tests the derivatives of linear functions.''' + m = random_integer(non_zero = True) + c = random_integer() + y = m*x + c + assert y.d(x) == m + +@repeat +def test_power_rule(): + '''Tests the power rule.''' + n = random_integer(2, 100) + y = x**n + assert y.d(x)(x = 1) == n + +@repeat +def test_sin_derivative(): + '''Tests derivatives of the sine function.''' + a = random_integer(non_zero = True) + f = random_integer(non_zero = True) + c = random_integer() + y = a*sin(f*x) + c + y_x = y.d(x) + assert y_x(x = 0) == a*f + +@repeat +def test_natural_log_derivative(): + '''Tests the derivative of the natural log.''' + a = random_integer(non_zero = True) + y = a*ln(x) + assert y.d(x)(x = 1) == a + +@repeat +def test_exponential_chain_rule(): + '''Tests the chain rule with the natural log.''' + a = random_integer(non_zero = True) + b = random_integer(non_zero = True) + y = a*exp(b * x) + assert y.d(x)(x = 0) == a*b + +def test_trivial_partial(): + '''Tests a trivial partial derivative example.''' + f = x * y + assert f.d(x) == y + assert f.d(y) == x \ No newline at end of file diff --git a/tests/tools.py b/tests/tools.py index a540d5b..598f17c 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -21,7 +21,7 @@ def inner(): return inner -def random_integer(min: int = -100, max: int = 100) -> int: +def random_integer(min: int = -100, max: int = 100, non_zero = False) -> int: """ Generates a random integer. Wrapper of random.randint. @@ -33,8 +33,12 @@ def random_integer(min: int = -100, max: int = 100) -> int: max The maximum value int to produce. """ - - return randint(min, max) + r = 0 + while r == 0: + r = randint(min, max) + if not non_zero: + break + return r def random_integers(length: int): """