Parent classes¶
There are many different cases which might be desirable for fitting an IMNN, i.e. generating simulations on-the-fly with SimulatorIMNN(), or with a fixed set of analytic gradients with GradientIMNN() or performing numerical gradients for the derivatives with NumericalGradientIMNN(), or when the datasets are very large (either in number of elements - n_s - or in shape - input_shape) then the gradients can be manually aggregated using AggregatedSimulatorIMNN(), imnn.AggregatedGradientIMNN() or AggregatedNumericalGradientIMNN(). These are all wrappers around a base class _IMNN() and optionally an aggregation class _AggregatedIMNN(). For completeness these are documented here.
The available modules are:
Base class¶
- 
class imnn.imnn._imnn._IMNN(n_s, n_d, n_params, n_summaries, input_shape, θ_fid, model, optimiser, key_or_state)¶
- Information maximising neural network parent class - This class defines the general fitting framework for information maximising neural networks. It includes the generic calculations of the Fisher information matrix from the outputs of a neural network as well as an XLA compilable fitting routine (with and without a progress bar). This class also provides a plotting routine for fitting history and a function to calculate the score compression of network outputs to quasi-maximum likelihood estimates of model parameter values. - The outline of the fitting procedure is that a set of \(i\in[1, n_s]\) simulations and \(n_d\) derivatives with respect to physical model parameters are used to calculate network outputs and their derivatives with respect to the physical model parameters, \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\), where \(\alpha\) labels the physical parameter. The exact details of how these are calculated depend on the type of available data (see list of different IMNN below). With \({\bf x}^i\) and \(\partial{{\bf x}^i}/\partial\theta_\alpha\) the covariance \[C_{ab} = \frac{1}{n_s-1}\sum_{i=1}^{n_s}(x^i_a-\mu^i_a) (x^i_b-\mu^i_b)\]- and the derivative of the mean of the network outputs with respect to the model parameters \[\frac{\partial\mu_a}{\partial\theta_\alpha} = \frac{1}{n_d} \sum_{i=1}^{n_d}\frac{\partial{x^i_a}}{\partial\theta_\alpha}\]- can be calculated and used form the Fisher information matrix \[F_{\alpha\beta} = \frac{\partial\mu_a}{\partial\theta_\alpha} C^{-1}_{ab}\frac{\partial\mu_b}{\partial\theta_\beta}.\]- The loss function is then defined as \[\Lambda = -\log|{\bf F}| + r(\Lambda_2) \Lambda_2\]- Since any linear rescaling of a sufficient statistic is also a sufficient statistic the negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs. We choose to fix this scale by constraining the covariance of network outputs as \[\Lambda_2 = ||{\bf C}-{\bf I}|| + ||{\bf C}^{-1}-{\bf I}||\]- Choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information as above. To avoid having a dual optimisation objective, we use a smooth and dynamic regularisation strength which turns off the regularisation to focus on maximising the Fisher information when the covariance has set the scale \[r(\Lambda_2) = \frac{\lambda\Lambda_2}{\Lambda_2-\exp (-\alpha\Lambda_2)}.\]- Once the loss function is calculated the automatic gradient is then calculated and used to update the network parameters via the optimiser function. Note for large input data-sizes, large \(n_s\) or massive networks the gradients may need manually accumulating via the - _AggregatedIMNN().- _IMNNis designed as the parent class for a range of specific case IMNNs. There is a helper function (IMNN) which should return the correct case when provided with the correct data. These different subclasses are:- Fit an IMNN using simulations generated on-the-fly from a jax (XLA compilable) simulator - Fit an IMNN using a precalculated set of fiducial simulations and their derivatives with respect to model parameters - Fit an IMNN using a precalculated set of fiducial simulations and simulations generated using parameter values just above and below the fiducial parameter values to make a numerical estimate of the derivatives of the network outputs. Best stability is achieved when seeds of the simulations are matched between all parameter directions for the numerical derivative - SimulatorIMNNdistributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrix- GradientIMNNdistributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrix- AggregatedNumericalGradientIMNN():- NumericalGradientIMNNdistributed over multiple jax devices and gradients aggregated manually. This might be necessary for very large input sizes as batching cannot be done when calculating the Fisher information matrix- AggregatedGradientIMNNwith prebuilt TensorFlow datasets- DatasetNumericalGradientIMNN():- AggregatedNumericalGradientIMNNwith prebuilt TensorFlow datasets- There are currently two other parent classes - AggregatedIMNN():- This is the parent class which provides the fitting routine when the gradients of the network parameters are aggregated manually rather than automatically by jax. This is necessary if the size of an entire batch of simulations (and their derivatives with respect to model parameters) and the network parameters and their calculated gradients is too large to fit into memory. Note there is a significant performance loss from using the aggregation so it should only be used for these large data cases - Parameters
- n_s (int) – Number of simulations used to calculate network output covariance 
- n_d (int) – Number of simulations used to calculate mean of network output derivative with respect to the model parameters 
- n_params (int) – Number of model parameters 
- n_summaries (int) – Number of summaries, i.e. outputs of the network 
- input_shape (tuple) – The shape of a single input to the network 
- θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs 
- validate (bool) – Whether a validation set is being used 
- simulate (bool) – Whether input simulations are generated on the fly 
- _run_with_pbar (bool) – Book keeping parameter noting that a progress bar is used when fitting (induces a performance hit). If - run_with_pbar = Trueand- run_without_pbar = Truethen a jit compilation error will occur and so it is prevented
- _run_without_pbar (bool) – Book keeping parameter noting that a progress bar is not used when fitting. If - run_with_pbar = Trueand- run_without_pbar = Truethen a jit compilation error will occur and so it is prevented
- F (float(n_params, n_params)) – Fisher information matrix calculated from the network outputs 
- invF (float(n_params, n_params)) – Inverse Fisher information matrix calculated from the network outputs 
- C (float(n_summaries, n_summaries)) – Covariance of the network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of the network outputs 
- μ (float(n_summaries,)) – Mean of the network outputs 
- dμ_dθ (float(n_summaries, n_params)) – Derivative of the mean of the network outputs with respect to model parameters 
- state (:obj:state) – The optimiser state used for updating the network parameters and optimisation algorithm 
- initial_w (list) – List of the network parameters values at initialisation (to restart) 
- final_w (list) – List of the network parameters values at the end of fitting 
- best_w (list) – List of the network parameters values which provide the maxmimum value of the determinant of the Fisher matrix 
- w (list) – List of the network parameters values (either final or best depending on setting when calling fit(…)) 
- history (dict) – - A dictionary containing the fitting history. Keys are
- detF – determinant of the Fisher information at the end of each iteration 
- detC – determinant of the covariance of network outputs at the end of each iteration 
- detinvC – determinant of the inverse covariance of network outputs at the end of each iteration 
- Λ2 – value of the covariance regularisation at the end of each iteration 
- r – value of the regularisation coupling at the end of each iteration 
- val_detF – determinant of the Fisher information of the validation data at the end of each iteration 
- val_detC – determinant of the covariance of network outputs given the validation data at the end of each iteration 
- val_detinvC – determinant of the inverse covariance of network outputs given the validation data at the end of each iteration 
- val_Λ2 – value of the covariance regularisation given the validation data at the end of each iteration 
- val_r – value of the regularisation coupling given the validation data at the end of each iteration 
- max_detF – maximum value of the determinant of the Fisher information on the validation data (if available) 
 
 
 
 - 
model:
- Neural network as a function of network parameters and inputs 
 - 
_get_parameters:
- Function which extracts the network parameters from the state 
 - 
_model_initialiser:
- Function to initialise neural network weights from RNG and shape tuple 
 - 
_opt_initialiser:
- Function which generates the optimiser state from network parameters 
 - 
_update:
- Function which updates the state from a gradient 
 - Public Methods: - __init__(n_s, n_d, n_params, n_summaries, …)- Constructor method - fit(λ, ε[, rng, patience, min_iterations, …])- Fitting routine for the IMNN - get_α(λ, ε)- Calculate rate parameter for regularisation from closeness criterion - set_F_statistics([w, key, validate])- Set necessary attributes for calculating score compressed summaries - get_summaries([w, key, validate])- Gets all network outputs and derivatives wrt model parameters - get_estimate(d)- Calculate score compressed parameter estimates from network outputs - plot([ax, expected_detF, colour, figsize, …])- Plot fitting history - Private Methods: - _initialise_parameters(n_s, n_d, n_params, …)- Performs type checking and initialisation of class attributes - _initialise_model(model, optimiser, key_or_state)- Initialises neural network parameters or loads optimiser state - Initialises history dictionary attribute - _set_history(results)- Places results from fitting into the history dictionary - _set_inputs(rng, max_iterations)- Builds list of inputs for the XLA compilable fitting routine - _get_fitting_keys(rng)- Generates random numbers for simulation generation if needed - _fit(inputs, λ=None, α=None[, min_iterations])- Single iteration fitting algorithm - _fit_cond(inputs, patience, max_iterations)- Stopping condition for the fitting loop - _update_loop_vars(inputs)- Updates input parameters if - max_detFis increased- _check_loop_vars(inputs, min_iterations)- Updates - patience_counterif- max_detFnot increased- _update_history(inputs, history, counter, ind)- Puts current fitting statistics into history arrays - _slogdet(matrix)- Combined summed logarithmic determinant - _construct_derivatives(derivatives)- Builds derivatives of the network outputs wrt model parameters - _get_F_statistics([w, key, validate])- Calculates the Fisher information and returns all statistics used - _calculate_F_statistics(summaries, derivatives)- Calculates the Fisher information matrix from network outputs - _get_regularisation_strength(Λ2, λ, α)- Coupling strength of the regularisation (amplified sigmoid) - _get_regularisation(C, invC)- Difference of the covariance (and its inverse) from identity - _get_loss(w, λ, α[, key])- Calculates the loss function and returns auxillary variables - _calculate_loss(summaries, derivatives, λ, α)- Calculates the loss function from network summaries and derivatives - _setup_plot([ax, expected_detF, figsize])- Builds axes for history plot 
 - 
_calculate_F_statistics(summaries, derivatives)¶
- Calculates the Fisher information matrix from network outputs - If the numerical derivative is being calculated then the derivatives are first constructed. If the mean is to be returned (for use in score compression), this is calculated and pushed to the results tuple. Then the covariance of the summaries is taken and inverted and the mean of the derivative of network summaries with respect to the model parameters is found and these are used to calculate the Gaussian form of the Fisher information matrix. - Parameters
- summaries (float(n_s, n_summaries)) – The network outputs 
- derivatives (float(n_d, n_summaries, n_params)) – The derivative of the network outputs wrt the model parameters. Note that when - NumericalGradientIMNNis being used the shape is- float(n_d, 2, n_params, n_summaries)which is then constructed into the the numerical derivative in- _construct_derivatives.
 
- Returns
- F (float(n_params, n_params)) – Fisher information matrix 
- C (float(n_summaries, n_summaries)) – Covariance of network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs 
- dμ_dθ (float(n_summaries, n_params)) – The derivative of the mean of network outputs with respect to model parameters 
- μ (float(n_summaries)) – The mean of the network outputs 
 
- Return type
- tuple 
 
 - 
_calculate_loss(summaries, derivatives, λ, α)¶
- Calculates the loss function from network summaries and derivatives - Parameters
- summaries (float(n_s, n_summaries)) – The network outputs 
- derivatives (float(n_d, n_summaries, n_params)) – The derivative of the network outputs wrt the model parameters. Note that when - NumericalGradientIMNNis being used the shape is- float(n_d, 2, n_params, n_summaries)which is then constructed into the the numerical derivative in- _construct_derivatives.
- λ (float) – Coupling strength of the regularisation 
- α (float) – Calculate rate parameter for regularisation from ϵ criterion 
 
- Returns
- float – Value of the regularised loss function 
- tuple – - Fitting statistics calculated on a single iteration
- F (float(n_params, n_params)) – Fisher information matrix 
- C (float(n_summaries, n_summaries)) – Covariance of network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs 
- Λ2 (float) – Covariance regularisation 
- r (float) – Regularisation coupling strength 
 
 
 
 
 - 
_check_loop_vars(inputs, min_iterations)¶
- Updates - patience_counterif- max_detFnot increased- If the determinant of the Fisher information matrix calculated in a given iteration is not larger than the - max_detFcalculated so far then the- patience_counteris increased by one as long as the number of iterations is greater than the minimum number of iterations that should be run.- Parameters
- inputs (tuple) – - patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix 
- counter (int) – While loop iteration counter 
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix 
- max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far 
- w (list) – Value of the network parameters which in current iteration 
- best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix 
 
- min_iterations (int) – Number of iterations that should be run before considering early stopping using the patience counter 
 
- Returns
- (described in Parameters) 
- Return type
- tuple 
 
 - 
_construct_derivatives(derivatives)¶
- Builds derivatives of the network outputs wrt model parameters - An empty directive in - _IMNN,- SimulatorIMNNand- GradientIMNNbut necessary to construct correct shaped derivatives when using- NumericalGradientIMNN.- Parameters
- derivatives (float(n_d, n_summaries, n_params)) – The derivatives of the network ouputs with respect to the model parameters 
- Returns
- The derivatives of the network ouputs with respect to the model parameters 
- Return type
- float(n_d, n_summaries, n_params) 
 
 - 
_fit(inputs, λ=None, α=None, min_iterations=None)¶
- Single iteration fitting algorithm - This function performs the network parameter updates first getting any necessary random number generators for simulators and then extracting the network parameters from the state. These parameters are used to calculate the gradient with respect to the network parameters of the loss function (see _IMNN class docstrings). Once the loss function is calculated the gradient is then used to update the network parameters via the optimiser function and the current iterations statistics are saved to the history arrays. If validation is used (recommended for - GradientIMNNand- NumericalGradientIMNN) then all necessary statistics to calculate the loss function are calculated and pushed to the history arrays.- The - patience_counteris increased if the value of determinant of the Fisher information matrix does not increase over the previous iterations upto- patiencenumber of iterations at which point early stopping occurs, but only if the number of iterations so far performed is greater than a specified- min_iterations.- Parameters
- inputs (tuple) – - max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far 
- best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix 
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix 
- detC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the covariance of network outputs 
- detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the inverse covariance of network outputs 
- Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – History of the covariance regularisation 
- r (float(max_iterations, 1) or float(max_iterations, 2)) – History of the regularisation coupling strength 
- counter (int) – While loop iteration counter 
- patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix 
- state (:obj: state) – Optimiser state used for updating the network parameters and optimisation algorithm 
- rng (int(2,)) – Stateless random number generator 
 
- λ (float) – Coupling strength of the regularisation 
- α (float) – Rate parameter for regularisation coupling 
- min_iterations (int) – Number of iterations that should be run before considering early stopping using the patience counter 
 
- Returns
- loop variables (described in Parameters) 
- Return type
- tuple 
 
 - 
_fit_cond(inputs, patience, max_iterations)¶
- Stopping condition for the fitting loop - The stopping conditions due to reaching - max_iterationsor the patience counter reaching- patiencedue to- patience_counternumber of iterations without increasing the determinant of the Fisher information matrix.- Parameters
- inputs (tuple) – - max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far 
- best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix 
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix 
- detC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the covariance of network outputs 
- detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the inverse covariance of network outputs 
- Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – History of the covariance regularisation 
- r (float(max_iterations, 1) or float(max_iterations, 2)) – History of the regularisation coupling strength 
- counter (int) – While loop iteration counter 
- patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix 
- state (:obj: state) – Optimiser state used for updating the network parameters and optimisation algorithm 
- rng (int(2,)) – Stateless random number generator 
 
- patience (int) – Number of iterations to stop the fitting when there is no increase in the value of the determinant of the Fisher information matrix 
- max_iterations (int) – 
- number of iterations to run the fitting procedure for (Maximum) – 
 
- Returns
- True if either the - patience_counterhas not reached the- patiencecriterion or if the- counterhas not reached- max_iterations
- Return type
- bool 
 
 - 
_get_F_statistics(w=None, key=None, validate=False)¶
- Calculates the Fisher information and returns all statistics used - First gets the summaries and derivatives and then uses them to calculate the Fisher information matrix from the outputs and return all the necessary constituents to calculate the Fisher information (which) are needed for the score compression or the regularisation of the loss function. - Parameters
- w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters 
- key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly 
- validate (bool, default=True) – Whether to calculate Fisher information using the validation set 
 
- Returns
- F (float(n_params, n_params)) – Fisher information matrix 
- C (float(n_summaries, n_summaries)) – Covariance of network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs 
- dμ_dθ (float(n_summaries, n_params)) – The derivative of the mean of network outputs with respect to model parameters 
- μ (float(n_summaries,)) – The mean of the network outputs 
 
- Return type
- tuple 
 
 - 
_get_fitting_keys(rng)¶
- Generates random numbers for simulation generation if needed - Parameters
- rng (int(2,) or None) – A random number generator 
- Returns
- A new random number generator and random number generators for training and validation, or empty values 
- Return type
- int(2,), int(2,), int(2,) or None, None, None 
 
 - 
_get_loss(w, λ, α, key=None)¶
- Calculates the loss function and returns auxillary variables - First gets the summaries and derivatives and then uses them to calculate the loss function. This function is separated to be able to use - jax.graddirectly rather than calculating the derivative of the summaries as is done with- _AggregatedIMNN.- Parameters
- w (list) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters 
- λ (float) – Coupling strength of the regularisation 
- α (float) – Calculate rate parameter for regularisation from ϵ criterion 
- key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly 
 
- Returns
- float – Value of the regularised loss function 
- tuple – - Fitting statistics calculated on a single iteration
- F (float(n_params, n_params)) – Fisher information matrix 
- C (float(n_summaries, n_summaries)) – Covariance of network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs 
- Λ2 (float) – Covariance regularisation 
- r (float) – Regularisation coupling strength 
 
 
 
 
 - 
_get_regularisation(C, invC)¶
- Difference of the covariance (and its inverse) from identity - The negative logarithm of the determinant of the Fisher information matrix needs to be regularised to fix the scale of the network outputs since any linear rescaling of a sufficient statistic is also a sufficient statistic. We choose to fix this scale by constraining the covariance of network outputs as \[\Lambda_2 = ||\bf{C}-\bf{I}|| + ||\bf{C}^{-1}-\bf{I}||\]- One benefit of choosing this constraint is that it forces the covariance to be approximately parameter independent which justifies choosing the covariance independent Gaussian Fisher information. - Parameters
- C (float(n_summaries, n_summaries)) – Covariance of the network ouputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of the network ouputs 
 
- Returns
- Regularisation loss terms for the distance of the covariance and its determinant from the identity matrix 
- Return type
- float 
 
 - 
_get_regularisation_strength(Λ2, λ, α)¶
- Coupling strength of the regularisation (amplified sigmoid) - To dynamically turn off the regularisation when the scale of the covariance is set to approximately the identity matrix, a smooth sigmoid conditional on the value of the regularisation is used. The rate, α, is calculated from a closeness condition of the covariance (and the inverse covariance) to the identity matrix using - get_α.- Parameters
- Λ2 (float) – Covariance regularisation 
- λ (float) – Coupling strength of the regularisation 
- α (float) – Calculate rate parameter for regularisation from ϵ criterion 
 
- Returns
- Smooth, dynamic regularisation strength 
- Return type
- float 
 
 - 
_initialise_history()¶
- Initialises history dictionary attribute - Notes - The contents of the history dictionary are
- detF – determinant of the Fisher information at the end of each iteration 
- detC – determinant of the covariance of network outputs at the end of each iteration 
- detinvC – determinant of the inverse covariance of network outputs at the end of each iteration 
- Λ2 – value of the covariance regularisation at the end of each iteration 
- r – value of the regularisation coupling at the end of each iteration 
- val_detF – determinant of the Fisher information of the validation data at the end of each iteration 
- val_detC – determinant of the covariance of network outputs given the validation data at the end of each iteration 
- val_detinvC – determinant of the inverse covariance of network outputs given the validation data at the end of each iteration 
- val_Λ2 – value of the covariance regularisation given the validation data at the end of each iteration 
- val_r – value of the regularisation coupling given the validation data at the end of each iteration 
- max_detF – maximum value of the determinant of the Fisher information on the validation data (if available) 
 
 
 - 
_initialise_model(model, optimiser, key_or_state)¶
- Initialises neural network parameters or loads optimiser state - Parameters
- model (tuple, len=2) – Tuple containing functions to initialise neural network - fn(rng: int(2), input_shape: tuple) -> tuple, listand the neural network as a function of network parameters and inputs- fn(w: list, d: float([None], input_shape)) -> float([None], n_summaries). (Essentibly stax-like, see jax.experimental.stax ))
- optimiser (tuple or obj, len=3) – Tuple containing functions to generate the optimiser state - fn(x0: list) -> :obj:state, to update the state from a list of gradients- fn(i: int, g: list, state: :obj:state) -> :obj:stateand to extract network parameters from the state- fn(state: :obj:state) -> list. (See jax.experimental.optimizers)
- key_or_state (int(2) or :obj:state) – Either a stateless random number generator or the state object of an preinitialised optimiser 
 
 - Notes - The design of the model follows jax’s stax module in that the model is encapsulated by two functions, one to initialise the network and one to call the model, i.e.: - import jax from jax.experimental import stax rng = jax.random.PRNGKey(0) data_key, model_key = jax.random.split(rng) input_shape = (10,) inputs = jax.random.normal(data_key, shape=input_shape) model = stax.serial( stax.Dense(10), stax.LeakyRelu, stax.Dense(10), stax.LeakyRelu, stax.Dense(2)) output_shape, initial_params = model[0](model_key, input_shape) outputs = model[1](initial_params, inputs) - Note that the model used in the IMNN is assumed to be totally broadcastable, i.e. any batch shape can be used for inputs. This might require having a layer which reshapes all batch dimensions into a single dimension and then unwraps it at the last layer. A model such as that above is already fully broadcastable. - The optimiser should follow jax’s experimental optimiser module in that the optimiser is encapsulated by three functions, one to initialise the state, one to update the state from a list of gradients and one to extract the network parameters from the state, .i.e - from jax.experimental import optimizers import jax.numpy as np optimiser = optimizers.adam(step_size=1e-3) initial_state = optimiser[0](initial_params) params = optimiser[2](initial_state) def scalar_output(params, inputs): return np.sum(model[1](params, inputs)) counter = 0 grad = jax.grad(scalar_output, argnums=0)(params, inputs) state = optimiser[1](counter, grad, state) - This function either initialises the neural network or the state if passed a stateless random number generator in - key_or_stateor loads a predefined state if the state is passed to- key_or_state. The functions get mapped to the class functions- self.model = model[1] self._model_initialiser = model[0] self._opt_initialiser = optimiser[0] self._update = optimiser[1] self._get_parameters = optimiser[2] - The state is made into the - stateclass attribute and the parameters are assigned to- initial_w,- final_w,- best_wand- wclass attributes (where- wstands for weights).- There is some type checking done, but for freedom of choice of model there will be very few raised warnings. - Raises
- TypeError – If the random number generator is not correct, or if there is no possible way to construct a model or an optimiser from the passed parameters 
- ValueError – If any input is - Noneor if the functions for the model or optimiser do not conform to the necessary specifications
 
 
 - 
_initialise_parameters(n_s, n_d, n_params, n_summaries, input_shape, θ_fid)¶
- Performs type checking and initialisation of class attributes - Parameters
- n_s (int) – Number of simulations used to calculate summary covariance 
- n_d (int) – Number of simulations used to calculate mean of summary derivative 
- n_params (int) – Number of model parameters 
- n_summaries (int) – Number of summaries, i.e. outputs of the network 
- input_shape (tuple) – The shape of a single input to the network 
- θ_fid (float(n_params,)) – The value of the fiducial parameter values used to generate inputs 
 
- Raises
- TypeError – Any of the parameters are not correct type 
- ValueError – Any of the parameters are - None- Θ_fidhas the wrong shape
 
 
 - 
_set_history(results)¶
- Places results from fitting into the history dictionary - Parameters
- results (list) – - List of results from fitting procedure. These are:
- detF (float(n_iterations, 2)) – determinant of the Fisher information, - detF[:, 0]for training and- detF[:, 1]for validation
- detC (float(n_iterations, 2)) – determinant of the covariance of network outputs, - detC[:, 0]for training and- detC[:, 1]for validation
- detinvC (float(n_iterations, 2)) – determinant of the inverse covariance of network outputs, - detinvC[:, 0]for training and- detinvC[:, 1]for validation
- Λ2 (float(n_iterations, 2)) – value of the covariance regularisation, - Λ2[:, 0]for training and- Λ2[:, 1]for validation
- r (float(n_iterations, 2)) – value of the regularisation coupling, - r[:, 0]for training and- r[:, 1]for validation
 
 
 
 - 
_set_inputs(rng, max_iterations)¶
- Builds list of inputs for the XLA compilable fitting routine - Parameters
- rng (int(2,) or None) – A stateless random number generator 
- max_iterations (int) – Maximum number of iterations to run the fitting procedure for 
 
 - Notes - The list of inputs to the routine are
- max_detF (float) – The maximum value of the determinant of the Fisher information matrix calculated so far. This is zero if not run before or the value from previous calls to - fit
- best_w (list) – The value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix. This is the initial network parameter values if not run before 
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the Fisher information matrix during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of - (max_iterations, 1), otherwise validation values are stored in- detF[:, 1].
- detC (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the covariance of network outputs during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of - (max_iterations, 1), otherwise validation values are stored in- detC[:, 1].
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the determinant of the inverse covariance of network outputs during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of - (max_iterations, 1), otherwise validation values are stored in- detinvC[:, 1].
- Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the covariance regularisation during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of - (max_iterations, 1), otherwise validation values are stored in- Λ2[:, 1].
- r (float(max_iterations, 1) or float(max_iterations, 2)) – A container for all possible values of the regularisation coupling strength during each iteration of fitting. If there is no validation (for simulation on-the-fly for example) then this container has a shape of - (max_iterations, 1),otherwise validation values are stored in- r[:, 1].
- counter (int) – Iteration counter used to note whether the while loop reaches - max_iterations. If not, the history objects (above) get truncated to length- counter. This starts with value zero
- patience_counter (int) – Counts the number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping. This starts with value zero 
- state (:obj:state) – The current optimiser state used for updating the network parameters and optimisation algorithm 
- rng (int(2,)) – A stateless random number generator which gets updated on each iteration 
 
 
 - 
_setup_plot(ax=None, expected_detF=None, figsize=(5, 15))¶
- Builds axes for history plot - Parameters
- ax (mpl.axes or None, default=None) – An axes object of predefined axes to be labelled 
- expected_detF (float or None, default=None) – Value of the expected determinant of the Fisher information to plot a horizontal line at to check fitting progress 
- figsize (tuple, default=(5, 15)) – The size of the figure to be produced 
 
- Returns
- An axes object of labelled axes 
- Return type
- mpl.axes 
 
 - 
_slogdet(matrix)¶
- Combined summed logarithmic determinant - Parameters
- matrix (float(n, n)) – An n x n matrix to calculate the summed logarithmic determinant of 
- Returns
- The summed logarithmic determinant multiplied by its sign 
- Return type
- float 
 
 - 
_update_history(inputs, history, counter, ind)¶
- Puts current fitting statistics into history arrays - Parameters
- inputs (tuple) – - Fitting statistics calculated on a single iteration
- F (float(n_params, n_params)) – Fisher information matrix 
- C (float(n_summaries, n_summaries)) – Covariance of network outputs 
- invC (float(n_summaries, n_summaries)) – Inverse covariance of network outputs 
- _Λ2 (float) – Covariance regularisation 
- _r (float) – Regularisation coupling strength 
 
 
- history (tuple) – - History arrays containing fitting statistics for each iteration
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the Fisher information matrix 
- detC (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the covariance of network outputs 
- detinvC (float(max_iterations, 1) or float(max_iterations, 2)) – Determinant of the inverse covariance of network outputs 
- Λ2 (float(max_iterations, 1) or float(max_iterations, 2)) – Covariance regularisation 
- r (float(max_iterations, 1) or float(max_iterations, 2)) – Regularisation coupling strength 
 
 
- counter (int) – Current iteration to insert a single iteration statistics into the history 
- ind (int) – Values of either 0 (fitting) or 1 (validation) to separate the fitting and validation historys 
 
- Returns
- float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the Fisher information matrix 
- float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the covariance of network outputs 
- float(max_iterations, 1) or float(max_iterations, 2) – History of the determinant of the inverse covariance of network outputs 
- float(max_iterations, 1) or float(max_iterations, 2) – History of the covariance regularisation 
- float(max_iterations, 1) or float(max_iterations, 2) – History of the regularisation coupling strength 
 
 
 - 
_update_loop_vars(inputs)¶
- Updates input parameters if - max_detFis increased- If the determinant of the Fisher information matrix calculated in a given iteration is larger than the - max_detFcalculated so far then the- patience_counteris reset to zero and the- max_detFis replaced with the current value of- detFand the network parameters in this iteration replace the previous parameters which obtained the highest determinant of the Fisher information,- best_w.- Parameters
- inputs (tuple) – - patience_counter (int) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix 
- counter (int) – While loop iteration counter 
- detF (float(max_iterations, 1) or float(max_iterations, 2)) – History of the determinant of the Fisher information matrix 
- max_detF (float) – Maximum value of the determinant of the Fisher information matrix calculated so far 
- w (list) – Value of the network parameters which in current iteration 
- best_w (list) – Value of the network parameters which obtained the maxmimum determinant of the Fisher information matrix 
 
- Returns
- (described in Parameters) 
- Return type
- tuple 
 
 - 
fit(λ, ε, rng=None, patience=100, min_iterations=100, max_iterations=100000, print_rate=None, best=True)¶
- Fitting routine for the IMNN - Parameters
- λ (float) – Coupling strength of the regularisation 
- ϵ (float) – Closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be 
- rng (int(2,) or None, default=None) – Stateless random number generator 
- patience (int, default=10) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping 
- min_iterations (int, default=100) – Number of iterations that should be run before considering early stopping using the patience counter 
- max_iterations (int, default=int(1e5)) – Maximum number of iterations to run the fitting procedure for 
- print_rate (int or None, default=None,) – Number of iterations before updating the progress bar whilst fitting. There is a performance hit from updating the progress bar more often and there is a large performance hit from using the progress bar at all. (Possible - RET_CHECKfailure if- print_rateis not- Nonewhen using GPUs). For this reason it is set to None as default
- best (bool, default=True) – Whether to set the network parameter attribute - self.wto the parameter values that obtained the maximum determinant of the Fisher information matrix or the parameter values at the final iteration of fitting
 
 - Example - We are going to summarise the mean and variance of some random Gaussian noise with 10 data points per example using a SimulatorIMNN. In this case we are going to generate the simulations on-the-fly with a simulator written in jax (from the examples directory). We will use 1000 simulations to estimate the covariance of the network outputs and the derivative of the mean of the network outputs with respect to the model parameters (Gaussian mean and variance) and generate the simulations at a fiducial μ=0 and Σ=1. The network will be a stax model with hidden layers of - [128, 128, 128]activated with leaky relu and outputting 2 summaries. Optimisation will be via Adam with a step size of- 1e-3. Rather arbitrarily we’ll set the regularisation strength and covariance identity constraint to λ=10 and ϵ=0.1 (these are relatively unimportant for such an easy model).- import jax import jax.numpy as np from jax.experimental import stax, optimizers from imnn import SimulatorIMNN rng = jax.random.PRNGKey(0) n_s = 1000 n_d = 1000 n_params = 2 n_summaries = 2 input_shape = (10,) simulator_args = {"input_shape": input_shape} θ_fid = np.array([0., 1.]) def simulator(rng, θ): return θ[0] + jax.random.normal( rng, shape=input_shape) * np.sqrt(θ[1]) model = stax.serial( stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(n_summaries)) optimiser = optimizers.adam(step_size=1e-3) λ = 10. ϵ = 0.1 model_key, fit_key = jax.random.split(rng) imnn = SimulatorIMNN( n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries, input_shape=input_shape, θ_fid=θ_fid, model=model, optimiser=optimiser, key_or_state=model_key, simulator=simulator) imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250, print_rate=None) - Notes - A minimum number of interations should be be run before stopping based on a maximum determinant of the Fisher information achieved since the loss function has dual objectives. Since the determinant of the covariance of the network outputs is forced to 1 quickly, this can be at the detriment to the value of the determinant of the Fisher information matrix early in the fitting procedure. For this reason starting early stopping after the covariance has converged is advised. This is not currently implemented but could be considered in the future. - The best fit network parameter values are probably not the most representative set of parameters when simulating on-the-fly since there is a high chance of a statistically overly-informative set of data being generated. Instead, if using - fit()consider using- best=Falsewhich sets- self.w=self.final_wwhich are the network parameter values obtained in the last iteration. Also consider using a larger- patiencevalue if using- fit()to overcome the fact that a flukish high value for the determinant might have been obtained due to the realisation of the dataset.- Due to some unusual thing, that I can’t work out, there is a massive performance hit when calling - jax.jit(self._fit)compared with directly decorating- _fitwith- @partial(jax.jit(static_argnums=0)). Unfortunately this means having to duplicate- _fitto include a version where the loop condition is decorated with a progress bar because the- tqdmmodule cannot use a jitted tracer. If the progress bar is not used then the fully decorated jitted- _fitfunction is used and it is super quick. Otherwise, just the body of the loop is jitted so that the condition function can be decorated by the progress bar (at the expense of a performance hit). I imagine that something can be improved here.- There is a chance of a - RET_CHECKfailure when using the progress bar on GPUs (this doesn’t seem to be a problem on CPUs). If this is the case then print_rate=None should be used- 
_fit:
- Main fitting function implemented as a - jax.lax.while_loop
 - 
_fit_pbar:
- Main fitting function as a - jax.lax.while_loopwith progress bar
 - Raises
- TypeError – If any input has the wrong type 
- ValueError – If any input (except - rngand- print_rate) are- None
- ValueError – If - rnghas the wrong shape
- ValueError – If - rngis- Nonebut simulating on-the-fly
- ValueError – If calling fit with - print_rate=Noneafter previous call with- print_rateas an integer value
- ValueError – If calling fit with - print_rateas an integer after previous call with- print_rate=None
 
 
 - 
get_estimate(d)¶
- Calculate score compressed parameter estimates from network outputs - Using score compression we can get parameter estimates under the transformation \[\hat{\boldsymbol{\theta}}_\alpha=\theta^{\rm{fid}}_\alpha+ \bf{F}^{-1}_{\alpha\beta}\frac{\partial\mu_i}{\partial \theta_\beta}\bf{C}^{-1}_{ij}(x(\bf{w}, \bf{d})-\mu)_j\]- where \(x_j\) is the \(j\) output of the network with network parameters \(\bf{w}\) and input data \(\bf{d}\). - Examples - Assuming that an IMNN has been fit (as in the example in - imnn.imnn._imnn.IMNN.fit()) then we can obtain a pseudo-maximum likelihood estimate of some target data (which is generated with parameter values μ=1, Σ=2) using- rng, target_key = jax.random.split(rng) target_data = model_simulator(target_key, np.array([1., 2.])) imnn.get_estimate(target_data) >>> DeviceArray([0.1108716, 1.7881424], dtype=float32) - The one standard deviation uncertainty on these parameter estimates (assuming the fiducial is at the maximum-likelihood estimate - which we know it isn’t here) estimated by the square root of the inverse Fisher information matrix is - np.sqrt(np.diag(imnn.invF)) >>> DeviceArray([0.31980422, 0.47132865], dtype=float32) - Note that we can compare the values estimated by the IMNN to the value of the mean and the variance of the target data itself, which is what the IMNN should be summarising - np.mean(target_data) >>> DeviceArray(0.10693721, dtype=float32) np.var(target_data) >>> DeviceArray(1.70872, dtype=float32) - Note that batches of data can be summarised at once using - get_estimate. In this example we will draw 10 different values of μ from between \(-10 < \mu < 10\) and 10 different values of Σ from between \(0 < \Sigma < 10\) and generate a batch of 10 different input data which we can summarise using the IMNN.- rng, mean_keys, var_keys = jax.random.split(rng, num=3) mean_vals = jax.random.uniform( mean_keys, minval=-10, maxval=10, shape=(10,)) var_vals = jax.random.uniform( var_keys, minval=0, maxval=10, shape=(10,)) np.stack([mean_vals, var_vals], -1) >>> DeviceArray([[ 3.8727236, 1.6727388], [-3.1113386, 8.14554 ], [ 9.87299 , 1.4134324], [ 4.4837523, 1.5812075], [-9.398947 , 3.5737753], [-2.0789695, 9.978279 ], [-6.2622285, 6.828809 ], [ 4.6470118, 6.0823894], [ 5.7369494, 8.856505 ], [ 4.248898 , 5.114669 ]], dtype=float32) batch_target_keys = np.array(jax.random.split(rng, num=10)) batch_target_data = jax.vmap(model_simulator)( batch_target_keys, (mean_vals, var_vals)) imnn.get_estimate(batch_target_data) >>> DeviceArray([[ 4.6041985, 8.344688 ], [-3.5172062, 7.7219954], [13.229679 , 23.668312 ], [ 5.745726 , 10.020965 ], [-9.734651 , 21.076218 ], [-1.8083427, 6.1901293], [-8.626409 , 18.894459 ], [ 5.7684307, 9.482665 ], [ 6.7861238, 14.128591 ], [ 4.900367 , 9.472563 ]], dtype=float32) - Parameters
- d (float(None, input_shape)) – Input data to be compressed to score compressed parameter estimates 
- Returns
- Score compressed parameter estimates 
- Return type
- float(None, n_params) 
 - 
single_element:
- Returns a single score compressed summary 
 - 
multiple_elements:
- Returns a batch of score compressed summaries 
 - Raises
- ValueError – If the Fisher statistics are not set after running - fitor- set_F_statistics.
 
 - 
get_summaries(w=None, key=None, validate=False)¶
- Gets all network outputs and derivatives wrt model parameters - Parameters
- w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters 
- key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly 
- validate (bool, default=False) – Whether to get summaries of the validation set 
 
- Raises
- ValueError – 
 
 - 
get_α(λ, ε)¶
- Calculate rate parameter for regularisation from closeness criterion - Parameters
- λ (float) – coupling strength of the regularisation 
- ϵ (float) – closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be 
 
- Returns
- The steepness of the tanh-like function (or rate) which determines how fast the determinant of the covariance of the network outputs should be pushed to 1 
- Return type
- float 
 
 - 
plot(ax=None, expected_detF=None, colour='C0', figsize=(5, 15), label='', filename=None, ncol=1)¶
- Plot fitting history - Plots a three panel vertical plot with the determinant of the Fisher information matrix in the first sublot, the covariance and the inverse covariance in the second and the regularisation term and the regularisation coupling strength in the final subplot. - A predefined axes can be passed to fill, and these axes can be decorated via a call to - _setup_plot(for horizonal plots for example).- Example - Assuming that an IMNN has been fit (as in the example in - imnn.imnn._imnn.IMNN.fit()) then we can make a training plot of the history by simply running- imnn.fit(expected_detF=50, filename="history_plot.png")   - Note we know the analytic value of the determinant of the Fisher information for this problem (\(|\bf{F}|=50\)) so we can add this line to the plot too, and save the output as a png named - history_plot.- Parameters
- ax (mpl.axes or None, default=None) – An axes object of predefined axes to be labelled 
- expected_detF (float or None, default=None) – Value of the expected determinant of the Fisher information to plot a horizontal line at to check fitting progress 
- colour (str or rgb/a value or list, default="C0") – Colour to plot the lines 
- figsize (tuple, default=(5, 15)) – The size of the figure to be produced 
- label (str, default="") – Name to add to description in legend 
- filename (str or None, default=None) – Filename to save plot to 
- ncol (int, default=1) – Number of columns to have in the legend 
 
- Returns
- An axes object of the filled plot 
- Return type
- mpl.axes 
 
 - 
set_F_statistics(w=None, key=None, validate=True)¶
- Set necessary attributes for calculating score compressed summaries - Parameters
- w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters 
- key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly 
- validate (bool, default=True) – Whether to calculate Fisher information using the validation set 
 
 
 
Aggregation of gradients¶
- 
class imnn.imnn._aggregated_imnn._AggregatedIMNN(host, devices, n_per_device)¶
- Manual aggregation of gradients for the IMNN parent class - This class defines the overriding fitting functions for - _IMNN()which allows gradients to be aggregated manually. This is necessary if networks or input data are extremely large (or if the number of simulations necessary to estimate the covariance of network outputs,- n_s, is very large) since all operations may not fit in memory.- The aggregation is done by calculating - n_per_devicenetwork outputs at once on each available- jax.device()and then scanning over all- n_sinputs and- n_dsimulations necessary to calculate the derivative of the mean of the network outputs with respect to the model parameters. This gives a set of summaries and derivatives from which the loss function\[\Lambda = -\log|\bf{F}| + r(\Lambda_2) \Lambda_2\]- (See IMNN: Information maximising neural networks) can be calculated and its gradient with respect to these summaries, \(\frac{\partial\Lambda}{\partial x_i^j}\) and derivatives \(\frac{\partial\Lambda}{\partial\partial{x_i^j}/ \partial\theta_\alpha}\) calculated, where \(i\) labels the network output and \(j\) labels the simulation. Note that these are small in comparison to the gradient with respect to the network parameters since their sizes are - n_s * n_summariesand- n_d * n_summaries * n_paramsrespectively. Once \(\frac{\partial\Lambda}{\partial{x_i^j}}\) and \(\frac{\partial\Lambda}{\partial\partial{x_i^j}/\partial \theta_\alpha}\) are calculated then gradients of the network outputs with respect to the network parameters \(\frac{\partial{x_i^j}}{\partial{w_{ab}^l}}\) and \(\frac{\partial\partial{x_i^j}/\partial\theta_\alpha} {\partial{w_{ab}^l}}\) are calculated and the chain rule is used to get\[\frac{\partial\Lambda}{\partial{w_{ab}^l}} = \frac{\partial \Lambda}{\partial{x_i^j}}\frac{\partial{x_i^j}} {\partial{w_{ab}^l}} + \frac{\partial\Lambda} {\partial\partial{x_i^j}/\partial\theta_\alpha} \frac{\partial\partial{x_i^j}/\partial\theta_\alpha} {\partial{w_{ab}^l}}\]- Note that we keep the memory use low because only - n_per_devicesimulations are handled at once before being summed into a single gradient list on each device.- n_per_devicesshould be as large as possible to get the best performance. If everything will fit in memory then this class should be avoided.- The AttributedIMNN class doesn’t directly inherit from - _IMNN(), but is meant to be built within a child class of it. For this reason there are attributes which are not explicitly set here, but are used within the module. These will be noted in Other Parameters below.- Parameters
- host (jax.device) – The main device where the Fisher information calculation is performed 
- devices (list) – A list of the available jax devices (from - jax.devices())
- n_devices (int) – Number of devices to aggregated calculation over 
- n_per_device (int) – Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance 
 
 - 
model:
- Neural network as a function of network parameters and inputs 
 - 
_get_parameters:
- Function which extracts the network parameters from the state 
 - 
_model_initialiser:
- Function to initialise neural network weights from RNG and shape tuple 
 - 
_opt_initialiser:
- Function which generates the optimiser state from network parameters 
 - 
_update:
- Function which updates the state from a gradient 
 - 
batch_summaries:
- Jitted function to calculate summaries on each XLA device 
 - 
batch_summaries_with_derivatives:
- Jitted function to calculate summaries from derivative on each device 
 - 
batch_gradients:
- Jitted function to calculate gradient on each XLA device 
 - 
batch_gradients_with_derivatives:
- Jitted function to calculate gradient from derivative on eachdevice 
 - Public Methods: - __init__(host, devices, n_per_device)- Constructor method - fit(λ, ε[, rng, patience, min_iterations, …])- Fitting routine for the IMNN - get_summaries(w[, key, validate])- Gets all network outputs and derivatives wrt model parameters - get_gradient(dΛ_dx, w[, key])- Aggregates gradients together to update the network parameters - Private Methods: - _set_devices(devices, n_per_device)- Checks that devices exist and that reshaping onto devices can occur - Creates jitted functions placed on desired XLA devices - Calculates the shapes for batching over different devices - _setup_progress_bar(print_rate, max_iterations)- Construct progress bar - _update_progress_bar(pbar, counter, …[, close])- Updates (and closes) progress bar - _collect_input(key[, validate])- Returns the dataset to be interated over - _get_batch_summaries(inputs, w, θ[, …])- Vectorised batch calculation of summaries or gradients - _split_dΛ_dx(dΛ_dx)- Separates dΛ_dx and d2Λ_dxdθ and reshapes them for aggregation - _construct_gradient(layers[, aux, func])- Multiuse function to iterate over tuple of network parameters 
 - 
_collect_input(key, validate=False)¶
- Returns the dataset to be interated over - Parameters
- key (int(2,)) – A random number generator 
- validate (bool, default=False) – Whether to use the validation set or not 
 
 
 - 
_construct_gradient(layers, aux=None, func='zeros')¶
- Multiuse function to iterate over tuple of network parameters - The options are:
- "zeros"– to create an empty gradient array
- "einsum"– to combine tuple of- dx_dwwith- dΛ_dx
- "derivative_einsum"– to combine tuple of- d2x_dwdθwith- d2Λ_dxdθ
- "sum"– to reduce sum batches of gradients on the first axis
 
 - Parameters
- layers (tuple) – The tuple of tuples of arrays to be iterated over 
- aux (float(various shapes)) – parameter to pass dΛ_dx and d2Λ_dxdθ to einsum 
- func (str) – - Option for the function to apply
- "zeros"– to create an empty gradient array
- "einsum"– to combine tuple of- dx_dwwith- dΛ_dx
- "derivative_einsum"– to combine tuple of- d2x_dwdθwith- d2Λ_dxdθ
- "sum"– to reduce sum batches of gradients on the first axis
 
 
 
- Returns
- Tuple of objects like the gradient of the loss function with respect to the network parameters 
- Return type
- tuple 
- Raises
- ValueError – If applied function is not implemented 
 
 - 
_get_batch_summaries(inputs, w, θ, gradient=False, derivative=False)¶
- Vectorised batch calculation of summaries or gradients - Parameters
- inputs (tuple) – - dΛ_dx if - gradient(float(n_per_device, n_summaries) or tuple)- dΛ_dx (float(n_per_device, n_summaries)) – The gradient of the loss function with respect to network outputs 
- d2Λ_dxdθ if - derivative(float(n_per_device, n_summaries, n_params)) – The gradient of the loss function with respect to derivative of network outputs with respect to model parameters
 
- keys if - SimulatorIMNN()(int(n_per_device, 2)) – The keys for generating simulations on-the-fly
- d if - NumericalGradientIMNN()(float(n_per_device, input_shape) or tuple)- d (float(n_per_device, input_shape)) – The simulations to be evaluated 
- dd_dθ if - derivative(float(n_per_device, input_shape, n_params)) – The derivative of the simulations to be evaluated with respect to model parameters
 
 
- w (list) – Network model parameters 
- θ (float(n_params,)) – The value of the model parameters to generate simulations at/to perform the derivative calculation 
- gradient (bool) – Whether to do the gradient calculation 
- derivative (bool, default=False) – Whether the gradient of loss function with respect to the derivative of the network outputs with respect to the model parameters is being used 
 
- Returns
- x if not - gradient(float(n_per_device, n_summaries) or tuple)- x (float(n_per_device, n_summaries)) – The network outputs 
- dd_dθ if - derivative(float(n_per_device, n_summaries, n_params)) – The derivative of the network outputs with respect to model parameters
 
- if gradient
- (tuple) – The accumlated and aggregated gradient of the loss function with respect to the network parameters 
 
- if 
 
- Return type
- float(n_devices, n_per_device, n_summaries) or tuple 
 
 - 
_set_batch_functions()¶
- Creates jitted functions placed on desired XLA devices - For each set of summaries to correctly be calculated on a particular device we predefine the jitted functions on each of these devices 
 - 
_set_devices(devices, n_per_device)¶
- Checks that devices exist and that reshaping onto devices can occur - Due to the aggregation then balanced splits must be made between the different devices and so these are checked. - Parameters
- devices (list) – A list of the available jax devices (from - jax.devices())
- n_per_device (int) – Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance 
 
- Raises
- ValueError – If - devicesor- n_per_deviceare None
- ValueError – If balanced splitting cannot be done 
- TypeError – If - devicesis not a list and if- n_per_deviceis not an int
 
 
 - 
_set_shapes()¶
- Calculates the shapes for batching over different devices - Not implemented - Raises
- ValueError – Not implemented in _AggregatedIMNN 
 
 - 
_setup_progress_bar(print_rate, max_iterations)¶
- Construct progress bar - Parameters
- print_rate (int or None) – The rate at which the progress bar is updated (no bar if None) 
- max_iterations (int) – The maximum number of iterations, used to setup bar upper limit 
 
- Returns
- progress bar or None – The TQDM progress bar object 
- int or None – The print rate (after checking for int or None) 
- int or None – The difference between the max_iterations and the print rate 
 
- Raises
- TypeError: – If - print_rateis not an integer
 
 - 
_split_dΛ_dx(dΛ_dx)¶
- Separates dΛ_dx and d2Λ_dxdθ and reshapes them for aggregation - Parameters
- dΛ_dx (tuple) – - dΛ_dx (float(n_s, n_summaries)) – The derivative of the loss function wrt the network outputs 
- d2Λ_dxdθ (float(n_d, n_summaries, n_params)) – The derivative of the loss function wrt the derivative of the network outputs wrt the model parameters 
 
- Raises
- ValueError – function not implemented in parent class 
 
 - 
_update_progress_bar(pbar, counter, patience_counter, max_detF, detF, detC, detinvC, Λ2, r, print_rate, max_iterations, remainder, close=False)¶
- Updates (and closes) progress bar - Checks whether a pbar is used and is so checks whether the iteration coincides with the print rate, or is the last set of iterations within the print rate from the last iteration, or if the last iteration has been reached and the bar should be closed. - Parameters
- pbar (progress bar object) – The TQDM progress bar 
- counter (int) – The value of the current iteration 
- patience_counter (int) – The number of iterations where the maximum of the determinant of the Fisher information matrix has not increased 
- max_detF (float) – Maximum of the determinant of the Fisher information matrix 
- detF (float(n_params, n_params)) – Fisher information matrix 
- detC (float(n_summaries, n_summaries)) – Covariance of the network summaries 
- detinvC (float(n_summaries, n_summaries)) – Inverse covariance of the network summaries 
- Λ2 (float) – Value of the regularisation term 
- r (float) – Value of the dynamic regularisation coupling strength 
- print_rate (int or None) – The number of iterations to run before updating the progress bar 
- max_iterations (int) – The maximum number of iterations to run 
- remainder (int or None) – The number of iterations before max_iterations to check progress 
- close (bool, default=False) – Whether to close the progress bar (on final iteration) 
 
 
 - 
fit(λ, ε, rng=None, patience=100, min_iterations=100, max_iterations=100000, print_rate=None, best=True)¶
- Fitting routine for the IMNN - Parameters
- λ (float) – Coupling strength of the regularisation 
- ϵ (float) – Closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be 
- rng (int(2,) or None, default=None) – Stateless random number generator 
- patience (int, default=10) – Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping 
- min_iterations (int, default=100) – Number of iterations that should be run before considering early stopping using the patience counter 
- max_iterations (int, default=int(1e5)) – Maximum number of iterations to run the fitting procedure for 
- print_rate (int or None, default=None,) – Number of iterations before updating the progress bar whilst fitting. There is a performance hit from updating the progress bar more often and there is a large performance hit from using the progress bar at all. (Possible - RET_CHECKfailure if- print_rateis not- Nonewhen using GPUs). For this reason it is set to None as default
- best (bool, default=True) – Whether to set the network parameter attribute - self.wto the parameter values that obtained the maximum determinant of the Fisher information matrix or the parameter values at the final iteration of fitting
 
 - Example - We are going to summarise the mean and variance of some random Gaussian noise with 10 data points per example using an AggregatedSimulatorIMNN. In this case we are going to generate the simulations on-the-fly with a simulator written in jax (from the examples directory). These simulations will be generated on-the-fly and passed through the network on each of the GPUs in - jax.devices("gpu")and we will make 100 simulations on each device at a time. The main computation will be done on the CPU. We will use 1000 simulations to estimate the covariance of the network outputs and the derivative of the mean of the network outputs with respect to the model parameters (Gaussian mean and variance) and generate the simulations at a fiducial μ=0 and Σ=1. The network will be a stax model with hidden layers of- [128, 128, 128]activated with leaky relu and outputting 2 summaries. Optimisation will be via Adam with a step size of- 1e-3. Rather arbitrarily we’ll set the regularisation strength and covariance identity constraint to λ=10 and ϵ=0.1 (these are relatively unimportant for such an easy model).- import jax import jax.numpy as np from jax.experimental import stax, optimizers from imnn import AggregatedSimulatorIMNN rng = jax.random.PRNGKey(0) n_s = 1000 n_d = 1000 n_params = 2 n_summaries = 2 input_shape = (10,) θ_fid = np.array([0., 1.]) def simulator(rng, θ): return θ[0] + jax.random.normal( rng, shape=input_shape) * np.sqrt(θ[1]) model = stax.serial( stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(n_summaries)) optimiser = optimizers.adam(step_size=1e-3) λ = 10. ϵ = 0.1 model_key, fit_key = jax.random.split(rng) host = jax.devices("cpu")[0] devices = jax.devices("gpu") n_per_device = 100 imnn = AggregatedSimulatorIMNN( n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries, input_shape=input_shape, θ_fid=θ_fid, model=model, optimiser=optimiser, key_or_state=model_key, simulator=simulator, host=host, devices=devices, n_per_device=n_per_device) imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250, print_rate=None) - Notes - A minimum number of interations should be be run before stopping based on a maximum determinant of the Fisher information achieved since the loss function has dual objectives. Since the determinant of the covariance of the network outputs is forced to 1 quickly, this can be at the detriment to the value of the determinant of the Fisher information matrix early in the fitting procedure. For this reason starting early stopping after the covariance has converged is advised. This is not currently implemented but could be considered in the future. - The best fit network parameter values are probably not the most representative set of parameters when simulating on-the-fly since there is a high chance of a statistically overly-informative set of data being generated. Instead, if using - fit()consider using- best=Falsewhich sets- self.w=self.final_wwhich are the network parameter values obtained in the last iteration. Also consider using a larger- patiencevalue if using- fit()to overcome the fact that a flukish high value for the determinant might have been obtained due to the realisation of the dataset.- Raises
- TypeError – If any input has the wrong type 
- ValueError – If any input (except - rng) are- None
- ValueError – If - rnghas the wrong shape
- ValueError – If - rngis- Nonebut simulating on-the-fly
 
 - 
get_keys_and_params:
- Jitted collection of parameters and random numbers 
 - 
calculate_loss:
- Returns the jitted gradient of the loss function wrt summaries 
 - 
validation_loss:
- Jitted loss and auxillary statistics from validation set 
 
 - 
get_gradient(dΛ_dx, w, key=None)¶
- Aggregates gradients together to update the network parameters - To avoid having to calculate the gradient with respect to all the simulations at once we aggregate by addition the gradient calculation by looping over the simulations again and combining them with the derivative of the loss function with respect to the network outputs (and their derivatives with respect to the model parameters). Whilst this is expensive, it is necessary since we cannot make a stochastic estimate of the Fisher information accurately and therefore we need to use all the simulations available - which is probably too large to fit in memory. - Parameters
- dΛ_dx (tuple) – - dΛ_dx float(n_s, n_params, n_summaries) – the derivative of the loss function with respect to network summaries 
- d2Λ_dxdθ float(n_d, n_summaries, n_params) – the derivative of the loss function with respect to the derivative of network summaries with respect to model parameters 
 
- w (list) – Network parameters 
- key (None or int(2,)) – Random number generator used in SimulatorIMNN 
 
- Returns
- The gradient of the loss function with respect to the network parameters calculated by aggregating 
- Return type
- list 
- Raises
- ValueError – function not implemented in parent class 
 
 - 
get_summaries(w, key=None, validate=False)¶
- Gets all network outputs and derivatives wrt model parameters - Parameters
- w (list or None, default=None) – The network parameters if wanting to calculate the Fisher information with a specific set of network parameters 
- key (int(2,) or None, default=None) – A random number generator for generating simulations on-the-fly 
- validate (bool, default=True) – Whether to get summaries of the validation set 
 
- Returns
- float(n_s, n_summaries) – The network outputs 
- float(n_d, n_summaries, n_params) – The derivative of the network outputs wrt the model parameters 
 
- Raises
- ValueError – function not implemented in parent class