Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
EdoardoChidichimo authored Jun 24, 2024
1 parent 4ffee66 commit 99eca8f
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions hyperit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,38 @@ def set_estimator(estimator_type: str, measure: str, params: dict) -> tuple:
'symbolic': ('Symbolic Estimator', None, {})
},
'te': {
'ksg': ('KSG Estimator', 'infodynamics.measures.continuous.kraskov.TransferEntropyCalculatorMultiVariateKraskov', {
"k_HISTORY": str(params.get('k', 1)), "k_TAU": str(params.get('k_tau', 1)),
"l_HISTORY": str(params.get('l', 1)), "l_TAU": str(params.get('l_tau', 1)),
"DELAY": str(params.get('delay', 1)), "k": str(params.get('kraskov_param', 4))
}),
'kernel': ('Box Kernel Estimator', 'infodynamics.measures.continuous.kernel.TransferEntropyCalculatorMultiVariateKernel', {
"KERNEL_WIDTH": str(params.get('kernel_width', 0.5))
}),
'gaussian': ('Gaussian Estimator', 'infodynamics.measures.continuous.gaussian.TransferEntropyCalculatorMultiVariateGaussian', {
"k_HISTORY": str(params.get('k', 1)), "k_TAU": str(params.get('k_tau', 1)),
"l_HISTORY": str(params.get('l', 1)), "l_TAU": str(params.get('l_tau', 1)),
"DELAY": str(params.get('delay', 1)), "BIAS_CORRECTION": str(params.get('bias_correction', False)).lower()
}),
'symbolic': ('Symbolic Estimator', 'infodynamics.measures.continuous.symbolic.TransferEntropyCalculatorSymbolic', {"k_HISTORY": str(params.get('k', 1))}, 2)
'ksg': ('KSG Estimator',
'infodynamics.measures.continuous.kraskov.TransferEntropyCalculatorMultiVariateKraskov',
{
"k_HISTORY": str(params.get('k', 1)), "k_TAU": str(params.get('k_tau', 1)),
"l_HISTORY": str(params.get('l', 1)), "l_TAU": str(params.get('l_tau', 1)),
"DELAY": str(params.get('delay', 1)), "k": str(params.get('kraskov_param', 4))
},
(params.get('k', 1), params.get('k_tau', 1), params.get('l', 1), params.get('l_tau', 1), params.get('delay', 1))),
# TEviaCondMI requires initialise(sourceDim, destDim, k, ktau, l, ltau, delay) so sourceDim and destDim will be added later

'kernel': ('Box Kernel Estimator',
'infodynamics.measures.continuous.kernel.TransferEntropyCalculatorMultiVariateKernel',
{
"KERNEL_WIDTH": str(params.get('kernel_width', 0.5))
},
(params.get('k', 1), params.get('k_tau', 1), params.get('l', 1), params.get('l_tau', 1), params.get('delay', 1))),

'gaussian': ('Gaussian Estimator',
'infodynamics.measures.continuous.gaussian.TransferEntropyCalculatorMultiVariateGaussian',
{
"k_HISTORY": str(params.get('k', 1)), "k_TAU": str(params.get('k_tau', 1)),
"l_HISTORY": str(params.get('l', 1)), "l_TAU": str(params.get('l_tau', 1)),
"DELAY": str(params.get('delay', 1)), "BIAS_CORRECTION": str(params.get('bias_correction', False)).lower()
},
(params.get('k', 1), params.get('k_tau', 1), params.get('l', 1), params.get('l_tau', 1), params.get('delay', 1))),

'symbolic': ('Symbolic Estimator',
'infodynamics.measures.continuous.symbolic.TransferEntropyCalculatorSymbolic',
{
"k_HISTORY": str(params.get('k', 1))
},
(params.get('k', 1), params.get('k_tau', 1), params.get('l', 1), params.get('l_tau', 1), params.get('delay', 1)))
}
}

Expand All @@ -138,6 +156,4 @@ def set_estimator(estimator_type: str, measure: str, params: dict) -> tuple:
else:
raise ValueError(f"Estimator type {estimator_type} not supported for measure {measure}.")

print(properties, flush=True)
print(initialise_parameter, flush=True)
return estimator_name, calculator, properties, initialise_parameter

0 comments on commit 99eca8f

Please sign in to comment.