Skip to content

Commit

Permalink
sample_probs is extracted out of sample_state (#6354)
Browse files Browse the repository at this point in the history
**Context:**
Currently, the sample_state helper function accepts a state vector,
turns it into probabilities, and then samples it. This task is to break
out the code for sample_probs from this function.

By breaking this code into it's own helper function, we can reuse it for
a "sample density matrix" implementation.
**Description of the Change:**
- [x] Separate `devices.qubit.sampling.sample_probs` from `sample_state`
- [x] Separate `devices.qubit.sampling.sample_probs_jax` from
`_sample_state_jax`
 - [x] similar but for qutrit mixed
 - [x] similar but for qutrit mixed

**Benefits:**
Better modularization; disentanglement

Future usage in new devices

**Possible Drawbacks:**

**Related GitHub Issues:**

**Related ShortCut Stories:**
[sc-73317]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
2 people authored and austingmhuang committed Oct 23, 2024
1 parent f54e053 commit 3dd186f
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 84 deletions.
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

<h3>New features since last release</h3>

* Introduced `sample_probs` function for the `qml.devices.qubit` and `qml.devices.qutrit_mixed` modules:
- This function takes probability distributions as input and returns sampled outcomes.
- Simplifies the sampling process by separating it from other operations in the measurement chain.
- Improves modularity: The same code can be easily adapted for other devices (e.g., a potential `default_mixed` device).
- Enhances maintainability by isolating the sampling logic.
[(#6354)](https://github.com/PennyLaneAI/pennylane/pull/6354)

* `qml.transforms.decompose` is added for stepping through decompositions to a target gate set.
[(#6334)](https://github.com/PennyLaneAI/pennylane/pull/6334)

Expand Down
7 changes: 4 additions & 3 deletions pennylane/devices/qubit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
measure
measure_with_samples
sample_state
sample_probs
simulate
adjoint_jacobian
adjoint_jvp
adjoint_vjp
"""

from .apply_operation import apply_operation
from .adjoint_jacobian import adjoint_jacobian, adjoint_jvp, adjoint_vjp
from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
from .sampling import sample_state, measure_with_samples
from .simulate import simulate, get_final_state, measure_final_state
from .sampling import measure_with_samples, sample_probs, sample_state
from .simulate import get_final_state, measure_final_state, simulate
111 changes: 57 additions & 54 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,78 +471,90 @@ def sample_state(
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
"""
if prng_key is not None or qml.math.get_interface(state) == "jax":
return _sample_state_jax(
state, shots, prng_key, is_state_batched=is_state_batched, wires=wires, seed=rng
)

rng = np.random.default_rng(rng)

total_indices = len(state.shape) - is_state_batched
state_wires = qml.wires.Wires(range(total_indices))

wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
basis_states = np.arange(2**num_wires)

flat_state = flatten_state(state, total_indices)
with qml.queuing.QueuingManager.stop_recording():
probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires)
# Keep same interface (e.g. jax) as in the device

# when using the torch interface with float32 as default dtype,
# probabilities must be renormalized as they may not sum to one
# see https://github.com/PennyLaneAI/pennylane/issues/5444
norm = qml.math.sum(probs, axis=-1)
abs_diff = qml.math.abs(norm - 1.0)
cutoff = 1e-07
return sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key)

if is_state_batched:
normalize_condition = False

for s in abs_diff:
if s != 0:
normalize_condition = True
if s > cutoff:
normalize_condition = False
break
def sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key=None):
"""
Sample from given probabilities, dispatching between JAX and NumPy implementations.
Args:
probs (array): The probabilities to sample from
shots (int): The number of samples to take
num_wires (int): The number of wires to sample
is_state_batched (bool): whether the state is batched or not
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]):
A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
"""
if qml.math.get_interface(probs) == "jax" or prng_key is not None:
return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, seed=rng)

return _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng)


def _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng):
"""
Sample from given probabilities using NumPy's random number generator.
Args:
probs (array): The probabilities to sample from
shots (int): The number of samples to take
num_wires (int): The number of wires to sample
is_state_batched (bool): whether the state is batched or not
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]):
A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used
"""
rng = np.random.default_rng(rng)
norm = qml.math.sum(probs, axis=-1)
norm_err = qml.math.abs(norm - 1.0)
cutoff = 1e-07

if normalize_condition:
probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm
norm_err = norm_err[..., np.newaxis] if not is_state_batched else norm_err
if qml.math.any(norm_err > cutoff):
raise ValueError("probabilities do not sum to 1")

# rng.choice doesn't support broadcasting
basis_states = np.arange(2**num_wires)
if is_state_batched:
probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:
if not 0 < abs_diff < cutoff:
norm = 1.0
probs = probs / norm

samples = rng.choice(basis_states, shots, p=probs)

powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(np.int64)


# pylint:disable = unused-argument
def _sample_state_jax(
state,
shots: int,
prng_key,
is_state_batched: bool = False,
wires=None,
seed=None,
) -> np.ndarray:
def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, seed=None):
"""
Returns a series of samples of a state for the JAX interface based on the PRNG.
Args:
state (array[complex]): A state vector to be sampled
probs (array): The probabilities to sample from
shots (int): The number of samples to take
prng_key (jax.random.PRNGKey): A``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator.
num_wires (int): The number of wires to sample
is_state_batched (bool): whether the state is batched or not
wires (Sequence[int]): The wires to sample
seed (numpy.random.Generator): seed to use to generate a key if a ``prng_key`` is not present. ``None`` by default.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
seed (Optional[int]): A seed for the random number generator. This is only used if ``prng_key``
is not provided.
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
Expand All @@ -554,19 +566,10 @@ def _sample_state_jax(
if prng_key is None:
prng_key = jax.random.PRNGKey(np.random.default_rng(seed).integers(100000))

total_indices = len(state.shape) - is_state_batched
state_wires = qml.wires.Wires(range(total_indices))

wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
basis_states = np.arange(2**num_wires)

flat_state = flatten_state(state, total_indices)
with qml.queuing.QueuingManager.stop_recording():
probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires)
basis_states = jnp.arange(2**num_wires)

if is_state_batched:
keys = jax_random_split(prng_key, num=len(state))
keys = jax_random_split(prng_key, num=probs.shape[0])
samples = jnp.array(
[
jax.random.choice(_key, basis_states, shape=(shots,), p=prob)
Expand All @@ -577,6 +580,6 @@ def _sample_state_jax(
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << np.arange(num_wires, dtype=int)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(int)
return (states_sampled_base_ten > 0).astype(jnp.int64)
2 changes: 1 addition & 1 deletion pennylane/devices/qutrit_mixed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
from .sampling import sample_state, measure_with_samples
from .sampling import sample_state, measure_with_samples, sample_probs
from .simulate import simulate, get_final_state, measure_final_state
103 changes: 94 additions & 9 deletions pennylane/devices/qutrit_mixed/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,25 +250,74 @@ def _sample_state_jax(
ndarray[int]: Sample values of the shape (shots, num_wires)
"""
# pylint: disable=import-outside-toplevel
import jax
import jax.numpy as jnp

key = prng_key

total_indices = get_num_wires(state, is_state_batched)
state_wires = qml.wires.Wires(range(total_indices))

wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
basis_states = np.arange(QUDIT_DIM**num_wires)

with qml.queuing.QueuingManager.stop_recording():
probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors)

state_len = len(state)

return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)


def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len):
"""
Sample from a probability distribution for a qutrit system using JAX.
This function generates samples based on the given probability distribution
for a qutrit system with a specified number of wires. It can handle both
batched and non-batched probability distributions. This function uses JAX
for potential GPU acceleration and improved performance.
Args:
probs (jnp.ndarray): Probability distribution to sample from. For non-batched
input, this should be a 1D array of length QUDIT_DIM**num_wires. For
batched input, this should be a 2D array where each row is a separate
probability distribution.
shots (int): Number of samples to generate.
num_wires (int): Number of wires in the qutrit system.
is_state_batched (bool): Whether the input probabilities are batched.
prng_key (jax.random.PRNGKey): JAX PRNG key for random number generation.
state_len (int): Length of the state (relevant for batched inputs).
Returns:
jnp.ndarray: An array of samples. For non-batched input, the shape is
(shots, num_wires). For batched input, the shape is
(batch_size, shots, num_wires).
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> probs = jnp.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system
>>> shots = 1000
>>> num_wires = 1
>>> is_state_batched = False
>>> prng_key = jax.random.PRNGKey(42)
>>> state_len = 1
>>> samples = _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)
>>> samples.shape
(1000, 1)
Note:
This function requires JAX to be installed. It internally imports JAX
and its numpy module (jnp).
"""
# pylint: disable=import-outside-toplevel
import jax
import jax.numpy as jnp

key = prng_key

basis_states = np.arange(QUDIT_DIM**num_wires)
if is_state_batched:
# Produce separate keys for each of the probabilities along the broadcasted axis
keys = []
for _ in state:
for _ in range(state_len):
key, subkey = jax.random.split(key)
keys.append(subkey)
samples = jnp.array(
Expand Down Expand Up @@ -323,18 +372,54 @@ def sample_state(
readout_errors=readout_errors,
)

rng = np.random.default_rng(rng)

total_indices = get_num_wires(state, is_state_batched)
state_wires = qml.wires.Wires(range(total_indices))

wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
basis_states = np.arange(QUDIT_DIM**num_wires)

with qml.queuing.QueuingManager.stop_recording():
probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors)

return sample_probs(probs, shots, num_wires, is_state_batched, rng)


def sample_probs(probs, shots, num_wires, is_state_batched, rng):
"""
Sample from a probability distribution for a qutrit system.
This function generates samples based on the given probability distribution
for a qutrit system with a specified number of wires. It can handle both
batched and non-batched probability distributions.
Args:
probs (ndarray): Probability distribution to sample from. For non-batched
input, this should be a 1D array of length QUDIT_DIM**num_wires. For
batched input, this should be a 2D array where each row is a separate
probability distribution.
shots (int): Number of samples to generate.
num_wires (int): Number of wires in the qutrit system.
is_state_batched (bool): Whether the input probabilities are batched.
rng (Optional[Generator]): Random number generator to use. If None, a new
generator will be created.
Returns:
ndarray: An array of samples. For non-batched input, the shape is
(shots, num_wires). For batched input, the shape is
(batch_size, shots, num_wires).
Example:
>>> probs = np.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system
>>> shots = 1000
>>> num_wires = 1
>>> is_state_batched = False
>>> rng = np.random.default_rng(42)
>>> samples = sample_probs(probs, shots, num_wires, is_state_batched, rng)
>>> samples.shape
(1000, 1)
"""
rng = np.random.default_rng(rng)
basis_states = np.arange(QUDIT_DIM**num_wires)
if is_state_batched:
# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
Expand Down
Loading

0 comments on commit 3dd186f

Please sign in to comment.