forked from official-monty/Monty
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild.rs
165 lines (139 loc) · 5.24 KB
/
build.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#[cfg(feature = "embed")]
use sha2::{Digest, Sha256};
#[cfg(feature = "embed")]
use std::fs;
#[cfg(feature = "embed")]
use std::path::Path;
use chrono::Utc;
use std::process::Command;
fn get_name() {
// Get the current Git commit hash
let output = Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.expect("Failed to execute git command");
let git_commit_hash = String::from_utf8(output.stdout)
.expect("Git output was not valid UTF-8")
.trim()
.to_string();
// Get the current date in YYYYMMDD format
let current_date = Utc::now().format("%Y%m%d").to_string();
// Combine into the desired format
let formatted_name = format!("Monty-dev-{}-{}", current_date, &git_commit_hash[..8]);
// Pass the formatted name as an environment variable
println!("cargo:rustc-env=FORMATTED_NAME={}", formatted_name);
}
#[cfg(feature = "embed")]
fn main() {
// Get the build version name
get_name();
// Extract the file names from the respective source files
let value_file_name = extract_network_name("src/networks/value.rs", "ValueFileDefaultName");
let policy_file_name = extract_network_name("src/networks/policy.rs", "PolicyFileDefaultName");
// Define fixed paths where the networks will be stored
let value_path = "value.network";
let policy_path = "policy.network";
// Validate and download the network files if needed
validate_and_download_network(&value_file_name, &value_path);
validate_and_download_network(&policy_file_name, &policy_path);
// Set up cargo instructions to track changes
println!("cargo:rerun-if-changed=src/networks/value.rs");
println!("cargo:rerun-if-changed=src/networks/policy.rs");
println!("cargo:rerun-if-changed={}", value_path);
println!("cargo:rerun-if-changed={}", policy_path);
}
#[cfg(not(feature = "embed"))]
fn main() {
// Get the build version name
get_name();
}
#[cfg(feature = "embed")]
fn extract_network_name(file_path: &str, const_name: &str) -> String {
let content = fs::read_to_string(file_path).expect("Unable to read networks file");
for line in content.lines() {
if line.contains(const_name) {
// Split the line on the '=' character to separate the variable name and the value
let parts: Vec<&str> = line.split('=').collect();
if parts.len() == 2 {
// Further split on '"' to extract the string value
let network_name = parts[1].split('"').nth(1);
if let Some(name) = network_name {
return name.into();
}
}
}
}
panic!(
"Network name not found or could not be parsed in {}",
file_path
);
}
#[cfg(feature = "embed")]
fn validate_and_download_network(expected_name: &str, dest_path: &str) {
let path = Path::new(dest_path);
// Extract the expected SHA-256 prefix from the expected file name
let expected_prefix = extract_sha_prefix(expected_name);
// If the file exists, calculate its SHA-256 and check the first 12 characters
if path.exists() {
if let Ok(existing_sha) = calculate_sha256(path) {
println!("Expected SHA-256 prefix: {}", expected_prefix);
println!("Actual SHA-256: {}", &existing_sha[..12]);
if existing_sha.starts_with(&expected_prefix) {
println!(
"File at {} is valid with matching SHA-256 prefix.",
dest_path
);
return; // No need to download
} else {
println!(
"File at {} has a mismatching SHA-256 prefix, redownloading...",
dest_path
);
}
} else {
println!(
"Failed to calculate SHA-256 for {}, redownloading...",
dest_path
);
}
}
// Download the correct network file
download_network(expected_name, dest_path);
}
#[cfg(feature = "embed")]
fn extract_sha_prefix(file_name: &str) -> String {
// Assume the format is "nn-<sha_prefix>.network"
let parts: Vec<&str> = file_name.split('-').collect();
if parts.len() == 2 {
return parts[1][..12].to_string(); // Extract the SHA-256 prefix
}
panic!("Invalid file name format: {}", file_name);
}
#[cfg(feature = "embed")]
fn calculate_sha256(path: &Path) -> Result<String, std::io::Error> {
let mut file = fs::File::open(path)?;
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
let result = hasher.finalize();
Ok(format!("{:x}", result))
}
#[cfg(feature = "embed")]
fn download_network(network_name: &str, dest_path: &str) {
let urls = [format!(
"https://tests.montychess.org/api/nn/{}",
network_name
)];
for url in &urls {
let output = Command::new("curl")
.arg("-sL")
.arg(url)
.output()
.expect("Failed to execute curl");
if output.status.success() {
fs::write(dest_path, output.stdout).expect("Failed to write network file");
println!("Downloaded {}", dest_path);
return;
}
}
panic!("Failed to download network file from any source.");
}