diff --git a/Cargo.lock b/Cargo.lock index df81e3f4e6..b988893b9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -570,7 +570,7 @@ dependencies = [ "futures-util", "handlebars", "http", - "indexmap 2.2.2", + "indexmap 2.2.5", "mime", "multer", "num-traits 0.2.17", @@ -621,7 +621,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "323a5143f5bdd2030f45e3f2e0c821c9b1d36e79cf382129c64299c50a7f3750" dependencies = [ "bytes", - "indexmap 2.2.2", + "indexmap 2.2.5", "serde", "serde_json", ] @@ -1125,7 +1125,7 @@ dependencies = [ "cairo-vm", "ctor", "derive_more", - "indexmap 2.2.2", + "indexmap 2.2.5", "itertools 0.10.5", "keccak", "log", @@ -2012,7 +2012,7 @@ checksum = "12d0939f42d40fb1d975cae073d7d4f82d83de4ba2149293115525245425f909" dependencies = [ "env_logger", "hashbrown 0.14.3", - "indexmap 2.2.2", + "indexmap 2.2.5", "itertools 0.11.0", "log", "num-bigint", @@ -5329,7 +5329,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.2.2", + "indexmap 2.2.5", "slab", "tokio", "tokio-util", @@ -5929,9 +5929,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.2" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -8764,7 +8764,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.2", + "indexmap 2.2.5", ] [[package]] @@ -9727,7 +9727,7 @@ dependencies = [ "bitflags 2.4.2", "byteorder", "derive_more", - "indexmap 2.2.2", + "indexmap 2.2.5", "libc", "parking_lot 0.12.1", "reth-mdbx-sys", @@ -10606,10 +10606,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ + "indexmap 2.2.5", "itoa", "ryu", "serde", @@ -11143,7 +11144,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap 2.2.2", + "indexmap 2.2.5", "log", "memchr", "once_cell", @@ -11615,7 +11616,7 @@ dependencies = [ "cairo-lang-starknet", "derive_more", "hex", - "indexmap 2.2.2", + "indexmap 2.2.5", "once_cell", "primitive-types", "serde", @@ -12186,7 +12187,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.2.2", + "indexmap 2.2.5", "serde", "serde_spanned", "toml_datetime", @@ -12199,7 +12200,7 @@ version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70f427fce4d84c72b5b732388bf4a9f4531b53f74e2887e3ecb2481f68f66d81" dependencies = [ - "indexmap 2.2.2", + "indexmap 2.2.5", "toml_datetime", "winnow", ] @@ -12210,7 +12211,7 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ - "indexmap 2.2.2", + "indexmap 2.2.5", "toml_datetime", "winnow", ] @@ -12221,7 +12222,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c9ffdf896f8daaabf9b66ba8e77ea1ed5ed0f72821b398aba62352e95062951" dependencies = [ - "indexmap 2.2.2", + "indexmap 2.2.5", "serde", "serde_spanned", "toml_datetime", @@ -12554,7 +12555,11 @@ version = "0.6.0-alpha.6" dependencies = [ "anyhow", "async-trait", + "crypto-bigint", + "dojo-types", + "dojo-world", "futures", + "indexmap 2.2.5", "libp2p", "libp2p-webrtc", "libp2p-webrtc-websys", @@ -12562,9 +12567,14 @@ dependencies = [ "regex", "serde", "serde_json", + "sqlx", + "starknet-core 0.9.0", + "starknet-crypto 0.6.1", + "starknet-ff", "tempfile", "thiserror", "tokio", + "torii-core", "tracing", "tracing-subscriber", "tracing-wasm", diff --git a/bin/torii/src/main.rs b/bin/torii/src/main.rs index e5cbe2ac2d..d2dff89607 100644 --- a/bin/torii/src/main.rs +++ b/bin/torii/src/main.rs @@ -145,7 +145,7 @@ async fn main() -> anyhow::Result<()> { // Get world address let world = WorldContractReader::new(args.world_address, &provider); - let mut db = Sql::new(pool.clone(), args.world_address).await?; + let db = Sql::new(pool.clone(), args.world_address).await?; let processors = Processors { event: vec![ Box::new(RegisterModelProcessor), @@ -161,7 +161,7 @@ async fn main() -> anyhow::Result<()> { let mut engine = Engine::new( world, - &mut db, + db.clone(), &provider, processors, EngineConfig { start_block: args.start_block, ..Default::default() }, @@ -179,6 +179,15 @@ async fn main() -> anyhow::Result<()> { ) .await?; + let mut libp2p_relay_server = torii_relay::server::Relay::new( + db, + args.relay_port, + args.relay_webrtc_port, + args.relay_local_key_path, + args.relay_cert_path, + ) + .expect("Failed to start libp2p relay server"); + let proxy_server = Arc::new(Proxy::new(args.addr, args.allowed_origins, Some(grpc_addr), None)); let graphql_server = spawn_rebuilding_graphql_server( @@ -188,14 +197,6 @@ async fn main() -> anyhow::Result<()> { proxy_server.clone(), ); - let mut libp2p_relay_server = torii_relay::server::Relay::new( - args.relay_port, - args.relay_webrtc_port, - args.relay_local_key_path, - args.relay_cert_path, - ) - .expect("Failed to start libp2p relay server"); - let endpoint = format!("http://{}", args.addr); let gql_endpoint = format!("{}/graphql", endpoint); let encoded: String = diff --git a/crates/torii/client/src/client/mod.rs b/crates/torii/client/src/client/mod.rs index 07c96bf77c..aad72ef099 100644 --- a/crates/torii/client/src/client/mod.rs +++ b/crates/torii/client/src/client/mod.rs @@ -10,8 +10,6 @@ use dojo_types::packing::unpack; use dojo_types::schema::Ty; use dojo_types::WorldMetadata; use dojo_world::contracts::WorldContractReader; -use futures::channel::mpsc::UnboundedReceiver; -use futures_util::lock::Mutex; use parking_lot::{RwLock, RwLockReadGuard}; use starknet::core::utils::cairo_short_string_to_felt; use starknet::providers::jsonrpc::HttpTransport; @@ -22,7 +20,7 @@ use torii_grpc::client::{EntityUpdateStreaming, ModelDiffsStreaming}; use torii_grpc::proto::world::RetrieveEntitiesResponse; use torii_grpc::types::schema::Entity; use torii_grpc::types::{KeysClause, Query}; -use torii_relay::client::{EventLoop, Message}; +use torii_relay::types::Message; use crate::client::error::{Error, ParseError}; use crate::client::storage::ModelStorage; @@ -106,41 +104,17 @@ impl Client { self.relay_client.command_sender.wait_for_relay().await.map_err(Error::RelayClient) } - /// Subscribes to a topic. - /// Returns true if the topic was subscribed to. - /// Returns false if the topic was already subscribed to. - pub async fn subscribe_topic(&mut self, topic: String) -> Result { - self.relay_client.command_sender.subscribe(topic).await.map_err(Error::RelayClient) - } - - /// Unsubscribes from a topic. - /// Returns true if the topic was subscribed to. - pub async fn unsubscribe_topic(&mut self, topic: String) -> Result { - self.relay_client.command_sender.unsubscribe(topic).await.map_err(Error::RelayClient) - } - /// Publishes a message to a topic. /// Returns the message id. - pub async fn publish_message(&mut self, topic: &str, message: &[u8]) -> Result, Error> { + pub async fn publish_message(&mut self, message: Message) -> Result, Error> { self.relay_client .command_sender - .publish(topic.to_string(), message.to_vec()) + .publish(message) .await .map_err(Error::RelayClient) .map(|m| m.0) } - /// Returns the event loop of the relay client. - /// Which can then be used to run the relay client - pub fn relay_client_runner(&self) -> Arc> { - self.relay_client.event_loop.clone() - } - - /// Returns the message receiver of the relay client. - pub fn relay_client_stream(&self) -> Arc>> { - self.relay_client.message_receiver.clone() - } - /// Returns a read lock on the World metadata that the client is connected to. pub fn metadata(&self) -> RwLockReadGuard<'_, WorldMetadata> { self.metadata.read() diff --git a/crates/torii/core/src/engine.rs b/crates/torii/core/src/engine.rs index 01e9d82acf..c9825e0b6f 100644 --- a/crates/torii/core/src/engine.rs +++ b/crates/torii/core/src/engine.rs @@ -41,9 +41,9 @@ impl Default for EngineConfig { } } -pub struct Engine<'db, P: Provider + Sync> { +pub struct Engine { world: WorldContractReader

, - db: &'db mut Sql, + db: Sql, provider: Box

, processors: Processors

, config: EngineConfig, @@ -56,10 +56,10 @@ struct UnprocessedEvent { data: Vec, } -impl<'db, P: Provider + Sync> Engine<'db, P> { +impl Engine

{ pub fn new( world: WorldContractReader

, - db: &'db mut Sql, + db: Sql, provider: P, processors: Processors

, config: EngineConfig, @@ -240,7 +240,7 @@ impl<'db, P: Provider + Sync> Engine<'db, P> { async fn process_block(&mut self, block: &BlockWithTxs) -> Result<()> { for processor in &self.processors.block { - processor.process(self.db, self.provider.as_ref(), block).await?; + processor.process(&mut self.db, self.provider.as_ref(), block).await?; } Ok(()) } @@ -255,7 +255,7 @@ impl<'db, P: Provider + Sync> Engine<'db, P> { for processor in &self.processors.transaction { processor .process( - self.db, + &mut self.db, self.provider.as_ref(), block, transaction_receipt, @@ -288,7 +288,7 @@ impl<'db, P: Provider + Sync> Engine<'db, P> { && processor.validate(event) { processor - .process(&self.world, self.db, block, transaction_receipt, event_id, event) + .process(&self.world, &mut self.db, block, transaction_receipt, event_id, event) .await?; } else { let unprocessed_event = UnprocessedEvent { diff --git a/crates/torii/core/src/sql.rs b/crates/torii/core/src/sql.rs index 6289ab2e0a..a0f7d44bb8 100644 --- a/crates/torii/core/src/sql.rs +++ b/crates/torii/core/src/sql.rs @@ -26,7 +26,7 @@ mod test; #[derive(Debug, Clone)] pub struct Sql { world_address: FieldElement, - pool: Pool, + pub pool: Pool, query_queue: QueryQueue, } diff --git a/crates/torii/core/src/sql_test.rs b/crates/torii/core/src/sql_test.rs index 61e21f4bf7..ad6addc46d 100644 --- a/crates/torii/core/src/sql_test.rs +++ b/crates/torii/core/src/sql_test.rs @@ -22,11 +22,11 @@ use crate::sql::Sql; pub async fn bootstrap_engine

( world: WorldContractReader

, - db: &mut Sql, + db: Sql, provider: P, migration: MigrationStrategy, sequencer: TestSequencer, -) -> Result, Box> +) -> Result, Box> where P: Provider + Send + Sync, { @@ -72,7 +72,7 @@ async fn test_load_from_remote() { let world = WorldContractReader::new(migration.world_address().unwrap(), &provider); let mut db = Sql::new(pool.clone(), migration.world_address().unwrap()).await.unwrap(); - let _ = bootstrap_engine(world, &mut db, &provider, migration, sequencer).await; + let _ = bootstrap_engine(world, db.clone(), &provider, migration, sequencer).await; let models = sqlx::query("SELECT * FROM models").fetch_all(&pool).await.unwrap(); assert_eq!(models.len(), 2); diff --git a/crates/torii/graphql/src/tests/mod.rs b/crates/torii/graphql/src/tests/mod.rs index 6fcbde2404..f723cab1e7 100644 --- a/crates/torii/graphql/src/tests/mod.rs +++ b/crates/torii/graphql/src/tests/mod.rs @@ -264,7 +264,7 @@ pub async fn spinup_types_test() -> Result { let target_path = format!("{}/target/dev", base_path); let migration = prepare_migration(base_path.into(), target_path.into()).unwrap(); let config = build_test_config("../types-test/Scarb.toml").unwrap(); - let mut db = Sql::new(pool.clone(), migration.world_address().unwrap()).await.unwrap(); + let db = Sql::new(pool.clone(), migration.world_address().unwrap()).await.unwrap(); let sequencer = TestSequencer::start(SequencerConfig::default(), get_default_test_starknet_config()).await; @@ -316,7 +316,7 @@ pub async fn spinup_types_test() -> Result { let (shutdown_tx, _) = broadcast::channel(1); let mut engine = Engine::new( world, - &mut db, + db, &provider, Processors { event: vec![ diff --git a/crates/torii/grpc/src/server/subscriptions/model_diff.rs b/crates/torii/grpc/src/server/subscriptions/model_diff.rs index ac73d15f6e..2aa15cc1eb 100644 --- a/crates/torii/grpc/src/server/subscriptions/model_diff.rs +++ b/crates/torii/grpc/src/server/subscriptions/model_diff.rs @@ -51,7 +51,7 @@ impl StateDiffManager { &self, reqs: Vec, ) -> Result>, Error> { - let id = rand::thread_rng().gen::(); + let id: usize = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); diff --git a/crates/torii/libp2p/Cargo.toml b/crates/torii/libp2p/Cargo.toml index 51e308975d..0503c861f1 100644 --- a/crates/torii/libp2p/Cargo.toml +++ b/crates/torii/libp2p/Cargo.toml @@ -11,26 +11,36 @@ version.workspace = true futures.workspace = true rand = "0.8.5" serde.workspace = true -serde_json.workspace = true -thiserror.workspace = true -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tracing.workspace = true +# preserve order +anyhow.workspace = true async-trait = "0.1.77" +crypto-bigint.workspace = true +dojo-types.workspace = true regex = "1.10.3" -anyhow.workspace = true +serde_json = { version = "1.0.114", features = [ "preserve_order" ] } +starknet-core = "0.9.0" +starknet-crypto.workspace = true +starknet-ff = "0.3.6" +thiserror.workspace = true +tracing-subscriber = { version = "0.3", features = [ "env-filter" ] } +tracing.workspace = true +indexmap = "2.2.5" [dev-dependencies] +dojo-world = { path = "../../dojo-world", features = [ "metadata" ] } tempfile = "3.9.0" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio.workspace = true libp2p = { git = "https://github.com/libp2p/rust-libp2p", features = [ "ed25519", "gossipsub", "identify", "macros", "noise", "ping", "quic", "relay", "tcp", "tokio", "yamux" ] } -libp2p-webrtc = { git = "https://github.com/libp2p/rust-libp2p", features = [ "tokio", "pem" ] } +libp2p-webrtc = { git = "https://github.com/libp2p/rust-libp2p", features = [ "pem", "tokio" ] } +tokio.workspace = true +torii-core.workspace = true +sqlx.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] libp2p = { git = "https://github.com/libp2p/rust-libp2p", features = [ "ed25519", "gossipsub", "identify", "macros", "ping", "tcp", "wasm-bindgen" ] } libp2p-webrtc-websys = { git = "https://github.com/libp2p/rust-libp2p" } tracing-wasm = "0.2.1" -wasm-bindgen-test = "0.3.40" wasm-bindgen-futures = "0.4.40" +wasm-bindgen-test = "0.3.40" wasm-timer = "0.2.5" diff --git a/crates/torii/libp2p/mocks/example_baseTypes.json b/crates/torii/libp2p/mocks/example_baseTypes.json new file mode 100644 index 0000000000..759c5aae83 --- /dev/null +++ b/crates/torii/libp2p/mocks/example_baseTypes.json @@ -0,0 +1,39 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "n0", "type": "felt" }, + { "name": "n1", "type": "bool" }, + { "name": "n2", "type": "string" }, + { "name": "n3", "type": "selector" }, + { "name": "n4", "type": "u128" }, + { "name": "n5", "type": "ContractAddress" }, + { "name": "n6", "type": "ClassHash" }, + { "name": "n7", "type": "timestamp" }, + { "name": "n8", "type": "shortstring" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "n0": "0x3e8", + "n1": true, + "n2": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", + "n3": "transfer", + "n4": "0x3e8", + "n5": "0x3e8", + "n6": "0x3e8", + "n7": 1000, + "n8": "transfer" + } +} \ No newline at end of file diff --git a/crates/torii/libp2p/mocks/example_enum.json b/crates/torii/libp2p/mocks/example_enum.json new file mode 100644 index 0000000000..c10ae99042 --- /dev/null +++ b/crates/torii/libp2p/mocks/example_enum.json @@ -0,0 +1,28 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [{ "name": "someEnum", "type": "enum", "contains": "MyEnum" }], + "MyEnum": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128,u128*)" }, + { "name": "Variant 3", "type": "(u128)" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "someEnum": { + "Variant 2": [2, [0, 1]] + } + } +} diff --git a/crates/torii/libp2p/mocks/example_presetTypes.json b/crates/torii/libp2p/mocks/example_presetTypes.json new file mode 100644 index 0000000000..f2cc9d7bc5 --- /dev/null +++ b/crates/torii/libp2p/mocks/example_presetTypes.json @@ -0,0 +1,37 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "n0", "type": "TokenAmount" }, + { "name": "n1", "type": "NftId" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "n0": { + "token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + "amount": { + "low": "0x3e8", + "high": "0x0" + } + }, + "n1": { + "collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + "token_id": { + "low": "0x3e8", + "high": "0x0" + } + } + } +} diff --git a/crates/torii/libp2p/mocks/mail_StructArray.json b/crates/torii/libp2p/mocks/mail_StructArray.json new file mode 100644 index 0000000000..6f5b58f31c --- /dev/null +++ b/crates/torii/libp2p/mocks/mail_StructArray.json @@ -0,0 +1,44 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" } + ], + "Person": [ + { "name": "name", "type": "felt" }, + { "name": "wallet", "type": "felt" } + ], + "Post": [ + { "name": "title", "type": "felt" }, + { "name": "content", "type": "felt" } + ], + "Mail": [ + { "name": "from", "type": "Person" }, + { "name": "to", "type": "Person" }, + { "name": "posts_len", "type": "felt" }, + { "name": "posts", "type": "Post*" } + ] + }, + "primaryType": "Mail", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1" + }, + "message": { + "from": { + "name": "Cow", + "wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" + }, + "to": { + "name": "Bob", + "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB" + }, + "posts_len": 2, + "posts": [ + { "title": "Greeting", "content": "Hello, Bob!" }, + { "title": "Farewell", "content": "Goodbye, Bob!" } + ] + } +} diff --git a/crates/torii/libp2p/src/client/mod.rs b/crates/torii/libp2p/src/client/mod.rs index 9d26458917..2efbb9c655 100644 --- a/crates/torii/libp2p/src/client/mod.rs +++ b/crates/torii/libp2p/src/client/mod.rs @@ -5,7 +5,7 @@ use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures::channel::oneshot; use futures::lock::Mutex; use futures::{select, StreamExt}; -use libp2p::gossipsub::{self, IdentTopic, MessageId, TopicHash}; +use libp2p::gossipsub::{self, IdentTopic, MessageId}; use libp2p::swarm::{NetworkBehaviour, Swarm, SwarmEvent}; use libp2p::{identify, identity, ping, Multiaddr, PeerId}; #[cfg(not(target_arch = "wasm32"))] @@ -16,7 +16,7 @@ pub mod events; use crate::client::events::ClientEvent; use crate::constants; use crate::errors::Error; -use crate::types::{ClientMessage, ServerMessage}; +use crate::types::Message; #[derive(NetworkBehaviour)] #[behaviour(out_event = "ClientEvent")] @@ -27,35 +27,18 @@ struct Behaviour { } pub struct RelayClient { - pub message_receiver: Arc>>, pub command_sender: CommandSender, pub event_loop: Arc>, } pub struct EventLoop { swarm: Swarm, - message_sender: UnboundedSender, command_receiver: UnboundedReceiver, } -#[derive(Debug, Clone)] -pub struct Message { - // PeerId of the relay that propagated the message - pub propagation_source: PeerId, - // Peer that published the message - pub source: PeerId, - pub message_id: MessageId, - // Hash of the topic message was published to - pub topic: TopicHash, - // Raw message payload - pub data: Vec, -} - #[derive(Debug)] enum Command { - Subscribe(String, oneshot::Sender>), - Unsubscribe(String, oneshot::Sender>), - Publish(String, Vec, oneshot::Sender>), + Publish(Message, oneshot::Sender>), WaitForRelay(oneshot::Sender>), } @@ -102,12 +85,10 @@ impl RelayClient { info!(target: "torii::relay::client", addr = %relay_addr, "Dialing relay"); swarm.dial(relay_addr.parse::()?)?; - let (message_sender, message_receiver) = futures::channel::mpsc::unbounded(); let (command_sender, command_receiver) = futures::channel::mpsc::unbounded(); Ok(Self { command_sender: CommandSender::new(command_sender), - message_receiver: Arc::new(Mutex::new(message_receiver)), - event_loop: Arc::new(Mutex::new(EventLoop { swarm, message_sender, command_receiver })), + event_loop: Arc::new(Mutex::new(EventLoop { swarm, command_receiver })), }) } @@ -155,12 +136,10 @@ impl RelayClient { info!(target: "torii::relay::client", addr = %relay_addr, "Dialing relay"); swarm.dial(relay_addr.parse::()?)?; - let (message_sender, message_receiver) = futures::channel::mpsc::unbounded(); let (command_sender, command_receiver) = futures::channel::mpsc::unbounded(); Ok(Self { command_sender: CommandSender::new(command_sender), - message_receiver: Arc::new(Mutex::new(message_receiver)), - event_loop: Arc::new(Mutex::new(EventLoop { swarm, message_sender, command_receiver })), + event_loop: Arc::new(Mutex::new(EventLoop { swarm, command_receiver })), }) } } @@ -174,28 +153,10 @@ impl CommandSender { Self { sender } } - pub async fn subscribe(&mut self, room: String) -> Result { + pub async fn publish(&mut self, data: Message) -> Result { let (tx, rx) = oneshot::channel(); - self.sender.unbounded_send(Command::Subscribe(room, tx)).expect("Failed to send command"); - - rx.await.expect("Failed to receive response") - } - - pub async fn unsubscribe(&mut self, room: String) -> Result { - let (tx, rx) = oneshot::channel(); - - self.sender.unbounded_send(Command::Unsubscribe(room, tx)).expect("Failed to send command"); - - rx.await.expect("Failed to receive response") - } - - pub async fn publish(&mut self, topic: String, data: Vec) -> Result { - let (tx, rx) = oneshot::channel(); - - self.sender - .unbounded_send(Command::Publish(topic, data, tx)) - .expect("Failed to send command"); + self.sender.unbounded_send(Command::Publish(data, tx)).expect("Failed to send command"); rx.await.expect("Failed to receive response") } @@ -219,15 +180,9 @@ impl EventLoop { select! { command = self.command_receiver.select_next_some() => { match command { - Command::Subscribe(room, sender) => { - sender.send(self.subscribe(&room)).expect("Failed to send response"); - }, - Command::Unsubscribe(room, sender) => { - sender.send(self.unsubscribe(&room)).expect("Failed to send response"); - }, - Command::Publish(topic, data, sender) => { - sender.send(self.publish(topic, data)).expect("Failed to send response"); - }, + Command::Publish(data, sender) => { + sender.send(self.publish(&data)).expect("Failed to send response"); + } Command::WaitForRelay(sender) => { if is_relay_ready { sender.send(Ok(())).expect("Failed to send response"); @@ -239,37 +194,13 @@ impl EventLoop { }, event = self.swarm.select_next_some() => { match event { - SwarmEvent::Behaviour(event) => { - match event { - // Handle behaviour events. - ClientEvent::Gossipsub(gossipsub::Event::Message { - propagation_source: peer_id, - message_id, - message, - }) => { - // deserialize message payload - let message_payload: ServerMessage = serde_json::from_slice(&message.data) - .expect("Failed to deserialize message"); - - let message = Message { - propagation_source: peer_id, - source: PeerId::from_bytes(&message_payload.peer_id).expect("Failed to parse peer id"), - message_id, - topic: message.topic, - data: message_payload.data, - }; + SwarmEvent::Behaviour(ClientEvent::Gossipsub(gossipsub::Event::Subscribed { topic, .. })) => { + // Handle behaviour events. + info!(target: "torii::relay::client::gossipsub", topic = ?topic, "Relay ready. Received subscription confirmation"); - self.message_sender.unbounded_send(message).expect("Failed to send message"); - } - ClientEvent::Gossipsub(gossipsub::Event::Subscribed { topic, .. }) => { - info!(target: "torii::relay::client::gossipsub", topic = ?topic, "Relay ready. Received subscription confirmation"); - - is_relay_ready = true; - if let Some(tx) = relay_ready_tx.take() { - tx.send(Ok(())).expect("Failed to send response"); - } - } - _ => {} + is_relay_ready = true; + if let Some(tx) = relay_ready_tx.take() { + tx.send(Ok(())).expect("Failed to send response"); } } SwarmEvent::ConnectionClosed { cause: Some(cause), .. } => { @@ -287,23 +218,13 @@ impl EventLoop { } } - fn subscribe(&mut self, room: &str) -> Result { - let topic = IdentTopic::new(room); - self.swarm.behaviour_mut().gossipsub.subscribe(&topic).map_err(Error::SubscriptionError) - } - - fn unsubscribe(&mut self, room: &str) -> Result { - let topic = IdentTopic::new(room); - self.swarm.behaviour_mut().gossipsub.unsubscribe(&topic).map_err(Error::PublishError) - } - - fn publish(&mut self, topic: String, data: Vec) -> Result { + fn publish(&mut self, data: &Message) -> Result { self.swarm .behaviour_mut() .gossipsub .publish( IdentTopic::new(constants::MESSAGING_TOPIC), - serde_json::to_string(&ClientMessage { topic, data }).unwrap(), + serde_json::to_string(data).unwrap(), ) .map_err(Error::PublishError) } diff --git a/crates/torii/libp2p/src/errors.rs b/crates/torii/libp2p/src/errors.rs index b9845817fc..2920e43a6e 100644 --- a/crates/torii/libp2p/src/errors.rs +++ b/crates/torii/libp2p/src/errors.rs @@ -39,4 +39,7 @@ pub enum Error { #[error("Failed to read certificate: {0}")] ReadCertificateError(anyhow::Error), + + #[error("Invalid message provided: {0}")] + InvalidMessageError(String), } diff --git a/crates/torii/libp2p/src/lib.rs b/crates/torii/libp2p/src/lib.rs index db6b58cbeb..1eaf1a17bd 100644 --- a/crates/torii/libp2p/src/lib.rs +++ b/crates/torii/libp2p/src/lib.rs @@ -4,4 +4,5 @@ pub mod errors; #[cfg(not(target_arch = "wasm32"))] pub mod server; mod tests; +pub mod typed_data; pub mod types; diff --git a/crates/torii/libp2p/src/server/mod.rs b/crates/torii/libp2p/src/server/mod.rs index 2d321d1242..0ef61fc8ad 100644 --- a/crates/torii/libp2p/src/server/mod.rs +++ b/crates/torii/libp2p/src/server/mod.rs @@ -2,10 +2,15 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::net::Ipv4Addr; use std::path::Path; +use std::str::FromStr; use std::time::Duration; use std::{fs, io}; +use crypto_bigint::U256; +use dojo_types::primitive::Primitive; +use dojo_types::schema::{Member, Struct, Ty}; use futures::StreamExt; +use indexmap::IndexMap; use libp2p::core::multiaddr::Protocol; use libp2p::core::muxing::StreamMuxerBox; use libp2p::core::Multiaddr; @@ -14,7 +19,11 @@ use libp2p::swarm::{NetworkBehaviour, SwarmEvent}; use libp2p::{identify, identity, noise, ping, relay, tcp, yamux, PeerId, Swarm, Transport}; use libp2p_webrtc as webrtc; use rand::thread_rng; -use tracing::info; +use serde_json::Number; +use starknet_crypto::{poseidon_hash_many, verify}; +use starknet_ff::FieldElement; +use torii_core::sql::Sql; +use tracing::{info, warn}; use webrtc::tokio::Certificate; use crate::constants; @@ -22,8 +31,11 @@ use crate::errors::Error; mod events; +use sqlx::Row; + use crate::server::events::ServerEvent; -use crate::types::{ClientMessage, ServerMessage}; +use crate::typed_data::PrimitiveType; +use crate::types::Message; #[derive(NetworkBehaviour)] #[behaviour(out_event = "ServerEvent")] @@ -36,10 +48,12 @@ pub struct Behaviour { pub struct Relay { swarm: Swarm, + db: Sql, } impl Relay { pub fn new( + pool: Sql, port: u16, port_webrtc: u16, local_key_path: Option, @@ -129,7 +143,7 @@ impl Relay { .subscribe(&IdentTopic::new(constants::MESSAGING_TOPIC)) .unwrap(); - Ok(Self { swarm }) + Ok(Self { swarm, db: pool }) } pub async fn run(&mut self) { @@ -142,45 +156,186 @@ impl Relay { message_id, message, }) => { - // Deserialize message. + // Deserialize typed data. // We shouldn't panic here - let message = serde_json::from_slice::(&message.data); - if let Err(e) = message { - info!( - target: "torii::relay::server::gossipsub", - error = %e, - "Failed to deserialize message" - ); - continue; - } - - let message = message.unwrap(); + let data = match serde_json::from_slice::(&message.data) { + Ok(message) => message, + Err(e) => { + info!( + target: "torii::relay::server::gossipsub", + error = %e, + "Failed to deserialize message" + ); + continue; + } + }; + + let ty = match validate_message(&data.message.message) { + Ok(parsed_message) => parsed_message, + Err(e) => { + info!( + target: "torii::relay::server::gossipsub", + error = %e, + "Failed to validate message" + ); + continue; + } + }; info!( target: "torii::relay::server", message_id = %message_id, peer_id = %peer_id, - topic = %message.topic, - data = %String::from_utf8_lossy(&message.data), + data = ?data, "Received message" ); - // forward message to room - let server_message = - ServerMessage { peer_id: peer_id.to_bytes(), data: message.data }; + // retrieve entity identity from db + let mut pool = match self.db.pool.acquire().await { + Ok(pool) => pool, + Err(e) => { + warn!( + target: "torii::relay::server", + error = %e, + "Failed to acquire pool" + ); + continue; + } + }; + + let keys = match ty_keys(&ty) { + Ok(keys) => keys, + Err(e) => { + warn!( + target: "torii::relay::server", + error = %e, + "Failed to get message model keys" + ); + continue; + } + }; + + // select only identity field, if doesn't exist, empty string + let entity = match sqlx::query("SELECT * FROM ? WHERE id = ?") + .bind(&ty.as_struct().unwrap().name) + .bind(format!("{:#x}", poseidon_hash_many(&keys))) + .fetch_optional(&mut *pool) + .await + { + Ok(entity_identity) => entity_identity, + Err(e) => { + warn!( + target: "torii::relay::server", + error = %e, + "Failed to fetch entity" + ); + continue; + } + }; + + if entity.is_none() { + // we can set the entity without checking identity + if let Err(e) = + self.db.set_entity(ty, &message_id.to_string()).await + { + info!( + target: "torii::relay::server", + error = %e, + "Failed to set message" + ); + continue; + } else { + info!( + target: "torii::relay::server", + message_id = %message_id, + peer_id = %peer_id, + "Message set" + ); + continue; + } + } - if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish( - IdentTopic::new(message.topic), - serde_json::to_string(&server_message) - .expect("Failed to serialize message") - .as_bytes(), + let entity = entity.unwrap(); + let identity = match FieldElement::from_str(&match entity + .try_get::("identity") + { + Ok(identity) => identity, + Err(e) => { + warn!( + target: "torii::relay::server", + error = %e, + "Failed to get identity from model" + ); + continue; + } + }) { + Ok(identity) => identity, + Err(e) => { + warn!( + target: "torii::relay::server", + error = %e, + "Failed to parse identity" + ); + continue; + } + }; + + // TODO: have a nonce in model to check + // against entity nonce and message nonce + // to prevent replay attacks. + + // Verify the signature + let message_hash = if let Ok(message) = data.message.encode(identity) { + message + } else { + info!( + target: "torii::relay::server", + "Failed to encode message" + ); + continue; + }; + + // for the public key used for verification; use identity from model + if let Ok(valid) = verify( + &identity, + &message_hash, + &data.signature_r, + &data.signature_s, ) { + if !valid { + info!( + target: "torii::relay::server", + "Invalid signature" + ); + continue; + } + } else { info!( - target: "torii::relay::server::gossipsub", + target: "torii::relay::server", + "Failed to verify signature" + ); + continue; + } + + if let Err(e) = self + .db + // event id is message id + .set_entity(ty, &message_id.to_string()) + .await + { + info!( + target: "torii::relay::server", error = %e, - "Failed to publish message" + "Failed to set message" ); } + + info!( + target: "torii::relay::server", + message_id = %message_id, + peer_id = %peer_id, + "Message verified and set" + ); } ServerEvent::Gossipsub(gossipsub::Event::Subscribed { peer_id, topic }) => { info!( @@ -233,6 +388,293 @@ impl Relay { } } +fn ty_keys(ty: &Ty) -> Result, Error> { + if let Ty::Struct(s) = &ty { + let mut keys = Vec::new(); + for m in s.keys() { + keys.extend(m.serialize().map_err(|_| { + Error::InvalidMessageError("Failed to serialize model key".to_string()) + })?); + } + Ok(keys) + } else { + Err(Error::InvalidMessageError("Entity is not a struct".to_string())) + } +} + +pub fn parse_ty_to_object(ty: &Ty) -> Result, Error> { + match ty { + Ty::Struct(struct_ty) => { + let mut object = IndexMap::new(); + for member in &struct_ty.children { + let mut member_object = IndexMap::new(); + member_object.insert("key".to_string(), PrimitiveType::Bool(member.key)); + member_object.insert( + "type".to_string(), + PrimitiveType::String(ty_to_string_type(&member.ty)), + ); + member_object.insert("value".to_string(), parse_ty_to_primitive(&member.ty)?); + object.insert(member.name.clone(), PrimitiveType::Object(member_object)); + } + Ok(object) + } + _ => Err(Error::InvalidMessageError("Expected Struct type".to_string())), + } +} + +pub fn ty_to_string_type(ty: &Ty) -> String { + match ty { + Ty::Primitive(primitive) => match primitive { + Primitive::U8(_) => "u8".to_string(), + Primitive::U16(_) => "u16".to_string(), + Primitive::U32(_) => "u32".to_string(), + Primitive::USize(_) => "usize".to_string(), + Primitive::U64(_) => "u64".to_string(), + Primitive::U128(_) => "u128".to_string(), + Primitive::U256(_) => "u256".to_string(), + Primitive::Felt252(_) => "felt".to_string(), + Primitive::ClassHash(_) => "class_hash".to_string(), + Primitive::ContractAddress(_) => "contract_address".to_string(), + Primitive::Bool(_) => "bool".to_string(), + }, + Ty::Struct(_) => "struct".to_string(), + Ty::Tuple(_) => "array".to_string(), + Ty::Enum(_) => "enum".to_string(), + } +} + +pub fn parse_ty_to_primitive(ty: &Ty) -> Result { + match ty { + Ty::Primitive(primitive) => match primitive { + Primitive::U8(value) => { + Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) + } + Primitive::U16(value) => { + Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) + } + Primitive::U32(value) => { + Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) + } + Primitive::USize(value) => { + Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) + } + Primitive::U64(value) => { + Ok(PrimitiveType::Number(Number::from(value.map(|v| v).unwrap_or(0u64)))) + } + Primitive::U128(value) => Ok(PrimitiveType::String( + value.map(|v| v.to_string()).unwrap_or_else(|| "0".to_string()), + )), + Primitive::U256(value) => Ok(PrimitiveType::String( + value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), + )), + Primitive::Felt252(value) => Ok(PrimitiveType::String( + value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), + )), + Primitive::ClassHash(value) => Ok(PrimitiveType::String( + value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), + )), + Primitive::ContractAddress(value) => Ok(PrimitiveType::String( + value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), + )), + Primitive::Bool(value) => Ok(PrimitiveType::Bool(value.unwrap_or(false))), + }, + _ => Err(Error::InvalidMessageError("Expected Primitive type".to_string())), + } +} + +pub fn parse_object_to_ty( + name: String, + object: &IndexMap, +) -> Result { + let mut ty_struct = Struct { name, children: vec![] }; + + for (field_name, value) in object { + // value has to be of type object + let object = if let PrimitiveType::Object(object) = value { + object + } else { + return Err(Error::InvalidMessageError("Value is not an object".to_string())); + }; + + let r#type = if let Some(r#type) = object.get("type") { + if let PrimitiveType::String(r#type) = r#type { + r#type + } else { + return Err(Error::InvalidMessageError("Type is not a string".to_string())); + } + } else { + return Err(Error::InvalidMessageError("Type is missing".to_string())); + }; + + let value = if let Some(value) = object.get("value") { + value + } else { + return Err(Error::InvalidMessageError("Value is missing".to_string())); + }; + + let key = if let Some(key) = object.get("key") { + if let PrimitiveType::Bool(key) = key { + *key + } else { + return Err(Error::InvalidMessageError("Key is not a boolean".to_string())); + } + } else { + return Err(Error::InvalidMessageError("Key is missing".to_string())); + }; + + match value { + PrimitiveType::Object(object) => { + let ty = parse_object_to_ty(field_name.clone(), object)?; + ty_struct.children.push(Member { name: field_name.clone(), ty, key }); + } + PrimitiveType::Array(_) => { + // tuples not supported yet + unimplemented!() + } + PrimitiveType::Number(number) => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: match r#type.as_str() { + "u8" => Ty::Primitive(Primitive::U8(Some(number.as_u64().unwrap() as u8))), + "u16" => { + Ty::Primitive(Primitive::U16(Some(number.as_u64().unwrap() as u16))) + } + "u32" => { + Ty::Primitive(Primitive::U32(Some(number.as_u64().unwrap() as u32))) + } + "usize" => { + Ty::Primitive(Primitive::USize(Some(number.as_u64().unwrap() as u32))) + } + "u64" => Ty::Primitive(Primitive::U64(Some(number.as_u64().unwrap()))), + _ => { + return Err(Error::InvalidMessageError( + "Invalid number type".to_string(), + )); + } + }, + key, + }); + } + PrimitiveType::Bool(boolean) => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::Bool(Some(*boolean))), + key, + }); + } + PrimitiveType::String(string) => match r#type.as_str() { + "u8" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U8(Some(u8::from_str(string).unwrap()))), + key, + }); + } + "u16" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U16(Some(u16::from_str(string).unwrap()))), + key, + }); + } + "u32" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U32(Some(u32::from_str(string).unwrap()))), + key, + }); + } + "usize" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::USize(Some(u32::from_str(string).unwrap()))), + key, + }); + } + "u64" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U64(Some(u64::from_str(string).unwrap()))), + key, + }); + } + "u128" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U128(Some(u128::from_str(string).unwrap()))), + key, + }); + } + "u256" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::U256(Some(U256::from_be_hex(string)))), + key, + }); + } + "felt" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::Felt252(Some( + FieldElement::from_str(string).unwrap(), + ))), + key, + }); + } + "class_hash" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::ClassHash(Some( + FieldElement::from_str(string).unwrap(), + ))), + key, + }); + } + "contract_address" => { + ty_struct.children.push(Member { + name: field_name.clone(), + ty: Ty::Primitive(Primitive::ContractAddress(Some( + FieldElement::from_str(string).unwrap(), + ))), + key, + }); + } + _ => { + return Err(Error::InvalidMessageError("Invalid string type".to_string())); + } + }, + } + } + + Ok(Ty::Struct(ty_struct)) +} + +// Validates the message model +// and returns the identity and signature +fn validate_message(message: &IndexMap) -> Result { + let model_name = if let Some(model_name) = message.get("model") { + if let PrimitiveType::String(model_name) = model_name { + model_name + } else { + return Err(Error::InvalidMessageError("Model name is not a string".to_string())); + } + } else { + return Err(Error::InvalidMessageError("Model name is missing".to_string())); + }; + + let model = if let Some(object) = message.get(model_name) { + if let PrimitiveType::Object(object) = object { + parse_object_to_ty(model_name.clone(), object)? + } else { + return Err(Error::InvalidMessageError("Model is not a struct".to_string())); + } + } else { + return Err(Error::InvalidMessageError("Model is missing".to_string())); + }; + + Ok(model) +} + fn read_or_create_identity(path: &Path) -> anyhow::Result { if path.exists() { let bytes = fs::read(path)?; diff --git a/crates/torii/libp2p/src/tests.rs b/crates/torii/libp2p/src/tests.rs index 225dc15a33..b19a5be1a4 100644 --- a/crates/torii/libp2p/src/tests.rs +++ b/crates/torii/libp2p/src/tests.rs @@ -2,8 +2,6 @@ mod test { use std::error::Error; - use futures::StreamExt; - use crate::client::RelayClient; #[cfg(target_arch = "wasm32")] @@ -15,18 +13,32 @@ mod test { #[cfg(not(target_arch = "wasm32"))] #[tokio::test] async fn test_client_messaging() -> Result<(), Box> { - use std::time::Duration; - + use dojo_types::schema::{Member, Struct, Ty}; + use indexmap::IndexMap; + use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; + use starknet_ff::FieldElement; use tokio::time::sleep; - use tokio::{self, select}; + use torii_core::sql::Sql; - use crate::server::Relay; + use crate::server::{parse_ty_to_object, Relay}; + use crate::typed_data::{Domain, TypedData}; + use crate::types::Message; let _ = tracing_subscriber::fmt() .with_env_filter("torii::relay::client=debug,torii::relay::server=debug") .try_init(); + + // Database + let options = ::from_str("sqlite::memory:") + .unwrap() + .create_if_missing(true); + let pool = SqlitePoolOptions::new().max_connections(5).connect_with(options).await.unwrap(); + sqlx::migrate!("../migrations").run(&pool).await.unwrap(); + + let db = Sql::new(pool.clone(), FieldElement::from_bytes_be(&[0; 32]).unwrap()).await?; + // Initialize the relay server - let mut relay_server: Relay = Relay::new(9900, 9901, None, None)?; + let mut relay_server: Relay = Relay::new(db, 9900, 9901, None, None)?; tokio::spawn(async move { relay_server.run().await; }); @@ -37,27 +49,72 @@ mod test { client.event_loop.lock().await.run().await; }); - client.command_sender.subscribe("mawmaw".to_string()).await?; client.command_sender.wait_for_relay().await?; - client.command_sender.publish("mawmaw".to_string(), "mimi".as_bytes().to_vec()).await?; + let mut data = Struct { name: "Message".to_string(), children: vec![] }; + + data.children.push(Member { + name: "player".to_string(), + ty: dojo_types::schema::Ty::Primitive( + dojo_types::primitive::Primitive::ContractAddress(Some( + FieldElement::from_bytes_be(&[0; 32]).unwrap(), + )), + ), + key: true, + }); - let message_receiver = client.message_receiver.clone(); - let mut message_receiver = message_receiver.lock().await; - - loop { - select! { - event = message_receiver.next() => { - if let Some(message) = event { - println!("Received message from {:?} with id {:?}: {:?}", message.source, message.message_id, message); - return Ok(()); - } - } - _ = sleep(Duration::from_secs(5)) => { - println!("Test Failed: Did not receive message within 5 seconds."); - return Err("Timeout reached without receiving a message".into()); - } - } - } + data.children.push(Member { + name: "message".to_string(), + ty: dojo_types::schema::Ty::Primitive(dojo_types::primitive::Primitive::U8(Some(0))), + key: false, + }); + + let mut typed_data = TypedData::new( + IndexMap::new(), + "Message", + Domain::new("Message", "1", "0x0", Some("1")), + IndexMap::new(), + ); + + typed_data.message.insert( + "model".to_string(), + crate::typed_data::PrimitiveType::String("Message".to_string()), + ); + typed_data.message.insert( + "Message".to_string(), + crate::typed_data::PrimitiveType::Object( + parse_ty_to_object(&Ty::Struct(data.clone())).unwrap(), + ), + ); + + println!("object ty: {:?}", parse_ty_to_object(&Ty::Struct(data)).unwrap()); + + client + .command_sender + .publish(Message { + message: typed_data, + signature_r: FieldElement::from_bytes_be(&[0; 32]).unwrap(), + signature_s: FieldElement::from_bytes_be(&[0; 32]).unwrap(), + }) + .await?; + + sleep(std::time::Duration::from_secs(2)).await; + + Ok(()) + // loop { + // select! { + // entity = sqlx::query("SELECT * FROM entities WHERE id = ?") + // .bind(format!("{:#x}", FieldElement::from_bytes_be(&[0; + // 32]).unwrap())).fetch_one(&pool) => { if let Ok(_) = entity { + // println!("Test OK: Received message within 5 seconds."); + // return Ok(()); + // } + // } + // _ = sleep(Duration::from_secs(5)) => { + // println!("Test Failed: Did not receive message within 5 seconds."); + // return Err("Timeout reached without receiving a message".into()); + // } + // } + // } } #[cfg(target_arch = "wasm32")] diff --git a/crates/torii/libp2p/src/typed_data.rs b/crates/torii/libp2p/src/typed_data.rs new file mode 100644 index 0000000000..4dec9753e9 --- /dev/null +++ b/crates/torii/libp2p/src/typed_data.rs @@ -0,0 +1,678 @@ +use std::str::FromStr; + +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; +use serde_json::Number; +use starknet_core::utils::{ + cairo_short_string_to_felt, get_selector_from_name, CairoShortStringToFeltError, +}; +use starknet_crypto::poseidon_hash_many; +use starknet_ff::FieldElement; + +use crate::errors::Error; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SimpleField { + pub name: String, + pub r#type: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ParentField { + pub name: String, + pub r#type: String, + pub contains: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Field { + ParentType(ParentField), + SimpleType(SimpleField), +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PrimitiveType { + // All of object types. Including preset types + Object(IndexMap), + Array(Vec), + Bool(bool), + // comprehensive representation of + // String, ShortString, Selector and Felt + String(String), + // For JSON numbers. Formed into a Felt + Number(Number), +} + +fn get_preset_types() -> IndexMap> { + let mut types = IndexMap::new(); + + types.insert( + "TokenAmount".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "token_address".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "amount".to_string(), + r#type: "u256".to_string(), + }), + ], + ); + + types.insert( + "NftId".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "collection_address".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "token_id".to_string(), + r#type: "u256".to_string(), + }), + ], + ); + + types.insert( + "u256".to_string(), + vec![ + Field::SimpleType(SimpleField { name: "low".to_string(), r#type: "u128".to_string() }), + Field::SimpleType(SimpleField { name: "high".to_string(), r#type: "u128".to_string() }), + ], + ); + + types +} + +// Get the fields of a specific type +// Looks up both the types hashmap as well as the preset types +// Returns the fields and the hashmap of types +fn get_fields(name: &str, types: &IndexMap>) -> Result, Error> { + if let Some(fields) = types.get(name) { + return Ok(fields.clone()); + } + + Err(Error::InvalidMessageError(format!("Type {} not found", name))) +} + +fn get_dependencies( + name: &str, + types: &IndexMap>, + dependencies: &mut Vec, +) -> Result<(), Error> { + if dependencies.contains(&name.to_string()) { + return Ok(()); + } + + dependencies.push(name.to_string()); + + for field in get_fields(name, types)? { + let mut field_type = match field { + Field::SimpleType(simple_field) => simple_field.r#type.clone(), + Field::ParentType(parent_field) => parent_field.contains.clone(), + }; + + field_type = field_type.trim_end_matches('*').to_string(); + + if types.contains_key(&field_type) && !dependencies.contains(&field_type) { + get_dependencies(&field_type, types, dependencies)?; + } + } + + Ok(()) +} + +pub fn encode_type(name: &str, types: &IndexMap>) -> Result { + let mut type_hash = String::new(); + + // get dependencies + let mut dependencies: Vec = Vec::new(); + get_dependencies(name, types, &mut dependencies)?; + + // sort dependencies + dependencies.sort_by_key(|dep| dep.to_lowercase()); + + for dep in dependencies { + type_hash += &format!("\"{}\"", dep); + + type_hash += "("; + + let fields = get_fields(&dep, types)?; + for (idx, field) in fields.iter().enumerate() { + match field { + Field::SimpleType(simple_field) => { + // if ( at start and ) at end + if simple_field.r#type.starts_with('(') && simple_field.r#type.ends_with(')') { + let inner_types = + &simple_field.r#type[1..simple_field.r#type.len() - 1] + .split(',') + .map(|t| { + if !t.is_empty() { format!("\"{}\"", t) } else { t.to_string() } + }) + .collect::>() + .join(","); + type_hash += &format!("\"{}\":({})", simple_field.name, inner_types); + } else { + type_hash += + &format!("\"{}\":\"{}\"", simple_field.name, simple_field.r#type); + } + } + Field::ParentType(parent_field) => { + type_hash += + &format!("\"{}\":\"{}\"", parent_field.name, parent_field.contains); + } + } + + if idx < fields.len() - 1 { + type_hash += ","; + } + } + + type_hash += ")"; + } + + Ok(type_hash) +} + +fn byte_array_from_string( + target_string: &str, +) -> Result<(Vec, FieldElement, usize), CairoShortStringToFeltError> { + let short_strings: Vec<&str> = split_long_string(target_string); + let remainder = short_strings.last().unwrap_or(&""); + + let mut short_strings_encoded = short_strings + .iter() + .map(|&s| cairo_short_string_to_felt(s)) + .collect::, _>>()?; + + let (pending_word, pending_word_length) = if remainder.is_empty() || remainder.len() == 31 { + (FieldElement::ZERO, 0) + } else { + (short_strings_encoded.pop().unwrap(), remainder.len()) + }; + + Ok((short_strings_encoded, pending_word, pending_word_length)) +} + +fn split_long_string(long_str: &str) -> Vec<&str> { + let mut result = Vec::new(); + + let mut start = 0; + while start < long_str.len() { + let end = (start + 31).min(long_str.len()); + result.push(&long_str[start..end]); + start = end; + } + + result +} + +#[derive(Debug, Default)] +pub struct Ctx { + pub base_type: String, + pub parent_type: String, + pub is_preset: bool, +} + +pub(crate) struct FieldInfo { + _name: String, + r#type: String, + base_type: String, + index: usize, +} + +pub(crate) fn get_value_type( + name: &str, + types: &IndexMap>, +) -> Result { + // iter both "types" and "preset_types" to find the field + for (idx, (key, value)) in types.iter().enumerate() { + if key == name { + return Ok(FieldInfo { + _name: name.to_string(), + r#type: key.clone(), + base_type: "".to_string(), + index: idx, + }); + } + + for (idx, field) in value.iter().enumerate() { + match field { + Field::SimpleType(simple_field) => { + if simple_field.name == name { + return Ok(FieldInfo { + _name: name.to_string(), + r#type: simple_field.r#type.clone(), + base_type: "".to_string(), + index: idx, + }); + } + } + Field::ParentType(parent_field) => { + if parent_field.name == name { + return Ok(FieldInfo { + _name: name.to_string(), + r#type: parent_field.contains.clone(), + base_type: parent_field.r#type.clone(), + index: idx, + }); + } + } + } + } + } + + Err(Error::InvalidMessageError(format!("Field {} not found in types", name))) +} + +fn get_hex(value: &str) -> Result { + if let Ok(felt) = FieldElement::from_str(value) { + Ok(felt) + } else { + // assume its a short string and encode + cairo_short_string_to_felt(value) + .map_err(|_| Error::InvalidMessageError("Invalid short string".to_string())) + } +} + +impl PrimitiveType { + pub fn encode( + &self, + r#type: &str, + types: &IndexMap>, + preset_types: &IndexMap>, + ctx: &mut Ctx, + ) -> Result { + match self { + PrimitiveType::Object(obj) => { + println!("r#type: {}", r#type); + + ctx.is_preset = preset_types.contains_key(r#type); + + let mut hashes = Vec::new(); + + if ctx.base_type == "enum" { + let (variant_name, value) = obj.first().ok_or_else(|| { + Error::InvalidMessageError("Enum value must be populated".to_string()) + })?; + let variant_type = get_value_type(variant_name, types)?; + + let arr: &Vec = match value { + PrimitiveType::Array(arr) => arr, + _ => { + return Err(Error::InvalidMessageError( + "Enum value must be an array".to_string(), + )); + } + }; + + // variant index + hashes.push(FieldElement::from(variant_type.index as u32)); + + // variant parameters + for (idx, param) in arr.iter().enumerate() { + let field_type = &variant_type + .r#type + .trim_start_matches('(') + .trim_end_matches(')') + .split(',') + .nth(idx) + .ok_or_else(|| { + Error::InvalidMessageError("Invalid enum variant type".to_string()) + })?; + + let field_hash = param.encode(field_type, types, preset_types, ctx)?; + hashes.push(field_hash); + } + + return Ok(poseidon_hash_many(hashes.as_slice())); + } + + let type_hash = + encode_type(r#type, if ctx.is_preset { preset_types } else { types })?; + println!("type_hash: {}", type_hash); + hashes.push(get_selector_from_name(&type_hash).map_err(|_| { + Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) + })?); + + for (field_name, value) in obj { + // recheck if we're currently in a preset type + ctx.is_preset = preset_types.contains_key(r#type); + + // pass correct types - preset or types + let field_type = get_value_type( + field_name, + if ctx.is_preset { preset_types } else { types }, + )?; + ctx.base_type = field_type.base_type; + ctx.parent_type = r#type.to_string(); + let field_hash = + value.encode(field_type.r#type.as_str(), types, preset_types, ctx)?; + hashes.push(field_hash); + } + + Ok(poseidon_hash_many(hashes.as_slice())) + } + PrimitiveType::Array(array) => Ok(poseidon_hash_many( + array + .iter() + .map(|x| x.encode(r#type.trim_end_matches('*'), types, preset_types, ctx)) + .collect::, _>>()? + .as_slice(), + )), + PrimitiveType::Bool(boolean) => { + let v = + if *boolean { FieldElement::from(1_u32) } else { FieldElement::from(0_u32) }; + Ok(v) + } + PrimitiveType::String(string) => match r#type { + "shortstring" => get_hex(string), + "string" => { + // split the string into short strings and encode + let byte_array = byte_array_from_string(string).map_err(|_| { + Error::InvalidMessageError("Invalid short string".to_string()) + })?; + + let mut hashes = vec![FieldElement::from(byte_array.0.len())]; + + for hash in byte_array.0 { + hashes.push(hash); + } + + hashes.push(byte_array.1); + hashes.push(FieldElement::from(byte_array.2)); + + Ok(poseidon_hash_many(hashes.as_slice())) + } + "selector" => get_selector_from_name(string).map_err(|_| { + Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) + }), + "felt" => get_hex(string), + "ContractAddress" => get_hex(string), + "ClassHash" => get_hex(string), + "timestamp" => get_hex(string), + "u128" => get_hex(string), + "i128" => get_hex(string), + _ => Err(Error::InvalidMessageError(format!("Invalid type {} for string", r#type))), + }, + PrimitiveType::Number(number) => { + let felt = FieldElement::from_str(&number.to_string()).map_err(|_| { + Error::InvalidMessageError(format!("Invalid number {}", number)) + })?; + Ok(felt) + } + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Domain { + pub name: String, + pub version: String, + #[serde(rename = "chainId")] + pub chain_id: String, + pub revision: Option, +} + +impl Domain { + pub fn new(name: &str, version: &str, chain_id: &str, revision: Option<&str>) -> Self { + Self { + name: name.to_string(), + version: version.to_string(), + chain_id: chain_id.to_string(), + revision: revision.map(|s| s.to_string()), + } + } + + pub fn encode(&self, types: &IndexMap>) -> Result { + let mut object = IndexMap::new(); + + object.insert("name".to_string(), PrimitiveType::String(self.name.clone())); + object.insert("version".to_string(), PrimitiveType::String(self.version.clone())); + object.insert("chainId".to_string(), PrimitiveType::String(self.chain_id.clone())); + if let Some(revision) = &self.revision { + object.insert("revision".to_string(), PrimitiveType::String(revision.clone())); + } + + // we dont need to pass our preset types here. domain should never use a preset type + PrimitiveType::Object(object).encode( + "StarknetDomain", + types, + &IndexMap::new(), + &mut Default::default(), + ) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TypedData { + pub types: IndexMap>, + #[serde(rename = "primaryType")] + pub primary_type: String, + pub domain: Domain, + pub message: IndexMap, +} + +impl TypedData { + pub fn new( + types: IndexMap>, + primary_type: &str, + domain: Domain, + message: IndexMap, + ) -> Self { + Self { types, primary_type: primary_type.to_string(), domain, message } + } + + pub fn encode(&self, account: FieldElement) -> Result { + let preset_types = get_preset_types(); + + if self.domain.revision.clone().unwrap_or("1".to_string()) != "1" { + return Err(Error::InvalidMessageError( + "Legacy revision 0 is not supported".to_string(), + )); + } + + let prefix_message = cairo_short_string_to_felt("StarkNet Message").unwrap(); + + // encode domain separator + let domain_hash = self.domain.encode(&self.types)?; + + // encode message + let message_hash = PrimitiveType::Object(self.message.clone()).encode( + &self.primary_type, + &self.types, + &preset_types, + &mut Default::default(), + )?; + + // return full hash + Ok(poseidon_hash_many(vec![prefix_message, domain_hash, account, message_hash].as_slice())) + } +} + +#[cfg(test)] +mod tests { + use starknet_core::utils::starknet_keccak; + use starknet_ff::FieldElement; + + use super::*; + + #[test] + fn test_read_json() { + // deserialize from json file + let path = "mocks/mail_StructArray.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + println!("{:?}", typed_data); + + let path = "mocks/example_enum.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + println!("{:?}", typed_data); + + let path = "mocks/example_presetTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + println!("{:?}", typed_data); + } + + #[test] + fn test_type_encode() { + let path = "mocks/example_baseTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let encoded = encode_type(&typed_data.primary_type, &typed_data.types).unwrap(); + + assert_eq!( + encoded, + "\"Example\"(\"n0\":\"felt\",\"n1\":\"bool\",\"n2\":\"string\",\"n3\":\"selector\",\"\ + n4\":\"u128\",\"n5\":\"ContractAddress\",\"n6\":\"ClassHash\",\"n7\":\"timestamp\",\"\ + n8\":\"shortstring\")" + ); + + let path = "mocks/mail_StructArray.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let encoded = encode_type(&typed_data.primary_type, &typed_data.types).unwrap(); + + assert_eq!( + encoded, + "\"Mail\"(\"from\":\"Person\",\"to\":\"Person\",\"posts_len\":\"felt\",\"posts\":\"\ + Post*\")\"Person\"(\"name\":\"felt\",\"wallet\":\"felt\")\"Post\"(\"title\":\"felt\",\ + \"content\":\"felt\")" + ); + + let path = "mocks/example_enum.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let encoded = encode_type(&typed_data.primary_type, &typed_data.types).unwrap(); + + assert_eq!( + encoded, + "\"Example\"(\"someEnum\":\"MyEnum\")\"MyEnum\"(\"Variant 1\":(),\"Variant \ + 2\":(\"u128\",\"u128*\"),\"Variant 3\":(\"u128\"))" + ); + + let path = "mocks/example_presetTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let encoded = encode_type(&typed_data.primary_type, &typed_data.types).unwrap(); + + assert_eq!(encoded, "\"Example\"(\"n0\":\"TokenAmount\",\"n1\":\"NftId\")"); + } + + #[test] + fn test_selector_encode() { + let selector = PrimitiveType::String("transfer".to_string()); + let selector_hash = + PrimitiveType::String(starknet_keccak("transfer".as_bytes()).to_string()); + + let types = IndexMap::new(); + let preset_types = get_preset_types(); + + let encoded_selector = + selector.encode("selector", &types, &preset_types, &mut Default::default()).unwrap(); + let raw_encoded_selector = + selector_hash.encode("felt", &types, &preset_types, &mut Default::default()).unwrap(); + + assert_eq!(encoded_selector, raw_encoded_selector); + assert_eq!(encoded_selector, starknet_keccak("transfer".as_bytes())); + } + + #[test] + fn test_domain_hash() { + let path = "mocks/example_baseTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let domain_hash = typed_data.domain.encode(&typed_data.types).unwrap(); + + assert_eq!( + domain_hash, + FieldElement::from_hex_be( + "0x555f72e550b308e50c1a4f8611483a174026c982a9893a05c185eeb85399657" + ) + .unwrap() + ); + } + + #[test] + fn test_message_hash() { + let address = + FieldElement::from_hex_be("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826").unwrap(); + + let path = "mocks/example_baseTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let message_hash = typed_data.encode(address).unwrap(); + + assert_eq!( + message_hash, + FieldElement::from_hex_be( + "0x790d9fa99cf9ad91c515aaff9465fcb1c87784d9cfb27271ed193675cd06f9c" + ) + .unwrap() + ); + + let path = "mocks/example_enum.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let message_hash = typed_data.encode(address).unwrap(); + + assert_eq!( + message_hash, + FieldElement::from_hex_be( + "0x3df10475ad5a8f49db4345a04a5b09164d2e24b09f6e1e236bc1ccd87627cc" + ) + .unwrap() + ); + + let path = "mocks/example_presetTypes.json"; + let file = std::fs::File::open(path).unwrap(); + let reader = std::io::BufReader::new(file); + + let typed_data: TypedData = serde_json::from_reader(reader).unwrap(); + + let message_hash = typed_data.encode(address).unwrap(); + + assert_eq!( + message_hash, + FieldElement::from_hex_be( + "0x26e7b8cedfa63cdbed14e7e51b60ee53ac82bdf26724eb1e3f0710cb8987522" + ) + .unwrap() + ); + } +} diff --git a/crates/torii/libp2p/src/types.rs b/crates/torii/libp2p/src/types.rs index 4b75f26028..a059038d1e 100644 --- a/crates/torii/libp2p/src/types.rs +++ b/crates/torii/libp2p/src/types.rs @@ -1,13 +1,11 @@ use serde::{Deserialize, Serialize}; +use starknet_ff::FieldElement; -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct ClientMessage { - pub topic: String, - pub data: Vec, -} +use crate::typed_data::TypedData; -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct ServerMessage { - pub peer_id: Vec, - pub data: Vec, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub message: TypedData, + pub signature_r: FieldElement, + pub signature_s: FieldElement, }