68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
import os
|
|
import psycopg2
|
|
import csv
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
DB_HOST = os.environ["DB_HOST"]
|
|
DB_PORT = os.environ["DB_PORT"]
|
|
DB_USER = os.environ["DB_USER"]
|
|
DB_PASSWORD = os.environ["DB_PASSWORD"]
|
|
DB_NAME = os.environ["DB_NAME"]
|
|
TSV_IS = "/scripts/is.txt"
|
|
TSV_ARE = "/scripts/are.txt"
|
|
|
|
# Connect to DB
|
|
conn = psycopg2.connect(
|
|
host=DB_HOST,
|
|
port=DB_PORT,
|
|
dbname=DB_NAME,
|
|
user=DB_USER,
|
|
password=DB_PASSWORD
|
|
)
|
|
conn.autocommit = True
|
|
cur = conn.cursor()
|
|
|
|
# Ensure pgvector extension and table exist
|
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS factoids (
|
|
id SERIAL PRIMARY KEY,
|
|
trigger TEXT NOT NULL,
|
|
copula TEXT NOT NULL DEFAULT 'is',
|
|
response TEXT NOT NULL,
|
|
embedding VECTOR(384)
|
|
);
|
|
""")
|
|
|
|
# Load model
|
|
print("Loading embedding model...")
|
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
# Process TSVs one line at a time
|
|
|
|
for copula, TSV_PATH in ( ("is", TSV_IS), ("are", TSV_ARE) ):
|
|
|
|
print(f"Loading and inserting from {TSV_PATH}...")
|
|
with open(TSV_PATH, "r", encoding="utf-8", errors="replace") as f:
|
|
reader = csv.reader(f, delimiter="\t")
|
|
count = 0
|
|
for row in reader:
|
|
if len(row) == 2:
|
|
if count % 100 == 0:
|
|
print(count, row[0], copula, row[1])
|
|
count += 1
|
|
if len(row) != 2:
|
|
continue
|
|
trigger, response = row[0].strip(), row[1].strip()
|
|
if not trigger or not response:
|
|
continue
|
|
embedding = model.encode(trigger).tolist()
|
|
cur.execute(
|
|
"INSERT INTO factoids (trigger, copula, response, embedding) VALUES (%s, %s, %s, %s)",
|
|
(trigger, copula, response, embedding)
|
|
)
|
|
|
|
print("All factoids loaded.")
|
|
cur.close()
|
|
conn.close()
|
|
|