Skip to content

Commit 33c613c

Browse files
easyCZroboquat
authored andcommitted
[stripe] Set reportId on invoices after updating credits
1 parent 9df045e commit 33c613c

File tree

3 files changed

+70
-19
lines changed

3 files changed

+70
-19
lines changed

components/usage/pkg/apiv1/billing.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (s *BillingService) UpdateInvoices(ctx context.Context, in *v1.UpdateInvoic
4949
return nil, status.Errorf(codes.Internal, "Failed to download usage report with ID: %s", in.GetReportId())
5050
}
5151

52-
credits, err := s.creditSummaryForTeams(report)
52+
credits, err := s.creditSummaryForTeams(report, in.GetReportId())
5353
if err != nil {
5454
log.Log.WithError(err).Errorf("Failed to compute credit summary.")
5555
return nil, status.Errorf(codes.InvalidArgument, "failed to compute credit summary")
@@ -100,7 +100,7 @@ func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcom
100100
}, nil
101101
}
102102

103-
func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[string]int64, error) {
103+
func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport, reportID string) (map[string]stripe.CreditSummary, error) {
104104
creditsPerTeamID := map[string]float64{}
105105

106106
for _, session := range sessions {
@@ -120,9 +120,12 @@ func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[str
120120
creditsPerTeamID[id] += session.CreditsUsed
121121
}
122122

123-
rounded := map[string]int64{}
123+
rounded := map[string]stripe.CreditSummary{}
124124
for teamID, credits := range creditsPerTeamID {
125-
rounded[teamID] = int64(math.Ceil(credits))
125+
rounded[teamID] = stripe.CreditSummary{
126+
Credits: int64(math.Ceil(credits)),
127+
ReportID: reportID,
128+
}
126129
}
127130

128131
return rounded, nil

components/usage/pkg/apiv1/billing_test.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,19 @@ import (
1818
func TestCreditSummaryForTeams(t *testing.T) {
1919
teamID_A, teamID_B := uuid.New().String(), uuid.New().String()
2020
teamAttributionID_A, teamAttributionID_B := db.NewTeamAttributionID(teamID_A), db.NewTeamAttributionID(teamID_B)
21+
reportID := "report_id_1"
2122

2223
scenarios := []struct {
2324
Name string
2425
Sessions db.UsageReport
2526
BillSessionsAfter time.Time
26-
Expected map[string]int64
27+
Expected map[string]stripe.CreditSummary
2728
}{
2829
{
2930
Name: "no instances in report, no summary",
3031
BillSessionsAfter: time.Time{},
3132
Sessions: nil,
32-
Expected: map[string]int64{},
33+
Expected: map[string]stripe.CreditSummary{},
3334
},
3435
{
3536
Name: "skips user attributions",
@@ -39,7 +40,7 @@ func TestCreditSummaryForTeams(t *testing.T) {
3940
AttributionID: db.NewUserAttributionID(uuid.New().String()),
4041
},
4142
},
42-
Expected: map[string]int64{},
43+
Expected: map[string]stripe.CreditSummary{},
4344
},
4445
{
4546
Name: "two workspace instances",
@@ -56,9 +57,12 @@ func TestCreditSummaryForTeams(t *testing.T) {
5657
CreditsUsed: 10,
5758
},
5859
},
59-
Expected: map[string]int64{
60+
Expected: map[string]stripe.CreditSummary{
6061
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
61-
teamID_A: 480,
62+
teamID_A: {
63+
Credits: 480,
64+
ReportID: reportID,
65+
},
6266
},
6367
},
6468
{
@@ -76,10 +80,16 @@ func TestCreditSummaryForTeams(t *testing.T) {
7680
CreditsUsed: (24) * 10,
7781
},
7882
},
79-
Expected: map[string]int64{
83+
Expected: map[string]stripe.CreditSummary{
8084
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
81-
teamID_A: 120,
82-
teamID_B: 240,
85+
teamID_A: {
86+
Credits: 120,
87+
ReportID: reportID,
88+
},
89+
teamID_B: {
90+
Credits: 240,
91+
ReportID: reportID,
92+
},
8393
},
8494
},
8595
{
@@ -99,16 +109,19 @@ func TestCreditSummaryForTeams(t *testing.T) {
99109
StartedAt: time.Now().AddDate(0, 0, -3),
100110
},
101111
},
102-
Expected: map[string]int64{
103-
teamID_A: 120,
112+
Expected: map[string]stripe.CreditSummary{
113+
teamID_A: {
114+
Credits: 120,
115+
ReportID: reportID,
116+
},
104117
},
105118
},
106119
}
107120

108121
for _, s := range scenarios {
109122
t.Run(s.Name, func(t *testing.T) {
110123
svc := NewBillingService(&stripe.Client{}, s.BillSessionsAfter, &gorm.DB{})
111-
actual, err := svc.creditSummaryForTeams(s.Sessions)
124+
actual, err := svc.creditSummaryForTeams(s.Sessions, reportID)
112125
require.NoError(t, err)
113126
require.Equal(t, s.Expected, actual)
114127
})

components/usage/pkg/stripe/stripe.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ import (
1616
"github.com/stripe/stripe-go/v72/client"
1717
)
1818

19+
const (
20+
reportIDMetadataKey = "reportId"
21+
)
22+
1923
type Client struct {
2024
sc *client.API
2125
}
@@ -58,9 +62,14 @@ type Invoice struct {
5862
Credits int64
5963
}
6064

65+
type CreditSummary struct {
66+
Credits int64
67+
ReportID string
68+
}
69+
6170
// UpdateUsage updates teams' Stripe subscriptions with usage data
6271
// `usageForTeam` is a map from team name to total workspace seconds used within a billing period.
63-
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]int64) error {
72+
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]CreditSummary) error {
6473
teamIds := make([]string, 0, len(creditsPerTeam))
6574
for k := range creditsPerTeam {
6675
teamIds = append(teamIds, k)
@@ -117,7 +126,7 @@ func (c *Client) findCustomers(ctx context.Context, query string) ([]*stripe.Cus
117126
return customers, nil
118127
}
119128

120-
func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, credits int64) (*UsageRecord, error) {
129+
func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, summary CreditSummary) (*UsageRecord, error) {
121130
subscriptions := customer.Subscriptions.Data
122131
if len(subscriptions) != 1 {
123132
return nil, fmt.Errorf("customer has an unexpected number of subscriptions %v (expected 1, got %d)", subscriptions, len(subscriptions))
@@ -136,15 +145,27 @@ func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Cu
136145
Context: ctx,
137146
},
138147
SubscriptionItem: stripe.String(subscriptionItemId),
139-
Quantity: stripe.Int64(credits),
148+
Quantity: stripe.Int64(summary.Credits),
140149
})
141150
if err != nil {
142151
return nil, fmt.Errorf("failed to register usage for customer %q on subscription item %s", customer.Name, subscriptionItemId)
143152
}
144153

154+
invoice, err := c.GetUpcomingInvoice(ctx, customer.ID)
155+
if err != nil {
156+
return nil, fmt.Errorf("failed to find upcoming invoice for customer %s: %w", customer.ID, err)
157+
}
158+
159+
_, err = c.UpdateInvoiceMetadata(ctx, invoice.ID, map[string]string{
160+
reportIDMetadataKey: summary.ReportID,
161+
})
162+
if err != nil {
163+
return nil, fmt.Errorf("failed to udpate invoice %s metadata with report ID: %w", invoice.ID, err)
164+
}
165+
145166
return &UsageRecord{
146167
SubscriptionItemID: subscriptionItemId,
147-
Quantity: credits,
168+
Quantity: summary.Credits,
148169
}, nil
149170
}
150171

@@ -205,6 +226,20 @@ func (c *Client) GetUpcomingInvoice(ctx context.Context, customerID string) (*In
205226
}, nil
206227
}
207228

229+
func (c *Client) UpdateInvoiceMetadata(ctx context.Context, invoiceID string, metadata map[string]string) (*stripe.Invoice, error) {
230+
invoice, err := c.sc.Invoices.Update(invoiceID, &stripe.InvoiceParams{
231+
Params: stripe.Params{
232+
Context: ctx,
233+
Metadata: metadata,
234+
},
235+
})
236+
if err != nil {
237+
return nil, fmt.Errorf("failed to update invoice %s metadata: %w", invoiceID, err)
238+
}
239+
240+
return invoice, nil
241+
}
242+
208243
// queriesForCustomersWithTeamIds constructs Stripe query strings to find the Stripe Customer for each teamId
209244
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
210245
// `clausesPerQuery` is a limit enforced by the Stripe API.

0 commit comments

Comments
 (0)