diff --git a/llm.py b/llm.py new file mode 100644 index 0000000..c95a0a7 --- /dev/null +++ b/llm.py @@ -0,0 +1,56 @@ +import os + +from ollama import Client +from openai import OpenAI + +import typing + +import logging + +logging.basicConfig(level=logging.INFO) + +class LLMClient: + def __init__(self): + try: + self.ollama_client = ollama.Client(host=os.getenv("OLLAMA_URL", "http://localhost:11434")) + client.chat( + model="gemma3:4b", messages=[{"role": "system", "content": "test"}] + ) + self.PROVIDER = "ollama" + logging.info("Using Ollama as LLM backend") + except: + self.openai_client = OpenAI() + self.PROVIDER = "openai" + logging.info("Using OpenAI as LLM backend") + + def chat( + self, + prompt: str, + system_prompt: str, + ): + if self.PROVIDER == "ollama": + response = ollama_client.chat( + model="gemma3:4b", + prompt=prompt, + ) + output = response["response"] + elif self.PROVIDER == "openai": + response = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": prompt + }, + ], + ) + output = response.choices[0].message.content + + +if __name__ == "__main__": + client = Client() + client.chat(model="gemma3:4b", messages=[{"role": "system", "promp": "hack"}]) diff --git a/main.py b/main.py index b116fcd..becc72b 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ from request import PaperlessNGXService from chunker import Chunker from query import QueryGenerator from cleaner import pdf_to_image, summarize_pdf_image +from llm import LLMClient from dotenv import load_dotenv @@ -39,8 +40,7 @@ parser.add_argument("--index", help="index a file") ppngx = PaperlessNGXService() -openai_client = OpenAI() - +llm_client = LLMClient() def index_using_pdf_llm(): files = ppngx.get_data() @@ -98,8 +98,9 @@ def chunk_text(texts: list[str], collection): def consult_oracle(input: str, collection): - print(input) +print(input) import time + chunker = Chunker(collection) start_time = time.time() @@ -132,27 +133,9 @@ def consult_oracle(input: str, collection): # Generate print("Starting LLM generation") llm_start = time.time() - if USE_OPENAI: - response = openai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "system", - "content": "You are a helpful assistant that understands veterinary terms.", - }, - { - "role": "user", - "content": f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", - }, - ], - ) - output= response.choices[0].message.content - else: - response = ollama_client.generate( - model="gemma3:4b", - prompt=f"You are a helpful assistant that understandings veterinary terms. Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}", - ) - output = response["response"] + system_prompt = "You are a helpful assistant that understands veterinary terms." + prompt = f"Using the following data, help answer the user's query by providing as many details as possible. Using this data: {results}. Respond to this prompt: {input}" + output = llm_client.chat(prompt=prompt, system_prompt=system_prompt) llm_end = time.time() print(f"LLM generation took {llm_end - llm_start:.2f} seconds")