import argparse, os, sys from aitextgen.TokenDataset import TokenDataset from aitextgen.utils import GPT2ConfigCPU from aitextgen.utils import build_gpt2_config from aitextgen import aitextgen # https://github.com/minimaxir/aitextgen/blob/master/aitextgen/utils.py # https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/configuration_gpt2.py def run_cpu(te: str, tok: str, dat: str, blocksize: int, num_steps: int = 10000) -> int: config = GPT2ConfigCPU() ai = aitextgen(tokenizer_file=tok, config=config) data = TokenDataset(dat, tokenizer_file=tok, block_size=blocksize, from_cache=True) ai.train(data, output_dir=te, batch_size=16, num_steps=num_steps, generate_every=1000, save_every=1000, num_workers=4) return 0 def run_gpu(te: str, tok: str, dat: str, blocksize: int, num_steps: int = 10000) -> int: #NOTE: vocab_size is fixed since this is not yet in train_tokenizer config = build_gpt2_config(vocab_size=1000, max_lenght=blocksize) ai = aitextgen(tokenizer_file=tok, config=config) data = TokenDataset(dat, tokenizer_file=tok, block_size=blocksize, from_cache=True) ai.train(data, output_dir=te, batch_size=16, num_steps=num_steps, generate_every=1000, save_every=1000, num_workers=4, to_gpu=True) return 0 def main() -> int: p = argparse.ArgumentParser() p.add_argument("text", type=str, help="text to create model from") p.add_argument("-b", "--blocksize", type=int, choices=[32, 64, 128, 256, 1024], default=64, help="block size, default=64 (corresponds to GPT-2 'max_lenght' config)") p.add_argument("-s", "--numsteps", type=int, default=10000) p.add_argument("--tokensdir", type=str, default="data/tokens/") p.add_argument("--ouputdir", type=str, default="data/models/") p.add_argument("--gpu", action="store_true") args = p.parse_args() tok_file = f"{args.tokensdir}{args.text}.tokenizer.json" dat_file = f"{args.tokensdir}{args.text}_bs={args.blocksize}.tar.gz" output_dir = f"{args.ouputdir}{args.text}_bs={args.blocksize}_ns={args.numsteps}" if args.gpu: return run_gpu(te=output_dir, tok=tok_file, dat=dat_file, blocksize=args.blocksize, num_steps=args.numsteps) else: return run_cpu(output_dir, tok_file, dat_file, args.blocksize) if __name__ == '__main__': sys.exit(main())