dumant/app/main.py
2025-07-13 09:46:56 -07:00

244 lines
7.2 KiB
Python

import os
import asyncpg
import numpy as np
from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer
import httpx
import random
import itertools
DB_CONFIG = {
"host": os.getenv("DB_HOST"),
"port": int(os.getenv("DB_PORT", "5432")),
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"database": os.getenv("DB_NAME"),
}
LLM_API_URL = os.getenv("LLM_API_URL", "http://llm:80")
app = FastAPI()
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
@app.on_event("startup")
async def startup():
app.state.db = await asyncpg.create_pool(**DB_CONFIG)
@app.on_event("shutdown")
async def shutdown():
await app.state.db.close()
async def extract_triggers(user_query: str) -> list[str]:
"""Ask the LLM to extract factoid-like terms from a user query."""
prompt = f"""<|im_start|>system
You are a factoid trigger extractor. Extract a list of keywords or short phrases that might match entries in an infobot-style knowledge base.
Respond only with a comma-separated list of triggers. Do not answer the question or correct spelling or grammar.
<|im_end|>
<|im_start|>user
Who is steve?
<|im_end|>
<|im_start|>assistant
steve
<|im_end|>
<|im_start|>user
Tell me about HTTP status codes.
<|im_end|>
<|im_start|>assistant
HTTP status codes
<|im_end|>
<|im_start|>user
Who are you and your Gary
<|im_end|>
<|im_start|>assistant
you, Gary
<|im_end|>
<|im_start|>user
{user_query}
<|im_end|>
<|im_start|>assistant
"""
payload = {
"inputs": prompt,
"parameters": {
"temperature": 0.3,
"max_new_tokens": 64,
},
}
print("PAYLOAD:", payload)
async with httpx.AsyncClient() as client:
r = await client.post(
f"{LLM_API_URL}/generate",
json=payload,
timeout=httpx.Timeout(120.0),
)
if r.status_code != 200:
return []
output = r.json().get("generated_text", "")
print("OUTPUT:", output)
# Try to parse output as a list
triggers = []
for line in output.splitlines():
line = line.strip(" []\"'")
if line:
triggers.extend([t.strip(" \"'") for t in line.split(",") if t.strip()])
return list(set(triggers))
@app.post("/ask")
async def ask(request: Request):
body = await request.json()
query = body.get("query", "").strip()
if not query:
return {"error": "Missing query"}
responses = []
# try to extract the query verbatim
async with app.state.db.acquire() as conn:
rows = await conn.fetchrow(
"""
SELECT trigger, response, copula
FROM factoids
WHERE trigger = $1
""",
query.strip()
)
if rows:
rows = [ {'trigger': x[0], 'response': x[1], 'copula': x[2]} for x in itertools.combinations(rows, 3) ]
options = []
for row in rows:
for option in row["response"].split("|"):
options.append(row | {"response": option.strip()})
row = random.choice(options)
print("CHOICE:", row)
triggers = [query.strip()]
choices = [row['response']]
# Use <reply> if present, otherwise prepend the trigger
for c in choices:
if c.startswith("<reply>"):
responses.append(c[len("<reply>"):].strip())
else:
responses.append(f"{row['trigger']} {row['copula']} {c}")
else:
triggers = await extract_triggers(query)
if not triggers:
return {"response": "I don't know that one."}
print("triggers: ", triggers)
for trigger in triggers:
row = await conn.fetchrow(
"""
SELECT trigger, response, copula
FROM factoids
WHERE trigger = $1
LIMIT 1
""",
trigger
)
if not row:
embedding = model.encode(trigger)
embedding_str = f"[{', '.join(map(str, embedding))}]"
row = await conn.fetchrow(
"""
SELECT trigger, response, copula
FROM factoids
ORDER BY embedding <-> $1::vector
LIMIT 1
""",
embedding_str
)
if row:
print("ROW:", row)
response = row['response']
# Pick a random option from pipe-separated responses
options = [r.strip() for r in response.split('|') if r.strip()]
if options:
k = 1
choices = random.choices(options,k=k)
else:
choices = [response]
# Use <reply> if present, otherwise prepend the trigger
for c in choices:
if c.startswith("<reply>"):
responses.append(c[len("<reply>"):].strip())
else:
responses.append(f"{row['trigger']} {row['copula']} {c}")
if not responses:
return {"response": "I don't know any of those."}
# Ask LLM to summarize
responses_str = "\nValue: ".join(responses)
#
summary_prompt = (
f"""<|im_start|>system
You are a summarizer for a fact-based chatbot. Your task is to condense database entries into short, accurate one-line summaries. Do not speculate, define, or add new facts. Do not correct spelling or phrasing from the facts or triggers. Do not mix context from prior triggers.
<|im_end|>
<|im_start|>user
Summarize the following database entry.
Trigger: Paris
Value: Paris is the capital of France
Value: Paris is located in the north-central part of the country.
<|im_end|>
<|im_start|>assistant
Paris is the capital of France and located in the north-central part of the country.
<|im_end|>
<|im_start|>user
Summarize the following database entry.
Trigger: squinky, spacehobo
Value: spacehobo is a Citizen.
Value: squinky is kinky
<|im_end|>
<|im_start|>assistant
spacehobo is a Citizen and squinky is kinky
<|im_end|>
<|im_start|>user
Summarize the following database entry.
Trigger: sky
Value: sky is blue
Value: the sky is the big thing outside when you look up
Value: Look!
<|im_end|>
<|im_start|>assistant
Look up at that big blue thing outside!
<|im_end|>
<|im_start|>user
Summarize the following database entry.
Trigger: {query}
Value: {responses_str}
<|im_end|>"""
)
print("SUMMARY PAYLOAD:", summary_prompt)
async with httpx.AsyncClient() as client:
r = await client.post(
f"{LLM_API_URL}/generate",
json={"inputs": summary_prompt, "parameters": {"temperature": 0.8, "max_new_tokens": 200}},
timeout=httpx.Timeout(120.0),
)
final_response = r.json().get("generated_text", "\n".join(responses))
print("FINAL RESPONSE:", final_response)
return {"reply": final_response, "matches": responses, "triggers": triggers}