Skip to content

Commit c7ccb50

Browse files
committed
[usage] implement CancelSubscription
1 parent 58c2413 commit c7ccb50

File tree

5 files changed

+74
-21
lines changed

5 files changed

+74
-21
lines changed

components/usage/pkg/apiv1/billing.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,18 @@ import (
2222
"gorm.io/gorm"
2323
)
2424

25-
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB) *BillingService {
25+
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService {
2626
return &BillingService{
2727
stripeClient: stripeClient,
2828
conn: conn,
29+
ccManager: ccManager,
2930
}
3031
}
3132

3233
type BillingService struct {
3334
conn *gorm.DB
3435
stripeClient *stripe.Client
36+
ccManager *db.CostCenterManager
3537

3638
v1.UnimplementedBillingServiceServer
3739
}
@@ -76,19 +78,18 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
7678
return nil, status.Errorf(codes.Internal, "Failed to retrieve subscription details from invoice.")
7779
}
7880

79-
teamID, found := subscription.Metadata[stripe.AttributionIDMetadataKey]
81+
attrID, found := subscription.Metadata[stripe.AttributionIDMetadataKey]
8082
if !found {
8183
logger.Error("Failed to find teamID from subscription metadata.")
8284
return nil, status.Errorf(codes.Internal, "Failed to extra teamID from Stripe subscription.")
8385
}
84-
logger = logger.WithField("team_id", teamID)
86+
logger = logger.WithField("attribution_id", attrID)
8587

8688
// To support individual `user`s, we'll need to also extract the `userId` from metadata here and handle separately.
87-
attributionID := db.NewTeamAttributionID(teamID)
89+
attributionID := db.NewTeamAttributionID(attrID)
8890
finalizedAt := time.Unix(invoice.StatusTransitions.FinalizedAt, 0)
8991

9092
logger = logger.
91-
WithField("attribution_id", attributionID).
9293
WithField("invoice_finalized_at", finalizedAt)
9394

9495
if invoice.Lines == nil || len(invoice.Lines.Data) == 0 {
@@ -126,6 +127,29 @@ func (s *BillingService) FinalizeInvoice(ctx context.Context, in *v1.FinalizeInv
126127
return &v1.FinalizeInvoiceResponse{}, nil
127128
}
128129

130+
func (s *BillingService) CancelSubscription(ctx context.Context, in *v1.CancelSubscriptionRequest) (*v1.CancelSubscriptionResponse, error) {
131+
logger := log.WithField("subscription_id", in.GetSubscriptionId())
132+
if in.GetSubscriptionId() == "" {
133+
return nil, status.Errorf(codes.InvalidArgument, "subscriptionId is required")
134+
}
135+
136+
attributionID, err := s.stripeClient.GetAttributionIdForSubscriptionId(ctx, in.GetSubscriptionId())
137+
if err != nil {
138+
return nil, err
139+
}
140+
logger.Infof("Subscription ended. Setting cost center back to free.")
141+
costCenter, err := s.ccManager.GetOrCreateCostCenter(ctx, attributionID)
142+
if err != nil {
143+
return nil, err
144+
}
145+
costCenter.BillingStrategy = db.CostCenter_Other
146+
_, err = s.ccManager.UpdateCostCenter(ctx, costCenter)
147+
if err != nil {
148+
return nil, err
149+
}
150+
return &v1.CancelSubscriptionResponse{}, nil
151+
}
152+
129153
func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcomingInvoiceRequest) (*v1.GetUpcomingInvoiceResponse, error) {
130154
if in.GetTeamId() == "" && in.GetUserId() == "" {
131155
return nil, status.Errorf(codes.InvalidArgument, "teamId or userId is required")

components/usage/pkg/db/cost_center.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,16 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos
112112

113113
now := time.Now()
114114

115+
// we always update the creationTime
116+
costCenter.CreationTime = NewVarcharTime(now)
115117
// we don't allow setting the creationTime or the nextBillingTime from outside
116118
costCenter.CreationTime = existingCostCenter.CreationTime
117119
costCenter.NextBillingTime = existingCostCenter.NextBillingTime
118120

119121
// Do we have a billing strategy update?
120122
if costCenter.BillingStrategy != existingCostCenter.BillingStrategy {
121-
if existingCostCenter.BillingStrategy == CostCenter_Other {
123+
switch costCenter.BillingStrategy {
124+
case CostCenter_Stripe:
122125
// moving to stripe -> let's run a finalization
123126
finalizationUsage, err := c.ComputeInvoiceUsageRecord(ctx, costCenter.ID)
124127
if err != nil {
@@ -130,12 +133,21 @@ func (c *CostCenterManager) UpdateCostCenter(ctx context.Context, costCenter Cos
130133
return CostCenter{}, err
131134
}
132135
}
136+
// we don't manage stripe billing cycle
137+
costCenter.NextBillingTime = VarcharTime{}
138+
139+
case CostCenter_Other:
140+
// cancelled from stripe reset the spending limit
141+
if costCenter.ID.IsEntity(AttributionEntity_Team) {
142+
costCenter.SpendingLimit = c.cfg.ForTeams
143+
} else {
144+
costCenter.SpendingLimit = c.cfg.ForUsers
145+
}
146+
// see you next month
147+
costCenter.NextBillingTime = NewVarcharTime(now.AddDate(0, 1, 0))
133148
}
134-
c.updateNextBillingTime(&costCenter, now)
135149
}
136150

137-
// we update the creationTime
138-
costCenter.CreationTime = NewVarcharTime(now)
139151
db := c.conn.Save(&costCenter)
140152
if db.Error != nil {
141153
return CostCenter{}, fmt.Errorf("failed to save cost center for attributionID %s: %w", costCenter.ID, db.Error)
@@ -163,8 +175,3 @@ func (c *CostCenterManager) ComputeInvoiceUsageRecord(ctx context.Context, attri
163175
Draft: false,
164176
}, nil
165177
}
166-
167-
func (c *CostCenterManager) updateNextBillingTime(costCenter *CostCenter, now time.Time) {
168-
nextMonth := NewVarcharTime(time.Now().AddDate(0, 1, 0))
169-
costCenter.NextBillingTime = nextMonth
170-
}

components/usage/pkg/db/cost_center_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,28 @@ func TestCostCenterManager_UpdateCostCenter(t *testing.T) {
8383
func TestSaveCostCenterMovedToStripe(t *testing.T) {
8484
conn := dbtest.ConnectForTests(t)
8585
mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{
86-
ForTeams: 0,
86+
ForTeams: 20,
8787
ForUsers: 500,
8888
})
8989
team := db.NewTeamAttributionID(uuid.New().String())
9090
cleanUp(t, conn, team)
9191
teamCC, err := mnr.GetOrCreateCostCenter(context.Background(), team)
9292
require.NoError(t, err)
93-
require.Equal(t, int32(0), teamCC.SpendingLimit)
93+
require.Equal(t, int32(20), teamCC.SpendingLimit)
9494

9595
teamCC.BillingStrategy = db.CostCenter_Stripe
96-
newTeamCC, err := mnr.UpdateCostCenter(context.Background(), teamCC)
96+
teamCC.SpendingLimit = 400050
97+
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
98+
require.NoError(t, err)
99+
require.Equal(t, db.CostCenter_Stripe, teamCC.BillingStrategy)
100+
require.Equal(t, db.VarcharTime{}, teamCC.NextBillingTime)
101+
require.Equal(t, int32(400050), teamCC.SpendingLimit)
102+
103+
teamCC.BillingStrategy = db.CostCenter_Other
104+
teamCC, err = mnr.UpdateCostCenter(context.Background(), teamCC)
97105
require.NoError(t, err)
98-
require.Equal(t, db.CostCenter_Stripe, newTeamCC.BillingStrategy)
99-
require.Equal(t, newTeamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), newTeamCC.NextBillingTime.Time().Truncate(time.Second))
106+
require.Equal(t, teamCC.CreationTime.Time().AddDate(0, 1, 0).Truncate(time.Second), teamCC.NextBillingTime.Time().Truncate(time.Second))
107+
require.Equal(t, int32(20), teamCC.SpendingLimit)
100108
}
101109

102110
func cleanUp(t *testing.T, conn *gorm.DB, attributionIds ...db.AttributionID) {

components/usage/pkg/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s
156156
if stripeClient == nil {
157157
v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{})
158158
} else {
159-
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn))
159+
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager))
160160
}
161161
return nil
162162
}

components/usage/pkg/stripe/stripe.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import (
88
"context"
99
"encoding/json"
1010
"fmt"
11-
"github.com/gitpod-io/gitpod/usage/pkg/db"
1211
"os"
1312
"strings"
1413

14+
"github.com/gitpod-io/gitpod/usage/pkg/db"
15+
1516
"github.com/gitpod-io/gitpod/common-go/log"
1617
"github.com/stripe/stripe-go/v72"
1718
"github.com/stripe/stripe-go/v72/client"
@@ -244,6 +245,19 @@ func (c *Client) GetInvoice(ctx context.Context, invoiceID string) (*stripe.Invo
244245
return invoice, nil
245246
}
246247

248+
func (c *Client) GetAttributionIdForSubscriptionId(ctx context.Context, subscriptionID string) (db.AttributionID, error) {
249+
subscription, err := c.sc.Subscriptions.Get(subscriptionID, nil)
250+
if err != nil {
251+
return "", fmt.Errorf("failed to search for subscription (%s): %w", subscriptionID, err)
252+
}
253+
attributionIDRaw := subscription.Customer.Metadata[AttributionIDMetadataKey]
254+
attributionID, err := db.ParseAttributionID(attributionIDRaw)
255+
if err != nil {
256+
return "", fmt.Errorf("failed to fetch AttributionID for subscription (%s): %w", subscriptionID, err)
257+
}
258+
return attributionID, nil
259+
}
260+
247261
// queriesForCustomersWithAttributionIDs constructs Stripe query strings to find the Stripe Customer for each teamId
248262
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
249263
// `clausesPerQuery` is a limit enforced by the Stripe API.

0 commit comments

Comments
 (0)