This repository has been archived on 2022-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
gosocial/pkg/models/thread.go

259 lines
6.3 KiB
Go
Raw Normal View History

package models
import (
"errors"
"fmt"
"strings"
"time"
"git.kirsle.net/apps/gosocial/pkg/log"
"gorm.io/gorm"
)
// Thread table - a post within a Forum.
type Thread struct {
ID uint64 `gorm:"primaryKey"`
ForumID uint64 `gorm:"index"`
Forum Forum
Pinned bool `gorm:"index"`
Explicit bool `gorm:"index"`
NoReply bool
Title string
CommentID uint64 `gorm:"index"`
Comment Comment // first comment of the thread
Views uint64
CreatedAt time.Time
UpdatedAt time.Time
}
// Preload related tables for the forum (classmethod).
func (f *Thread) Preload() *gorm.DB {
return DB.Preload("Forum").Preload("Comment.User.ProfilePhoto")
}
// GetThread by ID.
func GetThread(id uint64) (*Thread, error) {
t := &Thread{}
result := t.Preload().First(&t, id)
return t, result.Error
}
// GetThreads queries a set of thread IDs and returns them mapped.
func GetThreads(IDs []uint64) (map[uint64]*Thread, error) {
var (
mt = map[uint64]*Thread{}
ts = []*Thread{}
)
result := (&Thread{}).Preload().Where("id IN ?", IDs).Find(&ts)
for _, row := range ts {
mt[row.ID] = row
}
return mt, result.Error
}
// CreateThread creates a new thread with proper Comment structure.
func CreateThread(user *User, forumID uint64, title, message string, pinned, explicit, noReply bool) (*Thread, error) {
thread := &Thread{
ForumID: forumID,
Title: title,
Pinned: pinned,
Explicit: explicit,
NoReply: noReply && user.IsAdmin,
Comment: Comment{
User: *user,
Message: message,
},
}
log.Error("CreateThread: Going to post %+v", thread)
// Create the thread & comment first...
result := DB.Create(thread)
if result.Error != nil {
return nil, result.Error
}
// Fill out the Comment with proper reverse foreign keys.
thread.Comment.TableName = "threads"
thread.Comment.TableID = thread.ID
log.Error("Saving updated comment: %+v", thread)
result = DB.Save(&thread.Comment)
return thread, result.Error
}
// Reply to a thread, adding an additional comment.
func (t *Thread) Reply(user *User, message string) (*Comment, error) {
// Save the thread on reply, updating its timestamp.
if err := t.Save(); err != nil {
log.Error("Thread.Reply: couldn't ping UpdatedAt on thread: %s", err)
}
return AddComment(user, "threads", t.ID, message)
}
// DeleteReply removes a comment from a thread. If it is the primary comment, deletes the whole thread.
func (t *Thread) DeleteReply(comment *Comment) error {
// Sanity check that this reply is one of ours.
if !(comment.TableName == "threads" && comment.TableID == t.ID) {
return errors.New("that comment doesn't belong to this thread")
}
// Is this the primary comment that started the thread? If so, delete the whole thread.
if comment.ID == t.CommentID {
log.Error("DeleteReply(%d): this is the parent comment of a thread (%d '%s'), remove the whole thread", comment.ID, t.ID, t.Title)
return t.Delete()
}
// Remove just this comment.
return comment.Delete()
}
// PinnedThreads returns all pinned threads in a forum (there should generally be few of these).
func PinnedThreads(forum *Forum) ([]*Thread, error) {
var (
ts = []*Thread{}
query = (&Thread{}).Preload().Where(
"forum_id = ? AND pinned IS TRUE",
forum.ID,
).Order("updated_at desc")
)
result := query.Find(&ts)
return ts, result.Error
}
// PaginateThreads provides a forum index view of posts, minus pinned posts.
func PaginateThreads(user *User, forum *Forum, pager *Pagination) ([]*Thread, error) {
var (
ts = []*Thread{}
query = (&Thread{}).Preload()
wheres = []string{}
placeholders = []interface{}{}
)
// Always filters.
wheres = append(wheres, "forum_id = ? AND pinned IS NOT TRUE")
placeholders = append(placeholders, forum.ID)
// If the user hasn't opted in for Explicit, hide NSFW threads.
if !user.Explicit && !user.IsAdmin {
wheres = append(wheres, "explicit IS NOT TRUE")
}
query = query.Where(
strings.Join(wheres, " AND "),
placeholders...,
).Order(pager.Sort)
query.Model(&Thread{}).Count(&pager.Total)
result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&ts)
return ts, result.Error
}
// View a thread, incrementing its View count but not its UpdatedAt.
func (t *Thread) View() error {
return DB.Model(&Thread{}).Where(
"id = ?",
t.ID,
).Updates(map[string]interface{}{
"views": t.Views + 1,
"updated_at": t.UpdatedAt,
}).Error
}
// Save a thread, updating its timestamp.
func (t *Thread) Save() error {
return DB.Save(t).Error
}
// Delete a thread and all of its comments.
func (t *Thread) Delete() error {
// Remove all comments.
result := DB.Where(
"table_name = ? AND table_id = ?",
"threads", t.ID,
).Delete(&Comment{})
if result.Error != nil {
return fmt.Errorf("deleting comments for thread: %s", result.Error)
}
// Remove the thread itself.
return DB.Delete(t).Error
}
// ThreadStatistics queries for reply/view count for threads.
type ThreadStatistics struct {
Replies uint64
Views uint64
}
type ThreadStatsMap map[uint64]*ThreadStatistics
// MapThreadStatistics looks up statistics for a set of threads.
func MapThreadStatistics(threads []*Thread) ThreadStatsMap {
var (
result = ThreadStatsMap{}
IDs = []uint64{}
)
// Collect thread IDs and initialize the map.
for _, thread := range threads {
IDs = append(IDs, thread.ID)
result[thread.ID] = &ThreadStatistics{
Views: thread.Views,
}
}
// Hold the result of the count/group by query.
type group struct {
ID uint64
Replies uint64
}
var groups = []group{}
// Count comments grouped by thread IDs.
err := DB.Table(
"comments",
).Select(
"table_id AS id, count(id) AS replies",
).Where(
"table_name = ? AND table_id IN ?",
"threads", IDs,
).Group("table_id").Scan(&groups)
if err != nil {
log.Error("MapThreadStatistics: SQL error: %s")
}
// Map the results in.
for _, row := range groups {
log.Error("Got row: %+v", row)
if stats, ok := result[row.ID]; ok {
stats.Replies = row.Replies
// Remove the OG comment from the count.
if stats.Replies > 0 {
stats.Replies--
}
}
}
return result
}
// Has stats for this thread? (we should..)
func (ts ThreadStatsMap) Has(threadID uint64) bool {
_, ok := ts[threadID]
return ok
}
// Get thread stats.
func (ts ThreadStatsMap) Get(threadID uint64) *ThreadStatistics {
if stats, ok := ts[threadID]; ok {
return stats
}
return nil
}