| | --- |
| | tags: |
| | - text-generation-inference |
| | - transformers |
| | - trl |
| | - sft |
| | license: apache-2.0 |
| | language: |
| | - en |
| | --- |
| | |
| | # INFERENCE |
| |
|
| | ```python |
| | import time |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
| | |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | finetuned_model = AutoModelForCausalLM.from_pretrained("Mr-Vicky-01/sql-assistant") |
| | finetuned_model.to(device) |
| | tokenizer = AutoTokenizer.from_pretrained("Mr-Vicky-01/sql-assistant") |
| | |
| | prompt = """<|im_start|>system |
| | <|im_start|>system |
| | You are a helpful SQL assistant named Securitron. Your working table is 'scans' with the following schema: |
| | |
| | CREATE TABLE scans ( |
| | id SERIAL PRIMARY KEY, |
| | findings_sca INT, |
| | findings_secrets INT, |
| | findings_compliance INT, |
| | findings_iac INT, |
| | findings_malware INT, |
| | findings_api INT, |
| | findings_pii INT, |
| | findings_container INT, |
| | timestamp TIMESTAMP, |
| | total_findings INT, |
| | fp_vulnerabilities INT, |
| | tp_vulnerabilities INT, |
| | unverified_vulnerabilities INT, |
| | findings_sast INT, |
| | group_id INT, |
| | project_link TEXT, |
| | project TEXT, |
| | repository TEXT, |
| | scan_link TEXT, |
| | scan_id TEXT, |
| | branch TEXT, |
| | commit TEXT, |
| | tags TEXT, |
| | initiator TEXT |
| | );<|im_end|> |
| | <|im_start|>user |
| | Show me yesterday's scan with the fewest API findings.<|im_end|> |
| | <|im_start|>assistant |
| | """ |
| | |
| | s = time.time() |
| | |
| | encodeds = tokenizer(prompt, return_tensors="pt",truncation=True).input_ids.to(device) |
| | text_streamer = TextStreamer(tokenizer, skip_prompt = True) |
| | |
| | # Increase max_new_tokens if needed |
| | response = finetuned_model.generate( |
| | input_ids=encodeds, |
| | streamer=text_streamer, |
| | max_new_tokens=512, |
| | use_cache=True, |
| | pad_token_id=151645, |
| | eos_token_id=151645, |
| | num_return_sequences=1 |
| | ) |
| | e = time.time() |
| | print(f'time taken:{e-s}') |
| | ``` |