-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pathfinder task #11
Comments
Hi @cifkao, Thanks for raising these issues and for helping us improve the codebase. For LRA, we use Also regarding
I'll address this issue in my PR as well. |
Thanks. I set I1208 11:46:39.121981 140014513174336 dataset_builder.py:529] Constructing tf.data.Dataset for split hard[:80%], from /mnt/beegfs/projects/tpt-s2a-4/data/lra_release/pathfinder32/1.0.0
Traceback (most recent call last):
File "lra_benchmarks/image/train.py", line 420, in <module>
app.run(main)
File "/mnt/beegfs/home/cifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/mnt/beegfs/home/cifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "lra_benchmarks/image/train.py", line 337, in main
normalize=normalize)
File "/mnt/beegfs/home/cifka/d/projects/long-range-arena/lra_benchmarks/image/input_pipeline.py", line 182, in get_pathfinder_base_datasets
train_dataset = get_split(f'{split}[:80%]')
File "/mnt/beegfs/home/cifka/d/projects/long-range-arena/lra_benchmarks/image/input_pipeline.py", line 175, in get_split
split=split, decoders={'image': tfds.decode.SkipDecoding()})
File "/mnt/beegfs/home/cifka/venv/lra/lib/python3.7/site-packages/tensorflow_datasets/core/dataset_builder.py", line 535, in as_dataset
) % (self.name, self._data_dir_root))
AssertionError: Dataset pathfinder32: could not find data in /mnt/beegfs/projects/tpt-s2a-4/data/lra_release/. Please make sure to call dataset_builder.download_and_prepare(), or pass download=True to tfds.load() before trying to access the tf.data.Dataset object. I think the problem is my $ ls /mnt/beegfs/projects/tpt-s2a-4/data/lra_release/pathfinder32/
curv_baseline curv_contour_length_14 curv_contour_length_9 |
There is a small problem with the zip file we release. Some extra files slipped into the zip file while archiving. We are fixing that and will upload a new zip file with a better structure of directories and no unnecessary files. |
I had in fact deleted everything except for Either way, I checked the archive and I cannot find anything called |
Thank you @cifkao for checking this and sorry for the trouble. We checked and it turned out that we released the raw images for the pathfinder datasets and you need to make a TFDS files that you can generate using this code: However, we now also have the generated TFDS files available to make it convenient for people to use LRA. Here you can download the TFDS files for pathfinder: https://storage.cloud.google.com/long-range-arena/pathfinder_tfds.gz and then set |
Seems to work, thanks! |
Perfect! Let's keep this issue open until I send the PR that adds these information to the Readme :) |
Is anyone able to reproduce the paper's results using performer on pathfinder? Accuracy is much worse (62% vs. 77%). I was able to approximately reproduce with transformer and bigbird. |
@renebidart
|
This version of the link does not require you to log in to a Google account: https://storage.googleapis.com/long-range-arena/pathfinder_tfds.gz |
I can't reproduce performer's result in pathfinder32_hard task either. Get just 50.47% best eval result. |
Me neither. Furthermore, I've taken a look at the model config and it doesn't make sense -- the QKV dim is set to 16, while MLP and hidden are 32. I've skimmed through the code, these are actual dimensions, not the head ones after split.
Similar problems exist with other tasks. |
@EternalSorrrow, I'll soon send a fix for the issue with the configs of pathfinder. |
@MostafaDehghani |
Hello, |
Could you please specify which pathfinder task is used in the paper? I'm assuming it's pathfinder32, but which difficulty?
Also, the task is broken. There is no way to specify the path to the data, and the pipeline code tries to reference
_PATHFINER_TFDS_PATH
(note the typo), which is never defined (even without the typo).The text was updated successfully, but these errors were encountered: