Skip to content

Commit

Permalink
made (select|keep)_blades take indices, added (select|keep)_blades_wi…
Browse files Browse the repository at this point in the history
…th_name with old behavior
  • Loading branch information
RobinKa committed Jun 21, 2020
1 parent 23bf1b3 commit 26ecef3
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 32 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ ga.print(ga.reversion(quaternion))
ga.print(quaternion[0])

# tf.Tensor of shape [1]: -5 (ie. reversed sign of e_01 component)
ga.print(ga.select_blades(quaternion, "10"))
ga.print(ga.select_blades_with_name(quaternion, "10"))

# tf.Tensor of shape [8] with only e_01 component equal to 5
ga.print(ga.keep_blades(quaternion, "10"))
ga.print(ga.keep_blades_with_name(quaternion, "10"))
```

Alternatively we can convert the geometric algebra [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) instance to [`MultiVector`](https://tfga.warlock.ai/tfga.html#tfga.mv.MultiVector)
Expand Down
6 changes: 3 additions & 3 deletions notebooks/qed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@
"sta.print(\"b3 * b2^-1:\", sta.geom_prod(b3, sta.inverse(b2)), \"should be\", b1)\n",
"\n",
"sta.print(\"~b2 (Grade reversal):\", sta.reversion(b2))\n",
"sta.print(\"Scalar part of b2:\", sta.keep_blades(b2, \"\"))\n",
"sta.print(\"e_01 part of b2:\", sta.keep_blades(b2, \"01\"))"
"sta.print(\"Scalar part of b2:\", sta.keep_blades_with_name(b2, \"\"))\n",
"sta.print(\"e_01 part of b2:\", sta.keep_blades_with_name(b2, \"01\"))"
]
},
{
Expand Down Expand Up @@ -193,7 +193,7 @@
"\n",
"sta.print(\"A(t=0, x=5, y=3, z=9):\", a[0, 5, 3, 9])\n",
"sta.print(\"A(t=0, z=[3,4,5]):\", a[0, :, :, 3:6])\n",
"sta.print(\"e_0 part of A(X):\", sta.select_blades(a, \"0\").shape)\n",
"sta.print(\"e_0 part of A(X):\", sta.select_blades_with_name(a, \"0\").shape)\n",
"\n",
"sta.print(\"A(0, 0, 0, 0) * ~A(0, 0, 0, 0):\", sta.geom_prod(a, sta.reversion(a))[0, 0, 0, 0])"
]
Expand Down
10 changes: 5 additions & 5 deletions notebooks/tfga.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@
"source": [
"v = sta.from_tensor_with_kind(tf.ones(16, dtype=tf.float32), \"mv\")\n",
"sta.print(v)\n",
"sta.print(sta.keep_blades(v, [\"10\", \"1\"]))\n",
"sta.print(sta.keep_blades(v, \"2\"))\n",
"sta.print(\"R:\", sta.select_blades(v, [\"0\", \"01\", \"10\"]))\n",
"sta.print(\"R:\", sta.select_blades(v, [\"123\", \"01\", \"0\", \"0\"]))\n",
"sta.print(\"R:\", sta.select_blades(v, \"312\"))\n",
"sta.print(sta.keep_blades_with_name(v, [\"10\", \"1\"]))\n",
"sta.print(sta.keep_blades_with_name(v, \"2\"))\n",
"sta.print(\"R:\", sta.select_blades_with_name(v, [\"0\", \"01\", \"10\"]))\n",
"sta.print(\"R:\", sta.select_blades_with_name(v, [\"123\", \"01\", \"0\", \"0\"]))\n",
"sta.print(\"R:\", sta.select_blades_with_name(v, \"312\"))\n",
"sta.print(v[..., 0])"
]
}
Expand Down
18 changes: 9 additions & 9 deletions tests/test_dual_ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ def test_auto_diff_square(self):

# f(1) = 1^2 = 1, f'(1) = 2
x_squared = ga.geom_prod(x, x)
self.assertTensorsEqual(ga.select_blades(x_squared, ""), 1.0)
self.assertTensorsEqual(ga.select_blades(x_squared, "0"), 2.0)
self.assertTensorsEqual(ga.select_blades_with_name(x_squared, ""), 1.0)
self.assertTensorsEqual(ga.select_blades_with_name(x_squared, "0"), 2.0)

y = five + eps

# f(5) = 5^2 = 25, f'(5) = 10
y_squared = ga.geom_prod(y, y)
self.assertTensorsEqual(ga.select_blades(y_squared, ""), 25.0)
self.assertTensorsEqual(ga.select_blades(y_squared, "0"), 10.0)
self.assertTensorsEqual(ga.select_blades_with_name(y_squared, ""), 25.0)
self.assertTensorsEqual(ga.select_blades_with_name(y_squared, "0"), 10.0)

def test_batched_auto_diff_square(self):
"""Test automatic differentiation using
Expand All @@ -124,15 +124,15 @@ def test_batched_auto_diff_square(self):

# f(1) = 1^2 = 1, f'(1) = 2
x_squared = ga.geom_prod(x, x)
self.assertTensorsEqual(ga.select_blades(x_squared, ""), 1.0)
self.assertTensorsEqual(ga.select_blades(x_squared, "0"), 2.0)
self.assertTensorsEqual(ga.select_blades_with_name(x_squared, ""), 1.0)
self.assertTensorsEqual(ga.select_blades_with_name(x_squared, "0"), 2.0)

y = five + eps

# f(5) = 5^2 = 25, f'(5) = 10
y_squared = ga.geom_prod(y, y)
self.assertTensorsEqual(ga.select_blades(y_squared, ""), 25.0)
self.assertTensorsEqual(ga.select_blades(y_squared, "0"), 10.0)
self.assertTensorsEqual(ga.select_blades_with_name(y_squared, ""), 25.0)
self.assertTensorsEqual(ga.select_blades_with_name(y_squared, "0"), 10.0)

def test_mul_inverse(self):
ga = GeometricAlgebra(metric=dual_metric)
Expand All @@ -149,7 +149,7 @@ def test_mul_inverse(self):

# a^-1 = 1 / 2
a_inv = ga.inverse(a)
self.assertTensorsEqual(ga.select_blades(a_inv, ""), 0.5)
self.assertTensorsEqual(ga.select_blades_with_name(a_inv, ""), 0.5)

# c = a * b
# => a_inv * c = b
Expand Down
8 changes: 4 additions & 4 deletions tfga/mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ def __pow__(self, n: int) -> self:
)

def __getitem__(self, key: Union[str, List[str]]) -> self:
"""`MultiVector` with only passed blades as non-zeros."""
"""`MultiVector` with only passed blade names as non-zeros."""
return MultiVector(
self._algebra.keep_blades(self._blade_values, key),
self._algebra.keep_blades_with_name(self._blade_values, key),
self._algebra
)

def __call__(self, key: Union[str, List[str]]):
"""`tf.Tensor` with passed blades on last axis."""
return self._algebra.select_blades(self._blade_values, key)
"""`tf.Tensor` with passed blade names on last axis."""
return self._algebra.select_blades_with_name(self._blade_values, key)

def __repr__(self) -> str:
return self._algebra.mv_repr(self._blade_values)
Expand Down
57 changes: 48 additions & 9 deletions tfga/tfga.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
It exposes methods for operating on `tf.Tensor` instances where their last
axis is interpreted as blades of the algebra.
"""
from typing import List, Any, Union
from typing import List, Any, Union, Optional
import numbers
import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -640,30 +640,69 @@ def int_pow(self, a: tf.Tensor, n: int) -> tf.Tensor:
result = self.geom_prod(result, a)
return result

def keep_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor:
def keep_blades(self, a: tf.Tensor, blade_indices: List[int]) -> tf.Tensor:
"""Takes a geometric algebra tensor and returns it with only the given
blades as non-zeros.
blade_indices as non-zeros.
Args:
a: Geometric algebra tensor to copy
blade_names: Blades to keep
blade_indices: Indices for blades to keep
Returns:
`a` with only `blade_names` elements as non-zeros
`a` with only `blade_indices` components as non-zeros
"""
a = tf.convert_to_tensor(a, dtype_hint=tf.float32)
blade_indices = tf.cast(
tf.convert_to_tensor(blade_indices, dtype_hint=tf.int64),
dtype=tf.int64
)

blade_values = tf.gather(a, blade_indices, axis=-1)

return self.from_tensor(blade_values, blade_indices)

def keep_blades_with_name(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor:
"""Takes a geometric algebra tensor and returns it with only the given
blades as non-zeros.
Args:
a: Geometric algebra tensor to copy
blade_names: Blades to keep
Returns:
`a` with only `blade_names` components as non-zeros
"""
if isinstance(blade_names, str):
blade_names = [blade_names]

_, blade_indices = get_blade_indices_from_names(
blade_names, self.blades)

blade_values = tf.gather(a, blade_indices, axis=-1)
return self.keep_blades(a, blade_indices)

return self.from_tensor(blade_values, blade_indices)
def select_blades(self, a: tf.Tensor, blade_indices: List[int]) -> tf.Tensor:
"""Takes a geometric algebra tensor and returns a `tf.Tensor` with the
blades in blade_indices on the last axis.
Args:
a: Geometric algebra tensor to copy
blade_indices: Indices for blades to select
Returns:
`tf.Tensor` based on `a` with `blade_indices` on last axis.
"""
a = tf.convert_to_tensor(a, dtype_hint=tf.float32)
blade_indices = tf.cast(
tf.convert_to_tensor(blade_indices, dtype_hint=tf.int64),
dtype=tf.int64
)

result = tf.gather(a, blade_indices, axis=-1)

return result

def select_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor:
def select_blades_with_name(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.Tensor:
"""Takes a geometric algebra tensor and returns a `tf.Tensor` with the
blades in blade_names on the last axis.
Expand All @@ -684,7 +723,7 @@ def select_blades(self, a: tf.Tensor, blade_names: Union[List[str], str]) -> tf.
blade_signs, blade_indices = get_blade_indices_from_names(
blade_names, self.blades)

result = blade_signs * tf.gather(a, blade_indices, axis=-1)
result = blade_signs * self.select_blades(a, blade_indices)

if is_single_blade:
return result[..., 0]
Expand Down

0 comments on commit 26ecef3

Please sign in to comment.