Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 139 additions & 23 deletions backend/pkg/bot/dingtalk/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ type DingTalkClient struct {
accessToken string
expireAt time.Time
}
tokenMutex sync.RWMutex
tokenMutex sync.RWMutex
messageMu sync.Mutex
messageSeenAt map[string]time.Time
messageTTL time.Duration
nowFunc func() time.Time
processMessageFn func(ctx context.Context, data *chatbot.BotCallbackDataModel) error
}

func NewDingTalkClient(ctx context.Context, cancel context.CancelFunc, clientId, clientSecret, templateID string, logger *log.Logger, getQA bot.GetQAFun) (*DingTalkClient, error) {
Expand All @@ -54,17 +59,22 @@ func NewDingTalkClient(ctx context.Context, cancel context.CancelFunc, clientId,
if err != nil {
return nil, fmt.Errorf("failed to create card client: %w", err)
}
return &DingTalkClient{
ctx: ctx,
cancel: cancel,
clientID: clientId,
clientSecret: clientSecret,
templateID: templateID,
oauthClient: oauthClient,
cardClient: cardClient,
getQA: getQA,
logger: logger,
}, nil
client := &DingTalkClient{
ctx: ctx,
cancel: cancel,
clientID: clientId,
clientSecret: clientSecret,
templateID: templateID,
oauthClient: oauthClient,
cardClient: cardClient,
getQA: getQA,
logger: logger,
messageSeenAt: make(map[string]time.Time),
messageTTL: 5 * time.Minute,
nowFunc: time.Now,
}
client.startMessageCleanup()
return client, nil
}

func (c *DingTalkClient) GetAccessToken() (string, error) {
Expand Down Expand Up @@ -191,6 +201,75 @@ func (c *DingTalkClient) CreateAndDeliverCard(ctx context.Context, trackID strin
return nil
}

func (c *DingTalkClient) startMessageCleanup() {
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.cleanupExpiredMessages()
}
}
}()
}

func (c *DingTalkClient) cleanupExpiredMessages() {
now := c.nowFunc()

c.messageMu.Lock()
defer c.messageMu.Unlock()

for msgID, seenAt := range c.messageSeenAt {
if now.Sub(seenAt) > c.messageTTL {
delete(c.messageSeenAt, msgID)
}
}
}

func (c *DingTalkClient) tryMarkMessage(msgID string) bool {
if strings.TrimSpace(msgID) == "" {
return true
}

now := c.nowFunc()

c.messageMu.Lock()
defer c.messageMu.Unlock()

if seenAt, ok := c.messageSeenAt[msgID]; ok && now.Sub(seenAt) <= c.messageTTL {
return false
}

c.messageSeenAt[msgID] = now
return true
}

func (c *DingTalkClient) markMessageCompleted(msgID string) {
if strings.TrimSpace(msgID) == "" {
return
}

c.messageMu.Lock()
defer c.messageMu.Unlock()

c.messageSeenAt[msgID] = c.nowFunc()
}

func (c *DingTalkClient) clearMessageMark(msgID string) {
if strings.TrimSpace(msgID) == "" {
return
}

c.messageMu.Lock()
defer c.messageMu.Unlock()

delete(c.messageSeenAt, msgID)
}

func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
select {
case <-c.ctx.Done():
Expand All @@ -199,6 +278,40 @@ func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *cha
default:
}

if !c.tryMarkMessage(data.MsgId) {
c.logger.Info("ignore duplicate dingtalk message", log.String("msg_id", data.MsgId))
return []byte(""), nil
}

processor := c.processMessageFn
if processor == nil {
processor = c.processMessage
}

payload := *data
go c.processMessageAsync(c.ctx, &payload, processor)

return []byte(""), nil
}

func (c *DingTalkClient) processMessageAsync(ctx context.Context, data *chatbot.BotCallbackDataModel, processor func(context.Context, *chatbot.BotCallbackDataModel) error) {
defer func() {
if r := recover(); r != nil {
c.clearMessageMark(data.MsgId)
c.logger.Error("process dingtalk message panicked", log.String("msg_id", data.MsgId), log.Any("panic", r))
}
}()

if err := processor(ctx, data); err != nil {
c.clearMessageMark(data.MsgId)
c.logger.Error("process dingtalk message failed", log.String("msg_id", data.MsgId), log.Error(err))
return
}

c.markMessageCompleted(data.MsgId)
}

func (c *DingTalkClient) processMessage(ctx context.Context, data *chatbot.BotCallbackDataModel) error {
question := data.Text.Content
question = strings.TrimSpace(question)
trackID := uuid.New().String()
Expand All @@ -207,14 +320,14 @@ func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *cha
// create and deliver card
if err := c.CreateAndDeliverCard(ctx, trackID, data); err != nil {
c.logger.Error("CreateAndDeliverCard", log.Error(err))
return nil, err
return err
}

initialContent := fmt.Sprintf("**%s**\n\n%s", question, "稍等,让我想一想……")

if err := c.UpdateAIStreamCard(trackID, initialContent, false); err != nil {
c.logger.Error("UpdateInteractiveCard", log.Error(err))
return nil, nil
return err
}
// 初始化 默认为空
convInfo := &domain.ConversationInfo{
Expand Down Expand Up @@ -242,10 +355,11 @@ func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *cha
contentCh, err := c.getQA(ctx, question, *convInfo, "")
if err != nil {
c.logger.Error("dingtalk client failed to get answer", log.Error(err))
if err := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); err != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
return fmt.Errorf("get answer failed: %w; update error card failed: %w", err, updateErr)
}
return nil, nil
return nil
}

updateTicker := time.NewTicker(1500 * time.Millisecond)
Expand All @@ -259,11 +373,12 @@ func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *cha
if !ok {
if err := c.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
c.logger.Error("UpdateInteractiveCard in contentCh", log.Error(err))
if err := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); err != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
return fmt.Errorf("final update card failed: %w; fallback update failed: %w", err, updateErr)
}
}
return []byte(""), nil
return nil
}
fullContent += content
case <-updateTicker.C:
Expand All @@ -272,10 +387,11 @@ func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *cha
}
if err := c.UpdateAIStreamCard(trackID, fullContent, false); err != nil {
c.logger.Error("UpdateInteractiveCard in ticker", log.Error(err))
if err := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); err != nil {
c.logger.Error("UpdateInteractiveCard in ticker failed", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in ticker failed", log.Error(updateErr))
return fmt.Errorf("stream update card failed: %w; fallback update failed: %w", err, updateErr)
}
return []byte(""), nil
return nil
}
}
}
Expand Down
Loading