-
Notifications
You must be signed in to change notification settings - Fork 6
📈 Training a Model in NeoSR
So you've upscaled a few videos, but maybe you have a specific source in mind that you would like to work on or maybe you just want to try your hand at training your own neural network model. Well, it has never been easier. This guide will assume that you have a nvidia GPU (mandatory, and preferably 8GB+ VRAM) and that you're on a recent Windows OS.
We will be using neosr, which is the overall best option for training superresolution models in 2024. Trainner Redux is also a good alternative, however. Kim has ported this guide over to Trainner Redux here.
Make sure you download each of the following, or have them installed on your system already. The installation of each of these should be just a matter of selecting your system specifications, downloading the installer and running the installer.
- Python 3.12
- CUDA
- Pytorch 2.0 and TorchVision: the defaults should work: Stable, Windows, pip, Python and either CUDA version.
- chaiNNer: not a requirement per se, but I consider it basically mandatory for dataset preparation.
The dataset is probably the single most important component that goes into your model. It is essentially THE determining factor in your model's characteristics. Dataset building is time-consuming and an art in its own right, but also incredibly rewarding when you create a dataset that works. There are two types of datasets: datasets for on the fly degradation [OTF] and image pairs. If you're not too interested in building a dataset, you can also use pre-built datasets online, such as Nomos. You can also find additional datasets in the Enhance Everything Discord.
An oft-overlooked step to dataset construction is selecting a good variety of images that contain useful information for the model to learn. Images that are very simple, such as a picture of a featureless white wall, probably won't do much for your model (unless you're training a model on featureless white walls). Curated datasets such as Nomos will filter out low-information images, but otherwise, you'll have to do it yourself.
On this subject, musl, the developer of neosr, prepared the following examples and his thoughts.
- Has 5 different 'domains': grass, geometric object (car), sand, rocks, water. Diversity is key on any dataset. Anywhere the network "look" at that tile it will learn something new.
- No noise, sharp all the way through, near perfect dynamic range.
- High and medium frequencies. Textures provide high frequency details, while parts of the water have difficult medium frequencies.
- Has two domains: the foreground line art and the background, both in different styles.
- Complex and varies in shades.
- Sharp, no compression or grain/noise.
Note
Author's note: High frequency details are important for the purposes of art / animation as well-- not just for realistic images.
Note
Author's note 2: This diverges from musl's perspective, but I consider it essential to make sure depth of field and other "intentional blur" effects are included in the dataset, rather than just images which are sharp all the way through. This allows the model to learn when to leave intentional blur effects alone, and not just deblur the entire image to the detriment of faithfulness to the source.
Image pair datasets are what this guide will be focused on. Image pairs consist of a high resolution image [HR] and a low resolution image [LR]. Before you begin training, you should build your own image pair dataset (more on this shortly). Paired datasets also have the following additional requirements and consideration:
- The HR should be either 1x, 2x or 4x the size of the LR, depending on the scale of the model you choose to train. 3x is a thing as well, though not all archs support 3x.
- The LR should represent the source you plan the model to work on in terms of degradations represented (i.e., the problems the source has, such as blur, compression, etc.). The model will learn to correct these flaws, guided by the HRs.
- The HR should represent what the source should look like after going through the model (hence HRs also being referred to as ground truth).
- The images must be closely aligned. Warping or one image being a few frames off will only serve to confuse the model and produce muddled results.
Image pair datasets can be built in three different ways. Note that these methods aren't mutually exclusive-- in many scenarios, you'll want to use a combination of the three.
Real image pairs are typically composed of a LR from an old source, such as VHS or DVD, and a HR from a new source, such as Bluray or a WEB release. Real image pairs have the advantage of being the most realistic representation of the differences between a low resolution source and a modern release. Models created on purely real image pairs can look more "natural" without the oversmoothing of details and artificial sharpness which may be present in models trained on poorly prepared synthetic datasets (more on this later). It should be noted that a potential downside of using only real image pairs without enough source variety can lead to the model "overfitting" for to the dataset. It's important to introduce variety if you want your model to generalize well across different sources-- whether through synthetic image pairs (as discussed later) or having a good variety.
In addition, real image pairs can often be difficult to work with. Here are the typical steps one might take to create a real image pair dataset:
- Find the LR and HR. This might be a VHS and a BD release for example.
- Extract matching frames from the LR and HR. The frames must match as closely as possible, which can present a significant challenge. I created Image Pearer to semi-automate the process, but this is still a labor-intensive feat that requires time, dedication and patience.
- Align the pairs. After extracting the matching frames, the images generally won't be aligned properly. The HR likely won't be exactly 2x or 4x the size of the LR, and there may be warping which causes the pair to become misaligned, despite being the same frame. ImgAlign comes to the rescue-- it helps resize and warp the images as necessary to make sure the pairs are aligned properly.
Simple_Image_Compare_1.1_AOuUj2L12h.mp4
An example of image warping which will confuse the model.
- Confirm that the images are aligned properly-- when dropped into something like Simple Image Compare or your image comparison software of choice, the LR should scale to "fit" exactly into the HR.
Simple_Image_Compare_1.1_0oXbehxc2p.mp4
Confirming that a LR and HR image pair are perfectly aligned in Simple Image Compare.
- Congratulations, you now have your first image pair dataset! (Though you might want to collect more from other sources).
So yes, real image pairs can be a real pain to collect. But they often pay massive dividends being "natural" representations of the differences between a low-res and a high-res source. But if this section seemed unpalatable, fortunately, there are two other methods of creating paired datasets which would likely be much easier.
This process involves taking the LRs, and generating 2x or 4x versions of them as the HRs. Thus, you would be using "artificial" or "synthetic" HRs. Synthetic HRs are typically generated by using existing upscaling models (through chaiNNer's image iterator) or something like Topaz Video AI or Photo AI. Sure, you could do it manually, but at that point, it would take so much time that you might as well just make a real image pair dataset. While this approach can sound very simple, synthetic HRs have the downside of carrying over the faults of the model used to upscale the LR. For example, if a model has poor detail retention, and you only use synthetic HRs based on that model, your model will also have poor detail retention. Thus, it is important to be very selective of the models you use to help generate HRs or find ways to mitigate the issues.
Simple_Image_Compare_1.1_a1sGtsOzx2.mp4
An example of a synthetic HR-- remember, don't overdo the alterations made to the image.
This is the opposite of synthetic HRs-- you generate LRs from existing HRs. This is typically done by downscaling by .50% or 25% (depending on the scale of your model), then applying degradations. Fortunately, degradations can be easily applied through something such as Kim's dataset destroyer or umzi's dataset destroyer (a bit more complicated, but more features). Some degradations can be applied through chaiNNer as well. Power users will often also leverage AviSynth or VapourSynth to assist with degradations. Typically, you'll want to ensure your synthetic HRs have at least blur and some form of compression applied. That way, the model trained will learn to deblur (sharpen basically) and fix compression artifacting.
A common pitfall to avoid when using synthetic LRs is overzealous application of compression and blur. While it's important to make sure the model learns to deal with compression artifacting, applying too much compression will hurt model's ability to retain details. Similarly, too much blur will hurt detail retention and confuse the model during the training process. It'll also cause oversharpening, which is rarely desireable. So in essence, keep a balance!
Simple_Image_Compare_1.1_C3wcoSb1FF.mp4
An example of a synthetic LR with blur and JPEG degradations.
As mentioned earlier, you might often want to mix and match a combination of the three methods discussed above. If a source has a HD release, and you think it looks great, consider generating synthetic LRs for use in the dataset. On the other hand, a show might have never gotten a HD release, but you want to upscale it. You could generate synthetic HRs out of it. Then, there might be a similar show from that same era that also got a Bluray release. You might create a set of real image pairs using the original and the Bluray release. The possibilities are endless!
I'll only discuss OTF briefly, as I don't use OTF much myself. Unlike image pairs, OTF datasets... well, do not have pairs. They're a collection of images which get degraded in realtime (typically through the REAL-ESRGAN degradation pipeline) while training. This extends the training time, but can produce good results. That being said, OTF is really only suited for models focusing on real-life sources. If you're training an anime or a cartoon model, stick with image pairs. Anime and cartoons have their own specific considerations which OTF does not address.
A pitfall I see in OTF models is the over-aggressive use of degradation. If not carefully tuned, OTF generates very strong degradations, which models often can't handle properly without massive detail loss. @phhofm has a great writeup on OTF here, including recommended settings. I highly recommend checking it out if you choose to pursue the OTF route.
A validation dataset is not required, but it's still highly recommended. It will likely be your primary means of confirming your mode's progress. A validation dataset is essentially a dataset consisting of either single images or image pairs (I'd recommend no more than 8-10 images or pairs). At preset points during training, the training platform will generate images using the model's current state, and if using image pairs and you turn on the relevant settings in the config, compare them to the HRs from the validation dataset using PSNR, DISTS and/or and SSIM. Image pair validation datasets essentially represent the ideal of what your model should achieve, and the images generated during the validation process help you determine how close you are to that ideal.
As for what to include, for starters, you can include just images from your dataset if you have truly nothing else. But ideally, you'll want the validation dataset to be truly representative of the sources the model is intended for. Going one step further, you can even have individual images/pairs in your validation dataset represent the model's ability in different areas. For example, you can have one image or pair with large numbers of high frequency details serve to judge your model's detail retention ability. Or you could have a heavily degraded image or pair serve to judge your model's anti-compression abilities.
For PSNR and SSIM, you want the values as high as possible. For DISTS, you want the value as low as possible. Frankly however, all three metrics should only be used very loosely as guidance on how the model is progressing. Oftentimes, after visually comparing the output, you may find that checkpoints with worst validation values actually look better than checkpoints with superior validation values. I find that validation metrics are most useful for flagging when the model is generating horrible artifacts.
Thus, you absolutely should validate the results visually. There's no substitute for the Mark 1 eyeball! More on this later when we actually get a model trained up. :)
This has turned into hell of a section hasn't it? But we're not quite done yet! There are a few other considerations to dataset creation. First off, dataset size. There's no hard and fast rule on how big or how small your dataset should be. Even small, highly tailored datasets can work well for models that have very specific applications. That being said, you generally don't want your dataset above 5,000 to 7,000 images. With dataset preparation, quality trumps quantity!
On the topic of quality, one metric of quality is the source diversity in many cases. You'll often want a variety of sources within your dataset, to make sure that your model has a diverse set of information to learn from. Even if you're planning to train a model for say, just the Simpsons, don't pull the entire dataset from a single episode. Spread them out, and your model will adapt itself to the other parts of the show you plan on upscaling.
Architectures, or archs, are essentially frameworks for super resolution work. For example, if you've heard of ESRGAN, that'd be an architecture. Each arch has its own attributes and its own quirks, with some being focused on speed at the expense of robustness, and others being very slow but capable of handling whatever you throw at it. For simplicity's sake, we'll go with SVRGGNet for the purposes of this guide. SVRGGNet, aka Compact, is from the developers of ESRGAN. As a lighter arch, it is much faster than ESRGAN (think 10x+ in some scenarios). While it isn't quite as a robust as some of its slower counterparts, it is still quite capable and a perfect starting point due to its inference speed (aka upscaling speed), training speed, stability while training (other archs might explode on you if you're not careful) and low VRAM requirements. If you look on OpenModelDB, you'll see that these attributes have rendered it an extremely popular arch for very good reason.
Now that you have a dataset, the next step will be to install the actual training platform to train your models on. There are a few options out there, but neosr is the newest and most updated. It's also the most-well documented which makes life so much easier.
- Download neosr per the installation instructions.
- If you don't know how command line works, all you have to do is navigate to your folder where you want to install.
- In the windows explorer bar, type
cmd
and hit Enter. Command prompt should pop up. - Copy
git clone https://github.com/muslll/neosr
and then press RMB in command prompt to paste the command in. - Press enter, and you should see the download begin.
- You should have installed all the prerequisites earlier, but install them now if not.
Configs are how you prepare the training software to actually train the model on the dataest you prepared.
- Navigate to the
options
folder in your neosr base folder, and find thetrain_compact.toml
file. - Open it up, and take a look at the Configuration Walkthrough in the neosr wiki page. The default config is very good, but you'll want to at least fill in the paths to your datasets.
You should generally be following the default settings in the config, but in the order of appearance, here are a few notes regarding the config:
- I recommend keeping
use_amp
,bfloat16
andcompile
off. You can experiment with AMP and bfloat16 later, but they can be unstable with some systems/configs. - The default
patch_size
andbatch_size
in the config is a perfectly reasonably value. But if you have the VRAM and system resources, turning them up can greatly help with your model quality (though it'll slow training too). Typicalpatch_size
used are 32, 48, 64, 96 and 128. It's recommend thatpatch_size
is kept as a multiple of 16.batch_size
helps with training stability, which can be very important for less stable archs. Please refer to the Monitor Training section for an example of training instability. - You should make sure to use a pretrain. Pretrains for compact can be found here. Pretrains serve as a starting point to your training, and will speed up the process substantially. Make sure to pick just the normal Compact version for your scale (don't pick UltraCompact for example). Without a pretrain, you also won't be able to combine your model with other Compact models that have compatible pretrains (aka interpolation). And for clarification, pretrains are very much just a starting point for your model to help accelerate training, and has limited influence on the actual results. The dataset still remains the most significant factor in your model's output.
- Make sure to uncomment and fill out the validation dataset information if you have a validation dataset.
- Don't worry about anything under the
train
section of the config-- the default settings are optimal. Some more advanced users may want to tweak loss values for specific purposes, but the returns are dubious and likely minimal.
Tensorboard provides a fancy interface with graphs to track your model training. If you like fancy graphs, you can install tensorboard as follows.
- Make sure
use_tb_logger
is set totrue
in your config-- it should be the default. - Install tensorboard via
pip install tensorboard
. - After you start training, or have trained a model, you can launch tensorboard via
tensorboard --logdir NEOSRPATH\neosr\experiments\tb_logger\MODELNAME
. If it's working, you should see something like the below.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.14.0 at http://localhost:6006/ (Press CTRL+C to quit)
- By default, you should be able to access it http://localhost:6006/ via your browser.
You should see graphs like the above when opening the tensorboard page.
Now that you have a dataset created, a neosr installed and a config set up, it's finally time to begin training! Return to your neosr base folder, launch cmd again, and paste in the following: python train.py -opt options/train_compact.toml
. Note that if you changed the name of the default compact config, you'll have to update the command accordingly.
- If you want to end training, press Ctrl+C.
- If you want to resume training, you can set a resume state in the config file (found in the
experiments
> folder with your model name >training_states
) or simply use--auto_resume
, such aspython train.py -opt options/train_compact.toml --auto_resume
as your training command to start from the last stopped point.
If all is working as intended, you should see something like this.
2023-10-01 14:52:37,916 INFO:
------------------------ neosr ------------------------
Pytorch Version: 2.1.0.dev20230903+cu118
2023-10-01 14:52:38,189 INFO: Dataset [paired] is built.
2023-10-01 14:52:38,189 INFO: Training statistics:
Number of train images: 4079
Dataset enlarge ratio: 8
Batch size per gpu: 20
World size (gpu number): 1
Require iter number per epoch: 1632
Total epochs: 307; iters: 500000.
2023-10-01 14:52:38,191 INFO: Dataset [paired] is built.
2023-10-01 14:52:38,191 INFO: Number of val images/folders in Validation: 5
2023-10-01 14:52:38,497 INFO: Network [compact] is created.
2023-10-01 14:52:38,570 INFO: Network [unet] is created.
2023-10-01 14:52:38,643 INFO: Loading compact model from D:\Users\Sirosky\Jottacloud\Media\Upscaling\Models\Pretrains\Compact\1x_AniScale2_Refiner_i4_10K.pth, with param key: [params].
2023-10-01 14:52:38,651 INFO: Loss [L1Loss] is created.
2023-10-01 14:52:39,152 INFO: Loss [PerceptualLoss] is created.
2023-10-01 14:52:39,160 INFO: Loss [GANLoss] is created.
2023-10-01 14:52:39,160 INFO: Loss [colorloss] is created.
2023-10-01 14:52:39,161 INFO: Model [default] is created.
2023-10-01 14:53:47,746 INFO: Using CUDA prefetch dataloader.
2023-10-01 14:53:47,746 INFO: Start training from epoch: 0, iter: 0
This means that the model has begun training. It might take a few minutes, but then you should see something like this show up.
2023-10-01 14:55:26,227 INFO: [epoch: 0] [iter: 100] [performance: 1.015 it/s] [lr:(1.000e-04)] [eta: 1 day, 23:38:36, data_time: 4.7605e-01 l_g_pix: 2.3688e-03 l_percep: 7.7059e-01 l_g_color: 9.1230e-04 l_g_gan: 6.9329e-02 l_d_real: 6.9330e-01 out_d_real: -3.0924e-04 l_d_fake: 6.9301e-01 out_d_fake: -2.8798e-04
This means the model is training properly. As for all the numbers:
-
performance
shows the iterations per second. By default, the model will save every 1000 iterations, and run a validation check every 1000 iterations. I have mine set to save every 2,500 iterations, with validation every 5,000 iterations to avoid the spam. - the
eta
is determined by your config, and your current speed. I usually just ignore this. - the loss values such as
l_g_pix
should generally be close to 0 (except forl_g_gan
). That being said, don't worry about it too much for now.
Now, all there's left to do is to monitor the progress of your model. Ideally, you will have created a validation dataset, and you should see cmd spewing our validation results every once in a while, like this:
2023-10-01 15:56:00,819 INFO: Validation Validation
# psnr: 45.6487 Best: 45.9801 @ 5000 iter
# ssim: 0.9978 Best: 0.9979 @ 5000 iter
Here, you can see a PSNR value of over 45, which is extremely high. The SSIM value is also in agreement, at .9978 (out of 100). So this model is in very good shape. But what's this? Why is our current PNSR value (on the left, 45.6487) lower than the Best value of 45.9801 at 5,000 iterations? This may indicate that the model actually peaked at 5,000 iterations. As mentioned before, it's time to bring out the big guns: the Mark 1 eyeball. Never just look at the value and take it at face value!
You can find the model's generated validation images under neosr
>experiments
>folder with your model name>visualization
. Here, you'll see images generated from your validation dataset's LRs. You will then want to compare the generated images between themselves, and also with the validation dataset's HRs. With a visual check, you can determine whether the model is still training, stagnating or has gone FUBAR.
Simple_Image_Compare_1.1_BaQMEXFy9a.mp4
An example of training instability-- these validation images were 5000 iterations apart. Ideally, these artifacts will go away with continued training.
If after several validation checks your model's validation scores slip or remain stagnant, it's a likely indication that your model is done training. At the risk of sounding like a broken record, please do confirm visually as well.I typically complete Compact models at less than 40K iterations. Some even complete at 5K or 10K iterations, if I use a very on-point pretrain.
Once you determine your model is complete, you can find it in the models
folder next to the visualization
folder. The file with _g
is the actual model-- the _d
is a discriminator used in the training process. You don't need it for inference purposes, but it's necessary if you ever want to resume training the same model.
With that, congratulations on your trained model!
If you've made it this far, congratulations on your first trained model. Model training has quite a steep learning curve, but I hope this guide made it a bit easier to decrypt the process. With that being said, if you do decide to continue training models, consider checking out my (much shorter) writeup on Model Training Principles. While much of it is anime-focused, there is still plenty applicable to models dedicated to real-life sources and others. Obviously, some of this will be up to subjective taste, but hopefully it'll provide a sense of what to look for as you prepare datasets and train your own models. Good luck, may the neural network gambling parlor be ever in your favor!
Feel free to contact me on Discord (sirosky), or join the Enhance Everything Discord server, which contains additional resources for trainers.
- A big thank you for muslll for actually creating neosr, without which this guide would not exist, and validating the guide via Mark 1 eyeball.
- Credits to zarxrax, who wrote the incredibly helpful guide Training a Compact model in Real-ESRGAN, which started my model training addiction. In addition, this guide is very much inspired by his guide!