| | import re |
| | import numpy as np |
| | import tensorflow_hub as hub |
| | import openai |
| | import os |
| | import tensorflow_text |
| | from sklearn.neighbors import NearestNeighbors |
| | import gradio as gr |
| | import requests |
| | import json |
| | import fitz |
| |
|
| | |
| | openai.api_key = '9481961416fa4c8e883047c5679cf971' |
| | openai.api_base = 'https://demopro-oai-we2.openai.azure.com/' |
| | openai.api_type = 'azure' |
| | openai.api_version = '2022-12-01' |
| |
|
| | |
| | def flatten(_2d_list): |
| | flat_list = [] |
| | for element in _2d_list: |
| | if type(element) is list: |
| | for item in element: |
| | flat_list.append(item) |
| | else: |
| | flat_list.append(element) |
| | return flat_list |
| |
|
| |
|
| | def preprocess(text): |
| | text = text.replace('\n', ' ') |
| | text = re.sub('\s+', ' ', text) |
| | return text |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | def pdf_to_text(path, start_page=1, end_page=None): |
| | doc = fitz.open(path) |
| | total_pages = doc.page_count |
| |
|
| | if end_page is None: |
| | end_page = total_pages |
| |
|
| | text_list = [] |
| |
|
| | for i in range(start_page - 1, end_page): |
| | text = doc.load_page(i).get_text("text") |
| | text = preprocess(text) |
| | text_list.append(text) |
| |
|
| | doc.close() |
| | return text_list |
| |
|
| |
|
| | def text_to_chunks(texts, word_length=150, start_page=1): |
| | text_toks = [t.split(' ') for t in texts] |
| | page_nums = [] |
| | chunks = [] |
| |
|
| | for idx, words in enumerate(text_toks): |
| | for i in range(0, len(words), word_length): |
| | chunk = words[i : i + word_length] |
| | if ( |
| | (i + word_length) > len(words) |
| | and (len(chunk) < word_length) |
| | and (len(text_toks) != (idx + 1)) |
| | ): |
| | text_toks[idx + 1] = chunk + text_toks[idx + 1] |
| | continue |
| | chunk = ' '.join(chunk).strip() |
| | chunk = f'[Page no. {idx+start_page}]' + ' ' + '"' + chunk + '"' |
| | chunks.append(chunk) |
| | return chunks |
| | |
| | history=pdf_to_text('The Elements of Statisitcal Learning.pdf',start_page=20) |
| | history=text_to_chunks(history,start_page=1) |
| |
|
| |
|
| | def encoder(text): |
| | embed=openai.Embedding.create(input=text, engine="text-embedding-ada-002") |
| | return embed.get('data')[0].get('embedding') |
| | |
| | |
| | |
| | class SemanticSearch: |
| | |
| | def __init__(self): |
| | |
| | self.use =hub.load('https://tfhub.dev/google/universal-sentence-encoder-multilingual/3') |
| | self.fitted = False |
| | |
| | def get_text_embedding(self, texts, batch=1000): |
| | embeddings = [] |
| | for i in range(0, len(texts), batch): |
| | text_batch = texts[i : (i + batch)] |
| | emb_batch = self.use(text_batch) |
| | embeddings.append(emb_batch) |
| | embeddings = np.vstack(embeddings) |
| | return embeddings |
| |
|
| | |
| | |
| | |
| | def fit(self, data, batch=1000, n_neighbors=5): |
| | self.data = data |
| | self.embeddings = self.get_text_embedding(data, batch=batch) |
| | n_neighbors = min(n_neighbors, len(self.embeddings)) |
| | self.nn = NearestNeighbors(n_neighbors=n_neighbors) |
| | self.nn.fit(self.embeddings) |
| | self.fitted = True |
| | |
| | |
| | def __call__(self, text, return_data=True): |
| | inp_emb = self.use([text]) |
| | |
| | |
| | neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0] |
| | |
| | if return_data: |
| | return [self.data[i] for i in neighbors] |
| | else: |
| | return neighbors |
| | |
| | |
| | |
| | def generate_text(prompt, engine="text-davinci-003"): |
| | completions = openai.Completion.create( |
| | engine=engine, |
| | prompt=prompt, |
| | max_tokens=512, |
| | n=1, |
| | stop=None, |
| | temperature=0.7, |
| | ) |
| | message = completions.choices[0].text |
| | return message |
| |
|
| |
|
| | def generate_answer(question): |
| | |
| | topn_chunks = recommender(question) |
| | prompt = "" |
| | prompt += 'search results:\n\n' |
| | |
| | |
| | for c in topn_chunks: |
| | prompt += c + '\n\n' |
| | |
| | |
| | prompt += ''' |
| | Instructions: 如果搜索结果中找不到相关信息,只需要回答'未在该文档中找到相关信息'。 |
| | 如果找到了相关信息,请使用中文回答,回答尽量精确简洁。并在句子的末尾使用[七年级上册/七年级下册页码]符号引用每个参考文献(每个结果的开头都有这个编号) |
| | 如果不确定答案是否正确,就仅给出相似段落的来源,不要回复错误的答案。 |
| | \n\nQuery: {question}\nAnswer: |
| | ''' |
| | |
| | prompt += f"Query: {question}\nAnswer:" |
| | answer = generate_text(prompt,"text-davinci-003") |
| | return answer |
| |
|
| |
|
| | recommender = SemanticSearch() |
| | recommender.fit(history) |
| |
|
| |
|
| | |
| | def ask_api(question): |
| | |
| | if question.strip() == '': |
| | return '[ERROR]: 未输入问题' |
| |
|
| | return generate_answer(question) |
| |
|
| | title = 'Chat With Statistical Learning' |
| | description = """ 该机器人将以Trevor Hastie等人所著的The Elements of Statistical Learning Data Mining, Inference, and Prediction |
| | (即我们上课所用的课本)为主题回答你的问题,如果所问问题与书的内容无关,将会返回"未在该文档中找到相关信息" |
| | """ |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown(f'<center><h1>{title}</h1></center>') |
| | gr.Markdown(description) |
| |
|
| | with gr.Row(): |
| | with gr.Group(): |
| | question = gr.Textbox(label='请输入你的问题') |
| | btn = gr.Button(value='提交') |
| | btn.style(full_width=True) |
| |
|
| | with gr.Group(): |
| | answer = gr.Textbox(label='回答:') |
| |
|
| | btn.click( |
| | ask_api, |
| | inputs=[question], |
| | outputs=[answer] |
| | ) |
| |
|
| | |
| | demo.launch() |