diff --git a/trials/incsum.py b/trials/incsum.py new file mode 100644 index 00000000..41b3d500 --- /dev/null +++ b/trials/incsum.py @@ -0,0 +1,158 @@ +# Observe the incremental summaries by performing summaries in chunks +with open("transcript.txt") as f: + transcription = f.read() + +import spacy + + +def split_text_file(filename, token_count): + nlp = spacy.load('en_core_web_md') + + with open(filename, 'r') as file: + text = file.read() + + doc = nlp(text) + total_tokens = len(doc) + + parts = [] + start_index = 0 + + while start_index < total_tokens: + end_index = start_index + token_count + part_tokens = doc[start_index:end_index] + part = ' '.join(token.text for token in part_tokens) + parts.append(part) + start_index = end_index + + return parts + +# Set the chunk length here to split the transcript and test +MAX_CHUNK_LENGTH=1000 + +chunks = split_text_file("transcript.txt", MAX_CHUNK_LENGTH) +print("Number of chunks", len(chunks)) + +# Write chunks to file to refer to input vs output, separated by blank lines +with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a") as f: + for c in chunks: + f.write(c + "\n\n") + +# If we want to run only a certain model, type the option while running +# ex. python incsum.py 1 => will run approach 1 +# If no input, will run all approaches + +import sys +try: + index = sys.argv[1] +except: + index = None + + +# Approach 1 : facebook/bart-large-cnn +if index == "1" or index is None: + SUMMARY_MODEL="facebook/bart-large-cnn" + MIN_LENGTH=5 + MAX_LENGTH=10 + BEAM_SIZE=2 + + print("Performing chunk summary : " + SUMMARY_MODEL) + + from transformers import BartTokenizer, BartForConditionalGeneration + + tokenizer = BartTokenizer.from_pretrained(SUMMARY_MODEL) + model = BartForConditionalGeneration.from_pretrained(SUMMARY_MODEL) + summaries = [] + for c in chunks: + input_ids = tokenizer.encode(c, + truncation=True, + max_length=MAX_CHUNK_LENGTH, + padding="max_length", + return_tensors='pt') + summary_ids = model.generate( + input_ids, + num_beams=BEAM_SIZE, + max_length=56, + early_stopping=True, + length_penalty=1.0) + summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) + summaries.append(summary) + + with open("bart-summaries.txt", "a") as f: + for summary in summaries: + f.write(summary + "\n\n") + + +# Approach 2 +if index == "2" or index is None: + print("Performing chunk summary : " + "gpt-neo-1.3B") + + import torch + from transformers import GPTNeoForCausalLM, GPT2Tokenizer + + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + summaries = [] + + for c in chunks: + input_ids = tokenizer.encode(c, + truncation=True, + return_tensors='pt') + input_length = input_ids.shape[1] + attention_mask = torch.ones(input_ids.shape, dtype=torch.long) + + max_summary_length = 100 + max_length = input_length + max_summary_length + + output = model.generate(input_ids, + max_length=max_length, + attention_mask=attention_mask, + pad_token_id=model.config.eos_token_id, + num_beams=4, + length_penalty=2.0, + early_stopping=True) + summary_ids = output[0, input_length:] + summary = tokenizer.decode(summary_ids, skip_special_tokens=True) + summaries.append(summary) + with open("gptneo1.3B-summaries.txt", "a") as f: + f.write(summary + "\n\n") + +# Approach 3 +if index == "3" or index is None: + print("Performing chunk summary : " + "mpt-7B") + + import torch + import transformers + from transformers import AutoTokenizer + + config = transformers.AutoConfig.from_pretrained('mosaicml/mpt-7b', + trust_remote_code=True) + config.attn_config['attn_impl'] = 'triton' + config.max_seq_len = 1024 + config.init_device = "meta" + + model = transformers.AutoModelForCausalLM.from_pretrained( + 'mosaicml/mpt-7b', + trust_remote_code=True, + torch_dtype=torch.bfloat16 + ) + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') + + summaries = [] + for c in chunks: + input_ids = tokenizer.encode(c, return_tensors="pt") + attention_mask = torch.ones(input_ids.shape, dtype=torch.long) + output = model.generate(input_ids, + max_new_tokens=25, + attention_mask=attention_mask, + pad_token_id=model.config.eos_token_id, + num_return_sequences=1) + summary = tokenizer.decode(output[0], + skip_special_tokens=True) + summaries.append(summary) + + with open("mpt-7b-summaries.txt", "a") as f: + for summary in summaries: + f.write(summary + "\n\n") +