Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cases for qml.sample(..., counts=True) #2839

Merged
merged 21 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ of operators. [(#2622)](https://github.com/PennyLaneAI/pennylane/pull/2622)

* Samples can be grouped into counts by passing the `counts=True` flag to `qml.sample`.
[(#2686)](https://github.com/PennyLaneAI/pennylane/pull/2686)
[(#2839)](https://github.com/PennyLaneAI/pennylane/pull/2839)

Note that the change included creating a new `Counts` measurement type in `measurements.py`.

Expand Down Expand Up @@ -245,8 +246,7 @@ of operators. [(#2622)](https://github.com/PennyLaneAI/pennylane/pull/2622)
... return qml.sample(qml.PauliZ(0), counts=True), qml.sample(qml.PauliZ(1), counts=True)
>>> result = circuit()
>>> print(result)
[tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)
tensor({-1: 526, 1: 474}, dtype=object, requires_grad=True)]
({-1: 470, 1: 530}, {-1: 470, 1: 530})
```

* The `qml.state` and `qml.density_matrix` measurements now support custom wire
Expand Down
40 changes: 28 additions & 12 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def execute(self, circuit, **kwargs):
self._samples = self.generate_samples()

multiple_sampled_jobs = circuit.is_sampled and self._has_partitioned_shots()
ret_types = [m.return_type for m in circuit.measurements]
no_counts = all(ret is not qml.measurements.Counts for ret in ret_types)

# compute the required statistics
if not self.analytic and self._shot_vector is not None:
Expand All @@ -280,20 +282,29 @@ def execute(self, circuit, **kwargs):

if qml.math._multi_dispatch(r) == "jax": # pylint: disable=protected-access
r = r[0]
elif not isinstance(r[0], dict):
elif no_counts:
# Measurement types except for Counts
r = qml.math.squeeze(r)
if isinstance(r, (np.ndarray, list)) and r.shape and isinstance(r[0], dict):
# This happens when measurement type is Counts
results.append(r)

if not no_counts:

# This happens when at least one measurement type is Counts
for result_group in r:
if isinstance(result_group, list):
# List that contains one or more dictionaries
results.extend(result_group)
else:
# Other measurement results
results.append(result_group.T)

elif shot_tuple.copies > 1:
results.extend(r.T)
else:
results.append(r.T)

s1 = s2

if not multiple_sampled_jobs:
if not multiple_sampled_jobs and no_counts:
# Can only stack single element outputs
results = qml.math.stack(results)

Expand All @@ -302,8 +313,6 @@ def execute(self, circuit, **kwargs):

if not circuit.is_sampled:

ret_types = [m.return_type for m in circuit.measurements]

if len(circuit.measurements) == 1:
if circuit.measurements[0].return_type is qml.measurements.State:
# State: assumed to only be allowed if it's the only measurement
Expand All @@ -318,15 +327,18 @@ def execute(self, circuit, **kwargs):
):
# Measurements with expval or var
results = self._asarray(results, dtype=self.R_DTYPE)
elif any(ret is not qml.measurements.Counts for ret in ret_types):
# all the other cases except all counts
elif no_counts:
# all the other cases except any counts
results = self._asarray(results)

elif circuit.all_sampled and not self._has_partitioned_shots():

results = self._asarray(results)
else:
results = tuple(self._asarray(r) for r in results)
results = tuple(
qml.math.squeeze(self._asarray(r)) if not isinstance(r, dict) else r
for r in results
)

# increment counter for number of executions of qubit device
self._num_executions += 1
Expand Down Expand Up @@ -1012,6 +1024,7 @@ def _samples_to_counts(samples, no_observable_provided):
# Before converting to str, we need to extract elements from arrays
# to satisfy the case of jax interface, as jax arrays do not support str.
samples = ["".join([str(s.item()) for s in sample]) for sample in samples]

states, counts = np.unique(samples, return_counts=True)
return dict(zip(states, counts))

Expand Down Expand Up @@ -1055,14 +1068,17 @@ def _samples_to_counts(samples, no_observable_provided):
if counts:
return _samples_to_counts(samples, no_observable_provided)
return samples

num_wires = len(device_wires) if len(device_wires) > 0 else self.num_wires
if counts:
shape = (-1, bin_size, 3) if no_observable_provided else (-1, bin_size)
shape = (-1, bin_size, num_wires) if no_observable_provided else (-1, bin_size)
return [
_samples_to_counts(bin_sample, no_observable_provided)
for bin_sample in samples.reshape(shape)
]

return (
samples.reshape((3, bin_size, -1))
samples.reshape((num_wires, bin_size, -1))
if no_observable_provided
else samples.reshape((bin_size, -1))
)
Expand Down
9 changes: 3 additions & 6 deletions pennylane/interfaces/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,12 @@ def _execute(

for i, r in enumerate(res):

if any(m.return_type is qml.measurements.Counts for m in tapes[i].measurements):
continue

if isinstance(r, np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
try:
if isinstance(r[0][0], dict):
# This happens when measurement type is Counts and shot vector is passed
continue
except (IndexError, KeyError):
pass
r = np.hstack(r) if r.dtype == np.dtype("object") else r
res[i] = np.tensor(r)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/interfaces/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,26 @@ def cp_tape(t, a):
tc.set_parameters(a)
return tc

def array_if_not_counts(tape, r):
"""Auxiliary function to convert the result of a tape to an array,
unless the tape had Counts measurements that are represented with
dictionaries. JAX NumPy arrays don't support dictionaries."""
return (
jnp.array(r)
if not any(m.return_type is qml.measurements.Counts for m in tape.measurements)
else r
)

@jax.custom_vjp
def wrapped_exec(params):
new_tapes = [cp_tape(t, a) for t, a in zip(tapes, params)]
with qml.tape.Unwrap(*new_tapes):
res, _ = execute_fn(new_tapes, **gradient_kwargs)

if len(tapes) > 1:
res = [jnp.array(r) for r in res]
res = [array_if_not_counts(tape, r) for tape, r in zip(tapes, res)]
else:
res = jnp.array(res)
res = array_if_not_counts(tapes[0], res)

return res

Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
for i, tape in enumerate(tapes):
# convert output to TensorFlow tensors

if any(m.return_type is qml.measurements.Counts for m in tape.measurements):
continue

if isinstance(res[i], np.ndarray):
# For backwards compatibility, we flatten ragged tape outputs
# when there is no sampling
Expand Down
3 changes: 3 additions & 0 deletions pennylane/interfaces/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ
# For backwards compatibility, we flatten ragged tape outputs
r = np.hstack(r)

if any(m.return_type is qml.measurements.Counts for m in ctx.tapes[i].measurements):
continue

if isinstance(r, (list, tuple)):
res[i] = [torch.as_tensor(t) for t in r]

Expand Down
11 changes: 5 additions & 6 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,21 +654,20 @@ def circuit(x):

.. code-block:: python3

dev = qml.device('default.qubit', wires=3, shots=10)
dev = qml.device("default.qubit", wires=3, shots=10)


@qml.qnode(dev)
def my_circ():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0,1])
qml.CNOT(wires=[0, 1])
qml.PauliX(wires=2)
return qml.sample(qml.PauliZ(0), counts = True), qml.sample(counts=True)
return qml.sample(qml.PauliZ(0), counts=True), qml.sample(counts=True)

Executing this QNode:

>>> my_circ()
tensor([tensor({-1: 5, 1: 5}, dtype=object, requires_grad=True),
tensor({'001': 5, '111': 5}, dtype=object, requires_grad=True)],
dtype=object, requires_grad=True)
({-1: 3, 1: 7}, {'001': 7, '111': 3})

.. note::

Expand Down
22 changes: 19 additions & 3 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,25 @@ def __call__(self, *args, **kwargs):

res = res[0]

if (
not isinstance(self._qfunc_output, Sequence)
and self._qfunc_output.return_type is qml.measurements.Counts
):

if not self.device._has_partitioned_shots():
# return a dictionary with counts not as a single-element array
return res[0]

return tuple(res)

if isinstance(self._qfunc_output, Sequence) and qml.measurements.Counts in set(
out.return_type for out in self._qfunc_output
):
# If Counts was returned with other measurements, then apply the
# data structure used in the qfunc
qfunc_output_type = type(self._qfunc_output)
return qfunc_output_type(res)

if override_shots is not False:
# restore the initialization gradient function
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn
Expand All @@ -650,9 +669,6 @@ def __call__(self, *args, **kwargs):
self.tape.is_sampled and self.device._has_partitioned_shots()
):
return res
if self._qfunc_output.return_type is qml.measurements.Counts:
# return a dictionary with counts not as a single-element array
return res[0]

return qml.math.squeeze(res)

Expand Down
Loading