From 26ecef36ae0941b12319e312baffb75c2035dadf Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Sun, 21 Jun 2020 18:14:26 +0100 Subject: [PATCH] made (select|keep)_blades take indices, added (select|keep)_blades_with_name with old behavior --- README.md | 4 +-- notebooks/qed.ipynb | 6 ++--- notebooks/tfga.ipynb | 10 ++++---- tests/test_dual_ga.py | 18 +++++++------- tfga/mv.py | 8 +++--- tfga/tfga.py | 57 ++++++++++++++++++++++++++++++++++++------- 6 files changed, 71 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 81d169c..80ff350 100644 --- a/README.md +++ b/README.md @@ -56,10 +56,10 @@ ga.print(ga.reversion(quaternion)) ga.print(quaternion[0]) # tf.Tensor of shape [1]: -5 (ie. reversed sign of e_01 component) -ga.print(ga.select_blades(quaternion, "10")) +ga.print(ga.select_blades_with_name(quaternion, "10")) # tf.Tensor of shape [8] with only e_01 component equal to 5 -ga.print(ga.keep_blades(quaternion, "10")) +ga.print(ga.keep_blades_with_name(quaternion, "10")) ``` Alternatively we can convert the geometric algebra [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) instance to [`MultiVector`](https://tfga.warlock.ai/tfga.html#tfga.mv.MultiVector) diff --git a/notebooks/qed.ipynb b/notebooks/qed.ipynb index 1f4b8f0..5d070d4 100644 --- a/notebooks/qed.ipynb +++ b/notebooks/qed.ipynb @@ -161,8 +161,8 @@ "sta.print(\"b3 * b2^-1:\", sta.geom_prod(b3, sta.inverse(b2)), \"should be\", b1)\n", "\n", "sta.print(\"~b2 (Grade reversal):\", sta.reversion(b2))\n", - "sta.print(\"Scalar part of b2:\", sta.keep_blades(b2, \"\"))\n", - "sta.print(\"e_01 part of b2:\", sta.keep_blades(b2, \"01\"))" + "sta.print(\"Scalar part of b2:\", sta.keep_blades_with_name(b2, \"\"))\n", + "sta.print(\"e_01 part of b2:\", sta.keep_blades_with_name(b2, \"01\"))" ] }, { @@ -193,7 +193,7 @@ "\n", "sta.print(\"A(t=0, x=5, y=3, z=9):\", a[0, 5, 3, 9])\n", "sta.print(\"A(t=0, z=[3,4,5]):\", a[0, :, :, 3:6])\n", - "sta.print(\"e_0 part of A(X):\", sta.select_blades(a, \"0\").shape)\n", + "sta.print(\"e_0 part of A(X):\", sta.select_blades_with_name(a, \"0\").shape)\n", "\n", "sta.print(\"A(0, 0, 0, 0) * ~A(0, 0, 0, 0):\", sta.geom_prod(a, sta.reversion(a))[0, 0, 0, 0])" ] diff --git a/notebooks/tfga.ipynb b/notebooks/tfga.ipynb index 23fa38a..b842197 100644 --- a/notebooks/tfga.ipynb +++ b/notebooks/tfga.ipynb @@ -295,11 +295,11 @@ "source": [ "v = sta.from_tensor_with_kind(tf.ones(16, dtype=tf.float32), \"mv\")\n", "sta.print(v)\n", - "sta.print(sta.keep_blades(v, [\"10\", \"1\"]))\n", - "sta.print(sta.keep_blades(v, \"2\"))\n", - "sta.print(\"R:\", sta.select_blades(v, [\"0\", \"01\", \"10\"]))\n", - "sta.print(\"R:\", sta.select_blades(v, [\"123\", \"01\", \"0\", \"0\"]))\n", - "sta.print(\"R:\", sta.select_blades(v, \"312\"))\n", + "sta.print(sta.keep_blades_with_name(v, [\"10\", \"1\"]))\n", + "sta.print(sta.keep_blades_with_name(v, \"2\"))\n", + "sta.print(\"R:\", sta.select_blades_with_name(v, [\"0\", \"01\", \"10\"]))\n", + "sta.print(\"R:\", sta.select_blades_with_name(v, [\"123\", \"01\", \"0\", \"0\"]))\n", + "sta.print(\"R:\", sta.select_blades_with_name(v, \"312\"))\n", "sta.print(v[..., 0])" ] } diff --git a/tests/test_dual_ga.py b/tests/test_dual_ga.py index 2db3232..c6d8c21 100644 --- a/tests/test_dual_ga.py +++ b/tests/test_dual_ga.py @@ -97,15 +97,15 @@ def test_auto_diff_square(self): # f(1) = 1^2 = 1, f'(1) = 2 x_squared = ga.geom_prod(x, x) - self.assertTensorsEqual(ga.select_blades(x_squared, ""), 1.0) - self.assertTensorsEqual(ga.select_blades(x_squared, "0"), 2.0) + self.assertTensorsEqual(ga.select_blades_with_name(x_squared, ""), 1.0) + self.assertTensorsEqual(ga.select_blades_with_name(x_squared, "0"), 2.0) y = five + eps # f(5) = 5^2 = 25, f'(5) = 10 y_squared = ga.geom_prod(y, y) - self.assertTensorsEqual(ga.select_blades(y_squared, ""), 25.0) - self.assertTensorsEqual(ga.select_blades(y_squared, "0"), 10.0) + self.assertTensorsEqual(ga.select_blades_with_name(y_squared, ""), 25.0) + self.assertTensorsEqual(ga.select_blades_with_name(y_squared, "0"), 10.0) def test_batched_auto_diff_square(self): """Test automatic differentiation using @@ -124,15 +124,15 @@ def test_batched_auto_diff_square(self): # f(1) = 1^2 = 1, f'(1) = 2 x_squared = ga.geom_prod(x, x) - self.assertTensorsEqual(ga.select_blades(x_squared, ""), 1.0) - self.assertTensorsEqual(ga.select_blades(x_squared, "0"), 2.0) + self.assertTensorsEqual(ga.select_blades_with_name(x_squared, ""), 1.0) + self.assertTensorsEqual(ga.select_blades_with_name(x_squared, "0"), 2.0) y = five + eps # f(5) = 5^2 = 25, f'(5) = 10 y_squared = ga.geom_prod(y, y) - self.assertTensorsEqual(ga.select_blades(y_squared, ""), 25.0) - self.assertTensorsEqual(ga.select_blades(y_squared, "0"), 10.0) + self.assertTensorsEqual(ga.select_blades_with_name(y_squared, ""), 25.0) + self.assertTensorsEqual(ga.select_blades_with_name(y_squared, "0"), 10.0) def test_mul_inverse(self): ga = GeometricAlgebra(metric=dual_metric) @@ -149,7 +149,7 @@ def test_mul_inverse(self): # a^-1 = 1 / 2 a_inv = ga.inverse(a) - self.assertTensorsEqual(ga.select_blades(a_inv, ""), 0.5) + self.assertTensorsEqual(ga.select_blades_with_name(a_inv, ""), 0.5) # c = a * b # => a_inv * c = b diff --git a/tfga/mv.py b/tfga/mv.py index 0ccf019..222b321 100644 --- a/tfga/mv.py +++ b/tfga/mv.py @@ -137,15 +137,15 @@ def __pow__(self, n: int) -> self: ) def __getitem__(self, key: Union[str, List[str]]) -> self: - """`MultiVector` with only passed blades as non-zeros.""" + """`MultiVector` with only passed blade names as non-zeros.""" return MultiVector( - self._algebra.keep_blades(self._blade_values, key), + self._algebra.keep_blades_with_name(self._blade_values, key), self._algebra ) def __call__(self, key: Union[str, List[str]]): - """`tf.Tensor` with passed blades on last axis.""" - return self._algebra.select_blades(self._blade_values, key) + """`tf.Tensor` with passed blade names on last axis.""" + return self._algebra.select_blades_with_name(self._blade_values, key) def __repr__(self) -> str: return self._algebra.mv_repr(self._blade_values) diff --git a/tfga/tfga.py b/tfga/tfga.py index eb42768..5f450f6 100644 --- a/tfga/tfga.py +++ b/tfga/tfga.py @@ -5,7 +5,7 @@ It exposes methods for operating on `tf.Tensor` instances where their last axis is interpreted as blades of the algebra. """ -from typing import List, Any, Union +from typing import List, Any, Union, Optional import numbers import tensorflow as tf import numpy as np @@ -640,30 +640,69 @@ def int_pow(self, a: tf.Tensor, n: int) -> tf.Tensor: result = self.geom_prod(result, a) return result - def keep_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor: + def keep_blades(self, a: tf.Tensor, blade_indices: List[int]) -> tf.Tensor: """Takes a geometric algebra tensor and returns it with only the given - blades as non-zeros. + blade_indices as non-zeros. Args: a: Geometric algebra tensor to copy - blade_names: Blades to keep + blade_indices: Indices for blades to keep Returns: - `a` with only `blade_names` elements as non-zeros + `a` with only `blade_indices` components as non-zeros """ a = tf.convert_to_tensor(a, dtype_hint=tf.float32) + blade_indices = tf.cast( + tf.convert_to_tensor(blade_indices, dtype_hint=tf.int64), + dtype=tf.int64 + ) + + blade_values = tf.gather(a, blade_indices, axis=-1) + + return self.from_tensor(blade_values, blade_indices) + def keep_blades_with_name(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor: + """Takes a geometric algebra tensor and returns it with only the given + blades as non-zeros. + + Args: + a: Geometric algebra tensor to copy + blade_names: Blades to keep + + Returns: + `a` with only `blade_names` components as non-zeros + """ if isinstance(blade_names, str): blade_names = [blade_names] _, blade_indices = get_blade_indices_from_names( blade_names, self.blades) - blade_values = tf.gather(a, blade_indices, axis=-1) + return self.keep_blades(a, blade_indices) - return self.from_tensor(blade_values, blade_indices) + def select_blades(self, a: tf.Tensor, blade_indices: List[int]) -> tf.Tensor: + """Takes a geometric algebra tensor and returns a `tf.Tensor` with the + blades in blade_indices on the last axis. + + + Args: + a: Geometric algebra tensor to copy + blade_indices: Indices for blades to select + + Returns: + `tf.Tensor` based on `a` with `blade_indices` on last axis. + """ + a = tf.convert_to_tensor(a, dtype_hint=tf.float32) + blade_indices = tf.cast( + tf.convert_to_tensor(blade_indices, dtype_hint=tf.int64), + dtype=tf.int64 + ) + + result = tf.gather(a, blade_indices, axis=-1) + + return result - def select_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor: + def select_blades_with_name(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor: """Takes a geometric algebra tensor and returns a `tf.Tensor` with the blades in blade_names on the last axis. @@ -684,7 +723,7 @@ def select_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf. blade_signs, blade_indices = get_blade_indices_from_names( blade_names, self.blades) - result = blade_signs * tf.gather(a, blade_indices, axis=-1) + result = blade_signs * self.select_blades(a, blade_indices) if is_single_blade: return result[..., 0]