-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathweighted_cox_loss.py
42 lines (33 loc) · 1.32 KB
/
weighted_cox_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
def loss_fn(risks, events, weights, device):
"""
Calculate the Cox proportional hazards loss with weights for imbalance.
Parameters:
- risks: Tensor of predicted risk scores (log hazard ratio) from the model.
- events: Tensor of event indicators (1 if event occurred, 0 for censored).
- weights: Tensor of weights for each sample.
- device: The device (CPU or GPU) on which tensors are allocated.
Returns:
- Calculated loss.
"""
risks = risks.to(device)
events = events.to(device)
weights = weights.to(device)
events = events.view(-1)
risks = risks.view(-1)
weights = weights.view(-1)
total_weighted_events = torch.sum(weights * events)
# Sort by risk score
order = torch.argsort(risks, descending=True)
risks = risks[order]
events = events[order]
weights = weights[order]
# Calculate the risk set for each time
hazard_ratio = torch.exp(risks)
weighted_cumulative_hazard = torch.cumsum(weights * hazard_ratio, dim=0)
log_risk = torch.log(weighted_cumulative_hazard)
uncensored_likelihood = weights * (risks - log_risk)
# Only consider uncensored events
censored_likelihood = uncensored_likelihood * events
neg_likelihood = -torch.sum(censored_likelihood) / total_weighted_events
return neg_likelihood