Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
188 views
in Technique[技术] by (71.8m points)

python - Pytorch: How to unflatten/get back the network from flattened network?

I am using the following function to flatten the network:

#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
    flatNet = []
    shapes = []
    for param in net.parameters():
        #if its WEIGHTS
        curr_shape = param.cpu().data.numpy().shape
        shapes.append(curr_shape)
        if len(curr_shape) == 2:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
            flatNet.append(param)
        elif len(curr_shape) == 4:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
            flatNet.append(param)
        else:
            param = param.cpu().data.numpy().reshape(curr_shape[0])
            flatNet.append(param)
    finalNet = []
    for obj in flatNet:
        for x in obj:
            finalNet.append(x)
    finalNet = np.array(finalNet)
    return finalNet,shapes

The above function returns all the weights as a numpy column vector finalNet and shapes (list) of the network. I want to see the effect of weight modifications on the prediction accuracy. So, I change the weights. How can I copy this modified weight vector back to the original network? Please help. Thank you.

question from:https://stackoverflow.com/questions/65941834/pytorch-how-to-unflatten-get-back-the-network-from-flattened-network

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

There is a difference between model definition (its forward function), and the parameter configuration (what's called model state, and is easily accessible as a dictionary using state_dict).

You can get a model's state, as you did with your implementation flattenNetwork. However reverting this operation (i.e. if you only have the weights and layer shapes), for pretty much all models, is not possible.

Now, assuming you do - still - have access to net. My advice is that work with net.state_dict() directly, modify it, then load the dictionary of weights back with load_state_dict. This way, you will avoid having to deal with serializing the model's parameters yourself.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...