API Reference
astroemu.network
Neural network implementations for emu package.
initialise_mlp(in_size, out_size, hidden_size, nlayers, key, scale=0.1)
Initialize MLP parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_size
|
int
|
Input size. |
required |
out_size
|
int
|
Output size. |
required |
hidden_size
|
int
|
Hidden layer size. |
required |
nlayers
|
int
|
Number of hidden layers. |
required |
key
|
int
|
JAX random key. |
required |
scale
|
float
|
Scale for weight initialization. Defaults to 1e-1. |
0.1
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
MLP parameters. |
Source code in astroemu/network.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | |
mlp(params, input, act='relu')
Multi-layer perceptron with residual connections.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict
|
MLP parameters. |
required |
input
|
ndarray
|
Input array of shape [..., in_size]. |
required |
act
|
str
|
Activation function name from jax.nn. Defaults to "relu". Must be treated as a static argument if JIT-compiling mlp directly (static_argnames=("act",)). |
'relu'
|
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Output array of shape [..., out_size]. |
Source code in astroemu/network.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
astroemu.dataloaders
Data loaders for emu package.
SpectrumDataset
Dataset for loading spectra from .npz files.
Allows for optional preprocessing via a forward pipeline and selection of variable input parameters.
Source code in astroemu/dataloaders.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | |
__getitem__(idx)
Get spectrum and input parameters for given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
Index of the data point. |
required |
Returns:
| Type | Description |
|---|---|
tuple[ndarray, ndarray, ndarray]
|
tuple[jnp.ndarray, jnp.ndarray]: Tuple of (spectrum, input parameters). |
Source code in astroemu/dataloaders.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
__init__(files, x, y, forward_pipeline=None, variable_input=None, tiling=True, allow_pickle=False)
Initialize SpectrumDataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
files
|
list[str]
|
List of file paths to .npz files. |
required |
x
|
str
|
Key for independent variable in .npz files. |
required |
y
|
str
|
Key for dependent variable in .npz files. |
required |
forward_pipeline
|
Any
|
Preprocessing pipeline. Defaults to None. |
None
|
variable_input
|
list[str] | str | None
|
Keys for variable input parameters. If None, all parameters except x and y are used. Defaults to None. |
None
|
tiling
|
bool
|
Whether to tile input/output parameters. This is True by default since this is what makes astroemu (and globaemu) tick. However, you might want to turn it off if you want to use the dataset for something other than emulation or if you want to calcualte things like rolling averages using astroemu.utils.compute_mean_std. Note normalisation is applied before tiling. Defaults to True. |
True
|
allow_pickle
|
bool
|
Whether to allow loading pickled objects from .npz files. Defaults to False. |
False
|
Source code in astroemu/dataloaders.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
__len__()
Return number of files in dataset.
Source code in astroemu/dataloaders.py
85 86 87 | |
get_batch_iterator(batch_size, shuffle=True, key=None)
Yield batches of spectra and inputs as jnp.ndarray.
When tiling=True, yields (specs_flat, concat_inputs) where specs_flat has shape (batch * len_x,) and concat_inputs has shape (batch * len_x, n_params + 1) with x prepended as the first column.
When tiling=False, yields (specs, x, inputs) with shapes (batch, len_x), (batch, len_x), and (batch, n_params) respectively. This mode is suitable for computing rolling statistics via astroemu.utils.compute_mean_std and building normalisation pipelines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of samples per batch. |
required |
shuffle
|
bool
|
Whether to shuffle indices. Defaults to True. When tiling=True, this also shuffles the tiled samples within each batch so the network doesn't see x values in sequential order. |
True
|
key
|
Array | None
|
JAX PRNG key for shuffling. Required when shuffle=True. Defaults to None. |
None
|
Yields:
| Type | Description |
|---|---|
Generator
|
tiling=True: tuple[jnp.ndarray, jnp.ndarray] (specs_flat, concat_inputs) |
Generator
|
tiling=False: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] (specs, x, inputs) |
Source code in astroemu/dataloaders.py
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | |
load_spectrum(file, allow_pickle=False)
Load spectrum data from .npz file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
file
|
str
|
Path to .npz file. |
required |
allow_pickle
|
bool
|
Whether to allow loading pickled objects. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Dictionary containing data from .npz file. |
Source code in astroemu/dataloaders.py
11 12 13 14 15 16 17 18 19 20 21 22 23 | |
astroemu.normalisation
Normalisation pipelines for emu package.
NormalisationPipeline
Base class for normalisation pipelines.
Source code in astroemu/normalisation.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | |
backward(_y, _x, _params)
Apply backward transformation.
Source code in astroemu/normalisation.py
20 21 22 23 24 25 26 27 | |
forward(_y, _x, _params)
Apply forward transformation.
Source code in astroemu/normalisation.py
11 12 13 14 15 16 17 18 | |
log_base_10
Bases: NormalisationPipeline
Logarithm base 10 transformation for numerical stability.
Source code in astroemu/normalisation.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | |
__init__(y_selector=None, x_selector=None, params_selector=None, log_all_y=False, log_all_x=False, log_all_params=False, eps=1e-15)
Logarithm base 10 transformation for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_selector
|
list[int] | None
|
columns of the spectrum to apply log transformation. Assumes spectra are in the last dimension. None returns y without any transformation. |
None
|
x_selector
|
list[int] | None
|
indices of the independent variable to apply log transformation. None returns x unchanged. |
None
|
params_selector
|
list[int] | None
|
columns of the input parameters to apply log transformation. Assumes parameters are in the last dimension. None returns params without any transformation. |
None
|
log_all_y
|
bool
|
If True, apply log transformation to all columns of the spectrum. Overrides y_selector if True. |
False
|
log_all_x
|
bool
|
If True, apply log transformation to all elements of the independent variable. Overrides x_selector if True. |
False
|
log_all_params
|
bool
|
If True, apply log transformation to all columns of the input parameters. Overrides params_selector. |
False
|
eps
|
float
|
small value to add to avoid log(0). |
1e-15
|
Source code in astroemu/normalisation.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | |
backward(y, x, params)
Apply inverse log10 transformation to selected columns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
ndarray
|
Transformed spectrum array, shape (batch, len_x). |
required |
x
|
ndarray
|
Transformed independent variable, shape (batch, len_x). |
required |
params
|
ndarray
|
Transformed input parameters, shape (batch, n_params). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[ndarray, ndarray, ndarray]
|
Inverse transformed spectrum, independent variable, and parameters. |
Source code in astroemu/normalisation.py
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | |
forward(y, x, params)
Apply log10 transformation to selected columns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
ndarray
|
Spectrum array, shape (batch, len_x). |
required |
x
|
ndarray
|
Independent variable array, shape (batch, len_x). |
required |
params
|
ndarray
|
Input parameters array, shape (batch, n_params). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[ndarray, ndarray, ndarray]
|
Transformed spectrum, independent variable, and parameters. |
Source code in astroemu/normalisation.py
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | |
standardise
Bases: NormalisationPipeline
Standardisation normalisation pipeline.
Source code in astroemu/normalisation.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | |
__init__(y_mean, y_std, x_mean, x_std, params_mean, params_std, standardise_y=False, standardise_x=False, standardise_params=False)
Standardises the spectrum, independent variable, and parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_mean
|
ndarray
|
Mean of the spectrum. |
required |
y_std
|
ndarray
|
Standard deviation of the spectrum. |
required |
x_mean
|
float
|
Mean of the independent variable. |
required |
x_std
|
float
|
Standard deviation of the independent variable. |
required |
params_mean
|
ndarray
|
Mean of the input parameters. |
required |
params_std
|
ndarray
|
Standard deviation of the input parameters. |
required |
standardise_y
|
bool
|
Whether to standardise the spectrum. Defaults to False. |
False
|
standardise_x
|
bool
|
Whether to standardise the independent variable. Defaults to False. |
False
|
standardise_params
|
bool
|
Whether to standardise the input parameters. Defaults to False. |
False
|
Source code in astroemu/normalisation.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
backward(y, x, params)
Destandardise spectrum, independent variable, and input parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
ndarray
|
Standardised spectrum array, shape (batch, len_x). |
required |
x
|
ndarray
|
Standardised independent variable, shape (batch, len_x). |
required |
params
|
ndarray
|
Standardised input parameters, shape (batch, n_params). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[ndarray, ndarray, ndarray]
|
Destandardised spectrum, independent variable, and parameters. |
Source code in astroemu/normalisation.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | |
forward(y, x, params)
Standardise spectrum, independent variable, and input parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
ndarray
|
Spectrum array, shape (batch, len_x). |
required |
x
|
ndarray
|
Independent variable array, shape (batch, len_x). |
required |
params
|
ndarray
|
Input parameters array, shape (batch, n_params). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[ndarray, ndarray, ndarray]
|
Standardised spectrum, independent variable, and parameters. |
Source code in astroemu/normalisation.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | |
astroemu.train
Training loop for astroemu.
train(train_dataset, val_dataset, hidden_size, nlayers, act='relu', epochs=1000, patience=50, learning_rate=0.001, weight_decay=0.0001, batch_size=32, key=0, loss_fn=mse, loss_kwargs={})
Train an MLP emulator on spectral data using AdamW.
Initialises an MLP via initialise_mlp, then trains it using batches from train_dataset and val_dataset (both must have tiling=True). The per-batch train and validation steps are JIT-compiled. Training stops early if the validation loss does not improve for patience consecutive epochs, and the best parameters are returned.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_dataset
|
SpectrumDataset
|
Training dataset with tiling=True. |
required |
val_dataset
|
SpectrumDataset
|
Validation dataset with tiling=True. |
required |
hidden_size
|
int
|
Number of nodes in each hidden layer. |
required |
nlayers
|
int
|
Number of hidden layers. |
required |
act
|
str
|
Activation function name from jax.nn. Defaults to "relu". |
'relu'
|
epochs
|
int
|
Maximum number of training epochs. Defaults to 1000. |
1000
|
patience
|
int
|
Early stopping patience in epochs. Defaults to 50. |
50
|
learning_rate
|
float
|
AdamW learning rate. Defaults to 1e-3. |
0.001
|
weight_decay
|
float
|
AdamW weight decay. Defaults to 1e-4. |
0.0001
|
batch_size
|
int
|
Number of spectra per batch. Defaults to 32. |
32
|
key
|
int
|
Integer seed for JAX PRNG. Defaults to 0. |
0
|
loss_fn
|
Callable
|
Loss function to use. Defaults to mse. |
mse
|
loss_kwargs
|
dict
|
Additional keyword arguments for the loss function. Defaults to an empty dict. |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[dict, list[float], list[float]]
|
tuple[dict, list[float], list[float]]: Best network parameters, per-epoch training losses, and per-epoch validation losses. |
Source code in astroemu/train.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | |
astroemu.losses
Loss functions for astroemu.
kl(predictions, targets, noise, ndata=1)
Kullback-Leibler divergence loss.
From https://ui.adsabs.harvard.edu/abs/2025MNRAS.544..375B/abstract.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted probability distributions. |
required |
targets
|
ndarray
|
Target probability distributions. |
required |
noise
|
ndarray
|
Some estimate of noise in the data. |
required |
ndata
|
int
|
Number of data points. Defaults to 1. |
1
|
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Scalar KL divergence loss. |
Source code in astroemu/losses.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | |
mse(predictions, targets)
Mean squared error loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
ndarray
|
Predicted values. |
required |
targets
|
ndarray
|
Target values. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Scalar MSE loss. |
Source code in astroemu/losses.py
7 8 9 10 11 12 13 14 15 16 17 18 | |
astroemu.serialisation
Serialisation utilities for saving and loading trained emulators.
load(path)
Load a trained emulator from a .astroemu file.
For each saved dataset, this function attempts to reconstruct a SpectrumDataset using the saved file paths and pipeline. If any files are missing the raw config dict is returned under the dataset key instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Path to a .astroemu file. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Dictionary with the following keys:
|
Source code in astroemu/serialisation.py
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | |
save(path, params, train_losses, val_losses, hidden_size, nlayers, loss, train_dataset, val_dataset, test_dataset, act='relu', epochs=1000, patience=50, learning_rate=0.001, weight_decay=0.0001)
Save a trained emulator to a .astroemu file.
The .astroemu extension is appended automatically if not present. The file is a zip archive containing: - config.json : hyperparameters, training history, loss criterion, code version, and dataset configurations. - params.npz : network weight arrays. - pipeline.pkl : pickled normalisation pipeline instances for the training, validation, and test datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Destination path. The .astroemu extension is appended if not already present, e.g. "emulator" → "emulator.astroemu". |
required |
params
|
dict
|
Trained network parameters returned by train(). |
required |
train_losses
|
list[float]
|
Per-epoch training losses. |
required |
val_losses
|
list[float]
|
Per-epoch validation losses. |
required |
hidden_size
|
int
|
Number of nodes in each hidden layer. |
required |
nlayers
|
int
|
Number of hidden layers. |
required |
act
|
str
|
Activation function name used during training. |
'relu'
|
epochs
|
int
|
Max epochs used during training. |
1000
|
patience
|
int
|
Early stopping patience used during training. |
50
|
learning_rate
|
float
|
AdamW learning rate used during training. |
0.001
|
weight_decay
|
float
|
AdamW weight decay used during training. |
0.0001
|
loss
|
str
|
Name of the loss criterion used. |
required |
train_dataset
|
SpectrumDataset
|
Training dataset whose file paths and pipeline are saved. |
required |
val_dataset
|
SpectrumDataset
|
Validation dataset whose file paths and pipeline are saved. |
required |
test_dataset
|
SpectrumDataset
|
Test dataset whose file paths and pipeline are saved. |
required |
Source code in astroemu/serialisation.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | |
astroemu.utils
Utility functions for emu package.
compute_mean_std(loader)
Memory safe mean and std computation.
Expects the loader to yield three-tuples (spec, x, inputs) as produced by SpectrumDataset.get_batch_iterator() with tiling=False.
Since x (the independent variable) is identical for every sample in the dataset, its mean and std are computed as global scalars over all elements rather than per-column statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loader
|
Iterator[tuple[ndarray, ndarray, ndarray]]
|
Iterable yielding (spec, x, inputs) where: - spec: (batch_size, len_x) - x: (batch_size, len_x) - inputs: (batch_size, n_params) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
mean_spec |
ndarray
|
(len_x,) - per-frequency mean across batches |
std_spec |
ndarray
|
(len_x,) - per-frequency std across batches |
mean_x |
ndarray
|
scalar Array - global mean of the independent variable |
std_x |
ndarray
|
scalar Array - global std of the independent variable |
mean_input |
ndarray
|
(n_params,) - per-parameter mean across batches |
std_input |
ndarray
|
(n_params,) - per-parameter std across batches |
Source code in astroemu/utils.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | |