Improved model preprocessing, new batch structure#471
Open
jlamypoirier wants to merge 14 commits intojlp_simplify_mtpfrom
Open
Improved model preprocessing, new batch structure#471jlamypoirier wants to merge 14 commits intojlp_simplify_mtpfrom
jlamypoirier wants to merge 14 commits intojlp_simplify_mtpfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
✨ Description
Highlights:
ModelInputstructure to handle this preprocessing, and potentially replace the arbitrarykwargswe pass to the layers.samples(mostly). We now pack documents directly into batches, and always use the varlen implementation of mixers (no cross-document attention. Merge thebatchandsequencedimensions into a singletokendimension.Data changes
Data
samplingfield, inherit from sampling config instead. Addmicro_batch_size,maximum_document_lengthandtruncate_documentsfields (moved from batch config).setupmethod with fine-grainedsample_datasetprepares one dataset at the time.Batch configuration
BatchConfigaltogether. Micro-batches are now defined in the data config, and sequential micro-bathces and other scheduling options (ex. micro-sequences) are defined in the schedule config.Datasets
SamplingDatastructure with a simplerSamplingConfig. RemoveSamplingParameters, now directly part of the sampling/data configuration (sequence_length,truncate_documents) or provided separately todataset.sample(num_samples,seed).SampledDatasetUpdateConfig,SamplingConfig.update_config, no longer necessary sincenum_samplesandseedare no longer in the config.Memmap dataset
configfile) todata.autoso dynamic reader classes work out of the box.Document, ModelInput, Batch
Sample,Batchstructure.Document: Represent a single, unprocessed document. This is the only structure datasets deal with.Batch: Intermediate state representing a concatenation of documents. Its main purpose is to handle model preprocessing.ModelInput: A preprocessed batch ready to be fed to the model. A batch may lead to more than oneModelInput(with micro-sequences).Preprocessing
BatchPreprocessingConfig, provided by the model (indata.sample_dataset)Batch.get_model_inputs, called indata._collate_fn.Preparation, offline preprocessing
Other data changes
SampledIndexedDataset(1 ->num_headsin a few places)Modeling changes
Preprocessing
get_preprocessing_configto blocks/layers, following the same pattern aspreprocess(and meant to eventually replace it completely), but instead returning a dict of arguments for theBatchPreprocessingConfigso the data loader knows what preprocessing to do.preprocessed_meta.preprocess_batchnow handles meta inputs directly for making the schedule.model.preprocess_batchnow mainly consists of amodel_input.to_kwargs()call, plus some handling of reference models.Training
trainer._get_completion_metrics(replacesTrainingProgress),schedule._get_compute_metricsto handle common logging between trainer and evaluator.StepSchedule,ScheduleRunner.consumed_sampleswhich no longer makes sense.test_iters,PhaseType.test.PhaseTypevalues with lower case to simplify things (may affect logging/wandb a bit)Evaluation
Evaluatoritself.TrainingEvaluatorConfig. The trainer now defines a dict ofEvaluatorConfig, which inherit fromInterval config.get_sampling_parameters, evaluators may now calldata.sample_datasetdirectly instead.Attention
cross_document_attention, always use the varlen implementation.present, breaking training with micro-sequences. (Was previously unknown because of an implicit copy in a contiguous call)SSM
gdnanfkdaimplementations, mostly by merging redundant reshapes. (moved to Simplify MTP #470)CausalConv1d(runs in a loop). This allows running gdn tests on cpu.Language model
num_tokenskwarg.mask_inputskwarg, embedding layer can figure it out based onnum_tokensand patch embeddings.prediction_distanceto start at 1.labelsandloss_maskkwarg format to a list, one tensor per prediction head. Add masking for cross document predictions (hence the need for separate tensors).Debugging/QOL features
force_cpu_initializationin distributed config, which does the initialization on cpu and allows matching initialization.full_tensorstoTensorLogsConfigwhich allows saving/printing full tensors in tensor logs.Tweaks
Testing
TODO
External models
(outdated)
Future steps:
kwargsin layers with theMicroBatchstructure.preprocessmethod in base models and layers. (Still needed for rotary embeddings and stochastic mixer)preprocess_batch, which now takes an already preprocessed batch as input. Seems to be mostly about running reference models now.Open questions: