line_by_line

This commit is contained in:
gauthiier 2022-02-09 18:42:32 +01:00
parent 0323f1a00e
commit 58dd03ba73

View File

@ -30,6 +30,8 @@ def train(ouputdir: Path, blocksize: int, vocabsize: int, num_steps: int, gpu: b
config = build_gpt2_config(vocab_size=vocabsize, max_lenght=blocksize) config = build_gpt2_config(vocab_size=vocabsize, max_lenght=blocksize)
print(config)
ai = aitextgen(tokenizer_file=tok, config=config) ai = aitextgen(tokenizer_file=tok, config=config)
data = TokenDataset(dat, tokenizer_file=tok, block_size=blocksize, from_cache=True) data = TokenDataset(dat, tokenizer_file=tok, block_size=blocksize, from_cache=True)
@ -70,10 +72,10 @@ def encode(filepath: str, blocksize: int, vocabsize: int, ouputdir: Path, verbos
print(dataset_fn) print(dataset_fn)
if type(text) is str: if type(text) is str:
data = TokenDataset(file_path=text, tokenizer_file=tok_fn, block_size=blocksize) data = TokenDataset(file_path=text, tokenizer_file=tok_fn, block_size=blocksize, line_by_line=True)
else: else:
texts = [x.read_text() for x in text] texts = [x.read_text() for x in text]
data = TokenDataset(texts=texts, tokenizer_file=tok_fn, block_size=blocksize) data = TokenDataset(texts=texts, tokenizer_file=tok_fn, block_size=blocksize, line_by_line=True)
data.save(cache_destination=dataset_fn) data.save(cache_destination=dataset_fn)
return "encode success" return "encode success"