Skip to content

Improved model preprocessing, new batch structure#471

Open
jlamypoirier wants to merge 14 commits intojlp_simplify_mtpfrom
jlp_batch
Open

Improved model preprocessing, new batch structure#471
jlamypoirier wants to merge 14 commits intojlp_simplify_mtpfrom
jlp_batch

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Feb 17, 2026

✨ Description

Highlights:

  • Move (most of) model preprocessing to the data loader. This makes it not only simpler, but also faster since it runs in parallel processes.
  • Add a ModelInput structure to handle this preprocessing, and potentially replace the arbitrary kwargs we pass to the layers.
  • Drop the artificial concept of samples (mostly). We now pack documents directly into batches, and always use the varlen implementation of mixers (no cross-document attention. Merge the batch and sequence dimensions into a single token dimension.
  • Tons of simplifications, tweaks and fixes. See details below.

Data changes

Data

  • Config: remove sampling field, inherit from sampling config instead. Add micro_batch_size, maximum_document_length and truncate_documents fields (moved from batch config).
  • Replace setup method with fine-grained sample_dataset prepares one dataset at the time.

Batch configuration

  • Remove BatchConfig altogether. Micro-batches are now defined in the data config, and sequential micro-bathces and other scheduling options (ex. micro-sequences) are defined in the schedule config.
  • Note: all batch-related quantities are now given in tokens. It is no longer possible to set the actual batch size, which is now inferred from the micro-batch size and count.

Datasets

  • Replace SamplingData structure with a simpler SamplingConfig. Remove SamplingParameters, now directly part of the sampling/data configuration (sequence_length, truncate_documents) or provided separately to dataset.sample (num_samples, seed).
  • Remove SampledDatasetUpdateConfig, SamplingConfig.update_config, no longer necessary since num_samples and seed are no longer in the config.
  • Datasets now return documents instead of samples. Sampled datasets return lists of documents, concatenation is now performed in collate_fn.

Memmap dataset

  • Move memmap to a new directory together with readers and writers.
  • Remove image normalization from memmap reader. (moved to model preprocessing)
  • Add reader configs (now in a separate config file) to data.auto so dynamic reader classes work out of the box.
  • Memmap readers now read all dataset content whether it's needed or not, and no longer know about model preprocessing. Justification i that spans are cheap to read, and it doesn't make much sense to ignore image patches in a dataset because of the padding tokens.

Document, ModelInput, Batch

  • New structures replace and expand on the previous Sample, Batch structure.
  • Document: Represent a single, unprocessed document. This is the only structure datasets deal with.
  • Batch: Intermediate state representing a concatenation of documents. Its main purpose is to handle model preprocessing.
  • ModelInput: A preprocessed batch ready to be fed to the model. A batch may lead to more than one ModelInput (with micro-sequences).

Preprocessing

  • Model preprocessing is now configured through a BatchPreprocessingConfig, provided by the model (in data.sample_dataset)
  • Model preprocessing is done in Batch.get_model_inputs, called in data._collate_fn.

Preparation, offline preprocessing

  • Move offline preprocessing configs (tokenizer, image patch) to the (renamed) preparation directory, and make them completely distinct from other preprocessing configs.
  • Simplify the dataset discovery preparator quite a bit.

Other data changes

  • Fix some MTP issues in SampledIndexedDataset (1 -> num_heads in a few places)

Modeling changes

Preprocessing

  • Remove nearly all preprocessing, now handled in the data. Notable exceptions are the backup attention mask (hard to generalize), rotary (model-dependent, possibility of incompatible configs) and stochastic mixer.
  • Add get_preprocessing_config to blocks/layers, following the same pattern as preprocess (and meant to eventually replace it completely), but instead returning a dict of arguments for the BatchPreprocessingConfig so the data loader knows what preprocessing to do.
  • Remove preprocessed_meta. preprocess_batch now handles meta inputs directly for making the schedule.
  • model.preprocess_batch now mainly consists of a model_input.to_kwargs() call, plus some handling of reference models.

Training

  • Add trainer._get_completion_metrics (replaces TrainingProgress), schedule._get_compute_metrics to handle common logging between trainer and evaluator.
  • Adjust Step Schedule, ScheduleRunner.
  • Remove consumed_samples which no longer makes sense.
  • Remove the test phase which was both broken and made unnecessary by the addition of evaluators. Remove test_iters, PhaseType.test.
  • Replace PhaseType values with lower case to simplify things (may affect logging/wandb a bit)

Evaluation

  • Greatly simplify evaluators, removing redundant wrappers. There is only one structure left, the Evaluator itself.
  • Remove TrainingEvaluatorConfig. The trainer now defines a dict of EvaluatorConfig, which inherit from Interval config.
  • Remove get_sampling_parameters, evaluators may now call data.sample_dataset directly instead.
  • Tentatively fix LM eval evaluator

Attention

  • Remove cross_document_attention, always use the varlen implementation.
  • Remove the batch dimension completely.
  • Move the rotary embedding computation to right after the query and key_value computations. This improves efficiency (removes redundant computation for past key) and fixes an important bug where the in-place rotary computation could wrongly modify the present, breaking training with micro-sequences. (Was previously unknown because of an implicit copy in a contiguous call)

SSM

  • Simplify gdn anf kda implementations, mostly by merging redundant reshapes. (moved to Simplify MTP #470)
  • Add varlen support to backup torch implementation of CausalConv1d (runs in a loop). This allows running gdn tests on cpu.

Language model

  • Remove explicit padding mask, as padding is now entirely defined by num_tokens kwarg.
  • Remove mask_inputs kwarg, embedding layer can figure it out based on num_tokens and patch embeddings.
  • Standardize prediction_distance to start at 1.
  • Change labels and loss_mask kwarg format to a list, one tensor per prediction head. Add masking for cross document predictions (hence the need for separate tensors).
  • Drop support for DPO at least for now, because it caused trouble and hasn't been working for a long time anyway (if ever).

Debugging/QOL features

  • Add force_cpu_initialization in distributed config, which does the initialization on cpu and allows matching initialization.
  • Add full_tensors to TensorLogsConfig which allows saving/printing full tensors in tensor logs.

Tweaks

  • Adjust huggingface wrappers for the new data structures.

Testing

TODO

External models

  • Make Apriel 2 model return hidden states and vision hidden states, as needed for testing.

(outdated)

Future steps:

  • Replace kwargs in layers with the MicroBatch structure.
  • Remove the preprocess method in base models and layers. (Still needed for rotary embeddings and stochastic mixer)
  • Clarify preprocess_batch, which now takes an already preprocessed batch as input. Seems to be mostly about running reference models now.
  • Expand the data tests to cover model preprocessing.

Open questions:

  • Preference spans have not been working for a while, and are causing trouble. Do we still want them?
  • Do we want to bring back the option to ignore loss masking spans in some way? Is it ever needed?
  • What to do with Mamba, which doesn't support varlen? [bug] Can't compile varlen mamba with base image 25.11 #416
  • Blended datasets take each "sample" from a single dataset. Effectively this means each micro-batch takes only documents from one of the datasets, so sampling is uneven unless we have lots of sequential of parallel micro-batches. Is this fixable?l

@jlamypoirier jlamypoirier marked this pull request as ready for review March 13, 2026 19:43
@jlamypoirier jlamypoirier changed the title [Prototype] Improved model preprocessing, new batch structure Improved model preprocessing, new batch structure Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant