Vector Search via Embedding
Because database entries can be numerous and too large to fit into the prompt directly, we need a way to represent them in a more compact form for AI tools to use. To achieve this, we do two things:
- “Get all” endpoints are excluded from being derived as AI tools by using the
@operation.tool({ hidden: true })decorator. - Vector embeddings are generated for each user and task entry when they are created or updated. These embeddings use an OpenAI embeddings model and are stored in the database using pgvector , described as
Unsupported("vector(1536)")in the Prisma schema under theembeddingcolumn for each model. See the Database article for more details.
The EmbeddingService class implements the logic of generating embeddings and performing vector search using the pgvector extension in Postgres, exposing two main methods:
generateEntityEmbedding– generates an embedding for a given entity based on its string fields and updates the database entry with the generated vector.vectorSearch– performs vector search for a given query string and returns the most similar entries based on the stored embeddings.
Both methods are entity-type agnostic and work with both user and task entity types.
src/modules/embedding/EmbeddingService.ts
import { embed } from "ai";
import { openai } from "@ai-sdk/openai";
import { capitalize, omit } from "lodash";
import { Prisma } from "@prisma/client";
import { EntityType } from "@schemas/index";
import { UserType } from "@schemas/models/User.schema";
import { TaskType } from "@schemas/models/Task.schema";
import { BASE_KEYS } from "@/constants";
import DatabaseService from "../database/DatabaseService";
export default class EmbeddingService {
static async generateEmbedding(value: string): Promise<number[]> {
const { embedding } = await embed({
model: openai.embeddingModel("text-embedding-3-small"),
value,
});
return embedding;
}
static generateEntityEmbedding = async (
entityType: EntityType,
entityId: UserType["id"] | TaskType["id"],
) => {
const entity = await DatabaseService.prisma[
entityType as "user"
].findUnique({
where: { id: entityId },
});
const capitalizedEntityType = capitalize(entityType);
if (!entity) throw new Error(`${capitalizedEntityType} not found`);
const embedding = await this.generateEmbedding(
Object.values(omit(entity, BASE_KEYS))
.filter((v) => typeof v === "string")
.join(" ")
.trim()
.toLowerCase(),
);
await DatabaseService.prisma.$executeRawUnsafe(
`
UPDATE "${capitalizedEntityType}"
SET embedding = $1::vector
WHERE id = $2
`,
`[${embedding.join(",")}]`,
entityId,
);
return embedding;
};
static async vectorSearch<T>(
entityType: EntityType,
query: string,
limit: number = 10,
similarityThreshold: number = 0.4,
) {
const queryEmbedding = await this.generateEmbedding(
query.trim().toLowerCase(),
);
const capitalizedEntityType = capitalize(entityType);
// find similar vectors and return entity IDs
const vectorResults = await DatabaseService.prisma.$queryRaw<
{ id: String; similarity: number }[]
>`
SELECT
id,
1 - (embedding <=> ${`[${queryEmbedding.join(",")}]`}::vector) as similarity
FROM ${Prisma.raw(`"${capitalizedEntityType}"`)}
WHERE embedding IS NOT NULL
AND 1 - (embedding <=> ${`[${queryEmbedding.join(",")}]`}::vector) > ${similarityThreshold}
ORDER BY embedding <=> ${`[${queryEmbedding.join(",")}]`}::vector
LIMIT ${limit}
`;
return DatabaseService.prisma[entityType as "user"].findMany({
where: {
id: {
in: vectorResults.map((r) => r.id as string),
},
},
}) as Promise<T[]>;
}
}Last updated on