Skip to content

API Reference

astroemu.network

Neural network implementations for emu package.

initialise_mlp(in_size, out_size, hidden_size, nlayers, key, scale=0.1)

Initialize MLP parameters.

Parameters:

Name Type Description Default
in_size int

Input size.

required
out_size int

Output size.

required
hidden_size int

Hidden layer size.

required
nlayers int

Number of hidden layers.

required
key int

JAX random key.

required
scale float

Scale for weight initialization. Defaults to 1e-1.

0.1

Returns:

Name Type Description
dict dict

MLP parameters.

Source code in astroemu/network.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def initialise_mlp(
    in_size: int,
    out_size: int,
    hidden_size: int,
    nlayers: int,
    key: int,
    scale: float = 1e-1,
) -> dict:
    """Initialize MLP parameters.

    Args:
        in_size (int): Input size.
        out_size (int): Output size.
        hidden_size (int): Hidden layer size.
        nlayers (int): Number of hidden layers.
        key (int): JAX random key.
        scale (float, optional): Scale for weight initialization.
            Defaults to 1e-1.

    Returns:
        dict: MLP parameters.
    """
    keys = random.split(key, nlayers * 2 + 2 + 2)
    weights = (
        [
            {
                "weights" + str(0): scale
                * random.normal(keys[0], (in_size, hidden_size)),
                "bias" + str(0): scale
                * random.normal(keys[1], (hidden_size,)),
            }
        ]
        + [
            {
                "weights" + str(i + 1): scale
                * random.normal(keys[i + 2], (hidden_size, hidden_size)),
                "bias" + str(i + 1): scale
                * random.normal(keys[i + 3], (hidden_size,)),
            }
            for i in range(nlayers)
        ]
        + [
            {
                "weights" + str(nlayers + 1): scale
                * random.normal(keys[-2], (hidden_size, out_size)),
                "bias" + str(nlayers + 1): scale
                * random.normal(keys[-1], (out_size,)),
            }
        ]
    )
    return {k: v for d in weights for k, v in d.items()}

mlp(params, input, act='relu')

Multi-layer perceptron with residual connections.

Parameters:

Name Type Description Default
params dict

MLP parameters.

required
input ndarray

Input array of shape [..., in_size].

required
act str

Activation function name from jax.nn. Defaults to "relu". Must be treated as a static argument if JIT-compiling mlp directly (static_argnames=("act",)).

'relu'

Returns:

Type Description
ndarray

jnp.ndarray: Output array of shape [..., out_size].

Source code in astroemu/network.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def mlp(params: dict, input: jnp.ndarray, act: str = "relu") -> jnp.ndarray:
    """Multi-layer perceptron with residual connections.

    Args:
        params (dict): MLP parameters.
        input (jnp.ndarray): Input array of shape [..., in_size].
        act (str): Activation function name from jax.nn. Defaults to
            "relu". Must be treated as a static argument if JIT-compiling
            mlp directly (static_argnames=("act",)).

    Returns:
        jnp.ndarray: Output array of shape [..., out_size].
    """
    act_fn = getattr(jax.nn, act)
    num_layers = len(params) // 2  # total layers: input + hidden(s) + output

    x = jnp.dot(input, params["weights0"]) + params["bias0"]

    for i in range(1, num_layers - 1):  # exclude final output layer
        residual = x
        x = act_fn(x)
        x = jnp.dot(x, params[f"weights{i}"]) + params[f"bias{i}"]
        # Residual connection (only if shapes match)
        x += residual

    # Final layer: linear only, no activation
    output = (
        jnp.dot(x, params[f"weights{num_layers - 1}"])
        + params[f"bias{num_layers - 1}"]
    )
    return output

astroemu.dataloaders

Data loaders for emu package.

SpectrumDataset

Dataset for loading spectra from .npz files.

Allows for optional preprocessing via a forward pipeline and selection of variable input parameters.

Source code in astroemu/dataloaders.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class SpectrumDataset:
    """Dataset for loading spectra from .npz files.

    Allows for optional preprocessing via a forward pipeline and
    selection of variable input parameters.
    """

    def __init__(
        self,
        files: list[str],
        x: str,
        y: str,
        forward_pipeline: NormalisationPipeline
        | list[NormalisationPipeline]
        | None = None,
        variable_input: list[str] | str | None = None,
        tiling: bool = True,
        allow_pickle: bool = False,
    ) -> None:
        """Initialize SpectrumDataset.

        Args:
            files (list[str]): List of file paths to .npz files.
            x (str): Key for independent variable in .npz files.
            y (str): Key for dependent variable in .npz files.
            forward_pipeline (Any, optional): Preprocessing pipeline.
                Defaults to None.
            variable_input (list[str] | str | None, optional): Keys
                for variable input parameters.
                If None, all parameters except x and y are used.
                Defaults to None.
            tiling (bool, optional): Whether to tile input/output parameters.
                This is True by default since this is what makes
                astroemu (and globaemu) tick. However, you might want
                to turn it off if you want to use the dataset for
                something other than
                emulation or if you want to calcualte things like
                rolling averages using astroemu.utils.compute_mean_std. Note
                normalisation is applied before tiling.
                Defaults to True.
            allow_pickle (bool): Whether to allow loading pickled objects
                from .npz files. Defaults to False.
        """
        self.files = files
        self.varied_input = variable_input
        if type(variable_input) is str:
            self.varied_input = [variable_input]
        self.forward_pipeline = (
            forward_pipeline
            if isinstance(forward_pipeline, list)
            else [forward_pipeline]
            if forward_pipeline is not None
            else []
        )
        self.x = x
        self.y = y
        self.tiling = tiling
        self.allow_pickle = allow_pickle

    def __len__(self) -> int:
        """Return number of files in dataset."""
        return len(self.files)

    def __getitem__(
        self, idx: int
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Get spectrum and input parameters for given index.

        Args:
            idx (int): Index of the data point.

        Returns:
            tuple[jnp.ndarray, jnp.ndarray]: Tuple of (spectrum,
                input parameters).
        """
        input = load_spectrum(self.files[idx], allow_pickle=self.allow_pickle)
        x = jnp.array(input[self.x])
        y = jnp.array(input[self.y])
        if self.varied_input:
            raw = [(k, input[k].item()) for k in self.varied_input]
        else:
            raw = [
                (k, input[k].item())
                for k in sorted(input.keys())
                if k not in [self.x, self.y]
            ]

        # Build a single params dict: merge dict-valued entries,
        # add numeric scalars by key, skip non-numeric metadata (e.g. strings).
        params: dict = {}
        for k, val in raw:
            if isinstance(val, dict):
                params.update(val)
            elif isinstance(val, int | float):
                params[k] = val

        input = jnp.array(list(params.values()), dtype=jnp.float32)

        return y, x, input

    def get_batch_iterator(
        self,
        batch_size: int,
        shuffle: bool = True,
        key: jax.Array | None = None,
    ) -> Generator:
        """Yield batches of spectra and inputs as jnp.ndarray.

        When tiling=True, yields (specs_flat, concat_inputs) where
        specs_flat has shape (batch * len_x,) and concat_inputs has shape
        (batch * len_x, n_params + 1) with x prepended as the first column.

        When tiling=False, yields (specs, x, inputs) with shapes
        (batch, len_x), (batch, len_x), and (batch, n_params) respectively.
        This mode is suitable for computing rolling statistics via
        astroemu.utils.compute_mean_std and building
        normalisation pipelines.

        Args:
            batch_size (int): Number of samples per batch.
            shuffle (bool): Whether to shuffle indices. Defaults to True.
                When tiling=True, this also shuffles the tiled samples
                within each batch so the network doesn't see x values
                in sequential order.
            key (jax.Array | None): JAX PRNG key for shuffling.
                Required when shuffle=True. Defaults to None.

        Yields:
            tiling=True:  tuple[jnp.ndarray, jnp.ndarray]
                (specs_flat, concat_inputs)
            tiling=False: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
                (specs, x, inputs)
        """
        n = len(self)
        indices = jnp.arange(n)
        if shuffle:
            if key is None:
                key = jax.random.PRNGKey(0)
            indices = jax.random.permutation(key, indices)

        for start in range(0, n, batch_size):
            batch_indices = indices[start : start + batch_size]
            specs, x, inputs = zip(*[self[int(i)] for i in batch_indices])
            specs = jnp.stack(specs)
            x = jnp.stack(x)
            inputs = jnp.stack(inputs)

            for pipeline in self.forward_pipeline:
                specs, x, inputs = pipeline.forward(specs, x, inputs)

            if self.tiling:
                # tile params to match each x point, then prepend x column
                inputs = jnp.repeat(inputs, repeats=specs.shape[-1], axis=0)
                inputs = jnp.concatenate(
                    [x.flatten()[:, None], inputs], axis=-1
                )
                specs_flat = specs.flatten()

                # Shuffle tiled samples so network doesn't see x
                # values in order
                if shuffle:
                    key, subkey = jax.random.split(key)
                    perm = jax.random.permutation(subkey, len(specs_flat))
                    specs_flat = specs_flat[perm]
                    inputs = inputs[perm]

                yield specs_flat, inputs
            else:
                yield specs, x, inputs

__getitem__(idx)

Get spectrum and input parameters for given index.

Parameters:

Name Type Description Default
idx int

Index of the data point.

required

Returns:

Type Description
tuple[ndarray, ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: Tuple of (spectrum, input parameters).

Source code in astroemu/dataloaders.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __getitem__(
    self, idx: int
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Get spectrum and input parameters for given index.

    Args:
        idx (int): Index of the data point.

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: Tuple of (spectrum,
            input parameters).
    """
    input = load_spectrum(self.files[idx], allow_pickle=self.allow_pickle)
    x = jnp.array(input[self.x])
    y = jnp.array(input[self.y])
    if self.varied_input:
        raw = [(k, input[k].item()) for k in self.varied_input]
    else:
        raw = [
            (k, input[k].item())
            for k in sorted(input.keys())
            if k not in [self.x, self.y]
        ]

    # Build a single params dict: merge dict-valued entries,
    # add numeric scalars by key, skip non-numeric metadata (e.g. strings).
    params: dict = {}
    for k, val in raw:
        if isinstance(val, dict):
            params.update(val)
        elif isinstance(val, int | float):
            params[k] = val

    input = jnp.array(list(params.values()), dtype=jnp.float32)

    return y, x, input

__init__(files, x, y, forward_pipeline=None, variable_input=None, tiling=True, allow_pickle=False)

Initialize SpectrumDataset.

Parameters:

Name Type Description Default
files list[str]

List of file paths to .npz files.

required
x str

Key for independent variable in .npz files.

required
y str

Key for dependent variable in .npz files.

required
forward_pipeline Any

Preprocessing pipeline. Defaults to None.

None
variable_input list[str] | str | None

Keys for variable input parameters. If None, all parameters except x and y are used. Defaults to None.

None
tiling bool

Whether to tile input/output parameters. This is True by default since this is what makes astroemu (and globaemu) tick. However, you might want to turn it off if you want to use the dataset for something other than emulation or if you want to calcualte things like rolling averages using astroemu.utils.compute_mean_std. Note normalisation is applied before tiling. Defaults to True.

True
allow_pickle bool

Whether to allow loading pickled objects from .npz files. Defaults to False.

False
Source code in astroemu/dataloaders.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self,
    files: list[str],
    x: str,
    y: str,
    forward_pipeline: NormalisationPipeline
    | list[NormalisationPipeline]
    | None = None,
    variable_input: list[str] | str | None = None,
    tiling: bool = True,
    allow_pickle: bool = False,
) -> None:
    """Initialize SpectrumDataset.

    Args:
        files (list[str]): List of file paths to .npz files.
        x (str): Key for independent variable in .npz files.
        y (str): Key for dependent variable in .npz files.
        forward_pipeline (Any, optional): Preprocessing pipeline.
            Defaults to None.
        variable_input (list[str] | str | None, optional): Keys
            for variable input parameters.
            If None, all parameters except x and y are used.
            Defaults to None.
        tiling (bool, optional): Whether to tile input/output parameters.
            This is True by default since this is what makes
            astroemu (and globaemu) tick. However, you might want
            to turn it off if you want to use the dataset for
            something other than
            emulation or if you want to calcualte things like
            rolling averages using astroemu.utils.compute_mean_std. Note
            normalisation is applied before tiling.
            Defaults to True.
        allow_pickle (bool): Whether to allow loading pickled objects
            from .npz files. Defaults to False.
    """
    self.files = files
    self.varied_input = variable_input
    if type(variable_input) is str:
        self.varied_input = [variable_input]
    self.forward_pipeline = (
        forward_pipeline
        if isinstance(forward_pipeline, list)
        else [forward_pipeline]
        if forward_pipeline is not None
        else []
    )
    self.x = x
    self.y = y
    self.tiling = tiling
    self.allow_pickle = allow_pickle

__len__()

Return number of files in dataset.

Source code in astroemu/dataloaders.py
85
86
87
def __len__(self) -> int:
    """Return number of files in dataset."""
    return len(self.files)

get_batch_iterator(batch_size, shuffle=True, key=None)

Yield batches of spectra and inputs as jnp.ndarray.

When tiling=True, yields (specs_flat, concat_inputs) where specs_flat has shape (batch * len_x,) and concat_inputs has shape (batch * len_x, n_params + 1) with x prepended as the first column.

When tiling=False, yields (specs, x, inputs) with shapes (batch, len_x), (batch, len_x), and (batch, n_params) respectively. This mode is suitable for computing rolling statistics via astroemu.utils.compute_mean_std and building normalisation pipelines.

Parameters:

Name Type Description Default
batch_size int

Number of samples per batch.

required
shuffle bool

Whether to shuffle indices. Defaults to True. When tiling=True, this also shuffles the tiled samples within each batch so the network doesn't see x values in sequential order.

True
key Array | None

JAX PRNG key for shuffling. Required when shuffle=True. Defaults to None.

None

Yields:

Type Description
Generator

tiling=True: tuple[jnp.ndarray, jnp.ndarray] (specs_flat, concat_inputs)

Generator

tiling=False: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] (specs, x, inputs)

Source code in astroemu/dataloaders.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def get_batch_iterator(
    self,
    batch_size: int,
    shuffle: bool = True,
    key: jax.Array | None = None,
) -> Generator:
    """Yield batches of spectra and inputs as jnp.ndarray.

    When tiling=True, yields (specs_flat, concat_inputs) where
    specs_flat has shape (batch * len_x,) and concat_inputs has shape
    (batch * len_x, n_params + 1) with x prepended as the first column.

    When tiling=False, yields (specs, x, inputs) with shapes
    (batch, len_x), (batch, len_x), and (batch, n_params) respectively.
    This mode is suitable for computing rolling statistics via
    astroemu.utils.compute_mean_std and building
    normalisation pipelines.

    Args:
        batch_size (int): Number of samples per batch.
        shuffle (bool): Whether to shuffle indices. Defaults to True.
            When tiling=True, this also shuffles the tiled samples
            within each batch so the network doesn't see x values
            in sequential order.
        key (jax.Array | None): JAX PRNG key for shuffling.
            Required when shuffle=True. Defaults to None.

    Yields:
        tiling=True:  tuple[jnp.ndarray, jnp.ndarray]
            (specs_flat, concat_inputs)
        tiling=False: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
            (specs, x, inputs)
    """
    n = len(self)
    indices = jnp.arange(n)
    if shuffle:
        if key is None:
            key = jax.random.PRNGKey(0)
        indices = jax.random.permutation(key, indices)

    for start in range(0, n, batch_size):
        batch_indices = indices[start : start + batch_size]
        specs, x, inputs = zip(*[self[int(i)] for i in batch_indices])
        specs = jnp.stack(specs)
        x = jnp.stack(x)
        inputs = jnp.stack(inputs)

        for pipeline in self.forward_pipeline:
            specs, x, inputs = pipeline.forward(specs, x, inputs)

        if self.tiling:
            # tile params to match each x point, then prepend x column
            inputs = jnp.repeat(inputs, repeats=specs.shape[-1], axis=0)
            inputs = jnp.concatenate(
                [x.flatten()[:, None], inputs], axis=-1
            )
            specs_flat = specs.flatten()

            # Shuffle tiled samples so network doesn't see x
            # values in order
            if shuffle:
                key, subkey = jax.random.split(key)
                perm = jax.random.permutation(subkey, len(specs_flat))
                specs_flat = specs_flat[perm]
                inputs = inputs[perm]

            yield specs_flat, inputs
        else:
            yield specs, x, inputs

load_spectrum(file, allow_pickle=False)

Load spectrum data from .npz file.

Parameters:

Name Type Description Default
file str

Path to .npz file.

required
allow_pickle bool

Whether to allow loading pickled objects.

False

Returns:

Name Type Description
dict dict

Dictionary containing data from .npz file.

Source code in astroemu/dataloaders.py
11
12
13
14
15
16
17
18
19
20
21
22
23
def load_spectrum(file: str, allow_pickle: bool = False) -> dict:
    """Load spectrum data from .npz file.

    Args:
        file (str): Path to .npz file.
        allow_pickle (bool): Whether to allow loading pickled objects.

    Returns:
        dict: Dictionary containing data from .npz file.
    """
    data = jnp.load(file, allow_pickle=allow_pickle)
    input = {k: data[k] for k in data.files}
    return input

astroemu.normalisation

Normalisation pipelines for emu package.

NormalisationPipeline

Base class for normalisation pipelines.

Source code in astroemu/normalisation.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class NormalisationPipeline:
    """Base class for normalisation pipelines."""

    def forward(
        self,
        _y: jnp.ndarray,
        _x: jnp.ndarray,
        _params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Apply forward transformation."""
        raise NotImplementedError

    def backward(
        self,
        _y: jnp.ndarray,
        _x: jnp.ndarray,
        _params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Apply backward transformation."""
        raise NotImplementedError

backward(_y, _x, _params)

Apply backward transformation.

Source code in astroemu/normalisation.py
20
21
22
23
24
25
26
27
def backward(
    self,
    _y: jnp.ndarray,
    _x: jnp.ndarray,
    _params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Apply backward transformation."""
    raise NotImplementedError

forward(_y, _x, _params)

Apply forward transformation.

Source code in astroemu/normalisation.py
11
12
13
14
15
16
17
18
def forward(
    self,
    _y: jnp.ndarray,
    _x: jnp.ndarray,
    _params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Apply forward transformation."""
    raise NotImplementedError

log_base_10

Bases: NormalisationPipeline

Logarithm base 10 transformation for numerical stability.

Source code in astroemu/normalisation.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
class log_base_10(NormalisationPipeline):
    """Logarithm base 10 transformation for numerical stability."""

    def __init__(
        self,
        y_selector: list[int] | jnp.ndarray | None = None,
        x_selector: list[int] | jnp.ndarray | None = None,
        params_selector: list[int] | jnp.ndarray | None = None,
        log_all_y: bool = False,
        log_all_x: bool = False,
        log_all_params: bool = False,
        eps: float = 1e-15,
    ) -> None:
        """Logarithm base 10 transformation for numerical stability.

        Args:
            y_selector (list[int] | None): columns of the spectrum to apply
                log transformation. Assumes spectra are in the last dimension.
                None returns y without any transformation.
            x_selector (list[int] | None): indices of the independent variable
                to apply log transformation. None returns x unchanged.
            params_selector (list[int] | None): columns of the input parameters
                to apply log transformation. Assumes parameters are in the last
                dimension. None returns params without any transformation.
            log_all_y (bool): If True, apply log transformation to all columns
                of the spectrum. Overrides y_selector if True.
            log_all_x (bool): If True, apply log transformation to all elements
                of the independent variable. Overrides x_selector if True.
            log_all_params (bool): If True, apply log transformation to all
                columns of the input parameters. Overrides params_selector.
            eps (float): small value to add to avoid log(0).
        """
        self.y_selector = y_selector
        self.x_selector = x_selector
        self.params_selector = params_selector
        self.log_all_y = log_all_y
        self.log_all_x = log_all_x
        self.log_all_params = log_all_params
        self.eps = eps

        if log_all_y and y_selector is not None:
            warnings.warn("log_all_y is True, overriding y_selector.")
        else:
            if type(self.y_selector) is list:
                self.y_selector = jnp.array(self.y_selector)

        if log_all_x and x_selector is not None:
            warnings.warn("log_all_x is True, overriding x_selector.")
        else:
            if type(self.x_selector) is list:
                self.x_selector = jnp.array(self.x_selector)

        if log_all_params and params_selector is not None:
            warnings.warn(
                "log_all_params is True, overriding params_selector."
            )
        else:
            if type(self.params_selector) is list:
                self.params_selector = jnp.array(self.params_selector)

        if True not in [log_all_y, log_all_x, log_all_params] and all(
            s is None for s in [y_selector, x_selector, params_selector]
        ):
            warnings.warn(
                "No log transformation applied. Consider setting at least one "
                "of log_all_y, log_all_x, or log_all_params "
                "to True or providing "
                "a selector."
            )

    def forward(
        self,
        y: jnp.ndarray,
        x: jnp.ndarray,
        params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Apply log10 transformation to selected columns.

        Args:
            y (jnp.ndarray): Spectrum array, shape (batch, len_x).
            x (jnp.ndarray): Independent variable array, shape (batch, len_x).
            params (jnp.ndarray): Input parameters array, shape
                (batch, n_params).

        Returns:
            tuple: Transformed spectrum, independent variable, and parameters.
        """
        if self.log_all_y:
            y = jnp.log10(y + self.eps)
        elif self.y_selector is not None:
            mask = (
                jnp.zeros(y.shape[-1], dtype=bool)
                .at[self.y_selector]
                .set(True)
            )
            y = jnp.where(mask, jnp.log10(y + self.eps), y)

        if self.log_all_x:
            x = jnp.log10(x + self.eps)
        elif self.x_selector is not None:
            mask = (
                jnp.zeros(x.shape[-1], dtype=bool)
                .at[self.x_selector]
                .set(True)
            )
            x = jnp.where(mask, jnp.log10(x + self.eps), x)

        if self.log_all_params:
            params = jnp.log10(params + self.eps)
        elif self.params_selector is not None:
            mask = (
                jnp.zeros(params.shape[-1], dtype=bool)
                .at[self.params_selector]
                .set(True)
            )
            params = jnp.where(mask, jnp.log10(params + self.eps), params)

        return y, x, params

    def backward(
        self,
        y: jnp.ndarray,
        x: jnp.ndarray,
        params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Apply inverse log10 transformation to selected columns.

        Args:
            y (jnp.ndarray): Transformed spectrum array, shape (batch, len_x).
            x (jnp.ndarray): Transformed independent variable, shape
                (batch, len_x).
            params (jnp.ndarray): Transformed input parameters, shape
                (batch, n_params).

        Returns:
            tuple: Inverse transformed spectrum, independent variable, and
                parameters.
        """
        if self.log_all_y:
            y = 10**y - self.eps
        elif self.y_selector is not None:
            mask = (
                jnp.zeros(y.shape[-1], dtype=bool)
                .at[self.y_selector]
                .set(True)
            )
            y = jnp.where(mask, 10**y - self.eps, y)

        if self.log_all_x:
            x = 10**x - self.eps
        elif self.x_selector is not None:
            mask = (
                jnp.zeros(x.shape[-1], dtype=bool)
                .at[self.x_selector]
                .set(True)
            )
            x = jnp.where(mask, 10**x - self.eps, x)

        if self.log_all_params:
            params = 10**params - self.eps
        elif self.params_selector is not None:
            mask = (
                jnp.zeros(params.shape[-1], dtype=bool)
                .at[self.params_selector]
                .set(True)
            )
            params = jnp.where(mask, 10**params - self.eps, params)

        return y, x, params

__init__(y_selector=None, x_selector=None, params_selector=None, log_all_y=False, log_all_x=False, log_all_params=False, eps=1e-15)

Logarithm base 10 transformation for numerical stability.

Parameters:

Name Type Description Default
y_selector list[int] | None

columns of the spectrum to apply log transformation. Assumes spectra are in the last dimension. None returns y without any transformation.

None
x_selector list[int] | None

indices of the independent variable to apply log transformation. None returns x unchanged.

None
params_selector list[int] | None

columns of the input parameters to apply log transformation. Assumes parameters are in the last dimension. None returns params without any transformation.

None
log_all_y bool

If True, apply log transformation to all columns of the spectrum. Overrides y_selector if True.

False
log_all_x bool

If True, apply log transformation to all elements of the independent variable. Overrides x_selector if True.

False
log_all_params bool

If True, apply log transformation to all columns of the input parameters. Overrides params_selector.

False
eps float

small value to add to avoid log(0).

1e-15
Source code in astroemu/normalisation.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def __init__(
    self,
    y_selector: list[int] | jnp.ndarray | None = None,
    x_selector: list[int] | jnp.ndarray | None = None,
    params_selector: list[int] | jnp.ndarray | None = None,
    log_all_y: bool = False,
    log_all_x: bool = False,
    log_all_params: bool = False,
    eps: float = 1e-15,
) -> None:
    """Logarithm base 10 transformation for numerical stability.

    Args:
        y_selector (list[int] | None): columns of the spectrum to apply
            log transformation. Assumes spectra are in the last dimension.
            None returns y without any transformation.
        x_selector (list[int] | None): indices of the independent variable
            to apply log transformation. None returns x unchanged.
        params_selector (list[int] | None): columns of the input parameters
            to apply log transformation. Assumes parameters are in the last
            dimension. None returns params without any transformation.
        log_all_y (bool): If True, apply log transformation to all columns
            of the spectrum. Overrides y_selector if True.
        log_all_x (bool): If True, apply log transformation to all elements
            of the independent variable. Overrides x_selector if True.
        log_all_params (bool): If True, apply log transformation to all
            columns of the input parameters. Overrides params_selector.
        eps (float): small value to add to avoid log(0).
    """
    self.y_selector = y_selector
    self.x_selector = x_selector
    self.params_selector = params_selector
    self.log_all_y = log_all_y
    self.log_all_x = log_all_x
    self.log_all_params = log_all_params
    self.eps = eps

    if log_all_y and y_selector is not None:
        warnings.warn("log_all_y is True, overriding y_selector.")
    else:
        if type(self.y_selector) is list:
            self.y_selector = jnp.array(self.y_selector)

    if log_all_x and x_selector is not None:
        warnings.warn("log_all_x is True, overriding x_selector.")
    else:
        if type(self.x_selector) is list:
            self.x_selector = jnp.array(self.x_selector)

    if log_all_params and params_selector is not None:
        warnings.warn(
            "log_all_params is True, overriding params_selector."
        )
    else:
        if type(self.params_selector) is list:
            self.params_selector = jnp.array(self.params_selector)

    if True not in [log_all_y, log_all_x, log_all_params] and all(
        s is None for s in [y_selector, x_selector, params_selector]
    ):
        warnings.warn(
            "No log transformation applied. Consider setting at least one "
            "of log_all_y, log_all_x, or log_all_params "
            "to True or providing "
            "a selector."
        )

backward(y, x, params)

Apply inverse log10 transformation to selected columns.

Parameters:

Name Type Description Default
y ndarray

Transformed spectrum array, shape (batch, len_x).

required
x ndarray

Transformed independent variable, shape (batch, len_x).

required
params ndarray

Transformed input parameters, shape (batch, n_params).

required

Returns:

Name Type Description
tuple tuple[ndarray, ndarray, ndarray]

Inverse transformed spectrum, independent variable, and parameters.

Source code in astroemu/normalisation.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def backward(
    self,
    y: jnp.ndarray,
    x: jnp.ndarray,
    params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Apply inverse log10 transformation to selected columns.

    Args:
        y (jnp.ndarray): Transformed spectrum array, shape (batch, len_x).
        x (jnp.ndarray): Transformed independent variable, shape
            (batch, len_x).
        params (jnp.ndarray): Transformed input parameters, shape
            (batch, n_params).

    Returns:
        tuple: Inverse transformed spectrum, independent variable, and
            parameters.
    """
    if self.log_all_y:
        y = 10**y - self.eps
    elif self.y_selector is not None:
        mask = (
            jnp.zeros(y.shape[-1], dtype=bool)
            .at[self.y_selector]
            .set(True)
        )
        y = jnp.where(mask, 10**y - self.eps, y)

    if self.log_all_x:
        x = 10**x - self.eps
    elif self.x_selector is not None:
        mask = (
            jnp.zeros(x.shape[-1], dtype=bool)
            .at[self.x_selector]
            .set(True)
        )
        x = jnp.where(mask, 10**x - self.eps, x)

    if self.log_all_params:
        params = 10**params - self.eps
    elif self.params_selector is not None:
        mask = (
            jnp.zeros(params.shape[-1], dtype=bool)
            .at[self.params_selector]
            .set(True)
        )
        params = jnp.where(mask, 10**params - self.eps, params)

    return y, x, params

forward(y, x, params)

Apply log10 transformation to selected columns.

Parameters:

Name Type Description Default
y ndarray

Spectrum array, shape (batch, len_x).

required
x ndarray

Independent variable array, shape (batch, len_x).

required
params ndarray

Input parameters array, shape (batch, n_params).

required

Returns:

Name Type Description
tuple tuple[ndarray, ndarray, ndarray]

Transformed spectrum, independent variable, and parameters.

Source code in astroemu/normalisation.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def forward(
    self,
    y: jnp.ndarray,
    x: jnp.ndarray,
    params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Apply log10 transformation to selected columns.

    Args:
        y (jnp.ndarray): Spectrum array, shape (batch, len_x).
        x (jnp.ndarray): Independent variable array, shape (batch, len_x).
        params (jnp.ndarray): Input parameters array, shape
            (batch, n_params).

    Returns:
        tuple: Transformed spectrum, independent variable, and parameters.
    """
    if self.log_all_y:
        y = jnp.log10(y + self.eps)
    elif self.y_selector is not None:
        mask = (
            jnp.zeros(y.shape[-1], dtype=bool)
            .at[self.y_selector]
            .set(True)
        )
        y = jnp.where(mask, jnp.log10(y + self.eps), y)

    if self.log_all_x:
        x = jnp.log10(x + self.eps)
    elif self.x_selector is not None:
        mask = (
            jnp.zeros(x.shape[-1], dtype=bool)
            .at[self.x_selector]
            .set(True)
        )
        x = jnp.where(mask, jnp.log10(x + self.eps), x)

    if self.log_all_params:
        params = jnp.log10(params + self.eps)
    elif self.params_selector is not None:
        mask = (
            jnp.zeros(params.shape[-1], dtype=bool)
            .at[self.params_selector]
            .set(True)
        )
        params = jnp.where(mask, jnp.log10(params + self.eps), params)

    return y, x, params

standardise

Bases: NormalisationPipeline

Standardisation normalisation pipeline.

Source code in astroemu/normalisation.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class standardise(NormalisationPipeline):
    """Standardisation normalisation pipeline."""

    def __init__(
        self,
        y_mean: jnp.ndarray,
        y_std: jnp.ndarray,
        x_mean: jnp.ndarray,
        x_std: jnp.ndarray,
        params_mean: jnp.ndarray,
        params_std: jnp.ndarray,
        standardise_y: bool = False,
        standardise_x: bool = False,
        standardise_params: bool = False,
    ) -> None:
        """Standardises the spectrum, independent variable, and parameters.

        Args:
            y_mean (jnp.ndarray): Mean of the spectrum.
            y_std (jnp.ndarray): Standard deviation of the spectrum.
            x_mean (float): Mean of the independent variable.
            x_std (float): Standard deviation of the independent variable.
            params_mean (jnp.ndarray): Mean of the input parameters.
            params_std (jnp.ndarray): Standard deviation of the input
                parameters.
            standardise_y (bool): Whether to standardise the spectrum.
                Defaults to False.
            standardise_x (bool): Whether to standardise the independent
                variable. Defaults to False.
            standardise_params (bool): Whether to standardise the input
                parameters. Defaults to False.
        """
        self.y_mean = y_mean
        self.y_std = y_std
        self.x_mean = x_mean
        self.x_std = x_std
        self.params_mean = params_mean
        self.params_std = params_std
        self.standardise_y = standardise_y
        self.standardise_x = standardise_x
        self.standardise_params = standardise_params

        if True not in [standardise_y, standardise_x, standardise_params]:
            warnings.warn(
                "No standardisation applied. Consider setting at least one of "
                "standardise_y, standardise_x, or standardise_params to True."
            )

    def forward(
        self,
        y: jnp.ndarray,
        x: jnp.ndarray,
        params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Standardise spectrum, independent variable, and input parameters.

        Args:
            y (jnp.ndarray): Spectrum array, shape (batch, len_x).
            x (jnp.ndarray): Independent variable array, shape (batch, len_x).
            params (jnp.ndarray): Input parameters array, shape
                (batch, n_params).

        Returns:
            tuple: Standardised spectrum, independent variable, and parameters.
        """
        if self.standardise_y:
            y = (y - self.y_mean) / self.y_std
        if self.standardise_x:
            x = (x - self.x_mean) / self.x_std
        if self.standardise_params:
            params = (params - self.params_mean) / self.params_std
        return y, x, params

    def backward(
        self,
        y: jnp.ndarray,
        x: jnp.ndarray,
        params: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Destandardise spectrum, independent variable, and input parameters.

        Args:
            y (jnp.ndarray): Standardised spectrum array, shape (batch, len_x).
            x (jnp.ndarray): Standardised independent variable, shape
                (batch, len_x).
            params (jnp.ndarray): Standardised input parameters, shape
                (batch, n_params).

        Returns:
            tuple: Destandardised spectrum, independent variable, and
                parameters.
        """
        if self.standardise_y:
            y = y * self.y_std + self.y_mean
        if self.standardise_x:
            x = x * self.x_std + self.x_mean
        if self.standardise_params:
            params = params * self.params_std + self.params_mean
        return y, x, params

__init__(y_mean, y_std, x_mean, x_std, params_mean, params_std, standardise_y=False, standardise_x=False, standardise_params=False)

Standardises the spectrum, independent variable, and parameters.

Parameters:

Name Type Description Default
y_mean ndarray

Mean of the spectrum.

required
y_std ndarray

Standard deviation of the spectrum.

required
x_mean float

Mean of the independent variable.

required
x_std float

Standard deviation of the independent variable.

required
params_mean ndarray

Mean of the input parameters.

required
params_std ndarray

Standard deviation of the input parameters.

required
standardise_y bool

Whether to standardise the spectrum. Defaults to False.

False
standardise_x bool

Whether to standardise the independent variable. Defaults to False.

False
standardise_params bool

Whether to standardise the input parameters. Defaults to False.

False
Source code in astroemu/normalisation.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    y_mean: jnp.ndarray,
    y_std: jnp.ndarray,
    x_mean: jnp.ndarray,
    x_std: jnp.ndarray,
    params_mean: jnp.ndarray,
    params_std: jnp.ndarray,
    standardise_y: bool = False,
    standardise_x: bool = False,
    standardise_params: bool = False,
) -> None:
    """Standardises the spectrum, independent variable, and parameters.

    Args:
        y_mean (jnp.ndarray): Mean of the spectrum.
        y_std (jnp.ndarray): Standard deviation of the spectrum.
        x_mean (float): Mean of the independent variable.
        x_std (float): Standard deviation of the independent variable.
        params_mean (jnp.ndarray): Mean of the input parameters.
        params_std (jnp.ndarray): Standard deviation of the input
            parameters.
        standardise_y (bool): Whether to standardise the spectrum.
            Defaults to False.
        standardise_x (bool): Whether to standardise the independent
            variable. Defaults to False.
        standardise_params (bool): Whether to standardise the input
            parameters. Defaults to False.
    """
    self.y_mean = y_mean
    self.y_std = y_std
    self.x_mean = x_mean
    self.x_std = x_std
    self.params_mean = params_mean
    self.params_std = params_std
    self.standardise_y = standardise_y
    self.standardise_x = standardise_x
    self.standardise_params = standardise_params

    if True not in [standardise_y, standardise_x, standardise_params]:
        warnings.warn(
            "No standardisation applied. Consider setting at least one of "
            "standardise_y, standardise_x, or standardise_params to True."
        )

backward(y, x, params)

Destandardise spectrum, independent variable, and input parameters.

Parameters:

Name Type Description Default
y ndarray

Standardised spectrum array, shape (batch, len_x).

required
x ndarray

Standardised independent variable, shape (batch, len_x).

required
params ndarray

Standardised input parameters, shape (batch, n_params).

required

Returns:

Name Type Description
tuple tuple[ndarray, ndarray, ndarray]

Destandardised spectrum, independent variable, and parameters.

Source code in astroemu/normalisation.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def backward(
    self,
    y: jnp.ndarray,
    x: jnp.ndarray,
    params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Destandardise spectrum, independent variable, and input parameters.

    Args:
        y (jnp.ndarray): Standardised spectrum array, shape (batch, len_x).
        x (jnp.ndarray): Standardised independent variable, shape
            (batch, len_x).
        params (jnp.ndarray): Standardised input parameters, shape
            (batch, n_params).

    Returns:
        tuple: Destandardised spectrum, independent variable, and
            parameters.
    """
    if self.standardise_y:
        y = y * self.y_std + self.y_mean
    if self.standardise_x:
        x = x * self.x_std + self.x_mean
    if self.standardise_params:
        params = params * self.params_std + self.params_mean
    return y, x, params

forward(y, x, params)

Standardise spectrum, independent variable, and input parameters.

Parameters:

Name Type Description Default
y ndarray

Spectrum array, shape (batch, len_x).

required
x ndarray

Independent variable array, shape (batch, len_x).

required
params ndarray

Input parameters array, shape (batch, n_params).

required

Returns:

Name Type Description
tuple tuple[ndarray, ndarray, ndarray]

Standardised spectrum, independent variable, and parameters.

Source code in astroemu/normalisation.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def forward(
    self,
    y: jnp.ndarray,
    x: jnp.ndarray,
    params: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Standardise spectrum, independent variable, and input parameters.

    Args:
        y (jnp.ndarray): Spectrum array, shape (batch, len_x).
        x (jnp.ndarray): Independent variable array, shape (batch, len_x).
        params (jnp.ndarray): Input parameters array, shape
            (batch, n_params).

    Returns:
        tuple: Standardised spectrum, independent variable, and parameters.
    """
    if self.standardise_y:
        y = (y - self.y_mean) / self.y_std
    if self.standardise_x:
        x = (x - self.x_mean) / self.x_std
    if self.standardise_params:
        params = (params - self.params_mean) / self.params_std
    return y, x, params

astroemu.train

Training loop for astroemu.

train(train_dataset, val_dataset, hidden_size, nlayers, act='relu', epochs=1000, patience=50, learning_rate=0.001, weight_decay=0.0001, batch_size=32, key=0, loss_fn=mse, loss_kwargs={})

Train an MLP emulator on spectral data using AdamW.

Initialises an MLP via initialise_mlp, then trains it using batches from train_dataset and val_dataset (both must have tiling=True). The per-batch train and validation steps are JIT-compiled. Training stops early if the validation loss does not improve for patience consecutive epochs, and the best parameters are returned.

Parameters:

Name Type Description Default
train_dataset SpectrumDataset

Training dataset with tiling=True.

required
val_dataset SpectrumDataset

Validation dataset with tiling=True.

required
hidden_size int

Number of nodes in each hidden layer.

required
nlayers int

Number of hidden layers.

required
act str

Activation function name from jax.nn. Defaults to "relu".

'relu'
epochs int

Maximum number of training epochs. Defaults to 1000.

1000
patience int

Early stopping patience in epochs. Defaults to 50.

50
learning_rate float

AdamW learning rate. Defaults to 1e-3.

0.001
weight_decay float

AdamW weight decay. Defaults to 1e-4.

0.0001
batch_size int

Number of spectra per batch. Defaults to 32.

32
key int

Integer seed for JAX PRNG. Defaults to 0.

0
loss_fn Callable

Loss function to use. Defaults to mse.

mse
loss_kwargs dict

Additional keyword arguments for the loss function. Defaults to an empty dict.

{}

Returns:

Type Description
tuple[dict, list[float], list[float]]

tuple[dict, list[float], list[float]]: Best network parameters, per-epoch training losses, and per-epoch validation losses.

Source code in astroemu/train.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def train(
    train_dataset: SpectrumDataset,
    val_dataset: SpectrumDataset,
    hidden_size: int,
    nlayers: int,
    act: str = "relu",
    epochs: int = 1000,
    patience: int = 50,
    learning_rate: float = 1e-3,
    weight_decay: float = 1e-4,
    batch_size: int = 32,
    key: int = 0,
    loss_fn: Callable = mse,
    loss_kwargs: dict = {},
) -> tuple[dict, list[float], list[float]]:
    """Train an MLP emulator on spectral data using AdamW.

    Initialises an MLP via initialise_mlp, then trains it using batches
    from train_dataset and val_dataset (both must have tiling=True).
    The per-batch train and validation steps are JIT-compiled. Training
    stops early if the validation loss does not improve for patience
    consecutive epochs, and the best parameters are returned.

    Args:
        train_dataset (SpectrumDataset): Training dataset with tiling=True.
        val_dataset (SpectrumDataset): Validation dataset with tiling=True.
        hidden_size (int): Number of nodes in each hidden layer.
        nlayers (int): Number of hidden layers.
        act (str): Activation function name from jax.nn. Defaults to
            "relu".
        epochs (int): Maximum number of training epochs. Defaults to 1000.
        patience (int): Early stopping patience in epochs. Defaults to 50.
        learning_rate (float): AdamW learning rate. Defaults to 1e-3.
        weight_decay (float): AdamW weight decay. Defaults to 1e-4.
        batch_size (int): Number of spectra per batch. Defaults to 32.
        key (int): Integer seed for JAX PRNG. Defaults to 0.
        loss_fn (Callable): Loss function to use. Defaults to mse.
        loss_kwargs (dict): Additional keyword arguments for the loss
            function. Defaults to an empty dict.

    Returns:
        tuple[dict, list[float], list[float]]: Best network parameters,
            per-epoch training losses, and per-epoch validation losses.
    """
    # Infer input size from the dataset: n_params + 1 (x is prepended
    # to parameters in tiling mode, so in_size = n_params + 1).
    _, _, params_sample = train_dataset[0]
    in_size = int(params_sample.shape[0]) + 1
    out_size = 1

    # Initialise network parameters.
    rng_key = jax.random.PRNGKey(key)
    rng_key, init_key = jax.random.split(rng_key)
    params = initialise_mlp(in_size, out_size, hidden_size, nlayers, init_key)

    # Initialise AdamW optimizer.
    optimizer = optax.adamw(learning_rate, weight_decay=weight_decay)
    opt_state = optimizer.init(params)

    # JIT-compiled training step. act is a closure variable and is
    # therefore treated as a compile-time constant by JAX.
    @jax.jit
    def train_step(
        params: dict,
        opt_state: optax.OptState,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
    ) -> tuple[dict, optax.OptState, jnp.ndarray]:
        """Compute loss and gradients, then apply an AdamW update.

        Args:
            params (dict): Current network parameters.
            opt_state (optax.OptState): Current optimizer state.
            inputs (jnp.ndarray): Tiled input array of shape
                (batch * len_x, n_params + 1).
            targets (jnp.ndarray): Target values of shape (batch * len_x,).

        Returns:
            tuple[dict, optax.OptState, jnp.ndarray]: Updated parameters,
                updated optimizer state, and scalar batch loss.
        """

        def evaluate_loss_fn(p: dict) -> jnp.ndarray:
            preds = mlp(p, inputs, act)
            return loss_fn(preds.squeeze(-1), targets, **loss_kwargs)

        loss, grads = jax.value_and_grad(evaluate_loss_fn)(params)
        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss

    # JIT-compiled validation step.
    @jax.jit
    def val_step(
        params: dict,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
    ) -> jnp.ndarray:
        """Compute validation loss without updating parameters.

        Args:
            params (dict): Current network parameters.
            inputs (jnp.ndarray): Tiled input array of shape
                (batch * len_x, n_params + 1).
            targets (jnp.ndarray): Target values of shape (batch * len_x,).

        Returns:
            jnp.ndarray: Scalar batch loss.
        """
        preds = mlp(params, inputs, act)
        return loss_fn(preds.squeeze(-1), targets, **loss_kwargs)

    train_losses: list[float] = []
    val_losses: list[float] = []
    best_val_loss = jnp.inf
    best_params = params
    patience_counter = 0

    pbar = tqdm(range(epochs), desc="Training epochs")

    for epoch in pbar:
        rng_key, train_key = jax.random.split(rng_key)

        # Training pass.
        epoch_train_loss = 0.0
        n_train_batches = 0
        for targets, inputs in train_dataset.get_batch_iterator(
            batch_size, shuffle=True, key=train_key
        ):
            params, opt_state, batch_loss = train_step(
                params, opt_state, inputs, targets
            )
            epoch_train_loss += batch_loss
            n_train_batches += 1
        epoch_train_loss /= n_train_batches

        # Validation pass.
        epoch_val_loss = 0.0
        n_val_batches = 0
        for targets, inputs in val_dataset.get_batch_iterator(
            batch_size, shuffle=False
        ):
            epoch_val_loss += val_step(params, inputs, targets)
            n_val_batches += 1
        epoch_val_loss /= n_val_batches

        train_losses.append(epoch_train_loss)
        val_losses.append(epoch_val_loss)

        # Early stopping: save best params and reset counter on improvement.
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_params = params
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(
                    f"Early stopping at epoch {epoch + 1} "
                    f"(best val loss: {best_val_loss:.4f})."
                )
                break

        pbar.set_postfix(
            {
                "train_loss": f"{epoch_train_loss:.4f}",
                "val_loss": f"{epoch_val_loss:.4f}",
                "best_val_loss": f"{best_val_loss:.4f}",
            }
        )

    train_losses = [float(loss) for loss in train_losses]
    val_losses = [float(loss) for loss in val_losses]
    return best_params, train_losses, val_losses

astroemu.losses

Loss functions for astroemu.

kl(predictions, targets, noise, ndata=1)

Kullback-Leibler divergence loss.

From https://ui.adsabs.harvard.edu/abs/2025MNRAS.544..375B/abstract.

Parameters:

Name Type Description Default
predictions ndarray

Predicted probability distributions.

required
targets ndarray

Target probability distributions.

required
noise ndarray

Some estimate of noise in the data.

required
ndata int

Number of data points. Defaults to 1.

1

Returns:

Type Description
ndarray

jnp.ndarray: Scalar KL divergence loss.

Source code in astroemu/losses.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@jax.jit
def kl(
    predictions: jnp.ndarray, targets: jnp.ndarray, noise: jnp.ndarray,
    ndata: int = 1
) -> jnp.ndarray:
    """Kullback-Leibler divergence loss.

    From https://ui.adsabs.harvard.edu/abs/2025MNRAS.544..375B/abstract.

    Args:
        predictions (jnp.ndarray): Predicted probability distributions.
        targets (jnp.ndarray): Target probability distributions.
        noise (jnp.ndarray): Some estimate of noise in the data.
        ndata (int): Number of data points. Defaults to 1.

    Returns:
        jnp.ndarray: Scalar KL divergence loss.
    """
    rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
    return ndata / 2 * (rmse / noise) ** 2

mse(predictions, targets)

Mean squared error loss.

Parameters:

Name Type Description Default
predictions ndarray

Predicted values.

required
targets ndarray

Target values.

required

Returns:

Type Description
ndarray

jnp.ndarray: Scalar MSE loss.

Source code in astroemu/losses.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
@jax.jit
def mse(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Mean squared error loss.

    Args:
        predictions (jnp.ndarray): Predicted values.
        targets (jnp.ndarray): Target values.

    Returns:
        jnp.ndarray: Scalar MSE loss.
    """
    return jnp.mean((predictions - targets) ** 2)

astroemu.serialisation

Serialisation utilities for saving and loading trained emulators.

load(path)

Load a trained emulator from a .astroemu file.

For each saved dataset, this function attempts to reconstruct a SpectrumDataset using the saved file paths and pipeline. If any files are missing the raw config dict is returned under the dataset key instead.

Parameters:

Name Type Description Default
path str

Path to a .astroemu file.

required

Returns:

Name Type Description
dict dict

Dictionary with the following keys:

  • params (dict): Network weight arrays as jnp.ndarrays.
  • hyperparams (dict): Architecture and training hyperparameters.
  • train_losses (list[float]): Per-epoch training losses.
  • val_losses (list[float]): Per-epoch validation losses.
  • loss (str): Name of the loss criterion used.
  • version (str): astroemu version the emulator was trained with.
  • train_pipeline (list): Normalisation pipeline instances used for the training dataset.
  • val_pipeline (list): Normalisation pipeline instances used for the validation dataset.
  • test_pipeline (list): Normalisation pipeline instances used for the test dataset.
  • train_dataset (SpectrumDataset | dict): Reconstructed training dataset if all files are found, otherwise the raw config dict.
  • val_dataset (SpectrumDataset | dict): Reconstructed validation dataset if all files are found, otherwise the raw config dict.
  • test_dataset (SpectrumDataset | dict): Reconstructed test dataset if all files are found, otherwise the raw config dict.
Source code in astroemu/serialisation.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def load(path: str) -> dict:
    """Load a trained emulator from a .astroemu file.

    For each saved dataset, this function attempts to reconstruct a
    SpectrumDataset using the saved file paths and pipeline. If any files
    are missing the raw config dict is returned
    under the dataset key instead.

    Args:
        path (str): Path to a .astroemu file.

    Returns:
        dict: Dictionary with the following keys:

            - **params** (dict): Network weight arrays as jnp.ndarrays.
            - **hyperparams** (dict): Architecture and training
              hyperparameters.
            - **train_losses** (list[float]): Per-epoch training losses.
            - **val_losses** (list[float]): Per-epoch validation losses.
            - **loss** (str): Name of the loss criterion used.
            - **version** (str): astroemu version the emulator was trained
              with.
            - **train_pipeline** (list): Normalisation pipeline instances
              used for the training dataset.
            - **val_pipeline** (list): Normalisation pipeline instances
              used for the validation dataset.
            - **test_pipeline** (list): Normalisation pipeline instances
              used for the test dataset.
            - **train_dataset** (SpectrumDataset | dict): Reconstructed
              training dataset if all files are found, otherwise the raw
              config dict.
            - **val_dataset** (SpectrumDataset | dict): Reconstructed
              validation dataset if all files are found, otherwise the raw
              config dict.
            - **test_dataset** (SpectrumDataset | dict): Reconstructed
              test dataset if all files are found, otherwise the raw config
              dict.
    """
    with zipfile.ZipFile(path, "r") as zf:
        config = json.loads(zf.read("config.json"))

        params_buf = io.BytesIO(zf.read("params.npz"))
        params_np = np.load(params_buf)
        params = {k: jnp.array(params_np[k]) for k in params_np.files}

        pipelines = pickle.loads(zf.read("pipeline.pkl"))

    result: dict = {
        "params": params,
        "hyperparams": config["hyperparams"],
        "train_losses": config["train_losses"],
        "val_losses": config["val_losses"],
        "loss": config["loss"],
        "version": config["version"],
        "train_pipeline": pipelines["train"],
        "val_pipeline": pipelines["val"],
        "test_pipeline": pipelines["test"],
    }

    for split in ("train", "val", "test"):
        key = f"{split}_dataset"
        if key not in config:
            continue

        ds_config = config[key]  # config pertaining to this dataset split

        # look for the files on the system
        files = ds_config["files"]
        missing = [f for f in files if not Path(f).exists()]

        if missing:
            # if the files can't be found
            # return the config dict for this dataset
            result[key] = ds_config
        else:
            # if all files are found, reconstruct the SpectrumDataset
            result[key] = SpectrumDataset(
                files=files,
                x=ds_config["x"],
                y=ds_config["y"],
                forward_pipeline=pipelines[split] or None,
                variable_input=ds_config["variable_input"],
                tiling=ds_config["tiling"],
                allow_pickle=ds_config["allow_pickle"],
            )

    return result

save(path, params, train_losses, val_losses, hidden_size, nlayers, loss, train_dataset, val_dataset, test_dataset, act='relu', epochs=1000, patience=50, learning_rate=0.001, weight_decay=0.0001)

Save a trained emulator to a .astroemu file.

The .astroemu extension is appended automatically if not present. The file is a zip archive containing: - config.json : hyperparameters, training history, loss criterion, code version, and dataset configurations. - params.npz : network weight arrays. - pipeline.pkl : pickled normalisation pipeline instances for the training, validation, and test datasets.

Parameters:

Name Type Description Default
path str

Destination path. The .astroemu extension is appended if not already present, e.g. "emulator" → "emulator.astroemu".

required
params dict

Trained network parameters returned by train().

required
train_losses list[float]

Per-epoch training losses.

required
val_losses list[float]

Per-epoch validation losses.

required
hidden_size int

Number of nodes in each hidden layer.

required
nlayers int

Number of hidden layers.

required
act str

Activation function name used during training.

'relu'
epochs int

Max epochs used during training.

1000
patience int

Early stopping patience used during training.

50
learning_rate float

AdamW learning rate used during training.

0.001
weight_decay float

AdamW weight decay used during training.

0.0001
loss str

Name of the loss criterion used.

required
train_dataset SpectrumDataset

Training dataset whose file paths and pipeline are saved.

required
val_dataset SpectrumDataset

Validation dataset whose file paths and pipeline are saved.

required
test_dataset SpectrumDataset

Test dataset whose file paths and pipeline are saved.

required
Source code in astroemu/serialisation.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def save(
    path: str,
    params: dict,
    train_losses: list[float],
    val_losses: list[float],
    hidden_size: int,
    nlayers: int,
    loss: str,
    train_dataset: SpectrumDataset,
    val_dataset: SpectrumDataset,
    test_dataset: SpectrumDataset,
    act: str = "relu",
    epochs: int = 1000,
    patience: int = 50,
    learning_rate: float = 1e-3,
    weight_decay: float = 1e-4,
) -> None:
    """Save a trained emulator to a .astroemu file.

    The .astroemu extension is appended automatically if not present.
    The file is a zip archive containing:
      - config.json  : hyperparameters, training history, loss criterion,
                       code version, and dataset configurations.
      - params.npz   : network weight arrays.
      - pipeline.pkl : pickled normalisation pipeline instances for the
                       training, validation, and test datasets.

    Args:
        path (str): Destination path. The .astroemu extension is appended
            if not already present, e.g. "emulator" → "emulator.astroemu".
        params (dict): Trained network parameters returned by train().
        train_losses (list[float]): Per-epoch training losses.
        val_losses (list[float]): Per-epoch validation losses.
        hidden_size (int): Number of nodes in each hidden layer.
        nlayers (int): Number of hidden layers.
        act (str): Activation function name used during training.
        epochs (int): Max epochs used during training.
        patience (int): Early stopping patience used during training.
        learning_rate (float): AdamW learning rate used during training.
        weight_decay (float): AdamW weight decay used during training.
        loss (str): Name of the loss criterion used.
        train_dataset (SpectrumDataset): Training dataset whose file paths
            and pipeline are saved.
        val_dataset (SpectrumDataset): Validation dataset whose file paths
            and pipeline are saved.
        test_dataset (SpectrumDataset): Test dataset whose file paths
            and pipeline are saved.
    """
    if not path.endswith(".astroemu"):
        path = path + ".astroemu"

    config: dict = {
        "version": __version__,
        "hyperparams": {
            "hidden_size": hidden_size,
            "nlayers": nlayers,
            "act": act,
            "epochs": epochs,
            "patience": patience,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
        },
        "loss": loss,
        "train_losses": train_losses,
        "val_losses": val_losses,
    }

    def _dataset_config(ds: SpectrumDataset) -> dict:
        return {
            "files": ds.files,
            "x": ds.x,
            "y": ds.y,
            "variable_input": ds.varied_input,
            "tiling": ds.tiling,
            "allow_pickle": ds.allow_pickle,
        }

    config["train_dataset"] = _dataset_config(train_dataset)
    config["val_dataset"] = _dataset_config(val_dataset)
    config["test_dataset"] = _dataset_config(test_dataset)

    pipelines = {
        "train": train_dataset.forward_pipeline,
        "val": val_dataset.forward_pipeline,
        "test": test_dataset.forward_pipeline,
    }

    with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        zf.writestr("config.json", json.dumps(config, indent=2))

        buf = io.BytesIO()
        np.savez(buf, **{k: np.array(v) for k, v in params.items()})
        zf.writestr("params.npz", buf.getvalue())

        zf.writestr("pipeline.pkl", pickle.dumps(pipelines))

astroemu.utils

Utility functions for emu package.

compute_mean_std(loader)

Memory safe mean and std computation.

Expects the loader to yield three-tuples (spec, x, inputs) as produced by SpectrumDataset.get_batch_iterator() with tiling=False.

Since x (the independent variable) is identical for every sample in the dataset, its mean and std are computed as global scalars over all elements rather than per-column statistics.

Parameters:

Name Type Description Default
loader Iterator[tuple[ndarray, ndarray, ndarray]]

Iterable yielding (spec, x, inputs) where: - spec: (batch_size, len_x) - x: (batch_size, len_x) - inputs: (batch_size, n_params)

required

Returns:

Name Type Description
mean_spec ndarray

(len_x,) - per-frequency mean across batches

std_spec ndarray

(len_x,) - per-frequency std across batches

mean_x ndarray

scalar Array - global mean of the independent variable

std_x ndarray

scalar Array - global std of the independent variable

mean_input ndarray

(n_params,) - per-parameter mean across batches

std_input ndarray

(n_params,) - per-parameter std across batches

Source code in astroemu/utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_mean_std(
    loader: Iterator[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]],
) -> tuple[
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
]:
    """Memory safe mean and std computation.

    Expects the loader to yield three-tuples (spec, x, inputs) as produced
    by SpectrumDataset.get_batch_iterator() with tiling=False.

    Since x (the independent variable) is identical for every sample in the
    dataset, its mean and std are computed as global scalars over all elements
    rather than per-column statistics.

    Args:
        loader: Iterable yielding (spec, x, inputs) where:
            - spec:   (batch_size, len_x)
            - x:      (batch_size, len_x)
            - inputs: (batch_size, n_params)

    Returns:
        mean_spec:   (len_x,)  - per-frequency mean across batches
        std_spec:    (len_x,)  - per-frequency std across batches
        mean_x:      scalar Array - global mean of the independent variable
        std_x:       scalar Array - global std of the independent variable
        mean_input:  (n_params,) - per-parameter mean across batches
        std_input:   (n_params,) - per-parameter std across batches
    """
    spec_sum = None
    spec_sum_sq = None
    input_sum = None
    input_sum_sq = None
    n_spec_samples = 0
    n_input_samples = 0

    for spec, x, input_data in loader:
        batch_size = spec.shape[0]

        # spectrum accumulators: sum across batch dimension
        if spec_sum is None:
            spec_sum = jnp.zeros(spec.shape[1], dtype=spec.dtype)
            spec_sum_sq = jnp.zeros(spec.shape[1], dtype=spec.dtype)
        spec_sum = spec_sum + spec.sum(axis=0)
        spec_sum_sq = spec_sum_sq + (spec**2).sum(axis=0)
        n_spec_samples += batch_size

        # input parameter accumulators: sum across batch dimension
        if input_sum is None:
            input_sum = jnp.zeros(input_data.shape[-1], dtype=input_data.dtype)
            input_sum_sq = jnp.zeros(
                input_data.shape[-1], dtype=input_data.dtype
            )
        input_sum = input_sum + input_data.sum(axis=0)
        input_sum_sq = input_sum_sq + (input_data**2).sum(axis=0)
        n_input_samples += batch_size

    mean_spec = spec_sum / n_spec_samples
    var_spec = (spec_sum_sq / n_spec_samples) - mean_spec**2
    var_spec = jnp.where(
        var_spec < 1e-6, 1e-6, var_spec
    )  # avoid divide-by-zero
    std_spec = jnp.where(jnp.sqrt(var_spec) < 1e-3, 1.0, jnp.sqrt(var_spec))

    mean_input = input_sum / n_input_samples
    var_input = (input_sum_sq / n_input_samples) - mean_input**2
    var_input = jnp.where(
        var_input < 1e-6, 1e-6, var_input
    )  # avoid divide-by-zero
    std_input = jnp.where(jnp.sqrt(var_input) < 1e-3, 1.0, jnp.sqrt(var_input))

    # global average for x since it's the same for every sample
    # just take final batch's first row and compute mean/std across columns
    mean_x = x[0, :].mean()
    std_x = x[0, :].std()

    return mean_spec, std_spec, mean_x, std_x, mean_input, std_input