From 72b76bc836f1185eb7d59e76ada91915050c0530 Mon Sep 17 00:00:00 2001 From: Pondonda Date: Sun, 26 Jan 2025 19:19:18 +0100 Subject: [PATCH] adding downloadable_model and more downloadable_model python binding --- Cargo.toml | 6 + downloadable_model_graph.md | 51 ++ src/downloadable_model.rs | 517 ++++++++++++++++++ src/file_handler/mod.rs | 2 +- src/gtfs/mod.rs | 6 +- src/gtfs/write.rs | 2 +- src/lib.rs | 4 + src/model_builder.rs | 4 +- src/objects.rs | 4 +- transit_model_python/Cargo.toml | 6 +- transit_model_python/example.text | 40 ++ transit_model_python/pyproject.toml | 2 +- transit_model_python/python_test.py | 5 - transit_model_python/src/lib.rs | 253 +-------- .../src/modules/downloadable_model.rs | 292 ++++++++++ transit_model_python/src/modules/mod.rs | 2 + .../src/modules/python_transit_model.rs | 263 +++++++++ 17 files changed, 1203 insertions(+), 256 deletions(-) create mode 100644 downloadable_model_graph.md create mode 100644 src/downloadable_model.rs create mode 100644 transit_model_python/example.text delete mode 100644 transit_model_python/python_test.py create mode 100644 transit_model_python/src/modules/downloadable_model.rs create mode 100644 transit_model_python/src/modules/mod.rs create mode 100644 transit_model_python/src/modules/python_transit_model.rs diff --git a/Cargo.toml b/Cargo.toml index 830f61de3..b5d4145fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ members = [ xmllint = ["proj"] gtfs = [] parser = [] +downloadable = [] [dependencies] anyhow = "1" @@ -69,6 +70,8 @@ wkt = "0.10" zip = { version = "2", default-features = false, features = ["deflate"] } git-version = "0.3" pyo3 = "0.23.4" +reqwest = { version = "0.12.12", features = ["json"] } +tokio = "1.43.0" [[test]] name = "write_netex_france" @@ -80,3 +83,6 @@ approx = "0.5" log = "0.4" rust_decimal_macros = "1" testing_logger = "0.1" +mockito = "0.31" +tokio = { version = "1.0", features = ["full", "test-util"] } +futures = "0.3" diff --git a/downloadable_model_graph.md b/downloadable_model_graph.md new file mode 100644 index 000000000..b039b0e9e --- /dev/null +++ b/downloadable_model_graph.md @@ -0,0 +1,51 @@ +```mermaid +graph TD + A[DownloadableTransitModel] --> B[ModelConfig] + A --> C[NavitiaConfig] + A --> D[Downloader] + A --> E[Arc>] + A --> F[Arc>] + + subgraph Initialization + G[Start] --> H[initialize_model] + H --> I[get_remote_version] + I --> J[Navitia API] + H --> K[run_download] + K --> L[Downloader] + H --> M[ntfs::read] + end + + subgraph Background Updater + N[start_background_updater] --> O[check_and_update] + O --> P[get_remote_version] + O --> Q{Version Newer?} + Q -->|Yes| R[run_download] + Q -->|No| S[Return false] + R --> T[ntfs::read] + T --> U[Atomic Model Swap] + U --> V[Version Update] + end + + subgraph External Components + J -->|HTTP GET| W((Navitia Server)) + L -->|Storage| X[(S3/File System)] + M --> Y[NTFS Parser] + end + + subgraph Concurrency + E -.->|RwLock| Z[Thread-safe Reads] + E -.->|Write Lock| AA[Atomic Updates] + F -.->|Mutex| AB[Version Safety] + N --> AC[Tokio Spawn] + end + + style A fill:#f9f,stroke:#333 + style B fill:#ccf,stroke:#333 + style C fill:#ccf,stroke:#333 + style D fill:#cfc,stroke:#333 + style J fill:#fcc,stroke:#333 + style L fill:#fcc,stroke:#333 + style W fill:#cff,stroke:#333 + style X fill:#cff,stroke:#333 + style Y fill:#cff,stroke:#333 +``` \ No newline at end of file diff --git a/src/downloadable_model.rs b/src/downloadable_model.rs new file mode 100644 index 000000000..4f3a4cb29 --- /dev/null +++ b/src/downloadable_model.rs @@ -0,0 +1,517 @@ +// Copyright (C) 2025 Hove and/or its affiliates. +// +// This program is free software: you can redistribute it and/or modify it +// under the terms of the GNU Affero General Public License as published by the +// Free Software Foundation, version 3. + +// This program is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +// details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see + +//! A module for managing downloadable transit models with Navitia integration +//! +//! This system provides: +//! 1. Automatic updates from Navitia-powered transit feeds +//! 2. Secure authentication for Navitia API +//! 3. Thread-safe model access +//! 4. Configurable update intervals +//! 5. Pluggable download implementations + +use crate::ntfs; +use pyo3::prelude::*; +use reqwest::header; +use serde::Deserialize; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tokio::sync::RwLock; + +/// Configuration for model management +#[derive(Clone, Deserialize)] +#[pyclass] +pub struct ModelConfig { + /// Update check interval in seconds + pub check_interval_secs: u64, + /// Local path for storing downloaded models + pub path: String, +} + +#[pymethods] +impl ModelConfig { + /// Creates a new model configuration + /// + /// # Arguments + /// * `check_interval_secs` - Update check interval in seconds + /// * `path` - Local path for storing downloaded models + /// + /// # Returns + /// A new model configuration + #[new] + pub fn new(check_interval_secs: u64, path: &str) -> Self { + Self { + check_interval_secs, + path: path.into(), + } + } +} + +/// Configuration for Navitia API connection +#[pyclass] +#[derive(Clone, Deserialize)] +pub struct NavitiaConfig { + /// Base URL of Navitia API + pub navitia_url: String, + /// Coverage area identifier + pub coverage: String, + /// Authentication token for Navitia + pub navitia_token: String, +} + +#[pymethods] +impl NavitiaConfig { + /// Creates a new Navitia API configuration + /// + /// # Arguments + /// * `navitia_url` - Base URL of Navitia API + /// * `coverage` - Coverage area identifier + /// * `navitia_token` - Authentication token for Navitia + /// + /// # Returns + /// A new Navitia API configuration + #[new] + pub fn new(navitia_url: &str, coverage: &str, navitia_token: &str) -> Self { + Self { + navitia_url: navitia_url.into(), + coverage: coverage.into(), + navitia_token: navitia_token.into(), + } + } +} + +type DownloadResult = + Pin>> + Send>>; + +/// Trait defining model download functionality +pub trait Downloader: Send + Sync + 'static { + /// Downloads a specific model version + /// + /// # Arguments + /// * `config` - Model configuration + /// * `version` - Target version identifier + /// + /// # Returns + /// Future resolving to local path of downloaded model + fn run_download(&self, config: &ModelConfig, version: &str) -> DownloadResult; +} + +/// Main structure managing the transit model lifecycle +pub struct DownloadableTransitModel { + /// Thread-safe model access + pub current_model: Arc>, + /// Current model version + version: Arc>, + /// Download implementation + downloader: D, + /// Model configuration + config: ModelConfig, + /// Navitia API configuration + navitia_config: NavitiaConfig, +} + +impl DownloadableTransitModel { + /// Initializes a new model manager + /// + /// # Arguments + /// * `navitia_config` - Navitia API credentials + /// * `config` - Model management settings + /// * `downloader` - Download implementation + /// + /// # Flow + /// 1. Fetch current version from Navitia + /// 2. Download initial model + /// 3. Start background updater + pub async fn new( + navitia_config: NavitiaConfig, + config: ModelConfig, + downloader: D, + ) -> Result> { + let (model, version) = + Self::initialize_model(&config, &navitia_config, &downloader).await?; + + let instance = Self { + current_model: Arc::new(RwLock::new(model)), + version: Arc::new(Mutex::new(version)), + downloader, + config, + navitia_config, + }; + + instance.start_background_updater(); + Ok(instance) + } + + /// Fetches the current model + /// + /// # Arguments + /// * `config` - Model configuration + /// * `navitia_config` - Navitia API credentials + /// * `downloader` - Download implementation + /// + /// # Returns + /// A reference to the current model + /// A reference to the current model version path + async fn initialize_model( + config: &ModelConfig, + navitia_config: &NavitiaConfig, + downloader: &D, + ) -> Result<(crate::Model, String), Box> { + let version = Self::get_remote_version(navitia_config).await?; + let folder_saved_at_path = downloader.run_download(config, &version).await?; + let model = ntfs::read(&folder_saved_at_path).map_err(anyhow::Error::from)?; + Ok((model, version)) + } + + /// Starts the background updater + /// + /// # Arguments + /// + /// * `self` - The current instance + /// + /// # Flow + /// + /// 1. Periodically checks for updates + /// 2. Downloads and updates the model if a new version is available + /// 3. Logs the update + /// + /// # Note + /// + /// This function runs indefinitely in the background + /// + fn start_background_updater(&self) { + let config = self.config.clone(); + let downloader = self.downloader.clone(); + let model_ref = self.current_model.clone(); + let version_ref = self.version.clone(); + let navitia_config = self.navitia_config.clone(); + + tokio::spawn(async move { + let mut interval = + tokio::time::interval(Duration::from_secs(config.check_interval_secs)); + loop { + interval.tick().await; + match Self::check_and_update( + &config, + &downloader, + &model_ref, + &version_ref, + &navitia_config, + ) + .await + { + Ok(updated) => { + if updated { + println!("Updated to version {}", *version_ref.lock().await); + } + } + Err(e) => eprintln!("Background update failed: {}", e), + } + } + }); + } + + /// Checks for updates and updates the model if a new version is available + /// + /// # Arguments + /// + /// * `config` - Model configuration + /// * `downloader` - Download implementation + /// * `model` - The current model + /// * `version` - The current model version + /// * `navitia_config` - Navitia API credentials + /// + /// # Returns + /// + /// A boolean indicating whether the model was updated + async fn check_and_update( + config: &ModelConfig, + downloader: &D, + model: &Arc>, + version: &Mutex, + navitia_config: &NavitiaConfig, + ) -> Result> { + let remote_version = Self::get_remote_version(navitia_config).await?; + let current_version = version.lock().await.clone(); + + if remote_version > current_version { + // Download and load the model before acquiring the write lock + let saved_to_path = downloader.run_download(config, &remote_version).await?; + let new_model = ntfs::read(&saved_to_path).map_err(anyhow::Error::from)?; + + let mut model_lock = model.write().await; + *model_lock = new_model; + + // Update version + let mut version_lock = version.lock().await; + *version_lock = remote_version; + + Ok(true) + } else { + Ok(false) + } + } + + /// Fetches the current version from Navitia + /// + /// # Arguments + /// + /// * `config` - Navitia API configuration + /// + /// # Returns + /// + /// The current version identifier + async fn get_remote_version( + config: &NavitiaConfig, + ) -> Result> { + #[derive(Deserialize)] + struct Status { + dataset_created_at: String, + } + + #[derive(Deserialize)] + struct StatusResponse { + status: Status, + } + + let url = format!("{}/coverage/{}/status", config.navitia_url, config.coverage); + let client = reqwest::Client::new(); + let response = client + .get(&url) + .header( + header::AUTHORIZATION, + format!("Bearer {}", config.navitia_token), + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(format!("Failed to fetch status: {}", response.status()).into()); + } + + let res = response.json::().await?; + Ok(res.status.dataset_created_at) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::mock; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tokio::time::{self, Duration}; + + #[derive(Clone)] + struct MockDownloader; + + impl Downloader for MockDownloader { + fn run_download(&self, _config: &ModelConfig, _version: &str) -> DownloadResult { + Box::pin(async move { + time::sleep(Duration::from_secs(1)).await; + Ok("tests/fixtures/minimal_ntfs/".into()) + }) + } + } + + fn create_test_config() -> (ModelConfig, NavitiaConfig) { + ( + ModelConfig { + check_interval_secs: 1, + path: "test_path".into(), + }, + NavitiaConfig { + navitia_url: mockito::server_url(), + coverage: "test_coverage".into(), + navitia_token: "test_token".into(), + }, + ) + } + + async fn create_mock_navitia_response(version: &str) -> mockito::Mock { + mock("GET", "/coverage/test_coverage/status") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(format!( + r#"{{"status": {{"dataset_created_at": "{}"}}}}"#, + version + )) + .create() + } + + #[tokio::test] + async fn test_initialization() { + let (config, navitia_config) = create_test_config(); + let _m = create_mock_navitia_response("1.0.0").await; + + let model = DownloadableTransitModel::::new( + navitia_config, + config, + MockDownloader {}, + ) + .await + .unwrap(); + let version = model.version.lock().await; + assert_eq!(*version, "1.0.0"); + } + + #[tokio::test] + async fn test_background_update() { + let (config, navitia_config) = create_test_config(); + let _m = create_mock_navitia_response("1.0.0").await; + + let model = DownloadableTransitModel::::new( + navitia_config.clone(), + config.clone(), + MockDownloader {}, + ) + .await + .unwrap(); + let version = model.version.clone(); + let model_ref = model.current_model.clone(); + + // update to 1.0.1 version + let _m = create_mock_navitia_response("1.0.1").await; + + // Force an update + let updated = DownloadableTransitModel::::check_and_update( + &config, + &MockDownloader {}, + &model_ref, + &version, + &navitia_config, + ) + .await + .unwrap(); + assert!(updated); + + let version = model.version.lock().await; + assert_eq!(*version, "1.0.1"); + } + + #[tokio::test] + async fn test_no_update_when_version_same() { + let _m = create_mock_navitia_response("1.0.0").await; + let (model_config, navitia_config) = create_test_config(); + let downloader = MockDownloader {}; + + let model_manager = DownloadableTransitModel::new( + navitia_config.clone(), + model_config.clone(), + downloader.clone(), + ) + .await + .unwrap(); + + let updated = DownloadableTransitModel::check_and_update( + &model_config, + &downloader, + &model_manager.current_model, + &model_manager.version, + &navitia_config, + ) + .await + .unwrap(); + + assert!(!updated); + } + + #[tokio::test] + async fn test_concurrent_access_during_update() { + let _m = create_mock_navitia_response("1.0.0").await; + let (model_config, navitia_config) = create_test_config(); + let downloader = MockDownloader {}; + + let model_manager = DownloadableTransitModel::new(navitia_config, model_config, downloader) + .await + .unwrap(); + + let handle = model_manager.current_model.clone(); + let (tx, rx) = tokio::sync::oneshot::channel(); + + // Spawn a reader that holds the lock + tokio::spawn(async move { + let _guard = handle.read().await; + let lines = _guard + .lines + .get_idx("M1") + .iter() + .map(|idx| _guard.lines[*idx].name.to_string()) + .collect::>(); + assert_eq!(lines, vec!["Metro 1".to_string(),]); + tx.send(()).unwrap(); + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for read lock acquisition + rx.await.unwrap(); + + // Attempt update + let update_handle = model_manager.current_model.clone(); + tokio::spawn(async move { + let _guard = update_handle.write().await; + }); + } + + #[tokio::test] + async fn test_atomic_update() { + let version_counter = Arc::new(AtomicUsize::new(0)); + let (model_config, navitia_config) = create_test_config(); + + #[derive(Clone)] + struct CountingDownloader { + counter: Arc, + } + + impl Downloader for CountingDownloader { + fn run_download(&self, _: &ModelConfig, _: &str) -> DownloadResult { + let counter = self.counter.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + Ok("tests/fixtures/minimal_ntfs/".into()) + }) + } + } + + let downloader = CountingDownloader { + counter: version_counter.clone(), + }; + + let _m = create_mock_navitia_response("1.0.0").await; + let model_manager = DownloadableTransitModel::new( + navitia_config.clone(), + model_config.clone(), + downloader.clone(), + ) + .await + .unwrap(); + + // Trigger multiple update checks + for _ in 0..5 { + let _ = DownloadableTransitModel::check_and_update( + &model_config, + &downloader, + &model_manager.current_model, + &model_manager.version, + &navitia_config, + ) + .await; + } + + assert_eq!(version_counter.load(Ordering::SeqCst), 1); + } +} diff --git a/src/file_handler/mod.rs b/src/file_handler/mod.rs index 0c825f165..eb3ac3773 100644 --- a/src/file_handler/mod.rs +++ b/src/file_handler/mod.rs @@ -45,7 +45,7 @@ impl> PathFileHandler

{ } } -impl<'a, P: AsRef> FileHandler for &'a mut PathFileHandler

{ +impl> FileHandler for &mut PathFileHandler

{ type Reader = File; fn get_file_if_exists(self, name: &str) -> Result<(Option, PathBuf)> { let f = self.base_path.as_ref().join(name); diff --git a/src/gtfs/mod.rs b/src/gtfs/mod.rs index 3f89e92ec..2a67734c5 100644 --- a/src/gtfs/mod.rs +++ b/src/gtfs/mod.rs @@ -69,7 +69,7 @@ struct Agency { ticketing_deep_link_id: Option, } -impl<'a> From<&'a objects::Network> for Agency { +impl From<&objects::Network> for Agency { fn from(obj: &objects::Network) -> Agency { Agency { id: Some(obj.id.clone()), @@ -247,7 +247,7 @@ struct BookingRule { booking_url: Option, } -impl<'a> From<&'a objects::BookingRule> for BookingRule { +impl From<&objects::BookingRule> for BookingRule { fn from(obj: &objects::BookingRule) -> BookingRule { BookingRule { id: obj.id.clone(), @@ -285,7 +285,7 @@ struct Transfer { min_transfer_time: Option, } -impl<'a> From<&'a objects::Transfer> for Transfer { +impl From<&objects::Transfer> for Transfer { fn from(obj: &objects::Transfer) -> Transfer { Transfer { from_stop_id: obj.from_stop_id.clone(), diff --git a/src/gtfs/write.rs b/src/gtfs/write.rs index 6ca8461e9..8e7944450 100644 --- a/src/gtfs/write.rs +++ b/src/gtfs/write.rs @@ -375,7 +375,7 @@ where .collect() } -impl<'a> From<&'a objects::PhysicalMode> for RouteType { +impl From<&objects::PhysicalMode> for RouteType { fn from(obj: &objects::PhysicalMode) -> RouteType { match obj.id.as_str() { "RailShuttle" | "Tramway" => RouteType::Tramway, diff --git a/src/lib.rs b/src/lib.rs index b78c53283..7ec76240c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,6 +77,10 @@ pub mod validity_period; mod version_utils; pub mod vptranslator; +/// The `downloadable_model` module is used to manage a transit model that can be downloaded from a remote location. +// #[cfg(feature = "downloadable")] +pub mod downloadable_model; + // Good average size for initialization of the `StopTime` collection in `VehicleJourney` // Note: they are shrinked down in `Model::new()` to fit the real size pub(crate) const STOP_TIMES_INIT_CAPACITY: usize = 50; diff --git a/src/model_builder.rs b/src/model_builder.rs index 6606b4f01..156119c41 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -606,7 +606,7 @@ impl AsDateTime for &NaiveDateTime { } } -impl<'a> VehicleJourneyBuilder<'a> { +impl VehicleJourneyBuilder<'_> { fn find_or_create_sp(&mut self, sp: &str) -> Idx { self.model .collections @@ -1039,7 +1039,7 @@ impl<'a> VehicleJourneyBuilder<'a> { } } -impl<'a> Drop for VehicleJourneyBuilder<'a> { +impl Drop for VehicleJourneyBuilder<'_> { fn drop(&mut self) { use std::ops::DerefMut; let collections = &mut self.model.collections; diff --git a/src/objects.rs b/src/objects.rs index 4794afeef..22334e68d 100644 --- a/src/objects.rs +++ b/src/objects.rs @@ -924,7 +924,9 @@ impl<'de> ::serde::Deserialize<'de> for Time { // using the visitor pattern to avoid a string allocation struct TimeVisitor; - impl<'de> Visitor<'de> for TimeVisitor { + + // Use anonymous lifetime for Visitor implementation + impl Visitor<'_> for TimeVisitor { type Value = Time; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("a time in the format HH:MM:SS") diff --git a/transit_model_python/Cargo.toml b/transit_model_python/Cargo.toml index f2e30b6df..123929f81 100644 --- a/transit_model_python/Cargo.toml +++ b/transit_model_python/Cargo.toml @@ -14,7 +14,9 @@ name = "transit_model_python" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.23.3" -transit_model = { path = "../" } +pyo3 = { version = "0.23.3", features = ["extension-module"] } +transit_model = { path = "../", features = ["downloadable"] } typed_index_collection = { git = "https://github.com/hove-io/typed_index_collection", tag = "v2" } relational_types = { git = "https://github.com/hove-io/relational_types", tag = "v2" } +tokio = { version = "1.0", features = ["full"] } + diff --git a/transit_model_python/example.text b/transit_model_python/example.text new file mode 100644 index 000000000..0eae08211 --- /dev/null +++ b/transit_model_python/example.text @@ -0,0 +1,40 @@ +# from transit_model_python import PythonTransitModel, PythonDownloadableModel, NavitiaConfig, ModelConfig, PyDownloader + +# print("Hello World! from python_test.py") +# model = PythonTransitModel("../tests/fixtures/minimal_ntfs/") +# lines = model.get_lines("B42") +# print(lines) # Output: ["Metro 1"] + + +# use downloadable model + +# class NTFSDownloader(PyDownloader): +# def __new__(cls): +# # Create instance by passing the class as 'obj' to parent __new__ +# return super().__new__(cls, cls) + +# def run_download(self, config, version): +# print("Downloading NTFS data") +# # Download NTFS with any method you want and return the path to the downloaded file +# return "../tests/fixtures/minimal_ntfs/" + + +# config = ModelConfig( +# check_interval_secs=3600, +# path="./models" +# ) + +# navitia_config = NavitiaConfig( +# navitia_url="", +# navitia_token="", +# coverage="" +# ) + +# model = PythonDownloadableModel( +# navitia_config=navitia_config, +# model_config=config, +# downloader=NTFSDownloader() +# ) + + +# model.get_line("B42") diff --git a/transit_model_python/pyproject.toml b/transit_model_python/pyproject.toml index d4a14c0a5..c7dc8cd74 100644 --- a/transit_model_python/pyproject.toml +++ b/transit_model_python/pyproject.toml @@ -21,5 +21,5 @@ authors = [ ] maintainers = [ - {name = "Prince Merveil ONDONDA",email = "pondonda@gmail.com"} + {name = "Prince Merveil ONDONDA",email = "pondonda@hove.com"} ] diff --git a/transit_model_python/python_test.py b/transit_model_python/python_test.py deleted file mode 100644 index c4cce7eb8..000000000 --- a/transit_model_python/python_test.py +++ /dev/null @@ -1,5 +0,0 @@ -from transit_model_python import PythonTransitModel - -model = PythonTransitModel("../tests/fixtures/minimal_ntfs/") -lines = model.get_lines("M1") -print(lines) # Output: ["Metro 1"] \ No newline at end of file diff --git a/transit_model_python/src/lib.rs b/transit_model_python/src/lib.rs index 36807b1a0..e31cbd9f9 100644 --- a/transit_model_python/src/lib.rs +++ b/transit_model_python/src/lib.rs @@ -12,251 +12,24 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see -use pyo3::{exceptions::PyValueError, prelude::*}; -use std::sync::Arc; -use transit_model as intern_transit_model; -use transit_model::{model::Model, objects::StopTime}; -use typed_index_collection::Id; +//! PyO3 bindings for transit model data structures and operations +//! +//! Provides Python interfaces to access and manipulate transit model data +//! from NTFS (NeTEx France Standard) datasets. -#[pyclass] -pub struct PythonTransitModel { - model: Arc, -} - -#[pymethods] -impl PythonTransitModel { - /// Create a new PythonTransitModel - /// - /// # Arguments - /// * `path` - The path to the NTFS file - /// - /// # Returns - /// * A new PythonTransitModel - /// - #[new] - pub fn new(path: &str) -> Self { - let transit_objects = - intern_transit_model::ntfs::read(path).expect("Failed to read transit objects"); - Self { - model: Arc::new(transit_objects), - } - } - - /// Get the name of the stop - /// - /// # Arguments - /// - /// * `idx` - The index of the stop - /// - /// # Returns - /// - /// * The name of the line - pub fn get_lines(&self, idx: String) -> PyResult> { - Ok(self - .model - .lines - .get_idx(&idx) - .iter() - .map(|idx| self.model.lines[*idx].name.to_string()) - .collect()) - } - - /// Get the contributors providing the data for the stop - /// - /// # Arguments - /// - /// * `idx` - The index of the stop - /// - /// # Returns - /// - /// * The list of contributors providing the data for the stop - pub fn get_contributors(&self, idx: String) -> PyResult> { - Ok(self - .model - .contributors - .get_idx(&idx) - .iter() - .map(|idx| self.model.contributors[*idx].name.to_string()) - .collect()) - } - - /// Get the networks the stop belongs to - /// - /// # Arguments - /// - /// * `idx` - The index of the stop - /// - /// # Returns - /// - /// * The list of networks the stop belongs to - pub fn get_networks(&self, idx: String) -> PyResult> { - Ok(self - .model - .networks - .get_idx(&idx) - .iter() - .map(|idx| self.model.networks[*idx].id().to_string()) - .collect()) - } - - /// Get the name of the stop - /// - /// # Arguments - /// - /// * `idx` - The index of the stop - /// - /// # Returns - /// - /// * The name of the stop - pub fn get_stop_area_by_id(&self, idx: String) -> PyResult { - Ok(self - .model - .stop_areas - .get_idx(&idx) - .iter() - .map(|idx| self.model.stop_areas[*idx].name.clone()) - .collect()) - } - - /// Get the vehicule journey by id - /// - /// # Arguments - /// - /// * `idx` - The index of the vehicule journey - /// - /// # Returns - /// - /// * The vehicule journey id - pub fn get_vehicule_journey_by_id(&self, idx: String) -> PyResult { - Ok(self - .model - .vehicle_journeys - .get_idx(&idx) - .iter() - .map(|idx| self.model.vehicle_journeys[*idx].id.as_str()) - .collect()) - } - - /// Get the vehicule journey stop times - /// - /// # Arguments - /// - /// * `idx` - The index of the vehicule journey - /// - /// # Returns - /// - /// * The vehicule journey stop times - pub fn get_vehicule_journey_stop_times(&self, idx: String) -> PyResult> { - Ok(self - .model - .vehicle_journeys - .get_idx(&idx) - .iter() - .flat_map(|idx| self.model.vehicle_journeys[*idx].stop_times.iter().cloned()) - .collect()) - } +pub mod modules; - /// Get the vehicule journey stop times by vehicule journey id and stop id - /// - /// # Arguments - /// - /// * `vehicule_id` - The index of the vehicule journey - /// * `stop_id` - The index of the stop - /// - /// # Returns - /// - /// * The vehicule journey stop times - pub fn get_vehicule_journey_stop_times_by_stop_id( - &self, - vehicule_id: String, - stop_id: String, - ) -> PyResult> { - let stop_point = match self.model.stop_points.get_idx(&stop_id) { - Some(idx) => Some(idx), - None => None, - }; - if stop_point.is_none() { - return Err(PyValueError::new_err("StopPoint not found")); - } - let stop_times: Vec = self - .model - .vehicle_journeys - .get_idx(&vehicule_id) - .iter() - .flat_map(|idx| { - self.model.vehicle_journeys[*idx] - .stop_times - .iter() - .filter(|st| st.stop_point_idx == stop_point.unwrap()) - .cloned() - }) - .collect(); - - if stop_times.is_empty() { - Err(PyValueError::new_err("StopTime not found")) - } else { - Ok(stop_times) - } - } -} +use modules::downloadable_model::{PyDownloader, PythonDownloadableModel}; +use modules::python_transit_model::PythonTransitModel; +use pyo3::prelude::*; +use transit_model::downloadable_model::{ModelConfig, NavitiaConfig}; #[pymodule] fn transit_model_python(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_get_lines() { - let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") - .expect("Failed to read transit objects"); - let python_transit_model = PythonTransitModel { - model: Arc::new(transit_objects), - }; - let lines = python_transit_model.get_lines("M1".to_string()).unwrap(); - assert_eq!(lines, vec!["Metro 1".to_string(),]); - } - - #[test] - fn test_get_contributors() { - let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") - .expect("Failed to read transit objects"); - let python_transit_model = PythonTransitModel { - model: Arc::new(transit_objects), - }; - let contributors = python_transit_model - .get_contributors("TGC".to_string()) - .unwrap(); - assert_eq!(contributors, vec!["The Great Contributor".to_string(),]); - } - - #[test] - fn test_get_networks() { - let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") - .expect("Failed to read transit objects"); - let python_transit_model = PythonTransitModel { - model: Arc::new(transit_objects), - }; - let networks: Vec = python_transit_model - .get_networks("TGN".to_string()) - .unwrap(); - assert_eq!(networks, vec!["TGN".to_string(),]); - } - - #[test] - fn test_get_stop_area_by_id() { - let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") - .expect("Failed to read transit objects"); - let python_transit_model = PythonTransitModel { - model: Arc::new(transit_objects), - }; - let stop_area = python_transit_model - .get_stop_area_by_id("GDL".to_string()) - .unwrap(); - assert_eq!(stop_area, "Gare de Lyon".to_string()); - } -} diff --git a/transit_model_python/src/modules/downloadable_model.rs b/transit_model_python/src/modules/downloadable_model.rs new file mode 100644 index 000000000..d54ad42f9 --- /dev/null +++ b/transit_model_python/src/modules/downloadable_model.rs @@ -0,0 +1,292 @@ +// Copyright (C) 2024 Hove and/or its affiliates. +// +// This program is free software: you can redistribute it and/or modify it +// under the terms of the GNU Affero General Public License as published by the +// Free Software Foundation, version 3. + +// This program is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +// details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see +//! PyO3 bindings for transit model components to enable Python interoperability + +use pyo3::{exceptions::PyValueError, prelude::*}; +use std::{error::Error, future::Future, pin::Pin}; +use tokio::runtime::Runtime; +use transit_model::{ + downloadable_model::{DownloadableTransitModel, Downloader, ModelConfig, NavitiaConfig}, + objects::StopTime, +}; + +/// Python wrapper for a downloader component implementing the Downloader trait +/// +/// Acts as a bridge between Rust's Downloader trait and Python implementations, +/// allowing Python classes to provide download functionality to Rust code +#[pyclass(subclass)] +pub struct PyDownloader { + /// The underlying Python object implementing the download logic + inner: Py, +} + +impl Clone for PyDownloader { + /// Creates a new reference to the same Python downloader object + /// + /// Uses Python's GIL to safely clone the Python object reference + fn clone(&self) -> Self { + Python::with_gil(|py| PyDownloader { + inner: self.inner.clone_ref(py), + }) + } +} + +#[pymethods] +impl PyDownloader { + /// Creates a new PyDownloader wrapping a Python object + /// + /// # Arguments + /// * `obj` - Python object implementing the `run_download` method + #[new] + fn new(obj: Py) -> Self { + PyDownloader { inner: obj } + } +} + +impl Downloader for PyDownloader { + /// Executes the download operation by calling into Python implementation + /// + /// # Arguments + /// * `config` - Model configuration parameters + /// * `version` - Target version identifier for download + /// + /// # Returns + /// Future resolving to local path of downloaded model or error + fn run_download( + &self, + config: &ModelConfig, + version: &str, + ) -> Pin>> + Send>> { + let model = Python::with_gil(|py| self.inner.clone_ref(py)); + let version = version.to_string(); + let config = config.clone(); + + Box::pin(async move { + Python::with_gil(|py| { + model + .bind(py) + .call_method("run_download", (config, version), None) + .map_err(|e| Box::new(e) as Box) + .and_then(|result| { + result + .extract() + .map_err(|e| Box::new(e) as Box) + }) + }) + }) + } +} + +/// Python-exposed interface for interacting with downloadable transit models +/// +/// Provides thread-safe access to transit model data with async-aware locking +#[pyclass] +pub struct PythonDownloadableModel { + /// The underlying Rust implementation of the downloadable transit model + model: DownloadableTransitModel, +} + +#[pymethods] +impl PythonDownloadableModel { + /// Initializes a new downloadable transit model instance + /// + /// # Arguments + /// * `navitia_config` - Configuration for Navitia integration + /// * `model_config` - General model configuration parameters + /// * `downloader` - Downloader component implementing the download logic + /// + /// # Errors + /// Returns `PyValueError` if initialization fails at any stage + #[new] + pub fn new( + navitia_config: NavitiaConfig, + model_config: ModelConfig, + downloader: Py, + ) -> PyResult { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = rt + .block_on(async { + let rust_downloader = Python::with_gil(|py| match downloader.bind(py).extract() { + Ok(downloader) => Ok(downloader), + Err(e) => Err(PyValueError::new_err(format!( + "Failed to extract downloader: {}", + e + ))), + })?; + + DownloadableTransitModel::new(navitia_config, model_config, rust_downloader).await + }) + .map_err(|e| PyValueError::new_err(format!("Failed to create model: {}", e)))?; + + Ok(Self { model }) + } + + /// Retrieves transit lines associated with a given stop identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the stop + /// + /// # Returns + /// List of line names servicing the specified stop + pub fn get_lines(&self, idx: String) -> PyResult> { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let lines = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .lines + .get_idx(&idx) + .iter() + .map(|index| guard.lines[*index].name.clone()) + .collect::>() + }); + + Ok(lines) + } + + /// Retrieves contributors associated with a given stop identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the stop + /// + /// # Returns + /// List of contributor names providing data for the specified stop + pub fn get_contributors(&self, idx: String) -> PyResult> { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let contributors = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .contributors + .get_idx(&idx) + .iter() + .map(|index| guard.contributors[*index].name.clone()) + .collect::>() + }); + + Ok(contributors) + } + + /// Retrieves networks associated with a given stop identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the stop + /// + /// # Returns + /// List of network names the specified stop belongs to + pub fn get_networks(&self, idx: String) -> PyResult> { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let networks = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .networks + .get_idx(&idx) + .iter() + .map(|index| guard.networks[*index].name.clone()) + .collect::>() + }); + + Ok(networks) + } + + /// Retrieves the name of a stop area by its identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the stop area + /// + /// # Returns + /// Name of the specified stop area + pub fn get_stop_area_by_id(&self, idx: String) -> PyResult { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let stop_area = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .stop_areas + .get_idx(&idx) + .iter() + .map(|index| guard.stop_areas[*index].name.clone()) + .collect::() + }); + + Ok(stop_area) + } + + /// Retrieves a vehicle journey by its identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the vehicle journey + /// + /// # Returns + /// ID of the specified vehicle journey + pub fn get_vehicle_journey_by_id(&self, idx: String) -> PyResult { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let vehicle_journey = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .vehicle_journeys + .get_idx(&idx) + .iter() + .map(|index| guard.vehicle_journeys[*index].id.clone()) + .collect::() + }); + + Ok(vehicle_journey) + } + + /// Retrieves stop times for a specific vehicle journey + /// + /// # Arguments + /// * `idx` - Unique identifier for the vehicle journey + /// + /// # Returns + /// List of stop times associated with the specified vehicle journey + pub fn get_vehicle_journey_stop_times(&self, idx: String) -> PyResult> { + let rt = Runtime::new() + .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {}", e)))?; + + let model = &self.model; + + let vehicle_journey_stop_times = rt.block_on(async { + let guard = model.current_model.read().await; + guard + .vehicle_journeys + .get_idx(&idx) + .iter() + .flat_map(|index| guard.vehicle_journeys[*index].stop_times.iter().cloned()) + .collect::>() + }); + + Ok(vehicle_journey_stop_times) + } +} diff --git a/transit_model_python/src/modules/mod.rs b/transit_model_python/src/modules/mod.rs new file mode 100644 index 000000000..213444b1f --- /dev/null +++ b/transit_model_python/src/modules/mod.rs @@ -0,0 +1,2 @@ +pub mod downloadable_model; +pub mod python_transit_model; diff --git a/transit_model_python/src/modules/python_transit_model.rs b/transit_model_python/src/modules/python_transit_model.rs new file mode 100644 index 000000000..e17c15ab3 --- /dev/null +++ b/transit_model_python/src/modules/python_transit_model.rs @@ -0,0 +1,263 @@ +// Copyright (C) 2024 Hove and/or its affiliates. +// +// This program is free software: you can redistribute it and/or modify it +// under the terms of the GNU Affero General Public License as published by the +// Free Software Foundation, version 3. + +// This program is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +// details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see + +//! PyO3 bindings for transit model data structures and operations +//! +//! Provides Python interfaces to access and manipulate transit model data +//! from NTFS (NeTEx France Standard) datasets. +//! + +use pyo3::{exceptions::PyValueError, prelude::*}; +use std::sync::Arc; +use transit_model as intern_transit_model; +use transit_model::{model::Model, objects::StopTime}; +use typed_index_collection::Id; + +/// Thread-safe wrapper for transit model data with Python bindings +/// +/// Provides read-only access to transit model components through +/// atomic reference counting for safe concurrent access. +#[pyclass] +pub struct PythonTransitModel { + /// Shared reference to the underlying transit model data + model: Arc, +} + +#[pymethods] +impl PythonTransitModel { + /// Creates a new transit model instance from NTFS data + /// + /// # Arguments + /// * `path` - Path to the NTFS dataset directory + /// + /// # Panics + /// Will panic if the NTFS data cannot be read from the specified path + /// + /// # Example + /// ```python + /// model = PythonTransitModel("/path/to/ntfs/data") + /// ``` + #[new] + pub fn new(path: &str) -> Self { + let transit_objects = + intern_transit_model::ntfs::read(path).expect("Failed to read transit objects"); + Self { + model: Arc::new(transit_objects), + } + } + + /// Retrieves line names by line identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the transit line + /// + /// # Returns + /// Vector of line names matching the identifier (empty if not found) + /// + /// # Example + /// ```python + /// line_names = model.get_lines("line:123") + /// ``` + pub fn get_lines(&self, idx: String) -> PyResult> { + Ok(self + .model + .lines + .get_idx(&idx) + .iter() + .map(|idx| self.model.lines[*idx].name.to_string()) + .collect()) + } + + /// Retrieves contributor names by contributor identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the contributor + /// + /// # Returns + /// Vector of contributor names matching the identifier (empty if not found) + pub fn get_contributors(&self, idx: String) -> PyResult> { + Ok(self + .model + .contributors + .get_idx(&idx) + .iter() + .map(|idx| self.model.contributors[*idx].name.to_string()) + .collect()) + } + + /// Retrieves network identifiers by network identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the network + /// + /// # Returns + /// Vector containing the network ID if found (empty if not found) + /// + /// # Note + /// Returns the same identifier provided as input when exists in the model + pub fn get_networks(&self, idx: String) -> PyResult> { + Ok(self + .model + .networks + .get_idx(&idx) + .iter() + .map(|idx| self.model.networks[*idx].id().to_string()) + .collect()) + } + + /// Retrieves stop area name by stop area identifier + /// + /// # Arguments + /// * `idx` - Unique identifier for the stop area + /// + /// # Returns + /// Concatenated string of stop area names (empty if not found) + pub fn get_stop_area_by_id(&self, idx: String) -> PyResult { + Ok(self + .model + .stop_areas + .get_idx(&idx) + .iter() + .map(|idx| self.model.stop_areas[*idx].name.clone()) + .collect()) + } + + /// Retrieves vehicle journey identifier by journey ID + /// + /// # Arguments + /// * `idx` - Unique identifier for the vehicle journey + /// + /// # Returns + /// Concatenated string of journey IDs (empty if not found) + pub fn get_vehicule_journey_by_id(&self, idx: String) -> PyResult { + Ok(self + .model + .vehicle_journeys + .get_idx(&idx) + .iter() + .map(|idx| self.model.vehicle_journeys[*idx].id.as_str()) + .collect()) + } + + /// Retrieves all stop times for a specific vehicle journey + /// + /// # Arguments + /// * `idx` - Unique identifier for the vehicle journey + /// + /// # Returns + /// Vector of StopTime objects for the specified journey (empty if not found) + pub fn get_vehicule_journey_stop_times(&self, idx: String) -> PyResult> { + Ok(self + .model + .vehicle_journeys + .get_idx(&idx) + .iter() + .flat_map(|idx| self.model.vehicle_journeys[*idx].stop_times.iter().cloned()) + .collect()) + } + + /// Filters stop times for a specific vehicle journey and stop point + /// + /// # Arguments + /// * `vehicule_id` - Vehicle journey identifier + /// * `stop_id` - Stop point identifier + /// + /// # Returns + /// Vector of matching StopTime objects + /// + /// # Errors + /// Returns PyValueError if either the stop point or vehicle journey is not found + pub fn get_vehicule_journey_stop_times_by_stop_id( + &self, + vehicule_id: String, + stop_id: String, + ) -> PyResult> { + let stop_point_idx = self + .model + .stop_points + .get_idx(&stop_id) + .ok_or_else(|| PyValueError::new_err("StopPoint not found"))?; + + let stop_times: Vec = self + .model + .vehicle_journeys + .get_idx(&vehicule_id) + .into_iter() + .flat_map(|idx| &self.model.vehicle_journeys[idx].stop_times) + .filter(|st| st.stop_point_idx == stop_point_idx) + .cloned() + .collect(); + + if stop_times.is_empty() { + Err(PyValueError::new_err("StopTime not found")) + } else { + Ok(stop_times) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_lines() { + let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") + .expect("Failed to read transit objects"); + let python_transit_model = PythonTransitModel { + model: Arc::new(transit_objects), + }; + let lines = python_transit_model.get_lines("M1".to_string()).unwrap(); + assert_eq!(lines, vec!["Metro 1".to_string(),]); + } + + #[test] + fn test_get_contributors() { + let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") + .expect("Failed to read transit objects"); + let python_transit_model = PythonTransitModel { + model: Arc::new(transit_objects), + }; + let contributors = python_transit_model + .get_contributors("TGC".to_string()) + .unwrap(); + assert_eq!(contributors, vec!["The Great Contributor".to_string(),]); + } + + #[test] + fn test_get_networks() { + let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") + .expect("Failed to read transit objects"); + let python_transit_model = PythonTransitModel { + model: Arc::new(transit_objects), + }; + let networks: Vec = python_transit_model + .get_networks("TGN".to_string()) + .unwrap(); + assert_eq!(networks, vec!["TGN".to_string(),]); + } + + #[test] + fn test_get_stop_area_by_id() { + let transit_objects = intern_transit_model::ntfs::read("../tests/fixtures/minimal_ntfs/") + .expect("Failed to read transit objects"); + let python_transit_model = PythonTransitModel { + model: Arc::new(transit_objects), + }; + let stop_area = python_transit_model + .get_stop_area_by_id("GDL".to_string()) + .unwrap(); + assert_eq!(stop_area, "Gare de Lyon".to_string()); + } +}