I was trying to implement tensorflow-federated simple fedavg with cifar10 dataset and resnet18. Also this is the pytorch implementation. Just like trainable ones, I have aggregated non-trainable parameters of batch-normalization to server and averaged them. I have used 5 clients and dataset was divided to 5 randomly, 50k/5=10k training samples for each client, so there is no gross skewed distribution. I have tested each client, after training, with the full test dataset,10k samples, that I also use to test server. The problem is after first training round despite each client had 20-25% accuracy, the server has 10% accuracy and basically makes nearly the same predictions for each input. This is the only the case for first round since after that round server has almost always better accuracy than any client had in that round. For example
Round 0 training loss: 3.0080783367156982
Round 0 client_id: 0 eval_score: 0.2287999987602234
Round 0 client_id: 1 eval_score: 0.2614000141620636
Round 0 client_id: 2 eval_score: 0.22040000557899475
Round 0 client_id: 3 eval_score: 0.24799999594688416
Round 0 client_id: 4 eval_score: 0.2565999925136566
Round 0 validation accuracy: 10.0
Round 1 training loss: 1.920640230178833
Round 1 client_id: 0 eval_score: 0.25220000743865967
Round 1 client_id: 1 eval_score: 0.32199999690055847
Round 1 client_id: 2 eval_score: 0.32580000162124634
Round 1 client_id: 3 eval_score: 0.3513000011444092
Round 1 client_id: 4 eval_score: 0.34689998626708984
Round 1 validation accuracy: 34.470001220703125
Round 2 training loss: 1.65810227394104
Round 2 client_id: 0 eval_score: 0.34369999170303345
Round 2 client_id: 1 eval_score: 0.3138999938964844
Round 2 client_id: 2 eval_score: 0.35580000281333923
Round 2 client_id: 3 eval_score: 0.39649999141693115
Round 2 client_id: 4 eval_score: 0.3917999863624573
Round 2 validation accuracy: 45.0
Round 3 training loss: 1.4956902265548706
Round 3 client_id: 0 eval_score: 0.46380001306533813
Round 3 client_id: 1 eval_score: 0.388700008392334
Round 3 client_id: 2 eval_score: 0.39239999651908875
Round 3 client_id: 3 eval_score: 0.43700000643730164
Round 3 client_id: 4 eval_score: 0.430400013923645
Round 3 validation accuracy: 50.62000274658203
Round 4 training loss: 1.3692104816436768
Round 4 client_id: 0 eval_score: 0.510200023651123
Round 4 client_id: 1 eval_score: 0.42739999294281006
Round 4 client_id: 2 eval_score: 0.4223000109195709
Round 4 client_id: 3 eval_score: 0.45080000162124634
Round 4 client_id: 4 eval_score: 0.45559999346733093
Round 4 validation accuracy: 54.83000183105469
To solve the issue with first round I tried to repeat the dataset but it didnt help. After that I tried to use all the cifar10 training samples for each client meaning instead of creating 5 different datasets of 10k samples for each client I used all 50k samples as the dataset.
Round 0 training loss: 1.9335068464279175
Round 0 client_id: 0 eval_score: 0.4571000039577484
Round 0 client_id: 1 eval_score: 0.4514000117778778
Round 0 client_id: 2 eval_score: 0.4738999903202057
Round 0 client_id: 3 eval_score: 0.4560000002384186
Round 0 client_id: 4 eval_score: 0.4697999954223633
Round 0 validation accuracy: 10.0
Round 1 training loss: 1.4404207468032837
Round 1 client_id: 0 eval_score: 0.5945000052452087
Round 1 client_id: 1 eval_score: 0.5909000039100647
Round 1 client_id: 2 eval_score: 0.5864999890327454
Round 1 client_id: 3 eval_score: 0.5871999859809875
Round 1 client_id: 4 eval_score: 0.5684000253677368
Round 1 validation accuracy: 59.57999801635742
Round 2 training loss: 1.0174440145492554
Round 2 client_id: 0 eval_score: 0.7002999782562256
Round 2 client_id: 1 eval_score: 0.6953999996185303
Round 2 client_id: 2 eval_score: 0.6830999851226807
Round 2 client_id: 3 eval_score: 0.6682999730110168
Round 2 client_id: 4 eval_score: 0.6754000186920166
Round 2 validation accuracy: 72.41999816894531
Round 3 training loss: 0.7608759999275208
Round 3 client_id: 0 eval_score: 0.7621999979019165
Round 3 client_id: 1 eval_score: 0.7608000040054321
Round 3 client_id: 2 eval_score: 0.7390000224113464
Round 3 client_id: 3 eval_score: 0.7301999926567078
Round 3 client_id: 4 eval_score: 0.7303000092506409
Round 3 validation accuracy: 78.33000183105469
Round 4 training loss: 0.5893330574035645
Round 4 client_id: 0 eval_score: 0.7814000248908997
Round 4 client_id: 1 eval_score: 0.7861999869346619
Round 4 client_id: 2 eval_score: 0.7804999947547913
Round 4 client_id: 3 eval_score: 0.7694000005722046
Round 4 client_id: 4 eval_score: 0.758400022983551
Round 4 validation accuracy: 81.30000305175781
Clients obviously had the same initialization but i guess due to gpu use there were some minor accuracy differences yet each had 45+% accuracy. But as you can see even this didnt help with the first round. When using a simple cnn, such as the one available in the ".main", with suitable parameters this problem doesnt exist. And using
learning_rate=0.01 or momentum=0
instead of
learning_rate=0.1 and momentum=0.9
reduces this for problem the first round but it has overall worse performance and i am trying to reproduce a paper that used the latter parameters.
I have also tried the same with pytorch and got the very similar results. Colab for pytorch code The results for both are available in github.
I am very confused with that. Especially when I used entire training dataset and when each client had 45% accuracy. Also why get good results for following rounds? What changed between first round and the others? Every time clients had the same initialization with each other, same loss function, and same optimizer with the same parameters. The only thing that changed is the actual initialization between rounds.
So is there a special initialization that solves this first round problem or am I missing something?
Edit:
When the entire cifar10 training set is used for each client and dataset.repeat is used to repeat data.
Pre-training validation accuracy: 9.029999732971191
Round 0 training loss: 1.6472676992416382
Round 0 client_id: 0 eval_score: 0.5931000113487244
Round 0 client_id: 1 eval_score: 0.5042999982833862
Round 0 client_id: 2 eval_score: 0.5083000063896179
Round 0 client_id: 3 eval_score: 0.5600000023841858
Round 0 client_id: 4 eval_score: 0.6104999780654907
Round 0 validation accuracy: 10.0
What catches my attention here is the client accuracy here is actually very similar to second round (round 1) accuracy of clients when dataset wasnt repeated(previous results). so eventhough server had 10% accuracy it didnt affect much the results of the next round.
This is how it works with a simple cnn (defined in the main.py in github)
With the training set divided to 5
Pre-training validation accuracy: 9.489999771118164
Round 0 training loss: 2.1234841346740723
Round 0 client_id: 0 eval_score: 0.30250000953674316
Round 0 client_id: 1 eval_score: 0.2879999876022339
Round 0 client_id: 2 eval_score: 0.2533999979496002
Round 0 client_id: 3 eval_score: 0.25999999046325684
Round 0 client_id: 4 eval_score: 0.2897999882698059
Round 0 validation accuracy: 31.18000030517578
Entire training set for all the clients
Pre-training validation accuracy: 9.489999771118164
Round 0 training loss: 1.636365532875061
Round 0 client_id: 0 eval_score: 0.47850000858306885
Round 0 client_id: 1 eval_score: 0.49470001459121704
Round 0 client_id: 2 eval_score: 0.4918000102043152
Round 0 client_id: 3 eval_score: 0.492900013923645
Round 0 client_id: 4 eval_score: 0.4043000042438507
Round 0 validation accuracy: 50.62000274658203
As we can see when a simple cnn is used server accuracy is better than the best client accuracy, and definitely better than the average, beginning from the very first round. I am trying to understand why the resnet fails to do that and makes the same predictions regardless of input. After the first round the predictions look like
[[0.02677999 0.02175025 0.10807421 0.25275248 0.08478505 0.20601839
0.16497472 0.09307405 0.01779539 0.02399557]
[0.04087764 0.03603332 0.09987792 0.23636964 0.07425722 0.19982725
0.13649824 0.09779423 0.03454168 0.04392283]
[0.02448712 0.01900426 0.11061406 0.25295085 0.08886322 0.20792796
0.17296027 0.08762561 0.01570844 0.01985822]
[0.01790532 0.01536059 0.11237497 0.2519772 0.09357632 0.20954111
0.18946911 0.08571784 0.01004946 0.01402805]
[0.02116687 0.02263201 0.10294028 0.25523028 0.08544692 0.21299754
0.17604835 0.088608 0.01438032 0.02054946]
[0.01598492 0.01457187 0.10899033 0.25493488 0.09417254 0.20747423
0.19798534 0.08387674 0.0089481 0.01306108]
[0.01432306 0.01214803 0.11237216 0.25138852 0.09796435 0.2036258
0.20656979 0.08344456 0.00726837 0.01089529]
[0.01605278 0.0135905 0.11161591 0.25388476 0.09531546 0.20592561
0.19932476 0.08305667 0.00873495 0.01249863]
[0.02512863 0.0238647 0.10465285 0.24918261 0.08625458 0.21051233
0.16839236 0.09075507 0.01765386 0.02360307]
[0.05418856 0.05830322 0.09909651 0.20211859 0.07324574 0.18549475
0.11666768 0.0990423 0.05081367 0.06102907]]
They all return 3rd label.
See Question&Answers more detail:
os