diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index a5e991db8d8..a4146823d58 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -387,6 +387,13 @@ func (c *CassandraOnlineStore) validateUniqueFeatureNames(featureViewNames []str return nil } +func convertTimestampParam(value interface{}) interface{} { + if valInt64, ok := value.(int64); ok { + return time.Unix(valInt64, 0) + } + return value +} + func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { if err := c.validateUniqueFeatureNames(featureViewNames); err != nil { return nil, err @@ -692,109 +699,137 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ func (c *CassandraOnlineStore) rangeFilterToCQL(filter *model.SortKeyFilter) (string, []interface{}) { rangeParams := make([]interface{}, 0) - equality := "" + if filter.Equals != nil { - equality = fmt.Sprintf(`"%s" = ?`, filter.SortKeyName) - rangeParams = append(rangeParams, filter.Equals) + paramVal := convertTimestampParam(filter.Equals) + equality = fmt.Sprintf("\"%s\" = ?", filter.SortKeyName) + rangeParams = append(rangeParams, paramVal) return equality, rangeParams } rangeStart := "" if filter.RangeStart != nil { + paramVal := convertTimestampParam(filter.RangeStart) if filter.StartInclusive { rangeStart = fmt.Sprintf(`"%s" >= ?`, filter.SortKeyName) } else { rangeStart = fmt.Sprintf(`"%s" > ?`, filter.SortKeyName) } - rangeParams = append(rangeParams, filter.RangeStart) + rangeParams = append(rangeParams, paramVal) } + rangeEnd := "" if filter.RangeEnd != nil { + paramVal := convertTimestampParam(filter.RangeEnd) if filter.EndInclusive { rangeEnd = fmt.Sprintf(`"%s" <= ?`, filter.SortKeyName) } else { rangeEnd = fmt.Sprintf(`"%s" < ?`, filter.SortKeyName) } - rangeParams = append(rangeParams, filter.RangeEnd) + rangeParams = append(rangeParams, paramVal) } + var condition string if rangeStart != "" && rangeEnd != "" { - return fmt.Sprintf(`%s AND %s`, rangeStart, rangeEnd), rangeParams + condition = fmt.Sprintf("%s AND %s", rangeStart, rangeEnd) } else if rangeStart != "" { - return rangeStart, rangeParams + condition = rangeStart } else if rangeEnd != "" { - return rangeEnd, rangeParams + condition = rangeEnd } else { - return "", rangeParams + condition = "" } + return condition, rangeParams } -func (c *CassandraOnlineStore) getRangeQueryCQLStatement(tableName string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) (string, []interface{}) { +func (c CassandraOnlineStore) getRangeQueryCQLStatement(tableName string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) (string, []interface{}) { // this prevents fetching unnecessary features quotedFeatureNames := make([]string, len(featureNames)) for i, featureName := range featureNames { quotedFeatureNames[i] = fmt.Sprintf(`"%s"`, featureName) } - rangeFilterString := "" - orderByString := "" - params := make([]interface{}, 0) + rangeFilterClauses := make([]string, 0) + orderByClauses := make([]string, 0) + allParams := make([]interface{}, 0) + if len(sortKeyFilters) > 0 { - rangeFilters := make([]string, 0) - orderBy := make([]string, 0) for _, filter := range sortKeyFilters { filterString, filterParams := c.rangeFilterToCQL(filter) if filterString != "" { - rangeFilters = append(rangeFilters, filterString) + rangeFilterClauses = append(rangeFilterClauses, filterString) + allParams = append(allParams, filterParams...) } - orderBy = append(orderBy, fmt.Sprintf(`"%s" %s`, filter.SortKeyName, filter.Order.String())) - params = append(params, filterParams...) + orderByClauses = append(orderByClauses, fmt.Sprintf("\"%s\" %s", filter.SortKeyName, filter.Order.String())) } - if len(rangeFilters) > 0 { - rangeFilterString = fmt.Sprintf(" AND %s", strings.Join(rangeFilters, " AND ")) - } - orderByString = fmt.Sprintf(" ORDER BY %s", strings.Join(orderBy, ", ")) + } + + rangeFilterString := "" + if len(rangeFilterClauses) > 0 { + rangeFilterString = fmt.Sprintf(" AND %s", strings.Join(rangeFilterClauses, " AND ")) + } + + orderByString := "" + if len(orderByClauses) > 0 { + orderByString = fmt.Sprintf(" ORDER BY %s", strings.Join(orderByClauses, ", ")) } limitString := "" if limit > 0 { limitString = " LIMIT ?" - params = append(params, limit) + allParams = append(allParams, limit) } - return fmt.Sprintf( - `SELECT "entity_key", "event_ts", %s FROM %s WHERE "entity_key" = ?%s%s%s`, - strings.Join(quotedFeatureNames, ", "), + selectColumns := append([]string{"\"entity_key\"", "\"event_ts\""}, quotedFeatureNames...) + uniqueSelectColumnsMap := make(map[string]struct{}) + uniqueSelectColumns := []string{} + for _, col := range selectColumns { + if _, exists := uniqueSelectColumnsMap[col]; !exists { + uniqueSelectColumnsMap[col] = struct{}{} + uniqueSelectColumns = append(uniqueSelectColumns, col) + } + } + + cql := fmt.Sprintf( + "SELECT %s FROM %s WHERE \"entity_key\" = ?%s%s%s", + strings.Join(uniqueSelectColumns, ", "), // Use unique columns tableName, rangeFilterString, orderByString, limitString, - ), params + ) + return cql, allParams } -func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) ([][]RangeFeatureData, error) { +func (c CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string, sortKeyFilters []*model.SortKeyFilter, limit int32) ([][]RangeFeatureData, error) { if err := c.validateUniqueFeatureNames(featureViewNames); err != nil { return nil, err } - - serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + serializedEntityKeys, _, err := c.buildCassandraEntityKeys(entityKeys) if err != nil { return nil, fmt.Errorf("error when serializing entity keys for Cassandra: %v", err) } + results := make([][]RangeFeatureData, len(entityKeys)) for i := range results { results[i] = make([]RangeFeatureData, len(featureNames)) + for j := range results[i] { + results[i][j] = RangeFeatureData{ + FeatureView: featureViewNames[0], + FeatureName: featureNames[j], + Values: make([]interface{}, 0), + EventTimestamps: make([]timestamppb.Timestamp, 0), + } + } } featureNamesToIdx := make(map[string]int) for idx, name := range featureNames { featureNamesToIdx[name] = idx } - featureViewName := featureViewNames[0] - // Prepare the query tableName, err := c.getFqTableName(c.clusterConfigs.Keyspace, c.project, featureViewName, c.tableNameFormatVersion) if err != nil { return nil, err @@ -804,84 +839,59 @@ func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys [ var waitGroup sync.WaitGroup waitGroup.Add(len(serializedEntityKeys)) - errorsChannel := make(chan error, len(serializedEntityKeys)) - for _, serializedEntityKey := range serializedEntityKeys { - go func(serEntityKey any) { - defer waitGroup.Done() + for i, serializedEntityKey := range serializedEntityKeys { + go func(serEntityKey interface{}, entityIndex int) { + defer waitGroup.Done() queryParams := append([]interface{}{serEntityKey}, rangeParams...) iter := c.session.Query(cqlStatement, queryParams...).WithContext(ctx).Iter() - rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)] - // fill the row with nulls if not found - if iter.NumRows() == 0 { - for _, featName := range featureNames { - results[rowIdx][featureNamesToIdx[featName]] = RangeFeatureData{ - FeatureView: featureViewName, - FeatureName: featName, - Values: []interface{}{nil}, - Statuses: []serving.FieldStatus{serving.FieldStatus_NOT_FOUND}, - } + rowDataList := make([]map[string]interface{}, 0, iter.NumRows()) + for { + row := make(map[string]interface{}) + if !iter.MapScan(row) { + break } + rowDataList = append(rowDataList, row) + } + + if err := iter.Close(); err != nil { + errorsChannel <- fmt.Errorf("error iterating results for entity %v: %w", serEntityKey, err) return } - for i := 0; i < iter.NumRows(); i++ { - readValues := make(map[string]interface{}) - iter.MapScan(readValues) - eventTs := readValues["event_ts"].(time.Time) + if len(rowDataList) == 0 { + for j := range featureNames { + results[entityIndex][j].Values = []interface{}{nil} + results[entityIndex][j].EventTimestamps = []timestamppb.Timestamp{{}} + } + return + } - rowFeatures := results[rowIdx] - for _, featName := range featureNames { - if val, ok := readValues[featName]; ok { - var status serving.FieldStatus - if val == nil { - status = serving.FieldStatus_NULL_VALUE - } else { - status = serving.FieldStatus_PRESENT - } + entityResults := results[entityIndex] + for _, readValues := range rowDataList { + var eventTs time.Time + if tsVal, ok := readValues["event_ts"].(time.Time); ok { + eventTs = tsVal + } else { + errorsChannel <- fmt.Errorf("event_ts missing or not time.Time for entity %v, row %v", serEntityKey, readValues) + continue + } + eventTsProtoPtr := timestamppb.New(eventTs) + if eventTsProtoPtr == nil { + errorsChannel <- fmt.Errorf("failed to create timestamp proto for entity %v", serEntityKey) + continue + } + eventTsProtoValue := *eventTsProtoPtr - if featureData := &rowFeatures[featureNamesToIdx[featName]]; featureData != nil { - rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{ - FeatureView: featureViewName, - FeatureName: featName, - Values: append(featureData.Values, val), - Statuses: append(featureData.Statuses, status), - EventTimestamps: append(featureData.EventTimestamps, timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}), - } - } else { - rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{ - FeatureView: featureViewName, - FeatureName: featName, - Values: []interface{}{val}, - Statuses: []serving.FieldStatus{status}, - EventTimestamps: []timestamppb.Timestamp{{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}}, - } - } - } else { - if featureData := &rowFeatures[featureNamesToIdx[featName]]; featureData != nil { - rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{ - FeatureView: featureViewName, - FeatureName: featName, - Values: append(featureData.Values, nil), - Statuses: append(featureData.Statuses, serving.FieldStatus_NOT_FOUND), - EventTimestamps: append(featureData.EventTimestamps, timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}), - } - } else { - rowFeatures[featureNamesToIdx[featName]] = RangeFeatureData{ - FeatureView: featureViewName, - FeatureName: featName, - Values: []interface{}{nil}, - Statuses: []serving.FieldStatus{serving.FieldStatus_NOT_FOUND}, - EventTimestamps: []timestamppb.Timestamp{{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}}, - } - } - } + for j, featName := range featureNames { + val, _ := readValues[featName] + entityResults[j].Values = append(entityResults[j].Values, val) + entityResults[j].EventTimestamps = append(entityResults[j].EventTimestamps, eventTsProtoValue) } - results[rowIdx] = rowFeatures } - }(serializedEntityKey) + }(serializedEntityKey, i) } // wait until all concurrent single-key queries are done @@ -890,12 +900,28 @@ func (c *CassandraOnlineStore) OnlineReadRange(ctx context.Context, entityKeys [ var collectedErrors []error for err := range errorsChannel { - if err != nil { - collectedErrors = append(collectedErrors, err) - } + collectedErrors = append(collectedErrors, err) } if len(collectedErrors) > 0 { - return nil, errors.Join(collectedErrors...) + return nil, fmt.Errorf("encountered errors during range read: %v", collectedErrors) + } + + for _, entityRow := range results { + for i := range entityRow { + featureData := &entityRow[i] + if len(featureData.Values) == 1 && featureData.Values[0] == nil && len(featureData.EventTimestamps) == 1 && featureData.EventTimestamps[0].Seconds == 0 && featureData.EventTimestamps[0].Nanos == 0 { + featureData.Statuses = []serving.FieldStatus{serving.FieldStatus_NOT_FOUND} + } else { + featureData.Statuses = make([]serving.FieldStatus, len(featureData.Values)) + for k, val := range featureData.Values { + if val == nil { + featureData.Statuses[k] = serving.FieldStatus_NULL_VALUE + } else { + featureData.Statuses[k] = serving.FieldStatus_PRESENT + } + } + } + } } return results, nil