| import streamlit as st |
| import uuid |
| import sys |
| import requests |
| from peft import * |
| import bitsandbytes as bnb |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import transformers |
| from datasets import load_dataset |
| from huggingface_hub import notebook_login |
| from peft import ( |
| LoraConfig, |
| PeftConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| ) |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| ) |
|
|
|
|
| USER_ICON = "images/user-icon.png" |
| AI_ICON = "images/ai-icon.png" |
| MAX_HISTORY_LENGTH = 5 |
|
|
| if 'user_id' in st.session_state: |
| user_id = st.session_state['user_id'] |
| else: |
| user_id = str(uuid.uuid4()) |
| st.session_state['user_id'] = user_id |
|
|
| if 'chat_history' not in st.session_state: |
| st.session_state['chat_history'] = [] |
|
|
| if "chats" not in st.session_state: |
| st.session_state.chats = [ |
| { |
| 'id': 0, |
| 'question': '', |
| 'answer': '' |
| } |
| ] |
|
|
| if "questions" not in st.session_state: |
| st.session_state.questions = [] |
|
|
| if "answers" not in st.session_state: |
| st.session_state.answers = [] |
|
|
| if "input" not in st.session_state: |
| st.session_state.input = "" |
|
|
| st.markdown(""" |
| <style> |
| .block-container { |
| padding-top: 32px; |
| padding-bottom: 32px; |
| padding-left: 0; |
| padding-right: 0; |
| } |
| .element-container img { |
| background-color: #000000; |
| } |
| |
| .main-header { |
| font-size: 24px; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| def write_top_bar(): |
| col1, col2, col3 = st.columns([1,10,2]) |
| with col1: |
| st.image(AI_ICON, use_column_width='always') |
| with col2: |
| header = "Cogwise Intelligent Assistant" |
| st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True) |
| with col3: |
| clear = st.button("Clear Chat") |
| return clear |
|
|
| clear = write_top_bar() |
|
|
| if clear: |
| st.session_state.questions = [] |
| st.session_state.answers = [] |
| st.session_state.input = "" |
| st.session_state["chat_history"] = [] |
|
|
| def handle_input(): |
| input = st.session_state.input |
| question_with_id = { |
| 'question': input, |
| 'id': len(st.session_state.questions) |
| } |
| st.session_state.questions.append(question_with_id) |
|
|
| chat_history = st.session_state["chat_history"] |
| if len(chat_history) == MAX_HISTORY_LENGTH: |
| chat_history = chat_history[:-1] |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| |
| |
|
|
|
|
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| |
| |
|
|
| |
|
|
| from datasets import load_dataset |
|
|
| dataset_name = "nisaar/Lawyer_GPT_India" |
| |
| dataset = load_dataset(dataset_name, split="train") |
|
|
| |
|
|
| |
| |
|
|
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| load_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
|
|
| peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts" |
| config = PeftConfig.from_pretrained(peft_model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name_or_path, |
| return_dict=True, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
| """## Inference |
| |
| You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`. |
| """ |
|
|
| generation_config = model.generation_config |
| generation_config.max_new_tokens = 200 |
| generation_config_temperature = 1 |
| generation_config.top_p = 0.7 |
| generation_config.num_return_sequences = 1 |
| generation_config.pad_token_id = tokenizer.eos_token_id |
| generation_config_eod_token_id = tokenizer.eos_token_id |
|
|
| DEVICE = "cuda:0" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def generate_response(question: str) -> str: |
| prompt = f""" |
| <human>: {question} |
| <assistant>: |
| """.strip() |
| encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| with torch.inference_mode(): |
| outputs = model.generate( |
| input_ids=encoding.input_ids, |
| attention_mask=encoding.attention_mask, |
| generation_config=generation_config, |
| ) |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| assistant_start = '<assistant>:' |
| response_start = response.find(assistant_start) |
| return response[response_start + len(assistant_start):].strip() |
|
|
| |
| prompt=input |
| answer=generate_response(prompt) |
| print(answer) |
|
|
| |
| chat_history.append((input, answer)) |
|
|
| st.session_state.answers.append({ |
| 'answer': answer, |
| 'id': len(st.session_state.questions) |
| }) |
| st.session_state.input = "" |
|
|
| def write_user_message(md): |
| col1, col2 = st.columns([1,12]) |
|
|
| with col1: |
| st.image(USER_ICON, use_column_width='always') |
| with col2: |
| st.warning(md['question']) |
|
|
| def render_answer(answer): |
| col1, col2 = st.columns([1,12]) |
| with col1: |
| st.image(AI_ICON, use_column_width='always') |
| with col2: |
| st.info(answer) |
|
|
| def write_chat_message(md, q): |
| chat = st.container() |
| with chat: |
| render_answer(md['answer']) |
|
|
| with st.container(): |
| for (q, a) in zip(st.session_state.questions, st.session_state.answers): |
| write_user_message(q) |
| write_chat_message(a, q) |
|
|
| st.markdown('---') |
| input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) |
|
|