diff --git a/tokenise+train.py b/tokenise+train.py index a5dd7d7..dff775b 100644 --- a/tokenise+train.py +++ b/tokenise+train.py @@ -19,18 +19,14 @@ def train(filepath: str, ouputdir: Path, blocksize: int, vocabsize: int, num_ste from aitextgen.utils import build_gpt2_config from aitextgen import aitextgen - exts = ['.json', '.gz'] - files = [x for x in ouputdir.glob('*') if x.suffix in exts and x.name != "config.json"] + files = [x for x in ouputdir.glob('*') if x.name.endswith(".tokenizer.json")] print(files) - if len(files) == 2: - if files[0].suffix == '.json': - tok = str(files[0]) - dat = str(files[1]) - else: - tok = str(files[1]) - dat = str(files[0]) + if len(files) == 1: + tok = str(files[0]) + else: + return "No valid tokenizer in " + str(ouputdir) config = build_gpt2_config(vocab_size=vocabsize, max_lenght=blocksize, dropout=0.0, n_embd=256, n_layer=8, n_head=8)