Skip to content

Latest commit

 

History

History
199 lines (155 loc) · 7.26 KB

README_DOCS.md

File metadata and controls

199 lines (155 loc) · 7.26 KB

vtacML

vtacML is a Python package designed for the real-time analysis of data from the Visible Telescope (VT) of the SVOM satellite. This package uses machine learning models to analyze features from a list of observed VT sources and identify potential gamma-ray burst (GRB) optical afterglow candidates. vtacML is integrated into the real-time SVOM VT VHF pipeline and flags each source detected, indicating the probability that it is a GRB candidate. This information is then used by Burst Advocates (BAs) on shift to help them identify which source is the real GRB counterpart.

Table of Contents

Overview

The SVOM mission, a collaboration between the China National Space Administration (CNSA) and the French space agency CNES, aims to study gamma-ray bursts (GRBs), the most energetic explosions in the universe. The Visible Telescope (VT) on SVOM plays a critical role in observing these events in the optical wavelength range.

vtacML leverages machine learning to analyze VT data, providing a probability score for each observation to indicate its likelihood of being a GRB candidate. The package includes tools for data preprocessing, model training, evaluation, and visualization.

Installation

To install vtacML, you can use pip:

pip install vtacML

Alternatively, you can clone the repository and install the package locally:

git clone https://github.com/jerbeario/VTAC_ML.git
cd vtacML
pip install .

Usage

Quick Start

Here’s a quick example to get you started with vtacML:

from vtacML.pipeline import VTACMLPipe

# Initialize the pipeline
pipeline = VTACMLPipe()

# Load configuration
pipeline.load_config('path/to/config.yaml')

# Train the model
pipeline.train()

# Evaluate the model
pipeline.evaluate('evaluation_name', plot=True)

# Predict GRB candidates
predictions = pipeline.predict(observation_dataframe, prob=True)
print(predictions)

Grid Search and Model Training

vtacML can perform grid search on a large array of models and parameters specified in the configuration file. Initialize the VTACMLPipe class with a specified config file (or use the default) and train it. Then, you can save the best model for future use.

from vtacML.pipeline import VTACMLPipe

# Initialize the pipeline with a configuration file
pipeline = VTACMLPipe(config_file='path/to/config.yaml')

# Train the model with grid search
pipeline.train()

# Save the best model
pipeline.save_model('path/to/save/best_model.pkl')

Loading and Using the Best Model

After training and saving the best model, you can create a new instance of the VTACMLPipe class and load the best model for further use.

from vtacML.pipeline import VTACMLPipe

# Initialize a new pipeline instance
pipeline = VTACMLPipe()

# Load the best model
pipeline.load_model('path/to/save/best_model.pkl')

# Predict GRB candidates
predictions = pipeline.predict(observation_dataframe, prob=True)
print(predictions)

Using Pre-trained Model for Immediate Prediction

If you already have a trained model, you can use the quick wrapper function predict_from_best_pipeline to predict data immediately. A pre-trained model is available by default.

from vtacML.pipeline import predict_from_best_pipeline

# Predict GRB candidates using the pre-trained model
predictions = predict_from_best_pipeline(observation_dataframe, model_path='path/to/pretrained_model.pkl')
print(predictions)

Config File

The config file is used to configure the model searching process.

# Default config file, used to search for best model using only first two sequences (X0, X1) from the VT pipeline
Inputs:
  file: 'combined_qpo_vt_all_cases_with_GRB_with_flags.parquet' # Data file used for training. Located in /data/
#  path: 'combined_qpo_vt_with_GRB.parquet'
#  path: 'combined_qpo_vt_faint_case_with_GRB_with_flags.parquet'
  columns: [
    "MAGCAL_R0",
    "MAGCAL_B0",
    "MAGERR_R0",
    "MAGERR_B0",
    "MAGCAL_R1",
    "MAGCAL_B1",
    "MAGERR_R1",
    "MAGERR_B1",
    "MAGVAR_R1",
    "MAGVAR_B1",
    'EFLAG_R0',
    'EFLAG_R1',
    'EFLAG_B0',
    'EFLAG_B1',
    "NEW_SRC",
    "DMAG_CAT"
    ] # features used for training
  target_column: 'IS_GRB' # feature column that holds the class information to be predicted

# Set of models and parameters to perform GridSearchCV over
Models:
  rfc:
    class: RandomForestClassifier()
    param_grid:
      'rfc__n_estimators': [100, 200, 300]  # Number of trees in the forest
      'rfc__max_depth': [4, 6, 8]  # Maximum depth of the tree
      'rfc__min_samples_split': [2, 5, 10]  # Minimum number of samples required to split an internal node
      'rfc__min_samples_leaf': [1, 2, 4]  # Minimum number of samples required to be at a leaf node
      'rfc__bootstrap': [True, False]  # Whether bootstrap samples are used when building trees
  ada:
    class: AdaBoostClassifier()
    param_grid:
      'ada__n_estimators': [50, 100, 200]  # Number of weak learners
      'ada__learning_rate': [0.01, 0.1, 1]  # Learning rate
      'ada__algorithm': ['SAMME']  # Algorithm for boosting
  svc:
    class: SVC()
    param_grid:
      'svc__C': [0.1, 1, 10, 100]  # Regularization parameter
      'svc__kernel': ['poly', 'rbf', 'sigmoid']  # Kernel type to be used in the algorithm
      'svc__gamma': ['scale', 'auto']  # Kernel coefficient
      'svc__degree': [3, 4, 5]  # Degree of the polynomial kernel function (if `kernel` is 'poly')
  knn:
    class: KNeighborsClassifier()
    param_grid:
      'knn__n_neighbors': [3, 5, 7, 9]  # Number of neighbors to use
      'knn__weights': ['uniform', 'distance']  # Weight function used in prediction
      'knn__algorithm': ['ball_tree', 'kd_tree', 'brute']  # Algorithm used to compute the nearest neighbors
      'knn__p': [1, 2]  # Power parameter for the Minkowski metric
  lr:
    class: LogisticRegression()
    param_grid:
      'lr__penalty': ['l1', 'l2', 'elasticnet']  # Specify the norm of the penalty
      'lr__C': [0.01, 0.1, 1, 10]  # Inverse of regularization strength
      'lr__solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']  # Algorithm to use in the optimization problem
      'lr__max_iter': [100, 200, 300]  # Maximum number of iterations taken for the solvers to converge
  dt:
    class: DecisionTreeClassifier()
    param_grid:
      'dt__criterion': ['gini', 'entropy']  # The function to measure the quality of a split
      'dt__splitter': ['best', 'random']  # The strategy used to choose the split at each node
      'dt__max_depth': [4, 6, 8, 10]  # Maximum depth of the tree
      'dt__min_samples_split': [2, 5, 10]  # Minimum number of samples required to split an internal node
      'dt__min_samples_leaf': [1, 2, 4]  # Minimum number of samples required to be at a leaf node

# Output directories
Outputs:
  model_path: '/output/models'
  viz_path: '/output/visualizations/'
  plot_correlation:
    flag: True
    path: 'output/corr_plots/'

Documentation