Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scripts to train a downstream k-sparse autoencoder #14

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,011 changes: 1,011 additions & 0 deletions scripts/euclid/anomaly_detection.py

Large diffs are not rendered by default.

147 changes: 147 additions & 0 deletions scripts/euclid/create_datacubes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import glob
from astropy.io import fits
import numpy as np
from scipy.spatial import cKDTree



image_dir = "/path/to/cutouts/"

datacubes_dir = "/path/to/datacubes/"

file_properties = "/path/to/galaxy_catalogue/EuclidMorphPhysPropSpecZ.fits"

image_size_x = 64 # Cutout size of the datacube
image_size_y = 64

os.makedirs(datacubes_dir, exist_ok=True)

# Define filters
filters = ["VIS", "NIR-H", "NIR-J", "NIR-Y"]

search_radius_arcsec = 1.0
search_radius_deg = search_radius_arcsec / 3600.0

# Read the properties file
with fits.open(file_properties) as hdul:
data = hdul[1].data
object_ids = data['object_id']
ras = data['right_ascension']
decs = data['declination']

def parse_filename_info(fname):
# Extract RA/DEC from filename pattern: ...CUTOUT_{RA}_{DEC}.fits
fname = os.path.basename(fname)
coord_part = fname.split("CUTOUT_")[1].replace(".fits", "")
RA_str, DEC_str = coord_part.split("_")
return float(RA_str), float(DEC_str)

def sph_to_cart(ra_deg, dec_deg):
# Convert RA, Dec to Cartesian for KD-tree
rad = np.pi / 180.0
ra = ra_deg * rad
dec = dec_deg * rad
x = np.cos(dec)*np.cos(ra)
y = np.cos(dec)*np.sin(ra)
z = np.sin(dec)
return x, y, z

# Build dictionaries: filter_dict[filter] = list of (RA, DEC, filename)
filter_dict = {f: [] for f in filters}
for f in filters:
fpath = os.path.join(image_dir, f)
fits_files = glob.glob(os.path.join(fpath, f"MOSAIC-{f}*.fits"))
for ff in fits_files:
RA, DEC = parse_filename_info(ff)
filter_dict[f].append((RA, DEC, ff))

# Build KD-trees for each filter
filter_kd = {}
for f in filters:
if len(filter_dict[f]) == 0:
filter_kd[f] = None
continue
# Convert all RA/DEC for this filter to Cartesian
coords = np.array([sph_to_cart(r, d) for (r, d, _) in filter_dict[f]])
tree = cKDTree(coords)
filter_kd[f] = (tree, filter_dict[f])

def find_closest_image(RA_ref, DEC_ref, f):
# Uses KD-tree to find the closest image in a given filter
if filter_kd[f] is None:
return None
tree, sources = filter_kd[f]
x_ref, y_ref, z_ref = sph_to_cart(RA_ref, DEC_ref)
dist, idx = tree.query([x_ref, y_ref, z_ref], k=1)
# dist is the Euclidean distance in Cartesian space on the unit sphere
# Convert this back to an angular distance:
# dist in 3D ~ 2*sin(d/2), where d is angular distance in radians.
# For small angles, dist ~ d (radians).
# Let's approximate: d_radians ~ dist
d_deg = dist * (180.0/np.pi)
if d_deg <= search_radius_deg:
return sources[idx][2] # filename
else:
return None

n_filters = len(filters)

for obj_id, RA_ref, DEC_ref in zip(object_ids, ras, decs):
# Find the closest VIS image
vis_file = find_closest_image(RA_ref, DEC_ref, "VIS")
if vis_file is None:
# No VIS image found, skip this object
continue

# Read VIS image
with fits.open(vis_file) as vishdu:
vis_data = vishdu[0].data
vis_header = vishdu[0].header

Nx = image_size_x
Ny = image_size_y
cube = np.zeros((n_filters,Nx, Ny), dtype=float)

# Crop VIS data
ysize, xsize = vis_data.shape
xstart = (xsize - Nx) // 2
ystart = (ysize - Ny) // 2
xend = xstart + Nx
yend = ystart + Ny
cube[filters.index("VIS"),:, :] = vis_data[ystart:yend, xstart:xend]

# For the other filters, find closest match
for i, f in enumerate(filters):
if f == "VIS":
continue
match_file = find_closest_image(RA_ref, DEC_ref, f)

if match_file is not None:
with fits.open(match_file) as hdu:
fdata = hdu[0].data

yfsize, xfsize = fdata.shape
xfstart = (xfsize - Nx) // 2
yfstart = (yfsize - Ny) // 2
xfend = xfstart + Nx
yfend = yfstart + Ny
cube[i, :, :] = fdata[yfstart:yfend, xfstart:xfend]


# Name the output file using the object_id
out_fname = f"{obj_id}.fits"

out_path = os.path.join(datacubes_dir, out_fname)

primary_hdu = fits.PrimaryHDU(data=cube, header=vis_header)
# Update header with filter info
for i, fil in enumerate(filters):
primary_hdu.header[f'FILTER{i+1}'] = fil
primary_hdu.header['OBJID'] = obj_id
primary_hdu.header['RA_OBJ'] = RA_ref
primary_hdu.header['DEC_OBJ'] = DEC_ref

hdul = fits.HDUList([primary_hdu])
hdul.writeto(out_path, overwrite=True)
print(f"Saved cube for object {obj_id}: {out_path}")
189 changes: 189 additions & 0 deletions scripts/euclid/downstream_tasks/Euclid_NNPZ/Euclid_NNPZ_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import numpy as np
import math
from astropy.io import fits
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import seaborn as sns
import matplotlib.pyplot as plt


def write_results_to_file(metrics, parameter):

file_name = f"{parameter}_Euclid_NNPZ.txt"

# Write to the file
with open(file_name, 'w') as txt_file:
txt_file.write("Final Evaluation Statistics\n")
txt_file.write("===========================\n")
txt_file.write("Percentage: 100\n")
# Write mean and std metrics
for key in metrics:
txt_file.write(f"{key}: {metrics[key]:.4f}\n")


def compute_catalog_statistics(catalog_path, parameter):
with fits.open(catalog_path) as hdul:
catalog_data = hdul[1].data


if parameter == 'PhotoZ':
specz = catalog_data['Z']
photoz = catalog_data['phz_median']
valid_indices = (
~np.isnan(specz) &
~np.isnan(photoz)
)
elif parameter == 'logM':
specz = catalog_data['LOGM']
photoz = catalog_data['phz_pp_median_stellarmass']
valid_indices = (
~np.isnan(specz) &
~np.isnan(photoz) &
(catalog_data['spurious_flag'] == 0) &
(catalog_data['det_quality_flag'] < 4) &
(catalog_data['mumax_minus_mag'] > -2.6) &
(catalog_data['LOGM'] > 0) &
(catalog_data['CHI2'] < 17) &
(catalog_data['LOGM_ERR'] < 0.25) &
(catalog_data['phz_flags'] == 0)
)
elif parameter == 'logM_Enia':
specz = catalog_data['LOGM']
photoz = catalog_data['opp_median_stellarmass']
valid_indices = (
~np.isnan(specz) &
~np.isnan(photoz) &
(catalog_data['spurious_flag'] == 0) &
(catalog_data['det_quality_flag'] < 4) &
(catalog_data['mumax_minus_mag'] > -2.6) &
(catalog_data['phz_flags'] == 0)
)
else:
print("The parameter is not given")


valid_specz = specz[valid_indices]
valid_photoz = photoz[valid_indices]
print(len(valid_specz))
print(len(valid_photoz))
valid_specz = np.asarray(valid_specz, dtype=float)
valid_photoz = np.asarray(valid_photoz, dtype=float)

# Plot the true vs predicted redshifts
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
fig = plt.figure(figsize=(22, 10))
plt.subplots_adjust(left=0.18, bottom=None, right=None, top=None, wspace=0.02, hspace=0)

ax = sns.kdeplot(x=valid_specz, y=valid_photoz, cmap="RdYlBu_r", fill=True,
levels=15, thresh=0.05) # `fill=False` avoids the filled area

plt.gca().set_facecolor('white')

# Add the y=x line
max_val = max(valid_specz.max(), valid_photoz.max())
min_val = min(valid_specz.min(), valid_photoz.min())
if parameter =="PhotoZ":
plt.plot([-0.5,1.1], [-0.5,1.1], '--', linewidth=3, color='black')
plt.plot([-0.5, 1.1], [-0.5 + 0.15, 1.1 + 0.15], ':', linewidth=3, color='black') # Upper threshold
plt.plot([-0.5, 1.1], [-0.5 - 0.15, 1.1 - 0.15], ':', linewidth=3, color='black') # Lower threshold
else:
plt.plot([0.5*min_val, 1.5*max_val], [0.5*min_val, 1.5*max_val], '--', linewidth=3, color='black')
plt.plot([0.5 * min_val, 1.5 * max_val], [0.5 * min_val + 0.25, 1.5 * max_val + 0.25], ':', linewidth=3, color='black') # Upper threshold
plt.plot([0.5 * min_val, 1.5 * max_val], [0.5 * min_val - 0.25, 1.5 * max_val - 0.25], ':', linewidth=3, color='black') # Lower threshold

#plt.text(
#0.15, 0.75, r'$\rm \eta_{out} = %.2f \pm %.2f\%%$' % (outlier_fraction, outlier_fraction_err),
#horizontalalignment='left', verticalalignment='bottom',
#fontsize=50, transform=plt.gca().transAxes
#)


if parameter == 'PhotoZ':
plt.xlabel(r'$z_{\mathrm{DESI}}$', fontsize=50)
plt.ylabel(r'$\mathrm{photo-}z_{{\tt NNPZ}}$', fontsize=50)
elif parameter == 'logM':
plt.xlabel(r'$\mathrm{log(M_{star}/M_{\odot})_{DESI}}$', fontsize=50)
plt.ylabel(r"$\rm log(M_{star}/M_{\odot})_{\tt NNPZ}$", fontsize=50)
elif parameter == 'logM_Enia':
plt.xlabel(r'$\mathrm{True \, z \, log(M_{star}/M_{\odot})_{DESI}}$', fontsize=50)
plt.ylabel(r"$\rm Euclid\mbox{ } log(M_{star}/M_{\odot})_{IRAC}$", fontsize=50)

# Adjust ticks and formatting
plt.minorticks_on()
plt.tick_params(axis='x', which='major', labelsize=45)
plt.tick_params(axis='both', which='both', direction='in', top=True, right=True, labelsize=45)
plt.tick_params(axis='both', which='major', length=15) # Major ticks length
plt.tick_params(axis='both', which='minor', length=10) # Minor ticks length

plt.xticks(fontsize=50)
plt.yticks(fontsize=50)
if parameter == 'PhotoZ':
plt.xlim(-0.18,1.1)
plt.ylim(-0.18,1.1)
elif parameter == 'logM' :
plt.xlim(6.8, 13)
plt.ylim(6.8, 13)
elif parameter == 'logM_Enia':
plt.xlim(5, 14)
plt.ylim(5, 14)
else:
plt.xlim(0.5*min_val, 1.5*max_val)
plt.ylim(0.5*min_val, 1.5*max_val)


# Remove grid and legend, set white background
plt.gca().set_facecolor('white')
plt.grid(False)
plt.legend([], [], frameon=False)


# Save the plot
plt.tight_layout()
plot_file = f"{parameter}_Euclid_NNPZ.png"
plt.savefig(plot_file)
plt.close()

# Compute metrics
mse = mean_squared_error(valid_specz, valid_photoz)
mae = mean_absolute_error(valid_specz, valid_photoz)
r2 = r2_score(valid_specz, valid_photoz)

delta_z = (valid_photoz - valid_specz) / (1 + valid_specz)
mean_delta_z = np.mean(delta_z)
sigma_68 = (np.percentile(delta_z, 84.1) - np.percentile(delta_z, 15.9)) / 2
outlier_fraction = 100 * len(valid_photoz[np.abs(delta_z) >= 0.15]) / len(valid_photoz)
nmad = 1.48 * np.median(np.abs(delta_z - np.median(delta_z)))


print("Evaluation Statistics")
print("=====================")
print(f"Number of valid sources: {len(valid_specz)}")
print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Mean Absolute Error (MAE): {mae:.4f}")
print(f"R-squared: {r2:.4f}")
print(f"Bias: {mean_delta_z:.4f}")
print(f"Sigma 68: {sigma_68:.4f}")
print(f"Outlier Fraction (|Δz| ≥ 0.15): {outlier_fraction:.2f}%")
print(f"NMAD: {nmad:.4f}")

return {
"mse": mse,
"mae": mae,
"r2": r2,
"bias": mean_delta_z,
"nmad": nmad,
"sigma_68": sigma_68,
"outlier_fraction": outlier_fraction,
}

# Path to the catalog file
parameter = 'PhotoZ'
if parameter == 'PhotoZ':
catalog_path = "../../../Q1_data/DESISpecZ.fits"
elif parameter == 'logM' or parameter == 'logM_Enia':
catalog_path = "../../../Q1_data/DESI_logM.fits"
else:
print("No catalog")

metrics = compute_catalog_statistics(catalog_path, parameter)
write_results_to_file(metrics,parameter)
Loading