Skip to content

Model Loading

There are several ways to load models.

Continue Training/Finetuning

You can further train a model by reloading a model that has already been trained, using the methods outlined below.

Loading an aitextgen model

For the base case, loading the default 124M GPT-2 model via Huggingface:

ai = aitextgen()

The downloaded model will be downloaded to cache_dir: /aitextgen by default.

If you're loading a custom model for a different GPT-2/GPT-Neo architecture from scratch but with the normal GPT-2 tokenizer, you can pass only a config.

from aitextgen.utils import GPT2ConfigCPU
config = GPT2ConfigCPU()
ai = aitextgen(config=config)

While training/finetuning a model, two files will be created: the pytorch_model.bin which contains the weights for the model, and a config.json illustrating the architecture for the model. Both of these files are needed to reload the model.

If you've finetuned a model using aitextgen (the default model), you can pass the folder name containing the generated pytorch_model.bin and config.json to aitextgen (e.g. trained_model, which is where trained models will be saved by default).

Same Directory

If both files are in the current directory, you can pass model_folder=".".

ai = aitextgen(model_folder="trained_model")

These examples assume you are using the default GPT-2 tokenizer. If you have a custom tokenizer, you'll need to pass that along with loading the model.

ai3 = aitextgen(model_folder="trained_model",

If you want to download an alternative GPT-2 model from Hugging Face's repository of models, pass that model name to model.

ai = aitextgen(model="minimaxir/hacker-news")

The model and associated config + tokenizer will be downloaded into cache_dir.

This can also be used to download the pretrained GPT Neo models from EleutherAI.

ai = aitextgen(model="EleutherAI/gpt-neo-125M")

Loading TensorFlow-based GPT-2 models

aitextgen lets you download the models from Microsoft's servers that OpenAI had uploaded back when GPT-2 was first released in 2019. These models are then converted to a PyTorch format.

To use this workflow, pass the corresponding model number to tf_gpt2:

ai = aitextgen(tf_gpt2="124M")

This will cache the converted model locally in cache_dir, and using the same parameters will load the converted model.

The valid TF model names are ["124M","355M","774M","1558M"].