Skip to content

Commit

Permalink
[BIT-601] Scaling law on EMA loss (#1022)
Browse files Browse the repository at this point in the history
* Compute scaling law on EMA loss

The neural language model scaling law is typically meant to be computed on a loss averaged over the entire training sample. Currently it is computed within-batch only, which frequently sees losses below 1.69 the of natural entropy of text.

Here we now compute the scaling law and the resultant effective number of model parameters on the exponentially moving average loss for a server, which should greatly improve the definition of the result.

* Convert to tensor for calcs

* Ascending sort loss tables

* Add top and bottom weights to validator table

* Add top and bottom weights to validator table

* Add top and bottom weights to validator table

* Change mark uids in weights table

* Update scaling law powers each epoch

* Fix neuron.ip_version
  • Loading branch information
opentaco authored Dec 9, 2022
1 parent a92976f commit 332ba29
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 32 deletions.
97 changes: 66 additions & 31 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
self.neuron_changes = {} # neuron hotkey changes dict of dicts of dicts: [uid] -> [block] -> {'new_hotkey': , 'old_hotkey': , 'old_stats':}
self.alpha = 0.1 # EMA coefficient in [0, 1], higher alpha discounts older observations faster


if self.config.neuron.validation_synapse == 'TextCausalLMNext':
self.weight_key = 'shapley_values_nxt' # stat key + ! to calculate neuron weights with
# stat keys to duplicate (['key']->['key!']) and push zero to its EMA if neuron non-responsive
Expand Down Expand Up @@ -399,6 +400,8 @@ def run_epoch( self ):
max_weight_limit = self.subtensor.max_weight_limit
blocks_per_epoch = self.subtensor.validator_epoch_length if self.config.neuron.blocks_per_epoch == -1 else self.config.neuron.blocks_per_epoch
epochs_until_reset = self.subtensor.validator_epochs_per_reset if self.config.neuron.epochs_until_reset == -1 else self.config.neuron.epochs_until_reset
self.config.nucleus.scaling_law_power = self.subtensor.scaling_law_power if self.config.nucleus.scaling_law_power == -1 else self.config.nucleus.scaling_law_power
self.config.nucleus.synergy_scaling_law_power = self.subtensor.synergy_scaling_law_power if self.config.nucleus.synergy_scaling_law_power == -1 else self.config.nucleus.synergy_scaling_law_power

# === Logs Prometheus ===
self.prometheus_gauges.labels("current_block").set( current_block )
Expand All @@ -408,6 +411,8 @@ def run_epoch( self ):
self.prometheus_gauges.labels("min_allowed_weights").set( min_allowed_weights )
self.prometheus_gauges.labels("blocks_per_epoch").set( blocks_per_epoch )
self.prometheus_gauges.labels("epochs_until_reset").set( epochs_until_reset )
self.prometheus_gauges.labels("scaling_law_power").set( self.config.nucleus.scaling_law_power )
self.prometheus_gauges.labels("synergy_scaling_law_power").set( self.config.nucleus.synergy_scaling_law_power )

# === Update dataset size ===
if (batch_size != self.dataset.batch_size) or (sequence_length + validation_len != self.dataset.block_size):
Expand Down Expand Up @@ -537,7 +542,7 @@ def run_epoch( self ):
# === Calculate neuron weights ===
sample_uids, sample_weights = self.calculate_weights()
self.weights_table(sample_uids, sample_weights,
include_uids=list(stats.keys()), num_rows=2 * len(stats)) # print weights table
include_uids=list(stats.keys()), num_rows=len(stats) + 25) # print weights table

# === Logs ===
if self.config.using_wandb:
Expand Down Expand Up @@ -649,6 +654,33 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
for _uid, _stats in neuron_stats.items():
stats = self.neuron_stats.setdefault(_uid, {})

# === EMA normal update ===
# If synapse responsive push available values into EMA for normal update.
# Normal EMA values provide a view on neuron performance if fully responsive.
for key in _stats: # detailed neuron evaluation fields, e.g. loss, shapley_values, synergy
if math.isnan(_stats[key]):
continue
if key in stats:
stats[key] = (1 - self.alpha) * stats[key] + self.alpha * _stats[key] # update EMA
else:
stats.setdefault(key, _stats[key])

# === Extra stats computation ===
# Compute values on EMA stats, such as the scaling law on EMA loss.
# Required for values that need to be computed on longer-term stats.
extra_stats = {}
if 'loss_nxt' in _stats and 'loss_nxt' in stats: # elif neuron not responsive then omit
# estimate the effective number of model parameters from EMA loss
_num_params = scaling_law_loss_to_params(torch.tensor(stats['loss_nxt']))

# powered down number of params, e.g. dynamic range 3 → 6 nats for scaling_law_power=0.5
_pow_num_params = torch.pow(_num_params, self.config.nucleus.scaling_law_power)

extra_stats.update({'est_params_nxt': _num_params.item(), 'base_params_nxt': _pow_num_params.item()})

if 'synergy_nxt' in stats:
extra_stats['shapley_values_nxt'] = extra_stats['base_params_nxt'] + stats['synergy_nxt']

# === EMA zeroing update ===
# Push zero into EMA for synapse_keys to exponentially decay weighting keys if neuron non-responsive
if 'updates!' in stats:
Expand All @@ -662,27 +694,30 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
if key in _stats and not math.isnan(_stats[key]):
responsive_uids += [_uid]
stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * _stats[key]
elif key in extra_stats and not math.isnan(extra_stats[key]):
responsive_uids += [_uid]
stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * extra_stats[key]
else:
stats[zkey] = (1 - self.alpha) * stats[zkey] # + self.alpha * 0

# === EMA normal update ===
# If synapse responsive push available values into EMA for normal update.
# Normal EMA values provide a view on neuron performance if fully responsive.
for key in self.synapse_keys:
if key in _stats:
if key in _stats or key in extra_stats:
updates = 'updates_' + key
if updates in stats:
stats[updates] += 1 # increment number of normal EMA updates made
else:
stats.setdefault(updates, 1) # add updates fields for new uid entries

for key in _stats: # detailed neuron evaluation fields, e.g. loss, shapley_values, synergy
if math.isnan(_stats[key]):
for key in extra_stats: # detailed neuron evaluation fields, e.g. loss, shapley_values, synergy
if math.isnan(extra_stats[key]):
continue
if key in stats:
stats[key] = (1 - self.alpha) * stats[key] + self.alpha * _stats[key] # update EMA
stats[key] = (1 - self.alpha) * stats[key] + self.alpha * extra_stats[key] # update EMA
else:
stats.setdefault(key, _stats[key])
stats.setdefault(key, extra_stats[key])

return responsive_uids, list(neuron_stats.keys()) # responsive_uids, queried_uids

Expand Down Expand Up @@ -739,21 +774,25 @@ def weights_table(self, sample_uids, sample_weights, include_uids=None, num_rows
# === Weight table ===
# Prints exponential moving average statistics of valid neurons and latest weights
_neuron_stats = {}
uid_weights = [] # (uid, weight) tuples for sorting to find top/bottom weights
unvalidated = []
for uid, weight in zip(sample_uids.tolist(), sample_weights.tolist()):
if uid in self.neuron_stats:
_neuron_stats[uid] = {k: v for k, v in self.neuron_stats[uid].items()}
_neuron_stats[uid]['weight'] = weight
uid_weights += [(uid, weight)]
else:
unvalidated += [uid]

avail_include_uids = None
if include_uids is not None and num_rows is not None:
avail_include_uids = list(set(_neuron_stats.keys()) & set(include_uids)) # exclude include_uids with no stats
sorted_uids = sorted(uid_weights, key=lambda tup: tup[1])
top_bottom_uids = [_uid for _uid, _ in sorted_uids[:5] + sorted_uids[-10:]]
_include_uids = set(include_uids) | set(top_bottom_uids)
avail_include_uids = list(set(_neuron_stats.keys()) & _include_uids) # exclude include_uids with no stats
if len(_neuron_stats) > num_rows: # limit table to included_uids and remaining sample up to num_rows
remaining_uids = set(_neuron_stats.keys()) - set(include_uids) # find sample remaining, loses sample ordering
remaining_uids = set(_neuron_stats.keys()) - _include_uids # find sample remaining, loses sample ordering
remaining_uids = [uid for uid in _neuron_stats if uid in remaining_uids] # recover sample ordering
limited_uids = avail_include_uids + remaining_uids[:num_rows - len(include_uids)]
limited_uids = avail_include_uids + remaining_uids[:num_rows - len(_include_uids)]
_neuron_stats = {uid: stats for uid, stats in _neuron_stats.items() if uid in limited_uids}

print()
Expand All @@ -765,7 +804,7 @@ def weights_table(self, sample_uids, sample_weights, include_uids=None, num_rows
f'[white] max:[bold]{sample_weights.max().item():.4g}[/bold] / '
f'min:[bold]{sample_weights.min().item():.4g}[/bold] [/white] '
f'\[{max_weight_limit:.4g} allowed]', # caption
mark_uids=avail_include_uids)
mark_uids=include_uids)


class nucleus( torch.nn.Module ):
Expand Down Expand Up @@ -1086,7 +1125,8 @@ def _synergy(first, second, target, _ext):
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='')

logger.info(f'{str(synapse)} \t| Shapley base values <dim>[{time.time() - shapley_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

synergy_start_time = time.time()

Expand All @@ -1112,7 +1152,8 @@ def _synergy(first, second, target, _ext):
if hasattr(s[key], 'item'):
s[key] = s[key].item()

logger.info(f'{str(synapse)} \t| Shapley synergy values <dim>[{time.time() - synergy_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

if logging:
# === Synergy table ===
Expand Down Expand Up @@ -1186,14 +1227,7 @@ def _base_params(_stats, query_response):
_loss_val = _losses_val.mean()
_loss = _losses.mean()

# estimate the effective number of model parameters, modified with the scaling_law_power
_num_params = scaling_law_loss_to_params(_loss)

# powered down number of params, e.g. dynamic range 3 → 6 nats for scaling_law_power=0.5
_pow_num_params = torch.pow(_num_params, scaling_law_power)

_stats.update({'loss_val_nxt': _loss_val, 'losses_nxt': _losses, 'loss_nxt': _loss,
'est_params_nxt': _num_params, 'base_params_nxt': _pow_num_params,
'synergy_nxt': 0, 'synergy_loss_diff_nxt': 0})

def _synergy(first, second, target, ext):
Expand All @@ -1208,7 +1242,8 @@ def _synergy(first, second, target, ext):
loss, stats, unsuccessful = shapley_base(uids, query_responses, return_ops, times, routing_score,
_base_params, index_s, ext='_nxt')

logger.info(f'{str(synapse)} \t| Shapley base values <dim>[{time.time() - shapley_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Shapley base values (power={scaling_law_power:.1f})'
f'<dim>[{time.time() - shapley_start_time:.3g}s]</dim>')

synergy_start_time = time.time()

Expand All @@ -1217,31 +1252,29 @@ def _synergy(first, second, target, ext):
# === Shapley value combination ===
# Combine base values with synergy approximation to get final Shapley values.
for s in stats.values():
if 'base_params_nxt' in s and 'synergy_nxt' in s:
s['shapley_values_nxt'] = s['base_params_nxt'] + s['synergy_nxt']

if 'losses_nxt' in s:
del s['losses_nxt'] # remove batch losses - not needed for stats anymore

for key in s:
if hasattr(s[key], 'item'):
s[key] = s[key].item()

logger.info(f'{str(synapse)} \t| Shapley synergy values <dim>[{time.time() - synergy_start_time:.3g}s]</dim>')
logger.info(f'{str(synapse)} \t| Shapley synergy values (power={synergy_scaling_law_power:.1f})'
f'<dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

if logging:
# === Response table ===
# Prints the query response table: top prediction probabilities and texts for batch tasks
batch_predictions = format_predictions(uids, query_responses, return_ops, inputs, validation_len, index_s)
response_table(batch_predictions, stats, sort_col='shapley_values_nxt', console_width=console_width)
response_table(batch_predictions, stats, sort_col='loss_nxt', console_width=console_width)

# === Synergy table ===
# Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal)
synergy_table(stats, syn_loss_diff, 'shapley_values_nxt', console_width)
synergy_table(stats, syn_loss_diff, 'loss_nxt', console_width)

# === Neuron responses (table) ===
# Prints the evaluation of the neuron responses to the validator request
synapse_table(str(synapse), stats, 'shapley_values_nxt', console_width, shapley_start_time)
synapse_table(str(synapse), stats, 'loss_nxt', console_width, shapley_start_time)

# === Unsuccessful responses ===
# Prints the return codes and response times of unsuccessful responses
Expand Down Expand Up @@ -1453,7 +1486,8 @@ def response_table(batch_predictions: List, stats: Dict, sort_col: str, console_
col_keys = [c[1] for c in columns]

# === Sort rows ===
sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s], reverse=True, key=lambda _row: _row[1])
sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s],
reverse='loss' not in sort_col, key=lambda _row: _row[1])
if sort_col in col_keys:
sort_idx = col_keys.index(sort_col) # sort column with key of sort_col
columns[sort_idx][0] += '\u2193' # ↓ downwards arrow (sort)
Expand Down Expand Up @@ -1503,7 +1537,8 @@ def response_table(batch_predictions: List, stats: Dict, sort_col: str, console_
def synergy_table(stats, syn_loss_diff, sort_col, console_width):
r""" Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal)
"""
sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s], reverse=True, key=lambda _row: _row[1])
sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s],
reverse='loss' not in sort_col, key=lambda _row: _row[1])
uid_col = neuron_stats_columns[0] # [Column_name, key_name, format_string, rich_style]
columns = [uid_col] + [[f'{s[0]}', '', '{:.2f}', ''] for s in sort]
rows = [[uid_col[2].format(s[0])] +
Expand Down Expand Up @@ -1553,7 +1588,7 @@ def stats_table(stats, sort_col, console_width, title, caption, mark_uids=None):
if sort_col in col_keys:
sort_idx = col_keys.index(sort_col) # sort column with key of sort_col
columns[sort_idx][0] += '\u2193' # ↓ downwards arrow (sort)
rows = sorted(rows, reverse=True, key=lambda _row: _row[sort_idx][1]) # sort according to sortcol
rows = sorted(rows, reverse='loss' not in sort_col, key=lambda _row: _row[sort_idx][1]) # sort according to sortcol

# === Instantiate stats table ===
table = Table(width=console_width, box=None, row_styles=[Style(bgcolor='grey15'), ""])
Expand Down
2 changes: 1 addition & 1 deletion bittensor/_subtensor/subtensor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def serve (
'version': neuron.version,
'ip': neuron.ip,
'port': neuron.port,
'ip_type': neuron.ip_version,
'ip_type': neuron.ip_type,
'modality': neuron.modality,
'coldkey': neuron.coldkey
}
Expand Down

0 comments on commit 332ba29

Please sign in to comment.