File size: 4,488 Bytes
dc58d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eba4c6
 
dc58d54
 
2eba4c6
18b5344
2eba4c6
 
dc58d54
18b5344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc58d54
 
 
 
 
2eba4c6
 
dc58d54
 
 
2eba4c6
abae45a
2eba4c6
 
dc58d54
 
 
2eba4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc58d54
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import asyncpg

class NetworkDB:
    def __init__(self, database_url):
        self.pool = None
        self.database_url = database_url

    async def get_pool(self):
        if self.pool:
            return self.pool
        self.pool = await asyncpg.create_pool(
            self.database_url, min_size=1, max_size=10
        )
        return self.pool
    
    async def post_text(self, content: str, embeddings: list[float]) -> bool:
        # pool = await self.get_pool()
        # async with pool.acquire() as conn:
        try:
            conn = await asyncpg.connect(self.database_url)
            id = await conn.fetchval(
                "INSERT INTO text_posts (content, embedding) VALUES ($1, $2) RETURNING id",
                content,
                f"{embeddings}",
            )
            await conn.close()
            return True if id is not None else False
        except Exception as e:
            return False

    async def get_text_post_random(self) -> str:
        try:
            conn = await asyncpg.connect(self.database_url)
            id, post = await conn.fetchval(
                "SELECT (id, content) from text_posts ORDER BY random() LIMIT 1"
            )
            await conn.close()
            if post is not None:
                formatted_post = f"<|PostId_{id}|>\n{post}"
                return formatted_post
            return "[Internal Message: No post found!]"
        except Exception as e:
            print(f"Unexpected Error: {e}")
            return "[Internal Message: Server Error]"

    async def get_text_posts_latest(self) -> str:
        try:
            conn = await asyncpg.connect(self.database_url)
            posts = await conn.fetch("SELECT (id, content) from text_posts ORDER BY uploaded_at DESC LIMIT 5")
            await conn.close()
            if len(posts) == 0:
                return "[Internal Message: No posts in the database]"
            formatted_posts = "" 
            for i, post in enumerate(posts):
                post = post[0]
                if i > 0:
                    formatted_posts += "\n\n"
                formatted_posts += f'<|PostId_{post[0]}|>\n{post[1]}'
            return formatted_posts
        except Exception as e:
            print(f"Unexpected Error: {e}")
            return "[Internal Message: Server Error]"

    async def get_text_post_similar(self, query_embedding: list[float]) -> str:
        try:
            conn = await asyncpg.connect(self.database_url)
            id, post = await conn.fetchval(
                "SELECT (id, content) FROM text_posts ORDER BY embedding <-> $1 LIMIT 1",
                f"{query_embedding}",
            )
            await conn.close()
            if post is not None:
                formatted_post = f"<|PostId_{id}|>\n{post}"
                return formatted_post
            return "[Internal Message: No similar post found!]"
        except Exception as e:
            return "[Internal Message: Server Error]"

    async def get_text_post_comments(self, post_id: int) -> str:
        try:
            conn = await asyncpg.connect(self.database_url)
            comments = await conn.fetch(
                "SELECT content FROM text_posts_comments WHERE post_id = $1 ORDER BY uploaded_at DESC LIMIT 5",
                post_id
            )
            await conn.close()
            if len(comments) == 0:
                return "[Internal Message: No Comments on this post]"
            formatted_comments = ""
            for i, comment in enumerate(comments):
                # Only add new lines before the comments. So last comment won't have extra new lines. Don't add before first comment obviously
                if i > 0:
                    formatted_comments += "\n\n"
                formatted_comments += f"<|Comment_{i}|>\n{comment['content']}"
            return formatted_comments
        except Exception as e:
            return ["Internal Message: Server Error"]
    
    async def comment_on_text_post(self, post_id: int, content: str) -> bool:
        try:
            conn = await asyncpg.connect(self.database_url)
            success = await conn.fetchval("INSERT INTO text_posts_comments (post_id, content) VALUES ($1, $2) RETURNING id", post_id, content)
            await conn.close()
            return False if success is None else True
        except Exception as e:
            return False

    async def disconnect(self) -> None:
        if self.pool:
            self.pool.close()