Skip to content


aitextgen has a special class, TokenDataset, used for managing tokenized datasets to be fed into model training. (this is in contrast with other GPT-2 finetuning approaches, which tokenizes at training time although you can still do that by passing a file_path and other relevant parameters to ai.train().)

This has a few nice bonuses, including:

  • Tokenize a dataset on a local machine ahead of time and compress it, saving time/bandwidth transporting data to a remote machine
  • Supports both reading a dataset line-by-line (including single-column CSVs), or bulk texts.
  • Debug and log the loaded texts.
  • Merge datasets together without using external libraries
  • Cross-train on multiple datasets to "blend" them together.

Creating a TokenDataset For GPT-2 Finetuning

The easiest way to create a TokenDataset is to provide a target file. If no tokenizer_file is provided, it will use the default GPT-2 tokenizer.

from aitextgen.TokenDataset import TokenDataset

data = TokenDataset("shakespeare.txt")

If you pass a single-column CSV and specify line_by_line=True, the TokenDataset will parse it row-by-row, and is the recommended way to handle multiline texts.

data = TokenDataset("politics.csv", line_by_line=True)

You can also manually pass a list of texts to texts instead if you've processed them elsewhere.

data = TokenDataset(texts = ["Lorem", "Ipsum", "Dolor"])

Block Size

block_size is another parameter that can be passed when creating a TokenDataset, more useful for custom models. This should match the context window (e.g. the n_positions or max_position_embeddings config parameters). By default, it will choose 1024: the GPT-2 context window.

When implicitly loading a dataset via ai.train(), the block_size will be set to what is supported by the corresponding model config.

Debugging a TokenDataset

When loading a dataset, a progress bar will appear showing how many texts are loaded and

If you want to see what exactly is input to the model during training, you can access a slice via data[0].

Saving/Loading a TokenDataset

When creating a TokenDataset, you can automatically save it as a compressed gzipped numpy array when completed.

data = TokenDataset("shakespeare.txt", save_cache=True)

Or save it after you've loaded it with the save() function.

data = TokenDataset("shakespeare.txt")

By default, it will save to dataset_cache.tar.gz. You can then reload that into another Python session by specifying the cache.

data = TokenDataset("dataset_cache.tar.gz", from_cache=True)


You can quickly create a Tokenized dataset using the command line, e.g. aitextgen encode text.txt. This will drastically reduce the file size, and is recommended before moving the file to cloud services (where it can be loaded using the from_cache parameter noted above)

Using TokenDatasets with a Custom GPT-2 Model

The default TokenDataset has a block_size of 1024, which corresponds to the context window of the default GPT-2 model. If you're using a custom model w/ a different maximum. Additionally, you must explicitly provide the tokenizer file to rebuild the tokenizer, as the tokenizer will be different than the normal GPT-2 one.

See the Model From Scratch docs for more info.

Merging TokenDatasets

Merging processed TokenDatasets can be done with the merge_datasets() function. By default, it will take samples equal to the smallest dataset from each TokenDataset, randomly sampling the appropriate number of texts from the larger datasets. This will ensure that model training does not bias toward one particular dataset. (it can be disabled by setting equalize=False)

(You can merge bulk datasets and line-by-line datasets, but the training output may be bizarre!)

About Merging

The current implementation merges by subset count, so equalization may not be perfect, but it will not significantly impact training.

from aitextgen.TokenDataset import TokenDataset, merge_datasets

data1 = TokenDataset("politics1000.csv", line_by_line=True)   # 1000 samples
data2 = TokenDataset("politics4000.csv", line_by_line=True)   # 4000 samples

data_merged = merge_datasets([data1, data2])   # ~2000 samples