Giter VIP home page Giter VIP logo

Comments (7)

Alexzhuan avatar Alexzhuan commented on July 3, 2024 4

Hello @zdk123 , @CiaoHe ,

After careful review, it has come to our attention that there was an error in our model checkpoint's upload process. We sincerely regret this oversight and the inconvenience it may have precipitated.

To rectify this, we have promptly uploaded the corrected model checkpoint. We suggest you replace the previous LLaMA Diff checkpoint with this updated version zjunlp/llama-molinst-protein-7b. You should then be able to recover the parameters using the weight_diff.py script we've provided.

We wish to reiterate our deepest regret for any trouble this may have caused. Your insightful feedback is deeply valued, contributing significantly to the continual improvement of our work. Please do not hesitate to reach out if you have any additional issues or require further clarification.

from mol-instructions.

CiaoHe avatar CiaoHe commented on July 3, 2024 1

Thx for the update. Now I can normally infer.

from mol-instructions.

zdk123 avatar zdk123 commented on July 3, 2024 1

confirmed working with the new model. Thanks for fixing this so quickly!

from mol-instructions.

Alexzhuan avatar Alexzhuan commented on July 3, 2024

Hi,

What generation strategy do you use? beam search? We've previously experienced similar issues on other generation tasks (not in this work) with transformers' generation, primarily attributed to a problem with beam search within generation (refer to Make beam sample more robust). Consequently, we transitioned to sampling strategies, specifically the top-k or top-p sampling methods.

In this work, we choose the sampling methods such as top-k or top-p sampling during the generation process for protein-related tasks. The configuration used for our generation can be illustrated with the following example:

generation_config = GenerationConfig(
        do_sample=True,
        top_k=gen_args.top_k,                                     # 10 or more
        repetition_penalty=gen_args.repetition_penalty,           # 1.0 or 1.2 to penalize repetition
        pad_token_id=0,                                           # hard code for Llama Tokenizer
        bos_token_id=1,
        eos_token_id=2
)

The selection of sampling strategies and hyperparameters depends on your use case; it may be beneficial to iterate and experiment with different strategies and hyperparameters to find the optimal combination.

from mol-instructions.

CiaoHe avatar CiaoHe commented on July 3, 2024

Thx for your feedback. I tried several combinations of GenerationConfig, set repetition_penalty=1.2(as default), have experimented on :

  1. top_k = 10 / 15
  2. top_p = 0.75 / 0.9
  3. repetition_penalty = 1.0 / 1.2
  4. num_beams = 1 / 4

But I'm sorry all tries failed—still the same error. (RuntimeError: probability tensor contains either inf, nan or element < 0) I used the same instruction and Input as above:

"""
instruction: "Find and list any domains or motifs that are likely present in this protein sequence:"
input: "MEFDTIAAISTFPGEAGIGIVRISGDEALEIISKIFRPFRKKDIKSVKSHTIHYGHIVDPETGEVYDEVLVTVMRKPNTYTREDVVEINCHGGIVVSSKILELVLKHGARLAEPGEFTKRAFLNGRIDLSQAEAVIDIITSKTMLANRYAQKQLAGVLGQKMKDLKNKIMELLSHLLALIDFPEEDVEELEREEIKRRAKDILNDIEYLIASSESGRIIREGLKTAIIGKPNVGKSSLLNALLKQNRAIVTDIPGTTRDVIEEYMNIKGIPIKLIDTAGIRHTDELVEKIGVEKSKEVLAEADLILFVLDASRDLTKEDYEIFDILSGKNIIFVLNKVDLPKKIDEEELKKLVGNGIIVEVSTVERTGLDKLESEIYNLVFKGKVSATEEEIITNARHREVLINAKKHMESVIEAIEKGYSEDLITIDVNGALNEIGKITGETATEDVINQIFERFCVGK"
"""

Have you been able to check the released ckpt? (Another colleague deployed and tried to infer but met the same error, so I think it's not related to our machine/env problems)

from mol-instructions.

Alexzhuan avatar Alexzhuan commented on July 3, 2024

Thank you for your feedback. We truly appreciate your input. We will review the checkpoint we've released soon.

from mol-instructions.

zdk123 avatar zdk123 commented on July 3, 2024

As another datapoint, I'm seeing NaNs in the hidden states, even before trying to generate the sequences. I'm not sure this will reproduce exactly, but running on a p3.2xlarge instance:

Run first:

python weight_diff.py recover  \
    --path_raw "decapoda-research/llama-7b-hf" \
    --path_diff   "zjunlp/llama-molinst-protein-7b" \
    --path_tuned /home/ubuntu/zjunlp/llama-molinst-protein-7b_recovered

Then load the model and tokenizer (adapted from generate.py)

base_model_str = "decapoda-research/llama-7b-hf"
model_str = "zjunlp/llama-molinst-protein-7b"
recovered_model_path = "/home/ubuntu/zjunlp/llama-molinst-protein-7b_recovered"

model = LlamaForCausalLM.from_pretrained(
            recovered_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            #device_map="auto",
            device_map={"": 0},
            cache_dir=cache_dir
        )

tokenizer = LlamaTokenizer.from_pretrained(base_model_str, 
             bos_token='<s>', 
             eos_token='</s>',
             add_bos_token=True,
             add_eos_token=False)

# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2

model.cuda()

Try a prompt:

# get a uniprot sequence
import requests

url = "https://rest.uniprot.org/uniprotkb/P05067.fasta"
f = requests.get(url)
seq = "".join(f.text.split("\n")[1:])

prompter = Prompter()
prompt = prompter.generate_prompt("Is the input is a protein sequence?", seq)
encoder_text_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

with torch.no_grad():
    output = model(encoder_text_ids, output_hidden_states=True)

# check the last hidden layer
output["hidden_states"][-1][0, -1]

which outputs:

tensor([nan, nan, nan,  ..., nan, nan, nan], device='cuda:0',
       dtype=torch.float16)

Tweaking the prompts and the length of the sequence, sometimes I do get numbers from the hidden states.

from mol-instructions.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.