Skip to content

Commit

Permalink
Update light_benchmark (#344)
Browse files Browse the repository at this point in the history
* Update light_benchmark.py

* Update light_benchmark.py

* Update light_benchmark.py

* Update light_benchmark.py

* Update pipelines.py

* Update light_benchmark.py

* Update light_benchmark.py

* Update pipelines.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update light_benchmark.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update light_benchmark.py

* Update light_benchmark.py

* Update light_benchmark.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gcattan and pre-commit-ci[bot] authored Jan 15, 2025
1 parent 5e84d11 commit d99eacf
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# Modified from plot_classify_P300_bi.py of pyRiemann
# License: BSD (3-clause)

import random
import warnings

import numpy as np
import qiskit_algorithms
from lb_base import run
from moabb import set_log_level
from pyriemann.estimation import Shrinkage, XdawnCovariances
Expand All @@ -29,6 +32,14 @@

print(__doc__)

##############################################################################
# Set random seeds

seed = 42
random.seed(seed)
np.random.seed(seed)
qiskit_algorithms.utils.algorithm_globals.random_seed

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
Expand All @@ -49,14 +60,14 @@
pipelines = {}

pipelines["RG_QSVM"] = QuantumClassifierWithDefaultRiemannianPipeline(
shots=100,
shots=1024,
nfilter=2,
dim_red=PCA(n_components=5),
params={"seed": 42, "use_fidelity_state_vector_kernel": True},
params={"seed": seed, "use_fidelity_state_vector_kernel": True},
)

pipelines["RG_VQC"] = QuantumClassifierWithDefaultRiemannianPipeline(
shots=100, spsa_trials=1, two_local_reps=2, params={"seed": 42}
shots=100, spsa_trials=1, two_local_reps=2, params={"seed": seed}
)

pipelines["QMDM_mean"] = QuantumMDMWithRiemannianPipeline(
Expand All @@ -70,7 +81,7 @@
pipelines["QMDM_dist"] = QuantumMDMWithRiemannianPipeline(
metric={"mean": "logeuclid", "distance": "qlogeuclid_hull"},
quantum=True,
seed=42,
seed=seed,
shots=100,
)

Expand Down

0 comments on commit d99eacf

Please sign in to comment.