-
Notifications
You must be signed in to change notification settings - Fork 0
Loading Trained Parameters Weights
Matt Norman edited this page May 16, 2022
·
1 revision
To load Keras weights, which are in HDF5 format, please use the ponni::load_h5_weights routine. An example is below:
ponni::Matvec matvec_1( ponni::load_h5_weights<2>( fname , "/dense/dense" , "kernel:0" ) );
ponni::Bias bias_1 ( ponni::load_h5_weights<1>( fname , "/dense/dense" , "bias:0" ) );
ponni::Matvec matvec_2( ponni::load_h5_weights<2>( fname , "/dense_1/dense_1" , "kernel:0" ) );
ponni::Bias bias_2 ( ponni::load_h5_weights<1>( fname , "/dense_1/dense_1" , "bias:0" ) );Pytorch likes to store its weights in .pkl format, but ponni prefers HDF5. To save your PyTorch trainable parameters in HDF5, please use the following python code:
import h5py
f_id = h5py.File('file_name_goes_here.h5, 'w')
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
d_id = f_id.create_dataset(param_tensor,
model.state_dict()[param_tensor].size(),
dtype='f',
data=model.state_dict()[param_tensor])
f_id.close()Also, note that the PyTorch matrices are stored in row,column format whereas ponni prefers weights in column,row format. Therefore, you'll need to set the transpose parameter to true in the call to ponni::load_h5_weights like the example below:
bool transpose = true;
ponni::Matvec matvec_1( ponni::load_h5_weights<2>( fname , "/" , "0.0.0.0.1.weight" , transpose ) );
ponni::Bias bias_1 ( ponni::load_h5_weights<1>( fname , "/" , "0.0.0.0.1.bias" , transpose ) );
ponni::Matvec matvec_2( ponni::load_h5_weights<2>( fname , "/" , "0.0.0.2.sequential.0.weight" , transpose ) );
ponni::Bias bias_2 ( ponni::load_h5_weights<1>( fname , "/" , "0.0.0.2.sequential.0.bias" , transpose ) );
ponni::Matvec matvec_3( ponni::load_h5_weights<2>( fname , "/" , "0.0.2.sequential.0.weight" , transpose ) );
ponni::Bias bias_3 ( ponni::load_h5_weights<1>( fname , "/" , "0.0.2.sequential.0.bias" , transpose ) );
ponni::Matvec matvec_4( ponni::load_h5_weights<2>( fname , "/" , "0.2.sequential.0.weight" , transpose ) );
ponni::Bias bias_4 ( ponni::load_h5_weights<1>( fname , "/" , "0.2.sequential.0.bias" , transpose ) );
ponni::Matvec matvec_5( ponni::load_h5_weights<2>( fname , "/" , "2.weight" , transpose ) );
ponni::Bias bias_5 ( ponni::load_h5_weights<1>( fname , "/" , "2.bias" , transpose ) );