244 lines
7.2 KiB
Python
244 lines
7.2 KiB
Python
import os
|
|
import asyncpg
|
|
import numpy as np
|
|
from fastapi import FastAPI, Request
|
|
from sentence_transformers import SentenceTransformer
|
|
import httpx
|
|
import random
|
|
import itertools
|
|
|
|
DB_CONFIG = {
|
|
"host": os.getenv("DB_HOST"),
|
|
"port": int(os.getenv("DB_PORT", "5432")),
|
|
"user": os.getenv("DB_USER"),
|
|
"password": os.getenv("DB_PASSWORD"),
|
|
"database": os.getenv("DB_NAME"),
|
|
}
|
|
LLM_API_URL = os.getenv("LLM_API_URL", "http://llm:80")
|
|
|
|
app = FastAPI()
|
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
app.state.db = await asyncpg.create_pool(**DB_CONFIG)
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
await app.state.db.close()
|
|
|
|
async def extract_triggers(user_query: str) -> list[str]:
|
|
"""Ask the LLM to extract factoid-like terms from a user query."""
|
|
|
|
|
|
prompt = f"""<|im_start|>system
|
|
You are a factoid trigger extractor. Extract a list of keywords or short phrases that might match entries in an infobot-style knowledge base.
|
|
Respond only with a comma-separated list of triggers. Do not answer the question or correct spelling or grammar.
|
|
<|im_end|>
|
|
<|im_start|>user
|
|
Who is steve?
|
|
<|im_end|>
|
|
<|im_start|>assistant
|
|
steve
|
|
<|im_end|>
|
|
<|im_start|>user
|
|
Tell me about HTTP status codes.
|
|
<|im_end|>
|
|
<|im_start|>assistant
|
|
HTTP status codes
|
|
<|im_end|>
|
|
<|im_start|>user
|
|
Who are you and your Gary
|
|
<|im_end|>
|
|
<|im_start|>assistant
|
|
you, Gary
|
|
<|im_end|>
|
|
<|im_start|>user
|
|
{user_query}
|
|
<|im_end|>
|
|
<|im_start|>assistant
|
|
"""
|
|
|
|
payload = {
|
|
"inputs": prompt,
|
|
"parameters": {
|
|
"temperature": 0.3,
|
|
"max_new_tokens": 64,
|
|
},
|
|
}
|
|
|
|
print("PAYLOAD:", payload)
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(
|
|
f"{LLM_API_URL}/generate",
|
|
json=payload,
|
|
timeout=httpx.Timeout(120.0),
|
|
)
|
|
if r.status_code != 200:
|
|
return []
|
|
output = r.json().get("generated_text", "")
|
|
|
|
print("OUTPUT:", output)
|
|
|
|
# Try to parse output as a list
|
|
triggers = []
|
|
for line in output.splitlines():
|
|
line = line.strip(" []\"'")
|
|
if line:
|
|
triggers.extend([t.strip(" \"'") for t in line.split(",") if t.strip()])
|
|
return list(set(triggers))
|
|
|
|
@app.post("/ask")
|
|
async def ask(request: Request):
|
|
body = await request.json()
|
|
query = body.get("query", "").strip()
|
|
if not query:
|
|
return {"error": "Missing query"}
|
|
responses = []
|
|
|
|
# try to extract the query verbatim
|
|
async with app.state.db.acquire() as conn:
|
|
rows = await conn.fetchrow(
|
|
"""
|
|
SELECT trigger, response, copula
|
|
FROM factoids
|
|
WHERE trigger = $1
|
|
""",
|
|
query.strip()
|
|
)
|
|
if rows:
|
|
rows = [ {'trigger': x[0], 'response': x[1], 'copula': x[2]} for x in itertools.combinations(rows, 3) ]
|
|
options = []
|
|
for row in rows:
|
|
for option in row["response"].split("|"):
|
|
options.append(row | {"response": option.strip()})
|
|
row = random.choice(options)
|
|
print("CHOICE:", row)
|
|
triggers = [query.strip()]
|
|
choices = [row['response']]
|
|
# Use <reply> if present, otherwise prepend the trigger
|
|
for c in choices:
|
|
if c.startswith("<reply>"):
|
|
responses.append(c[len("<reply>"):].strip())
|
|
else:
|
|
responses.append(f"{row['trigger']} {row['copula']} {c}")
|
|
|
|
else:
|
|
triggers = await extract_triggers(query)
|
|
if not triggers:
|
|
return {"response": "I don't know that one."}
|
|
|
|
print("triggers: ", triggers)
|
|
for trigger in triggers:
|
|
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT trigger, response, copula
|
|
FROM factoids
|
|
WHERE trigger = $1
|
|
LIMIT 1
|
|
""",
|
|
trigger
|
|
)
|
|
|
|
if not row:
|
|
|
|
embedding = model.encode(trigger)
|
|
embedding_str = f"[{', '.join(map(str, embedding))}]"
|
|
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT trigger, response, copula
|
|
FROM factoids
|
|
ORDER BY embedding <-> $1::vector
|
|
LIMIT 1
|
|
""",
|
|
embedding_str
|
|
)
|
|
|
|
|
|
if row:
|
|
print("ROW:", row)
|
|
response = row['response']
|
|
# Pick a random option from pipe-separated responses
|
|
options = [r.strip() for r in response.split('|') if r.strip()]
|
|
if options:
|
|
k = 1
|
|
choices = random.choices(options,k=k)
|
|
else:
|
|
choices = [response]
|
|
# Use <reply> if present, otherwise prepend the trigger
|
|
for c in choices:
|
|
if c.startswith("<reply>"):
|
|
responses.append(c[len("<reply>"):].strip())
|
|
else:
|
|
responses.append(f"{row['trigger']} {row['copula']} {c}")
|
|
|
|
if not responses:
|
|
return {"response": "I don't know any of those."}
|
|
|
|
# Ask LLM to summarize
|
|
responses_str = "\nValue: ".join(responses)
|
|
#
|
|
summary_prompt = (
|
|
f"""<|im_start|>system
|
|
You are a summarizer for a fact-based chatbot. Your task is to condense database entries into short, accurate one-line summaries. Do not speculate, define, or add new facts. Do not correct spelling or phrasing from the facts or triggers. Do not mix context from prior triggers.
|
|
<|im_end|>
|
|
|
|
<|im_start|>user
|
|
Summarize the following database entry.
|
|
|
|
Trigger: Paris
|
|
Value: Paris is the capital of France
|
|
Value: Paris is located in the north-central part of the country.
|
|
<|im_end|>
|
|
|
|
<|im_start|>assistant
|
|
Paris is the capital of France and located in the north-central part of the country.
|
|
<|im_end|>
|
|
|
|
<|im_start|>user
|
|
Summarize the following database entry.
|
|
|
|
Trigger: squinky, spacehobo
|
|
Value: spacehobo is a Citizen.
|
|
Value: squinky is kinky
|
|
<|im_end|>
|
|
|
|
<|im_start|>assistant
|
|
spacehobo is a Citizen and squinky is kinky
|
|
<|im_end|>
|
|
|
|
<|im_start|>user
|
|
Summarize the following database entry.
|
|
|
|
Trigger: sky
|
|
Value: sky is blue
|
|
Value: the sky is the big thing outside when you look up
|
|
Value: Look!
|
|
<|im_end|>
|
|
|
|
<|im_start|>assistant
|
|
Look up at that big blue thing outside!
|
|
<|im_end|>
|
|
|
|
<|im_start|>user
|
|
Summarize the following database entry.
|
|
|
|
Trigger: {query}
|
|
Value: {responses_str}
|
|
<|im_end|>"""
|
|
|
|
)
|
|
print("SUMMARY PAYLOAD:", summary_prompt)
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(
|
|
f"{LLM_API_URL}/generate",
|
|
json={"inputs": summary_prompt, "parameters": {"temperature": 0.8, "max_new_tokens": 200}},
|
|
timeout=httpx.Timeout(120.0),
|
|
)
|
|
final_response = r.json().get("generated_text", "\n".join(responses))
|
|
print("FINAL RESPONSE:", final_response)
|
|
|
|
return {"reply": final_response, "matches": responses, "triggers": triggers}
|
|
|