Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSSM UI #362

Merged
merged 26 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
668c215
DeepSSM add data tabs and images
JakeWags Apr 3, 2024
9b2296c
Remove hanging TODO comments
JakeWags Apr 3, 2024
40c3cb5
Flip DeepSSM tab vs expansion panel order
JakeWags Apr 4, 2024
f810044
Add data tables to UI, needs styling
JakeWags Apr 4, 2024
a3854e7
Update DeepSSM UI to match backend changes
JakeWags Apr 4, 2024
b392950
Add training_pairs to output
JakeWags Apr 8, 2024
04f09e0
Pull filename from augmentation entries in datatable
JakeWags Apr 8, 2024
8abe7de
Fix datatable styling
JakeWags Apr 8, 2024
d465f05
Cleanup of deepssm tab
JakeWags Apr 8, 2024
c49e866
Update rest to remove lazy import
JakeWags Apr 9, 2024
0c2eae3
Augmentation shape viewer
JakeWags Apr 9, 2024
93c512a
Add index to training image model
JakeWags Apr 9, 2024
ae3eada
Update error with string processing
JakeWags Apr 9, 2024
98f10b5
Training image (WIP)
JakeWags Apr 9, 2024
ff8f968
Add groom and optimize to deepssm data controls
JakeWags Apr 10, 2024
edaeb7a
Update augmentation shapeviewer to have image, particles, and mesh
JakeWags Apr 10, 2024
6e36515
Update image_id field to be str
JakeWags Apr 11, 2024
5b2f6b5
DeepSSM heatmap, shapeviewer, etc
JakeWags Apr 11, 2024
dc1136c
Heatmap, color scales, image viewer fixes (WIP)
JakeWags Apr 11, 2024
61462b2
Fix hanging migration
JakeWags Apr 12, 2024
81f1a08
Uniform and non-uniform scalar bars for deepssm heatmaps
JakeWags Apr 12, 2024
6dab27f
fix: make DeepSSMUtils & DataAugmentationUtils imports lazy
annehaley Apr 12, 2024
d0f174a
fix: shapeworks import should be lazy too
annehaley Apr 12, 2024
a9bdf79
remove old TODO comments
annehaley Apr 12, 2024
1e34888
Merge branch 'master' into deepssm-ui
annehaley Apr 12, 2024
45ab369
style: remove whitespace
annehaley Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions shapeworks_cloud/core/deepssm_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import DataAugmentationUtils
import DeepSSMUtils
from celery import shared_task
from django.conf import settings
from django.contrib.auth.models import User
from rest_framework.authtoken.models import Token
import shapeworks as sw

from shapeworks_cloud.core import models
from swcc.api import swcc_session
Expand All @@ -18,6 +15,9 @@


def run_prep(params, project, project_file, progress):
import DeepSSMUtils
import shapeworks as sw

# //////////////////////////////////////////////
# /// STEP 1: Create Split
# //////////////////////////////////////////////
Expand All @@ -31,6 +31,7 @@ def run_prep(params, project, project_file, progress):
# /// STEP 2: Groom Training Shapes
# /////////////////////////////////////////////////////////////////
project_params = project.get_parameters('groom')
# alignment should always be set to ICP
project_params.set('alignment_method', 'Iterative Closest Point')
project_params.set('alignment_enabled', 'true')
project.set_parameters('groom', project_params)
Expand All @@ -41,13 +42,6 @@ def run_prep(params, project, project_file, progress):
# /////////////////////////////////////////////////////////////////
# /// STEP 3: Optimize Training Particles
# /////////////////////////////////////////////////////////////////

# set num_particles to 16 and iterations_per_split to 1
project_params = project.get_parameters('optimize')
project_params.set('number_of_particles', '16')
project_params.set('iterations_per_split', '1')
project.set_parameters('optimize', project_params)

DeepSSMUtils.optimize_training_particles(project)
project.save(project_file)
progress.update_percentage(12)
Expand Down Expand Up @@ -96,12 +90,14 @@ def run_prep(params, project, project_file, progress):


def run_augmentation(params, project, download_dir, progress):
import DataAugmentationUtils
import DeepSSMUtils

# /////////////////////////////////////////////////////////////////
# /// STEP 7: Augment Data
# /////////////////////////////////////////////////////////////////
num_samples = int(params['aug_num_samples'])
percent_variability = float(params['percent_variability']) / 100.0
# aug_sampler_type to lowecase
percent_variability = float(params['percent_variability'])
aug_sampler_type = params['aug_sampler_type'].lower()

num_dims = 0 # set to 0 to allow for percent variability to be used
Expand All @@ -127,6 +123,8 @@ def run_augmentation(params, project, download_dir, progress):


def run_training(params, project, download_dir, aug_dims, progress):
import DeepSSMUtils

batch_size = int(params['train_batch_size'])

# /////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -176,6 +174,8 @@ def run_training(params, project, download_dir, aug_dims, progress):


def run_testing(params, project, download_dir, progress):
import DeepSSMUtils

test_indices = DeepSSMUtils.get_split_indices(project, 'test')

# /////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -219,6 +219,8 @@ def run_deepssm_command(
post_command_function,
progress_id,
):
import shapeworks as sw

user = User.objects.get(id=user_id)
progress = models.TaskProgress.objects.get(id=progress_id)
token, _created = Token.objects.get_or_create(user=user)
Expand Down Expand Up @@ -256,6 +258,21 @@ def run_deepssm_command(

sw_project.load(sw_project_file)

groom_params = sw_project.get_parameters('groom')

# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
groom_params.set(key, value)

sw_project.set_parameters('groom', groom_params)

optimize_params = sw_project.get_parameters('optimize')
# for each parameter in the form data, set the parameter in the project
for key, value in form_data.items():
optimize_params.set(key, value)

sw_project.set_parameters('optimize', optimize_params)

os.chdir(sw_project.get_project_path())
run_prep(form_data, sw_project, sw_project_file, progress)

Expand All @@ -265,6 +282,9 @@ def run_deepssm_command(
result_data['augmentation'] = {
'total_data_csv': download_dir + '/deepssm/augmentation/TotalData.csv',
'violin_plot': download_dir + '/deepssm/augmentation/violin.png',
'generated_meshes': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Meshes/'
),
'generated_images': os.listdir(
download_dir + '/deepssm/augmentation/Generated-Images/'
),
Expand Down Expand Up @@ -293,6 +313,8 @@ def run_deepssm_command(

run_testing(form_data, sw_project, download_dir, progress)

subjects = sw_project.get_subjects()

result_data['testing'] = {
'world_predictions': os.listdir(
download_dir + '/deepssm/model/test_predictions/world_predictions/'
Expand All @@ -301,6 +323,7 @@ def run_deepssm_command(
download_dir + '/deepssm/model/test_predictions/local_predictions/'
),
'test_distances': download_dir + '/deepssm/test_distances.csv',
'test_split_subjects': subjects,
}

os.chdir('../../')
Expand Down Expand Up @@ -352,6 +375,15 @@ def post_command_function(project, download_dir, result_data, project_filename):
),
)
aug_pair.mesh.save(
result_data['augmentation']['generated_meshes'][i],
open(
download_dir
+ '/deepssm/augmentation/Generated-Meshes/'
+ result_data['augmentation']['generated_meshes'][i],
'rb',
),
)
aug_pair.image.save(
result_data['augmentation']['generated_images'][i],
open(
download_dir
Expand Down Expand Up @@ -402,10 +434,15 @@ def post_command_function(project, download_dir, result_data, project_filename):
file1 = predictions.pop()
filename = file1.split('.')[0]

# filename here represents the SUBJECT INDEX OF THE TEST SPLIT
subject_name = result_data['testing']['test_split_subjects'][
int(filename)
].get_display_name()

test_pair = models.DeepSSMTestingData.objects.create(
project=project,
image_type='world' if predictions == world_predictions else 'local',
image_id=filename,
image_id=subject_name,
)

predictions_path = (
Expand Down Expand Up @@ -457,9 +494,11 @@ def post_command_function(project, download_dir, result_data, project_filename):
for images in [train_images, val_and_test_images]:
for image in images:
image_type = 'train' if images == train_images else 'val_and_test'

train_image = models.DeepSSMTrainingImage.objects.create(
project=project,
validation=True if image_type == 'val_and_test' else False,
index=image.split('.')[0],
)
train_image.image.save(
image,
Expand Down Expand Up @@ -514,7 +553,7 @@ def post_command_function(project, download_dir, result_data, project_filename):
),
)

training_pair.vtk.save(
training_pair.mesh.save(
vtk_file,
open(
download_dir + '/deepssm/model/examples/' + vtk_file,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.25 on 2024-04-08 18:19
# Generated by Django 3.2.25 on 2024-04-11 16:23

from django.db import migrations, models
import django.db.models.deletion
Expand All @@ -23,7 +23,7 @@ class Migration(migrations.Migration):
),
('particles', s3_file_field.fields.S3FileField()),
('scalar', s3_file_field.fields.S3FileField()),
('vtk', s3_file_field.fields.S3FileField()),
('mesh', s3_file_field.fields.S3FileField()),
('index', models.CharField(max_length=255)),
('example_type', models.CharField(max_length=255)),
('validation', models.BooleanField(default=False)),
Expand All @@ -47,6 +47,7 @@ class Migration(migrations.Migration):
),
),
('image', s3_file_field.fields.S3FileField()),
('index', models.CharField(max_length=255)),
('validation', models.BooleanField(default=False)),
(
'project',
Expand All @@ -68,7 +69,7 @@ class Migration(migrations.Migration):
),
),
('image_type', models.CharField(max_length=255)),
('image_id', models.IntegerField()),
('image_id', models.CharField(max_length=255)),
('mesh', s3_file_field.fields.S3FileField()),
('particles', s3_file_field.fields.S3FileField()),
(
Expand Down Expand Up @@ -116,6 +117,7 @@ class Migration(migrations.Migration):
),
),
('sample_num', models.IntegerField()),
('image', s3_file_field.fields.S3FileField()),
('mesh', s3_file_field.fields.S3FileField()),
('particles', s3_file_field.fields.S3FileField()),
(
Expand Down
8 changes: 5 additions & 3 deletions shapeworks_cloud/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class DeepSSMTestingData(models.Model):
Project, on_delete=models.CASCADE, related_name='deepssm_testing_data'
)
image_type = models.CharField(max_length=255)
image_id = models.IntegerField()
image_id = models.CharField(max_length=255)
mesh = S3FileField()
particles = S3FileField()

Expand All @@ -212,8 +212,8 @@ class DeepSSMTrainingPair(models.Model):
)
particles = S3FileField() # .particles
scalar = S3FileField() # .scalar
vtk = S3FileField() # .vtk
index = models.CharField(max_length=255) # subject
mesh = S3FileField() # .vtk
index = models.CharField(max_length=255) # index
example_type = models.CharField(max_length=255) # best, median, worst
validation = models.BooleanField(default=False)

Expand All @@ -223,12 +223,14 @@ class DeepSSMTrainingImage(models.Model):
Project, on_delete=models.CASCADE, related_name='deepssm_training_images'
)
image = S3FileField()
index = models.CharField(max_length=255) # index
validation = models.BooleanField(default=False)


class DeepSSMAugPair(models.Model):
project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='deepssm_aug_pair')
sample_num = models.IntegerField()
image = S3FileField()
mesh = S3FileField()
particles = S3FileField()

Expand Down
8 changes: 3 additions & 5 deletions shapeworks_cloud/core/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rest_framework.viewsets import GenericViewSet

from . import filters, models, serializers
from .deepssm_tasks import deepssm_run
from .tasks import analyze, groom, optimize

DB_WRITE_ACCESS_LOG_FILE = Path(gettempdir(), 'logging', 'db_write_access.log')
Expand Down Expand Up @@ -50,8 +51,8 @@ def save_thumbnail_image(target, encoded_thumbnail):


class Pagination(PageNumberPagination):
page_size = 25
max_page_size = 100
page_size = 100
max_page_size = 200
page_size_query_param = 'page_size'


Expand Down Expand Up @@ -517,9 +518,6 @@ def analyze(self, request, **kwargs):
methods=['POST'],
)
def deepssm_run(self, request, **kwargs):
# lazy import; requires conda shapeworks env activation
from .deepssm_tasks import deepssm_run

project = self.get_object()
form_data = request.data
form_data = {k: str(v) for k, v in form_data.items()}
Expand Down
10 changes: 7 additions & 3 deletions shapeworks_cloud/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DeepSSMTrainingPairSerializer(serializers.ModelSerializer):
validation = serializers.BooleanField()
particles = S3FileSerializerField()
scalar = S3FileSerializerField()
vtk = S3FileSerializerField()
mesh = S3FileSerializerField()
index = serializers.CharField(max_length=255)

class Meta:
Expand All @@ -87,6 +87,7 @@ class Meta:
class DeepSSMTrainingImageSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image = S3FileSerializerField()
index = serializers.CharField(max_length=255)
validation = serializers.BooleanField()

class Meta:
Expand All @@ -98,6 +99,7 @@ class DeepSSMAugPairSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
sample_num = serializers.IntegerField()
mesh = S3FileSerializerField()
image = S3FileSerializerField()
particles = S3FileSerializerField()

class Meta:
Expand All @@ -122,7 +124,7 @@ class Meta:
class DeepSSMTestingDataReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image_type = serializers.CharField(max_length=255)
image_id = serializers.IntegerField()
image_id = serializers.CharField(max_length=255)
mesh = S3FileSerializerField()
particles = S3FileSerializerField()

Expand All @@ -137,7 +139,7 @@ class DeepSSMTrainingPairReadSerializer(serializers.ModelSerializer):
validation = serializers.BooleanField()
particles = S3FileSerializerField()
scalar = S3FileSerializerField()
vtk = S3FileSerializerField()
mesh = S3FileSerializerField()
index = serializers.CharField(max_length=255)

class Meta:
Expand All @@ -148,6 +150,7 @@ class Meta:
class DeepSSMTrainingImageReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
image = S3FileSerializerField()
index = serializers.CharField(max_length=255)
validation = serializers.BooleanField()

class Meta:
Expand All @@ -158,6 +161,7 @@ class Meta:
class DeepSSMAugPairReadSerializer(serializers.ModelSerializer):
project = ProjectSerializer()
mesh = S3FileSerializerField()
image = S3FileSerializerField()
particles = S3FileSerializerField()
sample_num = serializers.IntegerField()

Expand Down
Loading
Loading