Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on VOCASET produces static template output #5

Closed
leohku opened this issue Feb 20, 2024 · 16 comments
Closed

Training on VOCASET produces static template output #5

leohku opened this issue Feb 20, 2024 · 16 comments

Comments

@leohku
Copy link

leohku commented Feb 20, 2024

Symptoms

Training VOCASET with the supplied wav2vec2 script produces a static output of the template for any audio input. Here's an example:
CleanShot 2024-02-20 at 17 27 56@2x
Here are the Tensorboard error graphs:
CleanShot 2024-02-20 at 16 21 34@2x

Notably, the loss seems to be weirdly small for both components.

Steps to reproduce

  • In the dataset/vocaset folder, I copied over the vertices_npy and wav folders that's also used for FaceFormer training
  • Then, I trained DiffSpeaker on VOCASET using the provided script here scripts/diffusion/vocaset_training/diffspeaker_wav2vec2_vocaset.sh.
  • Finally, I ran the demo script using the following arguments
python demo_vocaset.py \
    --cfg configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml \
    --cfg_assets configs/assets/vocaset.yaml \
    --template datasets/vocaset/templates.pkl \
    --example demo/wavs/speech_long.wav \
    --ply datasets/vocaset/templates/FLAME_sample.ply \
    --checkpoint experiments/vocaset/diffusion_bias/diffspeaker_wav2vec2_vocaset/checkpoints/epoch=8999.ckpt \
    --id FaceTalk_170809_00138_TA

Troubleshooting steps tried

  • I verified that the scale of the data between the provided templates.pkl and the self-supplied .npy files are the same
  • The input .npy files are 60FPS, but I left the [::2,:] in the load_data function untouched
  • Doing a simple print(len(self.data_splits['train'])) in alm/data/vocaset.py, I can see that 314 training samples have been loaded
  • Using git, I verified that the source code, as well as the YAML files have not been altered
  • Except for scipy (mine is 1.12.0 vs 1.9.1 in requirements.txt), all pip packages have the same version as the supplied requirements.txt

Logs

Here's the truncated logs of the training

2024-02-19 16:33:21,268 SEED_VALUE: 1234
DEBUG: false
TRAIN:
  SPLIT: train
  NUM_WORKERS: 1
  BATCH_SIZE: 32
  START_EPOCH: 0
  END_EPOCH: 9000
  RESUME: ''
  PRETRAINED: ''
  OPTIM:
    OPTIM.TYPE: AdamW
    OPTIM.LR: 0.0001
    TYPE: AdamW
    LR: 0.0001
  ABLATION:
    SKIP_CONNECT: false
    MLP_DIST: false
    IS_DIST: false
  DATASETS:
  - vocaset
EVAL:
  SPLIT: gtest
  BATCH_SIZE: 32
  NUM_WORKERS: 1
  DATASETS:
  - vocaset
TEST:
  TEST_DIR: ''
  CHECKPOINTS: checkpoints/biwi/diffspeaker_wav2vec2_vocaset.ckpt
  NUM_WORKERS: 1
  BATCH_SIZE: 1
  REPLICATION_TIMES: 10
  SAVE_PREDICTIONS: false
  COUNT_TIME: false
  DATASETS:
  - vocaset
  SPLIT: test
  FOLDER: ./results
model:
  target: diffusion/diffusion_bias_modules
  audio_encoded_dim: 768
  model_type: diffusion_bias
  latent_dim: 512
  id_dim: 8
  ff_size: 1024
  num_layers: 1
  num_heads: 4
  dropout: 0.1
  max_len: 600
  activation: gelu
  normalize_before: true
  require_start_token: true
  arch: default
  predict_epsilon: false
  freq_shift: 0
  flip_sin_to_cos: true
  mem_attn_scale: 1.0
  tgt_attn_scale: 1.0
  audio_fps: 50
  hidden_fps: 30
  guidance_scale: 0
  guidance_uncondp: 0.0
  period: 30
  no_cross: false
  smooth_output: true
  denoiser:
    target: alm.models.architectures.adpt_bias_denoiser.Adpt_Bias_Denoiser
    params:
      audio_encoded_dim: ${model.audio_encoded_dim}
      ff_size: ${model.ff_size}
      num_layers: ${model.num_layers}
      num_heads: ${model.num_heads}
      dropout: ${model.dropout}
      normalize_before: ${model.normalize_before}
      activation: ${model.activation}
      return_intermediate_dec: false
      arch: ${model.arch}
      latent_dim: ${model.latent_dim}
      nfeats: ${DATASET.NFEATS}
      freq_shift: ${model.freq_shift}
      flip_sin_to_cos: ${model.flip_sin_to_cos}
      max_len: 3000
      id_dim: ${model.id_dim}
      require_start_token: ${model.require_start_token}
      mem_attn_scale: ${model.mem_attn_scale}
      tgt_attn_scale: ${model.tgt_attn_scale}
      audio_fps: ${model.audio_fps}
      hidden_fps: ${model.hidden_fps}
      guidance_scale: ${model.guidance_scale}
      guidance_uncondp: ${model.guidance_uncondp}
      period: ${model.period}
      no_cross: ${model.no_cross}
  scheduler:
    target: diffusers.DDIMScheduler
    num_inference_timesteps: 50
    eta: 0.0
    params:
      num_train_timesteps: 1000
      beta_start: 0.00085
      beta_end: 0.012
      beta_schedule: scaled_linear
      clip_sample: false
      prediction_type: sample
      set_alpha_to_one: false
      steps_offset: 1
  noise_scheduler:
    target: diffusers.DDPMScheduler
    params:
      num_train_timesteps: 1000
      beta_start: 0.00085
      beta_end: 0.012
      beta_schedule: scaled_linear
      variance_type: fixed_small
      prediction_type: sample
      clip_sample: false
LOSS:
  TYPE: voca
  VERTICE_ENC: 1
  VERTICE_ENC_V: 1
  LIP_ENC: 0
  LIP_ENC_V: 0
  DIST_SYNC_ON_STEP: true
METRIC: None
DATASET:
  VOCASET:
    NONE: none
    ROOT: ./datasets/vocaset
  JOINT_TYPE: vocaset
DEMO:
  EAMPLE: null
  ID: null
  CHECKPOINTS: templates
  TEMPLATE: datasets/vocaset/templates.pkl
  PLY: datasets/vocaset/templates/FLAME_sample.ply
  FPS: 30
LOGGER:
  SACE_CHECKPOINT_EPOCH: 1000
  LOG_EVERY_STEPS: 100
  VAL_EVERY_STEPS: 1000
  TENSORBOARD: true
  WANDB:
    OFFLINE: false
    PROJECT: null
    RESUME_ID: null
NAME: diffspeaker_wav2vec2_vocaset
ACCELERATOR: gpu
DEVICE:
- 0
audio_encoder:
  train_audio_encoder: true
  model_name_or_path: facebook/wav2vec2-base-960h
target: diffusion/diffusion_bias_modules
audio_encoded_dim: 768
model_type: diffusion_bias
latent_dim: 512
id_dim: 8
ff_size: 1024
num_layers: 1
num_heads: 4
dropout: 0.1
max_len: 600
activation: gelu
normalize_before: true
require_start_token: true
arch: default
predict_epsilon: false
freq_shift: 0
flip_sin_to_cos: true
mem_attn_scale: 1.0
tgt_attn_scale: 1.0
audio_fps: 50
hidden_fps: 30
guidance_scale: 0
guidance_uncondp: 0.0
period: 30
no_cross: false
smooth_output: true
denoiser:
  target: alm.models.architectures.adpt_bias_denoiser.Adpt_Bias_Denoiser
  params:
    audio_encoded_dim: ${model.audio_encoded_dim}
    ff_size: ${model.ff_size}
    num_layers: ${model.num_layers}
    num_heads: ${model.num_heads}
    dropout: ${model.dropout}
    normalize_before: ${model.normalize_before}
    activation: ${model.activation}
    return_intermediate_dec: false
    arch: ${model.arch}
    latent_dim: ${model.latent_dim}
    nfeats: ${DATASET.NFEATS}
    freq_shift: ${model.freq_shift}
    flip_sin_to_cos: ${model.flip_sin_to_cos}
    max_len: 3000
    id_dim: ${model.id_dim}
    require_start_token: ${model.require_start_token}
    mem_attn_scale: ${model.mem_attn_scale}
    tgt_attn_scale: ${model.tgt_attn_scale}
    audio_fps: ${model.audio_fps}
    hidden_fps: ${model.hidden_fps}
    guidance_scale: ${model.guidance_scale}
    guidance_uncondp: ${model.guidance_uncondp}
    period: ${model.period}
    no_cross: ${model.no_cross}
scheduler:
  target: diffusers.DDIMScheduler
  num_inference_timesteps: 50
  eta: 0.0
  params:
    num_train_timesteps: 1000
    beta_start: 0.00085
    beta_end: 0.012
    beta_schedule: scaled_linear
    clip_sample: false
    prediction_type: sample
    set_alpha_to_one: false
    steps_offset: 1
noise_scheduler:
  target: diffusers.DDPMScheduler
  params:
    num_train_timesteps: 1000
    beta_start: 0.00085
    beta_end: 0.012
    beta_schedule: scaled_linear
    variance_type: fixed_small
    prediction_type: sample
    clip_sample: false
FOLDER: ./experiments/vocaset
FOLDER_EXP: experiments/vocaset/diffusion_bias/diffspeaker_wav2vec2_vocaset
TIME: 2024-02-19-16-33-21

2024-02-19 16:33:57,177 datasets module vocaset initialized
2024-02-19 16:33:57,504 No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'
2024-02-19 16:34:01,962 model diffusion_bias loaded
2024-02-19 16:34:01,962 Callbacks initialized
2024-02-19 16:34:01,975 Trainer initialized
2024-02-19 16:34:02,529 Training started
2024-02-19 16:34:09,795 Epoch 0: Train_vertice_recon 3.744e-07   Train_vertice_reconv 2.512e-08   Train_lip_recon 0.000e+00   Train_lip_reconv 0.000e+00   Memory 18.4%
2024-02-19 16:34:13,748 Epoch 1: Train_vertice_recon 3.773e-07   Train_vertice_reconv 2.527e-08   Train_lip_recon 0.000e+00   Train_lip_reconv 0.000e+00   Memory 18.4%
2024-02-19 16:34:18,036 Epoch 2: Train_vertice_recon 3.716e-07   Train_vertice_reconv 2.496e-08   Train_lip_recon 0.000e+00   Train_lip_reconv 0.000e+00   Memory 18.1%
...
2024-02-20 02:55:02,321 Epoch 8998: Train_vertice_recon 3.687e-07   Train_vertice_reconv 2.469e-08   Train_lip_recon 0.000e+00   Train_lip_reconv 0.000e+00   Val_vertice_recon 4.144e-07   Val_vertice_reconv 3.175e-08   Val_lip_recon 0.000e+00   Val_lip_reconv 0.000e+00   Memory 10.6%
2024-02-20 02:55:08,536 Epoch 8999: Train_vertice_recon 3.634e-07   Train_vertice_reconv 2.439e-08   Train_lip_recon 0.000e+00   Train_lip_reconv 0.000e+00   Val_vertice_recon 4.144e-07   Val_vertice_reconv 3.175e-08   Val_lip_recon 0.000e+00   Val_lip_reconv 0.000e+00   Memory 10.6%
2024-02-20 02:55:09,241 Training done
2024-02-20 02:55:09,920 The checkpoints are stored in experiments/vocaset/diffusion_bias/diffspeaker_wav2vec2_vocaset/checkpoints
2024-02-20 02:55:09,921 The outputs of this experiment are stored in experiments/vocaset/diffusion_bias/diffspeaker_wav2vec2_vocaset
2024-02-20 02:55:09,921 Training ends!
@theEricMa
Copy link
Owner

🤔 My log shows that after initialization, the Train_vertice_recon loss is around 1e-05. I think there should be something wrong with the dataset you used.

2023-08-08 05:22:20,923 Epoch 0: Train_vertice_recon 2.833e-05   Train_vertice_reconv 8.574e-07   Memory 8.3%

Can you check if this still exists on biwi dataset?

@hqm0810
Copy link

hqm0810 commented Feb 22, 2024

🤔 My log shows that after initialization, the Train_vertice_recon loss is around 1e-05. I think there should be something wrong with the dataset you used.

2023-08-08 05:22:20,923 Epoch 0: Train_vertice_recon 2.833e-05   Train_vertice_reconv 8.574e-07   Memory 8.3%

Can you check if this still exists on biwi dataset?

hi, I meet the same problem, and when i debug, I find the self.motion_decoder have been initilized with all zeros, and it seems not been update in training follow your instruction, which results in the variable vertice_out=0 all the time.

I thought there are some bug in the code, something like variable need training not in optimizer. Could you train it from scractch using master branch code, or give some tips, appreciate your kindness.

Here is some of my training results.
image

@leohku
Copy link
Author

leohku commented Feb 22, 2024

@hqm0810 I've narrowed down the problem to an issue in the loss update function. Specifically, in allsplit_step,

# training
if split == "train":
    if self.guidance_uncondp > 0: # we randomly mask the audio feature
        audio_mask = torch.rand(batch['audio'].shape[0]) < self.guidance_uncondp
        batch['audio'][audio_mask] = 0

    rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
    loss = self.losses[split].update(rs_set)
    return loss

While the rs_set looks normal, the loss returned to the Trainer is None, which prevented backprop. This means the loss is indeed being calculated and logged by Torchmetric, but there is a difference in between the Tensor returned by the loss update() function and the value received by allsplit_step().

Another issue I've identified is that the loss tensor seems to have requires_grad=False.

Can you see if this is the cause for your problem?

@hqm0810
Copy link

hqm0810 commented Feb 22, 2024

@hqm0810 I've narrowed down the problem to an issue in the loss update function. Specifically, in allsplit_step,

# training
if split == "train":
    if self.guidance_uncondp > 0: # we randomly mask the audio feature
        audio_mask = torch.rand(batch['audio'].shape[0]) < self.guidance_uncondp
        batch['audio'][audio_mask] = 0

    rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
    loss = self.losses[split].update(rs_set)
    return loss

While the rs_set looks normal, the loss returned to the Trainer is None, which prevented backprop. This means the loss is indeed being calculated and logged by Torchmetric, but there is a difference in between the Tensor returned by the loss update() function and the value received by allsplit_step().

Another issue I've identified is that the loss tensor seems to have requires_grad=False.

Can you see if this is the cause for your problem?

Thank you, the cause of the problem is the VOCALosses always return None when call update function (although the loss in update is normal), so I simply reimplement the VOCALosses in DIFFUSION_BIAS, everything is normal, Thank you very much for your answer again.

@leohku
Copy link
Author

leohku commented Feb 22, 2024

It seems like for some reason computing the loss term in the Metrics class will cause the output tensors to lose its gradients, as well as incorrectly compute smaller losses. Using @hqm0810's answer as inspiration, I was able to construct a minimal working example:

For file alm/models/modeltype/diffusion_bias.py,

@@ -6,7 +6,7 @@ from transformers import Wav2Vec2Model
 
 from alm.config import instantiate_from_config
 from alm.models.modeltype.base import BaseModel
-from alm.models.losses.voca import VOCALosses
+from alm.models.losses.voca import VOCALosses, MaskedConsistency, MaskedVelocityConsistency
 from alm.utils.demo_utils import animate
 from .base import BaseModel
 
@@ -44,6 +42,8 @@ class DIFFUSION_BIAS(BaseModel):
             key: self._losses["losses_" + key]
             for key in ["train", "test", "val", ] # "train_val"
         }
+        self.reconstruct = MaskedConsistency()
+        self.reconstruct_v = MaskedVelocityConsistency()
 
         # set up model
         self.audio_encoder = Wav2Vec2Model.from_pretrained(cfg.audio_encoder.model_name_or_path)
@@ -114,7 +114,12 @@ class DIFFUSION_BIAS(BaseModel):
                 batch['audio'][audio_mask] = 0
 
             rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
-            loss = self.losses[split].update(rs_set)
+            
+            mask = rs_set['vertice_attention'].unsqueeze(-1)
+            loss1 = self.reconstruct(rs_set['vertice'], rs_set['vertice_pred'], mask)
+            loss2 = self.reconstruct_v(rs_set['vertice'], rs_set['vertice_pred'], mask)
+            loss = loss1 + loss2
+            self.losses[split].update(loss1, loss2, loss)
             return loss

For file alm/models/losses/voca.py,

@@ -118,31 +118,13 @@ class VOCALosses(Metric):
     #     lip_vertice = vertice.view(shape[0], shape[1], -1, 3)[:, :, mouth_map, :].view(shape[0], shape[1], -1)
     #     return lip_vertice
 
-    def update(self, rs_set):
-        # rs_set.keys() = dict_keys(['latent', 'latent_pred', 'vertice', 'vertice_recon', 'vertice_pred', 'vertice_attention'])
-
-        total: float = 0.0
-        # Compute the losses
-        # Compute instance loss
-
-        # padding mask
-        mask = rs_set['vertice_attention'].unsqueeze(-1)
-
+    def update(self, recon, recon_v, ttl):
         if self.split in ['losses_train', 'losses_val']: 
-            # vertice loss
-            total += self._update_loss("vertice_enc", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
-            total += self._update_loss("vertice_encv", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
-
-            # lip loss
-            # lip_vertice = self.vert2lip(rs_set['vertice'])
-            # lip_vertice_pred = self.vert2lip(rs_set['vertice_pred'])
-            # total += self._update_loss("lip_enc", lip_vertice, lip_vertice_pred, mask = mask)
-            # total += self._update_loss("lip_encv", lip_vertice, lip_vertice_pred, mask = mask)
-
-            self.total += total.detach()
+            self.vertice_enc += recon.detach()
+            self.vertice_encv += recon_v.detach()
+            self.total += ttl.detach()
             self.count += 1
-
-            return total
+            return ttl
         
         if self.split in ['losses_test']:
             raise ValueError(f"split {self.split} not supported")

This allows the model to train with the correct losses (Train_vertice_recon loss being around 1e-05 in the first epoch). But further modifications are required to implement it for the Validation stage.

@aixiaodewugege
Copy link

Hi,all. Is the bug fixed in master branch?

@aixiaodewugege
Copy link

aixiaodewugege commented May 8, 2024

It seems like for some reason computing the loss term in the Metrics class will cause the output tensors to lose its gradients, as well as incorrectly compute smaller losses. Using @hqm0810's answer as inspiration, I was able to construct a minimal working example:

For file alm/models/modeltype/diffusion_bias.py,

@@ -6,7 +6,7 @@ from transformers import Wav2Vec2Model
 
 from alm.config import instantiate_from_config
 from alm.models.modeltype.base import BaseModel
-from alm.models.losses.voca import VOCALosses
+from alm.models.losses.voca import VOCALosses, MaskedConsistency, MaskedVelocityConsistency
 from alm.utils.demo_utils import animate
 from .base import BaseModel
 
@@ -44,6 +42,8 @@ class DIFFUSION_BIAS(BaseModel):
             key: self._losses["losses_" + key]
             for key in ["train", "test", "val", ] # "train_val"
         }
+        self.reconstruct = MaskedConsistency()
+        self.reconstruct_v = MaskedVelocityConsistency()
 
         # set up model
         self.audio_encoder = Wav2Vec2Model.from_pretrained(cfg.audio_encoder.model_name_or_path)
@@ -114,7 +114,12 @@ class DIFFUSION_BIAS(BaseModel):
                 batch['audio'][audio_mask] = 0
 
             rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
-            loss = self.losses[split].update(rs_set)
+            
+            mask = rs_set['vertice_attention'].unsqueeze(-1)
+            loss1 = self.reconstruct(rs_set['vertice'], rs_set['vertice_pred'], mask)
+            loss2 = self.reconstruct_v(rs_set['vertice'], rs_set['vertice_pred'], mask)
+            loss = loss1 + loss2
+            self.losses[split].update(loss1, loss2, loss)
             return loss

For file alm/models/losses/voca.py,

@@ -118,31 +118,13 @@ class VOCALosses(Metric):
     #     lip_vertice = vertice.view(shape[0], shape[1], -1, 3)[:, :, mouth_map, :].view(shape[0], shape[1], -1)
     #     return lip_vertice
 
-    def update(self, rs_set):
-        # rs_set.keys() = dict_keys(['latent', 'latent_pred', 'vertice', 'vertice_recon', 'vertice_pred', 'vertice_attention'])
-
-        total: float = 0.0
-        # Compute the losses
-        # Compute instance loss
-
-        # padding mask
-        mask = rs_set['vertice_attention'].unsqueeze(-1)
-
+    def update(self, recon, recon_v, ttl):
         if self.split in ['losses_train', 'losses_val']: 
-            # vertice loss
-            total += self._update_loss("vertice_enc", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
-            total += self._update_loss("vertice_encv", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
-
-            # lip loss
-            # lip_vertice = self.vert2lip(rs_set['vertice'])
-            # lip_vertice_pred = self.vert2lip(rs_set['vertice_pred'])
-            # total += self._update_loss("lip_enc", lip_vertice, lip_vertice_pred, mask = mask)
-            # total += self._update_loss("lip_encv", lip_vertice, lip_vertice_pred, mask = mask)
-
-            self.total += total.detach()
+            self.vertice_enc += recon.detach()
+            self.vertice_encv += recon_v.detach()
+            self.total += ttl.detach()
             self.count += 1
-
-            return total
+            return ttl
         
         if self.split in ['losses_test']:
             raise ValueError(f"split {self.split} not supported")

This allows the model to train with the correct losses (Train_vertice_recon loss being around 1e-05 in the first epoch). But further modifications are required to implement it for the Validation stage.

Thanks for your code! Have you trained the dataset and verified the result?

@theEricMa
Copy link
Owner

Dear all, thanks for your effort. We have released the missing part for training. You can now train the model with decreasing losses.

@aixiaodewugege
Copy link

Dear all, thanks for your effort. We have released the missing part for training. You can now train the model with decreasing losses.

Hi, thanks for your reply. But I didn't find your latest commit. How should I use the latest code?

@moliq1
Copy link

moliq1 commented May 29, 2024

Hi, thanks for your reply. But I didn't find your latest commit. How should I use the latest code?

Hi @aixiaodewugege, do you get any results? I rewrited the loss calculation according to the above and trained on the vocaset for around 3500 epoches, when I tested, the results is still not good, the mouth even not opened.

@aixiaodewugege
Copy link

aixiaodewugege commented May 29, 2024

Hi. I can train it on vocaset and get good result after 9000 epoch.

@aixiaodewugege
Copy link

Hi, thanks for your reply. But I didn't find your latest commit. How should I use the latest code?

Hi @aixiaodewugege, do you get any results? I rewrited the loss calculation according to the above and trained on the vocaset for around 3500 epoches, when I tested, the results is still not good, the mouth even not opened.

Hi,Have you trained it on muti GPU? I only be able to train it on single GPU.

@moliq1
Copy link

moliq1 commented Jun 7, 2024

Hi, thanks for your reply. But I didn't find your latest commit. How should I use the latest code?

Hi @aixiaodewugege, do you get any results? I rewrited the loss calculation according to the above and trained on the vocaset for around 3500 epoches, when I tested, the results is still not good, the mouth even not opened.

Hi,Have you trained it on muti GPU? I only be able to train it on single GPU.

I havn't try it on muti GPU yet. But the single GPU is OK for me too.

@aixiaodewugege
Copy link

Hi, thanks for your reply. But I didn't find your latest commit. How should I use the latest code?

Hi @aixiaodewugege, do you get any results? I rewrited the loss calculation according to the above and trained on the vocaset for around 3500 epoches, when I tested, the results is still not good, the mouth even not opened.

Hi,Have you trained it on muti GPU? I only be able to train it on single GPU.

I havn't try it on muti GPU yet. But the single GPU is OK for me too.

Thanks. If you could fix the muti GPU problem, please teach me with it~~~

@chenyinlin1
Copy link

Hi. I can train it on vocaset and get good result after 9000 epoch.

Hi, I tried to modify the loss function a bit, but I still can't train the expected results, can I ask how you made the modification? Thank you!

@yangyifan18
Copy link

in torchmetrics 0.11.4, the update() function of Metric always return None because of this _wrap_update() function,so you‘re not supposed to return anything in your rewrite update() function.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants