Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class CopilotToolProcessor extends AbstractProcessor {

@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
for (Element element : roundEnv.getElementsAnnotatedWith(CopilotTool.class)) {
List<Element> annotatedElements = getCopilotToolAnnotatedElements(roundEnv);
for (Element element : annotatedElements) {
if (element.getKind() != ElementKind.METHOD) {
continue;
}
Expand All @@ -75,7 +76,7 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment

// Group methods by enclosing type
Map<TypeElement, List<ExecutableElement>> methodsByClass = new LinkedHashMap<>();
for (Element element : roundEnv.getElementsAnnotatedWith(CopilotTool.class)) {
for (Element element : annotatedElements) {
if (element.getKind() != ElementKind.METHOD) {
continue;
}
Expand All @@ -95,6 +96,15 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment
return false;
}

private List<Element> getCopilotToolAnnotatedElements(RoundEnvironment roundEnv) {
TypeElement copilotToolType = processingEnv.getElementUtils()
.getTypeElement("com.github.copilot.tool.CopilotTool");
if (copilotToolType != null) {
return new ArrayList<>(roundEnv.getElementsAnnotatedWith(copilotToolType));
}
return new ArrayList<>(roundEnv.getElementsAnnotatedWith(CopilotTool.class));
}

private void generateMetaClass(TypeElement classElement, List<ExecutableElement> methods) {
String packageName = processingEnv.getElementUtils().getPackageOf(classElement).getQualifiedName().toString();
String simpleClassName = classElement.getSimpleName().toString();
Expand Down
23 changes: 21 additions & 2 deletions java/src/test/java/com/github/copilot/CopilotSessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,27 @@ void testShouldGetLastSessionId() throws Exception {
ctx.configureForTest("session", "should_get_last_session_id");

try (CopilotClient client = ctx.createClient()) {
CopilotSession session = client
.createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get();
CopilotSession session = null;
for (int attempt = 1; attempt <= 2; attempt++) {
CompletableFuture<CopilotSession> createFuture = client
.createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL));
try {
session = createFuture.get(45, TimeUnit.SECONDS);
break;
} catch (java.util.concurrent.TimeoutException e) {
createFuture.cancel(true);
if (attempt == 2) {
throw e;
}
} catch (java.util.concurrent.ExecutionException e) {
if (e.getCause() instanceof java.util.concurrent.TimeoutException && attempt < 2) {
createFuture.cancel(true);
continue;
}
throw e;
}
}
assertNotNull(session, "Session should be created");
Comment thread
Copilot marked this conversation as resolved.

session.sendAndWait(new MessageOptions().setPrompt("Say hello")).get(60, TimeUnit.SECONDS);
String sessionId = session.getSessionId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.File;
import java.io.FilterWriter;
import java.io.IOException;
import java.io.Writer;
import java.net.URI;
import java.nio.file.Path;
import java.security.CodeSource;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.tools.Diagnostic;
import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.ForwardingJavaFileObject;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
Expand Down Expand Up @@ -540,25 +548,28 @@ private CompilationResult compileWithProcessor(List<JavaFileObject> sources) {

String classpath = resolveClasspath();
List<String> options = new ArrayList<>();
options.add("-proc:full");
options.addAll(List.of("-processor", "com.github.copilot.tool.CopilotToolProcessor"));
options.addAll(List.of("-classpath", classpath));
options.addAll(List.of("-d", tempDir.toString()));
options.addAll(List.of("-s", tempDir.toString()));
// Allow experimental APIs during test compilation
options.add("-Acopilot.experimental.allowed=true");

try {
StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null);
try (StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null)) {
fileManager.setLocation(StandardLocation.SOURCE_OUTPUT, List.of(tempDir.toFile()));
fileManager.setLocation(StandardLocation.CLASS_OUTPUT, List.of(tempDir.toFile()));
CollectingFileManager collectingFileManager = new CollectingFileManager(fileManager);

JavaCompiler.CompilationTask task = compiler.getTask(null, fileManager, diagnostics, options, null,
sources);
task.setProcessors(List.of(new CopilotToolProcessor()));
JavaCompiler.CompilationTask task = compiler.getTask(null, collectingFileManager, diagnostics, options,
null, sources);
task.call();

// Collect generated sources
List<String> generatedSources = new ArrayList<>();
collectGeneratedFiles(tempDir, generatedSources);
List<String> generatedSources = collectingFileManager.getGeneratedSources();
if (generatedSources.isEmpty()) {
// Fallback for file-manager implementations that only materialize on disk.
collectGeneratedFiles(tempDir, generatedSources);
}

return new CompilationResult(diagnostics.getDiagnostics(), generatedSources, tempDir);
} catch (Exception e) {
Expand Down Expand Up @@ -666,4 +677,52 @@ String getGeneratedSource(String qualifiedName) {
return null;
}
}

private static class CollectingFileManager extends ForwardingJavaFileManager<StandardJavaFileManager> {
private final Map<String, StringBuilder> generatedByClass = new LinkedHashMap<>();

CollectingFileManager(StandardJavaFileManager fileManager) {
super(fileManager);
}

@Override
public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind,
FileObject sibling) throws IOException {
JavaFileObject delegate = super.getJavaFileForOutput(location, className, kind, sibling);
if (kind != JavaFileObject.Kind.SOURCE) {
return delegate;
}
StringBuilder captured = new StringBuilder();
generatedByClass.put(className, captured);
return new ForwardingJavaFileObject<>(delegate) {
@Override
public Writer openWriter() throws IOException {
Writer target = delegate.openWriter();
return new FilterWriter(target) {
@Override
public void write(char[] cbuf, int off, int len) throws IOException {
captured.append(cbuf, off, len);
super.write(cbuf, off, len);
}

@Override
public void write(int c) throws IOException {
captured.append((char) c);
super.write(c);
}

@Override
public void write(String str, int off, int len) throws IOException {
captured.append(str, off, off + len);
super.write(str, off, len);
}
};
}
};
}

List<String> getGeneratedSources() {
return generatedByClass.values().stream().map(StringBuilder::toString).toList();
}
}
}
28 changes: 28 additions & 0 deletions test/snapshots/session/should_abort_a_session.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,31 @@ conversations:
content: What is 2+2?
- role: assistant
content: 2 + 2 = 4
- messages:
- role: system
content: ${system}
- role: user
content: run the shell command 'sleep 100' (note this works on both bash and PowerShell)
- role: assistant
content: I'll run the sleep command for 100 seconds.
tool_calls:
- id: toolcall_0
type: function
function:
name: report_intent
arguments: '{"intent":"Running sleep command"}'
- id: toolcall_1
type: function
function:
name: ${shell}
arguments: '{"command":"sleep 100","description":"Run sleep 100 command","mode":"sync","initial_wait":105}'
- role: tool
tool_call_id: toolcall_0
content: The execution of this tool, or a previous tool was interrupted.
- role: tool
tool_call_id: toolcall_1
content: The execution of this tool, or a previous tool was interrupted.
- role: user
content: What is 2+2?
- role: assistant
content: 2 + 2 = 4
Loading