diff --git a/.gitignore b/.gitignore index c4a2655..423d2f5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ tmp* build/* dist/* bayesh.egg-info/* -bin/bayesh \ No newline at end of file +bin/bayesh +coverage* \ No newline at end of file diff --git a/src/db.go b/src/db.go index aff595e..4b79a95 100644 --- a/src/db.go +++ b/src/db.go @@ -117,20 +117,15 @@ func (q *Queries) ConditionalEventCounts(ctx context.Context, cwd *string, previ conditions = append(conditions, colPreviousCmd+" = ?") args = append(args, *previousCmd) } - if minEventCount != nil { - conditions = append(conditions, colEventCounter+" >= ?") - args = append(args, *minEventCount) - } if len(conditions) > 0 { query.WriteString("WHERE " + strings.Join(conditions, " AND ") + newLine) } query.WriteString("GROUP BY " + colCurrentCmd + newLine) - slog.Debug("Inferring current command with", - "query", query.String(), - "cwd", cwd, - "previousCmd", previousCmd, - ) + if minEventCount != nil { + query.WriteString("HAVING SUM(" + colEventCounter + ") >= ?" + newLine) + args = append(args, *minEventCount) + } rows, err := q.db.QueryContext(ctx, query.String(), args...) if err != nil { diff --git a/src/db_test.go b/src/db_test.go index 30100d0..7f02baf 100644 --- a/src/db_test.go +++ b/src/db_test.go @@ -400,12 +400,16 @@ func TestConditionalEventCount(t *testing.T) { for _, row := range testData.allRows { matchesCwd := !input.targetCwd || row.Cwd == testData.targetCwd matchesPrevCmd := !input.targetPrevCmd || row.PreviousCmd == testData.targetPrevCmd - matchesMinEventCount := !input.minEventCount || row.EventCounter >= testData.targetMinEventCount - if matchesCwd && matchesPrevCmd && matchesMinEventCount { + if matchesCwd && matchesPrevCmd { expectedData[row.CurrentCmd] += row.EventCounter } } + for cmd, count := range expectedData { + if input.minEventCount && count < testData.targetMinEventCount { + delete(expectedData, cmd) + } + } var cwd *string = nil if input.targetCwd {