Comments (7)
After careful review, it has come to our attention that there was an error in our model checkpoint's upload process. We sincerely regret this oversight and the inconvenience it may have precipitated.
To rectify this, we have promptly uploaded the corrected model checkpoint. We suggest you replace the previous LLaMA Diff checkpoint with this updated version zjunlp/llama-molinst-protein-7b. You should then be able to recover the parameters using the weight_diff.py
script we've provided.
We wish to reiterate our deepest regret for any trouble this may have caused. Your insightful feedback is deeply valued, contributing significantly to the continual improvement of our work. Please do not hesitate to reach out if you have any additional issues or require further clarification.
from mol-instructions.
Thx for the update. Now I can normally infer.
from mol-instructions.
confirmed working with the new model. Thanks for fixing this so quickly!
from mol-instructions.
Hi,
What generation strategy do you use? beam search? We've previously experienced similar issues on other generation tasks (not in this work) with transformers' generation
, primarily attributed to a problem with beam search within generation
(refer to Make beam sample more robust). Consequently, we transitioned to sampling strategies, specifically the top-k or top-p sampling methods.
In this work, we choose the sampling methods such as top-k or top-p sampling during the generation process for protein-related tasks. The configuration used for our generation can be illustrated with the following example:
generation_config = GenerationConfig(
do_sample=True,
top_k=gen_args.top_k, # 10 or more
repetition_penalty=gen_args.repetition_penalty, # 1.0 or 1.2 to penalize repetition
pad_token_id=0, # hard code for Llama Tokenizer
bos_token_id=1,
eos_token_id=2
)
The selection of sampling strategies and hyperparameters depends on your use case; it may be beneficial to iterate and experiment with different strategies and hyperparameters to find the optimal combination.
from mol-instructions.
Thx for your feedback. I tried several combinations of GenerationConfig
, set repetition_penalty=1.2
(as default), have experimented on :
- top_k = 10 / 15
- top_p = 0.75 / 0.9
- repetition_penalty = 1.0 / 1.2
- num_beams = 1 / 4
But I'm sorry all tries failed—still the same error. (RuntimeError: probability tensor contains either
inf,
nan or element < 0
) I used the same instruction and Input as above:
"""
instruction: "Find and list any domains or motifs that are likely present in this protein sequence:"
input: "MEFDTIAAISTFPGEAGIGIVRISGDEALEIISKIFRPFRKKDIKSVKSHTIHYGHIVDPETGEVYDEVLVTVMRKPNTYTREDVVEINCHGGIVVSSKILELVLKHGARLAEPGEFTKRAFLNGRIDLSQAEAVIDIITSKTMLANRYAQKQLAGVLGQKMKDLKNKIMELLSHLLALIDFPEEDVEELEREEIKRRAKDILNDIEYLIASSESGRIIREGLKTAIIGKPNVGKSSLLNALLKQNRAIVTDIPGTTRDVIEEYMNIKGIPIKLIDTAGIRHTDELVEKIGVEKSKEVLAEADLILFVLDASRDLTKEDYEIFDILSGKNIIFVLNKVDLPKKIDEEELKKLVGNGIIVEVSTVERTGLDKLESEIYNLVFKGKVSATEEEIITNARHREVLINAKKHMESVIEAIEKGYSEDLITIDVNGALNEIGKITGETATEDVINQIFERFCVGK"
"""
Have you been able to check the released ckpt? (Another colleague deployed and tried to infer but met the same error, so I think it's not related to our machine/env problems)
from mol-instructions.
Thank you for your feedback. We truly appreciate your input. We will review the checkpoint we've released soon.
from mol-instructions.
As another datapoint, I'm seeing NaNs in the hidden states, even before trying to generate the sequences. I'm not sure this will reproduce exactly, but running on a p3.2xlarge instance:
Run first:
python weight_diff.py recover \
--path_raw "decapoda-research/llama-7b-hf" \
--path_diff "zjunlp/llama-molinst-protein-7b" \
--path_tuned /home/ubuntu/zjunlp/llama-molinst-protein-7b_recovered
Then load the model and tokenizer (adapted from generate.py)
base_model_str = "decapoda-research/llama-7b-hf"
model_str = "zjunlp/llama-molinst-protein-7b"
recovered_model_path = "/home/ubuntu/zjunlp/llama-molinst-protein-7b_recovered"
model = LlamaForCausalLM.from_pretrained(
recovered_model_path,
load_in_8bit=False,
torch_dtype=torch.float16,
#device_map="auto",
device_map={"": 0},
cache_dir=cache_dir
)
tokenizer = LlamaTokenizer.from_pretrained(base_model_str,
bos_token='<s>',
eos_token='</s>',
add_bos_token=True,
add_eos_token=False)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
model.cuda()
Try a prompt:
# get a uniprot sequence
import requests
url = "https://rest.uniprot.org/uniprotkb/P05067.fasta"
f = requests.get(url)
seq = "".join(f.text.split("\n")[1:])
prompter = Prompter()
prompt = prompter.generate_prompt("Is the input is a protein sequence?", seq)
encoder_text_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
output = model(encoder_text_ids, output_hidden_states=True)
# check the last hidden layer
output["hidden_states"][-1][0, -1]
which outputs:
tensor([nan, nan, nan, ..., nan, nan, nan], device='cuda:0',
dtype=torch.float16)
Tweaking the prompts and the length of the sequence, sometimes I do get numbers from the hidden states.
from mol-instructions.
Related Issues (13)
- Hyperparameter HOT 7
- Question about protein design HOT 3
- LORA Settings HOT 4
- pyarrow.lib.ArrowInvalid when loading data HOT 2
- [Request] Release the part of test-split for each task HOT 3
- molT5 dataset HOT 8
- about test data
- Requirements HOT 1
- Regarding some questions about the paper. HOT 3
- can not find config.json HOT 17
- Training Time HOT 1
- TypeError: __init__() got an unexpected keyword argument 'enable_lora' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mol-instructions.