diff --git a/README.md b/README.md index 7d6bb0c..3d5c28f 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,26 @@ func AddBalance(ctx context.Context, acc LedgerAccount, amount int) (err error) } ``` + + + +### UpdateItem + +```go +err := table.UpdateItem( + context.TODO(), + dynago.StringValue("partitionKey"), + dynago.StringValue("sortKey"), + aws.String("ADD Income :increment"), + map[string]ddbtypes.AttributeValue{ + ":increment": &ddbtypes.AttributeValueMemberN{Value: "1"}, + }, + []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_NEW"), + }, +) +``` + ### Query ```go diff --git a/interface.go b/interface.go index 640044b..c9f654d 100644 --- a/interface.go +++ b/interface.go @@ -4,6 +4,7 @@ import ( "context" "strconv" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" ) @@ -28,6 +29,7 @@ type WriteAPI interface { DeleteItem(ctx context.Context, pk, sk string) error BatchDeleteItems(ctx context.Context, input []AttributeRecord) []AttributeRecord BatchPutItems(ctx context.Context, items []BatchPutItemsInput) error + UpdateItem(ctx context.Context, pk Attribute, sk Attribute, updateExpression *string, expressionAttributeValues map[string]types.AttributeValue, opts ...UpdateOption) (*dynamodb.UpdateItemOutput, error) } type TransactionAPI interface { diff --git a/tests/updateitem_test.go b/tests/updateitem_test.go new file mode 100644 index 0000000..f70b27a --- /dev/null +++ b/tests/updateitem_test.go @@ -0,0 +1,190 @@ +package tests + +import ( + "context" + "crypto/rand" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/oolio-group/dynago" +) + +func TestUpdateItem(t *testing.T) { + ddbClient := prepareTable(t) + + partitionKey := "org_123#" + rand.Text() // avoid interference with other tests + + testCases := []struct { + name string + givePk dynago.Attribute + giveSk dynago.Attribute + giveUpdateExpr *string + giveExpressionAttributeValues map[string]ddbtypes.AttributeValue + giveOptions []dynago.UpdateOption + wantAttributes map[string]ddbtypes.AttributeValue + wantErrStr string + }{ + { + name: "inserting new item and requesting ALL_OLD attributes returns no attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-jan"), + giveUpdateExpr: aws.String("SET Income = :v"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":v": &ddbtypes.AttributeValueMemberN{Value: "1000"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_OLD"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{}, + }, + { + name: "inserting new item and requesting ALL_NEW attributes returns all attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-feb"), + giveUpdateExpr: aws.String("SET Income = :v"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":v": &ddbtypes.AttributeValueMemberN{Value: "1000"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_NEW"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{ + "Income": &ddbtypes.AttributeValueMemberN{Value: "1000"}, + "pk": &ddbtypes.AttributeValueMemberS{Value: partitionKey}, + "sk": &ddbtypes.AttributeValueMemberS{Value: "2026-feb"}, + }, + }, + { + name: "updating existing item and requesting ALL_OLD attributes returns all attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-jan"), + giveUpdateExpr: aws.String("SET Income = :v"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":v": &ddbtypes.AttributeValueMemberN{Value: "2000"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_OLD"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{ + "Income": &ddbtypes.AttributeValueMemberN{Value: "1000"}, // old value is returned + "pk": &ddbtypes.AttributeValueMemberS{Value: partitionKey}, + "sk": &ddbtypes.AttributeValueMemberS{Value: "2026-jan"}, + }, + }, + { + name: "updating existing item and requesting ALL_NEW attributes returns new attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-jan"), + giveUpdateExpr: aws.String("SET Income = :v"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":v": &ddbtypes.AttributeValueMemberN{Value: "3000"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_NEW"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{ + "Income": &ddbtypes.AttributeValueMemberN{Value: "3000"}, // new value is returned + "pk": &ddbtypes.AttributeValueMemberS{Value: partitionKey}, + "sk": &ddbtypes.AttributeValueMemberS{Value: "2026-jan"}, + }, + }, + { + name: "incrementing non-existing item with ALL_NEW returns new attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-mar"), + giveUpdateExpr: aws.String("ADD Income :increment"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":increment": &ddbtypes.AttributeValueMemberN{Value: "8"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_NEW"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{ + "Income": &ddbtypes.AttributeValueMemberN{Value: "8"}, + "pk": &ddbtypes.AttributeValueMemberS{Value: partitionKey}, + "sk": &ddbtypes.AttributeValueMemberS{Value: "2026-mar"}, + }, + }, + { + name: "incrementing existing item with ALL_NEW returns new attributes", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-mar"), + giveUpdateExpr: aws.String("ADD Income :increment"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":increment": &ddbtypes.AttributeValueMemberN{Value: "8"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithReturnValues("ALL_NEW"), + }, + wantAttributes: map[string]ddbtypes.AttributeValue{ + "Income": &ddbtypes.AttributeValueMemberN{Value: "16"}, + "pk": &ddbtypes.AttributeValueMemberS{Value: partitionKey}, + "sk": &ddbtypes.AttributeValueMemberS{Value: "2026-mar"}, + }, + }, + { + name: "increment missing item with condition expression causes an error", + givePk: dynago.StringValue(partitionKey), + giveSk: dynago.StringValue("2026-may"), + giveUpdateExpr: aws.String("ADD Income :increment"), + giveExpressionAttributeValues: map[string]ddbtypes.AttributeValue{ + ":increment": &ddbtypes.AttributeValueMemberN{Value: "8"}, + }, + giveOptions: []dynago.UpdateOption{ + dynago.WithConditionExpression("attribute_exists(pk) AND attribute_exists(sk)"), // want failure is the item does not exist + dynago.WithReturnValues("ALL_NEW"), + }, + wantAttributes: nil, + wantErrStr: "ConditionalCheckFailedException: The conditional request failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // t.Parallel() // commented out: DO NOT RUN THE TESTS IN PARALLEL, because they depend on each other (e.g., one test creates an item that another test updates) + + gotResponse, gotErr := ddbClient.UpdateItem(context.TODO(), tc.givePk, tc.giveSk, tc.giveUpdateExpr, tc.giveExpressionAttributeValues, tc.giveOptions...) + + if tc.wantErrStr != "" { + if gotErr == nil { + t.Fatalf("expected error but got nil") + } + + if !strings.Contains(gotErr.Error(), tc.wantErrStr) { + t.Fatalf("error message does not contain expected substring: got %q, want %q", gotErr.Error(), tc.wantErrStr) + } + } else { + if gotErr != nil { + t.Fatalf("unexpected error: %v", gotErr) + } + + if len(gotResponse.Attributes) != len(tc.wantAttributes) { + t.Fatalf("number of attributes does not match: got %d, want %d", len(gotResponse.Attributes), len(tc.wantAttributes)) + } + + for key, value := range tc.wantAttributes { + gotResponseValue := gotResponse.Attributes[key] + if gotResponseValue == nil { + t.Errorf("attribute %q not found in response", key) + continue + } + + switch v := gotResponseValue.(type) { + case *ddbtypes.AttributeValueMemberN: // number + if v.Value != value.(*ddbtypes.AttributeValueMemberN).Value { + t.Errorf("attribute %q does not match: got %q, want %q", key, v.Value, value.(*ddbtypes.AttributeValueMemberN).Value) + } + case *ddbtypes.AttributeValueMemberS: // string + if v.Value != value.(*ddbtypes.AttributeValueMemberS).Value { + t.Errorf("attribute %q does not match: got %q, want %q", key, v.Value, value.(*ddbtypes.AttributeValueMemberS).Value) + } + default: + t.Errorf("unsupported attribute value type for key %q: %T", key, gotResponseValue) + } + } + } + }) + } +} diff --git a/update_item.go b/update_item.go new file mode 100644 index 0000000..48575b8 --- /dev/null +++ b/update_item.go @@ -0,0 +1,86 @@ +package dynago + +import ( + "context" + //"fmt" + "log" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + dynamodbTypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type UpdateOption func(*dynamodb.UpdateItemInput) error + +func WithReturnValues(returnValues string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ReturnValues = dynamodbTypes.ReturnValue(returnValues) + return nil + } +} + +func WithConditionExpression(conditionExpression string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ConditionExpression = &conditionExpression + return nil + } +} + +func WithReturnConsumedCapacity(returnConsumedCapacity string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ReturnConsumedCapacity = dynamodbTypes.ReturnConsumedCapacity(returnConsumedCapacity) + return nil + } +} + +func WithReturnItemCollectionMetrics(returnItemCollectionMetrics string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ReturnItemCollectionMetrics = dynamodbTypes.ReturnItemCollectionMetrics(returnItemCollectionMetrics) + return nil + } +} + +func WithReturnValuesOnConditionCheckFailure(returnValuesOnConditionCheckFailure string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ReturnValuesOnConditionCheckFailure = dynamodbTypes.ReturnValuesOnConditionCheckFailure(returnValuesOnConditionCheckFailure) + return nil + } +} + +// UpdateItem updates a db record from dynamodb given a partition key and sort key +// @param item the item put into the database +// @return true if the record was updated, false otherwise +func (t *Client) UpdateItem( + ctx context.Context, + pk Attribute, + sk Attribute, + updateExpression *string, + expressionAttributeValues map[string]dynamodbTypes.AttributeValue, + opts ...UpdateOption, +) (*dynamodb.UpdateItemOutput, error) { + input := &dynamodb.UpdateItemInput{ + TableName: &t.TableName, + Key: t.NewKeys(pk, sk), + UpdateExpression: updateExpression, + } + + if len(expressionAttributeValues) > 0 { + input.ExpressionAttributeValues = expressionAttributeValues + } + + // Apply option functions + if len(opts) > 0 { + for _, opt := range opts { + if err := opt(input); err != nil { + return nil, err + } + } + } + + ret, err := t.client.UpdateItem(ctx, input) + if err != nil { + log.Println("Failed to Update item" + err.Error()) + return nil, err + } + + return ret, nil +}