package models import ( "time" "git.kirsle.net/apps/gosocial/pkg/log" "gorm.io/gorm" ) // Notification table. type Notification struct { ID uint64 `gorm:"primaryKey"` UserID uint64 `gorm:"index"` // who it belongs to AboutUserID *uint64 `form:"index"` // the other party of this notification User User `gorm:"foreignKey:about_user_id"` Type NotificationType // like, comment, ... Read bool `gorm:"index"` TableName string // on which of your tables (photos, comments, ...) TableID uint64 Message string // text associated, e.g. copy of comment added CreatedAt time.Time UpdatedAt time.Time } // Preload related tables for the forum (classmethod). func (n *Notification) Preload() *gorm.DB { return DB.Preload("User.ProfilePhoto") } type NotificationType string const ( NotificationLike NotificationType = "like" NotificationComment = "comment" NotificationCustom = "custom" // custom message pushed ) // CreateNotification func CreateNotification(n *Notification) error { return DB.Create(n).Error } // GetNotification by ID. func GetNotification(id uint64) (*Notification, error) { var n *Notification result := DB.Model(n).First(&n, id) return n, result.Error } // RemoveNotification about a table ID, e.g. when removing a like. func RemoveNotification(tableName string, tableID uint64) error { result := DB.Where( "table_name = ? AND table_id = ?", tableName, tableID, ).Delete(&Notification{}) return result.Error } // MarkNotificationsRead sets all a user's notifications to read. func MarkNotificationsRead(user *User) error { return DB.Model(&Notification{}).Where( "user_id = ? AND read IS NOT TRUE", user.ID, ).Update("read", true).Error } // CountUnreadNotifications gets the count of unread Notifications for a user. func CountUnreadNotifications(userID uint64) (int64, error) { query := DB.Where( "user_id = ? AND read = ?", userID, false, ) var count int64 result := query.Model(&Notification{}).Count(&count) return count, result.Error } // PaginateNotifications returns the user's notifications. func PaginateNotifications(user *User, pager *Pagination) ([]*Notification, error) { var ns = []*Notification{} query := (&Notification{}).Preload().Where( "user_id = ?", user.ID, ).Order( pager.Sort, ) query.Model(&Notification{}).Count(&pager.Total) result := query.Offset(pager.GetOffset()).Limit(pager.PerPage).Find(&ns) return ns, result.Error } // Save a notification. func (n *Notification) Save() error { return DB.Save(n).Error } // NotificationBody can store remote tables mapped. type NotificationBody struct { PhotoID uint64 Photo *Photo } type NotificationMap map[uint64]*NotificationBody // Get a notification's body from the map. func (m NotificationMap) Get(id uint64) *NotificationBody { if body, ok := m[id]; ok { return body } return &NotificationBody{} } // MapNotifications loads associated assets, like Photos, mapped to their notification ID. func MapNotifications(ns []*Notification) NotificationMap { var ( IDs = []uint64{} result = NotificationMap{} ) // Collect notification IDs. for _, row := range ns { IDs = append(IDs, row.ID) result[row.ID] = &NotificationBody{} } type scanner struct { PhotoID uint64 NotificationID uint64 } var scan []scanner // Load all of these that have photos. err := DB.Table( "notifications", ).Joins( "JOIN photos ON (notifications.table_name='photos' AND notifications.table_id=photos.id)", ).Select( "photos.id AS photo_id", "notifications.id AS notification_id", ).Where( "notifications.id IN ?", IDs, ).Scan(&scan) if err != nil { log.Error("Couldn't select photo IDs for notifications: %s", err) } // Collect and load all the photos by ID. var photoIDs = []uint64{} for _, row := range scan { // Store the photo ID in the result now. result[row.NotificationID].PhotoID = row.PhotoID photoIDs = append(photoIDs, row.PhotoID) } // Load the photos. if len(photoIDs) > 0 { if photos, err := GetPhotos(photoIDs); err != nil { log.Error("Couldn't load photo IDs for notifications: %s", err) } else { // Marry them to their notification IDs. for _, body := range result { if photo, ok := photos[body.PhotoID]; ok { body.Photo = photo } } } } return result }