* For the Direct Message History database, use gorm.io as ORM so that Postgres can be used instead of SQLite for bigger chat room instances. * In settings.toml: the new DatabaseType field defaults to 'sqlite3' but can be set to 'postgres' and use the credentials in the new PostgresDatabase field. * The DirectMessage table schema is also updated to deprecate the Timestamp int field in favor of a proper CreatedAt datetime field. Existing SQLite instances will upgrade their table in the background, converting Timestamp to CreatedAt and blanking out the legacy Timestamp column. * Fix some DB queries so when paginating your DMs history username list, sorting it by timestamp now works reliably. * For existing SQLite instances that want to switch to Postgres, use the scripts/sqlite2psql.py script to transfer your database over.
355 lines
8.9 KiB
Go
355 lines
8.9 KiB
Go
package models
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"time"
|
|
|
|
"git.kirsle.net/apps/barertc/pkg/config"
|
|
"git.kirsle.net/apps/barertc/pkg/log"
|
|
"git.kirsle.net/apps/barertc/pkg/messages"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type DirectMessage struct {
|
|
MessageID int64 `gorm:"primaryKey"`
|
|
ChannelID string `gorm:"index"`
|
|
Username string `gorm:"index"`
|
|
Message string
|
|
Timestamp int64 // deprecated
|
|
CreatedAt time.Time `gorm:"index"`
|
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
|
}
|
|
|
|
const DirectMessagePerPage = 20
|
|
|
|
// MigrateV2 upgrades the DirectMessage table for the V2 schema, where we switched
|
|
// from using SQLite to GORM so we can optionally use Postgres instead.
|
|
//
|
|
// During this switch, we also deprecate the old Timestamp column in favor of CreatedAt.
|
|
// This function will run in a background goroutine and update legacy rows to the new format.
|
|
func (db DirectMessage) MigrateV2() {
|
|
// Find rows that need upgrading.
|
|
var page int
|
|
for {
|
|
page++
|
|
var rows = []*DirectMessage{}
|
|
res := DB.Model(&DirectMessage{}).Where(
|
|
"timestamp > 0",
|
|
).Limit(1000).Find(&rows)
|
|
if res.Error != nil {
|
|
log.Error("DirectMessage.MigrateV2: %s", res.Error)
|
|
return
|
|
}
|
|
|
|
if len(rows) == 0 {
|
|
break
|
|
}
|
|
|
|
log.Warn("DirectMessage.MigrateV2: Updating %d DMs (page %d)", len(rows), page)
|
|
for _, row := range rows {
|
|
var created = time.Unix(row.Timestamp, 0)
|
|
row.CreatedAt = created
|
|
row.Timestamp = 0
|
|
DB.Save(row)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Init runs initialization tasks for the DMs table (migrate V2 and expire old rows).
|
|
func (dm DirectMessage) Init() error {
|
|
if DB == nil {
|
|
return ErrNotInitialized
|
|
}
|
|
|
|
// Migrate old rows to the V2 schema.
|
|
dm.MigrateV2()
|
|
|
|
// Delete old messages past the retention period.
|
|
if days := config.Current.DirectMessageHistory.RetentionDays; days > 0 {
|
|
cutoff := time.Now().Add(time.Duration(-days) * 24 * time.Hour)
|
|
log.Info("Deleting old DM history past %d days (cutoff: %s)", days, cutoff.Format(time.RFC3339))
|
|
res := DB.Exec(
|
|
"DELETE FROM direct_messages WHERE created_at IS NOT NULL AND created_at < ?",
|
|
cutoff,
|
|
)
|
|
if res.Error != nil {
|
|
return fmt.Errorf("deleting old DMs past the cutoff period: %s", res.Error)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LogMessage adds a message to the DM history between two users.
|
|
func (dm DirectMessage) LogMessage(fromUsername, toUsername string, msg messages.Message) error {
|
|
if DB == nil {
|
|
return ErrNotInitialized
|
|
}
|
|
|
|
if msg.MessageID == 0 {
|
|
return errors.New("message did not have a MessageID")
|
|
}
|
|
|
|
var (
|
|
channelID = CreateChannelID(fromUsername, toUsername)
|
|
)
|
|
|
|
m := &DirectMessage{
|
|
MessageID: msg.MessageID,
|
|
ChannelID: channelID,
|
|
Username: fromUsername,
|
|
Message: msg.Message,
|
|
}
|
|
return DB.Create(&m).Error
|
|
}
|
|
|
|
// ClearMessages clears all stored DMs that the username as a participant in.
|
|
func (dm DirectMessage) ClearMessages(username string) (int, error) {
|
|
if DB == nil {
|
|
return 0, ErrNotInitialized
|
|
}
|
|
|
|
var placeholders = []interface{}{
|
|
time.Now(),
|
|
fmt.Sprintf("@%s:%%", username), // `@alice:%`
|
|
fmt.Sprintf("%%:@%s", username), // `%:@alice`
|
|
username,
|
|
}
|
|
|
|
// Count all the messages we'll delete.
|
|
var (
|
|
count int
|
|
row = DB.Raw(`
|
|
SELECT COUNT(message_id)
|
|
FROM direct_messages
|
|
WHERE (channel_id LIKE ? OR channel_id LIKE ?)
|
|
OR username = ?
|
|
`, placeholders...)
|
|
)
|
|
if res := row.Scan(&count); res.Error != nil && false {
|
|
return 0, res.Error
|
|
}
|
|
|
|
// Delete them all.
|
|
res := DB.Exec(`
|
|
UPDATE direct_messages
|
|
SET deleted_at = ?
|
|
WHERE (channel_id LIKE ? OR channel_id LIKE ?)
|
|
OR username = ?
|
|
`, placeholders...)
|
|
|
|
return count, res.Error
|
|
}
|
|
|
|
// TakebackMessage removes a message by its MID from the DM history.
|
|
//
|
|
// Because the MessageID may have been from a previous chat session, the server can't immediately
|
|
// verify the current user had permission to take it back. This function instead will check whether
|
|
// a DM history exists sent by this username for that messageID, and if so, returns a
|
|
// boolean true that the username/messageID matched which will satisfy the permission check
|
|
// in the OnTakeback handler.
|
|
func (dm DirectMessage) TakebackMessage(username string, messageID int64, isAdmin bool) (bool, error) {
|
|
if DB == nil {
|
|
return false, ErrNotInitialized
|
|
}
|
|
|
|
// Does this messageID exist as sent by the user?
|
|
if !isAdmin {
|
|
var (
|
|
row = DB.Raw(
|
|
"SELECT message_id FROM direct_messages WHERE username = ? AND message_id = ?",
|
|
username, messageID,
|
|
)
|
|
foundMsgID int64
|
|
err = row.Scan(&foundMsgID)
|
|
)
|
|
if err != nil {
|
|
return false, errors.New("no such message ID found as owned by that user")
|
|
}
|
|
}
|
|
|
|
// Delete it.
|
|
res := DB.Exec(
|
|
"DELETE FROM direct_messages WHERE message_id = ?",
|
|
messageID,
|
|
)
|
|
|
|
// Return that it was successfully validated and deleted.
|
|
return res.Error == nil, res.Error
|
|
}
|
|
|
|
// PaginateDirectMessages returns a page of messages, the count of remaining, and an error.
|
|
func PaginateDirectMessages(fromUsername, toUsername string, beforeID int64) ([]messages.Message, int, error) {
|
|
if DB == nil {
|
|
return nil, 0, ErrNotInitialized
|
|
}
|
|
|
|
var (
|
|
result = []messages.Message{}
|
|
channelID = CreateChannelID(fromUsername, toUsername)
|
|
rows = []*DirectMessage{}
|
|
|
|
// Compute the remaining messages after finding the final messageID this page.
|
|
lastMessageID int64
|
|
remaining int
|
|
)
|
|
|
|
if beforeID == 0 {
|
|
beforeID = math.MaxInt64
|
|
}
|
|
|
|
res := DB.Model(&DirectMessage{}).Where(
|
|
"channel_id = ? AND message_id < ?",
|
|
channelID, beforeID,
|
|
).Order("message_id DESC").Limit(DirectMessagePerPage).Find(&rows)
|
|
if res.Error != nil {
|
|
return nil, 0, res.Error
|
|
}
|
|
|
|
for _, row := range rows {
|
|
msg := messages.Message{
|
|
MessageID: row.MessageID,
|
|
Username: row.Username,
|
|
Message: row.Message,
|
|
Timestamp: row.CreatedAt.Format(time.RFC3339),
|
|
}
|
|
result = append(result, msg)
|
|
lastMessageID = msg.MessageID
|
|
}
|
|
|
|
// Get a count of the remaining messages.
|
|
row := DB.Raw(`
|
|
SELECT COUNT(message_id)
|
|
FROM direct_messages
|
|
WHERE channel_id = ?
|
|
AND message_id < ?
|
|
`, channelID, lastMessageID)
|
|
if res := row.Scan(&remaining); res.Error != nil {
|
|
return nil, 0, res.Error
|
|
}
|
|
|
|
return result, remaining, nil
|
|
}
|
|
|
|
// PaginateUsernames returns a page of usernames that the current user has conversations with.
|
|
//
|
|
// Returns the usernames, total count, pages, and error.
|
|
func PaginateUsernames(fromUsername, sortBy string, page, perPage int) ([]string, int, int, error) {
|
|
if DB == nil {
|
|
return nil, 0, 0, ErrNotInitialized
|
|
}
|
|
|
|
var (
|
|
result = []string{}
|
|
count int // Total count of usernames
|
|
pages int // Number of pages available
|
|
offset = (page - 1) * perPage
|
|
orderBy string
|
|
|
|
// Channel IDs.
|
|
channelIDs = []string{
|
|
fmt.Sprintf(`@%s:%%`, fromUsername),
|
|
fmt.Sprintf(`%%:@%s`, fromUsername),
|
|
}
|
|
)
|
|
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
switch sortBy {
|
|
case "a-z":
|
|
orderBy = "username ASC"
|
|
case "z-a":
|
|
orderBy = "username DESC"
|
|
case "oldest":
|
|
orderBy = "newest_time ASC"
|
|
default:
|
|
orderBy = "newest_time DESC"
|
|
}
|
|
|
|
// Get all our distinct channel IDs to filter the query down: otherwise doing an ORDER BY timestamp
|
|
// causes a full table scan index which is very inefficient!
|
|
channelIDs, err := GetDistinctChannelIDs(fromUsername)
|
|
if err != nil {
|
|
return nil, 0, 0, err
|
|
}
|
|
|
|
// No channel IDs = no response to fetch.
|
|
if len(channelIDs) == 0 {
|
|
return nil, 0, 0, errors.New("you have no direct messages stored on this chat server")
|
|
}
|
|
|
|
type record struct {
|
|
Username string
|
|
// NewestTime time.Time
|
|
// OldestTime time.Time
|
|
}
|
|
var records []record
|
|
|
|
// Get all usernames and their newest/oldest timestamps.
|
|
res := DB.Model(&DirectMessage{}).Select(`
|
|
username,
|
|
MAX(created_at) AS newest_time
|
|
`).Where(
|
|
"channel_id IN ? AND username <> ?",
|
|
channelIDs, fromUsername,
|
|
).Group("username").Order(orderBy).Offset(offset).Limit(perPage).Scan(&records)
|
|
if res.Error != nil {
|
|
return nil, 0, 0, res.Error
|
|
}
|
|
|
|
for _, row := range records {
|
|
result = append(result, row.Username)
|
|
}
|
|
|
|
// The count of distinct channel IDs earlier.
|
|
count = len(channelIDs)
|
|
|
|
pages = int(math.Ceil(float64(count) / float64(perPage)))
|
|
if pages < 1 {
|
|
pages = 1
|
|
}
|
|
|
|
return result, count, pages, nil
|
|
}
|
|
|
|
// GetDistinctChannelIDs collects all of the conversation thread IDs the current user is a party to.
|
|
func GetDistinctChannelIDs(username string) ([]string, error) {
|
|
var (
|
|
result = []string{}
|
|
channelIDs = []string{
|
|
fmt.Sprintf(`@%s:%%`, username),
|
|
fmt.Sprintf(`%%:@%s`, username),
|
|
}
|
|
)
|
|
|
|
res := DB.Model(&DirectMessage{}).Select(
|
|
"DISTINCT(channel_id)",
|
|
).Where(
|
|
"(channel_id LIKE ? OR channel_id LIKE ?)",
|
|
channelIDs[0], channelIDs[1],
|
|
).Scan(&result)
|
|
if res.Error != nil {
|
|
return nil, res.Error
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// CreateChannelID returns a deterministic channel ID for a direct message conversation.
|
|
//
|
|
// The usernames (passed in any order) are sorted alphabetically and composed into the channel ID.
|
|
func CreateChannelID(fromUsername, toUsername string) string {
|
|
var parts = []string{fromUsername, toUsername}
|
|
sort.Strings(parts)
|
|
return fmt.Sprintf(
|
|
"@%s:@%s",
|
|
parts[0],
|
|
parts[1],
|
|
)
|
|
}
|