-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e28d533
commit ce06570
Showing
3 changed files
with
159 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from fastapi import APIRouter | ||
from pydantic import BaseModel | ||
|
||
from unsloth import FastLanguageModel | ||
import torch | ||
|
||
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | ||
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | ||
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | ||
|
||
alpaca_prompt = """Below is an instruction that describes a task, along with an input that provides additional context. Write a response that appropriately completes the request. | ||
### Instruction: | ||
{} | ||
### Input: | ||
{} | ||
### Response: | ||
{}""" | ||
|
||
class Question(BaseModel): | ||
query: str | ||
|
||
@router.post("/generate_answer") | ||
def generate_answer(value: Question): | ||
try: | ||
llama_model, llama_tokenizer = FastLanguageModel.from_pretrained( | ||
model_name = "Antonio27/llama3-8b-4-bit-for-sugar", | ||
max_seq_length = max_seq_length, | ||
dtype = dtype, | ||
load_in_4bit = load_in_4bit, | ||
) | ||
|
||
gemma_model, gemma_tokenizer = FastLanguageModel.from_pretrained( | ||
model_name = "unsloth/gemma-2-9b-it-bnb-4bit", | ||
max_seq_length = max_seq_length, | ||
dtype = dtype, | ||
load_in_4bit = load_in_4bit, | ||
) | ||
|
||
FastLanguageModel.for_inference(llama_model) | ||
llama_tokenizer.pad_token = llama_tokenizer.eos_token | ||
llama_tokenizer.add_eos_token = True | ||
|
||
inputs = llama_tokenizer( | ||
[ | ||
alpaca_prompt.format( | ||
f''' | ||
Your task is to answer children's questions using simple language. | ||
Explain any difficult words in a way a 3-year-old can understand. | ||
Keep responses under 60 words. | ||
\n\nQuestion: {value.query} | ||
''', # instruction | ||
"", # input | ||
"", # output - leave this blank for generation! | ||
) | ||
], return_tensors="pt").to("cuda") | ||
|
||
outputs = llama_model.generate(**inputs, max_new_tokens=256, temperature=0.6) | ||
decoded_outputs = llama_tokenizer.batch_decode(outputs) | ||
|
||
response_text = decoded_outputs[0] | ||
|
||
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL) | ||
if match: | ||
initial_response = match.group(1).strip() | ||
else: | ||
initial_response = "" | ||
|
||
FastLanguageModel.for_inference(gemma_model) | ||
gemma_tokenizer.pad_token = gemma_tokenizer.eos_token | ||
gemma_tokenizer.add_eos_token = True | ||
|
||
inputs = gemma_tokenizer( | ||
[ | ||
alpaca_prompt.format( | ||
f''' | ||
Modify the given content for a 5-year-old. | ||
Use simple words and phrases. | ||
Remove any repetitive information. | ||
Keep responses under 50 words. | ||
\n\nGiven Content: {initial_response} | ||
''', # instruction | ||
"", # input | ||
"", # output - leave this blank for generation! | ||
) | ||
], return_tensors="pt").to("cuda") | ||
|
||
outputs = gemma_model.generate(**inputs, max_new_tokens=256, temperature=0.6) | ||
decoded_outputs = gemma_tokenizer.batch_decode(outputs) | ||
|
||
response_text = decoded_outputs[0] | ||
|
||
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL) | ||
if match: | ||
adjusted_response = match.group(1).strip() | ||
else: | ||
adjusted_response = "" | ||
|
||
return { | ||
'success': True, | ||
'response': { | ||
"result": adjusted_response | ||
} | ||
} | ||
|
||
except Exception as e: | ||
return {'success': False, 'response': str(e)} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,22 @@ | ||
|
||
from transformers import GPT2Tokenizer, GPT2LMHeadModel | ||
|
||
|
||
# We should rename this | ||
class AI_Test: | ||
def __init__(self): | ||
pass | ||
|
||
def generate_bot_response(self, question): | ||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") | ||
model = GPT2LMHeadModel.from_pretrained("distilgpt2") | ||
|
||
prompt = ''' | ||
Your task is to answer children's questions using simple language. | ||
Explain any difficult words in a way a 3-year-old can understand. | ||
Keep responses under 60 words. | ||
\n\nQuestion: | ||
''' | ||
|
||
input_text = prompt + question | ||
|
||
inputs = tokenizer.encode(input_text, return_tensors='pt') | ||
outputs = model.generate(inputs, max_length=150, num_return_sequences=1) | ||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
|
||
return answer | ||
import os | ||
import uvicorn | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
|
||
from chat.router import router as chat_router | ||
# from piggy.router import router as piggy_router | ||
|
||
app = FastAPI( | ||
docs_url="/sugar-ai/docs", | ||
) | ||
|
||
app.include_router(chat_router, prefix="/sugar-ai/chat") | ||
# app.include_router(piggy_router, prefix="/sugar-ai/piggy") | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
from transformers import GPT2Tokenizer, GPT2LMHeadModel | ||
|
||
|
||
# We should rename this | ||
class AI_Test: | ||
def __init__(self): | ||
pass | ||
|
||
def generate_bot_response(self, question): | ||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") | ||
model = GPT2LMHeadModel.from_pretrained("distilgpt2") | ||
|
||
prompt = ''' | ||
Your task is to answer children's questions using simple language. | ||
Explain any difficult words in a way a 3-year-old can understand. | ||
Keep responses under 60 words. | ||
\n\nQuestion: | ||
''' | ||
|
||
input_text = prompt + question | ||
|
||
inputs = tokenizer.encode(input_text, return_tensors='pt') | ||
outputs = model.generate(inputs, max_length=150, num_return_sequences=1) | ||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
|
||
return answer |