Skip to content

Commit

Permalink
feat: pyo3 stub generation (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
cocool97 authored Feb 6, 2025
1 parent 79d96d4 commit 00c387d
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 16 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/python-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ on:
types: [created]

jobs:
gen-stubs:
name: "build-release"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build project
run: cargo run --bin stub_gen

build-python-packages:
runs-on: ubuntu-latest

Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ target
/Cargo.lock
/.vscode
venv
/.mypy_cache
/.mypy_cache
pyadb_client/pyadb_client.pyi
8 changes: 4 additions & 4 deletions adb_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ homedir = { version = "0.3.4" }
image = { version = "0.25.5" }
lazy_static = { version = "1.5.0" }
log = { version = "0.4.22" }
mdns-sd = { version = "0.13.1" }
mdns-sd = { version = "0.13.2" }
num-bigint = { version = "0.8.4", package = "num-bigint-dig" }
num-traits = { version = "0.2.19" }
quick-protobuf = { version = "0.8.1" }
rand = { version = "0.8.5" }
rand = { version = "0.9.0" }
rcgen = { version = "0.13.1" }
regex = { version = "1.11.1", features = ["perf", "std", "unicode"] }
rsa = { version = "0.9.7" }
rusb = { version = "0.9.4", features = ["vendored"] }
rustls = { version = "0.23.18" }
rustls-pki-types = "1.10.0"
rustls = { version = "0.23.22" }
rustls-pki-types = "1.11.0"
serde = { version = "1.0.216", features = ["derive"] }
serde_repr = { version = "0.1.19" }
sha1 = { version = "0.10.6", features = ["oid"] }
Expand Down
4 changes: 2 additions & 2 deletions adb_client/src/device/adb_message_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
}

pub(crate) fn open_session(&mut self, data: &[u8]) -> Result<ADBTransportMessage> {
let mut rng = rand::thread_rng();
let mut rng = rand::rng();

let message = ADBTransportMessage::new(
MessageCommand::Open,
rng.gen(), // Our 'local-id'
rng.random(), // Our 'local-id'
0,
data,
);
Expand Down
4 changes: 2 additions & 2 deletions adb_client/src/device/commands/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {

let file_size = apk_file.metadata()?.len();

let mut rng = rand::thread_rng();
let mut rng = rand::rng();

let local_id = rng.gen();
let local_id = rng.random();

self.open_session(format!("exec:cmd package 'install' -S {}\0", file_size).as_bytes())?;

Expand Down
3 changes: 1 addition & 2 deletions adb_client/src/device/models/adb_rsa_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use base64::{engine::general_purpose::STANDARD, Engine};
use num_bigint::{BigUint, ModInverse};
use num_traits::cast::ToPrimitive;
use num_traits::FromPrimitive;
use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey;
use rsa::traits::PublicKeyParts;
use rsa::{Pkcs1v15Sign, RsaPrivateKey};
Expand Down Expand Up @@ -52,7 +51,7 @@ pub struct ADBRsaKey {
impl ADBRsaKey {
pub fn new_random() -> Result<Self> {
Ok(Self {
private_key: RsaPrivateKey::new(&mut OsRng, ADB_PRIVATE_KEY_SIZE)?,
private_key: RsaPrivateKey::new(&mut rsa::rand_core::OsRng, ADB_PRIVATE_KEY_SIZE)?,
})
}

Expand Down
4 changes: 2 additions & 2 deletions benches/benchmark_adb_push.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use adb_client::ADBServer;
use anyhow::Result;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::{thread_rng, Rng};
use rand::{rng, Rng};
use std::fs::File;
use std::io::Write;
use std::process::Command;
Expand All @@ -14,7 +14,7 @@ const REMOTE_TEST_FILE_PATH: &str = "/data/local/tmp/test_file.bin";
fn generate_test_file(size_in_bytes: usize) -> Result<()> {
let mut test_file = File::create(LOCAL_TEST_FILE_PATH)?;

let mut rng = thread_rng();
let mut rng = rng();

const BUFFER_SIZE: usize = 64 * 1024;
let mut buffer = [0u8; BUFFER_SIZE];
Expand Down
12 changes: 9 additions & 3 deletions pyadb_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ readme = "README.md"

[lib]
name = "pyadb_client"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]

[[bin]]
name = "stub_gen"
doc = false

[dependencies]
anyhow = { version = "1.0.94" }
adb_client = { version = "2.0.6" }
anyhow = { version = "1.0.95" }
adb_client = { version = "2.1.5" }
pyo3 = { version = "0.23.4", features = ["extension-module", "anyhow", "abi3-py37"] }
pyo3-stub-gen = "0.7.0"
pyo3-stub-gen-derive = "0.7.0"
3 changes: 3 additions & 0 deletions pyadb_client/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ pip install ".[build]"
# Build development package
maturin develop

# Build stub file (.pyi)
cargo run --bin stub_gen

# Build release Python package
maturin build --release
```
3 changes: 3 additions & 0 deletions pyadb_client/src/adb_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ use std::net::SocketAddrV4;
use adb_client::ADBServer;
use anyhow::Result;
use pyo3::{pyclass, pymethods, PyResult};
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pymethods};

use crate::{PyADBServerDevice, PyDeviceShort};

#[gen_stub_pyclass]
#[pyclass]
pub struct PyADBServer(ADBServer);

#[gen_stub_pymethods]
#[pymethods]
impl PyADBServer {
#[new]
Expand Down
13 changes: 13 additions & 0 deletions pyadb_client/src/adb_server_device.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use adb_client::{ADBDeviceExt, ADBServerDevice};
use anyhow::Result;
use pyo3::{pyclass, pymethods};
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::{fs::File, path::PathBuf};

#[gen_stub_pyclass]
#[pyclass]
pub struct PyADBServerDevice(pub ADBServerDevice);

#[gen_stub_pymethods]
#[pymethods]
impl PyADBServerDevice {
#[getter]
Expand All @@ -29,6 +32,16 @@ impl PyADBServerDevice {
let mut writer = File::create(dest)?;
Ok(self.0.pull(&input.to_string_lossy(), &mut writer)?)
}

/// Install a package installed on the device
pub fn install(&mut self, apk_path: PathBuf) -> Result<()> {
Ok(self.0.install(&apk_path)?)
}

/// Uninstall a package installed on the device
pub fn uninstall(&mut self, package: &str) -> Result<()> {
Ok(self.0.uninstall(package)?)
}
}

impl From<ADBServerDevice> for PyADBServerDevice {
Expand Down
19 changes: 19 additions & 0 deletions pyadb_client/src/adb_usb_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,52 @@ use std::{fs::File, path::PathBuf};
use adb_client::{ADBDeviceExt, ADBUSBDevice};
use anyhow::Result;
use pyo3::{pyclass, pymethods};
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pymethods};

#[gen_stub_pyclass]
#[pyclass]
/// Represent a device directly reachable over USB.
pub struct PyADBUSBDevice(ADBUSBDevice);

#[gen_stub_pymethods]
#[pymethods]
impl PyADBUSBDevice {
#[staticmethod]
/// Autodetect a device reachable over USB.
/// This method raises an error if multiple devices or none are connected.
pub fn autodetect() -> Result<Self> {
Ok(ADBUSBDevice::autodetect()?.into())
}

/// Run shell commands on device and return the output (stdout + stderr merged)
pub fn shell_command(&mut self, commands: Vec<String>) -> Result<Vec<u8>> {
let mut output = Vec::new();
let commands: Vec<&str> = commands.iter().map(|x| &**x).collect();
self.0.shell_command(&commands, &mut output)?;
Ok(output)
}

/// Push a local file from input to dest
pub fn push(&mut self, input: PathBuf, dest: PathBuf) -> Result<()> {
let mut reader = File::open(input)?;
Ok(self.0.push(&mut reader, &dest.to_string_lossy())?)
}

/// Pull a file from device located at input, and drop it to dest
pub fn pull(&mut self, input: PathBuf, dest: PathBuf) -> Result<()> {
let mut writer = File::create(dest)?;
Ok(self.0.pull(&input.to_string_lossy(), &mut writer)?)
}

/// Install a package installed on the device
pub fn install(&mut self, apk_path: PathBuf) -> Result<()> {
Ok(self.0.install(&apk_path)?)
}

/// Uninstall a package installed on the device
pub fn uninstall(&mut self, package: &str) -> Result<()> {
Ok(self.0.uninstall(package)?)
}
}

impl From<ADBUSBDevice> for PyADBUSBDevice {
Expand Down
5 changes: 5 additions & 0 deletions pyadb_client/src/bin/stub_gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use pyo3_stub_gen::Result;

fn main() -> Result<()> {
pyadb_client::stub_info()?.generate()
}
6 changes: 6 additions & 0 deletions pyadb_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub use adb_usb_device::*;
pub use models::*;

use pyo3::prelude::*;
use pyo3_stub_gen::StubInfo;

#[pymodule]
fn pyadb_client(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -18,3 +19,8 @@ fn pyadb_client(m: &Bound<'_, PyModule>) -> PyResult<()> {

Ok(())
}

pub fn stub_info() -> anyhow::Result<StubInfo> {
// Need to be run from workspace root directory
StubInfo::from_pyproject_toml("pyproject.toml")
}
3 changes: 3 additions & 0 deletions pyadb_client/src/models/devices.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use adb_client::DeviceShort;
use pyo3::{pyclass, pymethods};
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pymethods};

// Check https://docs.rs/rigetti-pyo3/latest/rigetti_pyo3 to automatically build this code

#[gen_stub_pyclass]
#[pyclass]
pub struct PyDeviceShort(DeviceShort);

#[gen_stub_pymethods]
#[pymethods]
impl PyDeviceShort {
#[getter]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ build = ["maturin", "patchelf"]

[tool.maturin]
include = [{ path = "adb_client/**/*", format = "sdist" }]
features = ["pyo3/extension-module"]
manifest-path = "pyadb_client/Cargo.toml"

0 comments on commit 00c387d

Please sign in to comment.