Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/3.6.2/validator logit parameters #1057

Merged
merged 8 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.1
3.6.2
2 changes: 1 addition & 1 deletion bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
nest_asyncio.apply()

# Bittensor code and protocol version.
__version__ = '3.6.1'
__version__ = '3.6.2'
version_split = __version__.split(".")
__version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))

Expand Down
2 changes: 1 addition & 1 deletion bittensor/_cli/cli_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def list(self):
coldkeypub_str = '?'

wallet_tree = root.add("\n[bold white]{} ({})".format(w_name, coldkeypub_str))
hotkeys_path = self.config.wallet.path + w_name + '/hotkeys'
hotkeys_path = os.path.join(self.config.wallet.path, w_name, 'hotkeys')
try:
hotkeys = next(os.walk(os.path.expanduser(hotkeys_path)))
if len( hotkeys ) > 1:
Expand Down
36 changes: 23 additions & 13 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
self.device = torch.device ( device = self.config.neuron.device )
self.nucleus = nucleus ( config = self.config, device = self.device, subtensor = self.subtensor ).to( self.device )
self.dataset = (bittensor.dataset(config=self.config, batch_size=self.subtensor.validator_batch_size,
block_size=self.subtensor.validator_sequence_length + self.config.neuron.validation_len)
block_size=self.subtensor.validator_sequence_length + self.config.neuron.validation_len + self.subtensor.prune_len)
if dataset is None else dataset)
self.optimizer = torch.optim.SGD(
self.nucleus.parameters(), lr=self.config.neuron.learning_rate, momentum=self.config.neuron.momentum
Expand Down Expand Up @@ -234,7 +234,7 @@ def add_args( cls, parser ):
parser.add_argument('--neuron.blocks_per_epoch', type=int, help='Blocks per epoch, -1 value means we use the chain value.', default = -1 )
parser.add_argument('--neuron.epochs_until_reset', type=int, help='Number of epochs before weights are reset.', default = -1 )
parser.add_argument('--neuron.validation_len', type=int, help='Number of tokens to holdout for phrase validation beyond sequence context.', default=8)
parser.add_argument('--neuron.prune_len', type=int, help='Number of tokens to prune from each validation input sequence.', default=1)
parser.add_argument('--neuron.prune_len', type=int, help='Number of tokens to prune from each validation input sequence. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--neuron.device', type=str, help='miner default training device cpu/cuda', default=("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--neuron.clip_gradients', type=float, help='Implement gradient clipping to avoid exploding loss on smaller architectures.', default=1.0 )
parser.add_argument('--neuron.track_hotkey_changes', action='store_true', help='If True, track hotkey changes.', default=False)
Expand Down Expand Up @@ -400,13 +400,15 @@ def run_epoch( self ):
batch_size = self.subtensor.validator_batch_size
sequence_length = self.subtensor.validator_sequence_length
validation_len = self.config.neuron.validation_len # Number of tokens to holdout for phrase validation beyond sequence context
prune_len = self.config.neuron.prune_len # Number of tokens to holdout for phrase validation beyond sequence context
# Number of tokens to prune for phrase validation beyond sequence context
prune_len = self.config.neuron.prune_len = self.subtensor.prune_len
min_allowed_weights = self.subtensor.min_allowed_weights
max_weight_limit = self.subtensor.max_weight_limit
blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch
epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset
self.config.nucleus.scaling_law_power = self.subtensor.scaling_law_power
self.config.nucleus.synergy_scaling_law_power = self.subtensor.synergy_scaling_law_power
self.config.nucleus.logits_divergence = self.subtensor.logits_divergence

# === Logs Prometheus ===
self.prometheus_gauges.labels("current_block").set( current_block )
Expand Down Expand Up @@ -688,7 +690,7 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):

if 'logits_excess_nxt' in stats:
# penalize by logits divergence excess
extra_stats['shapley_values_nxt'] /= 1 + stats['logits_excess_nxt']
extra_stats['shapley_values_nxt'] /= 1 + self.config.nucleus.logits_divergence * stats['logits_excess_nxt']

# === EMA zeroing update ===
# Push zero into EMA for synapse_keys to exponentially decay weighting keys if neuron non-responsive
Expand Down Expand Up @@ -825,6 +827,7 @@ def __init__( self, config, device, subtensor ):

self.config.nucleus.scaling_law_power = subtensor.scaling_law_power if self.config.nucleus.scaling_law_power == -1 else self.config.nucleus.scaling_law_power
self.config.nucleus.synergy_scaling_law_power = subtensor.synergy_scaling_law_power if self.config.nucleus.synergy_scaling_law_power == -1 else self.config.nucleus.synergy_scaling_law_power
self.config.nucleus.logits_divergence = subtensor.logits_divergence if self.config.nucleus.logits_divergence == -1 else self.config.nucleus.logits_divergence

self.device = device
self.max_n = subtensor.max_n
Expand Down Expand Up @@ -872,6 +875,7 @@ def add_args( cls, parser ):
parser.add_argument('--nucleus.no_dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--nucleus.logits_divergence', type=float, help=' the divergence value for logit anomaly detection (default value: -1, pulling from subtensor directly)', default=-1)

@classmethod
def config ( cls ):
Expand Down Expand Up @@ -983,7 +987,7 @@ def forward(
num_endpoints = len(random_endpoints) # in case len(self.permute_uids) < num_endpoints during random_uids select

logger.info(f'Forward \t| Routing forward <dim>[{time.time() - start_time:.3g}s]</dim>')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)} (prune_len={prune_len})')
request_start_time = time.time()

# === Define which synapse we want to use ===
Expand Down Expand Up @@ -1028,6 +1032,7 @@ def forward(
validation_params = (random_uids, query_responses, return_ops, times, routing_score,
inputs, val_len, self.loss_fct,
self.config.nucleus.scaling_law_power, self.config.nucleus.synergy_scaling_law_power,
self.config.nucleus.logits_divergence,
console_width, self.config.logging.debug or self.config.logging.trace)

loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses
Expand Down Expand Up @@ -1057,7 +1062,7 @@ def scaling_law_loss_to_params(loss):
def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
scaling_law_power: float, synergy_scaling_law_power: float, logits_divergence_penalty: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Expand All @@ -1084,6 +1089,8 @@ def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTenso
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
logits_divergence_penalty (:obj:`float`, `required`):
Penalty scaling for logits divergence.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -1135,7 +1142,7 @@ def _synergy(first, second, target, _ext):
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='')

logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f}) '
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

synergy_start_time = time.time()
Expand All @@ -1162,7 +1169,7 @@ def _synergy(first, second, target, _ext):
if hasattr(s[key], 'item'):
s[key] = s[key].item()

logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f}) '
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

if logging:
Expand All @@ -1184,8 +1191,8 @@ def _synergy(first, second, target, _ext):
def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0
scaling_law_power: float, synergy_scaling_law_power: float, logits_divergence_penalty: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0,
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Calculate Shapley values and neuron response validation measure statistics, given TextCausalLMNext synapse responses.
Expand All @@ -1211,6 +1218,8 @@ def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatT
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
logits_divergence_penalty (:obj:`float`, `required`):
Penalty scaling for logits divergence.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -1250,17 +1259,18 @@ def _synergy(first, second, target, ext):
shapley_start_time = time.time()
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='_nxt')
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f}) '
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

divergence_start_time = time.time()
with torch.no_grad():
logits_divergence(stats, uids, query_responses, return_ops, times, index_s, ext='_nxt')
logger.info(f'{str(synapse)} \t| Logits divergences <dim>[{time.time() - divergence_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Logits divergences (penalty={logits_divergence_penalty}) '
f'<dim>[{time.time() - divergence_start_time:.3g}s]</dim>')

synergy_start_time = time.time()
syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=synergy_scaling_law_power)
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f}) '
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

# === Shapley value combination ===
Expand Down
26 changes: 26 additions & 0 deletions bittensor/_subtensor/subtensor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,32 @@ def make_substrate_call_with_retry():
).value
return make_substrate_call_with_retry()

@property
def prune_len (self) -> int:
r""" Returns PruneLen
Returns:
prune_len (int):
the number of pruned tokens from each requests
"""
@retry(delay=2, tries=3, backoff=2, max_delay=4)
def make_substrate_call_with_retry():
with self.substrate as substrate:
return substrate.query( module='SubtensorModule', storage_function = 'ValidatorPruneLen' ).value
return make_substrate_call_with_retry()

@property
def logits_divergence (self) -> int:
r""" Returns logits_divergence
Returns:
logits_divergence (int):
the divergence value for logit distances, a measure for anomaly detection
"""
@retry(delay=2, tries=3, backoff=2, max_delay=4)
def make_substrate_call_with_retry():
with self.substrate as substrate:
U64MAX = 18446744073709551615
return substrate.query( module='SubtensorModule', storage_function = 'ValidatorLogitsDivergence' ).value/U64MAX
return make_substrate_call_with_retry()
Comment on lines +428 to +452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I miss unit tests to cover this functions.

From my point of view, If we are hotfixing, we should define what we are covering creating the missing tests that was not existing and lead us to have the scenario we are fixing.

What do you think about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. In this case, the master branch was failing tests, so I pulled a few updates that fixed the tests. Those in of itself should have been a hotfix. I will update the description


def serve_axon (
self,
Expand Down
Loading