import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

token = os.environ["modeldatabaseHUB_API_TOKEN"]

model_id = 'Deci/DeciLM-6b-instruct'

SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:

{instruction}

### Response:
"""

DESCRIPTION = """
# <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p>
<span style='color: #292b47;'>Welcome to <a href="https://modeldatabase.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model;  you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span>
<p><span style='color: #292b47;'>Learn more about the base model <a href="https://deci.ai/blog/decilm-15-times-faster-than-llama2-nas-generated-llm-with-variable-gqa/" style="color: #3264ff;">DeciLM-6B.</a></span></p>
"""

if not torch.cuda.is_available():
    DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb'

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map='auto',
        trust_remote_code=True, 
        use_auth_token=token
    )
else:
    model = None

tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token

# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
    return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)

# Function to generate the model's response
def generate_model_response(message: str) -> str:
    prompt = get_prompt_with_template(message)
    inputs = tokenizer(prompt, return_tensors='pt')
    if torch.cuda.is_available():
        inputs = inputs.to('cuda')
    # Include **generate_kwargs to include the user-defined options
    output = model.generate(**inputs, 
                            max_new_tokens=3000, 
                            num_beams=2,
                            no_repeat_ngram_size=4,
                            early_stopping=True,
                            do_sample=True
                            ) 
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Function to extract the content after "### Response:"
def extract_response_content(full_response: str, ) -> str:
    response_start_index = full_response.find("### Response:")
    if response_start_index != -1:
        return full_response[response_start_index + len("### Response:"):].strip()
    else:
        return full_response

# The main function that uses the dynamic generate_kwargs
def get_response_with_template(message: str) -> str:
    full_response = generate_model_response(message)
    return extract_response_content(full_response)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value='Duplicate Space for private use',
                       elem_id='duplicate-button')
    with gr.Group():
        chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:')
        with gr.Row():
            textbox = gr.Textbox(
                container=False,
                show_label=False,
                placeholder='Type an instruction...',
                scale=10,
                elem_id="textbox"
            )
            submit_button = gr.Button(
                '💬 Submit',
                variant='primary',
                scale=1,
                min_width=0,
                elem_id="submit_button"
            )

            # Clear button to clear the chat history
            clear_button = gr.Button(
                '🗑️ Clear',
                variant='secondary',
            )

    clear_button.click(
        fn=lambda: ('',''),
        outputs=[textbox, chatbot],
        queue=False,
        api_name=False,
    )

    submit_button.click(
        fn=get_response_with_template,
        inputs=textbox,
        outputs= chatbot,
        queue=False,
        api_name=False,
    )

    gr.Examples(
        examples=[
            'Write detailed instructions for making chocolate chip pancakes.',
            'Write a 250-word article about your love of pancakes.',
            'Explain the plot of Back to the Future in three sentences.',
            'How do I make a trap beat?',
            'A step-by-step guide to learning Python in one month.',
        ],
        inputs=textbox,
        outputs=chatbot,
        fn=get_response_with_template,
        cache_examples=True,
        elem_id="examples"
    )


    gr.HTML(label="Keep in touch", value="<img src='https://modeldatabase.co/spaces/Deci/DeciLM-6b-instruct/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")

demo.launch()