Predictors

Predictors in the Mellon framework can be invoked directly via their __call__ method to produce function estimates at new locations. These predictors can also double as Gaussian Processes, offering uncertainty estimattion options. It also comes with serialization capabilities detailed in serialization.

Basic Usage

To generate estimates for new, out-of-sample locations, instantiate a predictor and call it like a function:

Example of accessing the mellon.Predictor from the mellon.model.DensityEstimator in Mellon Framework
 model = mellon.model.DensityEstimator(...)  # Initialize the model with appropriate arguments
 model.fit(X)  # Fit the model to the data
 predictor = model.predict  # Obtain the predictor object
 predicted_values = predictor(Xnew)  # Generate predictions for new locations

Uncertainy

If the predictor was generated with uncertainty estimates (typically by passing predictor_with_uncertainty=True and optimizer=”advi” to the model class, e.g., mellon.model.DensityEstimator) then it provides methods for computing variance at these locations, and co-variance to any other location.

Sub-Classes

The Predictor module in the Mellon framework features a variety of specialized subclasses of mellon.Predictor. The specific subclass instantiated by the model is contingent upon two key parameters:

  • gp_type: This argument determines the type of Gaussian Process used internally.

  • The nature of the predicted output: This can be real-valued, strictly positive, or time-sensitive.

The gp_type argument mainly affects the internal mathematical operations, whereas the nature of the predicted value dictates the subclass’s functional capabilities:

Vanilla Predictor

Utilized in the following methods:

class mellon.PredictorView on GitHub

Bases: ABC

Abstract base class for predictor models. It provides a common interface for all subclasses, which are expected to implement the _mean method for making predictions.

An instance predictor of a subclass of Predictor can be used to make a prediction by calling it with input data x:

>>> y = predictor(x)

It is the responsibility of subclasses to define the behaviour of _mean.

__call__(x: Union[array-like, pd.DataFrame], normalize: bool = False):

Equivalent to calling the mean method, this uses the trained model to make predictions based on the input array, x.

The prediction corresponds to the mean of the Gaussian Process conditional distribution of predictive functions.

The input array must be 2D with the length of its second dimension matching the number of features used in training the model.

Parameters:
  • x (array-like) – The input data to the predictor, having shape (n_samples, n_input_features).

  • normalize (bool) – Optional normalization by subtracting log(self.n_obs) (number of cells trained on), applicable only for cell-state density predictions. Default is False.

Returns:

The predicted output generated by the model.

Return type:

array

Raises:

ValueError – If the number of features in ‘x’ does not align with the number of features the predictor was trained on.

n_obs

The number of samples or cells that the model was trained on. This attribute is critical for normalization purposes, particularly when the normalize parameter in the __call__ method is set to True.

Type:

int

n_input_features

The number of features/dimensions of the cell-state representation the predictor was trained on. This is used for validation of input data.

Type:

int

covariance(x, diag=True)View on GitHub

Computes the covariance of the Gaussian Process distribution of functions over new data points or cell states.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the covariance.

  • diag (boolean, optional (default=True)) – Whether to return the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

classmethod from_dict(data_dict)View on GitHub

Deserialize the predictor from a python dictionay.

This method deserializes the predictor from a python dictionary.

Parameters:

data_dict (dict) – The dictionary from which to deserialize the predictor.

Returns:

An instance of the predictor.

Return type:

Predictor subclass instance

classmethod from_json(filepath, compress=None)View on GitHub

Deserialize the predictor from a JSON file.

This method deserializes the predictor from a JSON file. It automatically detects the compression method based on the file extension or uses the compress keyword to determine the compression method.

Parameters:
  • filename (str, pathlib.Path, or os.path) – The path of the JSON file from which to deserialize the predictor.

  • compress (str, optional) – The compression method to use (‘gzip’ or ‘bz2’). If None, no compression is used.

Returns:

An instance of the predictor.

Return type:

Predictor subclass instance

classmethod from_json_str(json_str)View on GitHub

Deserialize the predictor from a JSON string.

This method deserializes the predictor from the content of a JSON file.

Parameters:

json_str (str) – The JSON string from which to deserialize the predictor.

Returns:

An instance of the predictor.

Return type:

Predictor subclass instance

gradient(x, jit=True)View on GitHub

Conputes the gradient of the predict function for each line in x.

Parameters:
  • x (array-like) – Data points.

  • jit (bool) – Use jax just in time compilation. Defaults to True.

Returns:

gradiants - The gradient of function at each point in x. gradients.shape == x.shape

Return type:

array-like

hessian(x, jit=True)View on GitHub

Conputes the hessian of the predict function for each line in x.

Parameters:
  • x (array-like) – Data points.

  • jit (bool) – Use jax just in time compilation. Defaults to True.

Returns:

hessians - The hessian matrix of function at each point in x. hessians.shape == X.shape + X.shape[1:]

Return type:

array-like

hessian_log_determinant(x, jit=True)View on GitHub

Conputes the logarirhm of the determinat of the predict function for each line in x.

Parameters:
  • x (array-like) – Data points.

  • jit (bool) – Use jax just in time compilation. Defaults to True.

Returns:

signs, log_determinants - The sign of the determinant at each point x and the logarithm of its absolute value. signs.shape == log_determinants.shape == x.shape[0]

Return type:

array-like, array-like

mean(x, normalize=False)View on GitHub

Use the trained model to make a prediction based on the input array, x.

The prediction represents the mean of the Gaussian Process conditional distribution of predictive functions.

The input array should be 2D with its second dimension’s length equal to the number of features used in training the model.

Parameters:
  • x (array-like) – The input data to the predictor. The array should have shape (n_samples, n_input_features).

  • normalize (bool) – Whether to normalize the value by subtracting log(self.n_obs) (number of cells trained on). Applicable only for cell-state density predictions. Default is False.

Returns:

The predicted output generated by the model.

Return type:

array

Raises:

ValueError – If the number of features in ‘x’ does not match the number of features the predictor was trained on.

mean_covariance(x, diag=True)View on GitHub

Computes the uncertainty of the mean of the Gaussian process induced by the uncertainty of the latent representation of the mean function.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • diag (boolean, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

to_dict()View on GitHub

Serialize the predictor to a python dictionary.

Returns:

A python dictionary with the predictor data.

Return type:

dict

to_json(filename=None, compress=None)View on GitHub

Serialize the predictor to a JSON file.

This method serializes the predictor to a JSON file. It can optionally compress the JSON file using gzip or bz2 compression. It automatically detects the compression method based on the file extension or use the compress keyword to determine the compression method. It also makes sure the file is saved with the appropriate file extension.

Parameters:
  • filename (str or None) – The name of the JSON file to which to serialize the predictor. If filname is None then the JSON string is returned instead.

  • compress (str, optional) – The compression method to use (‘gzip’ or ‘bz2’). If None, no compression is used.

uncertainty(x, diag=True)View on GitHub

Computes the total uncertainty of the predicted values quantified by their variance or covariance.

The total uncertainty is defined by .covariance + .mean_covariance.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • diag (bool, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,) if diag=True) – The variances for each sample in the new data points.

  • cov (array-like, shape (n_samples, n_samples) if diag=False) – The full covariance matrix between the samples in the new data points.

Exponential Predictor

class mellon.base_predictor.ExpPredictorView on GitHub

Bases: Predictor

Abstract base class for predictor models which returs the exponent of its _mean method upon a call.

An instance predictor of a subclass of Predictor can be used to make a prediction by calling it with input data x:

>>> y = predictor(x)

It is the responsibility of subclasses to define the behaviour of _mean.

covariance(x, diag=True)View on GitHub

Computes the covariance of the Gaussian Process distribution of functions over new data points or cell states.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the covariance.

  • diag (boolean, optional (default=True)) – Whether to return the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

mean(x, logscale=False)View on GitHub

Use the trained model to make a prediction based on the input array, x.

The input array should be 2D with its second dimension’s length equal to the number of features used in training the model.

Parameters:
  • x (array-like) – The input data to the predictor. The array should have shape (n_samples, n_input_features).

  • logscale (bool) – Weather the predicted value should be returned in log scale. Default is False.

Returns:

The predicted output generated by the model.

Return type:

array

Raises:

ValueError – If the number of features in ‘x’ does not match the number of features the predictor was trained on.

mean_covariance(x, diag=True)View on GitHub

Computes the uncertainty of the mean of the Gaussian process induced by the uncertainty of the latent representation of the mean function.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • diag (boolean, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

uncertainty(x, diag=True)View on GitHub

Computes the total uncertainty of the predicted values quantified by their variance or covariance.

The total uncertainty is defined by .covariance + .mean_covariance.

Parameters:
  • x (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • diag (bool, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

Returns:

  • var (array-like, shape (n_samples,) if diag=True) – The variances for each sample in the new data points.

  • cov (array-like, shape (n_samples, n_samples) if diag=False) – The full covariance matrix between the samples in the new data points.

Time-sensitive Predictor

class mellon.base_predictor.PredictorTimeView on GitHub

Bases: Predictor

Abstract base class for predictor models with a time covariate.

An instance predictor of a subclass of PredictorTime can be used to make a prediction by calling it with input data x and time:

>>> y = predictor(x, time)

It is the responsibility of subclasses to define the behaviour of _mean.

__call__(x: Union[array-like, pd.DataFrame], normalize: bool = False):

Equivalent to calling the mean method, this uses the trained model to make predictions based on the input array ‘Xnew’, considering the specified ‘time’ or ‘multi_time’.

The predictions represent the mean of the Gaussian Process conditional distribution of predictive functions.

If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

Parameters:
  • Xnew (array-like) – The new data points for prediction.

  • time (scalar or array-like, optional) – The time points associated with each row in ‘Xnew’. If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

  • normalize (bool) – Optional normalization by subtracting log(self.n_obs) (number of cells trained on), applicable only for cell-state density predictions. Default is False.

Returns:

The predicted output generated by the model.

Return type:

array

Raises:

ValueError – If the number of features in ‘x’ does not align with the number of features the predictor was trained on.

n_obs

The average number of samples or cells per time point that the model was trained on. This attribute is critical for normalization purposes, particularly when the normalize parameter in the __call__ method is set to True.

Type:

int

n_input_features

The number of features/dimensions of the cell-state representation the predictor was trained on. This is used for validation of input data.

Type:

int

covariance(Xnew, time=None, diag=True, multi_time=None)View on GitHub

Computes the covariance of the Gaussian Process distribution of functions over new data points or cell states.

Parameters:
  • Xnew (array-like, shape (n_samples, n_features)) – The new data points for which to compute the covariance.

  • time (scalar or array-like, optional) – The time points associated with each cell/row in ‘Xnew’. If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

  • diag (boolean, optional (default=True)) – Whether to return the variance (True) or the full covariance matrix (False).

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a covariance for all states in x will be computed for each time value in multi_time separatly.

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

gradient(x, time, jit=True, multi_time=None)View on GitHub

Computes the gradient of the prediction function for each point in x at a given time.

Parameters:
  • x (array-like) – Data points at which the gradient is to be computed.

  • time (array-like or float) – Time point or points at which to evaluate the derivative. If time is a float, the derivative will be computed at this specific time point for all data points in x. If time is an array, it should be 1-D and the time derivative will be computed for all data-points at the corresponding time in the array.

  • jit (bool, optional) – If True, use JAX’s just-in-time (JIT) compilation to speed up the computation. Defaults to True.

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a gradient for all states in x will be made for each time value in multi_time separatly.

Returns:

The gradient of the prediction function at each point in x. The shape of the output array is the same as x.

Return type:

array-like

hessian(x, time, jit=True, multi_time=None)View on GitHub

Computes the Hessian matrix of the prediction function for each point in x at a given time.

Parameters:
  • x (array-like) – Data points at which the Hessian matrix is to be computed.

  • time (array-like or float) – Time point or points at which to evaluate the derivative. If time is a float, the derivative will be computed at this specific time point for all data points in x. If time is an array, it should be 1-D and the time derivative will be computed for all data-points at the corresponding time in the array.

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then the computation will be made for each row.

  • jit (bool, optional) – If True, use JAX’s just-in-time (JIT) compilation to speed up the computation. Defaults to True.

  • multi_time – If ‘multi_time’ is specified then a hessian for all states in x will be computed for each time value in multi_time separatly.

Returns:

The Hessian matrix of the prediction function at each point in x. The shape of the output array is x.shape + x.shape[1:].

Return type:

array-like

hessian_log_determinant(x, time, jit=True, multi_time=None)View on GitHub

Computes the logarithm of the determinant of the Hessian of the prediction function for each point in x at a given time.

Parameters:
  • x (array-like) – Data points at which the log determinant is to be computed.

  • time (array-like or float) – Time point or points at which to evaluate the derivative. If time is a float, the derivative will be computed at this specific time point for all data points in x. If time is an array, it should be 1-D and the time derivative will be computed for all data-points at the corresponding time in the array.

  • jit (bool, optional) – If True, use JAX’s just-in-time (JIT) compilation to speed up the computation. Defaults to True.

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a log determinant for all states in x will be computed for each time value in multi_time separatly.

Returns:

The sign of the determinant at each point in x and the logarithm of its absolute value. signs.shape == log_determinants.shape == x.shape[0].

Return type:

array-like

mean(Xnew, time=None, normalize=False, multi_time=None)View on GitHub

Use the trained model to make predictions based on the input array ‘Xnew’, considering the specified ‘time’ or ‘multi_time’.

The predictions represent the mean of the Gaussian Process conditional distribution of predictive functions.

If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

Parameters:
  • Xnew (array-like) – The new data points for prediction.

  • time (scalar or array-like, optional) – The time points associated with each row in ‘Xnew’. If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

  • normalize (bool) – Whether to normalize the value by subtracting log(self.n_obs) (number of cells trained on). Applicable only for cell-state density predictions. Default is False.

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a prediction for all states in x will be made for each time value in multi_time separatly.

Returns:

Predictions for ‘Xnew’.

Return type:

array-like

Raises:

ValueError – If ‘time’ is an array and its size does not match ‘Xnew’.

mean_covariance(Xnew, time=None, diag=True, multi_time=None)View on GitHub

Computes the uncertainty of the mean of the Gaussian process induced by the uncertainty of the latent representation of the mean function.

Parameters:
  • Xnew (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • time (scalar or array-like, optional) – The time points associated with each cell/row in ‘Xnew’. If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

  • diag (boolean, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a mean covariance for all states in x will be computed for each time value in multi_time separatly.

Returns:

  • var (array-like, shape (n_samples,)) – If diag=True, returns the variances for each sample.

  • cov (array-like, shape (n_samples, n_samples)) – If diag=False, returns the full covariance matrix between samples.

time_derivative(x, time, jit=True, multi_time=None)View on GitHub

Computes the time derivative of the prediction function for each line in x.

This function applies a jax-based gradient operation to the density function evaluated at a specific time. The derivative is with respect to time and not the inputs in x.

Parameters:
  • x (array-like) – Data points where the derivative is to be evaluated.

  • time (array-like or float) – Time point or points at which to evaluate the derivative. If time is a float, the derivative will be computed at this specific time point for all data points in x. If time is an array, it should be 1-D and the time derivative will be computed for all data-points at the corresponding time in the array.

  • jit (bool, optional) – If True, use JAX’s just-in-time (JIT) compilation to speed up the computation. Defaults to True.

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a time derivative for all states in x will be computed for each time value in multi_time separatly.

Returns:

The time derivative of the prediction function evaluated at each point in x. The shape of the output array is the same as x.

Return type:

array-like

uncertainty(Xnew, time=None, diag=True, multi_time=None)View on GitHub

Computes the total uncertainty of the predicted values quantified by their variance or covariance.

The total uncertainty is defined by .covariance + .mean_covariance.

Parameters:
  • Xnew (array-like, shape (n_samples, n_features)) – The new data points for which to compute the uncertainty.

  • time (scalar or array-like, optional) – The time points associated with each cell/row in ‘Xnew’. If ‘time’ is a scalar, it will be converted into a 1D array of the same size as ‘Xnew’.

  • diag (bool, optional (default=True)) – Whether to compute the variance (True) or the full covariance matrix (False).

  • multi_time (array-like, optional) – If ‘multi_time’ is specified then a uncertainty for all states in x will be computed for each time value in multi_time separatly.

Returns:

  • var (array-like, shape (n_samples,) if diag=True) – The variances for each sample in the new data points.

  • cov (array-like, shape (n_samples, n_samples) if diag=False) – The full covariance matrix between the samples in the new data points.