BareRTC/pkg/models/direct_messages.go
Noah Petherbridge 52df45b2e9 GORM for the database for Postgres support
* For the Direct Message History database, use gorm.io as ORM so that Postgres
  can be used instead of SQLite for bigger chat room instances.
* In settings.toml: the new DatabaseType field defaults to 'sqlite3' but can be
  set to 'postgres' and use the credentials in the new PostgresDatabase field.
* The DirectMessage table schema is also updated to deprecate the Timestamp int
  field in favor of a proper CreatedAt datetime field. Existing SQLite instances
  will upgrade their table in the background, converting Timestamp to CreatedAt
  and blanking out the legacy Timestamp column.
* Fix some DB queries so when paginating your DMs history username list, sorting
  it by timestamp now works reliably.
* For existing SQLite instances that want to switch to Postgres, use the
  scripts/sqlite2psql.py script to transfer your database over.
2025-07-06 11:56:20 -07:00

355 lines
8.9 KiB
Go

package models
import (
"errors"
"fmt"
"math"
"sort"
"time"
"git.kirsle.net/apps/barertc/pkg/config"
"git.kirsle.net/apps/barertc/pkg/log"
"git.kirsle.net/apps/barertc/pkg/messages"
"gorm.io/gorm"
)
type DirectMessage struct {
MessageID int64 `gorm:"primaryKey"`
ChannelID string `gorm:"index"`
Username string `gorm:"index"`
Message string
Timestamp int64 // deprecated
CreatedAt time.Time `gorm:"index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
const DirectMessagePerPage = 20
// MigrateV2 upgrades the DirectMessage table for the V2 schema, where we switched
// from using SQLite to GORM so we can optionally use Postgres instead.
//
// During this switch, we also deprecate the old Timestamp column in favor of CreatedAt.
// This function will run in a background goroutine and update legacy rows to the new format.
func (db DirectMessage) MigrateV2() {
// Find rows that need upgrading.
var page int
for {
page++
var rows = []*DirectMessage{}
res := DB.Model(&DirectMessage{}).Where(
"timestamp > 0",
).Limit(1000).Find(&rows)
if res.Error != nil {
log.Error("DirectMessage.MigrateV2: %s", res.Error)
return
}
if len(rows) == 0 {
break
}
log.Warn("DirectMessage.MigrateV2: Updating %d DMs (page %d)", len(rows), page)
for _, row := range rows {
var created = time.Unix(row.Timestamp, 0)
row.CreatedAt = created
row.Timestamp = 0
DB.Save(row)
}
}
}
// Init runs initialization tasks for the DMs table (migrate V2 and expire old rows).
func (dm DirectMessage) Init() error {
if DB == nil {
return ErrNotInitialized
}
// Migrate old rows to the V2 schema.
dm.MigrateV2()
// Delete old messages past the retention period.
if days := config.Current.DirectMessageHistory.RetentionDays; days > 0 {
cutoff := time.Now().Add(time.Duration(-days) * 24 * time.Hour)
log.Info("Deleting old DM history past %d days (cutoff: %s)", days, cutoff.Format(time.RFC3339))
res := DB.Exec(
"DELETE FROM direct_messages WHERE created_at IS NOT NULL AND created_at < ?",
cutoff,
)
if res.Error != nil {
return fmt.Errorf("deleting old DMs past the cutoff period: %s", res.Error)
}
}
return nil
}
// LogMessage adds a message to the DM history between two users.
func (dm DirectMessage) LogMessage(fromUsername, toUsername string, msg messages.Message) error {
if DB == nil {
return ErrNotInitialized
}
if msg.MessageID == 0 {
return errors.New("message did not have a MessageID")
}
var (
channelID = CreateChannelID(fromUsername, toUsername)
)
m := &DirectMessage{
MessageID: msg.MessageID,
ChannelID: channelID,
Username: fromUsername,
Message: msg.Message,
}
return DB.Create(&m).Error
}
// ClearMessages clears all stored DMs that the username as a participant in.
func (dm DirectMessage) ClearMessages(username string) (int, error) {
if DB == nil {
return 0, ErrNotInitialized
}
var placeholders = []interface{}{
time.Now(),
fmt.Sprintf("@%s:%%", username), // `@alice:%`
fmt.Sprintf("%%:@%s", username), // `%:@alice`
username,
}
// Count all the messages we'll delete.
var (
count int
row = DB.Raw(`
SELECT COUNT(message_id)
FROM direct_messages
WHERE (channel_id LIKE ? OR channel_id LIKE ?)
OR username = ?
`, placeholders...)
)
if res := row.Scan(&count); res.Error != nil && false {
return 0, res.Error
}
// Delete them all.
res := DB.Exec(`
UPDATE direct_messages
SET deleted_at = ?
WHERE (channel_id LIKE ? OR channel_id LIKE ?)
OR username = ?
`, placeholders...)
return count, res.Error
}
// TakebackMessage removes a message by its MID from the DM history.
//
// Because the MessageID may have been from a previous chat session, the server can't immediately
// verify the current user had permission to take it back. This function instead will check whether
// a DM history exists sent by this username for that messageID, and if so, returns a
// boolean true that the username/messageID matched which will satisfy the permission check
// in the OnTakeback handler.
func (dm DirectMessage) TakebackMessage(username string, messageID int64, isAdmin bool) (bool, error) {
if DB == nil {
return false, ErrNotInitialized
}
// Does this messageID exist as sent by the user?
if !isAdmin {
var (
row = DB.Raw(
"SELECT message_id FROM direct_messages WHERE username = ? AND message_id = ?",
username, messageID,
)
foundMsgID int64
err = row.Scan(&foundMsgID)
)
if err != nil {
return false, errors.New("no such message ID found as owned by that user")
}
}
// Delete it.
res := DB.Exec(
"DELETE FROM direct_messages WHERE message_id = ?",
messageID,
)
// Return that it was successfully validated and deleted.
return res.Error == nil, res.Error
}
// PaginateDirectMessages returns a page of messages, the count of remaining, and an error.
func PaginateDirectMessages(fromUsername, toUsername string, beforeID int64) ([]messages.Message, int, error) {
if DB == nil {
return nil, 0, ErrNotInitialized
}
var (
result = []messages.Message{}
channelID = CreateChannelID(fromUsername, toUsername)
rows = []*DirectMessage{}
// Compute the remaining messages after finding the final messageID this page.
lastMessageID int64
remaining int
)
if beforeID == 0 {
beforeID = math.MaxInt64
}
res := DB.Model(&DirectMessage{}).Where(
"channel_id = ? AND message_id < ?",
channelID, beforeID,
).Order("message_id DESC").Limit(DirectMessagePerPage).Find(&rows)
if res.Error != nil {
return nil, 0, res.Error
}
for _, row := range rows {
msg := messages.Message{
MessageID: row.MessageID,
Username: row.Username,
Message: row.Message,
Timestamp: row.CreatedAt.Format(time.RFC3339),
}
result = append(result, msg)
lastMessageID = msg.MessageID
}
// Get a count of the remaining messages.
row := DB.Raw(`
SELECT COUNT(message_id)
FROM direct_messages
WHERE channel_id = ?
AND message_id < ?
`, channelID, lastMessageID)
if res := row.Scan(&remaining); res.Error != nil {
return nil, 0, res.Error
}
return result, remaining, nil
}
// PaginateUsernames returns a page of usernames that the current user has conversations with.
//
// Returns the usernames, total count, pages, and error.
func PaginateUsernames(fromUsername, sortBy string, page, perPage int) ([]string, int, int, error) {
if DB == nil {
return nil, 0, 0, ErrNotInitialized
}
var (
result = []string{}
count int // Total count of usernames
pages int // Number of pages available
offset = (page - 1) * perPage
orderBy string
// Channel IDs.
channelIDs = []string{
fmt.Sprintf(`@%s:%%`, fromUsername),
fmt.Sprintf(`%%:@%s`, fromUsername),
}
)
if offset < 0 {
offset = 0
}
switch sortBy {
case "a-z":
orderBy = "username ASC"
case "z-a":
orderBy = "username DESC"
case "oldest":
orderBy = "newest_time ASC"
default:
orderBy = "newest_time DESC"
}
// Get all our distinct channel IDs to filter the query down: otherwise doing an ORDER BY timestamp
// causes a full table scan index which is very inefficient!
channelIDs, err := GetDistinctChannelIDs(fromUsername)
if err != nil {
return nil, 0, 0, err
}
// No channel IDs = no response to fetch.
if len(channelIDs) == 0 {
return nil, 0, 0, errors.New("you have no direct messages stored on this chat server")
}
type record struct {
Username string
// NewestTime time.Time
// OldestTime time.Time
}
var records []record
// Get all usernames and their newest/oldest timestamps.
res := DB.Model(&DirectMessage{}).Select(`
username,
MAX(created_at) AS newest_time
`).Where(
"channel_id IN ? AND username <> ?",
channelIDs, fromUsername,
).Group("username").Order(orderBy).Offset(offset).Limit(perPage).Scan(&records)
if res.Error != nil {
return nil, 0, 0, res.Error
}
for _, row := range records {
result = append(result, row.Username)
}
// The count of distinct channel IDs earlier.
count = len(channelIDs)
pages = int(math.Ceil(float64(count) / float64(perPage)))
if pages < 1 {
pages = 1
}
return result, count, pages, nil
}
// GetDistinctChannelIDs collects all of the conversation thread IDs the current user is a party to.
func GetDistinctChannelIDs(username string) ([]string, error) {
var (
result = []string{}
channelIDs = []string{
fmt.Sprintf(`@%s:%%`, username),
fmt.Sprintf(`%%:@%s`, username),
}
)
res := DB.Model(&DirectMessage{}).Select(
"DISTINCT(channel_id)",
).Where(
"(channel_id LIKE ? OR channel_id LIKE ?)",
channelIDs[0], channelIDs[1],
).Scan(&result)
if res.Error != nil {
return nil, res.Error
}
return result, nil
}
// CreateChannelID returns a deterministic channel ID for a direct message conversation.
//
// The usernames (passed in any order) are sorted alphabetically and composed into the channel ID.
func CreateChannelID(fromUsername, toUsername string) string {
var parts = []string{fromUsername, toUsername}
sort.Strings(parts)
return fmt.Sprintf(
"@%s:@%s",
parts[0],
parts[1],
)
}