diff --git a/go.sum b/go.sum index 1943dc8d..65f5adb4 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -160,6 +162,8 @@ golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sU golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -188,6 +192,8 @@ golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= diff --git a/internal/api/chat/create_conversation_message_stream_v2.go b/internal/api/chat/create_conversation_message_stream_v2.go index b82adf5d..8c715a36 100644 --- a/internal/api/chat/create_conversation_message_stream_v2.go +++ b/internal/api/chat/create_conversation_message_stream_v2.go @@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( APIKey: settings.OpenAIAPIKey, } - openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) + openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider) if err != nil { return s.sendStreamError(stream, err) } @@ -307,7 +307,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( for i, bsonMsg := range conversation.InappChatHistory { protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg) } - title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider) + title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, protoMessages, llmProvider) if err != nil { s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex()) return diff --git a/internal/api/grpc.go b/internal/api/grpc.go index ed9dc2b0..3451d667 100644 --- a/internal/api/grpc.go +++ b/internal/api/grpc.go @@ -15,6 +15,7 @@ import ( chatv2 "paperdebugger/pkg/gen/api/chat/v2" commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" // "github.com/grpc-ecosystem/go-grpc-middleware" @@ -106,6 +107,7 @@ func NewGrpcServer( userServer userv1.UserServiceServer, projectServer projectv1.ProjectServiceServer, commentServer commentv1.CommentServiceServer, + usageServer usagev1.UsageServiceServer, ) *GrpcServer { grpcServer := &GrpcServer{} grpcServer.userService = userService @@ -121,5 +123,6 @@ func NewGrpcServer( userv1.RegisterUserServiceServer(grpcServer.Server, userServer) projectv1.RegisterProjectServiceServer(grpcServer.Server, projectServer) commentv1.RegisterCommentServiceServer(grpcServer.Server, commentServer) + usagev1.RegisterUsageServiceServer(grpcServer.Server, usageServer) return grpcServer } diff --git a/internal/api/server.go b/internal/api/server.go index b093c767..60e790d9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -6,17 +6,23 @@ import ( "fmt" "net" "net/http" + "os" + "os/signal" "strings" + "syscall" "paperdebugger/internal/libs/logger" "paperdebugger/internal/libs/metadatautil" "paperdebugger/internal/libs/shared" + "paperdebugger/internal/services" + aiclient "paperdebugger/internal/services/toolkit/client" authv1 "paperdebugger/pkg/gen/api/auth/v1" chatv1 "paperdebugger/pkg/gen/api/chat/v1" chatv2 "paperdebugger/pkg/gen/api/chat/v2" commentv1 "paperdebugger/pkg/gen/api/comment/v1" projectv1 "paperdebugger/pkg/gen/api/project/v1" sharedv1 "paperdebugger/pkg/gen/api/shared/v1" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" userv1 "paperdebugger/pkg/gen/api/user/v1" "github.com/gin-gonic/gin" @@ -30,8 +36,10 @@ import ( ) type Server struct { - grpcServer *GrpcServer - ginServer *GinServer + grpcServer *GrpcServer + ginServer *GinServer + pricingService *services.PricingService + aiClientV2 *aiclient.AIClientV2 logger *logger.Logger } @@ -39,16 +47,26 @@ type Server struct { func NewServer( grpcServer *GrpcServer, ginServer *GinServer, + pricingService *services.PricingService, + aiClientV2 *aiclient.AIClientV2, logger *logger.Logger, ) *Server { return &Server{ - grpcServer: grpcServer, - ginServer: ginServer, - logger: logger, + grpcServer: grpcServer, + ginServer: ginServer, + pricingService: pricingService, + aiClientV2: aiClientV2, + logger: logger, } } func (s *Server) Run(addr string) { + // Start the pricing updater in the background + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.pricingService.StartPriceUpdater(ctx) + listener, err := net.Listen("tcp", ":0") if err != nil { s.logger.Fatalf("failed to start grpc server listener: %v", err) @@ -105,6 +123,22 @@ func (s *Server) Run(addr string) { s.logger.Fatalf("failed to register comment service grpc gateway: %v", err) return } + err = usagev1.RegisterUsageServiceHandler(context.Background(), mux, client) + if err != nil { + s.logger.Fatalf("failed to register usage service grpc gateway: %v", err) + return + } + + // Set up signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigChan + s.logger.Info("[PAPERDEBUGGER] received shutdown signal, shutting down gracefully...") + s.Shutdown() + os.Exit(0) + }() s.logger.Infof("[PAPERDEBUGGER] http server listening on %s", addr) s.ginServer.Any("/_pd/api/*path", func(c *gin.Context) { mux.ServeHTTP(c.Writer, c.Request) }) @@ -114,6 +148,16 @@ func (s *Server) Run(addr string) { } } +// Shutdown gracefully shuts down all server components. +func (s *Server) Shutdown() { + s.logger.Info("[PAPERDEBUGGER] shutting down AI client (draining usage records)...") + s.aiClientV2.Shutdown() + s.logger.Info("[PAPERDEBUGGER] AI client shutdown complete") + + s.grpcServer.GracefulStop() + s.logger.Info("[PAPERDEBUGGER] gRPC server shutdown complete") +} + func (s *Server) metadataAnnotator() func(ctx context.Context, req *http.Request) metadata.MD { return func(ctx context.Context, req *http.Request) metadata.MD { md := metadata.New(map[string]string{}) diff --git a/internal/api/usage/get_session_usage.go b/internal/api/usage/get_session_usage.go new file mode 100644 index 00000000..8c6db098 --- /dev/null +++ b/internal/api/usage/get_session_usage.go @@ -0,0 +1,40 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +func (s *UsageServer) GetSessionUsage( + ctx context.Context, + req *usagev1.GetSessionUsageRequest, +) (*usagev1.GetSessionUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + // Get session with costs already calculated by the service layer + session, err := s.usageService.GetActiveSessionWithCosts(ctx, actor.ID) + if err != nil { + return nil, err + } + + if session == nil { + return &usagev1.GetSessionUsageResponse{ + Session: nil, + }, nil + } + + return &usagev1.GetSessionUsageResponse{ + Session: &usagev1.SessionUsage{ + SessionExpiry: timestamppb.New(session.SessionExpiry), + Models: convertModelsToProto(session.Models), + TotalCostUsd: session.TotalCostUSD, + }, + }, nil +} diff --git a/internal/api/usage/get_weekly_usage.go b/internal/api/usage/get_weekly_usage.go new file mode 100644 index 00000000..248d2c15 --- /dev/null +++ b/internal/api/usage/get_weekly_usage.go @@ -0,0 +1,32 @@ +package usage + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +func (s *UsageServer) GetWeeklyUsage( + ctx context.Context, + req *usagev1.GetWeeklyUsageRequest, +) (*usagev1.GetWeeklyUsageResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + // Get weekly stats with costs already calculated by the service layer + stats, err := s.usageService.GetWeeklyUsageWithCosts(ctx, actor.ID) + if err != nil { + return nil, err + } + + return &usagev1.GetWeeklyUsageResponse{ + Usage: &usagev1.WeeklyUsage{ + Models: convertModelsToProto(stats.Models), + SessionCount: stats.SessionCount, + TotalCostUsd: stats.TotalCostUSD, + }, + }, nil +} diff --git a/internal/api/usage/server.go b/internal/api/usage/server.go new file mode 100644 index 00000000..8221ac9d --- /dev/null +++ b/internal/api/usage/server.go @@ -0,0 +1,42 @@ +package usage + +import ( + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/services" + usagev1 "paperdebugger/pkg/gen/api/usage/v1" +) + +type UsageServer struct { + usagev1.UnimplementedUsageServiceServer + + usageService *services.UsageService + logger *logger.Logger +} + +func NewUsageServer( + usageService *services.UsageService, + logger *logger.Logger, +) usagev1.UsageServiceServer { + return &UsageServer{ + usageService: usageService, + logger: logger, + } +} + +// convertModelsToProto converts ModelUsageStats to proto format. +// Costs are already calculated by the service layer. +func convertModelsToProto(models map[string]*services.ModelUsageStats) map[string]*usagev1.ModelTokens { + protoModels := make(map[string]*usagev1.ModelTokens, len(models)) + + for modelName, stats := range models { + protoModels[modelName] = &usagev1.ModelTokens{ + PromptTokens: stats.PromptTokens, + CompletionTokens: stats.CompletionTokens, + TotalTokens: stats.TotalTokens, + RequestCount: stats.RequestCount, + CostUsd: stats.CostUSD, + } + } + + return protoModels +} diff --git a/internal/libs/db/db.go b/internal/libs/db/db.go index 52a5548c..7f76387d 100644 --- a/internal/libs/db/db.go +++ b/internal/libs/db/db.go @@ -6,6 +6,7 @@ import ( "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -43,5 +44,47 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) { } logger.Info("[MONGO] initialized") - return &DB{Client: client, cfg: cfg, logger: logger}, nil + + db := &DB{Client: client, cfg: cfg, logger: logger} + db.ensureIndexes() + return db, nil +} + +// ensureIndexes creates necessary indexes for the database collections. +func (db *DB) ensureIndexes() { + sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) + + // TTL index: auto-delete sessions after 30 days past their expiry time + _, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{{Key: "session_expiry", Value: 1}}, + Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60), + }) + if err != nil { + db.logger.Error("Failed to create TTL index on llm_sessions", "error", err) + } + + // Compound index for efficient active session lookups + _, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "session_expiry", Value: -1}, + }, + }) + if err != nil { + db.logger.Error("Failed to create compound index on llm_sessions", "error", err) + } + + // Unique compound index for session creation and queries. + // session_start is rounded to the second, so concurrent requests within the same + // second will conflict, triggering duplicate key handling in RecordUsage. + _, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "session_start", Value: -1}, + }, + Options: options.Index().SetUnique(true), + }) + if err != nil { + db.logger.Error("Failed to create session_start index on llm_sessions", "error", err) + } } diff --git a/internal/models/model_pricing.go b/internal/models/model_pricing.go new file mode 100644 index 00000000..adfbf114 --- /dev/null +++ b/internal/models/model_pricing.go @@ -0,0 +1,23 @@ +package models + +import ( + "time" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// ModelPricing stores the pricing information for an LLM model. +// Prices are in USD per token. +type ModelPricing struct { + ID bson.ObjectID `bson:"_id"` + ModelID string `bson:"model_id"` // e.g., "openai/gpt-4" + ModelSlug string `bson:"model_slug"` // e.g., "gpt-4" (short name used in our app) + Name string `bson:"name"` // e.g., "OpenAI: GPT-4" + PromptPrice float64 `bson:"prompt_price"` // USD per token + CompletionPrice float64 `bson:"completion_price"` // USD per token + UpdatedAt time.Time `bson:"updated_at"` +} + +func (m ModelPricing) CollectionName() string { + return "model_pricing" +} diff --git a/internal/models/usage.go b/internal/models/usage.go new file mode 100644 index 00000000..0f0aa6f2 --- /dev/null +++ b/internal/models/usage.go @@ -0,0 +1,25 @@ +package models + +import "go.mongodb.org/mongo-driver/v2/bson" + +// ModelTokens stores token counts for a specific model. +type ModelTokens struct { + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` +} + +// LLMSession represents a user's session for tracking LLM usage and token counts. +// Tokens are stored per model in the Models map. +type LLMSession struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + SessionStart bson.DateTime `bson:"session_start"` + SessionExpiry bson.DateTime `bson:"session_expiry"` + Models map[string]*ModelTokens `bson:"models"` +} + +func (s LLMSession) CollectionName() string { + return "llm_sessions" +} diff --git a/internal/services/pricing.go b/internal/services/pricing.go new file mode 100644 index 00000000..3d734759 --- /dev/null +++ b/internal/services/pricing.go @@ -0,0 +1,206 @@ +package services + +import ( + "context" + "encoding/json" + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +const ( + OpenRouterModelsURL = "https://openrouter.ai/api/v1/models" + PriceRefreshInterval = 24 * time.Hour +) + +type PricingService struct { + BaseService + collection *mongo.Collection + httpClient *http.Client +} + +// OpenRouterModel represents a model from the OpenRouter API. +type OpenRouterModel struct { + ID string `json:"id"` + Name string `json:"name"` + Pricing struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` + } `json:"pricing"` +} + +// OpenRouterResponse is the response from the OpenRouter models API. +type OpenRouterResponse struct { + Data []OpenRouterModel `json:"data"` +} + +func NewPricingService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *PricingService { + base := NewBaseService(db, cfg, logger) + return &PricingService{ + BaseService: base, + collection: base.db.Collection((models.ModelPricing{}).CollectionName()), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// FetchAndUpdatePrices fetches model prices from OpenRouter and updates the database. +func (s *PricingService) FetchAndUpdatePrices(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, OpenRouterModelsURL, nil) + if err != nil { + return err + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + var openRouterResp OpenRouterResponse + if err := json.NewDecoder(resp.Body).Decode(&openRouterResp); err != nil { + return err + } + + now := time.Now() + for _, model := range openRouterResp.Data { + promptPrice, _ := strconv.ParseFloat(model.Pricing.Prompt, 64) + completionPrice, _ := strconv.ParseFloat(model.Pricing.Completion, 64) + + // Skip models with no pricing + if promptPrice == 0 && completionPrice == 0 { + continue + } + + // Extract model slug (short name) from the full model ID + // e.g., "openai/gpt-4" -> "gpt-4" + modelSlug := extractModelSlug(model.ID) + + filter := bson.M{"model_id": model.ID} + update := bson.M{ + "$set": bson.M{ + "model_id": model.ID, + "model_slug": modelSlug, + "name": model.Name, + "prompt_price": promptPrice, + "completion_price": completionPrice, + "updated_at": now, + }, + "$setOnInsert": bson.M{ + "_id": bson.NewObjectID(), + }, + } + opts := options.UpdateOne().SetUpsert(true) + _, err := s.collection.UpdateOne(ctx, filter, update, opts) + if err != nil { + s.logger.Warn("Failed to update model pricing", "modelID", model.ID, "error", err) + } + } + + s.logger.Info("Updated model pricing", "count", len(openRouterResp.Data)) + return nil +} + +// GetPricing returns the pricing for a model by its slug. +func (s *PricingService) GetPricing(ctx context.Context, modelSlug string) (*models.ModelPricing, error) { + // Try exact match first + filter := bson.M{"model_slug": modelSlug} + var pricing models.ModelPricing + err := s.collection.FindOne(ctx, filter).Decode(&pricing) + if err == nil { + return &pricing, nil + } + if err != mongo.ErrNoDocuments { + return nil, err + } + + // Try partial match (model slug might be a prefix) + // Use QuoteMeta to escape any regex special characters in the model slug + filter = bson.M{"model_slug": bson.M{"$regex": "^" + regexp.QuoteMeta(modelSlug)}} + err = s.collection.FindOne(ctx, filter).Decode(&pricing) + if err == mongo.ErrNoDocuments { + return nil, nil + } + if err != nil { + return nil, err + } + return &pricing, nil +} + +// GetAllPricing returns all model pricing. +func (s *PricingService) GetAllPricing(ctx context.Context) ([]models.ModelPricing, error) { + cursor, err := s.collection.Find(ctx, bson.M{}) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var pricings []models.ModelPricing + if err := cursor.All(ctx, &pricings); err != nil { + return nil, err + } + return pricings, nil +} + +// GetPricingMap returns a map of model slug to pricing for quick lookup. +func (s *PricingService) GetPricingMap(ctx context.Context) (map[string]*models.ModelPricing, error) { + pricings, err := s.GetAllPricing(ctx) + if err != nil { + return nil, err + } + + result := make(map[string]*models.ModelPricing) + for i := range pricings { + result[pricings[i].ModelSlug] = &pricings[i] + } + return result, nil +} + +// extractModelSlug extracts the short model name from a full model ID. +// e.g., "openai/gpt-4" -> "gpt-4", "anthropic/claude-3-opus" -> "claude-3-opus" +func extractModelSlug(modelID string) string { + parts := strings.Split(modelID, "/") + if len(parts) > 1 { + return parts[len(parts)-1] + } + return modelID +} + +// StartPriceUpdater starts a background goroutine that periodically updates prices. +func (s *PricingService) StartPriceUpdater(ctx context.Context) { + // Fetch immediately on startup + go func() { + if err := s.FetchAndUpdatePrices(ctx); err != nil { + s.logger.Error("Failed to fetch initial model pricing", "error", err) + } + }() + + // Then fetch periodically + go func() { + ticker := time.NewTicker(PriceRefreshInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.FetchAndUpdatePrices(context.Background()); err != nil { + s.logger.Error("Failed to update model pricing", "error", err) + } + } + } + }() +} diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index 87a1e26a..d68dd191 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -1,6 +1,9 @@ package client import ( + "context" + "sync" + "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" "paperdebugger/internal/libs/logger" @@ -13,6 +16,8 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo" ) +const usageChannelBufferSize = 100 + type AIClientV2 struct { toolCallHandler *handler.ToolCallHandlerV2 db *mongo.Database @@ -20,8 +25,16 @@ type AIClientV2 struct { reverseCommentService *services.ReverseCommentService projectService *services.ProjectService + usageService *services.UsageService cfg *cfg.Cfg logger *logger.Logger + + // usageChan buffers usage records for async processing + usageChan chan services.UsageRecord + // usageWg tracks in-flight usage recording operations + usageWg sync.WaitGroup + // shutdownOnce ensures shutdown logic runs only once + shutdownOnce sync.Once } // SetOpenAIClient sets the appropriate OpenAI client based on the LLM provider config. @@ -60,6 +73,7 @@ func NewAIClientV2( reverseCommentService *services.ReverseCommentService, projectService *services.ProjectService, + usageService *services.UsageService, cfg *cfg.Cfg, logger *logger.Logger, ) *AIClientV2 { @@ -107,9 +121,53 @@ func NewAIClientV2( reverseCommentService: reverseCommentService, projectService: projectService, + usageService: usageService, cfg: cfg, logger: logger, + + usageChan: make(chan services.UsageRecord, usageChannelBufferSize), } + // Start the usage recording worker + client.usageWg.Add(1) + go client.usageWorker() + return client } + +// usageWorker processes usage records from the channel. +// It runs until the channel is closed during shutdown. +func (a *AIClientV2) usageWorker() { + defer a.usageWg.Done() + + for record := range a.usageChan { + ctx := context.Background() + if err := a.usageService.RecordUsage(ctx, record); err != nil { + a.logger.Error("Failed to store usage", "error", err) + } + } +} + +// RecordUsageAsync queues a usage record for async processing. +// Returns false if the channel is full (record dropped). +func (a *AIClientV2) RecordUsageAsync(record services.UsageRecord) bool { + select { + case a.usageChan <- record: + return true + default: + a.logger.Warn("Usage channel full, dropping record", + "userID", record.UserID, + "model", record.Model, + "tokens", record.TotalTokens) + return false + } +} + +// Shutdown gracefully stops the usage worker, ensuring all pending +// records are processed before returning. +func (a *AIClientV2) Shutdown() { + a.shutdownOnce.Do(func() { + close(a.usageChan) + a.usageWg.Wait() + }) +} diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index f10082bf..0aea3be8 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "paperdebugger/internal/models" + "paperdebugger/internal/services" "paperdebugger/internal/services/toolkit/handler" chatv2 "paperdebugger/pkg/gen/api/chat/v2" "strings" "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" ) // define []openai.ChatCompletionMessageParamUnion as OpenAIChatHistory @@ -25,8 +27,8 @@ import ( // 1. The full chat history sent to the language model (including any tool call results). // 2. The incremental chat history visible to the user (including tool call results and assistant responses). // 3. An error, if any occurred during the process. -func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { - openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider) +func (a *AIClientV2) ChatCompletionV2(ctx context.Context, userID bson.ObjectID, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { + openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, userID, "", modelSlug, messages, llmProvider) if err != nil { return nil, nil, err } @@ -54,7 +56,7 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes // - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop. // - If no tool calls are needed, it appends the assistant's response and exits the loop. // - Finally, it returns the updated chat histories and any error encountered. -func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { +func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, userID bson.ObjectID, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) { openaiChatHistory := messages inappChatHistory := AppChatHistory{} @@ -96,8 +98,17 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chunk := stream.Current() if len(chunk.Choices) == 0 { - // Handle usage information - // fmt.Printf("Usage: %+v\n", chunk.Usage) + // Handle usage information - only record for non-BYOK users + if chunk.Usage.TotalTokens > 0 && !llmProvider.IsCustom() { + // Queue usage record for async processing + a.RecordUsageAsync(services.UsageRecord{ + UserID: userID, + Model: modelSlug, + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + }) + } continue } @@ -185,7 +196,6 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream // answer_content += chunk.Choices[0].Delta.Content // fmt.Printf("answer_content: %s\n", answer_content) streamHandler.HandleTextDoneItem(chunk, answer_content, reasoning_content) - break } } diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go index 1995d590..5cc43ce5 100644 --- a/internal/services/toolkit/client/get_citation_keys.go +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -241,7 +241,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI // Bibliography is placed at the start of the prompt to leverage prompt caching message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userId, "gpt-5.2", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/get_citation_keys_test.go b/internal/services/toolkit/client/get_citation_keys_test.go index 4d2a857d..43cbaa6f 100644 --- a/internal/services/toolkit/client/get_citation_keys_test.go +++ b/internal/services/toolkit/client/get_citation_keys_test.go @@ -25,10 +25,13 @@ func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService } projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + pricingService := services.NewPricingService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + usageService := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger(), pricingService) aiClient := client.NewAIClientV2( dbInstance, &services.ReverseCommentService{}, projectService, + usageService, cfg.GetCfg(), logger.GetLogger(), ) diff --git a/internal/services/toolkit/client/get_conversation_title_v2.go b/internal/services/toolkit/client/get_conversation_title_v2.go index 6c92f0c2..f3fd5c8c 100644 --- a/internal/services/toolkit/client/get_conversation_title_v2.go +++ b/internal/services/toolkit/client/get_conversation_title_v2.go @@ -11,9 +11,10 @@ import ( "github.com/openai/openai-go/v3" "github.com/samber/lo" + "go.mongodb.org/mongo-driver/v2/bson" ) -func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { +func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, userID bson.ObjectID, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) { messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string { if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok { return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent()) @@ -29,7 +30,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor message := strings.Join(messages, "\n") message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{ + _, resp, err := a.ChatCompletionV2(ctx, userID, "gpt-5-nano", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."), openai.UserMessage(message), }, llmProvider) diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index 69e73071..47829575 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -74,6 +74,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } } @@ -85,6 +88,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) Tools: toolRegistry.GetTools(), // Tool registration is managed centrally by the registry ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), // Must set to false, because we are construct our own chat history. + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } diff --git a/internal/services/usage.go b/internal/services/usage.go new file mode 100644 index 00000000..756481f8 --- /dev/null +++ b/internal/services/usage.go @@ -0,0 +1,345 @@ +package services + +import ( + "context" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +const SessionDuration = 5 * time.Hour + +type UsageService struct { + BaseService + sessionCollection *mongo.Collection + pricingService *PricingService +} + +type UsageRecord struct { + UserID bson.ObjectID + Model string + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 +} + +// ModelUsageStats stores aggregated usage statistics for a specific model. +type ModelUsageStats struct { + PromptTokens int64 `bson:"prompt_tokens"` + CompletionTokens int64 `bson:"completion_tokens"` + TotalTokens int64 `bson:"total_tokens"` + RequestCount int64 `bson:"request_count"` + CostUSD float64 `bson:"-"` // Calculated field, not stored +} + +type UsageStats struct { + Models map[string]*ModelUsageStats `bson:"models"` + SessionCount int64 `bson:"session_count"` + TotalCostUSD float64 `bson:"-"` // Calculated field, not stored +} + +// SessionUsageStats represents session usage with calculated costs. +type SessionUsageStats struct { + SessionExpiry time.Time + Models map[string]*ModelUsageStats + TotalCostUSD float64 +} + +// calculateCosts calculates the cost in USD for each model and total. +// pricingMap maps model slug to pricing info. +func (s *UsageStats) calculateCosts(pricingMap map[string]*models.ModelPricing) { + s.TotalCostUSD = 0 + for modelSlug, stats := range s.Models { + if pricing, ok := pricingMap[modelSlug]; ok && pricing != nil { + stats.CostUSD = float64(stats.PromptTokens)*pricing.PromptPrice + + float64(stats.CompletionTokens)*pricing.CompletionPrice + s.TotalCostUSD += stats.CostUSD + } + } +} + +func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger, pricingService *PricingService) *UsageService { + base := NewBaseService(db, cfg, logger) + return &UsageService{ + BaseService: base, + sessionCollection: base.db.Collection((models.LLMSession{}).CollectionName()), + pricingService: pricingService, + } +} + +// RecordUsage updates the active session or creates a new one if none exists. +// Falls back to update if insert fails (handles race when another request created a session). +func (s *UsageService) RecordUsage(ctx context.Context, record UsageRecord) error { + now := time.Now() + // Round session_start to the second so concurrent requests get the same timestamp. + // Combined with the unique index on (user_id, session_start), this ensures only one + // session is created per user per second, with conflicts handled via duplicate key retry. + sessionStart := now.Truncate(time.Second) + nowBson := bson.DateTime(now.UnixMilli()) + + // Build field paths for per-model token storage + modelPrefix := "models." + record.Model + filter := bson.M{ + "user_id": record.UserID, + "session_expiry": bson.M{"$gt": nowBson}, + } + update := bson.M{ + "$inc": bson.M{ + modelPrefix + ".prompt_tokens": record.PromptTokens, + modelPrefix + ".completion_tokens": record.CompletionTokens, + modelPrefix + ".total_tokens": record.TotalTokens, + modelPrefix + ".request_count": 1, + }, + } + + result, err := s.sessionCollection.UpdateOne(ctx, filter, update) + if err != nil { + return err + } + if result.MatchedCount > 0 { + return nil + } + + // No active session found - create a new one + session := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: record.UserID, + SessionStart: bson.DateTime(sessionStart.UnixMilli()), + SessionExpiry: bson.DateTime(sessionStart.Add(SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + record.Model: { + PromptTokens: record.PromptTokens, + CompletionTokens: record.CompletionTokens, + TotalTokens: record.TotalTokens, + RequestCount: 1, + }, + }, + } + _, err = s.sessionCollection.InsertOne(ctx, session) + if err != nil { + // Only retry with update if insert failed due to duplicate key (race condition) + if mongo.IsDuplicateKeyError(err) { + _, updateErr := s.sessionCollection.UpdateOne(ctx, filter, update) + if updateErr != nil { + // Log both errors for debugging + s.logger.Warn("Insert failed with duplicate key, update also failed", + "insertErr", err, + "updateErr", updateErr, + "userID", record.UserID) + return updateErr + } + // Race condition handled successfully + return nil + } + // Insert failed for non-duplicate-key reason (network, validation, etc.) + return err + } + return nil +} + +// GetActiveSession returns the current active session for a user, if any. +func (s *UsageService) GetActiveSession(ctx context.Context, userID bson.ObjectID) (*models.LLMSession, error) { + now := bson.DateTime(time.Now().UnixMilli()) + filter := bson.M{ + "user_id": userID, + "session_expiry": bson.M{"$gt": now}, + } + + var session models.LLMSession + err := s.sessionCollection.FindOne(ctx, filter).Decode(&session) + if err == mongo.ErrNoDocuments { + return nil, nil + } + if err != nil { + return nil, err + } + return &session, nil +} + +// GetWeeklyUsage returns aggregated usage for a user for the current week (Monday-Sunday). +func (s *UsageService) GetWeeklyUsage(ctx context.Context, userID bson.ObjectID) (*UsageStats, error) { + weekStart := startOfWeek(time.Now()) + return s.getUsageSince(ctx, userID, weekStart) +} + +func (s *UsageService) getUsageSince(ctx context.Context, userID bson.ObjectID, since time.Time) (*UsageStats, error) { + pipeline := bson.A{ + bson.M{"$match": bson.M{ + "user_id": userID, + "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, + }}, + // Convert models map to array for aggregation + bson.M{"$project": bson.M{ + "models_array": bson.M{"$objectToArray": "$models"}, + "session_count": bson.M{"$literal": 1}, + }}, + // Unwind the models array to aggregate per model + bson.M{"$unwind": bson.M{ + "path": "$models_array", + "preserveNullAndEmptyArrays": true, + }}, + // Group by model name and sum tokens + bson.M{"$group": bson.M{ + "_id": "$models_array.k", + "prompt_tokens": bson.M{"$sum": "$models_array.v.prompt_tokens"}, + "completion_tokens": bson.M{"$sum": "$models_array.v.completion_tokens"}, + "total_tokens": bson.M{"$sum": "$models_array.v.total_tokens"}, + "request_count": bson.M{"$sum": "$models_array.v.request_count"}, + }}, + // Reshape into array of model stats + bson.M{"$group": bson.M{ + "_id": nil, + "models": bson.M{"$push": bson.M{ + "k": "$_id", + "v": bson.M{ + "prompt_tokens": "$prompt_tokens", + "completion_tokens": "$completion_tokens", + "total_tokens": "$total_tokens", + "request_count": "$request_count", + }, + }}, + }}, + // Convert back to object + bson.M{"$project": bson.M{ + "models": bson.M{"$arrayToObject": "$models"}, + }}, + } + + cursor, err := s.sessionCollection.Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + // Get session count separately (simpler query) + countPipeline := bson.A{ + bson.M{"$match": bson.M{ + "user_id": userID, + "session_start": bson.M{"$gte": bson.DateTime(since.UnixMilli())}, + }}, + bson.M{"$count": "session_count"}, + } + countCursor, err := s.sessionCollection.Aggregate(ctx, countPipeline) + if err != nil { + return nil, err + } + defer countCursor.Close(ctx) + + var sessionCount int64 + if countCursor.Next(ctx) { + var countResult struct { + SessionCount int64 `bson:"session_count"` + } + if err := countCursor.Decode(&countResult); err != nil { + return nil, err + } + sessionCount = countResult.SessionCount + } + + if cursor.Next(ctx) { + var result UsageStats + if err := cursor.Decode(&result); err != nil { + return nil, err + } + result.SessionCount = sessionCount + return &result, nil + } + return &UsageStats{Models: make(map[string]*ModelUsageStats)}, nil +} + +// startOfWeek returns the start of the week (Monday 00:00:00 UTC). +func startOfWeek(t time.Time) time.Time { + t = t.UTC() + daysFromMonday := (int(t.Weekday()) + 6) % 7 // Sunday=6, Monday=0, Tuesday=1, ... + return time.Date(t.Year(), t.Month(), t.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) +} + +// GetActiveSessionWithCosts returns the current active session with costs calculated. +func (s *UsageService) GetActiveSessionWithCosts(ctx context.Context, userID bson.ObjectID) (*SessionUsageStats, error) { + session, err := s.GetActiveSession(ctx, userID) + if err != nil { + return nil, err + } + if session == nil { + return nil, nil + } + + // Get pricing map for cost calculation + pricingMap, err := s.pricingService.GetPricingMap(ctx) + if err != nil { + s.logger.Warn("Failed to get pricing map for session costs", "error", err) + pricingMap = make(map[string]*models.ModelPricing) + } + + // Convert session models to ModelUsageStats and calculate costs + modelsWithCosts := make(map[string]*ModelUsageStats, len(session.Models)) + var totalCostUSD float64 + + for modelName, tokens := range session.Models { + stats := &ModelUsageStats{ + PromptTokens: tokens.PromptTokens, + CompletionTokens: tokens.CompletionTokens, + TotalTokens: tokens.TotalTokens, + RequestCount: tokens.RequestCount, + } + if pricing, ok := pricingMap[modelName]; ok && pricing != nil { + stats.CostUSD = float64(stats.PromptTokens)*pricing.PromptPrice + + float64(stats.CompletionTokens)*pricing.CompletionPrice + totalCostUSD += stats.CostUSD + } + modelsWithCosts[modelName] = stats + } + + return &SessionUsageStats{ + SessionExpiry: session.SessionExpiry.Time(), + Models: modelsWithCosts, + TotalCostUSD: totalCostUSD, + }, nil +} + +// GetWeeklyUsageWithCosts returns aggregated weekly usage with costs calculated. +func (s *UsageService) GetWeeklyUsageWithCosts(ctx context.Context, userID bson.ObjectID) (*UsageStats, error) { + stats, err := s.GetWeeklyUsage(ctx, userID) + if err != nil { + return nil, err + } + + // Get pricing map for cost calculation + pricingMap, err := s.pricingService.GetPricingMap(ctx) + if err != nil { + s.logger.Warn("Failed to get pricing map for weekly costs", "error", err) + pricingMap = make(map[string]*models.ModelPricing) + } + + // Calculate costs using the existing method + stats.calculateCosts(pricingMap) + + return stats, nil +} + +// ListRecentSessions returns the most recent sessions for a user. +func (s *UsageService) ListRecentSessions(ctx context.Context, userID bson.ObjectID, limit int64) ([]models.LLMSession, error) { + filter := bson.M{"user_id": userID} + opts := options.Find(). + SetSort(bson.D{{Key: "session_start", Value: -1}}). + SetLimit(limit) + + cursor, err := s.sessionCollection.Find(ctx, filter, opts) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var sessions []models.LLMSession + if err := cursor.All(ctx, &sessions); err != nil { + return nil, err + } + return sessions, nil +} diff --git a/internal/services/usage_test.go b/internal/services/usage_test.go new file mode 100644 index 00000000..02171946 --- /dev/null +++ b/internal/services/usage_test.go @@ -0,0 +1,698 @@ +package services_test + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + "paperdebugger/internal/services" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func setupTestUsageService(t *testing.T) (*services.UsageService, *mongo.Collection) { + os.Setenv("PD_MONGO_URI", "mongodb://localhost:27017") + dbInstance, err := db.NewDB(cfg.GetCfg(), logger.GetLogger()) + if err != nil { + t.Fatalf("failed to connect to test db: %v", err) + } + + pricingService := services.NewPricingService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + svc := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger(), pricingService) + collection := dbInstance.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName()) + + return svc, collection +} + +func cleanupSessions(t *testing.T, collection *mongo.Collection, userID bson.ObjectID) { + ctx := context.Background() + _, err := collection.DeleteMany(ctx, bson.M{"user_id": userID}) + if err != nil { + t.Logf("cleanup warning: %v", err) + } +} + +func TestUsageService_RecordUsage_NewSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + record := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + + err := svc.RecordUsage(ctx, record) + require.NoError(t, err) + + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + + assert.Equal(t, userID, session.UserID) + require.NotNil(t, session.Models) + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(100), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), session.Models["gpt-4"].RequestCount) + + // Verify session expiry is set correctly (5 hours from now) + now := time.Now() + expiryTime := time.UnixMilli(int64(session.SessionExpiry)) + expectedExpiry := now.Add(services.SessionDuration) + assert.WithinDuration(t, expectedExpiry, expiryTime, 2*time.Second) +} + +func TestUsageService_RecordUsage_ExistingActiveSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Record first usage (creates session) + record1 := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + err := svc.RecordUsage(ctx, record1) + require.NoError(t, err) + + // Record second usage to same session with same model + record2 := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record2) + require.NoError(t, err) + + // Verify tokens are accumulated for the model + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(150), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(275), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(425), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), session.Models["gpt-4"].RequestCount) +} + +func TestUsageService_RecordUsage_MultipleModels(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Record usage for gpt-4 + record1 := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + err := svc.RecordUsage(ctx, record1) + require.NoError(t, err) + + // Record usage for claude-3 + record2 := services.UsageRecord{ + UserID: userID, + Model: "claude-3", + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record2) + require.NoError(t, err) + + // Record more usage for gpt-4 + record3 := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 25, + CompletionTokens: 30, + TotalTokens: 55, + } + err = svc.RecordUsage(ctx, record3) + require.NoError(t, err) + + // Verify per-model token storage + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Models) + + // Check gpt-4 tokens (accumulated from 2 records) + require.NotNil(t, session.Models["gpt-4"]) + assert.Equal(t, int64(125), session.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(230), session.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(355), session.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), session.Models["gpt-4"].RequestCount) + + // Check claude-3 tokens (single record) + require.NotNil(t, session.Models["claude-3"]) + assert.Equal(t, int64(50), session.Models["claude-3"].PromptTokens) + assert.Equal(t, int64(75), session.Models["claude-3"].CompletionTokens) + assert.Equal(t, int64(125), session.Models["claude-3"].TotalTokens) + assert.Equal(t, int64(1), session.Models["claude-3"].RequestCount) + + // Verify weekly usage aggregates per model + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats.Models) + + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(125), stats.Models["gpt-4"].PromptTokens) + + require.NotNil(t, stats.Models["claude-3"]) + assert.Equal(t, int64(50), stats.Models["claude-3"].PromptTokens) +} + +func TestUsageService_RecordUsage_ExpiredSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create an expired session manually + now := time.Now() + expiredSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-6 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1 * time.Hour).UnixMilli()), // Expired 1 hour ago + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, + } + _, err := collection.InsertOne(ctx, expiredSession) + require.NoError(t, err) + + // Record new usage - should create a new session, not update the expired one + record := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + } + err = svc.RecordUsage(ctx, record) + require.NoError(t, err) + + // Get active session + activeSession, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + require.NotNil(t, activeSession) + + // Should be a new session with only the new usage + assert.NotEqual(t, expiredSession.ID, activeSession.ID) + require.NotNil(t, activeSession.Models["gpt-4"]) + assert.Equal(t, int64(50), activeSession.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(75), activeSession.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(125), activeSession.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), activeSession.Models["gpt-4"].RequestCount) +} + +func TestUsageService_RecordUsage_RaceCondition(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Simulate concurrent requests trying to create sessions + concurrentRequests := 10 + var wg sync.WaitGroup + errors := make([]error, concurrentRequests) + + // Use a channel to synchronize goroutine starts for maximum race condition + start := make(chan struct{}) + + for i := range concurrentRequests { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start // Wait for signal to start + record := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + } + errors[idx] = svc.RecordUsage(ctx, record) + }(i) + } + + // Start all goroutines at once + close(start) + wg.Wait() + + // All requests should succeed (no errors) + for i, err := range errors { + assert.NoError(t, err, "Request %d should not have errored", i) + } + + // Count total sessions created (should be 1 or possibly more if race occurred) + filter := bson.M{"user_id": userID} + count, err := collection.CountDocuments(ctx, filter) + require.NoError(t, err) + + // Get all sessions to see the full picture + cursor, err := collection.Find(ctx, filter) + require.NoError(t, err) + var sessions []models.LLMSession + err = cursor.All(ctx, &sessions) + require.NoError(t, err) + + // Calculate total usage across all sessions for all models + var totalPrompt, totalCompletion, totalTokens, totalRequests int64 + for _, s := range sessions { + for _, m := range s.Models { + totalPrompt += m.PromptTokens + totalCompletion += m.CompletionTokens + totalTokens += m.TotalTokens + totalRequests += m.RequestCount + } + } + + // All tokens should be accumulated in a single session + assert.Equal(t, int64(100), totalPrompt, "Expected 10 requests * 10 tokens each") + assert.Equal(t, int64(200), totalCompletion, "Expected 10 requests * 20 tokens each") + assert.Equal(t, int64(300), totalTokens, "Expected 10 requests * 30 tokens each") + assert.Equal(t, int64(10), totalRequests, "Expected 10 requests recorded") + + // With the unique index on (user_id, session_start) and second-level truncation, + // only one session should be created. Concurrent inserts trigger duplicate key + // errors which are handled by falling back to update. + assert.Equal(t, int64(1), count, "Expected exactly 1 session due to unique index") +} + +func TestUsageService_GetActiveSession_NoSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + session, err := svc.GetActiveSession(ctx, userID) + require.NoError(t, err) + assert.Nil(t, session) +} + +func TestUsageService_GetWeeklyUsage_SingleSession(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Record some usage + record := services.UsageRecord{ + UserID: userID, + Model: "gpt-4", + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + } + err := svc.RecordUsage(ctx, record) + require.NoError(t, err) + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(100), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(1), stats.Models["gpt-4"].RequestCount) + assert.Equal(t, int64(1), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_MultipleSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create multiple sessions within the current week + now := time.Now() + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), // 2 days ago + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), // 1 day ago + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 3, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), // Now + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 10, + }, + }, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Verify aggregation per model + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(350), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(575), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(925), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(18), stats.Models["gpt-4"].RequestCount) + assert.Equal(t, int64(3), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_ExcludesOldSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + now := time.Now() + + // Create an old session (from last week) + oldSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-10 * 24 * time.Hour).UnixMilli()), // 10 days ago + SessionExpiry: bson.DateTime(now.Add(-10*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 1000, + CompletionTokens: 2000, + TotalTokens: 3000, + RequestCount: 50, + }, + }, + } + _, err := collection.InsertOne(ctx, oldSession) + require.NoError(t, err) + + // Create a current session + currentSession := models.LLMSession{ + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 5, + }, + }, + } + _, err = collection.InsertOne(ctx, currentSession) + require.NoError(t, err) + + // Get weekly usage + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should only include the current session + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(100), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(300), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(5), stats.Models["gpt-4"].RequestCount) + assert.Equal(t, int64(1), stats.SessionCount) +} + +func TestUsageService_GetWeeklyUsage_NoSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should return empty models map + assert.Empty(t, stats.Models) + assert.Equal(t, int64(0), stats.SessionCount) +} + +func TestUsageService_ListRecentSessions(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Create multiple sessions at different times + now := time.Now() + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-3 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-3*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-2 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-2*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 150, + CompletionTokens: 250, + TotalTokens: 400, + RequestCount: 2, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(now.Add(-1 * 24 * time.Hour).UnixMilli()), + SessionExpiry: bson.DateTime(now.Add(-1*24*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 200, + CompletionTokens: 300, + TotalTokens: 500, + RequestCount: 3, + }, + }, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + // List recent sessions (limit 2) + recent, err := svc.ListRecentSessions(ctx, userID, 2) + require.NoError(t, err) + assert.Len(t, recent, 2) + + // Should be in reverse chronological order (most recent first) + assert.Equal(t, int64(200), recent[0].Models["gpt-4"].PromptTokens) // Most recent + assert.Equal(t, int64(150), recent[1].Models["gpt-4"].PromptTokens) // Second most recent + + // List all sessions + all, err := svc.ListRecentSessions(ctx, userID, 10) + require.NoError(t, err) + assert.Len(t, all, 3) +} + +func TestStartOfWeek(t *testing.T) { + tests := []struct { + name string + input time.Time + expected time.Time + }{ + { + name: "Monday should return same day at 00:00", + input: time.Date(2024, 1, 1, 15, 30, 45, 0, time.UTC), // Monday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Tuesday should return previous Monday", + input: time.Date(2024, 1, 2, 15, 30, 45, 0, time.UTC), // Tuesday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Sunday should return previous Monday", + input: time.Date(2024, 1, 7, 15, 30, 45, 0, time.UTC), // Sunday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Wednesday mid-week should return Monday", + input: time.Date(2024, 1, 3, 12, 0, 0, 0, time.UTC), // Wednesday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "Saturday should return previous Monday", + input: time.Date(2024, 1, 6, 23, 59, 59, 0, time.UTC), // Saturday + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We need to test the private function indirectly via GetWeeklyUsage + // But for this specific test, we'll verify the logic manually + input := tt.input.UTC() + daysFromMonday := (int(input.Weekday()) + 6) % 7 + result := time.Date(input.Year(), input.Month(), input.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) + + assert.Equal(t, tt.expected, result) + assert.Equal(t, time.Monday, result.Weekday(), "Start of week should be Monday") + }) + } +} + +func TestUsageService_GetWeeklyUsage_WeekBoundary(t *testing.T) { + svc, collection := setupTestUsageService(t) + ctx := context.Background() + userID := bson.NewObjectID() + defer cleanupSessions(t, collection, userID) + + // Get the start of this week + now := time.Now().UTC() + daysFromMonday := (int(now.Weekday()) + 6) % 7 + weekStart := time.Date(now.Year(), now.Month(), now.Day()-daysFromMonday, 0, 0, 0, 0, time.UTC) + + // Create sessions on both sides of the week boundary + sessions := []models.LLMSession{ + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(-1 * time.Hour).UnixMilli()), // Just before week start + SessionExpiry: bson.DateTime(weekStart.Add(-1*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 100, + CompletionTokens: 200, + TotalTokens: 300, + RequestCount: 1, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.UnixMilli()), // Exactly at week start + SessionExpiry: bson.DateTime(weekStart.Add(services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 50, + CompletionTokens: 75, + TotalTokens: 125, + RequestCount: 1, + }, + }, + }, + { + ID: bson.NewObjectID(), + UserID: userID, + SessionStart: bson.DateTime(weekStart.Add(1 * time.Hour).UnixMilli()), // Just after week start + SessionExpiry: bson.DateTime(weekStart.Add(1*time.Hour + services.SessionDuration).UnixMilli()), + Models: map[string]*models.ModelTokens{ + "gpt-4": { + PromptTokens: 25, + CompletionTokens: 50, + TotalTokens: 75, + RequestCount: 1, + }, + }, + }, + } + + for _, session := range sessions { + _, err := collection.InsertOne(ctx, session) + require.NoError(t, err) + } + + stats, err := svc.GetWeeklyUsage(ctx, userID) + require.NoError(t, err) + require.NotNil(t, stats) + + // Should only include sessions at or after week start (last 2 sessions) + require.NotNil(t, stats.Models) + require.NotNil(t, stats.Models["gpt-4"]) + assert.Equal(t, int64(75), stats.Models["gpt-4"].PromptTokens) + assert.Equal(t, int64(125), stats.Models["gpt-4"].CompletionTokens) + assert.Equal(t, int64(200), stats.Models["gpt-4"].TotalTokens) + assert.Equal(t, int64(2), stats.Models["gpt-4"].RequestCount) + assert.Equal(t, int64(2), stats.SessionCount) +} diff --git a/internal/wire.go b/internal/wire.go index f823bc2e..8c7a111e 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -9,6 +9,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -32,6 +33,7 @@ var Set = wire.NewSet( user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, + usage.NewUsageServer, aiclient.NewAIClient, aiclient.NewAIClientV2, @@ -43,6 +45,8 @@ var Set = wire.NewSet( services.NewProjectService, services.NewPromptService, services.NewOAuthService, + services.NewUsageService, + services.NewPricingService, cfg.GetCfg, logger.GetLogger, diff --git a/internal/wire_gen.go b/internal/wire_gen.go index 75c4e91a..da5cafb0 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -13,6 +13,7 @@ import ( "paperdebugger/internal/api/chat" "paperdebugger/internal/api/comment" "paperdebugger/internal/api/project" + "paperdebugger/internal/api/usage" "paperdebugger/internal/api/user" "paperdebugger/internal/libs/cfg" "paperdebugger/internal/libs/db" @@ -38,21 +39,24 @@ func InitializeApp() (*api.Server, error) { aiClient := client.NewAIClient(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) chatService := services.NewChatService(dbDB, cfgCfg, loggerLogger) chatServiceServer := chat.NewChatServer(aiClient, chatService, projectService, userService, loggerLogger, cfgCfg) - aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) + pricingService := services.NewPricingService(dbDB, cfgCfg, loggerLogger) + usageService := services.NewUsageService(dbDB, cfgCfg, loggerLogger, pricingService) + aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, usageService, cfgCfg, loggerLogger) chatServiceV2 := services.NewChatServiceV2(dbDB, cfgCfg, loggerLogger) chatv2ChatServiceServer := chat.NewChatServerV2(aiClientV2, chatServiceV2, projectService, userService, loggerLogger, cfgCfg) promptService := services.NewPromptService(dbDB, cfgCfg, loggerLogger) userServiceServer := user.NewUserServer(userService, promptService, cfgCfg, loggerLogger) projectServiceServer := project.NewProjectServer(projectService, loggerLogger, cfgCfg) commentServiceServer := comment.NewCommentServer(projectService, chatService, reverseCommentService, loggerLogger, cfgCfg) - grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer) + usageServiceServer := usage.NewUsageServer(usageService, loggerLogger) + grpcServer := api.NewGrpcServer(userService, cfgCfg, authServiceServer, chatServiceServer, chatv2ChatServiceServer, userServiceServer, projectServiceServer, commentServiceServer, usageServiceServer) oAuthService := services.NewOAuthService(dbDB, cfgCfg, loggerLogger) oAuthHandler := auth.NewOAuthHandler(oAuthService) ginServer := api.NewGinServer(cfgCfg, oAuthHandler) - server := api.NewServer(grpcServer, ginServer, loggerLogger) + server := api.NewServer(grpcServer, ginServer, pricingService, aiClientV2, loggerLogger) return server, nil } // wire.go: -var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, cfg.GetCfg, logger.GetLogger, db.NewDB) +var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, usage.NewUsageServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, services.NewPricingService, cfg.GetCfg, logger.GetLogger, db.NewDB) diff --git a/pkg/gen/api/chat/v2/chat.pb.go b/pkg/gen/api/chat/v2/chat.pb.go index 0d312c55..485bfd0f 100644 --- a/pkg/gen/api/chat/v2/chat.pb.go +++ b/pkg/gen/api/chat/v2/chat.pb.go @@ -7,13 +7,12 @@ package chatv2 import ( - reflect "reflect" - sync "sync" - unsafe "unsafe" - _ "google.golang.org/genproto/googleapis/api/annotations" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" ) const ( diff --git a/pkg/gen/api/usage/v1/usage.pb.go b/pkg/gen/api/usage/v1/usage.pb.go new file mode 100644 index 00000000..33530afd --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.go @@ -0,0 +1,488 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ModelTokens struct { + state protoimpl.MessageState `protogen:"open.v1"` + PromptTokens int64 `protobuf:"varint,1,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int64 `protobuf:"varint,2,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + TotalTokens int64 `protobuf:"varint,3,opt,name=total_tokens,json=totalTokens,proto3" json:"total_tokens,omitempty"` + RequestCount int64 `protobuf:"varint,4,opt,name=request_count,json=requestCount,proto3" json:"request_count,omitempty"` + CostUsd float64 `protobuf:"fixed64,5,opt,name=cost_usd,json=costUsd,proto3" json:"cost_usd,omitempty"` // Cost in USD for this model + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ModelTokens) Reset() { + *x = ModelTokens{} + mi := &file_usage_v1_usage_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ModelTokens) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelTokens) ProtoMessage() {} + +func (x *ModelTokens) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelTokens.ProtoReflect.Descriptor instead. +func (*ModelTokens) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{0} +} + +func (x *ModelTokens) GetPromptTokens() int64 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *ModelTokens) GetCompletionTokens() int64 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *ModelTokens) GetTotalTokens() int64 { + if x != nil { + return x.TotalTokens + } + return 0 +} + +func (x *ModelTokens) GetRequestCount() int64 { + if x != nil { + return x.RequestCount + } + return 0 +} + +func (x *ModelTokens) GetCostUsd() float64 { + if x != nil { + return x.CostUsd + } + return 0 +} + +type SessionUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionExpiry *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` + // Tokens per model (model_slug -> tokens) + Models map[string]*ModelTokens `protobuf:"bytes,2,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + TotalCostUsd float64 `protobuf:"fixed64,3,opt,name=total_cost_usd,json=totalCostUsd,proto3" json:"total_cost_usd,omitempty"` // Total cost in USD across all models + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SessionUsage) Reset() { + *x = SessionUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SessionUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SessionUsage) ProtoMessage() {} + +func (x *SessionUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SessionUsage.ProtoReflect.Descriptor instead. +func (*SessionUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{1} +} + +func (x *SessionUsage) GetSessionExpiry() *timestamppb.Timestamp { + if x != nil { + return x.SessionExpiry + } + return nil +} + +func (x *SessionUsage) GetModels() map[string]*ModelTokens { + if x != nil { + return x.Models + } + return nil +} + +func (x *SessionUsage) GetTotalCostUsd() float64 { + if x != nil { + return x.TotalCostUsd + } + return 0 +} + +type WeeklyUsage struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Tokens per model (model_slug -> tokens) + Models map[string]*ModelTokens `protobuf:"bytes,1,rep,name=models,proto3" json:"models,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + SessionCount int64 `protobuf:"varint,2,opt,name=session_count,json=sessionCount,proto3" json:"session_count,omitempty"` + TotalCostUsd float64 `protobuf:"fixed64,3,opt,name=total_cost_usd,json=totalCostUsd,proto3" json:"total_cost_usd,omitempty"` // Total cost in USD across all models + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WeeklyUsage) Reset() { + *x = WeeklyUsage{} + mi := &file_usage_v1_usage_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WeeklyUsage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WeeklyUsage) ProtoMessage() {} + +func (x *WeeklyUsage) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WeeklyUsage.ProtoReflect.Descriptor instead. +func (*WeeklyUsage) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{2} +} + +func (x *WeeklyUsage) GetModels() map[string]*ModelTokens { + if x != nil { + return x.Models + } + return nil +} + +func (x *WeeklyUsage) GetSessionCount() int64 { + if x != nil { + return x.SessionCount + } + return 0 +} + +func (x *WeeklyUsage) GetTotalCostUsd() float64 { + if x != nil { + return x.TotalCostUsd + } + return 0 +} + +type GetSessionUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageRequest) Reset() { + *x = GetSessionUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageRequest) ProtoMessage() {} + +func (x *GetSessionUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageRequest.ProtoReflect.Descriptor instead. +func (*GetSessionUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{3} +} + +type GetSessionUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Active session usage, null if no active session + Session *SessionUsage `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetSessionUsageResponse) Reset() { + *x = GetSessionUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetSessionUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetSessionUsageResponse) ProtoMessage() {} + +func (x *GetSessionUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetSessionUsageResponse.ProtoReflect.Descriptor instead. +func (*GetSessionUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{4} +} + +func (x *GetSessionUsageResponse) GetSession() *SessionUsage { + if x != nil { + return x.Session + } + return nil +} + +type GetWeeklyUsageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageRequest) Reset() { + *x = GetWeeklyUsageRequest{} + mi := &file_usage_v1_usage_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageRequest) ProtoMessage() {} + +func (x *GetWeeklyUsageRequest) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageRequest.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageRequest) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{5} +} + +type GetWeeklyUsageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Usage *WeeklyUsage `protobuf:"bytes,1,opt,name=usage,proto3" json:"usage,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWeeklyUsageResponse) Reset() { + *x = GetWeeklyUsageResponse{} + mi := &file_usage_v1_usage_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWeeklyUsageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWeeklyUsageResponse) ProtoMessage() {} + +func (x *GetWeeklyUsageResponse) ProtoReflect() protoreflect.Message { + mi := &file_usage_v1_usage_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWeeklyUsageResponse.ProtoReflect.Descriptor instead. +func (*GetWeeklyUsageResponse) Descriptor() ([]byte, []int) { + return file_usage_v1_usage_proto_rawDescGZIP(), []int{6} +} + +func (x *GetWeeklyUsageResponse) GetUsage() *WeeklyUsage { + if x != nil { + return x.Usage + } + return nil +} + +var File_usage_v1_usage_proto protoreflect.FileDescriptor + +const file_usage_v1_usage_proto_rawDesc = "" + + "\n" + + "\x14usage/v1/usage.proto\x12\busage.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc2\x01\n" + + "\vModelTokens\x12#\n" + + "\rprompt_tokens\x18\x01 \x01(\x03R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x02 \x01(\x03R\x10completionTokens\x12!\n" + + "\ftotal_tokens\x18\x03 \x01(\x03R\vtotalTokens\x12#\n" + + "\rrequest_count\x18\x04 \x01(\x03R\frequestCount\x12\x19\n" + + "\bcost_usd\x18\x05 \x01(\x01R\acostUsd\"\x85\x02\n" + + "\fSessionUsage\x12A\n" + + "\x0esession_expiry\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\rsessionExpiry\x12:\n" + + "\x06models\x18\x02 \x03(\v2\".usage.v1.SessionUsage.ModelsEntryR\x06models\x12$\n" + + "\x0etotal_cost_usd\x18\x03 \x01(\x01R\ftotalCostUsd\x1aP\n" + + "\vModelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\xe5\x01\n" + + "\vWeeklyUsage\x129\n" + + "\x06models\x18\x01 \x03(\v2!.usage.v1.WeeklyUsage.ModelsEntryR\x06models\x12#\n" + + "\rsession_count\x18\x02 \x01(\x03R\fsessionCount\x12$\n" + + "\x0etotal_cost_usd\x18\x03 \x01(\x01R\ftotalCostUsd\x1aP\n" + + "\vModelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + + "\x05value\x18\x02 \x01(\v2\x15.usage.v1.ModelTokensR\x05value:\x028\x01\"\x18\n" + + "\x16GetSessionUsageRequest\"K\n" + + "\x17GetSessionUsageResponse\x120\n" + + "\asession\x18\x01 \x01(\v2\x16.usage.v1.SessionUsageR\asession\"\x17\n" + + "\x15GetWeeklyUsageRequest\"E\n" + + "\x16GetWeeklyUsageResponse\x12+\n" + + "\x05usage\x18\x01 \x01(\v2\x15.usage.v1.WeeklyUsageR\x05usage2\x9a\x02\n" + + "\fUsageService\x12\x85\x01\n" + + "\x0fGetSessionUsage\x12 .usage.v1.GetSessionUsageRequest\x1a!.usage.v1.GetSessionUsageResponse\"-\x82\xd3\xe4\x93\x02'\x12%/_pd/api/v1/users/@self/usage/session\x12\x81\x01\n" + + "\x0eGetWeeklyUsage\x12\x1f.usage.v1.GetWeeklyUsageRequest\x1a .usage.v1.GetWeeklyUsageResponse\",\x82\xd3\xe4\x93\x02&\x12$/_pd/api/v1/users/@self/usage/weeklyB\x87\x01\n" + + "\fcom.usage.v1B\n" + + "UsageProtoP\x01Z*paperdebugger/pkg/gen/api/usage/v1;usagev1\xa2\x02\x03UXX\xaa\x02\bUsage.V1\xca\x02\bUsage\\V1\xe2\x02\x14Usage\\V1\\GPBMetadata\xea\x02\tUsage::V1b\x06proto3" + +var ( + file_usage_v1_usage_proto_rawDescOnce sync.Once + file_usage_v1_usage_proto_rawDescData []byte +) + +func file_usage_v1_usage_proto_rawDescGZIP() []byte { + file_usage_v1_usage_proto_rawDescOnce.Do(func() { + file_usage_v1_usage_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc))) + }) + return file_usage_v1_usage_proto_rawDescData +} + +var file_usage_v1_usage_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_usage_v1_usage_proto_goTypes = []any{ + (*ModelTokens)(nil), // 0: usage.v1.ModelTokens + (*SessionUsage)(nil), // 1: usage.v1.SessionUsage + (*WeeklyUsage)(nil), // 2: usage.v1.WeeklyUsage + (*GetSessionUsageRequest)(nil), // 3: usage.v1.GetSessionUsageRequest + (*GetSessionUsageResponse)(nil), // 4: usage.v1.GetSessionUsageResponse + (*GetWeeklyUsageRequest)(nil), // 5: usage.v1.GetWeeklyUsageRequest + (*GetWeeklyUsageResponse)(nil), // 6: usage.v1.GetWeeklyUsageResponse + nil, // 7: usage.v1.SessionUsage.ModelsEntry + nil, // 8: usage.v1.WeeklyUsage.ModelsEntry + (*timestamppb.Timestamp)(nil), // 9: google.protobuf.Timestamp +} +var file_usage_v1_usage_proto_depIdxs = []int32{ + 9, // 0: usage.v1.SessionUsage.session_expiry:type_name -> google.protobuf.Timestamp + 7, // 1: usage.v1.SessionUsage.models:type_name -> usage.v1.SessionUsage.ModelsEntry + 8, // 2: usage.v1.WeeklyUsage.models:type_name -> usage.v1.WeeklyUsage.ModelsEntry + 1, // 3: usage.v1.GetSessionUsageResponse.session:type_name -> usage.v1.SessionUsage + 2, // 4: usage.v1.GetWeeklyUsageResponse.usage:type_name -> usage.v1.WeeklyUsage + 0, // 5: usage.v1.SessionUsage.ModelsEntry.value:type_name -> usage.v1.ModelTokens + 0, // 6: usage.v1.WeeklyUsage.ModelsEntry.value:type_name -> usage.v1.ModelTokens + 3, // 7: usage.v1.UsageService.GetSessionUsage:input_type -> usage.v1.GetSessionUsageRequest + 5, // 8: usage.v1.UsageService.GetWeeklyUsage:input_type -> usage.v1.GetWeeklyUsageRequest + 4, // 9: usage.v1.UsageService.GetSessionUsage:output_type -> usage.v1.GetSessionUsageResponse + 6, // 10: usage.v1.UsageService.GetWeeklyUsage:output_type -> usage.v1.GetWeeklyUsageResponse + 9, // [9:11] is the sub-list for method output_type + 7, // [7:9] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_usage_v1_usage_proto_init() } +func file_usage_v1_usage_proto_init() { + if File_usage_v1_usage_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_usage_v1_usage_proto_rawDesc), len(file_usage_v1_usage_proto_rawDesc)), + NumEnums: 0, + NumMessages: 9, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_usage_v1_usage_proto_goTypes, + DependencyIndexes: file_usage_v1_usage_proto_depIdxs, + MessageInfos: file_usage_v1_usage_proto_msgTypes, + }.Build() + File_usage_v1_usage_proto = out.File + file_usage_v1_usage_proto_goTypes = nil + file_usage_v1_usage_proto_depIdxs = nil +} diff --git a/pkg/gen/api/usage/v1/usage.pb.gw.go b/pkg/gen/api/usage/v1/usage.pb.gw.go new file mode 100644 index 00000000..3a455736 --- /dev/null +++ b/pkg/gen/api/usage/v1/usage.pb.gw.go @@ -0,0 +1,211 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: usage/v1/usage.proto + +/* +Package usagev1 is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package usagev1 + +import ( + "context" + "errors" + "io" + "net/http" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// Suppress "imported and not used" errors +var ( + _ codes.Code + _ io.Reader + _ status.Status + _ = errors.New + _ = runtime.String + _ = utilities.NewDoubleArray + _ = metadata.Join +) + +func request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetSessionUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetSessionUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetSessionUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetSessionUsage(ctx, &protoReq) + return msg, metadata, err +} + +func request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, client UsageServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetWeeklyUsage(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_UsageService_GetWeeklyUsage_0(ctx context.Context, marshaler runtime.Marshaler, server UsageServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetWeeklyUsageRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetWeeklyUsage(ctx, &protoReq) + return msg, metadata, err +} + +// RegisterUsageServiceHandlerServer registers the http handlers for service UsageService to "mux". +// UnaryRPC :call UsageServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterUsageServiceHandlerFromEndpoint instead. +// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call. +func RegisterUsageServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server UsageServiceServer) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + + return nil +} + +// RegisterUsageServiceHandlerFromEndpoint is same as RegisterUsageServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterUsageServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.NewClient(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + return RegisterUsageServiceHandler(ctx, mux, conn) +} + +// RegisterUsageServiceHandler registers the http handlers for service UsageService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterUsageServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterUsageServiceHandlerClient(ctx, mux, NewUsageServiceClient(conn)) +} + +// RegisterUsageServiceHandlerClient registers the http handlers for service UsageService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "UsageServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "UsageServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "UsageServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares. +func RegisterUsageServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client UsageServiceClient) error { + mux.Handle(http.MethodGet, pattern_UsageService_GetSessionUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetSessionUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/session")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetSessionUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetSessionUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_UsageService_GetWeeklyUsage_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/usage.v1.UsageService/GetWeeklyUsage", runtime.WithHTTPPathPattern("/_pd/api/v1/users/@self/usage/weekly")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_UsageService_GetWeeklyUsage_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_UsageService_GetWeeklyUsage_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + return nil +} + +var ( + pattern_UsageService_GetSessionUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "session"}, "")) + pattern_UsageService_GetWeeklyUsage_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6}, []string{"_pd", "api", "v1", "users", "@self", "usage", "weekly"}, "")) +) + +var ( + forward_UsageService_GetSessionUsage_0 = runtime.ForwardResponseMessage + forward_UsageService_GetWeeklyUsage_0 = runtime.ForwardResponseMessage +) diff --git a/pkg/gen/api/usage/v1/usage_grpc.pb.go b/pkg/gen/api/usage/v1/usage_grpc.pb.go new file mode 100644 index 00000000..7d33c1dd --- /dev/null +++ b/pkg/gen/api/usage/v1/usage_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc (unknown) +// source: usage/v1/usage.proto + +package usagev1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + UsageService_GetSessionUsage_FullMethodName = "/usage.v1.UsageService/GetSessionUsage" + UsageService_GetWeeklyUsage_FullMethodName = "/usage.v1.UsageService/GetWeeklyUsage" +) + +// UsageServiceClient is the client API for UsageService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type UsageServiceClient interface { + GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) + GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) +} + +type usageServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewUsageServiceClient(cc grpc.ClientConnInterface) UsageServiceClient { + return &usageServiceClient{cc} +} + +func (c *usageServiceClient) GetSessionUsage(ctx context.Context, in *GetSessionUsageRequest, opts ...grpc.CallOption) (*GetSessionUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetSessionUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetSessionUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *usageServiceClient) GetWeeklyUsage(ctx context.Context, in *GetWeeklyUsageRequest, opts ...grpc.CallOption) (*GetWeeklyUsageResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetWeeklyUsageResponse) + err := c.cc.Invoke(ctx, UsageService_GetWeeklyUsage_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// UsageServiceServer is the server API for UsageService service. +// All implementations must embed UnimplementedUsageServiceServer +// for forward compatibility. +type UsageServiceServer interface { + GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) + GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) + mustEmbedUnimplementedUsageServiceServer() +} + +// UnimplementedUsageServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedUsageServiceServer struct{} + +func (UnimplementedUsageServiceServer) GetSessionUsage(context.Context, *GetSessionUsageRequest) (*GetSessionUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetSessionUsage not implemented") +} +func (UnimplementedUsageServiceServer) GetWeeklyUsage(context.Context, *GetWeeklyUsageRequest) (*GetWeeklyUsageResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetWeeklyUsage not implemented") +} +func (UnimplementedUsageServiceServer) mustEmbedUnimplementedUsageServiceServer() {} +func (UnimplementedUsageServiceServer) testEmbeddedByValue() {} + +// UnsafeUsageServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to UsageServiceServer will +// result in compilation errors. +type UnsafeUsageServiceServer interface { + mustEmbedUnimplementedUsageServiceServer() +} + +func RegisterUsageServiceServer(s grpc.ServiceRegistrar, srv UsageServiceServer) { + // If the following call panics, it indicates UnimplementedUsageServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&UsageService_ServiceDesc, srv) +} + +func _UsageService_GetSessionUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetSessionUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetSessionUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetSessionUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetSessionUsage(ctx, req.(*GetSessionUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _UsageService_GetWeeklyUsage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetWeeklyUsageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UsageService_GetWeeklyUsage_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UsageServiceServer).GetWeeklyUsage(ctx, req.(*GetWeeklyUsageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// UsageService_ServiceDesc is the grpc.ServiceDesc for UsageService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var UsageService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "usage.v1.UsageService", + HandlerType: (*UsageServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetSessionUsage", + Handler: _UsageService_GetSessionUsage_Handler, + }, + { + MethodName: "GetWeeklyUsage", + Handler: _UsageService_GetWeeklyUsage_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "usage/v1/usage.proto", +} diff --git a/proto/usage/v1/usage.proto b/proto/usage/v1/usage.proto new file mode 100644 index 00000000..f5f480ce --- /dev/null +++ b/proto/usage/v1/usage.proto @@ -0,0 +1,53 @@ +syntax = "proto3"; + +package usage.v1; + +import "google/api/annotations.proto"; +import "google/protobuf/timestamp.proto"; + +option go_package = "paperdebugger/pkg/gen/api/usage/v1;usagev1"; + +service UsageService { + rpc GetSessionUsage(GetSessionUsageRequest) returns (GetSessionUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/session"}; + } + + rpc GetWeeklyUsage(GetWeeklyUsageRequest) returns (GetWeeklyUsageResponse) { + option (google.api.http) = {get: "/_pd/api/v1/users/@self/usage/weekly"}; + } +} + +message ModelTokens { + int64 prompt_tokens = 1; + int64 completion_tokens = 2; + int64 total_tokens = 3; + int64 request_count = 4; + double cost_usd = 5; // Cost in USD for this model +} + +message SessionUsage { + google.protobuf.Timestamp session_expiry = 1; + // Tokens per model (model_slug -> tokens) + map models = 2; + double total_cost_usd = 3; // Total cost in USD across all models +} + +message WeeklyUsage { + // Tokens per model (model_slug -> tokens) + map models = 1; + int64 session_count = 2; + double total_cost_usd = 3; // Total cost in USD across all models +} + +message GetSessionUsageRequest {} + +message GetSessionUsageResponse { + // Active session usage, null if no active session + SessionUsage session = 1; +} + +message GetWeeklyUsageRequest {} + +message GetWeeklyUsageResponse { + WeeklyUsage usage = 1; +} diff --git a/webapp/_webapp/src/paperdebugger.tsx b/webapp/_webapp/src/paperdebugger.tsx index 5cdc5e5d..172a897e 100644 --- a/webapp/_webapp/src/paperdebugger.tsx +++ b/webapp/_webapp/src/paperdebugger.tsx @@ -2,6 +2,7 @@ import { Chat } from "./views/chat"; import { Tabs } from "./components/tabs"; import { Settings } from "./views/settings"; import { Prompts } from "./views/prompts"; +import { Usage } from "./views/usage"; import { PdAppBodyContainer } from "./components/pd-app-body-container"; export const PaperDebugger = () => { @@ -23,6 +24,13 @@ export const PaperDebugger = () => { children: , tooltip: "Prompt Library", }, + { + key: "usage", + title: "Usage", + icon: "tabler:chart-bar", + children: , + tooltip: "Usage Statistics", + }, { key: "settings", title: "Settings", diff --git a/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts new file mode 100644 index 00000000..e38175ee --- /dev/null +++ b/webapp/_webapp/src/pkg/gen/apiclient/usage/v1/usage_pb.ts @@ -0,0 +1,203 @@ +// @generated by protoc-gen-es v2.11.0 with parameter "target=ts" +// @generated from file usage/v1/usage.proto (package usage.v1, syntax proto3) +/* eslint-disable */ + +import type { GenFile, GenMessage, GenService } from "@bufbuild/protobuf/codegenv2"; +import { fileDesc, messageDesc, serviceDesc } from "@bufbuild/protobuf/codegenv2"; +import { file_google_api_annotations } from "@buf/googleapis_googleapis.bufbuild_es/google/api/annotations_pb"; +import type { Timestamp } from "@bufbuild/protobuf/wkt"; +import { file_google_protobuf_timestamp } from "@bufbuild/protobuf/wkt"; +import type { Message } from "@bufbuild/protobuf"; + +/** + * Describes the file usage/v1/usage.proto. + */ +export const file_usage_v1_usage: GenFile = /*@__PURE__*/ + fileDesc("ChR1c2FnZS92MS91c2FnZS5wcm90bxIIdXNhZ2UudjEifgoLTW9kZWxUb2tlbnMSFQoNcHJvbXB0X3Rva2VucxgBIAEoAxIZChFjb21wbGV0aW9uX3Rva2VucxgCIAEoAxIUCgx0b3RhbF90b2tlbnMYAyABKAMSFQoNcmVxdWVzdF9jb3VudBgEIAEoAxIQCghjb3N0X3VzZBgFIAEoASLUAQoMU2Vzc2lvblVzYWdlEjIKDnNlc3Npb25fZXhwaXJ5GAEgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcBIyCgZtb2RlbHMYAiADKAsyIi51c2FnZS52MS5TZXNzaW9uVXNhZ2UuTW9kZWxzRW50cnkSFgoOdG90YWxfY29zdF91c2QYAyABKAEaRAoLTW9kZWxzRW50cnkSCwoDa2V5GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudXNhZ2UudjEuTW9kZWxUb2tlbnM6AjgBIrUBCgtXZWVrbHlVc2FnZRIxCgZtb2RlbHMYASADKAsyIS51c2FnZS52MS5XZWVrbHlVc2FnZS5Nb2RlbHNFbnRyeRIVCg1zZXNzaW9uX2NvdW50GAIgASgDEhYKDnRvdGFsX2Nvc3RfdXNkGAMgASgBGkQKC01vZGVsc0VudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEoCzIVLnVzYWdlLnYxLk1vZGVsVG9rZW5zOgI4ASIYChZHZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0IkIKF0dldFNlc3Npb25Vc2FnZVJlc3BvbnNlEicKB3Nlc3Npb24YASABKAsyFi51c2FnZS52MS5TZXNzaW9uVXNhZ2UiFwoVR2V0V2Vla2x5VXNhZ2VSZXF1ZXN0Ij4KFkdldFdlZWtseVVzYWdlUmVzcG9uc2USJAoFdXNhZ2UYASABKAsyFS51c2FnZS52MS5XZWVrbHlVc2FnZTKaAgoMVXNhZ2VTZXJ2aWNlEoUBCg9HZXRTZXNzaW9uVXNhZ2USIC51c2FnZS52MS5HZXRTZXNzaW9uVXNhZ2VSZXF1ZXN0GiEudXNhZ2UudjEuR2V0U2Vzc2lvblVzYWdlUmVzcG9uc2UiLYLT5JMCJxIlL19wZC9hcGkvdjEvdXNlcnMvQHNlbGYvdXNhZ2Uvc2Vzc2lvbhKBAQoOR2V0V2Vla2x5VXNhZ2USHy51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlcXVlc3QaIC51c2FnZS52MS5HZXRXZWVrbHlVc2FnZVJlc3BvbnNlIiyC0+STAiYSJC9fcGQvYXBpL3YxL3VzZXJzL0BzZWxmL3VzYWdlL3dlZWtseUKHAQoMY29tLnVzYWdlLnYxQgpVc2FnZVByb3RvUAFaKnBhcGVyZGVidWdnZXIvcGtnL2dlbi9hcGkvdXNhZ2UvdjE7dXNhZ2V2MaICA1VYWKoCCFVzYWdlLlYxygIIVXNhZ2VcVjHiAhRVc2FnZVxWMVxHUEJNZXRhZGF0YeoCCVVzYWdlOjpWMWIGcHJvdG8z", [file_google_api_annotations, file_google_protobuf_timestamp]); + +/** + * @generated from message usage.v1.ModelTokens + */ +export type ModelTokens = Message<"usage.v1.ModelTokens"> & { + /** + * @generated from field: int64 prompt_tokens = 1; + */ + promptTokens: bigint; + + /** + * @generated from field: int64 completion_tokens = 2; + */ + completionTokens: bigint; + + /** + * @generated from field: int64 total_tokens = 3; + */ + totalTokens: bigint; + + /** + * @generated from field: int64 request_count = 4; + */ + requestCount: bigint; + + /** + * Cost in USD for this model + * + * @generated from field: double cost_usd = 5; + */ + costUsd: number; +}; + +/** + * Describes the message usage.v1.ModelTokens. + * Use `create(ModelTokensSchema)` to create a new message. + */ +export const ModelTokensSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 0); + +/** + * @generated from message usage.v1.SessionUsage + */ +export type SessionUsage = Message<"usage.v1.SessionUsage"> & { + /** + * @generated from field: google.protobuf.Timestamp session_expiry = 1; + */ + sessionExpiry?: Timestamp; + + /** + * Tokens per model (model_slug -> tokens) + * + * @generated from field: map models = 2; + */ + models: { [key: string]: ModelTokens }; + + /** + * Total cost in USD across all models + * + * @generated from field: double total_cost_usd = 3; + */ + totalCostUsd: number; +}; + +/** + * Describes the message usage.v1.SessionUsage. + * Use `create(SessionUsageSchema)` to create a new message. + */ +export const SessionUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 1); + +/** + * @generated from message usage.v1.WeeklyUsage + */ +export type WeeklyUsage = Message<"usage.v1.WeeklyUsage"> & { + /** + * Tokens per model (model_slug -> tokens) + * + * @generated from field: map models = 1; + */ + models: { [key: string]: ModelTokens }; + + /** + * @generated from field: int64 session_count = 2; + */ + sessionCount: bigint; + + /** + * Total cost in USD across all models + * + * @generated from field: double total_cost_usd = 3; + */ + totalCostUsd: number; +}; + +/** + * Describes the message usage.v1.WeeklyUsage. + * Use `create(WeeklyUsageSchema)` to create a new message. + */ +export const WeeklyUsageSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 2); + +/** + * @generated from message usage.v1.GetSessionUsageRequest + */ +export type GetSessionUsageRequest = Message<"usage.v1.GetSessionUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetSessionUsageRequest. + * Use `create(GetSessionUsageRequestSchema)` to create a new message. + */ +export const GetSessionUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 3); + +/** + * @generated from message usage.v1.GetSessionUsageResponse + */ +export type GetSessionUsageResponse = Message<"usage.v1.GetSessionUsageResponse"> & { + /** + * Active session usage, null if no active session + * + * @generated from field: usage.v1.SessionUsage session = 1; + */ + session?: SessionUsage; +}; + +/** + * Describes the message usage.v1.GetSessionUsageResponse. + * Use `create(GetSessionUsageResponseSchema)` to create a new message. + */ +export const GetSessionUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 4); + +/** + * @generated from message usage.v1.GetWeeklyUsageRequest + */ +export type GetWeeklyUsageRequest = Message<"usage.v1.GetWeeklyUsageRequest"> & { +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageRequest. + * Use `create(GetWeeklyUsageRequestSchema)` to create a new message. + */ +export const GetWeeklyUsageRequestSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 5); + +/** + * @generated from message usage.v1.GetWeeklyUsageResponse + */ +export type GetWeeklyUsageResponse = Message<"usage.v1.GetWeeklyUsageResponse"> & { + /** + * @generated from field: usage.v1.WeeklyUsage usage = 1; + */ + usage?: WeeklyUsage; +}; + +/** + * Describes the message usage.v1.GetWeeklyUsageResponse. + * Use `create(GetWeeklyUsageResponseSchema)` to create a new message. + */ +export const GetWeeklyUsageResponseSchema: GenMessage = /*@__PURE__*/ + messageDesc(file_usage_v1_usage, 6); + +/** + * @generated from service usage.v1.UsageService + */ +export const UsageService: GenService<{ + /** + * @generated from rpc usage.v1.UsageService.GetSessionUsage + */ + getSessionUsage: { + methodKind: "unary"; + input: typeof GetSessionUsageRequestSchema; + output: typeof GetSessionUsageResponseSchema; + }, + /** + * @generated from rpc usage.v1.UsageService.GetWeeklyUsage + */ + getWeeklyUsage: { + methodKind: "unary"; + input: typeof GetWeeklyUsageRequestSchema; + output: typeof GetWeeklyUsageResponseSchema; + }, +}> = /*@__PURE__*/ + serviceDesc(file_usage_v1_usage, 0); + diff --git a/webapp/_webapp/src/query/api.ts b/webapp/_webapp/src/query/api.ts index 4098a018..3ae67e4b 100644 --- a/webapp/_webapp/src/query/api.ts +++ b/webapp/_webapp/src/query/api.ts @@ -224,3 +224,29 @@ export const acceptComments = async (data: PlainMessage const response = await apiclient.post(`/comments/accepted`, data); return fromJson(CommentsAcceptedResponseSchema, response); }; + +// Usage +import { + GetSessionUsageResponseSchema, + GetWeeklyUsageResponseSchema, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; + +export const getSessionUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/session", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetSessionUsageResponseSchema, response); +}; + +export const getWeeklyUsage = async () => { + if (!apiclient.hasToken()) { + throw new Error("No token"); + } + const response = await apiclient.get("/users/@self/usage/weekly", undefined, { + ignoreErrorToast: true, + }); + return fromJson(GetWeeklyUsageResponseSchema, response); +}; diff --git a/webapp/_webapp/src/query/index.ts b/webapp/_webapp/src/query/index.ts index 2c05d959..4c9ea5cc 100644 --- a/webapp/_webapp/src/query/index.ts +++ b/webapp/_webapp/src/query/index.ts @@ -22,6 +22,8 @@ import { upsertUserInstructions, getProjectInstructions, upsertProjectInstructions, + getSessionUsage, + getWeeklyUsage, } from "./api"; import { CreatePromptResponse, @@ -37,6 +39,10 @@ import { GetProjectInstructionsResponse, UpsertProjectInstructionsResponse, } from "../pkg/gen/apiclient/project/v1/project_pb"; +import { + GetSessionUsageResponse, + GetWeeklyUsageResponse, +} from "../pkg/gen/apiclient/usage/v1/usage_pb"; import { useAuthStore } from "../stores/auth-store"; export const useGetProjectQuery = (projectId: string, opts?: UseQueryOptionsOverride) => { @@ -166,3 +172,24 @@ export const useUpsertProjectInstructionsMutation = ( ...opts, }); }; + +// Usage +export const useGetSessionUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getSessionUsage().queryKey, + queryFn: () => getSessionUsage(), + enabled: !!user, + ...opts, + }); +}; + +export const useGetWeeklyUsageQuery = (opts?: UseQueryOptionsOverride) => { + const { user } = useAuthStore(); + return useQuery({ + queryKey: queryKeys.usage.getWeeklyUsage().queryKey, + queryFn: () => getWeeklyUsage(), + enabled: !!user, + ...opts, + }); +}; diff --git a/webapp/_webapp/src/query/keys.ts b/webapp/_webapp/src/query/keys.ts index e09bfd7e..dfa3fc34 100644 --- a/webapp/_webapp/src/query/keys.ts +++ b/webapp/_webapp/src/query/keys.ts @@ -5,6 +5,10 @@ export const queryKeys = createQueryKeyStore({ getUser: () => ["users", "@self"], getUserInstructions: () => ["users", "@self", "instructions"], }, + usage: { + getSessionUsage: () => ["users", "@self", "usage", "session"], + getWeeklyUsage: () => ["users", "@self", "usage", "weekly"], + }, prompts: { listPrompts: () => ["users", "@self", "prompts"], }, diff --git a/webapp/_webapp/src/views/usage/index.tsx b/webapp/_webapp/src/views/usage/index.tsx new file mode 100644 index 00000000..3465d9ad --- /dev/null +++ b/webapp/_webapp/src/views/usage/index.tsx @@ -0,0 +1,179 @@ +import { Spinner, Button } from "@heroui/react"; +import { Icon } from "@iconify/react"; +import { useState, useEffect } from "react"; +import { TabHeader } from "../../components/tab-header"; +import { useGetSessionUsageQuery, useGetWeeklyUsageQuery } from "../../query"; +import CellWrapper from "../../components/cell-wrapper"; +import { useSettingStore } from "../../stores/setting-store"; + +const formatCost = (cost: number | undefined): string => { + if (cost === undefined || cost === 0) return "USD $0.00"; + if (cost < 0.01) return `USD $${cost.toFixed(4)}`; + return `USD $${cost.toFixed(2)}`; +}; + +const formatTimeRemaining = (timestamp: { seconds?: bigint; nanos?: number } | undefined): string => { + if (!timestamp || !timestamp.seconds) return ""; + const expiryMs = Number(timestamp.seconds) * 1000; + const nowMs = Date.now(); + const diffMs = expiryMs - nowMs; + + if (diffMs <= 0) return ""; + + const totalMinutes = Math.floor(diffMs / 60000); + const hours = Math.floor(totalMinutes / 60); + const minutes = totalMinutes % 60; + + if (hours > 0) { + return `resets in ${hours} hr ${minutes} min`; + } + return `resets in ${minutes} min`; +}; + +const formatLastUpdated = (timestamp: number): string => { + const diffMs = Date.now() - timestamp; + const seconds = Math.floor(diffMs / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + + if (seconds < 10) return "just now"; + if (seconds < 60) return `${seconds} seconds ago`; + if (minutes === 1) return "1 minute ago"; + if (minutes < 60) return `${minutes} minutes ago`; + if (hours === 1) return "1 hour ago"; + return `${hours} hours ago`; +}; + +const SectionContainer = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const SectionTitle = ({ children }: { children: React.ReactNode }) => { + return
{children}
; +}; + +const CostDisplay = ({ cost }: { cost: number | undefined }) => { + return ( +
+ {formatCost(cost)} +
+ ); +}; + +export const Usage = () => { + const { settings } = useSettingStore(); + const isBYOK = Boolean(settings?.openaiApiKey); + + const { + data: sessionData, + isLoading: sessionLoading, + dataUpdatedAt: sessionUpdatedAt, + refetch: refetchSession, + isFetching: sessionFetching, + } = useGetSessionUsageQuery(); + const { + data: weeklyData, + isLoading: weeklyLoading, + refetch: refetchWeekly, + isFetching: weeklyFetching, + } = useGetWeeklyUsageQuery(); + + const [, setTick] = useState(0); + + // Update the "last updated" text periodically + useEffect(() => { + const interval = setInterval(() => setTick((t) => t + 1), 10000); + return () => clearInterval(interval); + }, []); + + const isLoading = sessionLoading || weeklyLoading; + const isFetching = sessionFetching || weeklyFetching; + + const handleRefresh = () => { + refetchSession(); + refetchWeekly(); + }; + + if (isLoading) { + return ( +
+ +
+ ); + } + + // Show message for BYOK users + if (isBYOK) { + return ( +
+ +
+ +
+ Usage tracking is not available when using your own API key. +
+
+
+
+ ); + } + + const session = sessionData?.session; + const weekly = weeklyData?.usage; + + return ( +
+ +
+ + + Current Session Usage + {session?.sessionExpiry && ( + ({formatTimeRemaining(session.sessionExpiry)}) + )} + + {session ? ( + + + + ) : ( + +
No active session
+
+ )} +
+ + + Weekly Usage + {weekly ? ( + + + + ) : ( + +
No usage data available
+
+ )} +
+
+ All costs displayed are fully covered by the PaperDebugger Team. +
+
+ + Last updated: {formatLastUpdated(sessionUpdatedAt)} + + +
+
+
+ ); +};