From daac33a70bd76c32eeb31f11687116f087b8e983 Mon Sep 17 00:00:00 2001 From: Tsing <2719584131@qq.com> Date: Sat, 1 Apr 2023 12:06:53 +0800 Subject: [PATCH] Update clientapfl.py --- system/flcore/clients/clientapfl.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/system/flcore/clients/clientapfl.py b/system/flcore/clients/clientapfl.py index befb99c3..d665de92 100644 --- a/system/flcore/clients/clientapfl.py +++ b/system/flcore/clients/clientapfl.py @@ -27,6 +27,7 @@ def train(self): # self.model.to(self.device) self.model.train() + self.model_per.train() max_local_steps = self.local_steps if self.train_slow: @@ -110,9 +111,7 @@ def test_metrics(self): def train_metrics(self): trainloader = self.load_train_data() - # self.model = self.load_model('model') - # self.model.to(self.device) - self.model.eval() + self.model_per.train() train_num = 0 losses = 0 @@ -128,7 +127,4 @@ def train_metrics(self): train_num += y.shape[0] losses += loss_per.item() * y.shape[0] - # self.model.cpu() - # self.save_model(self.model, 'model') - return losses, train_num